Skip to content

Commit

Permalink
raise warning if extra provider config given
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam-D-Lewis committed Dec 13, 2024
1 parent 2682982 commit 68b5064
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider):
result_config_dict = config.model_dump()
assert provider in result_config_dict
assert result_config_dict[provider]["kube_context"] == "some_context"


def test_provider_config_mismatch_warning(config_schema):
config_dict = {
"project_name": "test",
"provider": "local",
"existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider
}
with pytest.warns(UserWarning, match="configuration defined for other providers"):
config_schema(**config_dict)
1 change: 1 addition & 0 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True)
old_config.local = None
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config.model_dump()

Expand Down

0 comments on commit 68b5064

Please sign in to comment.