diff --git a/dbt_coves/config/config.py b/dbt_coves/config/config.py index c280c69c..68effc5d 100644 --- a/dbt_coves/config/config.py +++ b/dbt_coves/config/config.py @@ -156,7 +156,7 @@ class DataSyncModel(BaseModel): class BlueGreenModel(BaseModel): - service_connection_name: Optional[str] = "" + prod_db_env_var: Optional[str] = "" staging_database: Optional[str] = "" staging_suffix: Optional[str] = "" drop_staging_db_at_start: Optional[bool] = False @@ -264,7 +264,7 @@ class DbtCovesConfig: "load.fivetran.secrets_key", "data_sync.redshift.tables", "data_sync.snowflake.tables", - "blue_green.service_connection_name", + "blue_green.prod_db_env_var", "blue_green.staging_database", "blue_green.staging_suffix", "blue_green.drop_staging_db_at_start", diff --git a/dbt_coves/core/main.py b/dbt_coves/core/main.py index b6c505ef..a79fc95f 100644 --- a/dbt_coves/core/main.py +++ b/dbt_coves/core/main.py @@ -295,7 +295,8 @@ def main(parser: argparse.ArgumentParser = parser, test_cli_args: List[str] = li console.print( "[red]The process was killed by the OS due to running out of memory.[/red]" ) - console.print(f"[red]:cross_mark:[/red] {cpe.stderr}") + if cpe.stderr: + console.print(f"[red]:cross_mark:[/red] {cpe.stderr}") return cpe.returncode except Exception as ex: diff --git a/dbt_coves/tasks/blue_green/main.py b/dbt_coves/tasks/blue_green/main.py index ab25f808..2e9b2149 100644 --- a/dbt_coves/tasks/blue_green/main.py +++ b/dbt_coves/tasks/blue_green/main.py @@ -3,10 +3,9 @@ import snowflake.connector from rich.console import Console -from rich.text import Text from dbt_coves.core.exceptions import DbtCovesException -from dbt_coves.tasks.base import NonDbtBaseConfiguredTask +from dbt_coves.tasks.base import BaseConfiguredTask from dbt_coves.utils.tracking import trackable from .clone_db import CloneDB @@ -14,7 +13,7 @@ console = Console() -class BlueGreenTask(NonDbtBaseConfiguredTask): +class BlueGreenTask(BaseConfiguredTask): """ Task that performs a blue-green deployment """ @@ -32,7 +31,7 @@ def register_parser(cls, sub_parsers, base_subparser): ext_subparser.set_defaults(cls=cls, which="blue-green") cls.arg_parser = ext_subparser ext_subparser.add_argument( - "--service-connection-name", + "--prod-db-env-var", type=str, help="Snowflake service connection name", ) @@ -79,15 +78,14 @@ def get_config_value(self, key): @trackable def run(self) -> int: - self.service_connection_name = self.get_config_value("service_connection_name").upper() + self.prod_db_env_var = self.get_config_value("prod_db_env_var").upper() try: - self.production_database = os.environ[ - f"DATACOVES__{self.service_connection_name}__DATABASE" - ] + self.production_database = os.environ[self.prod_db_env_var] except KeyError: raise DbtCovesException( - f"There is no Database defined for Service Connection {self.service_connection_name}" + f"Environment variable {self.prod_db_env_var} not found. Please provide a production database" ) + self.con = self.snowflake_connection() staging_database = self.get_config_value("staging_database") staging_suffix = self.get_config_value("staging_suffix") if staging_database and staging_suffix: @@ -101,7 +99,6 @@ def run(self) -> int: f"{self.staging_database}" ) self.drop_staging_db_at_start = self.get_config_value("drop_staging_db_at_start") - self.con = self.snowflake_connection() self.cdb = CloneDB( self.production_database, @@ -134,21 +131,21 @@ def run(self) -> int: def _run_dbt_build(self, env): dbt_build_command: list = self._get_dbt_build_command() - env[f"DATACOVES__{self.service_connection_name}__DATABASE"] = self.staging_database + env[self.prod_db_env_var] = self.staging_database self._run_command(dbt_build_command, env=env) def _run_command(self, command: list, env=os.environ.copy()): command_string = " ".join(command) console.print(f"Running [b][i]{command_string}[/i][/b]") try: - output = subprocess.check_output(command, env=env, stderr=subprocess.PIPE) - console.print( - f"{Text.from_ansi(output.decode())}\n" - f"[green]{command_string} :heavy_check_mark:[/green]" + subprocess.run( + command, + env=env, + check=True, ) + console.print(f"[green]{command_string} :heavy_check_mark:[/green]") except subprocess.CalledProcessError as e: - formatted = f"{Text.from_ansi(e.stderr.decode()) if e.stderr else Text.from_ansi(e.stdout.decode())}" - e.stderr = f"An error has occurred running [red]{command_string}[/red]:\n{formatted}" + console.print(f"Error running [red]{e.cmd}[/red], see stack above for details") raise def _get_dbt_command(self, command): @@ -198,20 +195,46 @@ def _check_and_drop_staging_db(self): f"Green database {self.staging_database} already exists. Please either drop it or use a different name." ) + def _get_snowflake_credentials_from_dbt_adapter(self): + connection_dict = { + "account": self.config.credentials.account, + "warehouse": self.config.credentials.warehouse, + "database": self.config.credentials.database, + "role": self.config.credentials.role, + "schema": self.config.credentials.schema, + "user": self.config.credentials.user, + "session_parameters": { + "QUERY_TAG": "blue_green_swap", + }, + } + if self.config.credentials.password: + connection_dict["password"] = self.config.credentials.password + else: + connection_dict["private_key"] = self._get_snowflake_private_key() + connection_dict["login_timeout"] = 10 + + return connection_dict + + def _get_snowflake_private_key(self): + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + with open(self.config.credentials.private_key_path, "rb") as key_file: + private_key = serialization.load_pem_private_key( + key_file.read(), password=None, backend=default_backend() + ) + + # Convert the private key to the required format + return private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + def snowflake_connection(self): + connection_dict = self._get_snowflake_credentials_from_dbt_adapter() try: - return snowflake.connector.connect( - account=os.environ.get(f"DATACOVES__{self.service_connection_name}__ACCOUNT"), - warehouse=os.environ.get(f"DATACOVES__{self.service_connection_name}__WAREHOUSE"), - database=os.environ.get(f"DATACOVES__{self.service_connection_name}__DATABASE"), - role=os.environ.get(f"DATACOVES__{self.service_connection_name}__ROLE"), - schema=os.environ.get(f"DATACOVES__{self.service_connection_name}__SCHEMA"), - user=os.environ.get(f"DATACOVES__{self.service_connection_name}__USER"), - password=os.environ.get(f"DATACOVES__{self.service_connection_name}__PASSWORD"), - session_parameters={ - "QUERY_TAG": "blue_green_swap", - }, - ) + return snowflake.connector.connect(**connection_dict) except Exception as e: raise DbtCovesException( f"Couldn't establish Snowflake connection with {self.production_database}: {e}" diff --git a/dbt_coves/utils/flags.py b/dbt_coves/utils/flags.py index 928b4e3b..911a8d47 100644 --- a/dbt_coves/utils/flags.py +++ b/dbt_coves/utils/flags.py @@ -137,7 +137,7 @@ def __init__(self, cli_parser: ArgumentParser) -> None: self.dbt = {"command": None, "project_dir": None, "virtualenv": None, "cleanup": False} self.data_sync = {"redshift": {"tables": []}, "snowflake": {"tables": []}} self.blue_green = { - "service_connection_name": None, + "prod_db_env_var": None, "staging_database": None, "staging_suffix": None, "drop_staging_db_at_start": False, @@ -421,8 +421,8 @@ def parse_args(self, cli_args: List[str] = list()) -> None: # blue green if self.args.cls.__name__ == "BlueGreenTask": - if self.args.service_connection_name: - self.blue_green["service_connection_name"] = self.args.service_connection_name + if self.args.prod_db_env_var: + self.blue_green["prod_db_env_var"] = self.args.prod_db_env_var if self.args.staging_database: self.blue_green["staging_database"] = self.args.staging_database if self.args.staging_suffix: diff --git a/tests/blue_green/profiles.yml b/tests/blue_green/profiles.yml index 9cc63613..8d694629 100644 --- a/tests/blue_green/profiles.yml +++ b/tests/blue_green/profiles.yml @@ -2,7 +2,7 @@ default: outputs: dev: account: "{{ env_var('DATACOVES__DBT_COVES_TEST__ACCOUNT') }}" - database: DBT_COVES_TEST_STAGING + database: "{{ env_var('DATACOVES__DBT_COVES_TEST__DATABASE') }}" password: "{{ env_var('DATACOVES__DBT_COVES_TEST__PASSWORD') }}" role: "{{ env_var('DATACOVES__DBT_COVES_TEST__ROLE') }}" schema: TESTS_BLUE_GREEN diff --git a/tests/blue_green_test.py b/tests/blue_green_test.py index ce36aac9..3cdea352 100644 --- a/tests/blue_green_test.py +++ b/tests/blue_green_test.py @@ -39,7 +39,7 @@ @pytest.fixture(scope="class") def snowflake_connection(request): # Check env vars - assert "DATACOVES__DBT_COVES_TEST__USER" in os.environ + assert "DATACOVES__DBT_COVES_TEST__DATABASE" in os.environ assert "DATACOVES__DBT_COVES_TEST__PASSWORD" in os.environ assert "DATACOVES__DBT_COVES_TEST__ACCOUNT" in os.environ assert "DATACOVES__DBT_COVES_TEST__WAREHOUSE" in os.environ @@ -103,8 +103,8 @@ def test_dbt_coves_bluegreen(self): str(FIXTURE_DIR), "--profiles-dir", str(FIXTURE_DIR), - "--service-connection-name", - self.production_database, + "--prod-db-env-var", + "DATACOVES__DBT_COVES_TEST__DATABASE", "--keep-staging-db-on-success", ] if DBT_COVES_SETTINGS.get("drop_staging_db_at_start"):