Skip to content

Commit

Permalink
refactor: cleanup ContractValidationService (eclipse-edc#3878)
Browse files Browse the repository at this point in the history
refactor: cleanup ContractValidationService
  • Loading branch information
ndr-brt authored Feb 19, 2024
1 parent 4d190ae commit 6c48818
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 257 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.eclipse.edc.runtime.metamodel.annotation.Inject;
import org.eclipse.edc.runtime.metamodel.annotation.Provides;
import org.eclipse.edc.runtime.metamodel.annotation.Setting;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.asset.AssetIndex;
import org.eclipse.edc.spi.event.EventRouter;
import org.eclipse.edc.spi.message.RemoteMessageDispatcherRegistry;
Expand Down Expand Up @@ -109,9 +108,6 @@ public class ContractCoreExtension implements ServiceExtension {
@Inject
private ContractNegotiationStore store;

@Inject
private ParticipantAgentService agentService;

@Inject
private PolicyEngine policyEngine;

Expand Down Expand Up @@ -180,7 +176,7 @@ private void registerServices(ServiceExtensionContext context) {
var participantId = context.getParticipantId();

var policyEquality = new PolicyEquality(typeManager);
var validationService = new ContractValidationServiceImpl(agentService, assetIndex, policyEngine, policyEquality);
var validationService = new ContractValidationServiceImpl(assetIndex, policyEngine, policyEquality);
context.registerService(ContractValidationService.class, validationService);

// bind/register rule to evaluate contract expiry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import org.eclipse.edc.policy.engine.spi.PolicyEngine;
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.asset.AssetIndex;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.query.Criterion;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.types.domain.agreement.ContractAgreement;
Expand All @@ -51,16 +49,13 @@
*/
public class ContractValidationServiceImpl implements ContractValidationService {

private final ParticipantAgentService agentService;
private final AssetIndex assetIndex;
private final PolicyEngine policyEngine;
private final PolicyEquality policyEquality;

public ContractValidationServiceImpl(ParticipantAgentService agentService,
AssetIndex assetIndex,
public ContractValidationServiceImpl(AssetIndex assetIndex,
PolicyEngine policyEngine,
PolicyEquality policyEquality) {
this.agentService = agentService;
this.assetIndex = assetIndex;
this.policyEngine = policyEngine;
this.policyEquality = policyEquality;
Expand Down Expand Up @@ -125,34 +120,6 @@ public ContractValidationServiceImpl(ParticipantAgentService agentService,
return success();
}

@Override
public @NotNull Result<ValidatedConsumerOffer> validateInitialOffer(ClaimToken token, ValidatableConsumerOffer consumerOffer) {
return validateInitialOffer(agentService.createFor(token), consumerOffer);
}

@Override
@NotNull
public Result<ContractAgreement> validateAgreement(ClaimToken token, ContractAgreement agreement) {
return validateAgreement(agentService.createFor(token), agreement);
}

@Override
public @NotNull Result<Void> validateRequest(ClaimToken token, ContractAgreement agreement) {
return validateRequest(agentService.createFor(token), agreement);
}

@Override
@NotNull
public Result<Void> validateRequest(ClaimToken token, ContractNegotiation negotiation) {
return validateRequest(agentService.createFor(token), negotiation);
}

@Override
@NotNull
public Result<Void> validateConfirmed(ClaimToken token, ContractAgreement agreement, ContractOffer latestOffer) {
return validateConfirmed(agentService.createFor(token), agreement, latestOffer);
}

/**
* Validates an initial contract offer, ensuring that the referenced asset exists, is selected by the corresponding policy definition and the agent fulfills the contract policy.
* A sanitized policy definition is returned to avoid clients injecting manipulated policies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import org.eclipse.edc.policy.model.Permission;
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.asset.AssetIndex;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.types.domain.agreement.ContractAgreement;
import org.eclipse.edc.spi.types.domain.asset.Asset;
Expand Down Expand Up @@ -77,10 +75,9 @@ class ContractValidationServiceImplTest {
private final AssetIndex assetIndex = mock();
private final PolicyEngine policyEngine = mock();
private final PolicyEquality policyEquality = mock();
private final ParticipantAgentService agentService = mock();

private final ContractValidationService validationService =
new ContractValidationServiceImpl(agentService, assetIndex, policyEngine, policyEquality);
new ContractValidationServiceImpl(assetIndex, policyEngine, policyEquality);

private static ContractDefinition.Builder createContractDefinitionBuilder() {
return ContractDefinition.Builder.newInstance()
Expand Down Expand Up @@ -343,15 +340,13 @@ void validateInitialOffer_fails_whenContractPolicyEvaluationFails() {

var validatableOffer = createValidatableConsumerOffer();
var participantAgent = new ParticipantAgent(emptyMap(), Map.of(PARTICIPANT_IDENTITY, CONSUMER_ID));
var claimToken = ClaimToken.Builder.newInstance().build();

when(agentService.createFor(eq(claimToken))).thenReturn(participantAgent);
when(policyEngine.evaluate(eq(CATALOGING_SCOPE), any(), isA(PolicyContext.class))).thenReturn(Result.success());
when(policyEngine.evaluate(eq(NEGOTIATION_SCOPE), any(), isA(PolicyContext.class))).thenReturn(Result.failure("evaluation failure"));
when(assetIndex.findById(anyString())).thenReturn(Asset.Builder.newInstance().build());
when(assetIndex.countAssets(anyList())).thenReturn(1L);

var result = validationService.validateInitialOffer(claimToken, validatableOffer);
var result = validationService.validateInitialOffer(participantAgent, validatableOffer);

assertThat(result).isFailed().detail()
.startsWith("Policy in scope %s not fulfilled for offer %s, policy evaluation".formatted(NEGOTIATION_SCOPE, validatableOffer.getOfferId().toString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public CatalogService catalogService() {

@Provider
public CatalogProtocolService catalogProtocolService(ServiceExtensionContext context) {
return new CatalogProtocolServiceImpl(datasetResolver, participantAgentService, dataServiceRegistry,
return new CatalogProtocolServiceImpl(datasetResolver, dataServiceRegistry,
protocolTokenValidator(), context.getParticipantId(), transactionContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import org.eclipse.edc.connector.spi.protocol.ProtocolTokenValidator;
import org.eclipse.edc.policy.engine.spi.PolicyScope;
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.result.ServiceResult;
import org.eclipse.edc.transaction.spi.TransactionContext;
Expand All @@ -41,21 +40,18 @@ public class CatalogProtocolServiceImpl implements CatalogProtocolService {
private static final String PARTICIPANT_ID_PROPERTY_KEY = "participantId";

private final DatasetResolver datasetResolver;
private final ParticipantAgentService participantAgentService;
private final DataServiceRegistry dataServiceRegistry;
private final String participantId;
private final TransactionContext transactionContext;

private final ProtocolTokenValidator protocolTokenValidator;

public CatalogProtocolServiceImpl(DatasetResolver datasetResolver,
ParticipantAgentService participantAgentService,
DataServiceRegistry dataServiceRegistry,
ProtocolTokenValidator protocolTokenValidator,
String participantId,
TransactionContext transactionContext) {
this.datasetResolver = datasetResolver;
this.participantAgentService = participantAgentService;
this.dataServiceRegistry = dataServiceRegistry;
this.protocolTokenValidator = protocolTokenValidator;
this.participantId = participantId;
Expand All @@ -66,7 +62,6 @@ public CatalogProtocolServiceImpl(DatasetResolver datasetResolver,
@NotNull
public ServiceResult<Catalog> getCatalog(CatalogRequestMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> verifyToken(tokenRepresentation)
.map(participantAgentService::createFor)
.map(agent -> {
try (var datasets = datasetResolver.query(agent, message.getQuerySpec())) {
var dataServices = dataServiceRegistry.getDataServices();
Expand All @@ -84,7 +79,6 @@ public ServiceResult<Catalog> getCatalog(CatalogRequestMessage message, TokenRep
@Override
public @NotNull ServiceResult<Dataset> getDataset(String datasetId, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> verifyToken(tokenRepresentation)
.map(participantAgentService::createFor)
.map(agent -> datasetResolver.getById(agent, datasetId))
.compose(dataset -> {
if (dataset == null) {
Expand All @@ -95,8 +89,8 @@ public ServiceResult<Catalog> getCatalog(CatalogRequestMessage message, TokenRep
}));
}

private ServiceResult<ClaimToken> verifyToken(TokenRepresentation tokenRepresentation) {
return protocolTokenValidator.verifyToken(tokenRepresentation, CATALOGING_REQUEST_SCOPE, Policy.Builder.newInstance().build());
private ServiceResult<ParticipantAgent> verifyToken(TokenRepresentation tokenRepresentation) {
return protocolTokenValidator.verify(tokenRepresentation, CATALOGING_REQUEST_SCOPE, Policy.Builder.newInstance().build());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.iam.IdentityService;
import org.eclipse.edc.spi.iam.RequestScope;
import org.eclipse.edc.spi.iam.TokenRepresentation;
Expand All @@ -35,7 +34,6 @@
public class ProtocolTokenValidatorImpl implements ProtocolTokenValidator {

private final IdentityService identityService;

private final PolicyEngine policyEngine;
private final ParticipantAgentService agentService;

Expand All @@ -49,25 +47,6 @@ public ProtocolTokenValidatorImpl(IdentityService identityService, PolicyEngine
this.agentService = agentService;
}

/**
* Validate and extract the {@link ClaimToken} from the input {@link TokenRepresentation} by using the {@link IdentityService}
*
* @param tokenRepresentation The input {@link TokenRepresentation}
* @param policyScope The policy scope
* @param policy The {@link Policy}
* @return The {@link ClaimToken} if success, failure otherwise
*/
@Override
public ServiceResult<ClaimToken> verifyToken(TokenRepresentation tokenRepresentation, String policyScope, Policy policy) {
var result = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy));

if (result.failed()) {
monitor.debug(() -> "Unauthorized: %s".formatted(result.getFailureDetail()));
return ServiceResult.unauthorized("Unauthorized");
}
return ServiceResult.success(result.getContent());
}

@Override
public ServiceResult<ParticipantAgent> verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy) {
var tokenValidation = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.eclipse.edc.connector.transfer.spi.types.protocol.TransferStartMessage;
import org.eclipse.edc.connector.transfer.spi.types.protocol.TransferTerminationMessage;
import org.eclipse.edc.policy.engine.spi.PolicyScope;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.ServiceResult;
Expand Down Expand Up @@ -105,7 +105,7 @@ public ServiceResult<TransferProcess> notifyRequested(TransferRequestMessage mes
public ServiceResult<TransferProcess> notifyStarted(TransferStartMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> onMessageDo(message, context.claimToken(), context.agreement(), transferProcess -> startedAction(message, transferProcess)))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> startedAction(message, transferProcess)))
);
}

Expand All @@ -115,7 +115,7 @@ public ServiceResult<TransferProcess> notifyStarted(TransferStartMessage message
public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> onMessageDo(message, context.claimToken(), context.agreement(), transferProcess -> completedAction(message, transferProcess)))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> completedAction(message, transferProcess)))
);
}

Expand All @@ -125,7 +125,7 @@ public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage
public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> onMessageDo(message, context.claimToken(), context.agreement(), transferProcess -> terminatedAction(message, transferProcess)))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> terminatedAction(message, transferProcess)))
);
}

Expand All @@ -135,7 +135,7 @@ public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessag
public ServiceResult<TransferProcess> findById(String id, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(id, this::findTransferProcessById)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> validateCounterParty(context.claimToken(), context.agreement(), context.transferProcess())));
.compose(context -> validateCounterParty(context.participantAgent(), context.agreement(), context.transferProcess())));
}

@NotNull
Expand Down Expand Up @@ -229,7 +229,7 @@ private ServiceResult<ClaimTokenContext> validateDestination(TransferRequestMess
}

private ServiceResult<ClaimTokenContext> validateAgreement(TransferRemoteMessage message, ClaimTokenContext context) {
var validationResult = contractValidationService.validateAgreement(context.claimToken(), context.agreement());
var validationResult = contractValidationService.validateAgreement(context.participantAgent(), context.agreement());
if (validationResult.failed()) {
return ServiceResult.conflict(format("Cannot process %s because %s", message.getClass().getSimpleName(), "agreement not found or not valid"));
}
Expand All @@ -248,7 +248,7 @@ private <T> ServiceResult<TransferRequestMessageContext> fetchRequestContext(T i
}

private ServiceResult<ClaimTokenContext> verifyRequest(TokenRepresentation tokenRepresentation, TransferRequestMessageContext context) {
var result = protocolTokenValidator.verifyToken(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy());
var result = protocolTokenValidator.verify(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy());
if (result.failed()) {
monitor.debug(() -> "Verification Failed: %s".formatted(result.getFailureDetail()));
return ServiceResult.notFound("Not found");
Expand All @@ -265,9 +265,9 @@ private ServiceResult<ContractAgreement> findContractByTransferProcess(TransferP
return ServiceResult.success(agreement);
}

private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message, ClaimToken claimToken, ContractAgreement agreement, Function<TransferProcess, ServiceResult<TransferProcess>> action) {
private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message, ParticipantAgent participantAgent, ContractAgreement agreement, Function<TransferProcess, ServiceResult<TransferProcess>> action) {
return findAndLease(message)
.compose(transferProcess -> validateCounterParty(claimToken, agreement, transferProcess)
.compose(transferProcess -> validateCounterParty(participantAgent, agreement, transferProcess)
.compose(p -> {
if (p.shouldIgnoreIncomingMessage(message.getId())) {
return ServiceResult.success(p);
Expand All @@ -278,8 +278,8 @@ private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message
.onFailure(f -> breakLease(transferProcess)));
}

private ServiceResult<TransferProcess> validateCounterParty(ClaimToken claimToken, ContractAgreement agreement, TransferProcess transferProcess) {
var validation = contractValidationService.validateRequest(claimToken, agreement);
private ServiceResult<TransferProcess> validateCounterParty(ParticipantAgent participantAgent, ContractAgreement agreement, TransferProcess transferProcess) {
var validation = contractValidationService.validateRequest(participantAgent, agreement);
if (validation.failed()) {
return ServiceResult.badRequest(validation.getFailureMessages());
}
Expand Down Expand Up @@ -326,7 +326,7 @@ private void update(TransferProcess transferProcess) {
private record TransferRequestMessageContext(ContractAgreement agreement, TransferProcess transferProcess) {
}

private record ClaimTokenContext(ClaimToken claimToken, ContractAgreement agreement,
private record ClaimTokenContext(ParticipantAgent participantAgent, ContractAgreement agreement,
TransferProcess transferProcess) {
}
}
Loading

0 comments on commit 6c48818

Please sign in to comment.