Skip to content

Commit

Permalink
Validate tenant id existence in workflow transport actions
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 13, 2024
1 parent 5b71c22 commit 60e881d
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.TenantAwareHelper;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -43,6 +45,7 @@ public class DeleteWorkflowTransportAction extends HandledTransportAction<Workfl
private final Logger logger = LogManager.getLogger(DeleteWorkflowTransportAction.class);

private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final FlowFrameworkSettings flowFrameworkSettings;
private final Client client;
private volatile Boolean filterByEnabled;
private final ClusterService clusterService;
Expand All @@ -53,6 +56,7 @@ public class DeleteWorkflowTransportAction extends HandledTransportAction<Workfl
* @param transportService the transport service
* @param actionFilters action filters
* @param flowFrameworkIndicesHandler The Flow Framework indices handler
* @param flowFrameworkSettings The Flow Framework settings
* @param client the OpenSearch Client
* @param clusterService the cluster service
* @param xContentRegistry contentRegister to parse get response
Expand All @@ -63,13 +67,15 @@ public DeleteWorkflowTransportAction(
TransportService transportService,
ActionFilters actionFilters,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkSettings flowFrameworkSettings,
Client client,
ClusterService clusterService,
NamedXContentRegistry xContentRegistry,
Settings settings
) {
super(DeleteWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.flowFrameworkSettings = flowFrameworkSettings;
this.client = client;
filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings);
this.xContentRegistry = xContentRegistry;
Expand Down Expand Up @@ -115,6 +121,10 @@ private void executeDeleteRequest(
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId();
if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) {
return;
}
DeleteRequest deleteRequest = new DeleteRequest(GLOBAL_CONTEXT_INDEX, workflowId);
logger.info("Deleting workflow doc: {}", workflowId);
client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.util.TenantAwareHelper;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
Expand Down Expand Up @@ -149,6 +150,10 @@ private void executeDeprovisionRequest(
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId();
if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) {
return;
}
String allowDelete = request.getParams().get(ALLOW_DELETE);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
logger.info("Querying state for workflow: {}", workflowId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.util.TenantAwareHelper;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -45,6 +47,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction<WorkflowR
private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class);

private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final FlowFrameworkSettings flowFrameworkSettings;
private final Client client;
private final EncryptorUtils encryptorUtils;
private volatile Boolean filterByEnabled;
Expand All @@ -56,6 +59,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction<WorkflowR
* @param transportService the transport service
* @param actionFilters action filters
* @param flowFrameworkIndicesHandler The Flow Framework indices handler
* @param flowFrameworkSettings The Flow Framework settings
* @param encryptorUtils Encryptor utils
* @param client the Opensearch Client
* @param xContentRegistry contentRegister to parse get response
Expand All @@ -67,6 +71,7 @@ public GetWorkflowTransportAction(
TransportService transportService,
ActionFilters actionFilters,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkSettings flowFrameworkSettings,
Client client,
EncryptorUtils encryptorUtils,
ClusterService clusterService,
Expand All @@ -75,6 +80,7 @@ public GetWorkflowTransportAction(
) {
super(GetWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.flowFrameworkSettings = flowFrameworkSettings;
this.client = client;
this.encryptorUtils = encryptorUtils;
filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings);
Expand Down Expand Up @@ -132,6 +138,10 @@ private void executeGetRequest(
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId();
if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) {
return;
}
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
logger.info("Querying workflow from global context: {}", workflowId);
client.get(getRequest, ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.TenantAwareHelper;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.plugins.PluginsService;
Expand Down Expand Up @@ -71,6 +73,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final FlowFrameworkSettings flowFrameworkSettings;
private final EncryptorUtils encryptorUtils;
private final PluginsService pluginsService;
private volatile Boolean filterByEnabled;
Expand All @@ -85,6 +88,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
* @param client The node client to retrieve a stored use case template
* @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes
* @param flowFrameworkIndicesHandler Class to handle all internal system indices actions
* @param flowFrameworkSettings The Flow Framework settings
* @param encryptorUtils Utility class to handle encryption/decryption
* @param pluginsService The Plugins Service
* @param clusterService the cluster service
Expand All @@ -99,6 +103,7 @@ public ProvisionWorkflowTransportAction(
Client client,
WorkflowProcessSorter workflowProcessSorter,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
FlowFrameworkSettings flowFrameworkSettings,
EncryptorUtils encryptorUtils,
PluginsService pluginsService,
ClusterService clusterService,
Expand All @@ -110,6 +115,7 @@ public ProvisionWorkflowTransportAction(
this.client = client;
this.workflowProcessSorter = workflowProcessSorter;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.flowFrameworkSettings = flowFrameworkSettings;
this.encryptorUtils = encryptorUtils;
this.pluginsService = pluginsService;
filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings);
Expand Down Expand Up @@ -167,6 +173,10 @@ private void executeProvisionRequest(
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId();
if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) {
return;
}
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
logger.info("Querying workflow from global context: {}", workflowId);
client.get(getRequest, ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.TenantAwareHelper;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand Down Expand Up @@ -168,6 +169,10 @@ private void executeReprovisionRequest(
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String tenantId = request.getUpdatedTemplate() == null ? null : request.getUpdatedTemplate().getTenantId();
if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) {
return;
}
logger.info("Querying state for workflow: {}", workflowId);
// Retrieve state and resources created
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.util;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.util.Objects;

public class TenantAwareHelper {

/**
* Validates the tenant ID based on the multi-tenancy feature setting.
*
* @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled.
* @param tenantId The tenant ID to validate.
* @param listener The action listener to handle failure cases.
* @return true if the tenant ID is valid or if multi-tenancy is not enabled; false if the tenant ID is invalid and multi-tenancy is enabled.
*/
public static boolean validateTenantId(boolean isMultiTenancyEnabled, String tenantId, ActionListener<?> listener) {
if (isMultiTenancyEnabled && tenantId == null) {
listener.onFailure(new FlowFrameworkException("You don't have permission to access this resource", RestStatus.FORBIDDEN));
return false;
} else {
return true;
}
}

/**
* Validates the tenant resource by comparing the tenant ID from the request with the tenant ID from the resource.
*
* @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled.
* @param tenantIdFromRequest The tenant ID obtained from the request.
* @param tenantIdFromResource The tenant ID obtained from the resource.
* @param listener The action listener to handle failure cases.
* @return true if the tenant IDs match or if multi-tenancy is not enabled; false if the tenant IDs do not match and multi-tenancy is enabled.
*/
public static boolean validateTenantResource(
boolean isMultiTenancyEnabled,
String tenantIdFromRequest,
String tenantIdFromResource,
ActionListener<?> listener
) {
if (isMultiTenancyEnabled && !Objects.equals(tenantIdFromRequest, tenantIdFromResource)) {
listener.onFailure(new FlowFrameworkException("You don't have permission to access this resource", RestStatus.FORBIDDEN));
return false;
} else return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ public class DeleteWorkflowTransportActionTests extends OpenSearchTestCase {
private Client client;
private DeleteWorkflowTransportAction deleteWorkflowTransportAction;
private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private FlowFrameworkSettings flowFrameworkSettings;

@Override
public void setUp() throws Exception {
super.setUp();
this.client = mock(Client.class);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.flowFrameworkSettings = mock(FlowFrameworkSettings.class);

ClusterService clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(
Expand All @@ -64,6 +66,7 @@ public void setUp() throws Exception {
mock(TransportService.class),
mock(ActionFilters.class),
flowFrameworkIndicesHandler,
flowFrameworkSettings,
client,
clusterService,
xContentRegistry(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class GetWorkflowTransportActionTests extends OpenSearchTestCase {
private NamedXContentRegistry xContentRegistry;
private GetWorkflowTransportAction getTemplateTransportAction;
private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private FlowFrameworkSettings flowFrameworkSettings;
private Template template;
private EncryptorUtils encryptorUtils;

Expand All @@ -71,6 +72,7 @@ public void setUp() throws Exception {
this.client = mock(Client.class);
this.xContentRegistry = mock(NamedXContentRegistry.class);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.flowFrameworkSettings = mock(FlowFrameworkSettings.class);
this.sdkClient = SdkClientFactory.createSdkClient(client, xContentRegistry, Collections.emptyMap());
this.encryptorUtils = new EncryptorUtils(mock(ClusterService.class), client, sdkClient, xContentRegistry);
ClusterService clusterService = mock(ClusterService.class);
Expand All @@ -84,6 +86,7 @@ public void setUp() throws Exception {
mock(TransportService.class),
mock(ActionFilters.class),
flowFrameworkIndicesHandler,
flowFrameworkSettings,
client,
encryptorUtils,
clusterService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase {
private ProvisionWorkflowTransportAction provisionWorkflowTransportAction;
private Template template;
private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private FlowFrameworkSettings flowFrameworkSettings;
private EncryptorUtils encryptorUtils;
private PluginsService pluginsService;

Expand All @@ -79,6 +80,7 @@ public void setUp() throws Exception {
this.client = mock(Client.class);
this.workflowProcessSorter = mock(WorkflowProcessSorter.class);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.flowFrameworkSettings = mock(FlowFrameworkSettings.class);
this.encryptorUtils = mock(EncryptorUtils.class);
this.pluginsService = mock(PluginsService.class);
ClusterService clusterService = mock(ClusterService.class);
Expand All @@ -95,6 +97,7 @@ public void setUp() throws Exception {
client,
workflowProcessSorter,
flowFrameworkIndicesHandler,
flowFrameworkSettings,
encryptorUtils,
pluginsService,
clusterService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public void setUp() throws Exception {
this.workflowStepFactory = mock(WorkflowStepFactory.class);
this.workflowProcessSorter = mock(WorkflowProcessSorter.class);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.flowFrameworkSettings = mock(FlowFrameworkSettings.class);
this.encryptorUtils = mock(EncryptorUtils.class);
this.pluginsService = mock(PluginsService.class);

Expand Down
Loading

0 comments on commit 60e881d

Please sign in to comment.