Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Allow custom awsvpcConfiguration for ECS Worker (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
HughZurname authored Aug 18, 2023
1 parent 2e17708 commit bde96a7
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303)
- Added support to `ECSWorker` for `awsvpcConfiguration` [#304](https://github.com/PrefectHQ/prefect-aws/pull/304)

### Changed

Expand Down
92 changes: 86 additions & 6 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class ECSJobConfiguration(BaseJobConfiguration):
)
configure_cloudwatch_logs: Optional[bool] = Field(default=None)
cloudwatch_logs_options: Dict[str, str] = Field(default_factory=dict)
network_configuration: Dict[str, Any] = Field(default_factory=dict)
stream_output: Optional[bool] = Field(default=None)
task_start_timeout_seconds: int = Field(default=300)
task_watch_poll_interval: float = Field(default=5.0)
Expand Down Expand Up @@ -321,6 +322,18 @@ def cloudwatch_logs_options_requires_configure_cloudwatch_logs(
)
return values

@root_validator
def network_configuration_requires_vpc_id(cls, values: dict) -> dict:
"""
Enforces a `vpc_id` is provided when custom network configuration mode is
enabled for network settings.
"""
if values.get("network_configuration") and not values.get("vpc_id"):
raise ValueError(
"You must provide a `vpc_id` to enable custom `network_configuration`."
)
return values


class ECSVariables(BaseVariables):
"""
Expand Down Expand Up @@ -459,10 +472,21 @@ class ECSVariables(BaseVariables):
"When `configure_cloudwatch_logs` is enabled, this setting may be used to"
" pass additional options to the CloudWatch logs configuration or override"
" the default options. See the [AWS"
" documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options.)" # noqa
" documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options)" # noqa
" for available options. "
),
)

network_configuration: Dict[str, Any] = Field(
default_factory=dict,
description=(
"When `network_configuration` is supplied it will override ECS Worker's"
"awsvpcConfiguration that defined in the ECS task executing your workload. "
"See the [AWS documentation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-ecs-service-awsvpcconfiguration.html)" # noqa
" for available options."
),
)

stream_output: bool = Field(
default=None,
description=(
Expand Down Expand Up @@ -1242,7 +1266,7 @@ def _prepare_task_definition(

return task_definition

def _load_vpc_network_config(
def _load_network_configuration(
self, vpc_id: Optional[str], boto_session: boto3.Session
) -> dict:
"""
Expand Down Expand Up @@ -1289,6 +1313,47 @@ def _load_vpc_network_config(
}
}

def _custom_network_configuration(
self, vpc_id: str, network_configuration: dict, boto_session: boto3.Session
) -> dict:
"""
Load settings from a specific VPC or the default VPC and generate a task
run request's network configuration.
"""
ec2_client = boto_session.client("ec2")
vpc_message = f"VPC with ID {vpc_id}"

vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs")

if not vpcs:
raise ValueError(
f"Failed to find {vpc_message}. "
+ "Network configuration cannot be inferred. "
+ "Pass an explicit `vpc_id`."
)

vpc_id = vpcs[0]["VpcId"]
subnets = ec2_client.describe_subnets(
Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
)["Subnets"]

if not subnets:
raise ValueError(
f"Failed to find subnets for {vpc_message}. "
+ "Network configuration cannot be inferred."
)

config_subnets = network_configuration.get("subnets", [])
if not all(
[conf_sn in sn.values() for conf_sn in config_subnets for sn in subnets]
):
raise ValueError(
f"Subnets {config_subnets} not found within {vpc_message}."
+ "Please check that VPC is associated with supplied subnets."
)

return {"awsvpcConfiguration": network_configuration}

def _prepare_task_run_request(
self,
boto_session: boto3.Session,
Expand Down Expand Up @@ -1318,14 +1383,29 @@ def _prepare_task_run_request(
container_overrides = overrides.get("containerOverrides", [])

# Ensure the network configuration is present if using awsvpc for network mode

if task_definition.get("networkMode") == "awsvpc" and not task_run_request.get(
"networkConfiguration"
if (
task_definition.get("networkMode") == "awsvpc"
and not task_run_request.get("networkConfiguration")
and not configuration.network_configuration
):
task_run_request["networkConfiguration"] = self._load_vpc_network_config(
task_run_request["networkConfiguration"] = self._load_network_configuration(
configuration.vpc_id, boto_session
)

# Use networkConfiguration if supplied by user
if (
task_definition.get("networkMode") == "awsvpc"
and configuration.network_configuration
and configuration.vpc_id
):
task_run_request["networkConfiguration"] = (
self._custom_network_configuration(
configuration.vpc_id,
configuration.network_configuration,
boto_session,
)
)

# Ensure the container name is set if not provided at template time

container_name = (
Expand Down
107 changes: 107 additions & 0 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from moto.ec2.utils import generate_instance_identity_document
from prefect.server.schemas.core import FlowRun
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from pydantic import ValidationError
from tenacity import RetryError

from prefect_aws.workers.ecs_worker import (
Expand Down Expand Up @@ -884,6 +885,112 @@ async def test_network_config_from_vpc_id(
}


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ec2_resource = session.resource("ec2")
vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16")
subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id)
security_group = ec2_resource.create_security_group(
GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id
)

configuration = await construct_configuration(
aws_credentials=aws_credentials,
vpc_id=vpc.id,
override_network_configuration=True,
network_configuration={
"subnets": [subnet.id],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
},
)

session = aws_credentials.get_boto3_session()

async with ECSWorker(work_pool_name="test") as worker:
# Capture the task run call because moto does not track 'networkConfiguration'
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration")

# Subnet ids are copied from the vpc
assert network_configuration == {
"awsvpcConfiguration": {
"subnets": [subnet.id],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
}
}


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings_invalid_subnet(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ec2_resource = session.resource("ec2")
vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16")
security_group = ec2_resource.create_security_group(
GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id
)
ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id)

configuration = await construct_configuration(
aws_credentials=aws_credentials,
vpc_id=vpc.id,
override_network_configuration=True,
network_configuration={
"subnets": ["sn-8asdas"],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
},
)

session = aws_credentials.get_boto3_session()

with pytest.raises(
ValueError,
match=(
r"Subnets \['sn-8asdas'\] not found within VPC with ID "
+ vpc.id
+ r"\.Please check that VPC is associated with supplied subnets\."
),
):
async with ECSWorker(work_pool_name="test") as worker:
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

await run_then_stop_task(worker, configuration, flow_run)


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_configure_network_requires_vpc_id(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
with pytest.raises(
ValidationError,
match="You must provide a `vpc_id` to enable custom `network_configuration`.",
):
await construct_configuration(
aws_credentials=aws_credentials,
override_network_configuration=True,
network_configuration={
"subnets": [],
"assignPublicIp": "ENABLED",
"securityGroups": [],
},
)


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_default_vpc(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down

0 comments on commit bde96a7

Please sign in to comment.