Skip to content

Commit

Permalink
DCV-2857 changes to blue-green run (#494)
Browse files Browse the repository at this point in the history
* DCV-2857 changes to blue-green run

* Deprecate creating Snowflake conn from env_vars, change service_connection for prod_db_env_var

* Make dbt build output real-time, implement Snowflake login_timeout in key-pair cases
  • Loading branch information
BAntonellini authored Sep 11, 2024
1 parent ea48aac commit 2442472
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 39 deletions.
4 changes: 2 additions & 2 deletions dbt_coves/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class DataSyncModel(BaseModel):


class BlueGreenModel(BaseModel):
service_connection_name: Optional[str] = ""
prod_db_env_var: Optional[str] = ""
staging_database: Optional[str] = ""
staging_suffix: Optional[str] = ""
drop_staging_db_at_start: Optional[bool] = False
Expand Down Expand Up @@ -264,7 +264,7 @@ class DbtCovesConfig:
"load.fivetran.secrets_key",
"data_sync.redshift.tables",
"data_sync.snowflake.tables",
"blue_green.service_connection_name",
"blue_green.prod_db_env_var",
"blue_green.staging_database",
"blue_green.staging_suffix",
"blue_green.drop_staging_db_at_start",
Expand Down
3 changes: 2 additions & 1 deletion dbt_coves/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def main(parser: argparse.ArgumentParser = parser, test_cli_args: List[str] = li
console.print(
"[red]The process was killed by the OS due to running out of memory.[/red]"
)
console.print(f"[red]:cross_mark:[/red] {cpe.stderr}")
if cpe.stderr:
console.print(f"[red]:cross_mark:[/red] {cpe.stderr}")

return cpe.returncode
except Exception as ex:
Expand Down
81 changes: 52 additions & 29 deletions dbt_coves/tasks/blue_green/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@

import snowflake.connector
from rich.console import Console
from rich.text import Text

from dbt_coves.core.exceptions import DbtCovesException
from dbt_coves.tasks.base import NonDbtBaseConfiguredTask
from dbt_coves.tasks.base import BaseConfiguredTask
from dbt_coves.utils.tracking import trackable

from .clone_db import CloneDB

console = Console()


class BlueGreenTask(NonDbtBaseConfiguredTask):
class BlueGreenTask(BaseConfiguredTask):
"""
Task that performs a blue-green deployment
"""
Expand All @@ -32,7 +31,7 @@ def register_parser(cls, sub_parsers, base_subparser):
ext_subparser.set_defaults(cls=cls, which="blue-green")
cls.arg_parser = ext_subparser
ext_subparser.add_argument(
"--service-connection-name",
"--prod-db-env-var",
type=str,
help="Snowflake service connection name",
)
Expand Down Expand Up @@ -79,15 +78,14 @@ def get_config_value(self, key):

@trackable
def run(self) -> int:
self.service_connection_name = self.get_config_value("service_connection_name").upper()
self.prod_db_env_var = self.get_config_value("prod_db_env_var").upper()
try:
self.production_database = os.environ[
f"DATACOVES__{self.service_connection_name}__DATABASE"
]
self.production_database = os.environ[self.prod_db_env_var]
except KeyError:
raise DbtCovesException(
f"There is no Database defined for Service Connection {self.service_connection_name}"
f"Environment variable {self.prod_db_env_var} not found. Please provide a production database"
)
self.con = self.snowflake_connection()
staging_database = self.get_config_value("staging_database")
staging_suffix = self.get_config_value("staging_suffix")
if staging_database and staging_suffix:
Expand All @@ -101,7 +99,6 @@ def run(self) -> int:
f"{self.staging_database}"
)
self.drop_staging_db_at_start = self.get_config_value("drop_staging_db_at_start")
self.con = self.snowflake_connection()

self.cdb = CloneDB(
self.production_database,
Expand Down Expand Up @@ -134,21 +131,21 @@ def run(self) -> int:

def _run_dbt_build(self, env):
dbt_build_command: list = self._get_dbt_build_command()
env[f"DATACOVES__{self.service_connection_name}__DATABASE"] = self.staging_database
env[self.prod_db_env_var] = self.staging_database
self._run_command(dbt_build_command, env=env)

def _run_command(self, command: list, env=os.environ.copy()):
command_string = " ".join(command)
console.print(f"Running [b][i]{command_string}[/i][/b]")
try:
output = subprocess.check_output(command, env=env, stderr=subprocess.PIPE)
console.print(
f"{Text.from_ansi(output.decode())}\n"
f"[green]{command_string} :heavy_check_mark:[/green]"
subprocess.run(
command,
env=env,
check=True,
)
console.print(f"[green]{command_string} :heavy_check_mark:[/green]")
except subprocess.CalledProcessError as e:
formatted = f"{Text.from_ansi(e.stderr.decode()) if e.stderr else Text.from_ansi(e.stdout.decode())}"
e.stderr = f"An error has occurred running [red]{command_string}[/red]:\n{formatted}"
console.print(f"Error running [red]{e.cmd}[/red], see stack above for details")
raise

def _get_dbt_command(self, command):
Expand Down Expand Up @@ -198,20 +195,46 @@ def _check_and_drop_staging_db(self):
f"Green database {self.staging_database} already exists. Please either drop it or use a different name."
)

def _get_snowflake_credentials_from_dbt_adapter(self):
connection_dict = {
"account": self.config.credentials.account,
"warehouse": self.config.credentials.warehouse,
"database": self.config.credentials.database,
"role": self.config.credentials.role,
"schema": self.config.credentials.schema,
"user": self.config.credentials.user,
"session_parameters": {
"QUERY_TAG": "blue_green_swap",
},
}
if self.config.credentials.password:
connection_dict["password"] = self.config.credentials.password
else:
connection_dict["private_key"] = self._get_snowflake_private_key()
connection_dict["login_timeout"] = 10

return connection_dict

def _get_snowflake_private_key(self):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

with open(self.config.credentials.private_key_path, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend()
)

# Convert the private key to the required format
return private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

def snowflake_connection(self):
connection_dict = self._get_snowflake_credentials_from_dbt_adapter()
try:
return snowflake.connector.connect(
account=os.environ.get(f"DATACOVES__{self.service_connection_name}__ACCOUNT"),
warehouse=os.environ.get(f"DATACOVES__{self.service_connection_name}__WAREHOUSE"),
database=os.environ.get(f"DATACOVES__{self.service_connection_name}__DATABASE"),
role=os.environ.get(f"DATACOVES__{self.service_connection_name}__ROLE"),
schema=os.environ.get(f"DATACOVES__{self.service_connection_name}__SCHEMA"),
user=os.environ.get(f"DATACOVES__{self.service_connection_name}__USER"),
password=os.environ.get(f"DATACOVES__{self.service_connection_name}__PASSWORD"),
session_parameters={
"QUERY_TAG": "blue_green_swap",
},
)
return snowflake.connector.connect(**connection_dict)
except Exception as e:
raise DbtCovesException(
f"Couldn't establish Snowflake connection with {self.production_database}: {e}"
Expand Down
6 changes: 3 additions & 3 deletions dbt_coves/utils/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self, cli_parser: ArgumentParser) -> None:
self.dbt = {"command": None, "project_dir": None, "virtualenv": None, "cleanup": False}
self.data_sync = {"redshift": {"tables": []}, "snowflake": {"tables": []}}
self.blue_green = {
"service_connection_name": None,
"prod_db_env_var": None,
"staging_database": None,
"staging_suffix": None,
"drop_staging_db_at_start": False,
Expand Down Expand Up @@ -421,8 +421,8 @@ def parse_args(self, cli_args: List[str] = list()) -> None:

# blue green
if self.args.cls.__name__ == "BlueGreenTask":
if self.args.service_connection_name:
self.blue_green["service_connection_name"] = self.args.service_connection_name
if self.args.prod_db_env_var:
self.blue_green["prod_db_env_var"] = self.args.prod_db_env_var
if self.args.staging_database:
self.blue_green["staging_database"] = self.args.staging_database
if self.args.staging_suffix:
Expand Down
2 changes: 1 addition & 1 deletion tests/blue_green/profiles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default:
outputs:
dev:
account: "{{ env_var('DATACOVES__DBT_COVES_TEST__ACCOUNT') }}"
database: DBT_COVES_TEST_STAGING
database: "{{ env_var('DATACOVES__DBT_COVES_TEST__DATABASE') }}"
password: "{{ env_var('DATACOVES__DBT_COVES_TEST__PASSWORD') }}"
role: "{{ env_var('DATACOVES__DBT_COVES_TEST__ROLE') }}"
schema: TESTS_BLUE_GREEN
Expand Down
6 changes: 3 additions & 3 deletions tests/blue_green_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@pytest.fixture(scope="class")
def snowflake_connection(request):
# Check env vars
assert "DATACOVES__DBT_COVES_TEST__USER" in os.environ
assert "DATACOVES__DBT_COVES_TEST__DATABASE" in os.environ
assert "DATACOVES__DBT_COVES_TEST__PASSWORD" in os.environ
assert "DATACOVES__DBT_COVES_TEST__ACCOUNT" in os.environ
assert "DATACOVES__DBT_COVES_TEST__WAREHOUSE" in os.environ
Expand Down Expand Up @@ -103,8 +103,8 @@ def test_dbt_coves_bluegreen(self):
str(FIXTURE_DIR),
"--profiles-dir",
str(FIXTURE_DIR),
"--service-connection-name",
self.production_database,
"--prod-db-env-var",
"DATACOVES__DBT_COVES_TEST__DATABASE",
"--keep-staging-db-on-success",
]
if DBT_COVES_SETTINGS.get("drop_staging_db_at_start"):
Expand Down

0 comments on commit 2442472

Please sign in to comment.