Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
revert ecs worker test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Apr 2, 2024
1 parent d52689f commit 4f58f7d
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_TASK_DEFINITION_CACHE,
ECS_DEFAULT_CONTAINER_NAME,
ECS_DEFAULT_CPU,
ECS_DEFAULT_FAMILY,
ECS_DEFAULT_MEMORY,
AwsCredentials,
ECSJobConfiguration,
Expand Down Expand Up @@ -648,6 +649,7 @@ async def test_task_definition_arn(aws_credentials: AwsCredentials, flow_run: Fl
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
print(task)
assert (
task["taskDefinitionArn"] == task_definition_arn
), "The task definition should be used without registering a new one"
Expand Down Expand Up @@ -1322,20 +1324,8 @@ async def write_fake_log(task_arn):


@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize(
"cloudwatch_logs_options",
[
{
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
{
"max-buffer-size": "2m",
},
],
)
async def test_cloudwatch_log_options(
aws_credentials: AwsCredentials, flow_run: FlowRun, cloudwatch_logs_options: dict
aws_credentials: AwsCredentials, flow_run: FlowRun
):
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")
Expand All @@ -1344,10 +1334,12 @@ async def test_cloudwatch_log_options(
aws_credentials=aws_credentials,
configure_cloudwatch_logs=True,
execution_role_arn="test",
cloudwatch_logs_options=cloudwatch_logs_options,
cloudwatch_logs_options={
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
)
work_pool_name = "test"
async with ECSWorker(work_pool_name=work_pool_name) as worker:
async with ECSWorker(work_pool_name="test") as worker:
result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
Expand All @@ -1357,9 +1349,6 @@ async def test_cloudwatch_log_options(
task_definition = describe_task_definition(ecs_client, task)

for container in task_definition["containerDefinitions"]:
prefix = f"prefect-logs_{work_pool_name}_{flow_run.deployment_id}"
if cloudwatch_logs_options.get("awslogs-stream-prefix"):
prefix = cloudwatch_logs_options["awslogs-stream-prefix"]
if container["name"] == ECS_DEFAULT_CONTAINER_NAME:
# Assert that the container has logging configured with user
# provided options
Expand All @@ -1369,7 +1358,7 @@ async def test_cloudwatch_log_options(
"awslogs-create-group": "true",
"awslogs-group": "prefect",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": prefix,
"awslogs-stream-prefix": "override-prefix",
"max-buffer-size": "2m",
},
}
Expand Down Expand Up @@ -2329,3 +2318,27 @@ async def test_mask_sensitive_env_values():
res["overrides"]["containerOverrides"][0]["environment"][1]["value"]
== "NORMAL_VALUE"
)


@pytest.mark.usefixtures("ecs_mocks")
async def test_get_or_generate_family(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
configuration = await construct_configuration(
aws_credentials=aws_credentials,
)

work_pool_name = "test"
session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")
family = f"{ECS_DEFAULT_FAMILY}_{work_pool_name}_{flow_run.deployment_id}"

async with ECSWorker(work_pool_name=work_pool_name) as worker:
result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
task_definition = describe_task_definition(ecs_client, task)
assert task_definition["family"] == family

0 comments on commit 4f58f7d

Please sign in to comment.