diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index e17bf2aa0..7cf3d2898 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -16,6 +16,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -24,8 +25,10 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; @@ -84,38 +87,35 @@ public void onFailure(Exception e) { String description = null; String version = null; String protocol = null; - Map parameters = new HashMap<>(); - Map credentials = new HashMap<>(); - List actions = new ArrayList<>(); + Map parameters = Collections.emptyMap(); + Map credentials = Collections.emptyMap(); + List actions = Collections.emptyList(); for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { + for (Entry entry : workflowData.getContent().entrySet()) { switch (entry.getKey()) { case NAME_FIELD: - name = (String) content.get(NAME_FIELD); + name = (String) entry.getValue(); break; case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); + description = (String) entry.getValue(); break; case VERSION_FIELD: - version = (String) content.get(VERSION_FIELD); + version = (String) entry.getValue(); break; case PROTOCOL_FIELD: - protocol = (String) content.get(PROTOCOL_FIELD); + protocol = (String) entry.getValue(); break; case PARAMETERS_FIELD: - parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + parameters = getParameterMap(entry.getValue()); break; case CREDENTIALS_FIELD: - credentials = (Map) content.get(CREDENTIALS_FIELD); + credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); break; case ACTIONS_FIELD: - actions = (List) content.get(ACTIONS_FIELD); + actions = getConnectorActionList(entry.getValue()); break; } - } } @@ -145,14 +145,20 @@ public String getName() { return NAME; } - private static Map getParameterMap(Map params) { + @SuppressWarnings("unchecked") + private static Map getStringToStringMap(Object map, String fieldName) { + if (map instanceof Map) { + return (Map) map; + } + throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); + } + private static Map getParameterMap(Object parameterMap) { Map parameters = new HashMap<>(); - for (String key : params.keySet()) { - String value = params.get(key); + for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - parameters.put(key, value); + parameters.put(entry.getKey(), entry.getValue()); return null; }); } catch (PrivilegedActionException e) { @@ -162,4 +168,29 @@ private static Map getParameterMap(Map params) { return parameters; } + private static List getConnectorActionList(Object array) { + if (!(array instanceof Map[])) { + throw new IllegalArgumentException("[" + ACTIONS_FIELD + "] must be an array of key-value maps."); + } + List actions = new ArrayList<>(); + for (Map map : (Map[]) array) { + String actionType = (String) map.get(ConnectorAction.ACTION_TYPE_FIELD); + if (actionType == null) { + throw new IllegalArgumentException("[" + ConnectorAction.ACTION_TYPE_FIELD + "] is missing."); + } + @SuppressWarnings("unchecked") + ConnectorAction action = ConnectorAction.builder() + .actionType(ActionType.valueOf(actionType.toUpperCase(Locale.ROOT))) + .method((String) map.get(ConnectorAction.METHOD_FIELD)) + .url((String) map.get(ConnectorAction.URL_FIELD)) + .headers((Map) map.get(ConnectorAction.HEADERS_FIELD)) + .requestBody((String) map.get(ConnectorAction.REQUEST_BODY_FIELD)) + .preProcessFunction((String) map.get(ConnectorAction.ACTION_PRE_PROCESS_FUNCTION)) + .postProcessFunction((String) map.get(ConnectorAction.ACTION_POST_PROCESS_FUNCTION)) + .build(); + actions.add(action); + } + return actions; + } + }