diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index a5894687ac..b75412bd64 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -6,6 +6,7 @@ import re import sys import tempfile +import warnings from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union from pydantic import ConfigDict, Field, field_validator, model_validator @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any: data[provider] = provider_enum_model_map[provider]() else: # if the provider field is invalid, it won't be set when this validator is called - # so we need to check for it explicitly here, and set the `pre` to True + # so we need to check for it explicitly here, and set mode to "before" # TODO: this is a workaround, check if there is a better way to do this in Pydantic v2 raise ValueError( f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure" ) + set_providers = { + provider + for provider in provider_name_abbreviation_map.keys() + if provider in data and data[provider] + } + expected_provider_config = provider_enum_name_map[provider] + extra_provider_config = set_providers - {expected_provider_config} + if extra_provider_config: + warnings.warn( + f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}" + ) + else: set_providers = [ provider diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 5c21aef8d6..e445ba37da 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider): result_config_dict = config.model_dump() assert provider in result_config_dict assert result_config_dict[provider]["kube_context"] == "some_context" + + +def test_provider_config_mismatch_warning(config_schema): + config_dict = { + "project_name": "test", + "provider": "local", + "existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider + } + with pytest.warns(UserWarning, match="configuration defined for other providers"): + config_schema(**config_dict) diff --git a/tests/tests_unit/test_stages.py b/tests/tests_unit/test_stages.py index c716d93030..c15aa6d9fc 100644 --- a/tests/tests_unit/test_stages.py +++ b/tests/tests_unit/test_stages.py @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change( mock_model_fields, mock_get_state, terraform_state_stage, mock_config ): old_config = mock_config.model_copy(deep=True) + old_config.local = None old_config.provider = schema.ProviderEnum.gcp mock_get_state.return_value = old_config.model_dump()