Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
fix: 320 Subnet validation in Network Configuration fails when multip…
Browse files Browse the repository at this point in the history
…le subnets exist in the VPC (#321)
  • Loading branch information
chotalia authored Oct 10, 2023
1 parent 525c917 commit 8650dbf
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314)

### Fixed

- Resolved an issue where defining a custom network configuration with a subnet would erroneously report it as missing from the VPC when more than one subnet exists in the VPC. [#321](https://github.com/PrefectHQ/prefect-aws/pull/321)

### Deprecated

### Removed
Expand Down
6 changes: 3 additions & 3 deletions prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,10 +1349,10 @@ def _custom_network_configuration(
+ "Network configuration cannot be inferred."
)

subnet_ids = [subnet["SubnetId"] for subnet in subnets]

config_subnets = network_configuration.get("subnets", [])
if not all(
[conf_sn in sn.values() for conf_sn in config_subnets for sn in subnets]
):
if not all(conf_sn in subnet_ids for conf_sn in config_subnets):
raise ValueError(
f"Subnets {config_subnets} not found within {vpc_message}."
+ "Please check that VPC is associated with supplied subnets."
Expand Down
145 changes: 144 additions & 1 deletion tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ async def test_network_config_from_vpc_id(


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings(
async def test_network_config_1_subnet_in_custom_settings_1_in_vpc(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
Expand Down Expand Up @@ -937,6 +937,107 @@ async def test_network_config_from_custom_settings(
}


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_1_sn_in_custom_settings_many_in_vpc(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ec2_resource = session.resource("ec2")
vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16")
subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id)
ec2_resource.create_subnet(CidrBlock="10.0.3.0/24", VpcId=vpc.id)
ec2_resource.create_subnet(CidrBlock="10.0.4.0/24", VpcId=vpc.id)

security_group = ec2_resource.create_security_group(
GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id
)

configuration = await construct_configuration(
aws_credentials=aws_credentials,
vpc_id=vpc.id,
override_network_configuration=True,
network_configuration={
"subnets": [subnet.id],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
},
)

session = aws_credentials.get_boto3_session()

async with ECSWorker(work_pool_name="test") as worker:
# Capture the task run call because moto does not track 'networkConfiguration'
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration")

# Subnet ids are copied from the vpc
assert network_configuration == {
"awsvpcConfiguration": {
"subnets": [subnet.id],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
}
}


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_many_subnet_in_custom_settings_many_in_vpc(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ec2_resource = session.resource("ec2")
vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16")
subnets = [
ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id),
ec2_resource.create_subnet(CidrBlock="10.0.33.0/24", VpcId=vpc.id),
ec2_resource.create_subnet(CidrBlock="10.0.44.0/24", VpcId=vpc.id),
]
subnet_ids = [subnet.id for subnet in subnets]

security_group = ec2_resource.create_security_group(
GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id
)

configuration = await construct_configuration(
aws_credentials=aws_credentials,
vpc_id=vpc.id,
override_network_configuration=True,
network_configuration={
"subnets": subnet_ids,
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
},
)

session = aws_credentials.get_boto3_session()

async with ECSWorker(work_pool_name="test") as worker:
# Capture the task run call because moto does not track 'networkConfiguration'
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration")

# Subnet ids are copied from the vpc
assert network_configuration == {
"awsvpcConfiguration": {
"subnets": subnet_ids,
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
}
}


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings_invalid_subnet(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down Expand Up @@ -978,6 +1079,48 @@ async def test_network_config_from_custom_settings_invalid_subnet(
await run_then_stop_task(worker, configuration, flow_run)


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_subnets(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ec2_resource = session.resource("ec2")
vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16")
security_group = ec2_resource.create_security_group(
GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id
)
subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id)
invalid_subnet_id = "subnet-3bf19de7"

configuration = await construct_configuration(
aws_credentials=aws_credentials,
vpc_id=vpc.id,
override_network_configuration=True,
network_configuration={
"subnets": [invalid_subnet_id, subnet.id],
"assignPublicIp": "DISABLED",
"securityGroups": [security_group.id],
},
)

session = aws_credentials.get_boto3_session()

with pytest.raises(
ValueError,
match=(
rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC"
f" with ID {vpc.id}.Please check that VPC is associated with supplied"
" subnets."
),
):
async with ECSWorker(work_pool_name="test") as worker:
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

await run_then_stop_task(worker, configuration, flow_run)


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_configure_network_requires_vpc_id(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down

0 comments on commit 8650dbf

Please sign in to comment.