Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parse connector id from tool parameters map #846

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Features
- Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804))
- Adds user level access control based on backend roles ([#838](https://github.com/opensearch-project/flow-framework/pull/838))
- Support parsing connector_id when creating tools ([#846](https://github.com/opensearch-project/flow-framework/pull/846))

### Enhancements
### Bug Fixes
Expand Down
53 changes: 26 additions & 27 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
Expand Down Expand Up @@ -64,7 +65,15 @@ public PlainActionFuture<WorkflowData> execute(
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = ParseUtils.parseIfExists(inputs, INCLUDE_OUTPUT_IN_AGENT_RESPONSE, Boolean.class);
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

// parse connector_id, model_id and agent_id from previous node inputs
Set<String> toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID);
Map<String, String> parameters = getToolsParametersMap(
inputs.get(PARAMETERS_FIELD),
previousNodeInputs,
outputs,
toolParameterKeys
);

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

Expand Down Expand Up @@ -110,39 +119,29 @@ public String getName() {
private Map<String, String> getToolsParametersMap(
Object parameters,
Map<String, String> previousNodeInputs,
Map<String, WorkflowData> outputs
Map<String, WorkflowData> outputs,
Set<String> toolParameterKeys
) {
@SuppressWarnings("unchecked")
Map<String, String> parametersMap = (Map<String, String>) parameters;
Optional<String> previousNodeModel = previousNodeInputs.entrySet()
.stream()
.filter(e -> MODEL_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

Optional<String> previousNodeAgent = previousNodeInputs.entrySet()
.stream()
.filter(e -> AGENT_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when modelId is passed through previousSteps and not present already in parameters
if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
}
}

// Case when agentId is passed through previousSteps and not present already in parameters
if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) {
parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString());
for (String toolParameterKey : toolParameterKeys) {
Optional<String> previousNodeParameter = previousNodeInputs.entrySet()
.stream()
.filter(e -> toolParameterKey.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();

// Case when toolParameterKey is passed through previousSteps and not present already in parameters
if (previousNodeParameter.isPresent() && !parametersMap.containsKey(toolParameterKey)) {
WorkflowData previousNodeOutput = outputs.get(previousNodeParameter.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(toolParameterKey)) {
parametersMap.put(toolParameterKey, previousNodeOutput.getContent().get(toolParameterKey).toString());
}
}
}

// For other cases where modelId is already present in the parameters or not return the parametersMap
// For other cases where toolParameterKey is already present in the parameters or not return the parametersMap
return parametersMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,23 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
);
}

/**
* Helper method to invoke the Get Agent Rest Action
* @param client the rest client
* @return rest response
* @throws Exception
*/
protected Response getAgent(RestClient client, String agentId) throws Exception {
return TestHelpers.makeRequest(
client,
"GET",
String.format(Locale.ROOT, "/_plugins/_ml/agents/%s", agentId),
Collections.emptyMap(),
"",
null
);
}

/**
* Helper method to invoke the Search Workflow Rest Action with the given query
* @param client the rest client
Expand All @@ -668,7 +685,6 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
* @throws Exception if the request fails
*/
protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception {

// Execute search
Response restSearchResponse = TestHelpers.makeRequest(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
Expand All @@ -56,7 +57,6 @@ public void waitToStart() throws Exception {
}

public void testSearchWorkflows() throws Exception {

// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");
Response response = createWorkflow(client(), template);
Expand Down Expand Up @@ -228,7 +228,6 @@ public void testCreateAndProvisionCyclicalTemplate() throws Exception {
}

public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {

// Using a 3 step template to create a connector, register remote model and deploy model
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");

Expand Down Expand Up @@ -331,6 +330,79 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception {
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws Exception {
// Create a Workflow that has a credential 12345
Template template = TestHelpers.createTemplateFromFile("createconnector-createconnectortool-createflowagent.json");

// Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter
Response response = createWorkflowWithProvision(client(), template);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
Map<String, Object> responseMap = entityAsMap(response);
String workflowId = (String) responseMap.get(WORKFLOW_ID);
// wait and ensure state is completed/done
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); },
120,
TimeUnit.SECONDS
);

// Assert based on the agent-framework template
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 120);
Map<String, ResourceCreated> resourceMap = resourcesCreated.stream()
.collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r));
assertEquals(2, resourceMap.size());
assertTrue(resourceMap.containsKey("create_connector"));
assertTrue(resourceMap.containsKey("register_agent"));
String connectorId = resourceMap.get("create_connector").resourceId();
String agentId = resourceMap.get("register_agent").resourceId();
assertNotNull(connectorId);
assertNotNull(agentId);

// Assert that the agent contains the correct connector_id
response = getAgent(client(), agentId);
Map<String, Object> agentResponse = entityAsMap(response);
assertTrue(agentResponse.containsKey("tools"));
@SuppressWarnings("unchecked")
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) agentResponse.get("tools");
assertEquals(1, tools.size());
Map<String, Object> tool = tools.getFirst();
assertTrue(tool.containsKey("parameters"));
@SuppressWarnings("unchecked")
Map<String, String> toolParameters = (Map<String, String>) tool.get("parameters");
assertEquals(toolParameters, Map.of("connector_id", connectorId));

// Hit Deprovision API
// By design, this may not completely deprovision the first time if it takes >2s to process removals
Response deprovisionResponse = deprovisionWorkflow(client(), workflowId);
try {
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
} catch (ComparisonFailure e) {
// 202 return if still processing
assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse));
}
if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) {
// Short wait before we try again
Thread.sleep(10000);
deprovisionResponse = deprovisionWorkflow(client(), workflowId);
assertBusy(
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
30,
TimeUnit.SECONDS
);
}
assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse));
// Hit Delete API
Response deleteResponse = deleteWorkflow(client(), workflowId);
assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse));

// Verify state doc is deleted
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
}

public void testReprovisionWorkflow() throws Exception {
// Begin with a template to register a local pretrained model
Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json");
Expand Down Expand Up @@ -650,7 +722,6 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
}

public void testDefaultCohereUseCase() throws Exception {

// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCaseWithNoValidation(
client(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

public class ToolStepTests extends OpenSearchTestCase {
private WorkflowData inputData;
private WorkflowData inputDataWithConnectorId;
private WorkflowData inputDataWithModelId;
private WorkflowData inputDataWithAgentId;
private static final String mockedConnectorId = "mocked-connector-id";
private static final String mockedModelId = "mocked-model-id";
private static final String mockedAgentId = "mocked-agent-id";
private static final String createConnectorNodeId = "create_connector_node_id";
private static final String createModelNodeId = "create_model_node_id";
private static final String createAgentNodeId = "create_agent_node_id";

private WorkflowData boolStringInputData;
private WorkflowData badBoolInputData;

Expand All @@ -39,6 +52,9 @@ public void setUp() throws Exception {
"test-id",
"test-node-id"
);
inputDataWithConnectorId = new WorkflowData(Map.of(CONNECTOR_ID, mockedConnectorId), "test-id", createConnectorNodeId);
inputDataWithModelId = new WorkflowData(Map.of(MODEL_ID, mockedModelId), "test-id", createModelNodeId);
inputDataWithAgentId = new WorkflowData(Map.of(AGENT_ID, mockedAgentId), "test-id", createAgentNodeId);
boolStringInputData = new WorkflowData(
Map.ofEntries(
Map.entry("type", "type"),
Expand All @@ -63,7 +79,7 @@ public void setUp() throws Exception {
);
}

public void testTool() throws IOException, ExecutionException, InterruptedException {
public void testTool() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
Expand All @@ -88,7 +104,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
}

public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException {
public void testBoolParseFail() {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
Expand All @@ -100,10 +116,61 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup
);

assertTrue(future.isDone());
ExecutionException e = assertThrows(ExecutionException.class, () -> future.get());
ExecutionException e = assertThrows(ExecutionException.class, future::get);
assertEquals(WorkflowStepException.class, e.getCause().getClass());
WorkflowStepException w = (WorkflowStepException) e.getCause();
assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage());
assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus());
}

public void testToolWithConnectorId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createConnectorNodeId, inputDataWithConnectorId),
Map.of(createConnectorNodeId, CONNECTOR_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(CONNECTOR_ID, mockedConnectorId));
}

public void testToolWithModelId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createModelNodeId, inputDataWithModelId),
Map.of(createModelNodeId, MODEL_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(MODEL_ID, mockedModelId));
}

public void testToolWithAgentId() throws ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

PlainActionFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Map.of(createAgentNodeId, inputDataWithAgentId),
Map.of(createAgentNodeId, AGENT_ID),
Collections.emptyMap()
);
assertTrue(future.isDone());
Object tools = future.get().getContent().get("tools");
assertEquals(MLToolSpec.class, tools.getClass());
MLToolSpec mlToolSpec = (MLToolSpec) tools;
assertEquals(mlToolSpec.getParameters(), Map.of(AGENT_ID, mockedAgentId));
}
}
Loading
Loading