From 48c701959f21be671ebdc91d4d59a26c8802c91e Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Mon, 5 Aug 2024 11:23:36 -0700 Subject: [PATCH] Adds reprovision API to support updating search pipelines, ingest pipelines index settings (#804) * Initial commit, Adds ReprovisionWorkflowTransportAction, reprovision param for RestCreateWorkflowAction, creates and registers Update Ingest/Search pipeline steps in WorkflowResources, registers update steps in WorkflowStepFactory Signed-off-by: Joshua Palis * Initial reprovisiontransportaction implementation, Added UpdateIndexStep, improved WorkflowProcessSorter.createReprovisionSequence Signed-off-by: Joshua Palis * Implements Update index Step to support updating index settings, modifies updating resource created script to remove error if any Signed-off-by: Joshua Palis * Improves workflow node comparision Signed-off-by: Joshua Palis * Adding comments Signed-off-by: Joshua Palis * Fixing tests, adding javadocs Signed-off-by: Joshua Palis * Adding changelog Signed-off-by: Joshua Palis * Updating parse utils, RestCreateWorkflowAction, CreateWorkflowTransportAction tests. Adding check for reprovision without workflowID. Signed-off-by: Joshua Palis * Adding update step and get resource step tests Signed-off-by: Joshua Palis * Adding check for filtered setting list size Signed-off-by: Joshua Palis * Addign reprovision workflow transport action tests Signed-off-by: Joshua Palis * Adding tests for reprovision sequence creation Signed-off-by: Joshua Palis * Addressing comments Signed-off-by: Joshua Palis * Changing GetResourceStep to WorkflowDataStep Signed-off-by: Joshua Palis * Addressing PR comments Signed-off-by: Joshua Palis * Fixing state check for reprovision transport action Signed-off-by: Joshua Palis * Adding state eror check to reprovision transport action to remove error field Signed-off-by: Joshua Palis * removing error check from flowframeworkindices handler Signed-off-by: Joshua Palis * Adding check for no updated settings Signed-off-by: Joshua Palis * refactor reprovision sequence creation Signed-off-by: Joshua Palis * Fixing workflowrequest serialization Signed-off-by: Joshua Palis * Addressing PR comments Signed-off-by: Joshua Palis * Moving flattenSettings method to ParseUtils, added flatten settings tests Signed-off-by: Joshua Palis * updating workflowrequest Signed-off-by: Joshua Palis * fixing workflowrequest Signed-off-by: Joshua Palis * spotlessApply Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- CHANGELOG.md | 10 +- build.gradle | 1 + .../flowframework/FlowFrameworkPlugin.java | 5 +- .../flowframework/common/CommonValue.java | 2 + .../common/WorkflowResources.java | 98 ++++-- .../indices/FlowFrameworkIndicesHandler.java | 2 +- .../rest/RestCreateWorkflowAction.java | 43 ++- .../CreateWorkflowTransportAction.java | 122 ++++--- .../transport/ReprovisionWorkflowAction.java | 28 ++ .../transport/ReprovisionWorkflowRequest.java | 98 ++++++ .../ReprovisionWorkflowTransportAction.java | 333 ++++++++++++++++++ .../transport/WorkflowRequest.java | 50 ++- .../flowframework/util/ParseUtils.java | 35 ++ .../workflow/AbstractUpdatePipelineStep.java | 133 +++++++ .../workflow/UpdateIndexStep.java | 178 ++++++++++ .../workflow/UpdateIngestPipelineStep.java | 53 +++ .../workflow/UpdateSearchPipelineStep.java | 53 +++ .../workflow/WorkflowDataStep.java | 61 ++++ .../workflow/WorkflowProcessSorter.java | 261 ++++++++++++++ .../workflow/WorkflowStepFactory.java | 25 ++ .../FlowFrameworkPluginTests.java | 2 +- .../model/ResourceCreatedTests.java | 2 +- .../model/WorkflowValidatorTests.java | 2 +- .../rest/RestCreateWorkflowActionTests.java | 18 + .../CreateWorkflowTransportActionTests.java | 92 ++++- .../ReprovisionWorkflowRequestTests.java | 87 +++++ ...provisionWorkflowTransportActionTests.java | 307 ++++++++++++++++ .../WorkflowRequestResponseTests.java | 14 +- .../flowframework/util/ParseUtilsTests.java | 77 ++++ .../workflow/UpdateIndexStepTests.java | 242 +++++++++++++ .../UpdateIngestPipelineStepTests.java | 150 ++++++++ .../UpdateSearchPipelineStepTests.java | 151 ++++++++ .../workflow/WorkflowDataStepTests.java | 58 +++ .../workflow/WorkflowProcessSorterTests.java | 261 ++++++++++++++ 34 files changed, 2936 insertions(+), 118 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/AbstractUpdatePipelineStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/WorkflowDataStep.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStepTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStepTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/WorkflowDataStepTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 87f92d385..bdaed0ea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,19 +16,11 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.14...2.x) ### Features -- Support editing of certain workflow fields on a provisioned workflow ([#757](https://github.com/opensearch-project/flow-framework/pull/757)) -- Add allow_delete parameter to Deprovision API ([#763](https://github.com/opensearch-project/flow-framework/pull/763)) +- Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804)) ### Enhancements -- Register system index descriptors through SystemIndexPlugin.getSystemIndexDescriptors ([#750](https://github.com/opensearch-project/flow-framework/pull/750)) - ### Bug Fixes -- Handle Not Found exceptions as successful deletions for agents and models ([#805](https://github.com/opensearch-project/flow-framework/pull/805)) -- Wrap CreateIndexRequest mappings in _doc key as required ([#809](https://github.com/opensearch-project/flow-framework/pull/809)) -- Have FlowFrameworkException status recognized by ExceptionsHelper ([#811](https://github.com/opensearch-project/flow-framework/pull/811)) - ### Infrastructure ### Documentation ### Maintenance ### Refactoring -- Improve Template and WorkflowState builders ([#778](https://github.com/opensearch-project/flow-framework/pull/778)) diff --git a/build.gradle b/build.gradle index e0f38276c..b9f617683 100644 --- a/build.gradle +++ b/build.gradle @@ -175,6 +175,7 @@ dependencies { implementation "jakarta.json.bind:jakarta.json.bind-api:3.0.1" implementation "org.glassfish:jakarta.json:2.0.1" implementation "org.eclipse:yasson:3.0.3" + implementation "com.google.code.gson:gson:2.10.1" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index ec4c05145..f69534a77 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -50,6 +50,8 @@ import org.opensearch.flowframework.transport.GetWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; +import org.opensearch.flowframework.transport.ReprovisionWorkflowAction; +import org.opensearch.flowframework.transport.ReprovisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; import org.opensearch.flowframework.transport.SearchWorkflowStateAction; import org.opensearch.flowframework.transport.SearchWorkflowStateTransportAction; @@ -170,7 +172,8 @@ public List getRestHandlers( new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class), new ActionHandler<>(GetWorkflowStepAction.INSTANCE, GetWorkflowStepTransportAction.class), - new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class) + new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class), + new ActionHandler<>(ReprovisionWorkflowAction.INSTANCE, ReprovisionWorkflowTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 8bd8e9871..f291cff1c 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -78,6 +78,8 @@ private CommonValue() {} public static final String WORKFLOW_STEP = "workflow_step"; /** The param name for default use case, used by the create workflow API */ public static final String USE_CASE = "use_case"; + /** The param name for reprovisioning, used by the create workflow API */ + public static final String REPROVISION_WORKFLOW = "reprovision"; /* * Constants associated with plugin configuration diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index ad94f7b21..1b36b5e6f 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -32,6 +32,9 @@ import org.opensearch.flowframework.workflow.RegisterRemoteModelStep; import org.opensearch.flowframework.workflow.ReindexStep; import org.opensearch.flowframework.workflow.UndeployModelStep; +import org.opensearch.flowframework.workflow.UpdateIndexStep; +import org.opensearch.flowframework.workflow.UpdateIngestPipelineStep; +import org.opensearch.flowframework.workflow.UpdateSearchPipelineStep; import java.util.Set; import java.util.stream.Collectors; @@ -43,29 +46,39 @@ public enum WorkflowResources { /** Workflow steps for creating/deleting a connector and associated created resource */ - CREATE_CONNECTOR(CreateConnectorStep.NAME, WorkflowResources.CONNECTOR_ID, DeleteConnectorStep.NAME), + CREATE_CONNECTOR(CreateConnectorStep.NAME, null, DeleteConnectorStep.NAME, WorkflowResources.CONNECTOR_ID), /** Workflow steps for registering/deleting a remote model and associated created resource */ - REGISTER_REMOTE_MODEL(RegisterRemoteModelStep.NAME, WorkflowResources.MODEL_ID, DeleteModelStep.NAME), + REGISTER_REMOTE_MODEL(RegisterRemoteModelStep.NAME, null, DeleteModelStep.NAME, WorkflowResources.MODEL_ID), /** Workflow steps for registering/deleting a local model and associated created resource */ - REGISTER_LOCAL_MODEL(RegisterLocalCustomModelStep.NAME, WorkflowResources.MODEL_ID, DeleteModelStep.NAME), + REGISTER_LOCAL_MODEL(RegisterLocalCustomModelStep.NAME, null, DeleteModelStep.NAME, WorkflowResources.MODEL_ID), /** Workflow steps for registering/deleting a local sparse encoding model and associated created resource */ - REGISTER_LOCAL_SPARSE_ENCODING_MODEL(RegisterLocalSparseEncodingModelStep.NAME, WorkflowResources.MODEL_ID, DeleteModelStep.NAME), + REGISTER_LOCAL_SPARSE_ENCODING_MODEL(RegisterLocalSparseEncodingModelStep.NAME, null, DeleteModelStep.NAME, WorkflowResources.MODEL_ID), /** Workflow steps for registering/deleting a local OpenSearch provided pretrained model and associated created resource */ - REGISTER_LOCAL_PRETRAINED_MODEL(RegisterLocalPretrainedModelStep.NAME, WorkflowResources.MODEL_ID, DeleteModelStep.NAME), + REGISTER_LOCAL_PRETRAINED_MODEL(RegisterLocalPretrainedModelStep.NAME, null, DeleteModelStep.NAME, WorkflowResources.MODEL_ID), /** Workflow steps for registering/deleting a model group and associated created resource */ - REGISTER_MODEL_GROUP(RegisterModelGroupStep.NAME, WorkflowResources.MODEL_GROUP_ID, NoOpStep.NAME), + REGISTER_MODEL_GROUP(RegisterModelGroupStep.NAME, null, NoOpStep.NAME, WorkflowResources.MODEL_GROUP_ID), /** Workflow steps for deploying/undeploying a model and associated created resource */ - DEPLOY_MODEL(DeployModelStep.NAME, WorkflowResources.MODEL_ID, UndeployModelStep.NAME), + DEPLOY_MODEL(DeployModelStep.NAME, null, UndeployModelStep.NAME, WorkflowResources.MODEL_ID), /** Workflow steps for creating an ingest-pipeline and associated created resource */ - CREATE_INGEST_PIPELINE(CreateIngestPipelineStep.NAME, WorkflowResources.PIPELINE_ID, DeleteIngestPipelineStep.NAME), + CREATE_INGEST_PIPELINE( + CreateIngestPipelineStep.NAME, + UpdateIngestPipelineStep.NAME, + DeleteIngestPipelineStep.NAME, + WorkflowResources.PIPELINE_ID + ), /** Workflow steps for creating an ingest-pipeline and associated created resource */ - CREATE_SEARCH_PIPELINE(CreateSearchPipelineStep.NAME, WorkflowResources.PIPELINE_ID, DeleteSearchPipelineStep.NAME), + CREATE_SEARCH_PIPELINE( + CreateSearchPipelineStep.NAME, + UpdateSearchPipelineStep.NAME, + DeleteSearchPipelineStep.NAME, + WorkflowResources.PIPELINE_ID + ), /** Workflow steps for creating an index and associated created resource */ - CREATE_INDEX(CreateIndexStep.NAME, WorkflowResources.INDEX_NAME, DeleteIndexStep.NAME), + CREATE_INDEX(CreateIndexStep.NAME, UpdateIndexStep.NAME, DeleteIndexStep.NAME, WorkflowResources.INDEX_NAME), /** Workflow steps for reindex a source index to destination index and associated created resource */ - REINDEX(ReindexStep.NAME, WorkflowResources.INDEX_NAME, NoOpStep.NAME), + REINDEX(ReindexStep.NAME, null, NoOpStep.NAME, WorkflowResources.INDEX_NAME), /** Workflow steps for registering/deleting an agent and the associated created resource */ - REGISTER_AGENT(RegisterAgentStep.NAME, WorkflowResources.AGENT_ID, DeleteAgentStep.NAME); + REGISTER_AGENT(RegisterAgentStep.NAME, null, DeleteAgentStep.NAME, WorkflowResources.AGENT_ID); /** Connector Id for a remote model connector */ public static final String CONNECTOR_ID = "connector_id"; @@ -80,34 +93,37 @@ public enum WorkflowResources { /** Agent Id */ public static final String AGENT_ID = "agent_id"; - private final String workflowStep; - private final String resourceCreated; + private final String createStep; + private final String updateStep; private final String deprovisionStep; + private final String resourceCreated; + private static final Logger logger = LogManager.getLogger(WorkflowResources.class); private static final Set allResources = Stream.of(values()) .map(WorkflowResources::getResourceCreated) .collect(Collectors.toSet()); - WorkflowResources(String workflowStep, String resourceCreated, String deprovisionStep) { - this.workflowStep = workflowStep; - this.resourceCreated = resourceCreated; + WorkflowResources(String createStep, String updateStep, String deprovisionStep, String resourceCreated) { + this.createStep = createStep; + this.updateStep = updateStep; this.deprovisionStep = deprovisionStep; + this.resourceCreated = resourceCreated; } /** - * Returns the workflowStep for the given enum Constant - * @return the workflowStep of this data. + * Returns the create step for the given enum Constant + * @return the create step of this data. */ - public String getWorkflowStep() { - return workflowStep; + public String getCreateStep() { + return createStep; } /** - * Returns the resourceCreated for the given enum Constant - * @return the resourceCreated of this data. + * Returns the updateStep for the given enum Constant + * @return the updateStep of this data. */ - public String getResourceCreated() { - return resourceCreated; + public String getUpdateStep() { + return updateStep; } /** @@ -118,6 +134,14 @@ public String getDeprovisionStep() { return deprovisionStep; } + /** + * Returns the resourceCreated for the given enum Constant + * @return the resourceCreated of this data. + */ + public String getResourceCreated() { + return resourceCreated; + } + /** * Gets the resources created type based on the workflowStep. * @param workflowStep workflow step name @@ -127,7 +151,9 @@ public String getDeprovisionStep() { public static String getResourceByWorkflowStep(String workflowStep) throws FlowFrameworkException { if (workflowStep != null && !workflowStep.isEmpty()) { for (WorkflowResources mapping : values()) { - if (workflowStep.equals(mapping.getWorkflowStep()) || workflowStep.equals(mapping.getDeprovisionStep())) { + if (workflowStep.equals(mapping.getCreateStep()) + || workflowStep.equals(mapping.getDeprovisionStep()) + || workflowStep.equals(mapping.getUpdateStep())) { return mapping.getResourceCreated(); } } @@ -145,7 +171,7 @@ public static String getResourceByWorkflowStep(String workflowStep) throws FlowF public static String getDeprovisionStepByWorkflowStep(String workflowStep) throws FlowFrameworkException { if (workflowStep != null && !workflowStep.isEmpty()) { for (WorkflowResources mapping : values()) { - if (mapping.getWorkflowStep().equals(workflowStep)) { + if (mapping.getCreateStep().equals(workflowStep)) { return mapping.getDeprovisionStep(); } } @@ -154,6 +180,24 @@ public static String getDeprovisionStepByWorkflowStep(String workflowStep) throw throw new FlowFrameworkException("Unable to find deprovision step for step: " + workflowStep, RestStatus.BAD_REQUEST); } + /** + * Gets the update step type based on the workflowStep. + * @param workflowStep workflow step name + * @return the corresponding step to update + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static String getUpdateStepByWorkflowStep(String workflowStep) throws FlowFrameworkException { + if (workflowStep != null && !workflowStep.isEmpty()) { + for (WorkflowResources mapping : values()) { + if (mapping.getCreateStep().equals(workflowStep)) { + return mapping.getUpdateStep(); + } + } + } + logger.error("Unable to find update step for step: {}", workflowStep); + throw new FlowFrameworkException("Unable to find update step for step: " + workflowStep, RestStatus.BAD_REQUEST); + } + /** * Returns all the possible resource created types in enum * @return a set of all the resource created types diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 43c00d230..63ac6f7d4 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -690,7 +690,7 @@ public void updateResourceInStateIndex( Script script = new Script( ScriptType.INLINE, "painless", - "ctx._source.resources_created.add(params.newResource)", + "ctx._source.resources_created.add(params.newResource);", Collections.singletonMap("newResource", newResource.resourceMap()) ); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index bb604e8d6..032b4b898 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -39,6 +39,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; @@ -74,7 +75,7 @@ public List routes() { return List.of( // Create new workflow new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s", WORKFLOW_URI)), - // Update use case template + // Update use case template/ reprovision existing workflow new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, WORKFLOW_ID)) ); } @@ -84,8 +85,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); + boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); String useCase = request.param(USE_CASE); + // If provisioning, consume all other params and pass to provision transport action Map params = provision ? request.params() @@ -108,28 +111,32 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } if (!provision && !params.isEmpty()) { - // Consume params and content so custom exception is processed - params.keySet().stream().forEach(request::param); - request.content(); FlowFrameworkException ffe = new FlowFrameworkException( "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", RestStatus.BAD_REQUEST ); - return channel -> channel.sendResponse( - new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) - ); + return processError(ffe, params, request); } if (provision && updateFields) { - // Consume params and content so custom exception is processed - params.keySet().stream().forEach(request::param); - request.content(); FlowFrameworkException ffe = new FlowFrameworkException( "You can not use both the " + PROVISION_WORKFLOW + " and " + UPDATE_WORKFLOW_FIELDS + " parameters in the same request.", RestStatus.BAD_REQUEST ); - return channel -> channel.sendResponse( - new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + return processError(ffe, params, request); + } + if (reprovision && workflowId == null) { + FlowFrameworkException ffe = new FlowFrameworkException( + "You can not use the " + REPROVISION_WORKFLOW + " parameter to create a new template.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + if (reprovision && useCase != null) { + FlowFrameworkException ffe = new FlowFrameworkException( + "You cannot use the " + REPROVISION_WORKFLOW + " and " + USE_CASE + " parameters in the same request.", + RestStatus.BAD_REQUEST ); + return processError(ffe, params, request); } try { Template template; @@ -213,7 +220,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli provision || updateFields, params, useCase, - useCaseDefaultsMap + useCaseDefaultsMap, + reprovision ); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { @@ -249,4 +257,13 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } } + + private RestChannelConsumer processError(FlowFrameworkException ffe, Map params, RestRequest request) { + // Consume params and content so custom exception is processed + params.keySet().stream().forEach(request::param); + request.content(); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 930feef90..ecb015ffc 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -240,6 +240,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { context.restore(); if (getResponse.isExists()) { + Template existingTemplate = Template.parse(getResponse.getSourceAsString()); Template template = isFieldUpdate ? Template.updateExistingTemplate(existingTemplate, templateWithUser) @@ -248,53 +249,82 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - // Ignore state index if updating fields - if (!isFieldUpdate) { - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - request.getWorkflowId(), - Map.ofEntries( - Map.entry(STATE_FIELD, State.NOT_STARTED), - Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) - ), - ActionListener.wrap(updateResponse -> { - logger.info( - "updated workflow {} state to {}", - request.getWorkflowId(), - State.NOT_STARTED.name() - ); - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - }, exception -> { - String errorMessage = "Failed to update workflow " - + request.getWorkflowId() - + " in template index"; - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) + + if (request.isReprovision()) { + + // Reprovision request + ReprovisionWorkflowRequest reprovisionRequest = new ReprovisionWorkflowRequest( + getResponse.getId(), + existingTemplate, + template + ); + logger.info("Reprovisioning parameter is set, continuing to reprovision workflow {}", getResponse.getId()); + client.execute( + ReprovisionWorkflowAction.INSTANCE, + reprovisionRequest, + ActionListener.wrap(reprovisionResponse -> { + listener.onResponse(new WorkflowResponse(reprovisionResponse.getWorkflowId())); + }, exception -> { + String errorMessage = "Reprovisioning failed for workflow " + workflowId; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } else { + + // Update existing entry, full document replacement + flowFrameworkIndicesHandler.updateTemplateInGlobalContext( + request.getWorkflowId(), + template, + ActionListener.wrap(response -> { + // Regular update, reset provisioning status, ignore state index if updating fields + if (!isFieldUpdate) { + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + request.getWorkflowId(), + Map.ofEntries( + Map.entry(STATE_FIELD, State.NOT_STARTED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) + ), + ActionListener.wrap(updateResponse -> { + logger.info( + "updated workflow {} state to {}", + request.getWorkflowId(), + State.NOT_STARTED.name() ); - } - }) - ); - } else { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - } - }, exception -> { - String errorMessage = "Failed to update use case template " + request.getWorkflowId(); - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - } - }), - isFieldUpdate - ); + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + }, exception -> { + String errorMessage = "Failed to update workflow " + + request.getWorkflowId() + + " in template index"; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) + ); + } + }) + ); + } else { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } + }, exception -> { + String errorMessage = "Failed to update use case template " + request.getWorkflowId(); + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }), + isFieldUpdate + ); + } } else { String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; logger.error(errorMessage); diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowAction.java new file mode 100644 index 000000000..0c159837e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowAction.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestCreateWorkflowAction + */ +public class ReprovisionWorkflowAction extends ActionType { + + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/reprovision"; + /** An instance of this action */ + public static final ReprovisionWorkflowAction INSTANCE = new ReprovisionWorkflowAction(); + + private ReprovisionWorkflowAction() { + super(NAME, WorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java new file mode 100644 index 000000000..f6cde633e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.flowframework.model.Template; + +import java.io.IOException; + +/** + * Transport request to reprovision a workflow + */ +public class ReprovisionWorkflowRequest extends ActionRequest { + + /** + * The workflow Id + */ + private String workflowId; + /** + * The original template + */ + private Template originalTemplate; + /** + * The updated template + */ + private Template updatedTemplate; + + /** + * Instantiates a new ReprovisionWorkflowRequest + * @param workflowId the workflow ID + * @param originalTemplate the original Template + * @param updatedTemplate the updated Template + */ + public ReprovisionWorkflowRequest(String workflowId, Template originalTemplate, Template updatedTemplate) { + this.workflowId = workflowId; + this.originalTemplate = originalTemplate; + this.updatedTemplate = updatedTemplate; + } + + /** + * Instantiates a new ReprovisionWorkflow request + * @param in The input stream to read from + * @throws IOException If the stream cannot be read properly + */ + public ReprovisionWorkflowRequest(StreamInput in) throws IOException { + super(in); + this.workflowId = in.readString(); + this.originalTemplate = Template.parse(in.readString()); + this.updatedTemplate = Template.parse(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(workflowId); + out.writeString(originalTemplate.toJson()); + out.writeString(updatedTemplate.toJson()); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + /** + * Gets the workflow Id of the request + * @return the workflow Id + */ + public String getWorkflowId() { + return this.workflowId; + } + + /** + * Gets the original template of the request + * @return the original template + */ + public Template getOriginalTemplate() { + return this.originalTemplate; + } + + /** + * Gets the updated template of the request + * @return the updated template + */ + public Template getUpdatedTemplate() { + return this.updatedTemplate; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java new file mode 100644 index 000000000..90fe8066c --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -0,0 +1,333 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.plugins.PluginsService; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; +import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; +import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +/** + * Transport Action to reprovision a provisioned template + */ +public class ReprovisionWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(ReprovisionWorkflowTransportAction.class); + + private final ThreadPool threadPool; + private final Client client; + private final WorkflowStepFactory workflowStepFactory; + private final WorkflowProcessSorter workflowProcessSorter; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private final FlowFrameworkSettings flowFrameworkSettings; + private final PluginsService pluginsService; + private final EncryptorUtils encryptorUtils; + + /** + * Instantiates a new ReprovisionWorkflowTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param threadPool The OpenSearch thread pool + * @param client The node client to retrieve a stored use case template + * @param workflowStepFactory The factory instantiating workflow steps + * @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes + * @param flowFrameworkIndicesHandler Class to handle all internal system indices actions + * @param flowFrameworkSettings Whether this API is enabled + * @param encryptorUtils Utility class to handle encryption/decryption + * @param pluginsService The Plugins Service + */ + @Inject + public ReprovisionWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + WorkflowStepFactory workflowStepFactory, + WorkflowProcessSorter workflowProcessSorter, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + FlowFrameworkSettings flowFrameworkSettings, + EncryptorUtils encryptorUtils, + PluginsService pluginsService + ) { + super(ReprovisionWorkflowAction.NAME, transportService, actionFilters, ReprovisionWorkflowRequest::new); + this.threadPool = threadPool; + this.client = client; + this.workflowStepFactory = workflowStepFactory; + this.workflowProcessSorter = workflowProcessSorter; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.flowFrameworkSettings = flowFrameworkSettings; + this.encryptorUtils = encryptorUtils; + this.pluginsService = pluginsService; + } + + @Override + protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionListener listener) { + + String workflowId = request.getWorkflowId(); + + // Retrieve state and resources created + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + logger.info("Querying state for workflow: {}", workflowId); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { + context.restore(); + + State currentState = State.valueOf(response.getWorkflowState().getState()); + if (State.PROVISIONING.equals(currentState) || State.NOT_STARTED.equals(currentState)) { + String errorMessage = "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: " + + workflowId; + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + + // Generate reprovision sequence + List resourceCreated = response.getWorkflowState().resourcesCreated(); + + // Original template is retrieved from index, attempt to decrypt any exisiting credentials before processing + Template originalTemplate = encryptorUtils.decryptTemplateCredentials(request.getOriginalTemplate()); + Template updatedTemplate = request.getUpdatedTemplate(); + + // Validate updated template prior to execution + Workflow provisionWorkflow = updatedTemplate.workflows().get(PROVISION_WORKFLOW); + List updatedProcessSequence = workflowProcessSorter.sortProcessNodes( + provisionWorkflow, + request.getWorkflowId(), + Collections.emptyMap() // TODO : Add suport to reprovision substitution templates + ); + + try { + workflowProcessSorter.validate(updatedProcessSequence, pluginsService); + } catch (Exception e) { + String errormessage = "Workflow validation failed for workflow " + request.getWorkflowId(); + logger.error(errormessage, e); + listener.onFailure(new FlowFrameworkException(errormessage, RestStatus.BAD_REQUEST)); + } + List reprovisionProcessSequence = workflowProcessSorter.createReprovisionSequence( + workflowId, + originalTemplate, + updatedTemplate, + resourceCreated + ); + + // Remove error field if any prior to subsequent execution + if (response.getWorkflowState().getError() != null) { + Script script = new Script( + ScriptType.INLINE, + "painless", + "if(ctx._source.containsKey('error')){ctx._source.remove('error')}", + Collections.emptyMap() + ); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( + WORKFLOW_STATE_INDEX, + workflowId, + script, + ActionListener.wrap(updateResponse -> { + + }, exception -> { + String errorMessage = "Failed to update workflow state: " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); + } + + // Update State Index, maintain resources created for subsequent execution + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.PROVISIONING), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, resourceCreated) + ), + ActionListener.wrap(updateResponse -> { + + logger.info("Updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); + + // Attach last provisioned time to updated template and execute reprovisioning + Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate) + .lastProvisionedTime(Instant.now()) + .build(); + executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); + + listener.onResponse(new WorkflowResponse(workflowId)); + + }, exception -> { + String errorMessage = "Failed to update workflow state: " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + String errorMessage = "Failed to get workflow state for workflow " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + })); + } catch (Exception e) { + String errorMessage = "Failed to get workflow state for workflow " + workflowId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + + } + + /** + * Retrieves a thread from the provision thread pool to execute a workflow + * @param workflowId The id of the workflow + * @param template The updated template to store upon successful execution + * @param workflowSequence The sorted workflow to execute + * @param listener ActionListener for any failures that don't get caught earlier in below step + */ + private void executeWorkflowAsync( + String workflowId, + Template template, + List workflowSequence, + ActionListener listener + ) { + try { + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(template, workflowSequence, workflowId); }); + } catch (Exception exception) { + listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); + } + } + + /** + * Executes the given workflow sequence + * @param template The template to store after reprovisioning completes successfully + * @param workflowSequence The topologically sorted workflow to execute + * @param workflowId The workflowId associated with the workflow that is executing + */ + private void executeWorkflow(Template template, List workflowSequence, String workflowId) { + String currentStepId = ""; + try { + Map> workflowFutureMap = new LinkedHashMap<>(); + for (ProcessNode processNode : workflowSequence) { + List predecessors = processNode.predecessors(); + logger.info( + "Queueing process [{}].{}", + processNode.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + + workflowFutureMap.put(processNode.id(), processNode.execute()); + } + + // Attempt to complete each workflow step future, may throw a ExecutionException if any step completes exceptionally + for (Map.Entry> e : workflowFutureMap.entrySet()) { + currentStepId = e.getKey(); + e.getValue().actionGet(); + } + + logger.info("Reprovisioning completed successfully for workflow {}", workflowId); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.COMPLETED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), + Map.entry(PROVISION_END_TIME_FIELD, Instant.now().toEpochMilli()) + ), + ActionListener.wrap(updateResponse -> { + + logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + + // Replace template document + flowFrameworkIndicesHandler.updateTemplateInGlobalContext( + workflowId, + template, + ActionListener.wrap(templateResponse -> { + logger.info("Updated template for {}", workflowId, State.COMPLETED); + }, exception -> { + String errorMessage = "Failed to update use case template for " + workflowId; + logger.error(errorMessage, exception); + }), + true // ignores NOT_STARTED state if request is to reprovision + ); + }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) + ); + } catch (Exception ex) { + RestStatus status; + if (ex instanceof FlowFrameworkException) { + status = ((FlowFrameworkException) ex).getRestStatus(); + } else { + status = ExceptionsHelper.status(ex); + } + logger.error("Reprovisioning failed for workflow {} during step {}.", workflowId, currentStepId, ex); + String errorMessage = (ex.getCause() == null ? ex.getMessage() : ex.getCause().getClass().getName()) + + " during step " + + currentStepId + + ", restStatus: " + + status.toString(); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.FAILED), + Map.entry(ERROR_FIELD, errorMessage), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.FAILED), + Map.entry(PROVISION_END_TIME_FIELD, Instant.now().toEpochMilli()) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", workflowId, State.FAILED); + }, exceptionState -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exceptionState); }) + ); + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index dc0543656..a258e6e10 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.Map; +import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; /** @@ -46,6 +47,11 @@ public class WorkflowRequest extends ActionRequest { */ private boolean provision; + /** + * Reprovision flag + */ + private boolean reprovision; + /** * Update Fields flag */ @@ -72,7 +78,7 @@ public class WorkflowRequest extends ActionRequest { * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap()); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap(), false); } /** @@ -82,7 +88,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params, null, Collections.emptyMap()); + this(workflowId, template, new String[] { "all" }, true, params, null, Collections.emptyMap(), false); } /** @@ -93,7 +99,17 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param defaultParams The parameters from the REST body when a use case is given */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String useCase, Map defaultParams) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), useCase, defaultParams); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), useCase, defaultParams, false); + } + + /** + * Instantiates a new WorkflowRequest, set validation to all, sets reprovision flag + * @param workflowId the documentId of the workflow + * @param template the updated template + * @param reprovision the reprovision flag + */ + public WorkflowRequest(String workflowId, Template template, boolean reprovision) { + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap(), reprovision); } /** @@ -105,6 +121,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key. * @param useCase default use case given * @param defaultParams the params to be used in the substitution based on the default use case. + * @param reprovision flag to indicate if request is to reprovision */ public WorkflowRequest( @Nullable String workflowId, @@ -113,7 +130,8 @@ public WorkflowRequest( boolean provisionOrUpdate, Map params, String useCase, - Map defaultParams + Map defaultParams, + boolean reprovision ) { this.workflowId = workflowId; this.template = template; @@ -126,6 +144,7 @@ public WorkflowRequest( this.params = this.updateFields ? Collections.emptyMap() : params; this.useCase = useCase; this.defaultParams = defaultParams; + this.reprovision = reprovision; } /** @@ -139,13 +158,18 @@ public WorkflowRequest(StreamInput in) throws IOException { String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); this.validation = in.readStringArray(); - boolean provisionOrUpdate = in.readBoolean(); - this.params = provisionOrUpdate ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap(); - this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS); + boolean provisionOrUpdateOrReprovision = in.readBoolean(); + this.params = provisionOrUpdateOrReprovision + ? in.readMap(StreamInput::readString, StreamInput::readString) + : Collections.emptyMap(); + this.provision = provisionOrUpdateOrReprovision + && !params.containsKey(UPDATE_WORKFLOW_FIELDS) + && !params.containsKey(REPROVISION_WORKFLOW); this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS)); if (this.updateFields) { this.params = Collections.emptyMap(); } + this.reprovision = !provision && Boolean.parseBoolean(params.get(REPROVISION_WORKFLOW)); } /** @@ -214,17 +238,27 @@ public Map getDefaultParams() { return Map.copyOf(this.defaultParams); } + /** + * Gets the reprovision flag + * @return the reprovision boolean + */ + public boolean isReprovision() { + return this.reprovision; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); out.writeStringArray(validation); - out.writeBoolean(provision || updateFields); + out.writeBoolean(provision || updateFields || reprovision); if (provision) { out.writeMap(params, StreamOutput::writeString, StreamOutput::writeString); } else if (updateFields) { out.writeMap(Map.of(UPDATE_WORKFLOW_FIELDS, "true"), StreamOutput::writeString, StreamOutput::writeString); + } else if (reprovision) { + out.writeMap(Map.of(REPROVISION_WORKFLOW, "true"), StreamOutput::writeString, StreamOutput::writeString); } } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 7cd8645fe..e20b2ed3b 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.util; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -498,4 +500,37 @@ public static T parseIfExists(Map inputs, String key, Class< throw new IllegalArgumentException("Unsupported type: " + type); } } + + /** + * Compares workflow node user inputs + * @param originalInputs the original node user inputs + * @param updatedInputs the updated node user inputs + * @throws Exception for issues processing map + * @return boolean if equivalent + */ + public static boolean userInputsEquals(Map originalInputs, Map updatedInputs) throws Exception { + String originalInputsJson = parseArbitraryStringToObjectMapToString(originalInputs); + String updatedInputsJson = parseArbitraryStringToObjectMapToString(updatedInputs); + JsonElement elem1 = JsonParser.parseString(originalInputsJson); + JsonElement elem2 = JsonParser.parseString(updatedInputsJson); + return elem1.equals(elem2); + } + + /** + * Flattens a nested map of settings, delimitted by a period + * @param prefix the setting prefix + * @param settings the nested setting map + * @param flattenedSettings the final flattend map of settings + */ + public static void flattenSettings(String prefix, Map settings, Map flattenedSettings) { + for (Map.Entry entry : settings.entrySet()) { + String key = prefix.isEmpty() ? entry.getKey() : prefix + "." + entry.getKey(); + Object value = entry.getValue(); + if (value instanceof Map) { + flattenSettings(key, (Map) value, flattenedSettings); + } else { + flattenedSettings.put(key, value.toString()); + } + } + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractUpdatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractUpdatePipelineStep.java new file mode 100644 index 000000000..e57796f02 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractUpdatePipelineStep.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; +import org.opensearch.flowframework.util.ParseUtils; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; + +/** + * Step to update either a search or ingest pipeline + */ +public abstract class AbstractUpdatePipelineStep implements WorkflowStep { + private static final Logger logger = LogManager.getLogger(AbstractUpdatePipelineStep.class); + + // Client to store a pipeline in the cluster state + private final ClusterAdminClient clusterAdminClient; + + /** + * Instantiate this class + * @param client client to access cluster admin client + */ + protected AbstractUpdatePipelineStep(Client client) { + this.clusterAdminClient = client.admin().cluster(); + } + + /** + * Executes a put search or ingest pipeline request + * @param pipelineId the pipeline id + * @param configuration the pipeline configuration bytes + * @param clusterAdminClient the cluster admin client + * @param listener listener + */ + public abstract void executePutPipelineRequest( + String pipelineId, + BytesReference configuration, + ClusterAdminClient clusterAdminClient, + ActionListener listener + ); + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture createPipelineFuture = PlainActionFuture.newFuture(); + + Set requiredKeys = Set.of(PIPELINE_ID, CONFIGURATIONS); + + // currently, we are supporting an optional param of model ID into the various processors + Set optionalKeys = Set.of(MODEL_ID); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String pipelineId = (String) inputs.get(PIPELINE_ID); + String configurations = (String) inputs.get(CONFIGURATIONS); + + // Special case for processors that have arrays that need to have the quotes around or + // backslashes around strings in array removed + String transformedJsonStringForStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(configurations); + + byte[] byteArr = transformedJsonStringForStringArray.getBytes(StandardCharsets.UTF_8); + BytesReference configurationsBytes = new BytesArray(byteArr); + + String pipelineToBeCreated = this.getName(); + ActionListener putPipelineActionListener = new ActionListener<>() { + + @Override + public void onResponse(AcknowledgedResponse acknowledgedResponse) { + + // Not necessary to update state index entry since the resource ID remains unchaged + String resourceName = getResourceByWorkflowStep(getName()); + logger.info("Successfully updated resource: {}", pipelineId); + createPipelineFuture.onResponse( + new WorkflowData(Map.of(resourceName, pipelineId), currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId()) + ); + } + + @Override + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed step " + pipelineToBeCreated : e.getMessage()); + logger.error(errorMessage, e); + createPipelineFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + } + + }; + + executePutPipelineRequest(pipelineId, configurationsBytes, clusterAdminClient, putPipelineActionListener); + + } catch (FlowFrameworkException e) { + createPipelineFuture.onFailure(e); + } + return createPipelineFuture; + + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java new file mode 100644 index 000000000..9d35a32ce --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; +import org.opensearch.flowframework.util.ParseUtils; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; + +/** + * Step to update index settings and mappings, currently only update settings is implemented + */ +public class UpdateIndexStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(UpdateIndexStep.class); + private final Client client; + + /** The name of this step */ + public static final String NAME = "update_index"; + + /** + * Instantiate this class + * + * @param client Client to update an index + */ + public UpdateIndexStep(Client client) { + this.client = client; + } + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture updateIndexFuture = PlainActionFuture.newFuture(); + + Set requiredKeys = Set.of(INDEX_NAME, CONFIGURATIONS); + Set optionalKeys = Collections.emptySet(); + + try { + + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String indexName = (String) inputs.get(INDEX_NAME); + String configurations = (String) inputs.get(CONFIGURATIONS); + byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8); + BytesReference configurationsBytes = new BytesArray(byteArr); + + UpdateSettingsRequest updateSettingsRequest = new UpdateSettingsRequest(indexName); + + if (configurations.isEmpty()) { + String errorMessage = "Failed to update index settings for index " + indexName + ", index configuration is not given"; + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } else { + + Map sourceAsMap = XContentHelper.convertToMap(configurationsBytes, false, MediaTypeRegistry.JSON).v2(); + + // TODO : Add support to update index mappings + + // extract index settings from configuration + if (!sourceAsMap.containsKey("settings")) { + String errorMessage = "Failed to update index settings for index " + + indexName + + ", settings are not found in the index configuration"; + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } else { + + @SuppressWarnings("unchecked") + Map updatedSettings = (Map) sourceAsMap.get("settings"); + + // check if settings are flattened or expanded + Map flattenedSettings = new HashMap<>(); + if (updatedSettings.containsKey("index")) { + ParseUtils.flattenSettings("", updatedSettings, flattenedSettings); + } else { + flattenedSettings.putAll(updatedSettings); + } + + Map filteredSettings = new HashMap<>(); + + // Retrieve current Index Settings + GetSettingsRequest getSettingsRequest = new GetSettingsRequest(); + getSettingsRequest.indices(indexName); + getSettingsRequest.includeDefaults(true); + client.admin().indices().getSettings(getSettingsRequest, ActionListener.wrap(response -> { + Map indexToSettings = new HashMap(response.getIndexToSettings()); + + // Include in the update request only settings with updated values + Settings currentIndexSettings = indexToSettings.get(indexName); + for (Map.Entry e : flattenedSettings.entrySet()) { + String val = e.getValue().toString(); + if (!val.equals(currentIndexSettings.get(e.getKey()))) { + filteredSettings.put(e.getKey(), e.getValue()); + } + } + }, ex -> { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to retrieve the index settings for index " + indexName : e.getMessage()); + logger.error(errorMessage, e); + updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + })); + + updateSettingsRequest.settings(filteredSettings); + } + } + + if (updateSettingsRequest.settings().size() == 0) { + String errorMessage = "Failed to update index settings for index " + indexName + ", no settings have been updated"; + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } else { + client.admin().indices().updateSettings(updateSettingsRequest, ActionListener.wrap(acknowledgedResponse -> { + String resourceName = getResourceByWorkflowStep(getName()); + logger.info("Updated index settings for index {}", indexName); + updateIndexFuture.onResponse( + new WorkflowData(Map.of(resourceName, indexName), currentNodeInputs.getWorkflowId(), currentNodeId) + ); + + }, ex -> { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to update the index settings for index " + indexName : e.getMessage()); + logger.error(errorMessage, e); + updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + })); + } + } catch (Exception e) { + updateIndexFuture.onFailure(new WorkflowStepException(e.getMessage(), ExceptionsHelper.status(e))); + } + + return updateIndexFuture; + + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStep.java new file mode 100644 index 000000000..f1107c1c5 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStep.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; + +/** + * Step to update an ingest pipeline + */ +public class UpdateIngestPipelineStep extends AbstractUpdatePipelineStep { + private static final Logger logger = LogManager.getLogger(UpdateIngestPipelineStep.class); + + /** The name of this step, used as a key in the {@link WorkflowStepFactory} */ + public static final String NAME = "update_ingest_pipeline"; + + /** + * Instantiates a new UpdateIngestPipelineStep + * @param client The client to create a pipeline and store workflow data into the global context index + */ + public UpdateIngestPipelineStep(Client client) { + super(client); + } + + @Override + public void executePutPipelineRequest( + String pipelineId, + BytesReference configuration, + ClusterAdminClient clusterAdminClient, + ActionListener listener + ) { + PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON); + clusterAdminClient.putPipeline(putPipelineRequest, listener); + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStep.java new file mode 100644 index 000000000..0a6048eed --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStep.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; + +/** + * Step to update a search pipeline + */ +public class UpdateSearchPipelineStep extends AbstractUpdatePipelineStep { + private static final Logger logger = LogManager.getLogger(UpdateSearchPipelineStep.class); + + /** The name of this step, used as a key in the {@link WorkflowStepFactory} */ + public static final String NAME = "update_search_pipeline"; + + /** + * Instantiates a new UpdateSearchPipelineStep + * @param client The client to create a pipeline and store workflow data into the global context index + */ + public UpdateSearchPipelineStep(Client client) { + super(client); + } + + @Override + public void executePutPipelineRequest( + String pipelineId, + BytesReference configuration, + ClusterAdminClient clusterAdminClient, + ActionListener listener + ) { + PutSearchPipelineRequest putSearchPipelineRequest = new PutSearchPipelineRequest(pipelineId, configuration, XContentType.JSON); + clusterAdminClient.putSearchPipeline(putSearchPipelineRequest, listener); + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowDataStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowDataStep.java new file mode 100644 index 000000000..2d8ed0dcb --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowDataStep.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.flowframework.model.ResourceCreated; + +import java.util.Map; + +/** + * Internal step to pass created resources to dependent nodes. Only used in reprovisioning + */ +public class WorkflowDataStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(WorkflowDataStep.class); + private final ResourceCreated resourceCreated; + + /** The name of this step */ + public static final String NAME = "workflow_data_step"; + + /** + * Instantiate this class + * @param resourceCreated the created resource + */ + public WorkflowDataStep(ResourceCreated resourceCreated) { + this.resourceCreated = resourceCreated; + } + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture workflowDataFuture = PlainActionFuture.newFuture(); + workflowDataFuture.onResponse( + new WorkflowData( + Map.of(resourceCreated.resourceType(), resourceCreated.resourceId()), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + return workflowDataFuture; + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index a13240642..88a5d67e5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -13,10 +13,14 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; import org.opensearch.threadpool.ThreadPool; @@ -35,6 +39,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; @@ -143,6 +148,262 @@ public List sortProcessNodes(Workflow workflow, String workflowId, return nodes; } + /** + * Sort an updated workflow into a topologically sorted list of create/update process nodes + * @param workflowId the workflow ID associated with the template + * @param originalTemplate the original template currently indexed + * @param updatedTemplate the updated template to be executed + * @param resourcesCreated the resources previously created for the workflow + * @throws Exception for issues creating the reprovision sequence + * @return A list of ProcessNode + */ + public List createReprovisionSequence( + String workflowId, + Template originalTemplate, + Template updatedTemplate, + List resourcesCreated + ) throws Exception { + + Workflow updatedWorkflow = updatedTemplate.workflows().get(PROVISION_WORKFLOW); + if (updatedWorkflow.nodes().size() > this.maxWorkflowSteps) { + throw new FlowFrameworkException( + "Workflow " + + workflowId + + " has " + + updatedWorkflow.nodes().size() + + " nodes, which exceeds the maximum of " + + this.maxWorkflowSteps + + ". Change the setting [" + + MAX_WORKFLOW_STEPS.getKey() + + "] to increase this.", + RestStatus.BAD_REQUEST + ); + } + + // Topologically sort the updated workflow + List sortedUpdatedNodes = topologicalSort(updatedWorkflow.nodes(), updatedWorkflow.edges()); + + // Convert original template into node id map + Map originalTemplateMap = originalTemplate.workflows() + .get(PROVISION_WORKFLOW) + .nodes() + .stream() + .collect(Collectors.toMap(WorkflowNode::id, node -> node)); + + // Temporarily block node deletions until fine-grained deprovisioning is implemented + if (!originalTemplateMap.values().stream().allMatch(sortedUpdatedNodes::contains)) { + throw new FlowFrameworkException( + "Workflow Step deletion is not supported when reprovisioning a template.", + RestStatus.BAD_REQUEST + ); + } + + List reprovisionSequence = createReprovisionSequence( + workflowId, + updatedWorkflow, + sortedUpdatedNodes, + originalTemplateMap, + resourcesCreated + ); + + // If the reprovision sequence consists entirely of WorkflowDataSteps, then no modifications were made to the exisiting template. + if (reprovisionSequence.stream().allMatch(n -> n.workflowStep().getName().equals(WorkflowDataStep.NAME))) { + throw new FlowFrameworkException("Template does not contain any modifications", RestStatus.BAD_REQUEST); + } + + return reprovisionSequence; + } + + /** + * Compares an original and upated template and creates a list of update, create or workflowdatastep nodes + * @param workflowId the workflow ID associated with the template + * @param updatedWorkflow the updated workflow to be processed + * @param sortedUpdatedNodes the topologically sorted updated template nodes + * @param originalTemplateMap a map of node Id to workflow node of the original template + * @param resourcesCreated a list of resources created for this template + * @return a list of process node representing the reprovision sequence + * @throws Exception for issues creating the reprovision sequence + */ + private List createReprovisionSequence( + String workflowId, + Workflow updatedWorkflow, + List sortedUpdatedNodes, + Map originalTemplateMap, + List resourcesCreated + ) throws Exception { + Map idToNodeMap = new HashMap<>(); + List reprovisionSequence = new ArrayList<>(); + + for (WorkflowNode node : sortedUpdatedNodes) { + ProcessNode processNode = createProcessNode( + updatedWorkflow, + node, + originalTemplateMap, + resourcesCreated, + workflowId, + idToNodeMap + ); + if (processNode != null) { + idToNodeMap.put(processNode.id(), processNode); + reprovisionSequence.add(processNode); + } + } + + return reprovisionSequence; + } + + /** + * Determines which type of process node to create for a reprovision sequence + * @param updatedWorkflow the updated workflow to be processed + * @param node the current workflow node + * @param originalTemplateMap a map of node Id to workflow node of the original template + * @param resourcesCreated a list of resources created for this template + * @param workflowId the workflow ID associated with the template + * @param idToNodeMap a map of the current reprovision sequence + * @return a ProcessNode + * @throws Exception for issues creating the process node + */ + private ProcessNode createProcessNode( + Workflow updatedWorkflow, + WorkflowNode node, + Map originalTemplateMap, + List resourcesCreated, + String workflowId, + Map idToNodeMap + ) throws Exception { + WorkflowData data = new WorkflowData(node.userInputs(), updatedWorkflow.userParams(), workflowId, node.id()); + List predecessorNodes = updatedWorkflow.edges() + .stream() + .filter(e -> e.destination().equals(node.id())) + // since we are iterating in topological order we know all predecessors will be in the map + .map(e -> idToNodeMap.get(e.source())) + .collect(Collectors.toList()); + TimeValue nodeTimeout = parseTimeout(node); + + if (!originalTemplateMap.containsKey(node.id())) { + // Case 1: Additive modification, create new node + return createNewProcessNode(node, data, predecessorNodes, nodeTimeout); + } else { + WorkflowNode originalNode = originalTemplateMap.get(node.id()); + if (shouldUpdateNode(node, originalNode)) { + // Case 2: Existing modification, create update step + return createUpdateProcessNode(node, data, predecessorNodes, nodeTimeout); + } else { + // Case 4: No modification to existing node, create proxy step + return createWorkflowDataStepNode(node, data, predecessorNodes, nodeTimeout, resourcesCreated); + } + } + } + + /** + * Creates a process node to create a new resource + * @param node the current node + * @param data the current node data + * @param predecessorNodes the current node predecessors + * @param nodeTimeout the current node timeout + * @return a Process Node + */ + private ProcessNode createNewProcessNode( + WorkflowNode node, + WorkflowData data, + List predecessorNodes, + TimeValue nodeTimeout + ) { + WorkflowStep step = workflowStepFactory.createStep(node.type()); + return new ProcessNode( + node.id(), + step, + node.previousNodeInputs(), + Collections.emptyMap(), // TODO Add support to reprovision substitution templates + data, + predecessorNodes, + threadPool, + PROVISION_WORKFLOW_THREAD_POOL, + nodeTimeout + ); + } + + /** + * Creates a process node to update an existing resource + * @param node the current node + * @param data the current node data + * @param predecessorNodes the current node predecessors + * @param nodeTimeout the current node timeout + * @return a ProcessNode + * @throws FlowFrameworkException if the current node does not support updates + */ + private ProcessNode createUpdateProcessNode( + WorkflowNode node, + WorkflowData data, + List predecessorNodes, + TimeValue nodeTimeout + ) throws FlowFrameworkException { + String updateStepName = WorkflowResources.getUpdateStepByWorkflowStep(node.type()); + if (updateStepName != null) { + WorkflowStep step = workflowStepFactory.createStep(updateStepName); + return new ProcessNode( + node.id(), + step, + node.previousNodeInputs(), + Collections.emptyMap(), // TODO Add support to reprovision substitution templates + data, + predecessorNodes, + threadPool, + PROVISION_WORKFLOW_THREAD_POOL, + nodeTimeout + ); + } else { + // Case 3 : Cannot update step (not supported) + throw new FlowFrameworkException( + "Workflow Step " + node.id() + " does not support updates when reprovisioning.", + RestStatus.BAD_REQUEST + ); + } + } + + /** + * Creates a process node to pass workflow data to the next step in the reprovision sequence + * @param node the current node + * @param data the current node data + * @param predecessorNodes the current node predecessors + * @param nodeTimeout the current node timeout + * @param resourcesCreated the list of resources created for the template assoicated with this node + * @return a Process node + */ + private ProcessNode createWorkflowDataStepNode( + WorkflowNode node, + WorkflowData data, + List predecessorNodes, + TimeValue nodeTimeout, + List resourcesCreated + ) { + ResourceCreated nodeResource = resourcesCreated.stream() + .filter(rc -> rc.workflowStepId().equals(node.id())) + .findFirst() + .orElse(null); + + if (nodeResource != null) { + return new ProcessNode( + node.id(), + new WorkflowDataStep(nodeResource), + node.previousNodeInputs(), + Collections.emptyMap(), + data, + predecessorNodes, + threadPool, + PROVISION_WORKFLOW_THREAD_POOL, + nodeTimeout + ); + } else { + return null; + } + } + + private boolean shouldUpdateNode(WorkflowNode node, WorkflowNode originalNode) throws Exception { + return !node.previousNodeInputs().equals(originalNode.previousNodeInputs()) + || !ParseUtils.userInputsEquals(originalNode.userInputs(), node.userInputs()); + } + /** * Validates inputs and ensures the required plugins are installed for each step in a topologically sorted graph * @param processNodes the topologically sorted list of process nodes diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 5f99b7289..9fc8baada 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -118,6 +118,10 @@ public WorkflowStepFactory( stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); stepMap.put(DeleteIngestPipelineStep.NAME, () -> new DeleteIngestPipelineStep(client)); stepMap.put(CreateSearchPipelineStep.NAME, () -> new CreateSearchPipelineStep(client, flowFrameworkIndicesHandler)); + stepMap.put(UpdateIngestPipelineStep.NAME, () -> new UpdateIngestPipelineStep(client)); + stepMap.put(UpdateSearchPipelineStep.NAME, () -> new UpdateSearchPipelineStep(client)); + stepMap.put(UpdateIndexStep.NAME, () -> new UpdateIndexStep(client)); + stepMap.put(DeleteSearchPipelineStep.NAME, () -> new DeleteSearchPipelineStep(client)); } @@ -256,6 +260,27 @@ public enum WorkflowSteps { null ), + /** Update Ingest Pipeline Step */ + UPDATE_INGEST_PIPELINE( + UpdateIngestPipelineStep.NAME, + List.of(PIPELINE_ID, CONFIGURATIONS), + List.of(PIPELINE_ID), + Collections.emptyList(), + null + ), + + /** Update Search Pipeline Step */ + UPDATE_SEARCH_PIPELINE( + UpdateSearchPipelineStep.NAME, + List.of(PIPELINE_ID, CONFIGURATIONS), + List.of(PIPELINE_ID), + Collections.emptyList(), + null + ), + + /** Update Index Step */ + UPDATE_INDEX(UpdateIndexStep.NAME, List.of(INDEX_NAME, CONFIGURATIONS), List.of(INDEX_NAME), Collections.emptyList(), null), + /** Delete Search Pipeline Step */ DELETE_SEARCH_PIPELINE( DeleteSearchPipelineStep.NAME, diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 86224ca26..8f540868d 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -85,7 +85,7 @@ public void testPlugin() throws IOException { ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(9, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(9, ffp.getActions().size()); + assertEquals(10, ffp.getActions().size()); assertEquals(3, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); diff --git a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java index dadfd6d24..c24d7817c 100644 --- a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java +++ b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java @@ -23,7 +23,7 @@ public void setUp() throws Exception { } public void testParseFeature() throws IOException { - String workflowStepName = CREATE_CONNECTOR.getWorkflowStep(); + String workflowStepName = CREATE_CONNECTOR.getCreateStep(); String resourceType = getResourceByWorkflowStep(workflowStepName); ResourceCreated resourceCreated = new ResourceCreated(workflowStepName, "workflow_step_1", resourceType, "L85p1IsBbfF"); assertEquals(workflowStepName, resourceCreated.workflowStepName()); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index e685e07b9..37526f1b0 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -47,7 +47,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(21, validator.getWorkflowStepValidators().size()); + assertEquals(24, validator.getWorkflowStepValidators().size()); } public void testWorkflowStepFactoryHasValidators() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 7e537566c..e4f22e947 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -36,6 +36,7 @@ import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; @@ -160,6 +161,23 @@ public void testCreateWorkflowRequestWithUpdateAndProvision() throws Exception { ); } + public void testCreateWorkflowRequestWithCreateAndReprovision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(REPROVISION_WORKFLOW, "true"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains("You can not use the " + REPROVISION_WORKFLOW + " parameter to create a new template.") + ); + } + public void testCreateWorkflowRequestWithUpdateAndParams() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 265d1d52d..90b60d1d3 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -213,7 +213,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), null, Collections.emptyMap()); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), null, Collections.emptyMap(), false); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -257,7 +257,8 @@ public void testFailedToCreateNewWorkflow() { false, Collections.emptyMap(), null, - Collections.emptyMap() + Collections.emptyMap(), + false ); // Bypass checkMaxWorkflows and force onResponse @@ -296,7 +297,8 @@ public void testCreateNewWorkflow() { false, Collections.emptyMap(), null, - Collections.emptyMap() + Collections.emptyMap(), + false ); // Bypass checkMaxWorkflows and force onResponse @@ -336,6 +338,78 @@ public void testCreateNewWorkflow() { assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } + public void testUpdateWorkflowWithReprovision() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap(), + true + ); + + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(template.toJson()); + getListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new WorkflowResponse("1")); + return null; + }).when(client).execute(any(), any(), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + + assertEquals("1", responseCaptor.getValue().getWorkflowId()); + } + + public void testFailedToUpdateWorkflowWithReprovision() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap(), + true + ); + + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(template.toJson()); + getListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onFailure(new Exception("failed")); + return null; + }).when(client).execute(any(), any(), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(responseCaptor.capture()); + + assertEquals("Reprovisioning failed for workflow 1", responseCaptor.getValue().getMessage()); + } + public void testFailedToUpdateWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); @@ -511,7 +585,8 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc true, Collections.emptyMap(), null, - Collections.emptyMap() + Collections.emptyMap(), + false ); // Bypass checkMaxWorkflows and force onResponse @@ -572,7 +647,8 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() true, Collections.emptyMap(), null, - Collections.emptyMap() + Collections.emptyMap(), + false ); // Bypass checkMaxWorkflows and force onResponse @@ -620,7 +696,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() private Template generateValidTemplate() { WorkflowNode createConnector = new WorkflowNode( "workflow_step_1", - CREATE_CONNECTOR.getWorkflowStep(), + CREATE_CONNECTOR.getCreateStep(), Collections.emptyMap(), Map.ofEntries( Map.entry("name", ""), @@ -634,13 +710,13 @@ private Template generateValidTemplate() { ); WorkflowNode registerModel = new WorkflowNode( "workflow_step_2", - REGISTER_REMOTE_MODEL.getWorkflowStep(), + REGISTER_REMOTE_MODEL.getCreateStep(), Map.ofEntries(Map.entry("workflow_step_1", CONNECTOR_ID)), Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) ); WorkflowNode deployModel = new WorkflowNode( "workflow_step_3", - DEPLOY_MODEL.getWorkflowStep(), + DEPLOY_MODEL.getCreateStep(), Map.ofEntries(Map.entry("workflow_step_2", MODEL_ID)), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java new file mode 100644 index 000000000..2448d9372 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class ReprovisionWorkflowRequestTests extends OpenSearchTestCase { + + private Template originalTemplate; + private Template updatedTemplate; + + @Override + public void setUp() throws Exception { + super.setUp(); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Collections.emptyMap(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Collections.emptyMap(), Map.of("baz", "qux")); + WorkflowNode nodeC = new WorkflowNode("C", "c-type", Collections.emptyMap(), Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + WorkflowEdge edgeBC = new WorkflowEdge("B", "C"); + Workflow originalWorkflow = new Workflow(Map.of("key", "value"), List.of(nodeA, nodeB), List.of(edgeAB)); + Workflow updatedWorkflow = new Workflow(Map.of("key", "value"), List.of(nodeA, nodeB, nodeC), List.of(edgeAB, edgeBC)); + + this.originalTemplate = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("workflow", originalWorkflow), + Collections.emptyMap(), + TestHelpers.randomUser(), + null, + null, + null + ); + + this.updatedTemplate = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("workflow", updatedWorkflow), + Collections.emptyMap(), + TestHelpers.randomUser(), + null, + null, + null + ); + } + + public void testReprovisionWorkflowRequest() throws IOException { + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest("123", originalTemplate, updatedTemplate); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + ReprovisionWorkflowRequest requestFromStreamInput = new ReprovisionWorkflowRequest(in); + assertEquals(request.getWorkflowId(), requestFromStreamInput.getWorkflowId()); + assertEquals(request.getOriginalTemplate().toJson(), requestFromStreamInput.getOriginalTemplate().toJson()); + assertEquals(request.getUpdatedTemplate().toJson(), requestFromStreamInput.getUpdatedTemplate().toJson()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java new file mode 100644 index 000000000..ab2485be4 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -0,0 +1,307 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.plugins.PluginsService; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ReprovisionWorkflowTransportActionTests extends OpenSearchTestCase { + + private TransportService transportService; + private ActionFilters actionFilters; + private ThreadPool threadPool; + private Client client; + private WorkflowStepFactory workflowStepFactory; + private WorkflowProcessSorter workflowProcessSorter; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private FlowFrameworkSettings flowFrameworkSettings; + private EncryptorUtils encryptorUtils; + private PluginsService pluginsService; + + private ReprovisionWorkflowTransportAction reprovisionWorkflowTransportAction; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.transportService = mock(TransportService.class); + this.actionFilters = mock(ActionFilters.class); + this.threadPool = mock(ThreadPool.class); + this.client = mock(Client.class); + this.workflowStepFactory = mock(WorkflowStepFactory.class); + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); + this.pluginsService = mock(PluginsService.class); + + this.reprovisionWorkflowTransportAction = new ReprovisionWorkflowTransportAction( + transportService, + actionFilters, + threadPool, + client, + workflowStepFactory, + workflowProcessSorter, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + encryptorUtils, + pluginsService + ); + + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testReprovisionWorkflow() throws Exception { + + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state and resources created + doAnswer(invocation -> { + + ActionListener listener = invocation.getArgument(2); + + WorkflowState state = mock(WorkflowState.class); + ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); + when(state.getState()).thenReturn(State.COMPLETED.toString()); + when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); + when(state.getError()).thenReturn(null); + listener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + // Stub reprovision sequence creation + when(workflowProcessSorter.createReprovisionSequence(any(), any(), any(), any())).thenReturn(List.of(mock(ProcessNode.class))); + + // Bypass updateFlowFrameworkSystemIndexDoc and stub on response + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(mock(UpdateResponse.class)); + return null; + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + } + + public void testReprovisionProvisioningWorkflow() throws Exception { + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state and resources created + doAnswer(invocation -> { + + ActionListener listener = invocation.getArgument(2); + + WorkflowState state = mock(WorkflowState.class); + ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); + when(state.getState()).thenReturn(State.PROVISIONING.toString()); + when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); + listener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: 1", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testReprovisionNotStartedWorkflow() throws Exception { + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state and resources created + doAnswer(invocation -> { + + ActionListener listener = invocation.getArgument(2); + + WorkflowState state = mock(WorkflowState.class); + ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); + when(state.getState()).thenReturn(State.NOT_STARTED.toString()); + when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); + listener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: 1", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testFailedStateUpdate() throws Exception { + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state and resources created + doAnswer(invocation -> { + + ActionListener listener = invocation.getArgument(2); + + WorkflowState state = mock(WorkflowState.class); + ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); + when(state.getState()).thenReturn(State.COMPLETED.toString()); + when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); + when(state.getError()).thenReturn(null); + listener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + // Stub reprovision sequence creation + when(workflowProcessSorter.createReprovisionSequence(any(), any(), any(), any())).thenReturn(List.of(mock(ProcessNode.class))); + + // Bypass updateFlowFrameworkSystemIndexDoc and stub on response + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new Exception("failed")); + return null; + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update workflow state: 1", exceptionCaptor.getValue().getMessage()); + } + + public void testFailedWorkflowStateRetrieval() throws Exception { + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state index retrieval failure + doAnswer(invocation -> { + + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("failed")); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to get workflow state for workflow 1", exceptionCaptor.getValue().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index e15199e06..c06ae2e36 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -206,7 +206,16 @@ public void testWorkflowRequestWithUseCaseAndParamsInBody() throws IOException { public void testWorkflowRequestWithParamsNoProvision() throws IOException { IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), null, Collections.emptyMap()) + () -> new WorkflowRequest( + "123", + template, + new String[] { "all" }, + false, + Map.of("foo", "bar"), + null, + Collections.emptyMap(), + false + ) ); assertEquals("Params may only be included when provisioning.", ex.getMessage()); } @@ -219,7 +228,8 @@ public void testWorkflowRequestWithOnlyUpdateParamNoProvision() throws IOExcepti true, Map.of(UPDATE_WORKFLOW_FIELDS, "true"), null, - Collections.emptyMap() + Collections.emptyMap(), + false ); assertNotNull(workflowRequest.getWorkflowId()); assertEquals(template, workflowRequest.getTemplate()); diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 193616b20..1cdb0c50e 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -13,8 +13,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParser.Token; +import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -25,6 +27,9 @@ import java.util.Map; import java.util.Set; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + public class ParseUtilsTests extends OpenSearchTestCase { public void testResourceToStringToJson() throws IOException { String json = ParseUtils.resourceToString("/template/finaltemplate.json"); @@ -253,4 +258,76 @@ public void testParseIfExistWhenWrongTypeIsPassed() { assertThrows(IllegalArgumentException.class, () -> ParseUtils.parseIfExists(inputs, "key1", Integer.class)); } + + public void testUserInputsEquals() throws Exception { + + Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); + Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + Map[] originalActions = new Map[] { + Map.ofEntries( + Map.entry(ConnectorAction.ACTION_TYPE_FIELD, ConnectorAction.ActionType.PREDICT.name()), + Map.entry(ConnectorAction.METHOD_FIELD, "post"), + Map.entry(ConnectorAction.URL_FIELD, "foo.test"), + Map.entry( + ConnectorAction.REQUEST_BODY_FIELD, + "{ \"model\": \"${parameters.model1}\", \"messages\": ${parameters.messages1} }" + ) + ) }; + + Map[] updatedActions = new Map[] { + Map.ofEntries( + Map.entry(ConnectorAction.ACTION_TYPE_FIELD, ConnectorAction.ActionType.PREDICT.name()), + Map.entry(ConnectorAction.METHOD_FIELD, "put"), + Map.entry(ConnectorAction.URL_FIELD, "bar.test"), + Map.entry( + ConnectorAction.REQUEST_BODY_FIELD, + "{ \"model\": \"${parameters.model2}\", \"messages\": ${parameters.messages2} }" + ) + ) }; + + Map originalInputs = Map.ofEntries( + Map.entry(CommonValue.NAME_FIELD, "test"), + Map.entry(CommonValue.DESCRIPTION_FIELD, "description"), + Map.entry(CommonValue.VERSION_FIELD, "1"), + Map.entry(CommonValue.PROTOCOL_FIELD, "test"), + Map.entry(CommonValue.PARAMETERS_FIELD, params), + Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), + Map.entry(CommonValue.ACTIONS_FIELD, originalActions) + ); + + Map updatedInputs = Map.ofEntries( + Map.entry(CommonValue.NAME_FIELD, "test"), + Map.entry(CommonValue.DESCRIPTION_FIELD, "description"), + Map.entry(CommonValue.VERSION_FIELD, "1"), + Map.entry(CommonValue.PROTOCOL_FIELD, "test"), + Map.entry(CommonValue.PARAMETERS_FIELD, params), + Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), + Map.entry(CommonValue.ACTIONS_FIELD, updatedActions) + ); + + assertFalse(ParseUtils.userInputsEquals(originalInputs, updatedInputs)); + + } + + public void testFlattenSettings() throws Exception { + + Map indexSettingsMap = new HashMap<>(); + indexSettingsMap.put( + "index", + Map.ofEntries( + Map.entry("knn", "true"), + Map.entry("number_of_shards", "2"), + Map.entry("number_of_replicas", "1"), + Map.entry("default_pipeline", "_none"), + Map.entry("search", Map.of("default_pipeine", "_none")) + ) + ); + Map flattenedSettings = new HashMap<>(); + ParseUtils.flattenSettings("", indexSettingsMap, flattenedSettings); + assertEquals(5, flattenedSettings.size()); + + // every setting should start with index + assertTrue(flattenedSettings.entrySet().stream().allMatch(x -> x.getKey().startsWith("index."))); + + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java new file mode 100644 index 000000000..7dade5607 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java @@ -0,0 +1,242 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class UpdateIndexStepTests extends OpenSearchTestCase { + + private Client client; + private AdminClient adminClient; + private IndicesAdminClient indicesAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + indicesAdminClient = mock(IndicesAdminClient.class); + + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + } + + public void testUpdateIndexStepWithUpdatedSettings() throws ExecutionException, InterruptedException, IOException { + + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + String indexName = "test-index"; + + // Create existing settings for default pipelines + Settings.Builder builder = Settings.builder(); + builder.put("index.number_of_shards", 2); + builder.put("index.number_of_replicas", 1); + builder.put("index.knn", true); + builder.put("index.default_pipeline", "ingest_pipeline_id"); + builder.put("index.search.default_pipeline", "search_pipeline_id"); + Map indexToSettings = new HashMap<>(); + indexToSettings.put(indexName, builder.build()); + + // Stub get index settings request/response + doAnswer(invocation -> { + ActionListener getSettingsResponseListener = invocation.getArgument(1); + getSettingsResponseListener.onResponse(new GetSettingsResponse(indexToSettings, indexToSettings)); + return null; + }).when(indicesAdminClient).getSettings(any(), any()); + + // validate update settings request content + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor updateSettingsRequestCaptor = ArgumentCaptor.forClass(UpdateSettingsRequest.class); + + // Configurations has updated search/ingest pipeline default values of _none + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"_none\",\"search\":{\"default_pipeline\":\"_none\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowData data = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)), + "test-id", + "test-node-id" + ); + PlainActionFuture future = updateIndexStep.execute( + data.getNodeId(), + data, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(indicesAdminClient, times(1)).getSettings(any(GetSettingsRequest.class), any()); + verify(indicesAdminClient, times(1)).updateSettings(updateSettingsRequestCaptor.capture(), any()); + + Settings settingsToUpdate = updateSettingsRequestCaptor.getValue().settings(); + assertEquals(2, settingsToUpdate.size()); + assertEquals("_none", settingsToUpdate.get("index.default_pipeline")); + assertEquals("_none", settingsToUpdate.get("index.search.default_pipeline")); + } + + public void testMissingSettings() throws InterruptedException { + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + String configurations = "{\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + + // Data with empty configuration field + WorkflowData incorrectData = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, "index-name"), Map.entry(CONFIGURATIONS, configurations)), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = updateIndexStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Failed to update index settings for index index-name, settings are not found in the index configuration", + exception.getCause().getMessage() + ); + } + + public void testEmptyConfiguration() throws InterruptedException { + + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + // Data with empty configuration field + WorkflowData incorrectData = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, "index-name"), Map.entry(CONFIGURATIONS, "")), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = updateIndexStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Failed to update index settings for index index-name, index configuration is not given", + exception.getCause().getMessage() + ); + } + + public void testMissingInputs() throws InterruptedException { + + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + // Data with missing configuration field + WorkflowData incorrectData = new WorkflowData(Map.ofEntries(Map.entry(INDEX_NAME, "index-name")), "test-id", "test-node-id"); + + PlainActionFuture future = updateIndexStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Missing required inputs [configurations] in workflow [test-id] node [test-node-id]", + exception.getCause().getMessage() + ); + + } + + public void testNoSettingsChanged() throws InterruptedException { + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + String indexName = "test-index"; + + // Create existing settings for default pipelines + Settings.Builder builder = Settings.builder(); + builder.put("index.number_of_shards", 2); + builder.put("index.number_of_replicas", 1); + builder.put("index.knn", true); + builder.put("index.default_pipeline", "ingest_pipeline_id"); + builder.put("index.search.default_pipeline", "search_pipeline_id"); + Map indexToSettings = new HashMap<>(); + indexToSettings.put(indexName, builder.build()); + + // Stub get index settings request/response + doAnswer(invocation -> { + ActionListener getSettingsResponseListener = invocation.getArgument(1); + getSettingsResponseListener.onResponse(new GetSettingsResponse(indexToSettings, indexToSettings)); + return null; + }).when(indicesAdminClient).getSettings(any(), any()); + + // validate update settings request content + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor updateSettingsRequestCaptor = ArgumentCaptor.forClass(UpdateSettingsRequest.class); + + // Configurations have no change + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"ingest_pipeline_id\",\"search\":{\"default_pipeline\":\"search_pipeline_id\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowData data = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)), + "test-id", + "test-node-id" + ); + PlainActionFuture future = updateIndexStep.execute( + data.getNodeId(), + data, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Failed to update index settings for index test-index, no settings have been updated", + exception.getCause().getMessage() + ); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStepTests.java new file mode 100644 index 000000000..75dc5b584 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/UpdateIngestPipelineStepTests.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class UpdateIngestPipelineStepTests extends OpenSearchTestCase { + + private WorkflowData inputData; + private WorkflowData outpuData; + private Client client; + private AdminClient adminClient; + private ClusterAdminClient clusterAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + String configurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; + inputData = new WorkflowData( + Map.ofEntries(Map.entry(CONFIGURATIONS, configurations), Map.entry(PIPELINE_ID, "pipelineId")), + "test-id", + "test-node-id" + ); + + // Set output data to returned pipelineId + outpuData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + } + + public void testUpdateIngestPipelineStep() throws InterruptedException, ExecutionException, IOException { + + UpdateIngestPipelineStep updateIngestPipelineStep = new UpdateIngestPipelineStep(client); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = updateIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return true + verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); + + assertTrue(future.isDone()); + assertEquals(outpuData.getContent(), future.get().getContent()); + } + + public void testUpdateIngestPipelineStepFailure() throws InterruptedException { + + UpdateIngestPipelineStep updateIngestPipelineStep = new UpdateIngestPipelineStep(client); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = updateIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return false + verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onFailure(new Exception("Failed step update_ingest_pipeline")); + + assertTrue(future.isDone()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals("Failed step update_ingest_pipeline", exception.getCause().getMessage()); + } + + public void testMissingData() throws InterruptedException { + UpdateIngestPipelineStep updateIngestPipelineStep = new UpdateIngestPipelineStep(client); + + // Data with missing input and output fields + WorkflowData incorrectData = new WorkflowData( + Map.ofEntries( + Map.entry("id", PIPELINE_ID), + Map.entry("description", "some description"), + Map.entry("type", "text_embedding"), + Map.entry(MODEL_ID, MODEL_ID) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = updateIngestPipelineStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Missing required inputs [configurations, pipeline_id] in workflow [test-id] node [test-node-id]", + exception.getCause().getMessage() + ); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStepTests.java new file mode 100644 index 000000000..214b47547 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/UpdateSearchPipelineStepTests.java @@ -0,0 +1,151 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class UpdateSearchPipelineStepTests extends OpenSearchTestCase { + + private WorkflowData inputData; + private WorkflowData outpuData; + private Client client; + private AdminClient adminClient; + private ClusterAdminClient clusterAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + String configurations = + "{\"response_processors\":[{\"retrieval_augmented_generation\":{\"context_field_list\":[\"text\"],\"user_instructions\":\"Generate a concise and informative answer in less than 100 words for the given question\",\"description\":\"Demo pipeline Using OpenAI Connector\",\"tag\":\"openai_pipeline_demo\",\"model_id\":\"tbFoNI4BW58L8XKV4RF3\",\"system_prompt\":\"You are a helpful assistant\"}}]}"; + inputData = new WorkflowData( + Map.ofEntries(Map.entry(CONFIGURATIONS, configurations), Map.entry(PIPELINE_ID, "pipelineId")), + "test-id", + "test-node-id" + ); + + // Set output data to returned pipelineId + outpuData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + } + + public void testUpdateSearchPipelineStep() throws InterruptedException, ExecutionException, IOException { + + UpdateSearchPipelineStep updateSearchPipelineStep = new UpdateSearchPipelineStep(client); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = updateSearchPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return true + verify(clusterAdminClient, times(1)).putSearchPipeline(any(PutSearchPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); + + assertTrue(future.isDone()); + assertEquals(outpuData.getContent(), future.get().getContent()); + + } + + public void testUpdateSearchPipelineStepFailure() throws InterruptedException { + + UpdateSearchPipelineStep updateSearchPipelineStep = new UpdateSearchPipelineStep(client); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = updateSearchPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return false + verify(clusterAdminClient, times(1)).putSearchPipeline(any(PutSearchPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onFailure(new Exception("Failed step update_search_pipeline")); + + assertTrue(future.isDone()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals("Failed step update_search_pipeline", exception.getCause().getMessage()); + } + + public void testMissingData() throws InterruptedException { + UpdateSearchPipelineStep updateSearchPipelineStep = new UpdateSearchPipelineStep(client); + // Data with missing input and output fields + WorkflowData incorrectData = new WorkflowData( + Map.ofEntries( + Map.entry("id", PIPELINE_ID), + Map.entry("description", "some description"), + Map.entry("type", "text_embedding"), + Map.entry(MODEL_ID, MODEL_ID) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = updateSearchPipelineStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals( + "Missing required inputs [configurations, pipeline_id] in workflow [test-id] node [test-node-id]", + exception.getCause().getMessage() + ); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataStepTests.java new file mode 100644 index 000000000..f6e751b6f --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataStepTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertTrue; + +public class WorkflowDataStepTests extends OpenSearchTestCase { + + private WorkflowDataStep workflowDataStep; + private WorkflowData inputData; + private WorkflowData outputData; + + private String workflowId = "test-id"; + private String workflowStepId = "test-node-id"; + private String resourceId = "resourceId"; + private String resourceType = "resourceType"; + + @Override + public void setUp() throws Exception { + super.setUp(); + + ResourceCreated resourceCreated = new ResourceCreated("step_name", workflowStepId, resourceType, resourceId); + this.workflowDataStep = new WorkflowDataStep(resourceCreated); + this.inputData = new WorkflowData(Map.of(), workflowId, workflowStepId); + this.outputData = new WorkflowData(Map.ofEntries(Map.entry(resourceType, resourceId)), workflowId, workflowStepId); + } + + public void testExecuteWorkflowDataStep() throws ExecutionException, InterruptedException { + + @SuppressWarnings("unchecked") + PlainActionFuture future = workflowDataStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(outputData.getContent().get(resourceType), future.get().getContent().get(resourceType)); + + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 32df9c86a..5d4624b7d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.Version; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -20,6 +21,8 @@ import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -33,6 +36,7 @@ import org.junit.BeforeClass; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -42,14 +46,17 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; @@ -57,6 +64,7 @@ import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndPreviousNodes; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -88,6 +96,12 @@ private static List parse(String json) throws IOException { private static FlowFrameworkSettings flowFrameworkSettings; private static WorkflowStepFactory workflowStepFactory; + private static Version templateVersion; + private static List compatibilityVersions; + private static Template reprovisionTemplate; + private static ResourceCreated pipelineResource; + private static ResourceCreated indexResource; + @BeforeClass public static void setup() throws IOException { AdminClient adminClient = mock(AdminClient.class); @@ -120,6 +134,53 @@ public static void setup() throws IOException { ); workflowStepFactory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings, client); workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, testThreadPool, flowFrameworkSettings); + + templateVersion = Version.fromString("1.0.0"); + compatibilityVersions = List.of(Version.fromString("2.1.6"), Version.fromString("3.0.0")); + + // Register Search Pipeline Step + String pipelineId = "pipelineId"; + String pipelineConfigurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; + WorkflowNode createSearchPipeline = new WorkflowNode( + "workflow_step_1", + CreateSearchPipelineStep.NAME, + Map.of(), + Map.ofEntries(Map.entry(CONFIGURATIONS, pipelineConfigurations), Map.entry(PIPELINE_ID, pipelineId)) + ); + + // Create Index Step + String indexName = "indexName"; + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"_none\",\"search\":{\"default_pipeline\":\"${{workflow_step_1.pipeline_id}}\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowNode createIndex = new WorkflowNode( + "workflow_step_2", + CreateIndexStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", PIPELINE_ID)), + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)) + ); + List nodes = List.of(createSearchPipeline, createIndex); + List edges = List.of(new WorkflowEdge("workflow_step_1", "workflow_step_2")); + Workflow workflow = new Workflow(Map.of(), nodes, edges); + Map uiMetadata = null; + + Instant now = Instant.now(); + reprovisionTemplate = new Template( + "test", + "a test template", + "test use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + uiMetadata, + null, + now, + now, + null + ); + + pipelineResource = new ResourceCreated(CreateSearchPipelineStep.NAME, "workflow_step_1", PIPELINE_ID, pipelineId); + indexResource = new ResourceCreated(CreateIndexStep.NAME, "workflow_step_2", INDEX_NAME, indexName); } @AfterClass @@ -559,4 +620,204 @@ public void testReadWorkflowStepFile_withDefaultTimeout() throws IOException { TimeValue registerRemoteModelTimeout = workflowProcessSorter.parseTimeout(registerModel); assertEquals(10, registerRemoteModelTimeout.getSeconds()); } + + public void testCreateReprovisionSequenceWithNoChange() { + FlowFrameworkException ex = expectThrows( + FlowFrameworkException.class, + () -> workflowProcessSorter.createReprovisionSequence( + "1", + reprovisionTemplate, + reprovisionTemplate, + List.of(pipelineResource, indexResource) + ) + ); + + assertEquals("Template does not contain any modifications", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + + } + + public void testCreateReprovisionSequenceWithDeletion() { + // Register Search Pipeline Step + String pipelineId = "pipelineId"; + String pipelineConfigurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; + WorkflowNode createSearchPipeline = new WorkflowNode( + "workflow_step_1", + CreateSearchPipelineStep.NAME, + Map.of(), + Map.ofEntries(Map.entry(CONFIGURATIONS, pipelineConfigurations), Map.entry(PIPELINE_ID, pipelineId)) + ); + List nodes = List.of(createSearchPipeline); + Workflow workflow = new Workflow(Map.of(), nodes, List.of()); + Map uiMetadata = null; + + Instant now = Instant.now(); + Template templateWithNoCreateIndex = new Template( + "test", + "a test template", + "test use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + uiMetadata, + null, + now, + now, + null + ); + + FlowFrameworkException ex = expectThrows( + FlowFrameworkException.class, + () -> workflowProcessSorter.createReprovisionSequence( + "1", + reprovisionTemplate, + templateWithNoCreateIndex, + List.of(pipelineResource, indexResource) + ) + ); + + assertEquals("Workflow Step deletion is not supported when reprovisioning a template.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + + } + + public void testCreateReprovisionSequenceWithAdditiveModification() throws Exception { + + // Register Search Pipeline Step + String pipelineId = "pipelineId"; + String pipelineConfigurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; + WorkflowNode createSearchPipeline = new WorkflowNode( + "workflow_step_1", + CreateSearchPipelineStep.NAME, + Map.of(), + Map.ofEntries(Map.entry(CONFIGURATIONS, pipelineConfigurations), Map.entry(PIPELINE_ID, pipelineId)) + ); + + // Create Index Step + String indexName = "indexName"; + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"_none\",\"search\":{\"default_pipeline\":\"${{workflow_step_1.pipeline_id}}\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowNode createIndex = new WorkflowNode( + "workflow_step_2", + CreateIndexStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", PIPELINE_ID)), + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)) + ); + + // Register ingest pipeline step + String ingestPipelineId = "pipelineId"; + String ingestPipelineConfigurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“sdsadsadasd”}}]}"; + WorkflowNode createIngestPipeline = new WorkflowNode( + "workflow_step_3", + CreateIngestPipelineStep.NAME, + Map.of(), + Map.ofEntries(Map.entry(CONFIGURATIONS, ingestPipelineConfigurations), Map.entry(PIPELINE_ID, ingestPipelineId)) + ); + + List nodes = List.of(createSearchPipeline, createIndex, createIngestPipeline); + List edges = List.of(new WorkflowEdge("workflow_step_1", "workflow_step_2")); + Workflow workflow = new Workflow(Map.of(), nodes, edges); + Map uiMetadata = null; + + Instant now = Instant.now(); + Template templateWithAdditiveModification = new Template( + "test", + "a test template", + "test use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + uiMetadata, + null, + now, + now, + null + ); + + List reprovisionSequence = workflowProcessSorter.createReprovisionSequence( + "1", + reprovisionTemplate, + templateWithAdditiveModification, + List.of(pipelineResource, indexResource) + ); + + // Should result in a 3 step sequence + assertTrue(reprovisionSequence.size() == 3); + List reprovisionWorkflowStepNames = reprovisionSequence.stream() + .map(ProcessNode::workflowStep) + .map(WorkflowStep::getName) + .collect(Collectors.toList()); + // Assert 1 create ingest pipeline step in the sequence + assertTrue(reprovisionWorkflowStepNames.contains(CreateIngestPipelineStep.NAME)); + // Assert 2 get resource steps in the sequence + assertTrue( + reprovisionWorkflowStepNames.stream().filter(x -> x.equals(WorkflowDataStep.NAME)).collect(Collectors.toList()).size() == 2 + ); + } + + public void testCreateReprovisionSequenceWithUpdates() throws Exception { + // Register Search Pipeline Step with modified model ID + String pipelineId = "pipelineId"; + String pipelineConfigurations = + "{“description”:“An neural ingest pipeline”,“processors”:[{“text_embedding”:{“field_map”:{“text”:“analyzed_text”},“model_id”:“abcdefgg”}}]}"; + WorkflowNode createSearchPipeline = new WorkflowNode( + "workflow_step_1", + CreateSearchPipelineStep.NAME, + Map.of(), + Map.ofEntries(Map.entry(CONFIGURATIONS, pipelineConfigurations), Map.entry(PIPELINE_ID, pipelineId)) + ); + + // Create Index Step with modifies index settings + String indexName = "indexName"; + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"test_pipeline_id\",\"search\":{\"default_pipeline\":\"${{workflow_step_1.pipeline_id}}\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowNode createIndex = new WorkflowNode( + "workflow_step_2", + CreateIndexStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", PIPELINE_ID)), + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)) + ); + + List nodes = List.of(createSearchPipeline, createIndex); + List edges = List.of(new WorkflowEdge("workflow_step_1", "workflow_step_2")); + Workflow workflow = new Workflow(Map.of(), nodes, edges); + Map uiMetadata = null; + + Instant now = Instant.now(); + Template templateWithModifiedNodes = new Template( + "test", + "a test template", + "test use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + uiMetadata, + null, + now, + now, + null + ); + + List reprovisionSequence = workflowProcessSorter.createReprovisionSequence( + "1", + reprovisionTemplate, + templateWithModifiedNodes, + List.of(pipelineResource, indexResource) + ); + + // Should result in a 2 step sequence + assertTrue(reprovisionSequence.size() == 2); + List reprovisionWorkflowStepNames = reprovisionSequence.stream() + .map(ProcessNode::workflowStep) + .map(WorkflowStep::getName) + .collect(Collectors.toList()); + // Assert 1 update search pipeline step in the sequence + assertTrue(reprovisionWorkflowStepNames.contains(UpdateSearchPipelineStep.NAME)); + // Assert update index step in the sequence + assertTrue(reprovisionWorkflowStepNames.contains(UpdateIndexStep.NAME)); + } + }