diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 27cb521e..52fc0e70 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -1,81 +1,94 @@ name: 🐞 Bug -description: Report a bug or an issue you've found with dbt-adapters +description: Report a bug or an issue you've found title: "[Bug]
+ +
+ + + +- [Features](#features) + - [Quick start](#quick-start) + - [Installation](#installation) + - [Prerequisites](#prerequisites) + - [Credentials](#credentials) + - [Configuring your profile](#configuring-your-profile) + - [Additional information](#additional-information) + - [Models](#models) + - [Table configuration](#table-configuration) + - [Table location](#table-location) + - [Incremental models](#incremental-models) + - [On schema change](#on-schema-change) + - [Iceberg](#iceberg) + - [Highly available table (HA)](#highly-available-table-ha) + - [HA known issues](#ha-known-issues) + - [Update glue data catalog](#update-glue-data-catalog) + - [Snapshots](#snapshots) + - [Timestamp strategy](#timestamp-strategy) + - [Check strategy](#check-strategy) + - [Hard-deletes](#hard-deletes) + - [Working example](#working-example) + - [Snapshots known issues](#snapshots-known-issues) + - [AWS Lake Formation integration](#aws-lake-formation-integration) + - [Python models](#python-models) + - [Contracts](#contracts) + - [Contributing](#contributing) + - [Contributors ✨](#contributors-) + + +# Features + +- Supports dbt version `1.7.*` +- Support for Python +- Supports [seeds][seeds] +- Correctly detects views and their columns +- Supports [table materialization][table] + - [Iceberg tables][athena-iceberg] are supported **only with Athena Engine v3** and **a unique table location** + (see table location section below) + - Hive tables are supported by both Athena engines +- Supports [incremental models][incremental] + - On Iceberg tables: + - Supports the use of `unique_key` only with the `merge` strategy + - Supports the `append` strategy + - On Hive tables: + - Supports two incremental update strategies: `insert_overwrite` and `append` + - Does **not** support the use of `unique_key` +- Supports [snapshots][snapshots] +- Supports [Python models][python-models] + +[seeds]: https://docs.getdbt.com/docs/building-a-dbt-project/seeds + +[incremental]: https://docs.getdbt.com/docs/build/incremental-models + +[table]: https://docs.getdbt.com/docs/build/materializations#table + +[python-models]: https://docs.getdbt.com/docs/build/python-models#configuring-python-models + +[athena-iceberg]: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html + +[snapshots]: https://docs.getdbt.com/docs/build/snapshots + +## Quick start + +### Installation + +- `pip install dbt-athena-community` +- Or `pip install git+https://github.com/dbt-athena/dbt-athena.git` + +### Prerequisites + +To start, you will need an S3 bucket, for instance `my-bucket` and an Athena database: + +```sql +CREATE DATABASE IF NOT EXISTS analytics_dev +COMMENT 'Analytics models generated by dbt (development)' +LOCATION 's3://my-bucket/' +WITH DBPROPERTIES ('creator'='Foo Bar', 'email'='foo@bar.com'); +``` + +Notes: + +- Take note of your AWS region code (e.g. `us-west-2` or `eu-west-2`, etc.). +- You can also use [AWS Glue](https://docs.aws.amazon.com/athena/latest/ug/glue-athena.html) to create and manage Athena + databases. + +### Credentials + +Credentials can be passed directly to the adapter, or they can +be [determined automatically](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) based +on `aws cli`/`boto3` conventions. +You can either: + +- Configure `aws_access_key_id` and `aws_secret_access_key` +- Configure `aws_profile_name` to match a profile defined in your AWS credentials file. + Checkout dbt profile configuration below for details. + +### Configuring your profile + +A dbt profile can be configured to run against AWS Athena using the following configuration: + +| Option | Description | Required? | Example | +|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------| +| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` | +| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` | +| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` | +| s3_tmp_table_dir | Prefix for storing temporary tables, if different from the connection's `s3_data_dir` | Optional | `s3://bucket3/dbt/` | +| region_name | AWS region of your Athena instance | Required | `eu-west-1` | +| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` | +| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` | +| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` | +| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` | +| aws_access_key_id | Access key ID of the user performing requests | Optional | `AKIAIOSFODNN7EXAMPLE` | +| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | +| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | +| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | +| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` | +| num_retries | Number of times to retry a failing query | Optional | `3` | +| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | +| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` | +| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | +| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | + +**Example profiles.yml entry:** + +```yaml +athena: + target: dev + outputs: + dev: + type: athena + s3_staging_dir: s3://athena-query-results/dbt/ + s3_data_dir: s3://your_s3_bucket/dbt/ + s3_data_naming: schema_table + s3_tmp_table_dir: s3://your_s3_bucket/temp/ + region_name: eu-west-1 + schema: dbt + database: awsdatacatalog + threads: 4 + aws_profile_name: my-profile + work_group: my-workgroup + spark_work_group: my-spark-workgroup + seed_s3_upload_args: + ACL: bucket-owner-full-control +``` + +### Additional information + +- `threads` is supported +- `database` and `catalog` can be used interchangeably + +## Models + +### Table configuration + +- `external_location` (`default=none`) + - If set, the full S3 path to which the table will be saved + - Works only with incremental models + - Does not work with Hive table with `ha` set to true +- `partitioned_by` (`default=none`) + - An array list of columns by which the table will be partitioned + - Limited to creation of 100 partitions (*currently*) +- `bucketed_by` (`default=none`) + - An array list of columns to bucket data, ignored if using Iceberg +- `bucket_count` (`default=none`) + - The number of buckets for bucketing your data, ignored if using Iceberg +- `table_type` (`default='hive'`) + - The type of table + - Supports `hive` or `iceberg` +- `ha` (`default=false`) + - If the table should be built using the high-availability method. This option is only available for Hive tables + since it is by default for Iceberg tables (see the section [below](#highly-available-table-ha)) +- `format` (`default='parquet'`) + - The data format for the table + - Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, `TEXTFILE` +- `write_compression` (`default=none`) + - The compression type to use for any storage format that allows compression to be specified. To see which options are + available, check out [CREATE TABLE AS][create-table-as] +- `field_delimiter` (`default=none`) + - Custom field delimiter, for when format is set to `TEXTFILE` +- `table_properties`: table properties to add to the table, valid for Iceberg only +- `native_drop`: Relation drop operations will be performed with SQL, not direct Glue API calls. No S3 calls will be + made to manage data in S3. Data in S3 will only be cleared up for Iceberg + tables [see AWS docs](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-managing-tables.html). Note that + Iceberg DROP TABLE operations may timeout if they take longer than 60 seconds. +- `seed_by_insert` (`default=false`) + - Default behaviour uploads seed data to S3. This flag will create seeds using an SQL insert statement + - Large seed files cannot use `seed_by_insert`, as the SQL insert statement would + exceed [the Athena limit of 262144 bytes](https://docs.aws.amazon.com/athena/latest/ug/service-limits.html) +- `force_batch` (`default=false`) + - Skip creating the table as CTAS and run the operation directly in batch insert mode + - This is particularly useful when the standard table creation process fails due to partition limitations, + allowing you to work with temporary tables and persist the dataset more efficiently +- `unique_tmp_table_suffix` (`default=false`) + - For incremental models using insert overwrite strategy on hive table + - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid + - Useful if you are looking to run multiple dbt build inserting in the same table in parallel +- `temp_schema` (`default=none`) + - For incremental models, it allows to define a schema to hold temporary create statements + used in incremental model runs + - Schema will be created in the model target database if does not exist +- `lf_tags_config` (`default=none`) + - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns + - `enabled` (`default=False`) whether LF tags management is enabled for a model + - `tags` dictionary with tags and their values to assign for the model + - `tags_columns` dictionary with a tag key, value and list of columns they must be assigned to + - `lf_inherited_tags` (`default=none`) + - List of Lake Formation tag keys that are intended to be inherited from the database level and thus shouldn't be + removed during association of those defined in `lf_tags_config` + - i.e., the default behavior of `lf_tags_config` is to be exhaustive and first remove any pre-existing tags from + tables and columns before associating the ones currently defined for a given model + - This breaks tag inheritance as inherited tags appear on tables and columns like those associated directly + +```sql +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change='append_new_columns', + table_type='iceberg', + schema='test_schema', + lf_tags_config={ + 'enabled': true, + 'tags': { + 'tag1': 'value1', + 'tag2': 'value2' + }, + 'tags_columns': { + 'tag1': { + 'value1': ['column1', 'column2'], + 'value2': ['column3', 'column4'] + } + }, + 'inherited_tags': ['tag1', 'tag2'] + } + ) +}} +``` + +- Format for `dbt_project.yml`: + +```yaml + +lf_tags_config: + enabled: true + tags: + tag1: value1 + tag2: value2 + tags_columns: + tag1: + value1: [ column1, column2 ] + inherited_tags: [ tag1, tag2 ] +``` + +- `lf_grants` (`default=none`) + - Lake Formation grants config for data_cell filters + - Format: + + ```python + lf_grants={ + 'data_cell_filters': { + 'enabled': True | False, + 'filters': { + 'filter_name': { + 'row_filter': '+ +
+ + + +- [Features](#features) + - [Quick start](#quick-start) + - [Installation](#installation) + - [Prerequisites](#prerequisites) + - [Credentials](#credentials) + - [Configuring your profile](#configuring-your-profile) + - [Additional information](#additional-information) + - [Models](#models) + - [Table configuration](#table-configuration) + - [Table location](#table-location) + - [Incremental models](#incremental-models) + - [On schema change](#on-schema-change) + - [Iceberg](#iceberg) + - [Highly available table (HA)](#highly-available-table-ha) + - [HA known issues](#ha-known-issues) + - [Update glue data catalog](#update-glue-data-catalog) + - [Snapshots](#snapshots) + - [Timestamp strategy](#timestamp-strategy) + - [Check strategy](#check-strategy) + - [Hard-deletes](#hard-deletes) + - [Working example](#working-example) + - [Snapshots known issues](#snapshots-known-issues) + - [AWS Lake Formation integration](#aws-lake-formation-integration) + - [Python models](#python-models) + - [Contracts](#contracts) + - [Contributing](#contributing) + - [Contributors ✨](#contributors-) + + +# Features + +- Supports dbt version `1.7.*` +- Support for Python +- Supports [seeds][seeds] +- Correctly detects views and their columns +- Supports [table materialization][table] + - [Iceberg tables][athena-iceberg] are supported **only with Athena Engine v3** and **a unique table location** + (see table location section below) + - Hive tables are supported by both Athena engines +- Supports [incremental models][incremental] + - On Iceberg tables: + - Supports the use of `unique_key` only with the `merge` strategy + - Supports the `append` strategy + - On Hive tables: + - Supports two incremental update strategies: `insert_overwrite` and `append` + - Does **not** support the use of `unique_key` +- Supports [snapshots][snapshots] +- Supports [Python models][python-models] + +[seeds]: https://docs.getdbt.com/docs/building-a-dbt-project/seeds + +[incremental]: https://docs.getdbt.com/docs/build/incremental-models + +[table]: https://docs.getdbt.com/docs/build/materializations#table + +[python-models]: https://docs.getdbt.com/docs/build/python-models#configuring-python-models + +[athena-iceberg]: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html + +[snapshots]: https://docs.getdbt.com/docs/build/snapshots + +## Quick start + +### Installation + +- `pip install dbt-athena-community` +- Or `pip install git+https://github.com/dbt-athena/dbt-athena.git` + +### Prerequisites + +To start, you will need an S3 bucket, for instance `my-bucket` and an Athena database: + +```sql +CREATE DATABASE IF NOT EXISTS analytics_dev +COMMENT 'Analytics models generated by dbt (development)' +LOCATION 's3://my-bucket/' +WITH DBPROPERTIES ('creator'='Foo Bar', 'email'='foo@bar.com'); +``` + +Notes: + +- Take note of your AWS region code (e.g. `us-west-2` or `eu-west-2`, etc.). +- You can also use [AWS Glue](https://docs.aws.amazon.com/athena/latest/ug/glue-athena.html) to create and manage Athena + databases. + +### Credentials + +Credentials can be passed directly to the adapter, or they can +be [determined automatically](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) based +on `aws cli`/`boto3` conventions. +You can either: + +- Configure `aws_access_key_id` and `aws_secret_access_key` +- Configure `aws_profile_name` to match a profile defined in your AWS credentials file. + Checkout dbt profile configuration below for details. + +### Configuring your profile + +A dbt profile can be configured to run against AWS Athena using the following configuration: + +| Option | Description | Required? | Example | +|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------| +| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` | +| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` | +| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` | +| s3_tmp_table_dir | Prefix for storing temporary tables, if different from the connection's `s3_data_dir` | Optional | `s3://bucket3/dbt/` | +| region_name | AWS region of your Athena instance | Required | `eu-west-1` | +| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` | +| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` | +| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` | +| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` | +| aws_access_key_id | Access key ID of the user performing requests | Optional | `AKIAIOSFODNN7EXAMPLE` | +| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | +| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | +| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | +| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` | +| num_retries | Number of times to retry a failing query | Optional | `3` | +| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | +| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` | +| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | +| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | + +**Example profiles.yml entry:** + +```yaml +athena: + target: dev + outputs: + dev: + type: athena + s3_staging_dir: s3://athena-query-results/dbt/ + s3_data_dir: s3://your_s3_bucket/dbt/ + s3_data_naming: schema_table + s3_tmp_table_dir: s3://your_s3_bucket/temp/ + region_name: eu-west-1 + schema: dbt + database: awsdatacatalog + threads: 4 + aws_profile_name: my-profile + work_group: my-workgroup + spark_work_group: my-spark-workgroup + seed_s3_upload_args: + ACL: bucket-owner-full-control +``` + +### Additional information + +- `threads` is supported +- `database` and `catalog` can be used interchangeably + +## Models + +### Table configuration + +- `external_location` (`default=none`) + - If set, the full S3 path to which the table will be saved + - Works only with incremental models + - Does not work with Hive table with `ha` set to true +- `partitioned_by` (`default=none`) + - An array list of columns by which the table will be partitioned + - Limited to creation of 100 partitions (*currently*) +- `bucketed_by` (`default=none`) + - An array list of columns to bucket data, ignored if using Iceberg +- `bucket_count` (`default=none`) + - The number of buckets for bucketing your data, ignored if using Iceberg +- `table_type` (`default='hive'`) + - The type of table + - Supports `hive` or `iceberg` +- `ha` (`default=false`) + - If the table should be built using the high-availability method. This option is only available for Hive tables + since it is by default for Iceberg tables (see the section [below](#highly-available-table-ha)) +- `format` (`default='parquet'`) + - The data format for the table + - Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, `TEXTFILE` +- `write_compression` (`default=none`) + - The compression type to use for any storage format that allows compression to be specified. To see which options are + available, check out [CREATE TABLE AS][create-table-as] +- `field_delimiter` (`default=none`) + - Custom field delimiter, for when format is set to `TEXTFILE` +- `table_properties`: table properties to add to the table, valid for Iceberg only +- `native_drop`: Relation drop operations will be performed with SQL, not direct Glue API calls. No S3 calls will be + made to manage data in S3. Data in S3 will only be cleared up for Iceberg + tables [see AWS docs](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-managing-tables.html). Note that + Iceberg DROP TABLE operations may timeout if they take longer than 60 seconds. +- `seed_by_insert` (`default=false`) + - Default behaviour uploads seed data to S3. This flag will create seeds using an SQL insert statement + - Large seed files cannot use `seed_by_insert`, as the SQL insert statement would + exceed [the Athena limit of 262144 bytes](https://docs.aws.amazon.com/athena/latest/ug/service-limits.html) +- `force_batch` (`default=false`) + - Skip creating the table as CTAS and run the operation directly in batch insert mode + - This is particularly useful when the standard table creation process fails due to partition limitations, + allowing you to work with temporary tables and persist the dataset more efficiently +- `unique_tmp_table_suffix` (`default=false`) + - For incremental models using insert overwrite strategy on hive table + - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid + - Useful if you are looking to run multiple dbt build inserting in the same table in parallel +- `temp_schema` (`default=none`) + - For incremental models, it allows to define a schema to hold temporary create statements + used in incremental model runs + - Schema will be created in the model target database if does not exist +- `lf_tags_config` (`default=none`) + - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns + - `enabled` (`default=False`) whether LF tags management is enabled for a model + - `tags` dictionary with tags and their values to assign for the model + - `tags_columns` dictionary with a tag key, value and list of columns they must be assigned to + - `lf_inherited_tags` (`default=none`) + - List of Lake Formation tag keys that are intended to be inherited from the database level and thus shouldn't be + removed during association of those defined in `lf_tags_config` + - i.e., the default behavior of `lf_tags_config` is to be exhaustive and first remove any pre-existing tags from + tables and columns before associating the ones currently defined for a given model + - This breaks tag inheritance as inherited tags appear on tables and columns like those associated directly + +```sql +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change='append_new_columns', + table_type='iceberg', + schema='test_schema', + lf_tags_config={ + 'enabled': true, + 'tags': { + 'tag1': 'value1', + 'tag2': 'value2' + }, + 'tags_columns': { + 'tag1': { + 'value1': ['column1', 'column2'], + 'value2': ['column3', 'column4'] + } + }, + 'inherited_tags': ['tag1', 'tag2'] + } + ) +}} +``` + +- Format for `dbt_project.yml`: + +```yaml + +lf_tags_config: + enabled: true + tags: + tag1: value1 + tag2: value2 + tags_columns: + tag1: + value1: [ column1, column2 ] + inherited_tags: [ tag1, tag2 ] +``` + +- `lf_grants` (`default=none`) + - Lake Formation grants config for data_cell filters + - Format: + + ```python + lf_grants={ + 'data_cell_filters': { + 'enabled': True | False, + 'filters': { + 'filter_name': { + 'row_filter': 'Tuple[str, str]: + """Formats a value based on its column type for inclusion in a SQL query.""" + comp_func = "=" # Default comparison function + if value is None: + return "null", " is " + elif column_type == "integer": + return str(value), comp_func + elif column_type == "string": + # Properly escape single quotes in the string value + escaped_value = str(value).replace("'", "''") + return f"'{escaped_value}'", comp_func + elif column_type == "date": + return f"DATE'{value}'", comp_func + elif column_type == "timestamp": + return f"TIMESTAMP'{value}'", comp_func + else: + # Raise an error for unsupported column types + raise ValueError(f"Unsupported column type: {column_type}") + + @available + def run_operation_with_potential_multiple_runs(self, query: str, op: str) -> None: + while True: + try: + self._run_query(query, catch_partitions_limit=False) + break + except OperationalError as e: + if f"ICEBERG_{op.upper()}_MORE_RUNS_NEEDED" not in str(e): + raise e + + def _run_query(self, sql: str, catch_partitions_limit: bool) -> AthenaCursor: + query = self.connections._add_query_comment(sql) + conn = self.connections.get_thread_connection() + cursor: AthenaCursor = conn.handle.cursor() + LOGGER.debug(f"Running Athena query:\n{query}") + try: + cursor.execute(query, catch_partitions_limit=catch_partitions_limit) + except OperationalError as e: + LOGGER.debug(f"CAUGHT EXCEPTION: {e}") + raise e + return cursor diff --git a/dbt-athena/src/dbt/adapters/athena/lakeformation.py b/dbt-athena/src/dbt/adapters/athena/lakeformation.py new file mode 100644 index 00000000..86b51d01 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/lakeformation.py @@ -0,0 +1,287 @@ +"""AWS Lakeformation permissions management helper utilities.""" + +from typing import Dict, List, Optional, Sequence, Set, Union + +from dbt_common.exceptions import DbtRuntimeError +from mypy_boto3_lakeformation import LakeFormationClient +from mypy_boto3_lakeformation.type_defs import ( + AddLFTagsToResourceResponseTypeDef, + BatchPermissionsRequestEntryTypeDef, + ColumnLFTagTypeDef, + DataCellsFilterTypeDef, + GetResourceLFTagsResponseTypeDef, + LFTagPairTypeDef, + RemoveLFTagsFromResourceResponseTypeDef, + ResourceTypeDef, +) +from pydantic import BaseModel + +from dbt.adapters.athena.relation import AthenaRelation +from dbt.adapters.events.logging import AdapterLogger + +logger = AdapterLogger("AthenaLakeFormation") + + +class LfTagsConfig(BaseModel): + enabled: bool = False + tags: Optional[Dict[str, str]] = None + tags_columns: Optional[Dict[str, Dict[str, List[str]]]] = None + inherited_tags: Optional[List[str]] = None + + +class LfTagsManager: + def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_tags_config: LfTagsConfig): + self.lf_client = lf_client + self.database = relation.schema + self.table = relation.identifier + self.lf_tags = lf_tags_config.tags + self.lf_tags_columns = lf_tags_config.tags_columns + self.lf_inherited_tags = set(lf_tags_config.inherited_tags) if lf_tags_config.inherited_tags else set() + + def process_lf_tags_database(self) -> None: + if self.lf_tags: + database_resource = {"Database": {"Name": self.database}} + response = self.lf_client.add_lf_tags_to_resource( + Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + + def process_lf_tags(self) -> None: + table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}} + existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource) + self._remove_lf_tags_columns(existing_lf_tags) + self._apply_lf_tags_table(table_resource, existing_lf_tags) + self._apply_lf_tags_columns() + + @staticmethod + def _column_tags_to_remove( + lf_tags_columns: List[ColumnLFTagTypeDef], lf_inherited_tags: Set[str] + ) -> Dict[str, Dict[str, List[str]]]: + to_remove = {} + + for column in lf_tags_columns: + non_inherited_tags = [tag for tag in column["LFTags"] if not tag["TagKey"] in lf_inherited_tags] + for tag in non_inherited_tags: + tag_key = tag["TagKey"] + tag_value = tag["TagValues"][0] + if tag_key not in to_remove: + to_remove[tag_key] = {tag_value: [column["Name"]]} + elif tag_value not in to_remove[tag_key]: + to_remove[tag_key][tag_value] = [column["Name"]] + else: + to_remove[tag_key][tag_value].append(column["Name"]) + + return to_remove + + def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTypeDef) -> None: + lf_tags_columns = existing_lf_tags.get("LFTagsOnColumns", []) + logger.debug(f"COLUMNS: {lf_tags_columns}") + if lf_tags_columns: + to_remove = LfTagsManager._column_tags_to_remove(lf_tags_columns, self.lf_inherited_tags) + logger.debug(f"TO REMOVE: {to_remove}") + for tag_key, tag_config in to_remove.items(): + for tag_value, columns in tag_config.items(): + resource = { + "TableWithColumns": {"DatabaseName": self.database, "Name": self.table, "ColumnNames": columns} + } + response = self.lf_client.remove_lf_tags_from_resource( + Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}] + ) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove") + + @staticmethod + def _table_tags_to_remove( + lf_tags_table: List[LFTagPairTypeDef], lf_tags: Optional[Dict[str, str]], lf_inherited_tags: Set[str] + ) -> Dict[str, Sequence[str]]: + return { + tag["TagKey"]: tag["TagValues"] + for tag in lf_tags_table + if tag["TagKey"] not in (lf_tags or {}) + if tag["TagKey"] not in lf_inherited_tags + } + + def _apply_lf_tags_table( + self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef + ) -> None: + lf_tags_table = existing_lf_tags.get("LFTagsOnTable", []) + logger.debug(f"EXISTING TABLE TAGS: {lf_tags_table}") + logger.debug(f"CONFIG TAGS: {self.lf_tags}") + + to_remove = LfTagsManager._table_tags_to_remove(lf_tags_table, self.lf_tags, self.lf_inherited_tags) + + logger.debug(f"TAGS TO REMOVE: {to_remove}") + if to_remove: + response = self.lf_client.remove_lf_tags_from_resource( + Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags, "remove") + + if self.lf_tags: + response = self.lf_client.add_lf_tags_to_resource( + Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + + def _apply_lf_tags_columns(self) -> None: + if self.lf_tags_columns: + for tag_key, tag_config in self.lf_tags_columns.items(): + for tag_value, columns in tag_config.items(): + resource = { + "TableWithColumns": {"DatabaseName": self.database, "Name": self.table, "ColumnNames": columns} + } + response = self.lf_client.add_lf_tags_to_resource( + Resource=resource, + LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}], + ) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}) + + def _parse_and_log_lf_response( + self, + response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef], + columns: Optional[List[str]] = None, + lf_tags: Optional[Dict[str, str]] = None, + verb: str = "add", + ) -> None: + table_appendix = f".{self.table}" if self.table else "" + columns_appendix = f" for columns {columns}" if columns else "" + resource_msg = self.database + table_appendix + columns_appendix + if failures := response.get("Failures", []): + base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg + for failure in failures: + tag = failure.get("LFTag", {}).get("TagKey") + error = failure.get("Error", {}).get("ErrorMessage") + logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}") + raise DbtRuntimeError(base_msg) + logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg) + + +class FilterConfig(BaseModel): + row_filter: str + column_names: List[str] = [] + principals: List[str] = [] + + def to_api_repr(self, catalog_id: str, database: str, table: str, name: str) -> DataCellsFilterTypeDef: + return { + "TableCatalogId": catalog_id, + "DatabaseName": database, + "TableName": table, + "Name": name, + "RowFilter": {"FilterExpression": self.row_filter}, + "ColumnNames": self.column_names, + "ColumnWildcard": {"ExcludedColumnNames": []}, + } + + def to_update(self, existing: DataCellsFilterTypeDef) -> bool: + return self.row_filter != existing["RowFilter"]["FilterExpression"] or set(self.column_names) != set( + existing["ColumnNames"] + ) + + +class DataCellFiltersConfig(BaseModel): + enabled: bool = False + filters: Dict[str, FilterConfig] + + +class LfGrantsConfig(BaseModel): + data_cell_filters: DataCellFiltersConfig + + +class LfPermissions: + def __init__(self, catalog_id: str, relation: AthenaRelation, lf_client: LakeFormationClient) -> None: + self.catalog_id = catalog_id + self.relation = relation + self.database: str = relation.schema + self.table: str = relation.identifier + self.lf_client = lf_client + + def get_filters(self) -> Dict[str, DataCellsFilterTypeDef]: + table_resource = {"CatalogId": self.catalog_id, "DatabaseName": self.database, "Name": self.table} + return {f["Name"]: f for f in self.lf_client.list_data_cells_filter(Table=table_resource)["DataCellsFilters"]} + + def process_filters(self, config: LfGrantsConfig) -> None: + current_filters = self.get_filters() + logger.debug(f"CURRENT FILTERS: {current_filters}") + + to_drop = [f for name, f in current_filters.items() if name not in config.data_cell_filters.filters] + logger.debug(f"FILTERS TO DROP: {to_drop}") + for f in to_drop: + self.lf_client.delete_data_cells_filter( + TableCatalogId=f["TableCatalogId"], + DatabaseName=f["DatabaseName"], + TableName=f["TableName"], + Name=f["Name"], + ) + + to_add = [ + f.to_api_repr(self.catalog_id, self.database, self.table, name) + for name, f in config.data_cell_filters.filters.items() + if name not in current_filters + ] + logger.debug(f"FILTERS TO ADD: {to_add}") + for f in to_add: + self.lf_client.create_data_cells_filter(TableData=f) + + to_update = [ + f.to_api_repr(self.catalog_id, self.database, self.table, name) + for name, f in config.data_cell_filters.filters.items() + if name in current_filters and f.to_update(current_filters[name]) + ] + logger.debug(f"FILTERS TO UPDATE: {to_update}") + for f in to_update: + self.lf_client.update_data_cells_filter(TableData=f) + + def process_permissions(self, config: LfGrantsConfig) -> None: + for name, f in config.data_cell_filters.filters.items(): + logger.debug(f"Start processing permissions for filter: {name}") + current_permissions = self.lf_client.list_permissions( + Resource={ + "DataCellsFilter": { + "TableCatalogId": self.catalog_id, + "DatabaseName": self.database, + "TableName": self.table, + "Name": name, + } + } + )["PrincipalResourcePermissions"] + + current_principals = {p["Principal"]["DataLakePrincipalIdentifier"] for p in current_permissions} + + to_revoke = {p for p in current_principals if p not in f.principals} + if to_revoke: + self.lf_client.batch_revoke_permissions( + CatalogId=self.catalog_id, + Entries=[self._permission_entry(name, principal, idx) for idx, principal in enumerate(to_revoke)], + ) + revoke_principals_msg = "\n".join(to_revoke) + logger.debug(f"Revoked permissions for filter {name} from principals:\n{revoke_principals_msg}") + else: + logger.debug(f"No redundant permissions found for filter: {name}") + + to_add = {p for p in f.principals if p not in current_principals} + if to_add: + self.lf_client.batch_grant_permissions( + CatalogId=self.catalog_id, + Entries=[self._permission_entry(name, principal, idx) for idx, principal in enumerate(to_add)], + ) + add_principals_msg = "\n".join(to_add) + logger.debug(f"Granted permissions for filter {name} to principals:\n{add_principals_msg}") + else: + logger.debug(f"No new permissions added for filter {name}") + + logger.debug(f"Permissions are set to be consistent with config for filter: {name}") + + def _permission_entry(self, filter_name: str, principal: str, idx: int) -> BatchPermissionsRequestEntryTypeDef: + return { + "Id": str(idx), + "Principal": {"DataLakePrincipalIdentifier": principal}, + "Resource": { + "DataCellsFilter": { + "TableCatalogId": self.catalog_id, + "DatabaseName": self.database, + "TableName": self.table, + "Name": filter_name, + } + }, + "Permissions": ["SELECT"], + "PermissionsWithGrantOption": [], + } diff --git a/dbt-athena/src/dbt/adapters/athena/python_submissions.py b/dbt-athena/src/dbt/adapters/athena/python_submissions.py new file mode 100644 index 00000000..5a3799ec --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/python_submissions.py @@ -0,0 +1,256 @@ +import time +from functools import cached_property +from typing import Any, Dict + +import botocore +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.config import AthenaSparkSessionConfig +from dbt.adapters.athena.connections import AthenaCredentials +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.athena.session import AthenaSparkSessionManager +from dbt.adapters.base import PythonJobHelper + +SUBMISSION_LANGUAGE = "python" + + +class AthenaPythonJobHelper(PythonJobHelper): + """ + Default helper to execute python models with Athena Spark. + + Args: + PythonJobHelper (PythonJobHelper): The base python helper class + """ + + def __init__(self, parsed_model: Dict[Any, Any], credentials: AthenaCredentials) -> None: + """ + Initialize spark config and connection. + + Args: + parsed_model (Dict[Any, Any]): The parsed python model. + credentials (AthenaCredentials): Credentials for Athena connection. + """ + self.relation_name = parsed_model.get("relation_name", None) + self.config = AthenaSparkSessionConfig( + parsed_model.get("config", {}), + polling_interval=credentials.poll_interval, + retry_attempts=credentials.num_retries, + ) + self.spark_connection = AthenaSparkSessionManager( + credentials, self.timeout, self.polling_interval, self.engine_config, self.relation_name + ) + + @cached_property + def timeout(self) -> int: + """ + Get the timeout value. + + Returns: + int: The timeout value in seconds. + """ + return self.config.set_timeout() + + @cached_property + def session_id(self) -> str: + """ + Get the session ID. + + Returns: + str: The session ID as a string. + """ + return str(self.spark_connection.get_session_id()) + + @cached_property + def polling_interval(self) -> float: + """ + Get the polling interval. + + Returns: + float: The polling interval in seconds. + """ + return self.config.set_polling_interval() + + @cached_property + def engine_config(self) -> Dict[str, int]: + """ + Get the engine configuration. + + Returns: + Dict[str, int]: A dictionary containing the engine configuration. + """ + return self.config.set_engine_config() + + @cached_property + def athena_client(self) -> Any: + """ + Get the Athena client. + + Returns: + Any: The Athena client object. + """ + return self.spark_connection.athena_client + + def get_current_session_status(self) -> Any: + """ + Get the current session status. + + Returns: + Any: The status of the session + """ + return self.spark_connection.get_session_status(self.session_id) + + def submit(self, compiled_code: str) -> Any: + """ + Submit a calculation to Athena. + + This function submits a calculation to Athena for execution using the provided compiled code. + It starts a calculation execution with the current session ID and the compiled code as the code block. + The function then polls until the calculation execution is completed, and retrieves the result. + If the execution is successful and completed, the result S3 URI is returned. Otherwise, a DbtRuntimeError + is raised with the execution status. + + Args: + compiled_code (str): The compiled code to submit for execution. + + Returns: + dict: The result S3 URI if the execution is successful and completed. + + Raises: + DbtRuntimeError: If the execution ends in a state other than "COMPLETED". + + """ + # Seeing an empty calculation along with main python model code calculation is submitted for almost every model + # Also, if not returning the result json, we are getting green ERROR messages instead of OK messages. + # And with this handling, the run model code in target folder every model under run folder seems to be empty + # Need to fix this work around solution + if compiled_code.strip(): + while True: + try: + LOGGER.debug( + f"Model {self.relation_name} - Using session: {self.session_id} to start calculation execution." + ) + calculation_execution_id = self.athena_client.start_calculation_execution( + SessionId=self.session_id, CodeBlock=compiled_code.lstrip() + )["CalculationExecutionId"] + break + except botocore.exceptions.ClientError as ce: + LOGGER.exception(f"Encountered client error: {ce}") + if ( + ce.response["Error"]["Code"] == "InvalidRequestException" + and "Session is in the BUSY state; needs to be IDLE to accept Calculations." + in ce.response["Error"]["Message"] + ): + LOGGER.exception("Going to poll until session is IDLE") + self.poll_until_session_idle() + except Exception as e: + raise DbtRuntimeError(f"Unable to start spark python code execution. Got: {e}") + execution_status = self.poll_until_execution_completion(calculation_execution_id) + LOGGER.debug(f"Model {self.relation_name} - Received execution status {execution_status}") + if execution_status == "COMPLETED": + try: + result = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + )["Result"] + except Exception as e: + LOGGER.error(f"Unable to retrieve results: Got: {e}") + result = {} + return result + else: + return {"ResultS3Uri": "string", "ResultType": "string", "StdErrorS3Uri": "string", "StdOutS3Uri": "string"} + + def poll_until_session_idle(self) -> None: + """ + Polls the session status until it becomes idle or exceeds the timeout. + + Raises: + DbtRuntimeError: If the session chosen is not available or if it does not become idle within the timeout. + """ + polling_interval = self.polling_interval + timer: float = 0 + while True: + session_status = self.get_current_session_status()["State"] + if session_status in ["TERMINATING", "TERMINATED", "DEGRADED", "FAILED"]: + LOGGER.debug( + f"Model {self.relation_name} - The session: {self.session_id} was not available. " + f"Got status: {session_status}. Will try with a different session." + ) + self.spark_connection.remove_terminated_session(self.session_id) + if "session_id" in self.__dict__: + del self.__dict__["session_id"] + break + if session_status == "IDLE": + break + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + LOGGER.debug( + f"Model {self.relation_name} - Session {self.session_id} did not become free within {self.timeout}" + " seconds. Will try with a different session." + ) + if "session_id" in self.__dict__: + del self.__dict__["session_id"] + break + + def poll_until_execution_completion(self, calculation_execution_id: str) -> Any: + """ + Poll the status of a calculation execution until it is completed, failed, or canceled. + + This function polls the status of a calculation execution identified by the given `calculation_execution_id` + until it is completed, failed, or canceled. It uses the Athena client to retrieve the status of the execution + and checks if the state is one of "COMPLETED", "FAILED", or "CANCELED". If the execution is not yet completed, + the function sleeps for a certain polling interval, which starts with the value of `self.polling_interval` and + doubles after each iteration until it reaches the `self.timeout` period. If the execution does not complete + within the timeout period, a `DbtRuntimeError` is raised. + + Args: + calculation_execution_id (str): The ID of the calculation execution to poll. + + Returns: + str: The final state of the calculation execution, which can be one of "COMPLETED", "FAILED" or "CANCELED". + + Raises: + DbtRuntimeError: If the calculation execution does not complete within the timeout period. + + """ + try: + polling_interval = self.polling_interval + timer: float = 0 + while True: + execution_response = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + ) + execution_session = execution_response.get("SessionId", None) + execution_status = execution_response.get("Status", None) + execution_result = execution_response.get("Result", None) + execution_stderr_s3_path = "" + if execution_result: + execution_stderr_s3_path = execution_result.get("StdErrorS3Uri", None) + + execution_status_state = "" + execution_status_reason = "" + if execution_status: + execution_status_state = execution_status.get("State", None) + execution_status_reason = execution_status.get("StateChangeReason", None) + + if execution_status_state in ["FAILED", "CANCELED"]: + raise DbtRuntimeError( + f"""Calculation Id: {calculation_execution_id} +Session Id: {execution_session} +Status: {execution_status_state} +Reason: {execution_status_reason} +Stderr s3 path: {execution_stderr_s3_path} +""" + ) + + if execution_status_state == "COMPLETED": + return execution_status_state + + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + self.athena_client.stop_calculation_execution(CalculationExecutionId=calculation_execution_id) + raise DbtRuntimeError( + f"Execution {calculation_execution_id} did not complete within {self.timeout} seconds." + ) + finally: + self.spark_connection.set_spark_session_load(self.session_id, -1) diff --git a/dbt-athena/src/dbt/adapters/athena/query_headers.py b/dbt-athena/src/dbt/adapters/athena/query_headers.py new file mode 100644 index 00000000..5220299a --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/query_headers.py @@ -0,0 +1,43 @@ +from typing import Any, Dict + +from dbt.adapters.base.query_headers import MacroQueryStringSetter, _QueryComment +from dbt.adapters.contracts.connection import AdapterRequiredConfig + + +class AthenaMacroQueryStringSetter(MacroQueryStringSetter): + def __init__(self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any]): + super().__init__(config, query_header_context) + self.comment = _AthenaQueryComment(None) + + +class _AthenaQueryComment(_QueryComment): + """ + Athena DDL does not always respect /* ... */ block quotations. + This function is the same as _QueryComment.add except that + a leading "-- " is prepended to the query_comment and any newlines + in the query_comment are replaced with " ". This allows the default + query_comment to be added to `create external table` statements. + """ + + def add(self, sql: str) -> str: + if not self.query_comment: + return sql + + # alter or vacuum statements don't seem to support properly query comments + # let's just exclude them + sql = sql.lstrip() + if any(sql.lower().startswith(keyword) for keyword in ["alter", "drop", "optimize", "vacuum", "msck"]): + return sql + + cleaned_query_comment = self.query_comment.strip().replace("\n", " ") + + if self.append: + # replace last ';' with ';' + sql = sql.rstrip() + if sql[-1] == ";": + sql = sql[:-1] + return f"{sql}\n-- /* {cleaned_query_comment} */;" + + return f"{sql}\n-- /* {cleaned_query_comment} */" + + return f"-- /* {cleaned_query_comment} */\n{sql}" diff --git a/dbt-athena/src/dbt/adapters/athena/relation.py b/dbt-athena/src/dbt/adapters/athena/relation.py new file mode 100644 index 00000000..0de8d784 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/relation.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Optional, Set + +from mypy_boto3_glue.type_defs import TableTypeDef + +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy + + +class TableType(Enum): + TABLE = "table" + VIEW = "view" + CTE = "cte" + MATERIALIZED_VIEW = "materializedview" + ICEBERG = "iceberg_table" + + def is_physical(self) -> bool: + return self in [TableType.TABLE, TableType.ICEBERG] + + +@dataclass +class AthenaIncludePolicy(Policy): + database: bool = True + schema: bool = True + identifier: bool = True + + +@dataclass +class AthenaHiveIncludePolicy(Policy): + database: bool = False + schema: bool = True + identifier: bool = True + + +@dataclass(frozen=True, eq=False, repr=False) +class AthenaRelation(BaseRelation): + quote_character: str = '"' # Presto quote character + include_policy: Policy = field(default_factory=lambda: AthenaIncludePolicy()) + s3_path_table_part: Optional[str] = None + detailed_table_type: Optional[str] = None # table_type option from the table Parameters in Glue Catalog + require_alias: bool = False + + def render_hive(self) -> str: + """ + Render relation with Hive format. Athena uses a Hive format for some DDL statements. + + See: + - https://aws.amazon.com/athena/faqs/ "Q: How do I create tables and schemas for my data on Amazon S3?" + - https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + """ + + old_quote_character = self.quote_character + object.__setattr__(self, "quote_character", "`") # Hive quote char + old_include_policy = self.include_policy + object.__setattr__(self, "include_policy", AthenaHiveIncludePolicy()) + rendered = self.render() + object.__setattr__(self, "quote_character", old_quote_character) + object.__setattr__(self, "include_policy", old_include_policy) + return str(rendered) + + def render_pure(self) -> str: + """ + Render relation without quotes characters. + This is needed for not standard executions like optimize and vacuum + """ + old_value = self.quote_character + object.__setattr__(self, "quote_character", "") + rendered = self.render() + object.__setattr__(self, "quote_character", old_value) + return str(rendered) + + +class AthenaSchemaSearchMap(Dict[InformationSchema, Dict[str, Set[Optional[str]]]]): + """A utility class to keep track of what information_schema tables to + search for what schemas and relations. The schema and relation values are all + lowercase to avoid duplication. + """ + + def add(self, relation: AthenaRelation) -> None: + key = relation.information_schema_only() + if key not in self: + self[key] = {} + if relation.schema is not None: + schema = relation.schema.lower() + relation_name = relation.name.lower() + if schema not in self[key]: + self[key][schema] = set() + self[key][schema].add(relation_name) + + +RELATION_TYPE_MAP = { + "EXTERNAL_TABLE": TableType.TABLE, + "EXTERNAL": TableType.TABLE, # type returned by federated query tables + "GOVERNED": TableType.TABLE, + "MANAGED_TABLE": TableType.TABLE, + "VIRTUAL_VIEW": TableType.VIEW, + "table": TableType.TABLE, + "view": TableType.VIEW, + "cte": TableType.CTE, + "materializedview": TableType.MATERIALIZED_VIEW, +} + + +def get_table_type(table: TableTypeDef) -> TableType: + table_full_name = ".".join(filter(None, [table.get("CatalogId"), table.get("DatabaseName"), table["Name"]])) + + input_table_type = table.get("TableType") + if input_table_type and input_table_type not in RELATION_TYPE_MAP: + raise ValueError(f"Table type {table['TableType']} is not supported for table {table_full_name}") + + if table.get("Parameters", {}).get("table_type", "").lower() == "iceberg": + _type = TableType.ICEBERG + elif not input_table_type: + raise ValueError(f"Table type cannot be None for table {table_full_name}") + else: + _type = RELATION_TYPE_MAP[input_table_type] + + LOGGER.debug(f"table_name : {table_full_name}") + LOGGER.debug(f"table type : {_type}") + + return _type diff --git a/dbt-athena/src/dbt/adapters/athena/s3.py b/dbt-athena/src/dbt/adapters/athena/s3.py new file mode 100644 index 00000000..94e1270f --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/s3.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class S3DataNaming(Enum): + UNIQUE = "unique" + TABLE = "table" + TABLE_UNIQUE = "table_unique" + SCHEMA_TABLE = "schema_table" + SCHEMA_TABLE_UNIQUE = "schema_table_unique" diff --git a/dbt-athena/src/dbt/adapters/athena/session.py b/dbt-athena/src/dbt/adapters/athena/session.py new file mode 100644 index 00000000..b346d13e --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/session.py @@ -0,0 +1,263 @@ +import json +import threading +import time +from functools import cached_property +from hashlib import md5 +from typing import Any, Dict +from uuid import UUID + +import boto3 +import boto3.session +from dbt_common.exceptions import DbtRuntimeError +from dbt_common.invocation import get_invocation_id + +from dbt.adapters.athena.config import get_boto3_config +from dbt.adapters.athena.constants import ( + DEFAULT_THREAD_COUNT, + LOGGER, + SESSION_IDLE_TIMEOUT_MIN, +) +from dbt.adapters.contracts.connection import Connection + +invocation_id = get_invocation_id() +spark_session_list: Dict[UUID, str] = {} +spark_session_load: Dict[UUID, int] = {} + + +def get_boto3_session(connection: Connection) -> boto3.session.Session: + return boto3.session.Session( + aws_access_key_id=connection.credentials.aws_access_key_id, + aws_secret_access_key=connection.credentials.aws_secret_access_key, + aws_session_token=connection.credentials.aws_session_token, + region_name=connection.credentials.region_name, + profile_name=connection.credentials.aws_profile_name, + ) + + +def get_boto3_session_from_credentials(credentials: Any) -> boto3.session.Session: + return boto3.session.Session( + aws_access_key_id=credentials.aws_access_key_id, + aws_secret_access_key=credentials.aws_secret_access_key, + aws_session_token=credentials.aws_session_token, + region_name=credentials.region_name, + profile_name=credentials.aws_profile_name, + ) + + +class AthenaSparkSessionManager: + """ + A helper class to manage Athena Spark Sessions. + """ + + def __init__( + self, + credentials: Any, + timeout: int, + polling_interval: float, + engine_config: Dict[str, int], + relation_name: str = "N/A", + ) -> None: + """ + Initialize the AthenaSparkSessionManager instance. + + Args: + credentials (Any): The credentials to be used. + timeout (int): The timeout value in seconds. + polling_interval (float): The polling interval in seconds. + engine_config (Dict[str, int]): The engine configuration. + + """ + self.credentials = credentials + self.timeout = timeout + self.polling_interval = polling_interval + self.engine_config = engine_config + self.lock = threading.Lock() + self.relation_name = relation_name + + @cached_property + def spark_threads(self) -> int: + """ + Get the number of Spark threads. + + Returns: + int: The number of Spark threads. If not found in the profile, returns the default thread count. + """ + if not DEFAULT_THREAD_COUNT: + LOGGER.debug(f"""Threads not found in profile. Got: {DEFAULT_THREAD_COUNT}""") + return 1 + return int(DEFAULT_THREAD_COUNT) + + @cached_property + def spark_work_group(self) -> str: + """ + Get the Spark work group. + + Returns: + str: The Spark work group. Raises an exception if not found in the profile. + """ + if not self.credentials.spark_work_group: + raise DbtRuntimeError(f"Expected spark_work_group in profile. Got: {self.credentials.spark_work_group}") + return str(self.credentials.spark_work_group) + + @cached_property + def athena_client(self) -> Any: + """ + Get the AWS Athena client. + + This function returns an AWS Athena client object that can be used to interact with the Athena service. + The client is created using the region name and profile name provided during object instantiation. + + Returns: + Any: The Athena client object. + + """ + return get_boto3_session_from_credentials(self.credentials).client( + "athena", config=get_boto3_config(num_retries=self.credentials.effective_num_retries) + ) + + @cached_property + def session_description(self) -> str: + """ + Converts the engine configuration to md5 hash value + + Returns: + str: A concatenated text of dbt invocation_id and engine configuration's md5 hash + """ + hash_desc = md5(json.dumps(self.engine_config, sort_keys=True, ensure_ascii=True).encode("utf-8")).hexdigest() + return f"dbt: {invocation_id} - {hash_desc}" + + def get_session_id(self, session_query_capacity: int = 1) -> UUID: + """ + Get a session ID for the Spark session. + When does a new session get created: + - When thread limit not reached + - When thread limit reached but same engine configuration session is not available + - When thread limit reached and same engine configuration session exist and it is busy running a python model + and has one python model in queue (determined by session_query_capacity). + + Returns: + UUID: The session ID. + """ + session_list = list(spark_session_list.items()) + + if len(session_list) < self.spark_threads: + LOGGER.debug( + f"Within thread limit, creating new session for model: {self.relation_name}" + f" with session description: {self.session_description}." + ) + return self.start_session() + else: + matching_session_id = next( + ( + session_id + for session_id, description in session_list + if description == self.session_description + and spark_session_load.get(session_id, 0) <= session_query_capacity + ), + None, + ) + if matching_session_id: + LOGGER.debug( + f"Over thread limit, matching session found for model: {self.relation_name}" + f" with session description: {self.session_description} and has capacity." + ) + self.set_spark_session_load(str(matching_session_id), 1) + return matching_session_id + else: + LOGGER.debug( + f"Over thread limit, matching session not found or found with over capacity. Creating new session" + f" for model: {self.relation_name} with session description: {self.session_description}." + ) + return self.start_session() + + def start_session(self) -> UUID: + """ + Start an Athena session. + + This function sends a request to the Athena service to start a session in the specified Spark workgroup. + It configures the session with specific engine configurations. If the session state is not IDLE, the function + polls until the session creation is complete. The response containing session information is returned. + + Returns: + dict: The session information dictionary. + + """ + description = self.session_description + response = self.athena_client.start_session( + Description=description, + WorkGroup=self.credentials.spark_work_group, + EngineConfiguration=self.engine_config, + SessionIdleTimeoutInMinutes=SESSION_IDLE_TIMEOUT_MIN, + ) + session_id = response["SessionId"] + if response["State"] != "IDLE": + self.poll_until_session_creation(session_id) + + with self.lock: + spark_session_list[UUID(session_id)] = self.session_description + spark_session_load[UUID(session_id)] = 1 + + return UUID(session_id) + + def poll_until_session_creation(self, session_id: str) -> None: + """ + Polls the status of an Athena session creation until it is completed or reaches the timeout. + + Args: + session_id (str): The ID of the session being created. + + Returns: + str: The final status of the session, which will be "IDLE" if the session creation is successful. + + Raises: + DbtRuntimeError: If the session creation fails, is terminated, or degrades during polling. + DbtRuntimeError: If the session does not become IDLE within the specified timeout. + + """ + polling_interval = self.polling_interval + timer: float = 0 + while True: + creation_status_response = self.get_session_status(session_id) + creation_status_state = creation_status_response.get("State", "") + creation_status_reason = creation_status_response.get("StateChangeReason", "") + if creation_status_state in ["FAILED", "TERMINATED", "DEGRADED"]: + raise DbtRuntimeError( + f"Unable to create session: {session_id}. Got status: {creation_status_state}" + f" with reason: {creation_status_reason}." + ) + elif creation_status_state == "IDLE": + LOGGER.debug(f"Session: {session_id} created") + break + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + self.remove_terminated_session(session_id) + raise DbtRuntimeError(f"Session {session_id} did not create within {self.timeout} seconds.") + + def get_session_status(self, session_id: str) -> Any: + """ + Get the session status. + + Returns: + Any: The status of the session + """ + return self.athena_client.get_session_status(SessionId=session_id)["Status"] + + def remove_terminated_session(self, session_id: str) -> None: + """ + Removes session uuid from session list variable + + Returns: None + """ + with self.lock: + spark_session_list.pop(UUID(session_id), "Session id not found") + spark_session_load.pop(UUID(session_id), "Session id not found") + + def set_spark_session_load(self, session_id: str, change: int) -> None: + """ + Increase or decrease the session load variable + + Returns: None + """ + with self.lock: + spark_session_load[UUID(session_id)] = spark_session_load.get(UUID(session_id), 0) + change diff --git a/dbt-athena/src/dbt/adapters/athena/utils.py b/dbt-athena/src/dbt/adapters/athena/utils.py new file mode 100644 index 00000000..39ec755b --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/utils.py @@ -0,0 +1,63 @@ +import json +import re +from enum import Enum +from typing import Any, Generator, List, Optional, TypeVar + +from mypy_boto3_athena.type_defs import DataCatalogTypeDef + +from dbt.adapters.athena.constants import LOGGER + + +def clean_sql_comment(comment: str) -> str: + split_and_strip = [line.strip() for line in comment.split("\n")] + return " ".join(line for line in split_and_strip if line) + + +def stringify_table_parameter_value(value: Any) -> Optional[str]: + """Convert any variable to string for Glue Table property.""" + try: + if isinstance(value, (dict, list)): + value_str: str = json.dumps(value) + else: + value_str = str(value) + return value_str[:512000] + except (TypeError, ValueError) as e: + # Handle non-stringifiable objects and non-serializable objects + LOGGER.warning(f"Non-stringifiable object. Error: {str(e)}") + return None + + +def is_valid_table_parameter_key(key: str) -> bool: + """Check if key is valid for Glue Table property according to official documentation.""" + # Simplified version of key pattern which works with re + # Original pattern can be found here https://docs.aws.amazon.com/glue/latest/webapi/API_Table.html + key_pattern: str = r"^[\u0020-\uD7FF\uE000-\uFFFD\t]*$" + return len(key) <= 255 and bool(re.match(key_pattern, key)) + + +def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: + return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None + + +class AthenaCatalogType(Enum): + GLUE = "GLUE" + LAMBDA = "LAMBDA" + HIVE = "HIVE" + + +def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]: + return AthenaCatalogType(catalog["Type"]) if catalog else None + + +T = TypeVar("T") + + +def get_chunks(lst: List[T], n: int) -> Generator[List[T], None, None]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def ellipsis_comment(s: str, max_len: int = 255) -> str: + """Ellipsis string if it exceeds max length""" + return f"{s[:(max_len - 3)]}..." if len(s) > max_len else s diff --git a/dbt-athena/src/dbt/include/athena/__init__.py b/dbt-athena/src/dbt/include/athena/__init__.py new file mode 100644 index 00000000..b177e5d4 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/__init__.py @@ -0,0 +1,3 @@ +import os + +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt-athena/src/dbt/include/athena/dbt_project.yml b/dbt-athena/src/dbt/include/athena/dbt_project.yml new file mode 100644 index 00000000..471a15ed --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/dbt_project.yml @@ -0,0 +1,5 @@ +name: dbt_athena +version: 1.0 +config-version: 2 + +macro-paths: ['macros'] diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql new file mode 100644 index 00000000..bb106d3b --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql @@ -0,0 +1,20 @@ +{% macro athena__get_columns_in_relation(relation) -%} + {{ return(adapter.get_columns_in_relation(relation)) }} +{% endmacro %} + +{% macro athena__get_empty_schema_sql(columns) %} + {%- set col_err = [] -%} + select + {% for i in columns %} + {%- set col = columns[i] -%} + {%- if col['data_type'] is not defined -%} + {{ col_err.append(col['name']) }} + {%- else -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + cast(null as {{ dml_data_type(col['data_type']) }}) as {{ col_name }}{{ ", " if not loop.last }} + {%- endif -%} + {%- endfor -%} + {%- if (col_err | length) > 0 -%} + {{ exceptions.column_type_missing(column_names=col_err) }} + {%- endif -%} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql new file mode 100644 index 00000000..ae7a187c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql @@ -0,0 +1,17 @@ +{% macro athena__get_catalog(information_schema, schemas) -%} + {{ return(adapter.get_catalog()) }} +{%- endmacro %} + + +{% macro athena__list_schemas(database) -%} + {{ return(adapter.list_schemas(database)) }} +{% endmacro %} + + +{% macro athena__list_relations_without_caching(schema_relation) %} + {{ return(adapter.list_relations_without_caching(schema_relation)) }} +{% endmacro %} + +{% macro athena__get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql new file mode 100644 index 00000000..7b95bd7b --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql @@ -0,0 +1,14 @@ +{% macro athena__persist_docs(relation, model, for_relation=true, for_columns=true) -%} + {% set persist_relation_docs = for_relation and config.persist_relation_docs()%} + {% set persist_column_docs = for_columns and config.persist_column_docs() and model.columns %} + {% if persist_relation_docs or persist_column_docs %} + {% do adapter.persist_docs_to_glue( + relation=relation, + model=model, + persist_relation_docs=persist_relation_docs, + persist_column_docs=persist_column_docs, + skip_archive_table_version=true + ) + %}} + {% endif %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql new file mode 100644 index 00000000..f668acf7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql @@ -0,0 +1,95 @@ +{%- macro athena__py_save_table_as(compiled_code, target_relation, optional_args={}) -%} + {% set location = optional_args.get("location") %} + {% set format = optional_args.get("format", "parquet") %} + {% set mode = optional_args.get("mode", "overwrite") %} + {% set write_compression = optional_args.get("write_compression", "snappy") %} + {% set partitioned_by = optional_args.get("partitioned_by") %} + {% set bucketed_by = optional_args.get("bucketed_by") %} + {% set sorted_by = optional_args.get("sorted_by") %} + {% set merge_schema = optional_args.get("merge_schema", true) %} + {% set bucket_count = optional_args.get("bucket_count") %} + {% set field_delimiter = optional_args.get("field_delimiter") %} + {% set spark_ctas = optional_args.get("spark_ctas", "") %} + +import pyspark + + +{{ compiled_code }} +def materialize(spark_session, df, target_relation): + import pandas + if isinstance(df, pyspark.sql.dataframe.DataFrame): + pass + elif isinstance(df, pandas.core.frame.DataFrame): + df = spark_session.createDataFrame(df) + else: + msg = f"{type(df)} is not a supported type for dbt Python materialization" + raise Exception(msg) + +{% if spark_ctas|length > 0 %} + df.createOrReplaceTempView("{{ target_relation.schema}}_{{ target_relation.identifier }}_tmpvw") + spark_session.sql(""" + {{ spark_ctas }} + select * from {{ target_relation.schema}}_{{ target_relation.identifier }}_tmpvw + """) +{% else %} + writer = df.write \ + .format("{{ format }}") \ + .mode("{{ mode }}") \ + .option("path", "{{ location }}") \ + .option("compression", "{{ write_compression }}") \ + .option("mergeSchema", "{{ merge_schema }}") \ + .option("delimiter", "{{ field_delimiter }}") + if {{ partitioned_by }} is not None: + writer = writer.partitionBy({{ partitioned_by }}) + if {{ bucketed_by }} is not None: + writer = writer.bucketBy({{ bucket_count }},{{ bucketed_by }}) + if {{ sorted_by }} is not None: + writer = writer.sortBy({{ sorted_by }}) + + writer.saveAsTable( + name="{{ target_relation.schema}}.{{ target_relation.identifier }}", + ) +{% endif %} + + return "Success: {{ target_relation.schema}}.{{ target_relation.identifier }}" + +{{ athena__py_get_spark_dbt_object() }} + +dbt = SparkdbtObj() +df = model(dbt, spark) +materialize(spark, df, dbt.this) +{%- endmacro -%} + +{%- macro athena__py_execute_query(query) -%} +{{ athena__py_get_spark_dbt_object() }} + +def execute_query(spark_session): + spark_session.sql("""{{ query }}""") + return "OK" + +dbt = SparkdbtObj() +execute_query(spark) +{%- endmacro -%} + +{%- macro athena__py_get_spark_dbt_object() -%} +def get_spark_df(identifier): + """ + Override the arguments to ref and source dynamically + + spark.table('awsdatacatalog.analytics_dev.model') + Raises pyspark.sql.utils.AnalysisException: + spark_catalog requires a single-part namespace, + but got [awsdatacatalog, analytics_dev] + + So the override removes the catalog component and only + provides the schema and identifer to spark.table() + """ + return spark.table(".".join(identifier.split(".")[1:]).replace('"', '')) + +class SparkdbtObj(dbtObj): + def __init__(self): + super().__init__(load_df_function=get_spark_df) + self.source = lambda *args: source(*args, dbt_load_df_function=get_spark_df) + self.ref = lambda *args: ref(*args, dbt_load_df_function=get_spark_df) + +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql new file mode 100644 index 00000000..611ffc59 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql @@ -0,0 +1,58 @@ +{% macro athena__drop_relation(relation) -%} + {%- set native_drop = config.get('native_drop', default=false) -%} + {%- set rel_type_object = adapter.get_glue_table_type(relation) -%} + {%- set rel_type = none if rel_type_object == none else rel_type_object.value -%} + {%- set natively_droppable = rel_type == 'iceberg_table' or relation.type == 'view' -%} + + {%- if native_drop and natively_droppable -%} + {%- do drop_relation_sql(relation) -%} + {%- else -%} + {%- do drop_relation_glue(relation) -%} + {%- endif -%} +{% endmacro %} + +{% macro drop_relation_glue(relation) -%} + {%- do log('Dropping relation via Glue and S3 APIs') -%} + {%- do adapter.clean_up_table(relation) -%} + {%- do adapter.delete_from_glue_catalog(relation) -%} +{% endmacro %} + +{% macro drop_relation_sql(relation) -%} + + {%- do log('Dropping relation via SQL only') -%} + {% call statement('drop_relation', auto_begin=False) -%} + {%- if relation.type == 'view' -%} + drop {{ relation.type }} if exists {{ relation.render() }} + {%- else -%} + drop {{ relation.type }} if exists {{ relation.render_hive() }} + {% endif %} + {%- endcall %} +{% endmacro %} + +{% macro set_table_classification(relation) -%} + {%- set format = config.get('format', default='parquet') -%} + {% call statement('set_table_classification', auto_begin=False) -%} + alter table {{ relation.render_hive() }} set tblproperties ('classification' = '{{ format }}') + {%- endcall %} +{%- endmacro %} + +{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %} + {%- set temp_identifier = base_relation.identifier ~ suffix -%} + {%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%} + + {%- if temp_schema is not none -%} + {%- set temp_relation = temp_relation.incorporate(path={ + "identifier": temp_identifier, + "schema": temp_schema + }) -%} + {%- do create_schema(temp_relation) -%} + {% endif %} + + {{ return(temp_relation) }} +{% endmacro %} + +{% macro athena__rename_relation(from_relation, to_relation) %} + {% call statement('rename_relation') -%} + alter table {{ from_relation.render_hive() }} rename to `{{ to_relation.schema }}`.`{{ to_relation.identifier }}` + {%- endcall %} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql new file mode 100644 index 00000000..2750a7f9 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql @@ -0,0 +1,15 @@ +{% macro athena__create_schema(relation) -%} + {%- call statement('create_schema') -%} + create schema if not exists {{ relation.without_identifier().render_hive() }} + {% endcall %} + + {{ adapter.add_lf_tags_to_database(relation) }} + +{% endmacro %} + + +{% macro athena__drop_schema(relation) -%} + {%- call statement('drop_schema') -%} + drop schema if exists {{ relation.without_identifier().render_hive() }} cascade + {% endcall %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql new file mode 100644 index 00000000..07c32d35 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql @@ -0,0 +1,17 @@ +{% macro run_hooks(hooks, inside_transaction=True) %} + {% set re = modules.re %} + {% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %} + {% set rendered = render(hook.get('sql')) | trim %} + {% if (rendered | length) > 0 %} + {%- if re.match("optimize\W+\S+\W+rewrite data using bin_pack", rendered.lower(), re.MULTILINE) -%} + {%- do adapter.run_operation_with_potential_multiple_runs(rendered, "optimize") -%} + {%- elif re.match("vacuum\W+\S+", rendered.lower(), re.MULTILINE) -%} + {%- do adapter.run_operation_with_potential_multiple_runs(rendered, "vacuum") -%} + {%- else -%} + {% call statement(auto_begin=inside_transaction) %} + {{ rendered }} + {% endcall %} + {%- endif -%} + {% endif %} + {% endfor %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql new file mode 100644 index 00000000..c1bf6505 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql @@ -0,0 +1,88 @@ +{% macro get_partition_batches(sql, as_subquery=True) -%} + {# Retrieve partition configuration and set default partition limit #} + {%- set partitioned_by = config.get('partitioned_by') -%} + {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%} + {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%} + {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %} + + {# Retrieve distinct partitions from the given SQL #} + {% call statement('get_partitions', fetch_result=True) %} + {%- if as_subquery -%} + select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }}; + {%- else -%} + select distinct {{ partitioned_keys }} from {{ sql }} order by {{ partitioned_keys }}; + {%- endif -%} + {% endcall %} + + {# Initialize variables to store partition info #} + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set ns = namespace(partitions = [], bucket_conditions = {}, bucket_numbers = [], bucket_column = None, is_bucketed = false) -%} + + {# Process each partition row #} + {%- for row in rows -%} + {%- set single_partition = [] -%} + {# Use Namespace to hold the counter for loop index #} + {%- set counter = namespace(value=0) -%} + {# Loop through each column in the row #} + {%- for col, partition_key in zip(row, partitioned_by) -%} + {# Process bucketed columns using the new macro with the index #} + {%- do process_bucket_column(col, partition_key, table, ns, counter.value) -%} + + {# Logic for non-bucketed columns #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + {%- if not bucket_match -%} + {# For non-bucketed columns, format partition key and value #} + {%- set column_type = adapter.convert_type(table, counter.value) -%} + {%- set value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + {%- set partition_key_formatted = adapter.format_one_partition_key(partitioned_by[counter.value]) -%} + {%- do single_partition.append(partition_key_formatted + comp_func + value) -%} + {%- endif -%} + {# Increment the counter #} + {%- set counter.value = counter.value + 1 -%} + {%- endfor -%} + + {# Concatenate conditions for a single partition #} + {%- set single_partition_expression = single_partition | join(' and ') -%} + {%- if single_partition_expression not in ns.partitions %} + {%- do ns.partitions.append(single_partition_expression) -%} + {%- endif -%} + {%- endfor -%} + + {# Calculate total batches based on bucketing and partitioning #} + {%- if ns.is_bucketed -%} + {%- set total_batches = ns.partitions | length * ns.bucket_numbers | length -%} + {%- else -%} + {%- set total_batches = ns.partitions | length -%} + {%- endif -%} + {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ total_batches) %} + + {# Determine the number of batches per partition limit #} + {%- set batches_per_partition_limit = (total_batches // athena_partitions_limit) + (total_batches % athena_partitions_limit > 0) -%} + + {# Create conditions for each batch #} + {%- set partitions_batches = [] -%} + {%- for i in range(batches_per_partition_limit) -%} + {%- set batch_conditions = [] -%} + {%- if ns.is_bucketed -%} + {# Combine partition and bucket conditions for each batch #} + {%- for partition_expression in ns.partitions -%} + {%- for bucket_num in ns.bucket_numbers -%} + {%- set bucket_condition = ns.bucket_column + " IN (" + ns.bucket_conditions[bucket_num] | join(", ") + ")" -%} + {%- set combined_condition = "(" + partition_expression + ' and ' + bucket_condition + ")" -%} + {%- do batch_conditions.append(combined_condition) -%} + {%- endfor -%} + {%- endfor -%} + {%- else -%} + {# Extend batch conditions with partitions for non-bucketed columns #} + {%- do batch_conditions.extend(ns.partitions) -%} + {%- endif -%} + {# Calculate batch start and end index and append batch conditions #} + {%- set start_index = i * athena_partitions_limit -%} + {%- set end_index = start_index + athena_partitions_limit -%} + {%- do partitions_batches.append(batch_conditions[start_index:end_index] | join(' or ')) -%} + {%- endfor -%} + + {{ return(partitions_batches) }} + +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql new file mode 100644 index 00000000..3790fbba --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql @@ -0,0 +1,20 @@ +{% macro process_bucket_column(col, partition_key, table, ns, col_index) %} + {# Extract bucket information from the partition key #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + + {%- if bucket_match -%} + {# For bucketed columns, compute bucket numbers and conditions #} + {%- set column_type = adapter.convert_type(table, col_index) -%} + {%- set ns.is_bucketed = true -%} + {%- set ns.bucket_column = bucket_match[1] -%} + {%- set bucket_num = adapter.murmur3_hash(col, bucket_match[2] | int) -%} + {%- set formatted_value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + + {%- if bucket_num not in ns.bucket_numbers %} + {%- do ns.bucket_numbers.append(bucket_num) %} + {%- do ns.bucket_conditions.update({bucket_num: [formatted_value]}) -%} + {%- elif formatted_value not in ns.bucket_conditions[bucket_num] %} + {%- do ns.bucket_conditions[bucket_num].append(formatted_value) -%} + {%- endif -%} + {%- endif -%} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql new file mode 100644 index 00000000..952d6ecb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql @@ -0,0 +1,58 @@ +{% macro alter_relation_add_columns(relation, add_columns = none, table_type = 'hive') -%} + {% if add_columns is none %} + {% set add_columns = [] %} + {% endif %} + + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} + add columns ( + {%- for column in add_columns -%} + {{ column.name }} {{ ddl_data_type(column.data_type, table_type) }}{{ ', ' if not loop.last }} + {%- endfor -%} + ) + {%- endset -%} + + {% if (add_columns | length) > 0 %} + {{ return(run_query(sql)) }} + {% endif %} +{% endmacro %} + +{% macro alter_relation_drop_columns(relation, remove_columns = none) -%} + {% if remove_columns is none %} + {% set remove_columns = [] %} + {% endif %} + + {%- for column in remove_columns -%} + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} drop column {{ column.name }} + {% endset %} + {% do run_query(sql) %} + {%- endfor -%} +{% endmacro %} + +{% macro alter_relation_replace_columns(relation, replace_columns = none, table_type = 'hive') -%} + {% if replace_columns is none %} + {% set replace_columns = [] %} + {% endif %} + + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} + replace columns ( + {%- for column in replace_columns -%} + {{ column.name }} {{ ddl_data_type(column.data_type, table_type) }}{{ ', ' if not loop.last }} + {%- endfor -%} + ) + {%- endset -%} + + {% if (replace_columns | length) > 0 %} + {{ return(run_query(sql)) }} + {% endif %} +{% endmacro %} + +{% macro alter_relation_rename_column(relation, source_column, target_column, target_column_type, table_type = 'hive') -%} + {% set sql -%} + alter {{ relation.type }} {{ relation.render_pure() }} + change column {{ source_column }} {{ target_column }} {{ ddl_data_type(target_column_type, table_type) }} + {%- endset -%} + {{ return(run_query(sql)) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql new file mode 100644 index 00000000..13011578 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -0,0 +1,122 @@ +{% macro validate_get_incremental_strategy(raw_strategy, table_type) %} + {%- if table_type == 'iceberg' -%} + {% set invalid_strategy_msg -%} + Invalid incremental strategy provided: {{ raw_strategy }} + Incremental models on Iceberg tables only work with 'append' or 'merge' (v3 only) strategy. + {%- endset %} + {% if raw_strategy not in ['append', 'merge'] %} + {% do exceptions.raise_compiler_error(invalid_strategy_msg) %} + {% endif %} + {%- else -%} + {% set invalid_strategy_msg -%} + Invalid incremental strategy provided: {{ raw_strategy }} + Expected one of: 'append', 'insert_overwrite' + {%- endset %} + + {% if raw_strategy not in ['append', 'insert_overwrite'] %} + {% do exceptions.raise_compiler_error(invalid_strategy_msg) %} + {% endif %} + {% endif %} + + {% do return(raw_strategy) %} +{% endmacro %} + + +{% macro batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches|length) -%} + {%- set insert_batch_partitions -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + ); + {%- endset -%} + {%- do run_query(insert_batch_partitions) -%} + {%- endfor -%} +{% endmacro %} + + +{% macro incremental_insert( + on_schema_change, + tmp_relation, + target_relation, + existing_relation, + force_batch, + statement_name="main" + ) +%} + {%- set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- if not dest_columns -%} + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} + {%- endif -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {% if force_batch %} + {% do batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {% else %} + {%- set insert_full -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + ); + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(insert_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% do batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {%- endif -%} + {%- endif -%} + + SELECT '{{query_result}}' + +{%- endmacro %} + + +{% macro delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} + {%- set partitioned_keys = partitioned_by | tojson | replace('\"', '') | replace('[', '') | replace(']', '') -%} + {% call statement('get_partitions', fetch_result=True) %} + select distinct {{partitioned_keys}} from {{ tmp_relation }}; + {% endcall %} + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set partitions = [] -%} + {%- for row in rows -%} + {%- set single_partition = [] -%} + {%- for col in row -%} + {%- set column_type = adapter.convert_type(table, loop.index0) -%} + {%- if column_type == 'integer' or column_type is none -%} + {%- set value = col|string -%} + {%- elif column_type == 'string' -%} + {%- set value = "'" + col + "'" -%} + {%- elif column_type == 'date' -%} + {%- set value = "'" + col|string + "'" -%} + {%- elif column_type == 'timestamp' -%} + {%- set value = "'" + col|string + "'" -%} + {%- else -%} + {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} + {%- endif -%} + {%- do single_partition.append(partitioned_by[loop.index0] + '=' + value) -%} + {%- endfor -%} + {%- set single_partition_expression = single_partition | join(' and ') -%} + {%- do partitions.append('(' + single_partition_expression + ')') -%} + {%- endfor -%} + {%- for i in range(partitions | length) %} + {%- do adapter.clean_up_partitions(target_relation, partitions[i]) -%} + {%- endfor -%} +{%- endmacro %} + +{% macro remove_partitions_from_columns(columns_with_partitions, partition_keys) %} + {%- set columns = [] -%} + {%- for column in columns_with_partitions -%} + {%- if column.name not in partition_keys -%} + {%- do columns.append(column) -%} + {%- endif -%} + {%- endfor -%} + {{ return(columns) }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql new file mode 100644 index 00000000..c18ac681 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -0,0 +1,175 @@ +{% materialization incremental, adapter='athena', supported_languages=['sql', 'python'] -%} + {% set raw_strategy = config.get('incremental_strategy') or 'insert_overwrite' %} + {% set table_type = config.get('table_type', default='hive') | lower %} + {% set model_language = model['language'] %} + {% set strategy = validate_get_incremental_strategy(raw_strategy, table_type) %} + {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} + {% set versions_to_keep = config.get('versions_to_keep', 1) | as_number %} + {% set lf_tags_config = config.get('lf_tags_config') %} + {% set lf_grants = config.get('lf_grants') %} + {% set partitioned_by = config.get('partitioned_by') %} + {% set force_batch = config.get('force_batch', False) | as_bool -%} + {% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%} + {% set temp_schema = config.get('temp_schema') %} + {% set target_relation = this.incorporate(type='table') %} + {% set existing_relation = load_relation(this) %} + -- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix + {% if unique_tmp_table_suffix == True and strategy == 'insert_overwrite' and table_type == 'hive' %} + {% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %} + {% else %} + {% set tmp_table_suffix = '__dbt_tmp' %} + {% endif %} + + {% if unique_tmp_table_suffix == True and table_type == 'iceberg' %} + {% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %} + {% endif %} + + {% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix, + schema=schema, + database=database) %} + {% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix, temp_schema=temp_schema) %} + + -- If no partitions are used with insert_overwrite, we fall back to append mode. + {% if partitioned_by is none and strategy == 'insert_overwrite' %} + {% set strategy = 'append' %} + {% endif %} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + -- `BEGIN` happens here: + {{ run_hooks(pre_hooks, inside_transaction=True) }} + + {% set to_drop = [] %} + {% if existing_relation is none %} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} + {% elif existing_relation.is_view or should_full_refresh() %} + {% do drop_relation(existing_relation) %} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} + {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} + {% set build_sql = incremental_insert( + on_schema_change, tmp_relation, target_relation, existing_relation, force_batch + ) + %} + {% do to_drop.append(tmp_relation) %} + {% elif strategy == 'append' %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = incremental_insert( + on_schema_change, tmp_relation, target_relation, existing_relation, force_batch + ) + %} + {% do to_drop.append(tmp_relation) %} + {% elif strategy == 'merge' and table_type == 'iceberg' %} + {% set unique_key = config.get('unique_key') %} + {% set incremental_predicates = config.get('incremental_predicates') %} + {% set delete_condition = config.get('delete_condition') %} + {% set update_condition = config.get('update_condition') %} + {% set insert_condition = config.get('insert_condition') %} + {% set empty_unique_key -%} + Merge strategy must implement unique_key as a single column or a list of columns. + {%- endset %} + {% if unique_key is none %} + {% do exceptions.raise_compiler_error(empty_unique_key) %} + {% endif %} + {% if incremental_predicates is not none %} + {% set inc_predicates_not_list -%} + Merge strategy must implement incremental_predicates as a list of predicates. + {%- endset %} + {% if not adapter.is_list(incremental_predicates) %} + {% do exceptions.raise_compiler_error(inc_predicates_not_list) %} + {% endif %} + {% endif %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = iceberg_merge( + on_schema_change=on_schema_change, + tmp_relation=tmp_relation, + target_relation=target_relation, + unique_key=unique_key, + incremental_predicates=incremental_predicates, + existing_relation=existing_relation, + delete_condition=delete_condition, + update_condition=update_condition, + insert_condition=insert_condition, + force_batch=force_batch, + ) + %} + {% do to_drop.append(tmp_relation) %} + {% endif %} + + {% call statement("main", language=model_language) %} + {% if model_language == 'sql' %} + {{ build_sql }} + {% else %} + {{ log(build_sql) }} + {% do athena__py_execute_query(query=build_sql) %} + {% endif %} + {% endcall %} + + -- set table properties + {% if not to_drop and table_type != 'iceberg' and model_language != 'python' %} + {{ set_table_classification(target_relation) }} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + -- `COMMIT` happens here + {% do adapter.commit() %} + + {% for rel in to_drop %} + {% do drop_relation(rel) %} + {% endfor %} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, False) %} + + {{ return({'relations': [target_relation]}) }} + +{%- endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql new file mode 100644 index 00000000..8741adb7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql @@ -0,0 +1,162 @@ +{% macro get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns) %} + + {%- if merge_update_columns and merge_exclude_columns -%} + {{ exceptions.raise_compiler_error( + 'Model cannot specify merge_update_columns and merge_exclude_columns. Please update model to use only one config' + )}} + {%- elif merge_update_columns -%} + {%- set update_columns = [] -%} + {%- for column in dest_columns -%} + {% if column.column | lower in merge_update_columns | map("lower") | list %} + {%- do update_columns.append(column) -%} + {% endif %} + {%- endfor -%} + {%- elif merge_exclude_columns -%} + {%- set update_columns = [] -%} + {%- for column in dest_columns -%} + {% if column.column | lower not in merge_exclude_columns | map("lower") | list %} + {%- do update_columns.append(column) -%} + {% endif %} + {%- endfor -%} + {%- else -%} + {%- set update_columns = dest_columns -%} + {%- endif -%} + + {{ return(update_columns) }} + +{% endmacro %} + +{%- macro get_update_statement(col, rule, is_last) -%} + {%- if rule == "coalesce" -%} + {{ col.quoted }} = {{ 'coalesce(src.' + col.quoted + ', target.' + col.quoted + ')' }} + {%- elif rule == "sum" -%} + {%- if col.data_type.startswith("map") -%} + {{ col.quoted }} = {{ 'map_zip_with(coalesce(src.' + col.quoted + ', map()), coalesce(target.' + col.quoted + ', map()), (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0))' }} + {%- else -%} + {{ col.quoted }} = {{ 'src.' + col.quoted + ' + target.' + col.quoted }} + {%- endif -%} + {%- elif rule == "append" -%} + {{ col.quoted }} = {{ 'src.' + col.quoted + ' || target.' + col.quoted }} + {%- elif rule == "append_distinct" -%} + {{ col.quoted }} = {{ 'array_distinct(src.' + col.quoted + ' || target.' + col.quoted + ')' }} + {%- elif rule == "replace" -%} + {{ col.quoted }} = {{ 'src.' + col.quoted }} + {%- else -%} + {{ col.quoted }} = {{ rule | replace("_new_", 'src.' + col.quoted) | replace("_old_", 'target.' + col.quoted) }} + {%- endif -%} + {{ "," if not is_last }} +{%- endmacro -%} + + +{% macro batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + {%- set src_batch_part -%} + merge into {{ target_relation }} as target + using (select {{ dest_cols_csv }} from {{ tmp_relation }} where {{ batch }}) as src + {%- endset -%} + {%- set merge_batch -%} + {{ src_batch_part }} + {{ merge_part }} + {%- endset -%} + {%- do run_query(merge_batch) -%} + {%- endfor -%} +{%- endmacro -%} + + +{% macro iceberg_merge( + on_schema_change, + tmp_relation, + target_relation, + unique_key, + incremental_predicates, + existing_relation, + delete_condition, + update_condition, + insert_condition, + force_batch, + statement_name="main" + ) +%} + {%- set merge_update_columns = config.get('merge_update_columns') -%} + {%- set merge_exclude_columns = config.get('merge_exclude_columns') -%} + {%- set merge_update_columns_default_rule = config.get('merge_update_columns_default_rule', 'replace') -%} + {%- set merge_update_columns_rules = config.get('merge_update_columns_rules') -%} + + {% set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} + {% if not dest_columns %} + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} + {% endif %} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + {%- if unique_key is sequence and unique_key is not string -%} + {%- set unique_key_cols = unique_key -%} + {%- else -%} + {%- set unique_key_cols = [unique_key] -%} + {%- endif -%} + {%- set src_columns_quoted = [] -%} + {%- set dest_columns_wo_keys = [] -%} + {%- for col in dest_columns -%} + {%- do src_columns_quoted.append('src.' + col.quoted ) -%} + {%- if col.name not in unique_key_cols -%} + {%- do dest_columns_wo_keys.append(col) -%} + {%- endif -%} + {%- endfor -%} + {%- set update_columns = get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns_wo_keys) -%} + {%- set src_cols_csv = src_columns_quoted | join(', ') -%} + + {%- set merge_part -%} + on ( + {%- for key in unique_key_cols -%} + target.{{ key }} = src.{{ key }} + {{ " and " if not loop.last }} + {%- endfor -%} + {% if incremental_predicates is not none -%} + and ( + {%- for inc_predicate in incremental_predicates %} + {{ inc_predicate }} {{ "and " if not loop.last }} + {%- endfor %} + ) + {%- endif %} + ) + {% if delete_condition is not none -%} + when matched and ({{ delete_condition }}) + then delete + {%- endif %} + {% if update_columns -%} + when matched {% if update_condition is not none -%} and {{ update_condition }} {%- endif %} + then update set + {%- for col in update_columns %} + {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} + {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} + {%- else -%} + {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} + {%- endif -%} + {%- endfor %} + {%- endif %} + when not matched {% if insert_condition is not none -%} and {{ insert_condition }} {%- endif %} + then insert ({{ dest_cols_csv }}) + values ({{ src_cols_csv }}) + {%- endset -%} + + {%- if force_batch -%} + {% do batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {%- else -%} + {%- set src_part -%} + merge into {{ target_relation }} as target using {{ tmp_relation }} as src + {%- endset -%} + {%- set merge_full -%} + {{ src_part }} + {{ merge_part }} + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(merge_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% do batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {%- endif -%} + {%- endif -%} + + SELECT '{{query_result}}' +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql new file mode 100644 index 00000000..fc7fcbcb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql @@ -0,0 +1,89 @@ +{% macro sync_column_schemas(on_schema_change, target_relation, schema_changes_dict) %} + {%- set partitioned_by = config.get('partitioned_by', default=none) -%} + {% set table_type = config.get('table_type', default='hive') | lower %} + {%- if partitioned_by is none -%} + {%- set partitioned_by = [] -%} + {%- endif %} + {%- set add_to_target_arr = schema_changes_dict['source_not_in_target'] -%} + {%- if on_schema_change == 'append_new_columns'-%} + {%- if add_to_target_arr | length > 0 -%} + {%- do alter_relation_add_columns(target_relation, add_to_target_arr, table_type) -%} + {%- endif -%} + {% elif on_schema_change == 'sync_all_columns' %} + {%- set remove_from_target_arr = schema_changes_dict['target_not_in_source'] -%} + {%- set new_target_types = schema_changes_dict['new_target_types'] -%} + {% if table_type == 'iceberg' %} + {# + If last run of alter_column_type was failed on rename tmp column to origin. + Do rename to protect origin column from deletion and losing data. + #} + {% for remove_col in remove_from_target_arr if remove_col.column.endswith('__dbt_alter') %} + {%- set origin_col_name = remove_col.column | replace('__dbt_alter', '') -%} + {% for add_col in add_to_target_arr if add_col.column == origin_col_name %} + {%- do alter_relation_rename_column(target_relation, remove_col.name, add_col.name, add_col.data_type, table_type) -%} + {%- do remove_from_target_arr.remove(remove_col) -%} + {%- do add_to_target_arr.remove(add_col) -%} + {% endfor %} + {% endfor %} + + {% if add_to_target_arr | length > 0 %} + {%- do alter_relation_add_columns(target_relation, add_to_target_arr, table_type) -%} + {% endif %} + {% if remove_from_target_arr | length > 0 %} + {%- do alter_relation_drop_columns(target_relation, remove_from_target_arr) -%} + {% endif %} + {% if new_target_types != [] %} + {% for ntt in new_target_types %} + {% set column_name = ntt['column_name'] %} + {% set new_type = ntt['new_type'] %} + {% do alter_column_type(target_relation, column_name, new_type) %} + {% endfor %} + {% endif %} + {% else %} + {%- set replace_with_target_arr = remove_partitions_from_columns(schema_changes_dict['source_columns'], partitioned_by) -%} + {% if add_to_target_arr | length > 0 or remove_from_target_arr | length > 0 or new_target_types | length > 0 %} + {%- do alter_relation_replace_columns(target_relation, replace_with_target_arr, table_type) -%} + {% endif %} + {% endif %} + {% endif %} + {% set schema_change_message %} + In {{ target_relation }}: + Schema change approach: {{ on_schema_change }} + Columns added: {{ add_to_target_arr }} + Columns removed: {{ remove_from_target_arr }} + Data types changed: {{ new_target_types }} + {% endset %} + {% do log(schema_change_message) %} +{% endmacro %} + +{% macro athena__alter_column_type(relation, column_name, new_column_type) -%} + {# + 1. Create a new column (w/ temp name and correct type) + 2. Copy data over to it + 3. Drop the existing column + 4. Rename the new column to existing column + #} + {%- set table_type = config.get('table_type', 'hive') -%} + {%- set tmp_column = column_name + '__dbt_alter' -%} + {%- set new_ddl_data_type = ddl_data_type(new_column_type, table_type) -%} + + {#- do alter_relation_add_columns(relation, [ tmp_column ], table_type) -#} + {%- set add_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} add columns({{ tmp_column }} {{ new_ddl_data_type }}); + {%- endset -%} + {%- do run_query(add_column_query) -%} + + {%- set update_query -%} + update {{ relation.render_pure() }} set {{ tmp_column }} = cast({{ column_name }} as {{ new_column_type }}); + {%- endset -%} + {%- do run_query(update_query) -%} + + {#- do alter_relation_drop_columns(relation, [ column_name ]) -#} + {%- set drop_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} drop column {{ column_name }}; + {%- endset -%} + {%- do run_query(drop_column_query) -%} + + {%- do alter_relation_rename_column(relation, tmp_column, column_name, new_column_type, table_type) -%} + +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql new file mode 100644 index 00000000..e20c4020 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -0,0 +1,223 @@ +{% macro create_table_as(temporary, relation, compiled_code, language='sql', skip_partitioning=false) -%} + {{ adapter.dispatch('create_table_as', 'athena')(temporary, relation, compiled_code, language, skip_partitioning) }} +{%- endmacro %} + + +{% macro athena__create_table_as(temporary, relation, compiled_code, language='sql', skip_partitioning=false) -%} + {%- set materialized = config.get('materialized', default='table') -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- do log("Skip partitioning: " ~ skip_partitioning) -%} + {%- set partitioned_by = config.get('partitioned_by', default=none) if not skip_partitioning else none -%} + {%- set bucketed_by = config.get('bucketed_by', default=none) -%} + {%- set bucket_count = config.get('bucket_count', default=none) -%} + {%- set field_delimiter = config.get('field_delimiter', default=none) -%} + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set format = config.get('format', default='parquet') -%} + {%- set write_compression = config.get('write_compression', default=none) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} + {%- set s3_tmp_table_dir = config.get('s3_tmp_table_dir', default=target.s3_tmp_table_dir) -%} + {%- set extra_table_properties = config.get('table_properties', default=none) -%} + + {%- set location_property = 'external_location' -%} + {%- set partition_property = 'partitioned_by' -%} + {%- set work_group_output_location_enforced = adapter.is_work_group_output_location_enforced() -%} + {%- set location = adapter.generate_s3_location(relation, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + temporary, + ) -%} + {%- set native_drop = config.get('native_drop', default=false) -%} + + {%- set contract_config = config.get('contract') -%} + {%- if contract_config.enforced -%} + {{ get_assert_columns_equivalent(compiled_code) }} + {%- endif -%} + + {%- if native_drop and table_type == 'iceberg' -%} + {% do log('Config native_drop enabled, skipping direct S3 delete') %} + {%- else -%} + {% do adapter.delete_from_s3(location) %} + {%- endif -%} + + {%- if language == 'python' -%} + {%- set spark_ctas = '' -%} + {%- if table_type == 'iceberg' -%} + {%- set spark_ctas -%} + create table {{ relation.schema | replace('\"', '`') }}.{{ relation.identifier | replace('\"', '`') }} + using iceberg + location '{{ location }}/' + + {%- if partitioned_by is not none %} + partitioned by ( + {%- for prop_value in partitioned_by -%} + {{ prop_value }} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ) + {%- endif %} + + {%- if extra_table_properties is not none %} + tblproperties( + {%- for prop_name, prop_value in extra_table_properties.items() -%} + '{{ prop_name }}'='{{ prop_value }}' + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ) + {% endif %} + + as + {%- endset -%} + {%- endif -%} + + {# {% do log('Creating table with spark and compiled code: ' ~ compiled_code) %} #} + {{ athena__py_save_table_as( + compiled_code, + relation, + optional_args={ + 'location': location, + 'format': format, + 'mode': 'overwrite', + 'partitioned_by': partitioned_by, + 'bucketed_by': bucketed_by, + 'write_compression': write_compression, + 'bucket_count': bucket_count, + 'field_delimiter': field_delimiter, + 'spark_ctas': spark_ctas + } + ) + }} + {%- else -%} + {%- if table_type == 'iceberg' -%} + {%- set location_property = 'location' -%} + {%- set partition_property = 'partitioning' -%} + {%- if bucketed_by is not none or bucket_count is not none -%} + {%- set ignored_bucket_iceberg -%} + bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function + when partitioning. Will be ignored + {%- endset -%} + {%- set bucketed_by = none -%} + {%- set bucket_count = none -%} + {% do log(ignored_bucket_iceberg) %} + {%- endif -%} + {%- if materialized == 'table' and ( 'unique' not in s3_data_naming or external_location is not none) -%} + {%- set error_unique_location_iceberg -%} + You need to have an unique table location when creating Iceberg table since we use the RENAME feature + to have near-zero downtime. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} + {%- endif -%} + {%- endif %} + + create table {{ relation }} + with ( + table_type='{{ table_type }}', + is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, + {%- if not work_group_output_location_enforced or table_type == 'iceberg' -%} + {{ location_property }}='{{ location }}', + {%- endif %} + {%- if partitioned_by is not none %} + {{ partition_property }}=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucketed_by is not none %} + bucketed_by=ARRAY{{ bucketed_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucket_count is not none %} + bucket_count={{ bucket_count }}, + {%- endif %} + {%- if field_delimiter is not none %} + field_delimiter='{{ field_delimiter }}', + {%- endif %} + {%- if write_compression is not none %} + write_compression='{{ write_compression }}', + {%- endif %} + format='{{ format }}' + {%- if extra_table_properties is not none -%} + {%- for prop_name, prop_value in extra_table_properties.items() -%} + , + {{ prop_name }}={{ prop_value }} + {%- endfor -%} + {% endif %} + ) + as + {{ compiled_code }} + {%- endif -%} +{%- endmacro -%} + +{% macro create_table_as_with_partitions(temporary, relation, compiled_code, language='sql') -%} + + {%- set tmp_relation = api.Relation.create( + identifier=relation.identifier ~ '__tmp_not_partitioned', + schema=relation.schema, + database=relation.database, + s3_path_table_part=relation.identifier ~ '__tmp_not_partitioned' , + type='table' + ) + -%} + + {%- if tmp_relation is not none -%} + {%- do drop_relation(tmp_relation) -%} + {%- endif -%} + + {%- do log('CREATE NON-PARTIONED STAGING TABLE: ' ~ tmp_relation) -%} + {%- do run_query(create_table_as(temporary, tmp_relation, compiled_code, language, true)) -%} + + {% set partitions_batches = get_partition_batches(sql=tmp_relation, as_subquery=False) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + + {%- set dest_columns = adapter.get_columns_in_relation(tmp_relation) -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + + {%- if loop.index == 1 -%} + {%- set create_target_relation_sql -%} + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + {%- do run_query(create_table_as(temporary, relation, create_target_relation_sql, language)) -%} + {%- else -%} + {%- set insert_batch_partitions_sql -%} + insert into {{ relation }} ({{ dest_cols_csv }}) + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + + {%- do run_query(insert_batch_partitions_sql) -%} + {%- endif -%} + + + {%- endfor -%} + + {%- do drop_relation(tmp_relation) -%} + + select 'SUCCESSFULLY CREATED TABLE {{ relation }}' + +{%- endmacro %} + +{% macro safe_create_table_as(temporary, relation, compiled_code, language='sql', force_batch=False) -%} + {%- if language != 'sql' -%} + {{ return(create_table_as(temporary, relation, compiled_code, language)) }} + {%- elif force_batch -%} + {%- do create_table_as_with_partitions(temporary, relation, compiled_code, language) -%} + {%- set query_result = relation ~ ' with many partitions created' -%} + {%- else -%} + {%- if temporary -%} + {%- do run_query(create_table_as(temporary, relation, compiled_code, language, true)) -%} + {%- set compiled_code_result = relation ~ ' as temporary relation without partitioning created' -%} + {%- else -%} + {%- set compiled_code_result = adapter.run_query_with_partitions_limit_catching(create_table_as(temporary, relation, compiled_code)) -%} + {%- do log('COMPILED CODE RESULT: ' ~ compiled_code_result) -%} + {%- if compiled_code_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {%- do create_table_as_with_partitions(temporary, relation, compiled_code, language) -%} + {%- set compiled_code_result = relation ~ ' with many partitions created' -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {{ return(compiled_code_result) }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql new file mode 100644 index 00000000..1f94361c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql @@ -0,0 +1,183 @@ +-- TODO create a drop_relation_with_versions, to be sure to remove all historical versions of a table +{% materialization table, adapter='athena', supported_languages=['sql', 'python'] -%} + {%- set identifier = model['alias'] -%} + {%- set language = model['language'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set old_tmp_relation = adapter.get_relation(identifier=identifier ~ '__ha', + schema=schema, + database=database) -%} + {%- set old_bkp_relation = adapter.get_relation(identifier=identifier ~ '__bkp', + schema=schema, + database=database) -%} + {%- set is_ha = config.get('ha', default=false) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default='table_unique') -%} + {%- set full_refresh_config = config.get('full_refresh', default=False) -%} + {%- set is_full_refresh_mode = (flags.FULL_REFRESH == True or full_refresh_config == True) -%} + {%- set versions_to_keep = config.get('versions_to_keep', default=4) -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set force_batch = config.get('force_batch', False) | as_bool -%} + {%- set target_relation = api.Relation.create(identifier=identifier, + schema=schema, + database=database, + type='table') -%} + {%- set tmp_relation = api.Relation.create(identifier=target_relation.identifier ~ '__ha', + schema=schema, + database=database, + s3_path_table_part=target_relation.identifier, + type='table') -%} + + {%- if ( + table_type == 'hive' + and is_ha + and ('unique' not in s3_data_naming or external_location is not none) + ) -%} + {%- set error_unique_location_hive_ha -%} + You need to have an unique table location when using ha config with hive table. + Use s3_data_naming unique, table_unique or schema_table_unique, and avoid to set an explicit + external_location. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_hive_ha) %} + {%- endif -%} + + {{ run_hooks(pre_hooks) }} + + {%- if table_type == 'hive' -%} + + -- for ha tables that are not in full refresh mode and when the relation exists we use the swap behavior + {%- if is_ha and not is_full_refresh_mode and old_relation is not none -%} + -- drop the old_tmp_relation if it exists + {%- if old_tmp_relation is not none -%} + {%- do adapter.delete_from_glue_catalog(old_tmp_relation) -%} + {%- endif -%} + + -- create tmp table + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + -- swap table + {%- set swap_table = adapter.swap_table(tmp_relation, target_relation) -%} + + -- delete glue tmp table, do not use drop_relation, as it will remove data of the target table + {%- do adapter.delete_from_glue_catalog(tmp_relation) -%} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, True) %} + + {%- else -%} + -- Here we are in the case of non-ha tables or ha tables but in case of full refresh. + {%- if old_relation is not none -%} + {{ drop_relation(old_relation) }} + {%- endif -%} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- endif -%} + + {%- if language != 'python' -%} + {{ set_table_classification(target_relation) }} + {%- endif -%} + {%- else -%} + + {%- if old_relation is none -%} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- else -%} + {%- if old_relation.is_view -%} + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- do drop_relation(old_relation) -%} + {%- do rename_relation(tmp_relation, target_relation) -%} + {%- else -%} + -- delete old tmp iceberg table if it exists + {%- if old_tmp_relation is not none -%} + {%- do drop_relation(old_tmp_relation) -%} + {%- endif -%} + + -- If we have this, it means that at least the first renaming occurred but there was an issue + -- afterwards, therefore we are in weird state. The easiest and cleanest should be to remove + -- the backup relation. It won't have an impact because since we are in the else condition, + -- that means that old relation exists therefore no downtime yet. + {%- if old_bkp_relation is not none -%} + {%- do drop_relation(old_bkp_relation) -%} + {%- endif -%} + + {% set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) %} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + + {%- set old_relation_table_type = adapter.get_glue_table_type(old_relation).value if old_relation else none -%} + + -- we cannot use old_bkp_relation, because it returns None if the relation doesn't exist + -- we need to create a python object via the make_temp_relation instead + {%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%} + + {%- if old_relation_table_type == 'iceberg_table' -%} + {{ rename_relation(old_relation, old_relation_bkp) }} + {%- else -%} + {%- do drop_relation_glue(old_relation) -%} + {%- endif -%} + + -- publish the target table doing a final renaming + {{ rename_relation(tmp_relation, target_relation) }} + + -- if old relation is iceberg_table, we have a backup + -- therefore we can drop the old relation backup, in all other cases there is nothing to do + -- in case of switch from hive to iceberg the backup table do not exists + -- in case of first run, the backup table do not exists + {%- if old_relation_table_type == 'iceberg_table' -%} + {%- do drop_relation(old_relation_bkp) -%} + {%- endif -%} + + {%- endif -%} + {%- endif -%} + + {%- endif -%} + + {% call statement("main", language=language) %} + {%- if language=='sql' -%} + SELECT '{{ query_result }}'; + {%- endif -%} + {% endcall %} + + {{ run_hooks(post_hooks) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {{ return({'relations': [target_relation]}) }} + +{% endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql new file mode 100644 index 00000000..655354ac --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql @@ -0,0 +1,54 @@ +{% macro create_or_replace_view(run_outside_transaction_hooks=True) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%} + {%- set target_relation = api.Relation.create( + identifier=identifier, + schema=schema, + database=database, + type='view', + ) -%} + + {% if run_outside_transaction_hooks %} + -- no transactions on BigQuery + {{ run_hooks(pre_hooks, inside_transaction=False) }} + {% endif %} + -- `BEGIN` happens here on Snowflake + {{ run_hooks(pre_hooks, inside_transaction=True) }} + -- If there's a table with the same name and we weren't told to full refresh, + -- that's an error. If we were told to full refresh, drop it. This behavior differs + -- for Snowflake and BigQuery, so multiple dispatch is used. + {%- if old_relation is not none and old_relation.is_table -%} + {{ handle_existing_table(should_full_refresh(), old_relation) }} + {%- endif -%} + + -- build model + {% call statement('main') -%} + {{ create_view_as(target_relation, compiled_code) }} + {%- endcall %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + {{ adapter.commit() }} + + {% if run_outside_transaction_hooks %} + -- No transactions on BigQuery + {{ run_hooks(post_hooks, inside_transaction=False) }} + {% endif %} + + {{ return({'relations': [target_relation]}) }} + +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql new file mode 100644 index 00000000..019676db --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql @@ -0,0 +1,10 @@ +{% macro athena__create_view_as(relation, sql) -%} + {%- set contract_config = config.get('contract') -%} + {%- if contract_config.enforced -%} + {{ get_assert_columns_equivalent(sql) }} + {%- endif -%} + create or replace view + {{ relation }} + as + {{ sql }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql new file mode 100644 index 00000000..449e0833 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql @@ -0,0 +1,17 @@ +{% materialization view, adapter='athena' -%} + {%- set identifier = model['alias'] -%} + {%- set versions_to_keep = config.get('versions_to_keep', default=4) -%} + {%- set target_relation = api.Relation.create(identifier=identifier, + schema=schema, + database=database, + type='view') -%} + + {% set to_return = create_or_replace_view(run_outside_transaction_hooks=False) %} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, False) %} + + {% set target_relation = this.incorporate(type='view') %} + {% do persist_docs(target_relation, model) %} + + {% do return(to_return) %} +{%- endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql new file mode 100644 index 00000000..f44c7c9c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql @@ -0,0 +1,215 @@ +{% macro default__reset_csv_table(model, full_refresh, old_relation, agate_table) %} + {% set sql = "" %} + -- No truncate in Athena so always drop CSV table and recreate + {{ drop_relation(old_relation) }} + {% set sql = create_csv_table(model, agate_table) %} + + {{ return(sql) }} +{% endmacro %} + +{% macro try_cast_timestamp(col) %} + {% set date_formats = [ + '%Y-%m-%d %H:%i:%s', + '%Y/%m/%d %H:%i:%s', + '%d %M %Y %H:%i:%s', + '%d/%m/%Y %H:%i:%s', + '%d-%m-%Y %H:%i:%s', + '%Y-%m-%d %H:%i:%s.%f', + '%Y/%m/%d %H:%i:%s.%f', + '%d %M %Y %H:%i:%s.%f', + '%d/%m/%Y %H:%i:%s.%f', + '%Y-%m-%dT%H:%i:%s.%fZ', + '%Y-%m-%dT%H:%i:%sZ', + '%Y-%m-%dT%H:%i:%s', + ]%} + + coalesce( + {% for date_format in date_formats %} + try(date_parse({{ col }}, '{{ date_format }}')) + {%- if not loop.last -%}, {% endif -%} + {% endfor %} + ) as {{ col }} +{% endmacro %} + +{% macro create_csv_table_insert(model, agate_table) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + {%- set column_override = config.get('column_types', {}) -%} + {%- set quote_seed_column = config.get('quote_columns') -%} + {%- set s3_data_dir = config.get('s3_data_dir', target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', target.s3_data_naming) -%} + {%- set s3_tmp_table_dir = config.get('s3_tmp_table_dir', default=target.s3_tmp_table_dir) -%} + {%- set external_location = config.get('external_location') -%} + + {%- set relation = api.Relation.create( + identifier=identifier, + schema=model.schema, + database=model.database, + type='table' + ) -%} + + {%- set location = adapter.generate_s3_location(relation, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + temporary) -%} + + {% set sql_table %} + create external table {{ relation.render_hive() }} ( + {%- for col_name in agate_table.column_names -%} + {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} + {%- set type = column_override.get(col_name, inferred_type) -%} + {%- set type = type if type != "string" else "varchar" -%} + {%- set column_name = (col_name | string) -%} + {{ adapter.quote_seed_column(column_name, quote_seed_column, "`") }} {{ ddl_data_type(type) }} {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + ) + location '{{ location }}' + + {% endset %} + + {% call statement('_') -%} + {{ sql_table }} + {%- endcall %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(relation, lf_grants) }} + {% endif %} + + {{ return(sql) }} +{% endmacro %} + + +{% macro create_csv_table_upload(model, agate_table) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set column_override = config.get('column_types', {}) -%} + {%- set quote_seed_column = config.get('quote_columns', None) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', target.s3_data_naming) -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set seed_s3_upload_args = config.get('seed_s3_upload_args', default=target.seed_s3_upload_args) -%} + + {%- set tmp_relation = api.Relation.create( + identifier=identifier + "__dbt_tmp", + schema=model.schema, + database=model.database, + type='table' + ) -%} + + {%- set tmp_s3_location = adapter.upload_seed_to_s3( + tmp_relation, + agate_table, + s3_data_dir, + s3_data_naming, + external_location, + seed_s3_upload_args=seed_s3_upload_args + ) -%} + + -- create target relation + {%- set relation = api.Relation.create( + identifier=identifier, + schema=model.schema, + database=model.database, + type='table' + ) -%} + + -- drop tmp relation if exists + {{ drop_relation(tmp_relation) }} + + {% set sql_tmp_table %} + create external table {{ tmp_relation.render_hive() }} ( + {%- for col_name in agate_table.column_names -%} + {%- set column_name = (col_name | string) -%} + {{ adapter.quote_seed_column(column_name, quote_seed_column, "`") }} string {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + ) + row format serde 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + location '{{ tmp_s3_location }}' + tblproperties ( + 'skip.header.line.count'='1' + ) + {% endset %} + + -- casting to type string is not allowed needs to be varchar + {% set sql %} + select + {% for col_name in agate_table.column_names -%} + {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} + {%- set type = column_override.get(col_name, inferred_type) -%} + {%- set type = type if type != "string" else "varchar" -%} + {%- set column_name = (col_name | string) -%} + {%- set quoted_column_name = adapter.quote_seed_column(column_name, quote_seed_column) -%} + {% if type == 'timestamp' %} + {{ try_cast_timestamp(quoted_column_name) }} + {% else %} + cast(nullif({{quoted_column_name}}, '') as {{ type }}) as {{quoted_column_name}} + {% endif %} + {%- if not loop.last -%}, {% endif -%} + {%- endfor %} + from + {{ tmp_relation }} + {% endset %} + + -- create tmp table + {% call statement('_') -%} + {{ sql_tmp_table }} + {%- endcall -%} + + -- create target table from tmp table + {% set sql_table = create_table_as(false, relation, sql) %} + {% call statement('_') -%} + {{ sql_table }} + {%- endcall %} + + -- drop tmp table + {{ drop_relation(tmp_relation) }} + + -- delete csv file from s3 + {% do adapter.delete_from_s3(tmp_s3_location) %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(relation, lf_grants) }} + {% endif %} + + {{ return(sql_table) }} +{% endmacro %} + +{% macro athena__create_csv_table(model, agate_table) %} + + {%- set seed_by_insert = config.get('seed_by_insert', False) | as_bool -%} + + {%- if seed_by_insert -%} + {% do log('seed by insert...') %} + {%- set sql_table = create_csv_table_insert(model, agate_table) -%} + {%- else -%} + {% do log('seed by upload...') %} + {%- set sql_table = create_csv_table_upload(model, agate_table) -%} + {%- endif -%} + + {{ return(sql_table) }} +{% endmacro %} + +{# Overwrite to satisfy dbt-core logic #} +{% macro athena__load_csv_rows(model, agate_table) %} + {%- set seed_by_insert = config.get('seed_by_insert', False) | as_bool -%} + {%- if seed_by_insert %} + {{ default__load_csv_rows(model, agate_table) }} + {%- else -%} + select 1 + {% endif %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql new file mode 100644 index 00000000..23dd2608 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql @@ -0,0 +1,244 @@ +{# + Hash function to generate dbt_scd_id. dbt by default uses md5(coalesce(cast(field as varchar))) + but it does not work with athena. It throws an error Unexpected parameters (varchar) for function + md5. Expected: md5(varbinary) +#} +{% macro athena__snapshot_hash_arguments(args) -%} + to_hex(md5(to_utf8({%- for arg in args -%} + coalesce(cast({{ arg }} as varchar ), '') + {% if not loop.last %} || '|' || {% endif %} + {%- endfor -%}))) +{%- endmacro %} + + +{# + If hive table then Recreate the snapshot table from the new_snapshot_table + If iceberg table then Update the standard snapshot merge to include the DBT_INTERNAL_SOURCE prefix in the src_cols_csv +#} +{% macro hive_snapshot_merge_sql(target, source, insert_cols, table_type) -%} + {%- set target_relation = adapter.get_relation(database=target.database, schema=target.schema, identifier=target.identifier) -%} + {%- if target_relation is not none -%} + {% do adapter.drop_relation(target_relation) %} + {%- endif -%} + + {% set sql -%} + select * from {{ source }}; + {%- endset -%} + + {{ create_table_as(False, target_relation, sql) }} +{% endmacro %} + +{% macro iceberg_snapshot_merge_sql(target, source, insert_cols) %} + {%- set insert_cols_csv = insert_cols | join(', ') -%} + {%- set src_columns = [] -%} + {%- for col in insert_cols -%} + {%- do src_columns.append('dbt_internal_source.' + col) -%} + {%- endfor -%} + {%- set src_cols_csv = src_columns | join(', ') -%} + + merge into {{ target }} as dbt_internal_dest + using {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + + when matched + and dbt_internal_dest.dbt_valid_to is null + and dbt_internal_source.dbt_change_type in ('update', 'delete') + then update + set dbt_valid_to = dbt_internal_source.dbt_valid_to + + when not matched + and dbt_internal_source.dbt_change_type = 'insert' + then insert ({{ insert_cols_csv }}) + values ({{ src_cols_csv }}) +{% endmacro %} + + +{# + Create a new temporary table that will hold the new snapshot results. + This table will then be used to overwrite the target snapshot table. +#} +{% macro hive_create_new_snapshot_table(target, source, insert_cols) %} + {%- set temp_relation = make_temp_relation(target, '__dbt_tmp_snapshot') -%} + {%- set preexisting_tmp_relation = load_cached_relation(temp_relation) -%} + {%- if preexisting_tmp_relation is not none -%} + {%- do adapter.drop_relation(preexisting_tmp_relation) -%} + {%- endif -%} + + {# TODO: Add insert_cols #} + {%- set src_columns = [] -%} + {%- set dst_columns = [] -%} + {%- set updated_columns = [] -%} + {%- for col in insert_cols -%} + {%- do src_columns.append('dbt_internal_source.' + col) -%} + {%- do dst_columns.append('dbt_internal_dest.' + col) -%} + {%- if col.replace('"', '') in ['dbt_valid_to'] -%} + {%- do updated_columns.append('dbt_internal_source.' + col) -%} + {%- else -%} + {%- do updated_columns.append('dbt_internal_dest.' + col) -%} + {%- endif -%} + {%- endfor -%} + {%- set src_cols_csv = src_columns | join(', ') -%} + {%- set dst_cols_csv = dst_columns | join(', ') -%} + {%- set updated_cols_csv = updated_columns | join(', ') -%} + + {%- set source_columns = adapter.get_columns_in_relation(source) -%} + + {% set sql -%} + -- Unchanged rows + select {{ dst_cols_csv }} + from {{ target }} as dbt_internal_dest + left join {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + where dbt_internal_source.dbt_scd_id is null + + union all + + -- Updated or deleted rows + select {{ updated_cols_csv }} + from {{ target }} as dbt_internal_dest + inner join {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + where dbt_internal_dest.dbt_valid_to is null + and dbt_internal_source.dbt_change_type in ('update', 'delete') + + union all + + -- New rows + select {{ src_cols_csv }} + from {{ source }} as dbt_internal_source + left join {{ target }} as dbt_internal_dest + on dbt_internal_dest.dbt_scd_id = dbt_internal_source.dbt_scd_id + where dbt_internal_dest.dbt_scd_id is null + + {%- endset -%} + + {% call statement('create_new_snapshot_table') %} + {{ create_table_as(False, temp_relation, sql) }} + {% endcall %} + + {% do return(temp_relation) %} +{% endmacro %} + +{% materialization snapshot, adapter='athena' %} + {%- set config = model['config'] -%} + + {%- set target_table = model.get('alias', model.get('name')) -%} + {%- set strategy_name = config.get('strategy') -%} + {%- set file_format = config.get('file_format', 'parquet') -%} + {%- set table_type = config.get('table_type', 'hive') -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {{ log('Checking if target table exists') }} + {% set target_relation_exists, target_relation = get_or_create_relation( + database=model.database, + schema=model.schema, + identifier=target_table, + type='table') -%} + + {%- if not target_relation.is_table -%} + {% do exceptions.relation_wrong_type(target_relation, 'table') %} + {%- endif -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {{ run_hooks(pre_hooks, inside_transaction=True) }} + + {% set strategy_macro = strategy_dispatch(strategy_name) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} + + {% if not target_relation_exists %} + + {% set build_sql = build_snapshot_table(strategy, model['compiled_sql']) %} + {% set final_sql = create_table_as(False, target_relation, build_sql) %} + + {% else %} + + {{ adapter.valid_snapshot_target(target_relation) }} + + {% set staging_relation = make_temp_relation(target_relation) %} + {%- set preexisting_staging_relation = load_cached_relation(staging_relation) -%} + {%- if preexisting_staging_relation is not none -%} + {%- do adapter.drop_relation(preexisting_staging_relation) -%} + {%- endif -%} + + {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} + + {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | list %} + + + {% if missing_columns %} + {% do alter_relation_add_columns(target_relation, missing_columns, table_type) %} + {% endif %} + + + {% set source_columns = adapter.get_columns_in_relation(staging_table) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | list %} + + {% set quoted_source_columns = [] %} + {% for column in source_columns %} + {% do quoted_source_columns.append(adapter.quote(column.name)) %} + {% endfor %} + + {% if table_type == 'iceberg' %} + {% set final_sql = iceberg_snapshot_merge_sql( + target = target_relation, + source = staging_table, + insert_cols = quoted_source_columns, + ) + %} + {% else %} + {% set new_snapshot_table = hive_create_new_snapshot_table( + target = target_relation, + source = staging_table, + insert_cols = quoted_source_columns, + ) + %} + {% set final_sql = hive_snapshot_merge_sql( + target = target_relation, + source = new_snapshot_table + ) + %} + {% endif %} + + {% endif %} + + {% call statement('main') %} + {{ final_sql }} + {% endcall %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + {% if staging_table is defined %} + {% do adapter.drop_relation(staging_table) %} + {% endif %} + + {% if new_snapshot_table is defined %} + {% do adapter.drop_relation(new_snapshot_table) %} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {{ return({'relations': [target_relation]}) }} + +{% endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql b/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql new file mode 100644 index 00000000..2a8bf0b5 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql @@ -0,0 +1,3 @@ +{% macro athena__any_value(expression) -%} + arbitrary({{ expression }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql new file mode 100644 index 00000000..e9868945 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql @@ -0,0 +1,3 @@ +{% macro athena__array_append(array, new_element) -%} + {{ array }} || {{ new_element }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql new file mode 100644 index 00000000..0300db98 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql @@ -0,0 +1,3 @@ +{% macro athena__array_concat(array_1, array_2) -%} + {{ array_1 }} || {{ array_2 }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql new file mode 100644 index 00000000..e52c8f3d --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql @@ -0,0 +1,7 @@ +{% macro athena__array_construct(inputs, data_type) -%} + {% if inputs|length > 0 %} + array[ {{ inputs|join(' , ') }} ] + {% else %} + {{ safe_cast('array[]', 'array(' ~ data_type ~ ')') }} + {% endif %} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql b/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql new file mode 100644 index 00000000..0b70d2f8 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql @@ -0,0 +1,3 @@ +{% macro athena__bool_or(expression) -%} + bool_or({{ expression }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql b/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql new file mode 100644 index 00000000..4b80643d --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql @@ -0,0 +1,18 @@ +{# Implements Athena-specific datatypes where they differ from the dbt-core defaults #} +{# See https://docs.aws.amazon.com/athena/latest/ug/data-types.html #} + +{%- macro athena__type_float() -%} + DOUBLE +{%- endmacro -%} + +{%- macro athena__type_numeric() -%} + DECIMAL(38,6) +{%- endmacro -%} + +{%- macro athena__type_int() -%} + INTEGER +{%- endmacro -%} + +{%- macro athena__type_string() -%} + VARCHAR +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql b/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql new file mode 100644 index 00000000..aa3400b4 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql @@ -0,0 +1,3 @@ +{% macro athena__date_trunc(datepart, date) -%} + date_trunc('{{datepart}}', {{date}}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql b/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql new file mode 100644 index 00000000..7d06bea7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql @@ -0,0 +1,3 @@ +{% macro athena__dateadd(datepart, interval, from_date_or_timestamp) -%} + date_add('{{ datepart }}', {{ interval }}, {{ from_date_or_timestamp }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql b/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql new file mode 100644 index 00000000..4c2dfeeb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql @@ -0,0 +1,28 @@ +{% macro athena__datediff(first_date, second_date, datepart) -%} + {%- if datepart == 'year' -%} + (year(CAST({{ second_date }} AS TIMESTAMP)) - year(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'quarter' -%} + ({{ datediff(first_date, second_date, 'year') }} * 4) + quarter(CAST({{ second_date }} AS TIMESTAMP)) - quarter(CAST({{ first_date }} AS TIMESTAMP)) + {%- elif datepart == 'month' -%} + ({{ datediff(first_date, second_date, 'year') }} * 12) + month(CAST({{ second_date }} AS TIMESTAMP)) - month(CAST({{ first_date }} AS TIMESTAMP)) + {%- elif datepart == 'day' -%} + ((to_milliseconds((CAST(CAST({{ second_date }} AS TIMESTAMP) AS DATE) - CAST(CAST({{ first_date }} AS TIMESTAMP) AS DATE)))) / 86400000) + {%- elif datepart == 'week' -%} + ({{ datediff(first_date, second_date, 'day') }} / 7 + case + when dow(CAST({{first_date}} AS TIMESTAMP)) <= dow(CAST({{second_date}} AS TIMESTAMP)) then + case when {{first_date}} <= {{second_date}} then 0 else -1 end + else + case when {{first_date}} <= {{second_date}} then 1 else 0 end + end) + {%- elif datepart == 'hour' -%} + ({{ datediff(first_date, second_date, 'day') }} * 24 + hour(CAST({{ second_date }} AS TIMESTAMP)) - hour(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'minute' -%} + ({{ datediff(first_date, second_date, 'hour') }} * 60 + minute(CAST({{ second_date }} AS TIMESTAMP)) - minute(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'second' -%} + ({{ datediff(first_date, second_date, 'minute') }} * 60 + second(CAST({{ second_date }} AS TIMESTAMP)) - second(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'millisecond' -%} + (to_milliseconds((CAST({{ second_date }} AS TIMESTAMP) - CAST({{ first_date }} AS TIMESTAMP)))) + {%- else -%} + {% if execute %}{{ exceptions.raise_compiler_error("Unsupported datepart for macro datediff in Athena: {!r}".format(datepart)) }}{% endif %} + {%- endif -%} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql b/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql new file mode 100644 index 00000000..f93ed907 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql @@ -0,0 +1,41 @@ +{# Athena has different types between DML and DDL #} +{# ref: https://docs.aws.amazon.com/athena/latest/ug/data-types.html #} +{% macro ddl_data_type(col_type, table_type = 'hive') -%} + -- transform varchar + {% set re = modules.re %} + {% set data_type = re.sub('(?:varchar|character varying)(?:\(\d+\))?', 'string', col_type) %} + + -- transform array and map + {%- if 'array' in data_type or 'map' in data_type -%} + {% set data_type = data_type.replace('(', '<').replace(')', '>') -%} + {%- endif -%} + + -- transform int + {%- if 'integer' in data_type -%} + {% set data_type = data_type.replace('integer', 'int') -%} + {%- endif -%} + + -- transform timestamp + {%- if table_type == 'iceberg' -%} + {%- if 'timestamp' in data_type -%} + {% set data_type = 'timestamp' -%} + {%- endif -%} + + {%- if 'binary' in data_type -%} + {% set data_type = 'binary' -%} + {%- endif -%} + {%- endif -%} + + {{ return(data_type) }} +{% endmacro %} + +{% macro dml_data_type(col_type) -%} + {%- set re = modules.re -%} + -- transform int to integer + {%- set data_type = re.sub('\bint\b', 'integer', col_type) -%} + -- transform string to varchar because string does not work in DML + {%- set data_type = re.sub('string', 'varchar', data_type) -%} + -- transform float to real because float does not work in DML + {%- set data_type = re.sub('float', 'real', data_type) -%} + {{ return(data_type) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql b/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql new file mode 100644 index 00000000..773eb507 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql @@ -0,0 +1,3 @@ +{% macro athena__hash(field) -%} + lower(to_hex(md5(to_utf8(cast({{field}} as varchar))))) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql b/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql new file mode 100644 index 00000000..265e1bd1 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql @@ -0,0 +1,18 @@ +{% macro athena__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + array_join( + {%- if limit_num %} + slice( + {%- endif %} + array_agg( + {{ measure }} + {%- if order_by_clause %} + {{ order_by_clause }} + {%- endif %} + ) + {%- if limit_num %} + , 1, {{ limit_num }} + ) + {%- endif %} + , {{ delimiter_text }} + ) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/right.sql b/dbt-athena/src/dbt/include/athena/macros/utils/right.sql new file mode 100644 index 00000000..a0203d61 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/right.sql @@ -0,0 +1,7 @@ +{% macro athena__right(string_text, length_expression) %} + case when {{ length_expression }} = 0 + then '' + else + substr({{ string_text }}, -1 * ({{ length_expression }})) + end +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql b/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql new file mode 100644 index 00000000..aed0866e --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql @@ -0,0 +1,4 @@ +-- TODO: make safe_cast supports complex structures +{% macro athena__safe_cast(field, type) -%} + try_cast({{field}} as {{type}}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql b/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql new file mode 100644 index 00000000..c746b6d2 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql @@ -0,0 +1,31 @@ +{% + pyathena converts time zoned timestamps to strings so lets avoid them now() +%} + +{% macro athena__current_timestamp() -%} + {{ cast_timestamp('now()') }} +{%- endmacro %} + + +{% macro cast_timestamp(timestamp_col) -%} + {%- set config = model.get('config', {}) -%} + {%- set table_type = config.get('table_type', 'hive') -%} + {%- set is_view = config.get('materialized', 'table') in ['view', 'ephemeral'] -%} + {%- if table_type == 'iceberg' and not is_view -%} + cast({{ timestamp_col }} as timestamp(6)) + {%- else -%} + cast({{ timestamp_col }} as timestamp) + {%- endif -%} +{%- endmacro %} + +{% + Macro to get the end_of_time timestamp +%} + +{% macro end_of_time() -%} + {{ return(adapter.dispatch('end_of_time')()) }} +{%- endmacro %} + +{% macro athena__end_of_time() -%} + cast('9999-01-01' AS timestamp) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/profile_template.yml b/dbt-athena/src/dbt/include/athena/profile_template.yml new file mode 100644 index 00000000..87be52ed --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/profile_template.yml @@ -0,0 +1,23 @@ +fixed: + type: athena +prompts: + s3_staging_dir: + hint: S3 location to store Athena query results and metadata, e.g. s3://athena_query_result/prefix/ + + s3_data_dir: + hint: S3 location where to store data/tables, e.g. s3://bucket_name/prefix/ + + region_name: + hint: AWS region of your Athena instance + + schema: + hint: Specify the schema (Athena database) to build models into (lowercase only) + + database: + hint: Specify the database (Data catalog) to build models into (lowercase only) + default: awsdatacatalog + + threads: + hint: '1 or more' + type: 'int' + default: 1 diff --git a/dbt-athena/src/dbt/include/athena/sample_profiles.yml b/dbt-athena/src/dbt/include/athena/sample_profiles.yml new file mode 100644 index 00000000..f9cbd0cf --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/sample_profiles.yml @@ -0,0 +1,19 @@ +default: + outputs: + dev: + type: athena + s3_staging_dir: [s3_staging_dir] + s3_data_dir: [s3_data_dir] + region_name: [region_name] + database: [database name] + schema: [dev_schema] + + prod: + type: athena + s3_staging_dir: [s3_staging_dir] + s3_data_dir: [s3_data_dir] + region_name: [region_name] + database: [database name] + schema: [prod_schema] + + target: dev diff --git a/dbt-athena/tests/__init__.py b/dbt-athena/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/functional/__init__.py b/dbt-athena/tests/functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/functional/adapter/fixture_datediff.py b/dbt-athena/tests/functional/adapter/fixture_datediff.py new file mode 100644 index 00000000..4986ae57 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/fixture_datediff.py @@ -0,0 +1,52 @@ +seeds__data_datediff_csv = """first_date,second_date,datepart,result +2018-01-01 01:00:00,2018-01-02 01:00:00,day,1 +2018-01-01 01:00:00,2018-02-01 01:00:00,month,1 +2018-01-01 01:00:00,2019-01-01 01:00:00,year,1 +2018-01-01 01:00:00,2018-01-01 02:00:00,hour,1 +2018-01-01 01:00:00,2018-01-01 02:01:00,minute,61 +2018-01-01 01:00:00,2018-01-01 02:00:01,second,3601 +2019-12-31 00:00:00,2019-12-27 00:00:00,week,-1 +2019-12-31 00:00:00,2019-12-30 00:00:00,week,0 +2019-12-31 00:00:00,2020-01-02 00:00:00,week,0 +2019-12-31 00:00:00,2020-01-06 02:00:00,week,1 +,2018-01-01 02:00:00,hour, +2018-01-01 02:00:00,,hour, +""" + +models__test_datediff_sql = """ +with data as ( + select * from {{ ref('data_datediff') }} +) +select + case + when datepart = 'second' then {{ datediff('first_date', 'second_date', 'second') }} + when datepart = 'minute' then {{ datediff('first_date', 'second_date', 'minute') }} + when datepart = 'hour' then {{ datediff('first_date', 'second_date', 'hour') }} + when datepart = 'day' then {{ datediff('first_date', 'second_date', 'day') }} + when datepart = 'week' then {{ datediff('first_date', 'second_date', 'week') }} + when datepart = 'month' then {{ datediff('first_date', 'second_date', 'month') }} + when datepart = 'year' then {{ datediff('first_date', 'second_date', 'year') }} + else null + end as actual, + result as expected +from data +-- Also test correct casting of literal values. +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "millisecond") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "second") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "minute") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "hour") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "day") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-03 00:00:00.000'", "week") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "month") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "quarter") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "year") }} as actual, 1 as expected +""" diff --git a/dbt-athena/tests/functional/adapter/fixture_split_parts.py b/dbt-athena/tests/functional/adapter/fixture_split_parts.py new file mode 100644 index 00000000..c9ff3d7c --- /dev/null +++ b/dbt-athena/tests/functional/adapter/fixture_split_parts.py @@ -0,0 +1,39 @@ +models__test_split_part_sql = """ +with data as ( + + select * from {{ ref('data_split_part') }} + +) + +select + {{ split_part('parts', 'split_on', 1) }} as actual, + result_1 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 2) }} as actual, + result_2 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 3) }} as actual, + result_3 as expected + +from data +""" + +models__test_split_part_yml = """ +version: 2 +models: + - name: test_split_part + tests: + - assert_equal: + actual: actual + expected: expected +""" diff --git a/dbt-athena/tests/functional/adapter/test_basic_hive.py b/dbt-athena/tests/functional/adapter/test_basic_hive.py new file mode 100644 index 00000000..b4026884 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_hive.py @@ -0,0 +1,63 @@ +""" +Run the basic dbt test suite on hive tables when applicable. +""" +import pytest + +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_empty import BaseEmpty +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_singular_tests import BaseSingularTests +from dbt.tests.adapter.basic.test_singular_tests_ephemeral import ( + BaseSingularTestsEphemeral, +) +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + pass + + +class TestSingularTestsHive(BaseSingularTests): + pass + + +class TestSingularTestsEphemeralHive(BaseSingularTestsEphemeral): + pass + + +class TestEmptyHive(BaseEmpty): + pass + + +class TestEphemeralHive(BaseEphemeral): + pass + + +class TestIncrementalHive(BaseIncremental): + pass + + +class TestGenericTestsHive(BaseGenericTests): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsHive(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampHive(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodHive(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py b/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py new file mode 100644 index 00000000..e532cf7b --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py @@ -0,0 +1,151 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +modified_config_materialized_table = """ + {{ config(materialized="table", table_type="hive", native_drop="True") }} +""" + +modified_config_materialized_incremental = """ + {{ config(materialized="incremental", + table_type="hive", + incremental_strategy="append", + unique_key="id", + native_drop="True") }} +""" + +modified_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +modified_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, modified_config_materialized_table), + (config_materialized_incremental, modified_config_materialized_incremental), + (model_base, modified_model_base), + (model_ephemeral, modified_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_iceberg.py b/dbt-athena/tests/functional/adapter/test_basic_iceberg.py new file mode 100644 index 00000000..f092b8f9 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_iceberg.py @@ -0,0 +1,147 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +iceberg_config_materialized_table = """ + {{ config(materialized="table", table_type="iceberg") }} +""" + +iceberg_config_materialized_incremental = """ + {{ config(materialized="incremental", table_type="iceberg", incremental_strategy="merge", unique_key="id") }} +""" + +iceberg_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +iceberg_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, iceberg_config_materialized_table), + (config_materialized_incremental, iceberg_config_materialized_incremental), + (model_base, iceberg_model_base), + (model_ephemeral, iceberg_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py b/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py new file mode 100644 index 00000000..020551c2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py @@ -0,0 +1,152 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +iceberg_config_materialized_table = """ + {{ config(materialized="table", table_type="iceberg", native_drop="True") }} +""" + +iceberg_config_materialized_incremental = """ + {{ config(materialized="incremental", + table_type="iceberg", + incremental_strategy="merge", + unique_key="id", + native_drop="True") }} +""" + +iceberg_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +iceberg_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, iceberg_config_materialized_table), + (config_materialized_incremental, iceberg_config_materialized_incremental), + (model_base, iceberg_model_base), + (model_ephemeral, iceberg_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +@pytest.mark.skip(reason="The native drop will usually time out in the test") +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_change_relation_types.py b/dbt-athena/tests/functional/adapter/test_change_relation_types.py new file mode 100644 index 00000000..047cac75 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_change_relation_types.py @@ -0,0 +1,26 @@ +import pytest + +from dbt.tests.adapter.relations.test_changing_relation_type import ( + BaseChangeRelationTypeValidator, +) + + +class TestChangeRelationTypesHive(BaseChangeRelationTypeValidator): + pass + + +class TestChangeRelationTypesIceberg(BaseChangeRelationTypeValidator): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + } + } + + def test_changing_materialization_changes_relation_type(self, project): + self._run_and_check_materialization("view") + self._run_and_check_materialization("table") + self._run_and_check_materialization("view") + # skip incremntal that doesn't work with Iceberg + self._run_and_check_materialization("table", extra_args=["--full-refresh"]) diff --git a/dbt-athena/tests/functional/adapter/test_constraints.py b/dbt-athena/tests/functional/adapter/test_constraints.py new file mode 100644 index 00000000..44ff8c57 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_constraints.py @@ -0,0 +1,28 @@ +import pytest + +from dbt.tests.adapter.constraints.fixtures import ( + model_quoted_column_schema_yml, + my_model_with_quoted_column_name_sql, +) +from dbt.tests.adapter.constraints.test_constraints import BaseConstraintQuotedColumn + + +class TestAthenaConstraintQuotedColumn(BaseConstraintQuotedColumn): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_with_quoted_column_name_sql, + # we replace text type with varchar + # do not replace text with string, because then string is replaced with a capital TEXT that leads to failure + "constraints_schema.yml": model_quoted_column_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self): + # FIXME: dbt-athena outputs a query about stats into `target/run/` directory. + # dbt-core expects the query to be a ddl statement to create a table. + # This is a workaround to pass the test for now. + + # NOTE: by the above reason, this test just checks the query can be executed without errors. + # The query itself is not checked. + return 'SELECT \'{"rowcount":1,"data_scanned_in_bytes":0}\';' diff --git a/dbt-athena/tests/functional/adapter/test_detailed_table_type.py b/dbt-athena/tests/functional/adapter/test_detailed_table_type.py new file mode 100644 index 00000000..ba6c885c --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_detailed_table_type.py @@ -0,0 +1,44 @@ +import re + +import pytest + +from dbt.tests.util import run_dbt, run_dbt_and_capture + +get_detailed_table_type_sql = """ +{% macro get_detailed_table_type(schema) %} + {% if execute %} + {% set relation = api.Relation.create(database="awsdatacatalog", schema=schema) %} + {% set schema_tables = adapter.list_relations_without_caching(schema_relation = relation) %} + {% for rel in schema_tables %} + {% do log('Detailed Table Type: ' ~ rel.detailed_table_type, info=True) %} + {% endfor %} + {% endif %} +{% endmacro %} +""" + +# Model SQL for an Iceberg table +iceberg_model_sql = """ + select 1 as id, 'iceberg' as name + {{ config(materialized='table', table_type='iceberg') }} +""" + + +@pytest.mark.usefixtures("project") +class TestDetailedTableType: + @pytest.fixture(scope="class") + def macros(self): + return {"get_detailed_table_type.sql": get_detailed_table_type_sql} + + @pytest.fixture(scope="class") + def models(self): + return {"iceberg_model.sql": iceberg_model_sql} + + def test_detailed_table_type(self, project): + # Run the models + run_results = run_dbt(["run"]) + assert len(run_results) == 1 # Ensure model ran successfully + + args_str = f'{{"schema": "{project.test_schema}"}}' + run_macro, stdout = run_dbt_and_capture(["run-operation", "get_detailed_table_type", "--args", args_str]) + iceberg_table_type = re.search(r"Detailed Table Type: (\w+)", stdout).group(1) + assert iceberg_table_type == "ICEBERG" diff --git a/dbt-athena/tests/functional/adapter/test_docs.py b/dbt-athena/tests/functional/adapter/test_docs.py new file mode 100644 index 00000000..e8bc5978 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_docs.py @@ -0,0 +1,135 @@ +import os + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.adapter.basic.expected_catalog import base_expected_catalog, no_stats +from dbt.tests.adapter.basic.test_docs_generate import ( + BaseDocsGenerate, + run_and_generate, + verify_metadata, +) +from dbt.tests.util import get_artifact, run_dbt + +model_sql = """ +select 1 as id +""" + +iceberg_model_sql = """ +{{ + config( + materialized="table", + table_type="iceberg", + post_hook="alter table model drop column to_drop" + ) +}} +select 1 as id, 'to_drop' as to_drop +""" + +override_macros_sql = """ +{% macro get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} +""" + + +def custom_verify_catalog_athena(project, expected_catalog, start_time): + # get the catalog.json + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + assert os.path.exists(catalog_path) + catalog = get_artifact(catalog_path) + + # verify the catalog + assert set(catalog) == {"errors", "nodes", "sources", "metadata"} + verify_metadata( + catalog["metadata"], + "https://schemas.getdbt.com/dbt/catalog/v1.json", + start_time, + ) + assert not catalog["errors"] + + for key in "nodes", "sources": + for unique_id, expected_node in expected_catalog[key].items(): + found_node = catalog[key][unique_id] + for node_key in expected_node: + assert node_key in found_node + # the value of found_node[node_key] is not exactly expected_node[node_key] + + +class TestDocsGenerate(BaseDocsGenerate): + """ + Override of BaseDocsGenerate to make it working with Athena + """ + + @pytest.fixture(scope="class") + def expected_catalog(self, project): + return base_expected_catalog( + project, + role="test", + id_type="integer", + text_type="text", + time_type="timestamp without time zone", + view_type="VIEW", + table_type="BASE TABLE", + model_stats=no_stats(), + ) + + def test_run_and_generate_no_compile(self, project, expected_catalog): + start_time = run_and_generate(project, ["--no-compile"]) + assert not os.path.exists(os.path.join(project.project_root, "target", "manifest.json")) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Test generic "docs generate" command + def test_run_and_generate(self, project, expected_catalog): + start_time = run_and_generate(project) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Check that assets have been copied to the target directory for use in the docs html page + assert os.path.exists(os.path.join(".", "target", "assets")) + assert os.path.exists(os.path.join(".", "target", "assets", "lorem-ipsum.txt")) + assert not os.path.exists(os.path.join(".", "target", "non-existent-assets")) + + +class TestDocsGenerateOverride: + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": model_sql} + + @pytest.fixture(scope="class") + def macros(self): + return {"override_macros_sql.sql": override_macros_sql} + + def test_generate_docs(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + docs_generate = run_dbt(["--warn-error", "docs", "generate"]) + assert len(docs_generate._compile_results.results) == 1 + assert docs_generate._compile_results.results[0].status == RunStatus.Success + assert docs_generate.errors is None + + +class TestDocsGenerateIcebergNonCurrentColumn: + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": iceberg_model_sql} + + @pytest.fixture(scope="class") + def macros(self): + return {"override_macros_sql.sql": override_macros_sql} + + def test_generate_docs(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + docs_generate = run_dbt(["--warn-error", "docs", "generate"]) + assert len(docs_generate._compile_results.results) == 1 + assert docs_generate._compile_results.results[0].status == RunStatus.Success + assert docs_generate.errors is None + + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + assert os.path.exists(catalog_path) + catalog = get_artifact(catalog_path) + columns = catalog["nodes"]["model.test.model"]["columns"] + assert "to_drop" not in columns + assert "id" in columns diff --git a/dbt-athena/tests/functional/adapter/test_empty.py b/dbt-athena/tests/functional/adapter/test_empty.py new file mode 100644 index 00000000..c79fdfc2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_empty.py @@ -0,0 +1,9 @@ +from dbt.tests.adapter.empty.test_empty import BaseTestEmpty, BaseTestEmptyInlineSourceRef + + +class TestAthenaEmpty(BaseTestEmpty): + pass + + +class TestAthenaEmptyInlineSourceRef(BaseTestEmptyInlineSourceRef): + pass diff --git a/dbt-athena/tests/functional/adapter/test_force_batch.py b/dbt-athena/tests/functional/adapter/test_force_batch.py new file mode 100644 index 00000000..4d1fcc30 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_force_batch.py @@ -0,0 +1,136 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__force_batch_sql = """ +{{ config( + materialized='table', + partitioned_by=['date_column'], + force_batch=true + ) +}} + +select + random() as rnd, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +models_append_force_batch_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='append', + partitioned_by=['date_column'], + force_batch=true + ) +}} + +select + random() as rnd, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +models_merge_force_batch_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['date_column'], + partitioned_by=['date_column'], + force_batch=true + ) +}} +{% if is_incremental() %} + select + 1 as rnd, + cast(date_column as date) as date_column + from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) + ) as t1(date_array) + cross join unnest(date_array) as t2(date_column) +{% else %} + select + 2 as rnd, + cast(date_column as date) as date_column + from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) + ) as t1(date_array) + cross join unnest(date_array) as t2(date_column) +{% endif %} +""" + + +class TestForceBatchInsertParam: + @pytest.fixture(scope="class") + def models(self): + return {"force_batch.sql": models__force_batch_sql} + + def test__force_batch_param(self, project): + relation_name = "force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert models_records_count == 212 + + +class TestAppendForceBatch: + @pytest.fixture(scope="class") + def models(self): + return {"models_append_force_batch.sql": models_append_force_batch_sql} + + def test__append_force_batch_param(self, project): + relation_name = "models_append_force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count == 212 + + +class TestMergeForceBatch: + @pytest.fixture(scope="class") + def models(self): + return {"models_merge_force_batch.sql": models_merge_force_batch_sql} + + def test__merge_force_batch_param(self, project): + relation_name = "models_merge_force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_distinct_query = f"select distinct rnd from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count == 212 + + models_distinct_records = project.run_sql(model_run_result_distinct_query, fetch="all")[0][0] + assert models_distinct_records == 1 diff --git a/dbt-athena/tests/functional/adapter/test_ha_iceberg.py b/dbt-athena/tests/functional/adapter/test_ha_iceberg.py new file mode 100644 index 00000000..769ce581 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_ha_iceberg.py @@ -0,0 +1,90 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__table_iceberg_naming_table_unique = """ +{{ config( + materialized='table', + table_type='iceberg', + s3_data_naming='table_unique' + ) +}} + +select + 1 as id, + 'test 1' as name +union all +select + 2 as id, + 'test 2' as name +""" + +models__table_iceberg_naming_table = """ +{{ config( + materialized='table', + table_type='iceberg', + s3_data_naming='table' + ) +}} + +select + 1 as id, + 'test 1' as name +union all +select + 2 as id, + 'test 2' as name +""" + + +class TestTableIcebergTableUnique: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_table_unique.sql": models__table_iceberg_naming_table_unique} + + def test__table_creation(self, project, capsys): + relation_name = "table_iceberg_table_unique" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + fist_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = fist_model_run.results[0] + assert first_model_run_result.status == RunStatus.Success + + first_models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert first_models_records_count == 2 + + second_model_run = run_dbt(["run", "-d", "--select", relation_name]) + second_model_run_result = second_model_run.results[0] + assert second_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + # in case of 2nd run we expect that the target table is renamed to __bkp + alter_statement = ( + f"alter table `{project.test_schema}`.`{relation_name}` " + f"rename to `{project.test_schema}`.`{relation_name}__bkp`" + ) + delete_bkp_table_log = ( + f'Deleted table from glue catalog: "awsdatacatalog"."{project.test_schema}"."{relation_name}__bkp"' + ) + assert alter_statement in out + assert delete_bkp_table_log in out + + second_models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert second_models_records_count == 2 + + +# in case s3_data_naming=table for iceberg a compile error must be raised +# with this test we want to be sure that this type of behavior is not violated +class TestTableIcebergTable: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_table.sql": models__table_iceberg_naming_table} + + def test__table_creation(self, project): + relation_name = "table_iceberg_table" + + with pytest.raises(Exception): + run_dbt(["run", "--select", relation_name]) diff --git a/dbt-athena/tests/functional/adapter/test_hive_iceberg.py b/dbt-athena/tests/functional/adapter/test_hive_iceberg.py new file mode 100644 index 00000000..c52a720b --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_hive_iceberg.py @@ -0,0 +1,67 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__table_base_model = """ +{{ config( + materialized='table', + table_type=var("table_type"), + s3_data_naming='table_unique' + ) +}} + +select + 1 as id, + 'test 1' as name, + {{ cast_timestamp('current_timestamp') }} as created_at +union all +select + 2 as id, + 'test 2' as name, + {{ cast_timestamp('current_timestamp') }} as created_at +""" + + +class TestTableFromHiveToIceberg: + @pytest.fixture(scope="class") + def models(self): + return {"table_hive_to_iceberg.sql": models__table_base_model} + + def test__table_creation(self, project): + relation_name = "table_hive_to_iceberg" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run_hive = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"hive"}']) + model_run_result_hive = model_run_hive.results[0] + assert model_run_result_hive.status == RunStatus.Success + models_records_count_hive = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_hive == 2 + + model_run_iceberg = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"iceberg"}']) + model_run_result_iceberg = model_run_iceberg.results[0] + assert model_run_result_iceberg.status == RunStatus.Success + models_records_count_iceberg = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_iceberg == 2 + + +class TestTableFromIcebergToHive: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_to_hive.sql": models__table_base_model} + + def test__table_creation(self, project): + relation_name = "table_iceberg_to_hive" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run_iceberg = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"iceberg"}']) + model_run_result_iceberg = model_run_iceberg.results[0] + assert model_run_result_iceberg.status == RunStatus.Success + models_records_count_iceberg = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_iceberg == 2 + + model_run_hive = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"hive"}']) + model_run_result_hive = model_run_hive.results[0] + assert model_run_result_hive.status == RunStatus.Success + models_records_count_hive = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_hive == 2 diff --git a/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py b/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py new file mode 100644 index 00000000..4e45403f --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py @@ -0,0 +1,390 @@ +""" +Run optional dbt functional tests for Iceberg incremental merge, including delete and incremental predicates. + +""" +import re + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.adapter.incremental.test_incremental_merge_exclude_columns import ( + BaseMergeExcludeColumns, +) +from dbt.tests.adapter.incremental.test_incremental_predicates import ( + BaseIncrementalPredicates, +) +from dbt.tests.adapter.incremental.test_incremental_unique_id import ( + BaseIncrementalUniqueKey, + models__duplicated_unary_unique_key_list_sql, + models__empty_str_unique_key_sql, + models__empty_unique_key_list_sql, + models__expected__one_str__overwrite_sql, + models__expected__unique_key_list__inplace_overwrite_sql, + models__no_unique_key_sql, + models__nontyped_trinary_unique_key_list_sql, + models__not_found_unique_key_list_sql, + models__not_found_unique_key_sql, + models__str_unique_key_sql, + models__trinary_unique_key_list_sql, + models__unary_unique_key_list_sql, + seeds__add_new_rows_sql, + seeds__duplicate_insert_sql, + seeds__seed_csv, +) +from dbt.tests.util import check_relations_equal, run_dbt + +seeds__expected_incremental_predicates_csv = """id,msg,color +3,anyway,purple +1,hey,blue +2,goodbye,red +2,yo,green +""" + +seeds__expected_delete_condition_csv = """id,msg,color +1,hey,blue +3,anyway,purple +""" + +seeds__expected_predicates_and_delete_condition_csv = """id,msg,color +1,hey,blue +1,hello,blue +3,anyway,purple +""" + +models__merge_exclude_all_columns_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy='merge', + merge_exclude_columns=['msg', 'color'] +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +seeds__expected_merge_exclude_all_columns_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + +models__update_condition_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + update_condition='target.id > 1' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, 'v1-updated') + , (2, 'v2-updated') +) as t (id, value) + +{% else %} + +select * from ( + values + (-1, 'v-1') + , (0, 'v0') + , (1, 'v1') + , (2, 'v2') +) as t (id, value) + +{% endif %} +""" + +seeds__expected_update_condition_csv = """id,value +-1,v-1 +0,v0 +1,v1 +2,v2-updated +""" + +models__insert_condition_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + insert_condition='src.status != 0' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, -1) + , (2, 0) + , (3, 1) +) as t (id, status) + +{% else %} + +select * from ( + values + (0, 1) +) as t (id, status) + +{% endif %} + +""" + +seeds__expected_insert_condition_csv = """id,status +0, 1 +1,-1 +3,1 +""" + + +class TestIcebergIncrementalUniqueKey(BaseIncrementalUniqueKey): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "iceberg", "incremental_strategy": "merge"}} + + @pytest.fixture(scope="class") + def models(self): + return { + "trinary_unique_key_list.sql": models__trinary_unique_key_list_sql, + "nontyped_trinary_unique_key_list.sql": models__nontyped_trinary_unique_key_list_sql, + "unary_unique_key_list.sql": models__unary_unique_key_list_sql, + "not_found_unique_key.sql": models__not_found_unique_key_sql, + "empty_unique_key_list.sql": models__empty_unique_key_list_sql, + "no_unique_key.sql": models__no_unique_key_sql, + "empty_str_unique_key.sql": models__empty_str_unique_key_sql, + "str_unique_key.sql": models__str_unique_key_sql, + "duplicated_unary_unique_key_list.sql": models__duplicated_unary_unique_key_list_sql, + "not_found_unique_key_list.sql": models__not_found_unique_key_list_sql, + "expected": { + "one_str__overwrite.sql": replace_cast_date(models__expected__one_str__overwrite_sql), + "unique_key_list__inplace_overwrite.sql": replace_cast_date( + models__expected__unique_key_list__inplace_overwrite_sql + ), + }, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "duplicate_insert.sql": replace_cast_date(seeds__duplicate_insert_sql), + "seed.csv": seeds__seed_csv, + "add_new_rows.sql": replace_cast_date(seeds__add_new_rows_sql), + } + + @pytest.mark.xfail(reason="Model config 'unique_keys' is required for incremental merge.") + def test__no_unique_keys(self, project): + super().test__no_unique_keys(project) + + @pytest.mark.skip( + reason="""" + If 'unique_keys' does not contain columns then the join condition will fail. + The adapter isn't handling this input scenario. + """ + ) + def test__empty_str_unique_key(self): + pass + + @pytest.mark.skip( + reason=""" + If 'unique_keys' does not contain columns then the join condition will fail. + The adapter isn't handling this input scenario. + """ + ) + def test__empty_unique_key_list(self): + pass + + +class TestIcebergIncrementalPredicates(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+incremental_predicates": ["src.id <> 3", "target.id <> 2"], + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_delete_insert_incremental_predicates.csv": seeds__expected_incremental_predicates_csv} + + +class TestIcebergDeleteCondition(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+delete_condition": "src.id = 2 and target.color = 'red'", + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_delete_insert_incremental_predicates.csv": seeds__expected_delete_condition_csv} + + # Modifying the seed_rows number from the base class method + def test__incremental_predicates(self, project): + """seed should match model after two incremental runs""" + + expected_fields = self.get_expected_fields( + relation="expected_delete_insert_incremental_predicates", seed_rows=2 + ) + test_case_fields = self.get_test_fields( + project, + seed="expected_delete_insert_incremental_predicates", + incremental_model="delete_insert_incremental_predicates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) + + +class TestIcebergPredicatesAndDeleteCondition(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+delete_condition": "src.msg = 'yo' and target.color = 'red'", + "+incremental_predicates": ["src.id <> 1", "target.msg <> 'blue'"], + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_delete_insert_incremental_predicates.csv": seeds__expected_predicates_and_delete_condition_csv + } + + # Modifying the seed_rows number from the base class method + def test__incremental_predicates(self, project): + """seed should match model after two incremental runs""" + + expected_fields = self.get_expected_fields( + relation="expected_delete_insert_incremental_predicates", seed_rows=3 + ) + test_case_fields = self.get_test_fields( + project, + seed="expected_delete_insert_incremental_predicates", + incremental_model="delete_insert_incremental_predicates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) + + +class TestIcebergMergeExcludeColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + +class TestIcebergMergeExcludeAllColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + @pytest.fixture(scope="class") + def models(self): + return {"merge_exclude_columns.sql": models__merge_exclude_all_columns_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_exclude_columns.csv": seeds__expected_merge_exclude_all_columns_csv} + + +class TestIcebergUpdateCondition: + @pytest.fixture(scope="class") + def models(self): + return {"merge_update_condition.sql": models__update_condition_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_update_condition.csv": seeds__expected_update_condition_csv} + + def test__merge_update_condition(self, project): + """Seed should match the model after incremental run""" + + expected_seed_name = "expected_merge_update_condition" + run_dbt(["seed", "--select", expected_seed_name, "--full-refresh"]) + + relation_name = "merge_update_condition" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + check_relations_equal(project.adapter, [relation_name, expected_seed_name]) + + +def replace_cast_date(model: str) -> str: + """Wrap all date strings with a cast date function""" + + new_model = re.sub("'[0-9]{4}-[0-9]{2}-[0-9]{2}'", r"cast(\g<0> as date)", model) + return new_model + + +class TestIcebergInsertCondition: + @pytest.fixture(scope="class") + def models(self): + return {"merge_insert_condition.sql": models__insert_condition_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_insert_condition.csv": seeds__expected_insert_condition_csv} + + def test__merge_insert_condition(self, project): + """Seed should match the model after run""" + + expected_seed_name = "expected_merge_insert_condition" + run_dbt(["seed", "--select", expected_seed_name, "--full-refresh"]) + + relation_name = "merge_insert_condition" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + check_relations_equal(project.adapter, [relation_name, expected_seed_name]) diff --git a/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py b/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py new file mode 100644 index 00000000..039da2d8 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py @@ -0,0 +1,115 @@ +from collections import namedtuple + +import pytest + +from dbt.tests.util import check_relations_equal, run_dbt + +models__merge_no_updates_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy = 'merge', + merge_update_columns = ['id'], + table_type = 'iceberg', +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +seeds__expected_merge_no_updates_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + +ResultHolder = namedtuple( + "ResultHolder", + [ + "seed_count", + "model_count", + "seed_rows", + "inc_test_model_count", + "relation", + ], +) + + +class TestIncrementalIcebergMergeNoUpdates: + @pytest.fixture(scope="class") + def models(self): + return {"merge_no_updates.sql": models__merge_no_updates_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_no_updates.csv": seeds__expected_merge_no_updates_csv} + + def update_incremental_model(self, incremental_model): + """update incremental model after the seed table has been updated""" + model_result_set = run_dbt(["run", "--select", incremental_model]) + return len(model_result_set) + + def get_test_fields(self, project, seed, incremental_model, update_sql_file): + seed_count = len(run_dbt(["seed", "--select", seed, "--full-refresh"])) + + model_count = len(run_dbt(["run", "--select", incremental_model, "--full-refresh"])) + + relation = incremental_model + # update seed in anticipation of incremental model update + row_count_query = f"select * from {project.test_schema}.{seed}" + + seed_rows = len(project.run_sql(row_count_query, fetch="all")) + + # propagate seed state to incremental model according to unique keys + inc_test_model_count = self.update_incremental_model(incremental_model=incremental_model) + + return ResultHolder(seed_count, model_count, seed_rows, inc_test_model_count, relation) + + def check_scenario_correctness(self, expected_fields, test_case_fields, project): + """Invoke assertions to verify correct build functionality""" + # 1. test seed(s) should build afresh + assert expected_fields.seed_count == test_case_fields.seed_count + # 2. test model(s) should build afresh + assert expected_fields.model_count == test_case_fields.model_count + # 3. seeds should have intended row counts post update + assert expected_fields.seed_rows == test_case_fields.seed_rows + # 4. incremental test model(s) should be updated + assert expected_fields.inc_test_model_count == test_case_fields.inc_test_model_count + # 5. result table should match intended result set (itself a relation) + check_relations_equal(project.adapter, [expected_fields.relation, test_case_fields.relation]) + + def test__merge_no_updates(self, project): + """seed should match model after incremental run""" + + expected_fields = ResultHolder( + seed_count=1, + model_count=1, + inc_test_model_count=1, + seed_rows=3, + relation="expected_merge_no_updates", + ) + + test_case_fields = self.get_test_fields( + project, + seed="expected_merge_no_updates", + incremental_model="merge_no_updates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) diff --git a/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py b/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py new file mode 100644 index 00000000..d06e95f2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py @@ -0,0 +1,108 @@ +import pytest +import yaml +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__schema_tmp_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='insert_overwrite', + partitioned_by=['date_column'], + temp_schema=var('temp_schema_name') + ) +}} +select + random() as rnd, + cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column +""" + + +class TestIncrementalTmpSchema: + @pytest.fixture(scope="class") + def models(self): + return {"schema_tmp.sql": models__schema_tmp_sql} + + def test__schema_tmp(self, project, capsys): + relation_name = "schema_tmp" + temp_schema_name = f"{project.test_schema}_tmp" + drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + vars_dict = { + "temp_schema_name": temp_schema_name, + "logical_date": "2024-01-01", + } + + first_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + first_model_run_result = first_model_run.results[0] + + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 1 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name not in incremental_model_run_result_table_name + + vars_dict["logical_date"] = "2024-01-02" + incremental_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + records_count_incremental_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_incremental_run == 2 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name == incremental_model_run_result_table_name.split(".")[1].strip('"') + + project.run_sql(drop_temp_schema) diff --git a/dbt-athena/tests/functional/adapter/test_on_schema_change.py b/dbt-athena/tests/functional/adapter/test_on_schema_change.py new file mode 100644 index 00000000..bebd9110 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_on_schema_change.py @@ -0,0 +1,100 @@ +import json + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt, run_dbt_and_capture + +models__table_base_model = """ +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change=var("on_schema_change"), + table_type=var("table_type"), + ) +}} + +select + 1 as id, + 'test 1' as name +{%- if is_incremental() -%} + ,current_date as updated_at +{%- endif -%} +""" + + +class TestOnSchemaChange: + @pytest.fixture(scope="class") + def models(self): + models = {} + for table_type in ["hive", "iceberg"]: + for on_schema_change in ["sync_all_columns", "append_new_columns", "ignore", "fail"]: + models[f"{table_type}_on_schema_change_{on_schema_change}.sql"] = models__table_base_model + return models + + def _column_names(self, project, relation_name): + result = project.run_sql(f"show columns from {relation_name}", fetch="all") + column_names = [row[0].strip() for row in result] + return column_names + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__sync_all_columns(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_sync_all_columns" + vars = {"on_schema_change": "sync_all_columns", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name", "updated_at"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__append_new_columns(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_append_new_columns" + vars = {"on_schema_change": "append_new_columns", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name", "updated_at"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__ignore(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_ignore" + vars = {"on_schema_change": "ignore", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__fail(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_fail" + vars = {"on_schema_change": "fail", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental, log = run_dbt_and_capture(args, expect_pass=False) + assert model_run_incremental.results[0].status == RunStatus.Error + assert "The source and target schemas on this incremental model are out of sync!" in log + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name"] diff --git a/dbt-athena/tests/functional/adapter/test_partitions.py b/dbt-athena/tests/functional/adapter/test_partitions.py new file mode 100644 index 00000000..da2e5955 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_partitions.py @@ -0,0 +1,352 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +# this query generates 212 records +test_partitions_model_sql = """ +select + random() as rnd, + cast(date_column as date) as date_column, + doy(date_column) as doy +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +test_single_nullable_partition_model_sql = """ +with data as ( + select + random() as col_1, + row_number() over() as id + from + unnest(sequence(1, 200)) +) + +select + col_1, id +from data +union all +select random() as col_1, NULL as id +union all +select random() as col_1, NULL as id +""" + +test_nullable_partitions_model_sql = """ +{{ config( + materialized='table', + format='parquet', + s3_data_naming='table', + partitioned_by=['id', 'date_column'] +) }} + +with data as ( + select + random() as rnd, + row_number() over() as id, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +) + +select + rnd, + case when id <= 50 then null else id end as id, + date_column +from data +union all +select + random() as rnd, + NULL as id, + NULL as date_column +union all +select + random() as rnd, + NULL as id, + cast('2023-09-02' as date) as date_column +union all +select + random() as rnd, + 40 as id, + NULL as date_column +""" + +test_bucket_partitions_sql = """ +with non_random_strings as ( + select + chr(cast(65 + (row_number() over () % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 1) % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 4) % 26) as bigint)) as non_random_str + from + (select 1 union all select 2 union all select 3) as temp_table +) +select + cast(date_column as date) as date_column, + doy(date_column) as doy, + nrnd.non_random_str +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-24'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +join non_random_strings nrnd on true +""" + + +class TestHiveTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "hive", "+materialized": "table", "+partitioned_by": ["date_column", "doy"]}} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_hive_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_hive_partitions" + model_run_result_row_count_query = "select count(*) as records from {}.{}".format( + project.test_schema, relation_name + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_iceberg_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergIncrementalPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "incremental", + "+incremental_strategy": "merge", + "+unique_key": "doy", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions_incremental.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + """ + Check that the incremental run works with iceberg and partitioned datasets + """ + + relation_name = "test_iceberg_partitions_incremental" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name, "--full-refresh"]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + incremental_model_run = run_dbt(["run", "--select", relation_name]) + + incremental_model_run_result = incremental_model_run.results[0] + + # check that the model run successfully after incremental run + assert incremental_model_run_result.status == RunStatus.Success + + incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert incremental_records_count == 212 + + +class TestHiveNullValuedPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id", "date_column"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_nullable_partitions_model.sql": test_nullable_partitions_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_nullable_partitions_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_null_id_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where id is null" + ) + model_run_result_null_date_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where date_column is null" + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 215 + + null_id_count_first_run = project.run_sql(model_run_result_null_id_count_query, fetch="all")[0][0] + + assert null_id_count_first_run == 52 + + null_date_count_first_run = project.run_sql(model_run_result_null_date_count_query, fetch="all")[0][0] + + assert null_date_count_first_run == 2 + + +class TestHiveSingleNullValuedPartition: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_single_nullable_partition_model.sql": test_single_nullable_partition_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_single_nullable_partition_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 202 + + +class TestIcebergTablePartitionsBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy", "bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_and_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 + + +class TestIcebergTableBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_in_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 diff --git a/dbt-athena/tests/functional/adapter/test_python_submissions.py b/dbt-athena/tests/functional/adapter/test_python_submissions.py new file mode 100644 index 00000000..55f8beb9 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_python_submissions.py @@ -0,0 +1,177 @@ +# import os +# import shutil +# from collections import Counter +# from copy import deepcopy + +# import pytest +# import yaml + +# from dbt.tests.adapter.python_model.test_python_model import ( +# BasePythonIncrementalTests, +# BasePythonModelTests, +# ) +# from dbt.tests.util import run_dbt + +# basic_sql = """ +# {{ config(materialized="table") }} +# select 1 as column_1, 2 as column_2, '{{ run_started_at.strftime("%Y-%m-%d") }}' as run_date +# """ + +# basic_python = """ +# def model(dbt, spark): +# dbt.config( +# materialized='table', +# ) +# df = dbt.ref("model") +# return df +# """ + +# basic_spark_python = """ +# def model(dbt, spark_session): +# dbt.config(materialized="table") + +# data = [(1,), (2,), (3,), (4,)] + +# df = spark_session.createDataFrame(data, ["A"]) + +# return df +# """ + +# second_sql = """ +# select * from {{ref('my_python_model')}} +# """ + +# schema_yml = """version: 2 +# models: +# - name: model +# versions: +# - v: 1 +# """ + + +# class TestBasePythonModelTests(BasePythonModelTests): +# @pytest.fixture(scope="class") +# def models(self): +# return { +# "schema.yml": schema_yml, +# "model.sql": basic_sql, +# "my_python_model.py": basic_python, +# "spark_model.py": basic_spark_python, +# "second_sql_model.sql": second_sql, +# } + + +# incremental_python = """ +# def model(dbt, spark_session): +# dbt.config(materialized="incremental") +# df = dbt.ref("model") + +# if dbt.is_incremental: +# max_from_this = ( +# f"select max(run_date) from {dbt.this.schema}.{dbt.this.identifier}" +# ) +# df = df.filter(df.run_date >= spark_session.sql(max_from_this).collect()[0][0]) + +# return df +# """ + + +# class TestBasePythonIncrementalTests(BasePythonIncrementalTests): +# @pytest.fixture(scope="class") +# def project_config_update(self): +# return {"models": {"+incremental_strategy": "append"}} + +# @pytest.fixture(scope="class") +# def models(self): +# return {"model.sql": basic_sql, "incremental.py": incremental_python} + +# def test_incremental(self, project): +# vars_dict = { +# "test_run_schema": project.test_schema, +# } + +# results = run_dbt(["run", "--vars", yaml.safe_dump(vars_dict)]) +# assert len(results) == 2 + + +# class TestPythonClonePossible: +# """Test that basic clone operations are possible on Python models. This +# class has been adapted from the BaseClone and BaseClonePossible classes in +# dbt-core and modified to test Python models in addition to SQL models.""" + +# @pytest.fixture(scope="class") +# def models(self): +# return { +# "schema.yml": schema_yml, +# "model.sql": basic_sql, +# "my_python_model.py": basic_python, +# } + +# @pytest.fixture(scope="class") +# def other_schema(self, unique_schema): +# return unique_schema + "_other" + +# @pytest.fixture(scope="class") +# def profiles_config_update(self, dbt_profile_target, unique_schema, other_schema): +# """Update the profiles config to duplicate the default schema to a +# separate schema called `otherschema`.""" +# outputs = {"default": dbt_profile_target, "otherschema": deepcopy(dbt_profile_target)} +# outputs["default"]["schema"] = unique_schema +# outputs["otherschema"]["schema"] = other_schema +# return {"test": {"outputs": outputs, "target": "default"}} + +# def copy_state(self, project_root): +# """Copy the manifest.json project for a run into a separate `state/` +# directory inside the project root, so that we can reference it +# for cloning.""" +# state_path = os.path.join(project_root, "state") +# if not os.path.exists(state_path): +# os.makedirs(state_path) +# shutil.copyfile(f"{project_root}/target/manifest.json", f"{state_path}/manifest.json") + +# def run_and_save_state(self, project_root): +# """Run models and save the state to a separate directory to prepare +# for testing clone operations.""" +# results = run_dbt(["run"]) +# assert len(results) == 2 +# self.copy_state(project_root) + +# def assert_relation_types_match_counter(self, project, schema, counter): +# """Check that relation types in a given database and schema match the +# counts specified by a Counter object.""" +# schema_relations = project.adapter.list_relations(database=project.database, schema=schema) +# schema_types = [str(r.type) for r in schema_relations] +# assert Counter(schema_types) == counter + +# def test_can_clone_true(self, project, unique_schema, other_schema): +# """Test that Python models can be cloned using `dbt clone`. Adapted from +# the BaseClonePossible.test_can_clone_true test in dbt-core.""" +# project.create_test_schema(other_schema) +# self.run_and_save_state(project.project_root) + +# # Base models should be materialized as tables +# self.assert_relation_types_match_counter(project, unique_schema, Counter({"table": 2})) + +# clone_args = [ +# "clone", +# "--state", +# "state", +# "--target", +# "otherschema", +# ] + +# results = run_dbt(clone_args) +# assert len(results) == 2 + +# # Cloned models should be materialized as views +# self.assert_relation_types_match_counter(project, other_schema, Counter({"view": 2})) + +# # Objects already exist, so this is a no-op +# results = run_dbt(clone_args) +# assert len(results) == 2 +# assert all("no-op" in r.message.lower() for r in results) + +# # Recreate all objects +# results = run_dbt([*clone_args, "--full-refresh"]) +# assert len(results) == 2 +# assert not any("no-op" in r.message.lower() for r in results) diff --git a/dbt-athena/tests/functional/adapter/test_quote_seed_column.py b/dbt-athena/tests/functional/adapter/test_quote_seed_column.py new file mode 100644 index 00000000..c0845c39 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_quote_seed_column.py @@ -0,0 +1,39 @@ +import pytest + +from dbt.tests.adapter.basic.files import ( + base_materialized_var_sql, + base_table_sql, + base_view_sql, + schema_base_yml, + seeds_base_csv, +) +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations + +seed_base_csv_underscore_column = seeds_base_csv.replace("name", "_name") + +quote_columns_seed_schema = """ + +seeds: +- name: base + config: + quote_columns: true + +""" + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + schema = schema_base_yml + quote_columns_seed_schema + return { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seed_base_csv_underscore_column, + } diff --git a/dbt-athena/tests/functional/adapter/test_retries_iceberg.py b/dbt-athena/tests/functional/adapter/test_retries_iceberg.py new file mode 100644 index 00000000..f4711d9d --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_retries_iceberg.py @@ -0,0 +1,136 @@ +"""Test parallel insert into iceberg table.""" +import copy +import os + +import pytest + +from dbt.artifacts.schemas.results import RunStatus +from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture + +PARALLELISM = 10 + +base_dbt_profile = { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "threads": PARALLELISM, + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": 0, + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, +} + +models__target = """ +{{ + config( + table_type='iceberg', + materialized='table' + ) +}} + +select * from ( + values + (1, -1) +) as t (id, status) +limit 0 + +""" + +models__source = { + f"model_{i}.sql": f""" +{{{{ + config( + table_type='iceberg', + materialized='table', + tags=['src'], + pre_hook='insert into target values ({i}, {i})' + ) +}}}} + +select 1 as col +""" + for i in range(PARALLELISM) +} + +seeds__expected_target_init = "id,status" +seeds__expected_target_post = "id,status\n" + "\n".join([f"{i},{i}" for i in range(PARALLELISM)]) + + +class TestIcebergRetriesDisabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + profile["num_iceberg_retries"] = 0 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run, log = run_dbt_and_capture(["run", "--select", "tag:src"], expect_pass=False) + assert any(model_run_result.status == RunStatus.Error for model_run_result in run.results) + assert "ICEBERG_COMMIT_ERROR" in log + + +class TestIcebergRetriesEnabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + # we set iceberg retries to a high number to ensure that the test will pass + profile["num_iceberg_retries"] = PARALLELISM * 5 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run = run_dbt(["run", "--select", "tag:src"]) + assert all([model_run_result.status == RunStatus.Success for model_run_result in run.results]) + check_relations_equal(project.adapter, [relation_name, expected__post_seed_name]) diff --git a/dbt-athena/tests/functional/adapter/test_seed_by_insert.py b/dbt-athena/tests/functional/adapter/test_seed_by_insert.py new file mode 100644 index 00000000..cb1f91c5 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_seed_by_insert.py @@ -0,0 +1,30 @@ +import pytest + +from dbt.tests.adapter.basic.files import ( + base_materialized_var_sql, + base_table_sql, + base_view_sql, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations + +seed_by_insert_schema = """ + +seeds: +- name: base + config: + seed_by_insert: True + +""" + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + schema = schema_base_yml + seed_by_insert_schema + return { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema, + } diff --git a/dbt-athena/tests/functional/adapter/test_snapshot.py b/dbt-athena/tests/functional/adapter/test_snapshot.py new file mode 100644 index 00000000..3debfb04 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_snapshot.py @@ -0,0 +1,300 @@ +""" +These are the test suites from `dbt.tests.adapter.basic.test_snapshot_check_cols` +and `dbt.tests.adapter.basic.test_snapshot_timestamp`, but slightly adapted. +The original test suites for snapshots didn't work out-of-the-box, because Athena +has no support for UPDATE statements on hive tables. This is required by those +tests to update the input seeds. + +This file also includes test suites for the iceberg tables. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + cc_all_snapshot_sql, + cc_date_snapshot_sql, + cc_name_snapshot_sql, + seeds_added_csv, + seeds_base_csv, + ts_snapshot_sql, +) +from dbt.tests.util import relation_from_name, run_dbt + + +def check_relation_rows(project, snapshot_name, count): + """Assert that the relation has the given number of rows""" + relation = relation_from_name(project.adapter, snapshot_name) + result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one") + assert result[0] == count + + +def check_relation_columns(project, snapshot_name, count): + """Assert that the relation has the given number of columns""" + relation = relation_from_name(project.adapter, snapshot_name) + result = project.run_sql(f"select * from {relation}", fetch="one") + assert len(result) == count + + +seeds_altered_added_csv = """ +11,Mateo,2014-09-08T17:04:27 +12,Julian,2000-02-05T11:48:30 +13,Gabriel,2001-07-11T07:32:52 +14,Isaac,2002-11-25T03:22:28 +15,Levi,2009-11-16T11:57:15 +16,Elizabeth,2005-04-10T03:50:11 +17,Grayson,2019-08-07T19:28:17 +18,Dylan,2014-03-02T11:50:41 +19,Jayden,2009-06-07T07:12:49 +20,Luke,2003-12-06T21:42:18 +""".lstrip() + + +seeds_altered_base_csv = """ +id,name,some_date +1,Easton_updated,1981-05-20T06:46:51 +2,Lillian_updated,1978-09-03T18:10:33 +3,Jeremiah_updated,1982-03-11T03:59:51 +4,Nolan_updated,1976-05-06T20:21:35 +5,Hannah_updated,1982-06-23T05:41:26 +6,Eleanor_updated,1991-08-10T23:12:21 +7,Lily_updated,1971-03-29T14:58:02 +8,Jonathan_updated,1988-02-26T02:55:24 +9,Adrian_updated,1994-02-09T13:14:23 +10,Nora_updated,1976-03-01T16:51:39 +""".lstrip() + + +iceberg_cc_all_snapshot_sql = """ +{% snapshot cc_all_snapshot %} + {{ config( + check_cols='all', + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_cc_name_snapshot_sql = """ +{% snapshot cc_name_snapshot %} + {{ config( + check_cols=['name'], + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_cc_date_snapshot_sql = """ +{% snapshot cc_date_snapshot %} + {{ config( + check_cols=['some_date'], + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_ts_snapshot_sql = """ +{% snapshot ts_snapshot %} + {{ config( + strategy='timestamp', + unique_key='id', + updated_at='some_date', + target_database=database, + target_schema=schema, + table_type='iceberg', + )}} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +class TestSnapshotCheckColsHive: + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seeds_base_csv, + "added.csv": seeds_added_csv, + "updated_1.csv": seeds_base_csv + seeds_altered_added_csv, + "updated_2.csv": seeds_altered_base_csv + seeds_altered_added_csv, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "hive_snapshot_strategy_check_cols"} + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "cc_all_snapshot.sql": cc_all_snapshot_sql, + "cc_date_snapshot.sql": cc_date_snapshot_sql, + "cc_name_snapshot.sql": cc_name_snapshot_sql, + } + + def test_snapshot_check_cols(self, project): + # seed command + results = run_dbt(["seed"]) + assert len(results) == 4 + + # snapshot command + results = run_dbt(["snapshot"]) + for result in results: + assert result.status == "success" + + check_relation_columns(project, "cc_all_snapshot", 7) + check_relation_columns(project, "cc_name_snapshot", 7) + check_relation_columns(project, "cc_date_snapshot", 7) + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 10) + check_relation_rows(project, "cc_name_snapshot", 10) + check_relation_rows(project, "cc_date_snapshot", 10) + + relation_from_name(project.adapter, "cc_all_snapshot") + + # point at the "added" seed so the snapshot sees 10 new rows + results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 20) + check_relation_rows(project, "cc_name_snapshot", 20) + check_relation_rows(project, "cc_date_snapshot", 20) + + # re-run snapshots, using "updated_1" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_1"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 30) + check_relation_rows(project, "cc_date_snapshot", 30) + # unchanged: only the timestamp changed + check_relation_rows(project, "cc_name_snapshot", 20) + + # re-run snapshots, using "updated_2" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_2"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 40) + check_relation_rows(project, "cc_name_snapshot", 30) + # does not see name updates + check_relation_rows(project, "cc_date_snapshot", 30) + + +class TestSnapshotTimestampHive: + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seeds_base_csv, + "added.csv": seeds_added_csv, + "updated_1.csv": seeds_base_csv + seeds_altered_added_csv, + "updated_2.csv": seeds_altered_base_csv + seeds_altered_added_csv, + } + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "ts_snapshot.sql": ts_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "hive_snapshot_strategy_timestamp"} + + def test_snapshot_timestamp(self, project): + # seed command + results = run_dbt(["seed"]) + assert len(results) == 4 + + # snapshot command + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relation_columns(project, "ts_snapshot", 7) + + # snapshot has 10 rows + check_relation_rows(project, "ts_snapshot", 10) + + # point at the "added" seed so the snapshot sees 10 new rows + results = run_dbt(["snapshot", "--vars", "seed_name: added"]) + for result in results: + assert result.status == "success" + + # snapshot now has 20 rows + check_relation_rows(project, "ts_snapshot", 20) + + # re-run snapshots, using "updated_1" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_1"]) + for result in results: + assert result.status == "success" + + # snapshot now has 30 rows + check_relation_rows(project, "ts_snapshot", 30) + + # re-run snapshots, using "updated_2" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_2"]) + for result in results: + assert result.status == "success" + + # snapshot still has 30 rows because timestamp not updated + check_relation_rows(project, "ts_snapshot", 30) + + +class TestIcebergSnapshotTimestamp(TestSnapshotTimestampHive): + @pytest.fixture(scope="class") + def snapshots(self): + return { + "ts_snapshot.sql": iceberg_ts_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "iceberg_snapshot_strategy_timestamp"} + + +class TestIcebergSnapshotCheckCols(TestSnapshotCheckColsHive): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "iceberg_snapshot_strategy_check_cols"} + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "cc_all_snapshot.sql": iceberg_cc_all_snapshot_sql, + "cc_date_snapshot.sql": iceberg_cc_date_snapshot_sql, + "cc_name_snapshot.sql": iceberg_cc_name_snapshot_sql, + } diff --git a/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py b/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py new file mode 100644 index 00000000..0f188ce1 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py @@ -0,0 +1,130 @@ +import re + +import pytest +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__unique_tmp_table_suffix_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='insert_overwrite', + partitioned_by=['date_column'], + unique_tmp_table_suffix=True + ) +}} +select + random() as rnd, + cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column +""" + + +class TestUniqueTmpTableSuffix: + @pytest.fixture(scope="class") + def models(self): + return {"unique_tmp_table_suffix.sql": models__unique_tmp_table_suffix_sql} + + def test__unique_tmp_table_suffix(self, project, capsys): + relation_name = "unique_tmp_table_suffix" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + expected_unique_table_name_re = ( + r"unique_tmp_table_suffix__dbt_tmp_" + r"[0-9a-fA-F]{8}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{12}" + ) + + first_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-01"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + first_model_run_result = first_model_run.results[0] + + assert first_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + first_model_run_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0] + + # Run statements logged output should not contain unique table suffix after first run + assert not bool(re.search(expected_unique_table_name_re, first_model_run_result_table_name)) + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 1 + + incremental_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-02"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + # Run statements logged for subsequent incremental model runs should use unique table suffix + assert bool(re.search(expected_unique_table_name_re, incremental_model_run_result_table_name)) + + assert first_model_run_result_table_name != incremental_model_run_result_table_name + + incremental_model_run_2 = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-03"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run_2.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + incremental_model_run_result_table_name_2 = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert incremental_model_run_result_table_name != incremental_model_run_result_table_name_2 + + assert first_model_run_result_table_name != incremental_model_run_result_table_name_2 diff --git a/dbt-athena/tests/functional/adapter/test_unit_testing.py b/dbt-athena/tests/functional/adapter/test_unit_testing.py new file mode 100644 index 00000000..5ec246c2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_unit_testing.py @@ -0,0 +1,36 @@ +import pytest + +from dbt.tests.adapter.unit_testing.test_case_insensitivity import ( + BaseUnitTestCaseInsensivity, +) +from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput +from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes + + +class TestAthenaUnitTestingTypes(BaseUnitTestingTypes): + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["2.0", "2.0"], + ["'12345'", "12345"], + ["'string'", "string"], + ["true", "true"], + ["date '2024-04-01'", "2024-04-01"], + ["timestamp '2024-04-01 00:00:00.000'", "'2024-04-01 00:00:00.000'"], + # TODO: activate once safe_cast supports complex structures + # ["array[1, 2, 3]", "[1, 2, 3]"], + # [ + # "map(array['10', '15', '20'], array['t', 'f', NULL])", + # """'{"10: "t", "15": "f", "20": null}'""", + # ], + ] + + +class TestAthenaUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass + + +class TestAthenaUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass diff --git a/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py b/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py new file mode 100644 index 00000000..4f448420 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py @@ -0,0 +1,36 @@ +import json +import re +from typing import List + + +def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]: + sql_create_statements = [] + # Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..." + for events_msg in dbt_run_capsys_output.split("\n")[1:]: + base_msg_data = None + # Best effort solution to avoid invalid records and blank lines + try: + base_msg_data = json.loads(events_msg).get("data") + except json.JSONDecodeError: + pass + """First run will not produce data.sql object in the execution logs, only data.base_msg + containing the "Running Athena query:" initial create statement. + Subsequent incremental runs will only contain the insert from the tmp table into the model + table destination. + Since we want to compare both run create statements, we need to handle both cases""" + if base_msg_data: + base_msg = base_msg_data.get("base_msg") + if "Running Athena query:" in str(base_msg): + if "create table" in base_msg: + sql_create_statements.append(base_msg) + + if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data: + if "create table" in base_msg_data.get("sql"): + sql_create_statements.append(base_msg_data.get("sql")) + + return sql_create_statements + + +def extract_create_statement_table_names(sql_create_statement: str) -> List[str]: + table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement) + return [table_name.rstrip() for table_name in table_names] diff --git a/dbt-athena/tests/functional/adapter/utils/test_utils.py b/dbt-athena/tests/functional/adapter/utils/test_utils.py new file mode 100644 index 00000000..ef13c062 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/utils/test_utils.py @@ -0,0 +1,196 @@ +import pytest +from tests.functional.adapter.fixture_datediff import ( + models__test_datediff_sql, + seeds__data_datediff_csv, +) +from tests.functional.adapter.fixture_split_parts import ( + models__test_split_part_sql, + models__test_split_part_yml, +) + +from dbt.tests.adapter.utils.fixture_datediff import models__test_datediff_yml +from dbt.tests.adapter.utils.test_any_value import BaseAnyValue +from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend +from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat +from dbt.tests.adapter.utils.test_array_construct import BaseArrayConstruct +from dbt.tests.adapter.utils.test_bool_or import BaseBoolOr +from dbt.tests.adapter.utils.test_concat import BaseConcat +from dbt.tests.adapter.utils.test_current_timestamp import BaseCurrentTimestampNaive +from dbt.tests.adapter.utils.test_date_trunc import BaseDateTrunc +from dbt.tests.adapter.utils.test_dateadd import BaseDateAdd +from dbt.tests.adapter.utils.test_datediff import BaseDateDiff +from dbt.tests.adapter.utils.test_escape_single_quotes import ( + BaseEscapeSingleQuotesQuote, +) +from dbt.tests.adapter.utils.test_except import BaseExcept +from dbt.tests.adapter.utils.test_hash import BaseHash +from dbt.tests.adapter.utils.test_intersect import BaseIntersect +from dbt.tests.adapter.utils.test_length import BaseLength +from dbt.tests.adapter.utils.test_listagg import BaseListagg +from dbt.tests.adapter.utils.test_position import BasePosition +from dbt.tests.adapter.utils.test_replace import BaseReplace +from dbt.tests.adapter.utils.test_right import BaseRight +from dbt.tests.adapter.utils.test_split_part import BaseSplitPart +from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral + +models__array_concat_expected_sql = """ +select 1 as id, {{ array_construct([1,2,3,4,5,6]) }} as array_col +""" + + +models__array_concat_actual_sql = """ +select 1 as id, {{ array_concat(array_construct([1,2,3]), array_construct([4,5,6])) }} as array_col +""" + + +class TestAnyValue(BaseAnyValue): + pass + + +class TestArrayAppend(BaseArrayAppend): + pass + + +class TestArrayConstruct(BaseArrayConstruct): + pass + + +# Altered test because can't merge empty and non-empty arrays. +class TestArrayConcat(BaseArrayConcat): + @pytest.fixture(scope="class") + def models(self): + return { + "actual.sql": models__array_concat_actual_sql, + "expected.sql": models__array_concat_expected_sql, + } + + +class TestBoolOr(BaseBoolOr): + pass + + +class TestConcat(BaseConcat): + pass + + +class TestEscapeSingleQuotes(BaseEscapeSingleQuotesQuote): + pass + + +class TestExcept(BaseExcept): + pass + + +class TestHash(BaseHash): + pass + + +class TestIntersect(BaseIntersect): + pass + + +class TestLength(BaseLength): + pass + + +class TestPosition(BasePosition): + pass + + +class TestReplace(BaseReplace): + pass + + +class TestRight(BaseRight): + pass + + +class TestSplitPart(BaseSplitPart): + @pytest.fixture(scope="class") + def models(self): + return { + "test_split_part.yml": models__test_split_part_yml, + "test_split_part.sql": self.interpolate_macro_namespace(models__test_split_part_sql, "split_part"), + } + + +class TestStringLiteral(BaseStringLiteral): + pass + + +class TestDateTrunc(BaseDateTrunc): + pass + + +class TestDateAdd(BaseDateAdd): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "name": "test_date_add", + "seeds": { + "test": { + "date_add": { + "data_dateadd": { + "+column_types": { + "from_time": "timestamp", + "result": "timestamp", + }, + } + } + }, + }, + } + + +class TestDateDiff(BaseDateDiff): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "name": "test_date_diff", + "seeds": { + "test": { + "date_diff": { + "data_datediff": { + "+column_types": { + "first_date": "timestamp", + "second_date": "timestamp", + }, + } + } + }, + }, + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"data_datediff.csv": seeds__data_datediff_csv} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_datediff.yml": models__test_datediff_yml, + "test_datediff.sql": self.interpolate_macro_namespace(models__test_datediff_sql, "datediff"), + } + + +# TODO: Activate this once we have datatypes.sql macro +# class TestCastBoolToText(BaseCastBoolToText): +# pass + + +# TODO: Activate this once we have datatypes.sql macro +# class TestSafeCast(BaseSafeCast): +# pass + + +class TestListagg(BaseListagg): + pass + + +# TODO: Implement this macro when needed +# class TestLastDay(BaseLastDay): +# pass + + +class TestCurrentTimestamp(BaseCurrentTimestampNaive): + pass diff --git a/dbt-athena/tests/functional/conftest.py b/dbt-athena/tests/functional/conftest.py new file mode 100644 index 00000000..1591459b --- /dev/null +++ b/dbt-athena/tests/functional/conftest.py @@ -0,0 +1,26 @@ +import os + +import pytest + +# Import the functional fixtures as a plugin +# Note: fixtures with session scope need to be local +pytest_plugins = ["dbt.tests.fixtures.project"] + + +# The profile dictionary, used to write out profiles.yml +@pytest.fixture(scope="class") +def dbt_profile_target(): + return { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")), + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, + "spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"), + } diff --git a/dbt-athena/tests/unit/__init__.py b/dbt-athena/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/unit/conftest.py b/dbt-athena/tests/unit/conftest.py new file mode 100644 index 00000000..21d051c7 --- /dev/null +++ b/dbt-athena/tests/unit/conftest.py @@ -0,0 +1,75 @@ +from io import StringIO +import os +from unittest.mock import MagicMock, patch + +import boto3 +import pytest + +from dbt_common.events import get_event_manager +from dbt_common.events.base_types import EventLevel +from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter + +from dbt.adapters.athena import connections +from dbt.adapters.athena.connections import AthenaCredentials + +from tests.unit.utils import MockAWSService +from tests.unit import constants + + +@pytest.fixture(scope="class") +def athena_client(): + with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client: + return mock_athena_client + + +@pytest.fixture(scope="function") +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = constants.AWS_REGION + + +@patch.object(connections, "AthenaCredentials") +@pytest.fixture(scope="class") +def athena_credentials(): + return AthenaCredentials( + database=constants.DATA_CATALOG_NAME, + schema=constants.DATABASE_NAME, + s3_staging_dir=constants.S3_STAGING_DIR, + region_name=constants.AWS_REGION, + work_group=constants.ATHENA_WORKGROUP, + spark_work_group=constants.SPARK_WORKGROUP, + ) + + +@pytest.fixture() +def mock_aws_service(aws_credentials) -> MockAWSService: + return MockAWSService() + + +@pytest.fixture(scope="function") +def dbt_error_caplog() -> StringIO: + return _setup_custom_caplog("dbt_error", EventLevel.ERROR) + + +@pytest.fixture(scope="function") +def dbt_debug_caplog() -> StringIO: + return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG) + + +def _setup_custom_caplog(name: str, level: EventLevel): + string_buf = StringIO() + capture_config = LoggerConfig( + name=name, + level=level, + use_colors=False, + line_format=LineFormat.PlainText, + filter=NoFilter, + output_stream=string_buf, + ) + event_manager = get_event_manager() + event_manager.add_logger(capture_config) + return string_buf diff --git a/dbt-athena/tests/unit/constants.py b/dbt-athena/tests/unit/constants.py new file mode 100644 index 00000000..f549ad64 --- /dev/null +++ b/dbt-athena/tests/unit/constants.py @@ -0,0 +1,11 @@ +CATALOG_ID = "12345678910" +DATA_CATALOG_NAME = "awsdatacatalog" +SHARED_DATA_CATALOG_NAME = "9876543210" +FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source" +DATABASE_NAME = "test_dbt_athena" +BUCKET = "test-dbt-athena" +AWS_REGION = "eu-west-1" +S3_STAGING_DIR = "s3://my-bucket/test-dbt/" +S3_TMP_TABLE_DIR = "s3://my-bucket/test-dbt-temp/" +ATHENA_WORKGROUP = "dbt-athena-adapter" +SPARK_WORKGROUP = "spark" diff --git a/dbt-athena/tests/unit/fixtures.py b/dbt-athena/tests/unit/fixtures.py new file mode 100644 index 00000000..ba3d048e --- /dev/null +++ b/dbt-athena/tests/unit/fixtures.py @@ -0,0 +1 @@ +seed_data = [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}] diff --git a/dbt-athena/tests/unit/test_adapter.py b/dbt-athena/tests/unit/test_adapter.py new file mode 100644 index 00000000..e3c3b923 --- /dev/null +++ b/dbt-athena/tests/unit/test_adapter.py @@ -0,0 +1,1330 @@ +import datetime +import decimal +from multiprocessing import get_context +from unittest import mock +from unittest.mock import patch + +import agate +import boto3 +import botocore +import pytest +from dbt_common.clients import agate_helper +from dbt_common.exceptions import ConnectionError, DbtRuntimeError +from moto import mock_aws +from moto.core import DEFAULT_ACCOUNT_ID + +from dbt.adapters.athena import AthenaAdapter +from dbt.adapters.athena import Plugin as AthenaPlugin +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter +from dbt.adapters.athena.exceptions import S3LocationException +from dbt.adapters.athena.relation import AthenaRelation, TableType +from dbt.adapters.athena.utils import AthenaCatalogType +from dbt.adapters.contracts.connection import ConnectionState +from dbt.adapters.contracts.relation import RelationType + +from .constants import ( + ATHENA_WORKGROUP, + AWS_REGION, + BUCKET, + DATA_CATALOG_NAME, + DATABASE_NAME, + FEDERATED_QUERY_CATALOG_NAME, + S3_STAGING_DIR, + SHARED_DATA_CATALOG_NAME, +) +from .fixtures import seed_data +from .utils import TestAdapterConversions, config_from_parts_or_dicts, inject_adapter + + +class TestAthenaAdapter: + def setup_method(self, _): + self.config = TestAthenaAdapter._config_from_settings() + self._adapter = None + self.used_schemas = frozenset( + { + ("awsdatacatalog", "foo"), + ("awsdatacatalog", "quux"), + ("awsdatacatalog", "baz"), + (SHARED_DATA_CATALOG_NAME, "foo"), + (FEDERATED_QUERY_CATALOG_NAME, "foo"), + } + ) + + @property + def adapter(self): + if self._adapter is None: + self._adapter = AthenaAdapter(self.config, get_context("spawn")) + inject_adapter(self._adapter, AthenaPlugin) + return self._adapter + + @staticmethod + def _config_from_settings(settings={}): + project_cfg = { + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "config-version": 2, + } + + profile_cfg = { + "outputs": { + "test": { + **{ + "type": "athena", + "s3_staging_dir": S3_STAGING_DIR, + "region_name": AWS_REGION, + "database": DATA_CATALOG_NAME, + "work_group": ATHENA_WORKGROUP, + "schema": DATABASE_NAME, + }, + **settings, + } + }, + "target": "test", + } + + return config_from_parts_or_dicts(project_cfg, profile_cfg) + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection_validations(self, connection_cls): + try: + connection = self.adapter.acquire_connection("dummy") + except DbtRuntimeError as e: + pytest.fail(f"got ValidationException: {e}") + except BaseException as e: + pytest.fail(f"acquiring connection failed with unknown exception: {e}") + + connection_cls.assert_not_called() + connection.handle + connection_cls.assert_called_once() + _, arguments = connection_cls.call_args_list[0] + assert arguments["s3_staging_dir"] == "s3://my-bucket/test-dbt/" + assert arguments["endpoint_url"] is None + assert arguments["schema_name"] == "test_dbt_athena" + assert arguments["work_group"] == "dbt-athena-adapter" + assert arguments["cursor_class"] == AthenaCursor + assert isinstance(arguments["formatter"], AthenaParameterFormatter) + assert arguments["poll_interval"] == 1.0 + assert arguments["retry_config"].attempt == 6 + assert arguments["retry_config"].exceptions == ( + "ThrottlingException", + "TooManyRequestsException", + "InternalServerException", + ) + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection(self, connection_cls): + connection = self.adapter.acquire_connection("dummy") + + connection_cls.assert_not_called() + connection.handle + assert connection.state == ConnectionState.OPEN + assert connection.handle is not None + connection_cls.assert_called_once() + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection_exc(self, connection_cls, dbt_error_caplog): + connection_cls.side_effect = lambda **_: (_ for _ in ()).throw(Exception("foobar")) + connection = self.adapter.acquire_connection("dummy") + conn_res = None + with pytest.raises(ConnectionError) as exc: + conn_res = connection.handle + + assert conn_res is None + assert connection.state == ConnectionState.FAIL + assert exc.value.__str__() == "foobar" + assert "Got an error when attempting to open a Athena connection due to foobar" in dbt_error_caplog.getvalue() + + @pytest.mark.parametrize( + ( + "s3_data_dir", + "s3_data_naming", + "s3_path_table_part", + "s3_tmp_table_dir", + "external_location", + "is_temporary_table", + "expected", + ), + ( + pytest.param( + None, "table", None, None, None, False, "s3://my-bucket/test-dbt/tables/table", id="table naming" + ), + pytest.param( + None, "unique", None, None, None, False, "s3://my-bucket/test-dbt/tables/uuid", id="unique naming" + ), + pytest.param( + None, + "table_unique", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/table/uuid", + id="table_unique naming", + ), + pytest.param( + None, + "schema_table", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/table", + id="schema_table naming", + ), + pytest.param( + None, + "schema_table_unique", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/table/uuid", + id="schema_table_unique naming", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + None, + False, + "s3://my-data-bucket/schema/table/uuid", + id="data_dir set", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + "s3://path/to/external/", + False, + "s3://path/to/external", + id="external_location set and not temporary", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + "s3://my-bucket/test-dbt-temp/", + "s3://path/to/external/", + True, + "s3://my-bucket/test-dbt-temp/schema/table/uuid", + id="s3_tmp_table_dir set, external_location set and temporary", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + "s3://path/to/external/", + True, + "s3://my-data-bucket/schema/table/uuid", + id="s3_tmp_table_dir is empty, external_location set and temporary", + ), + pytest.param( + None, + "schema_table_unique", + "other_table", + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/other_table/uuid", + id="s3_path_table_part set", + ), + ), + ) + @patch("dbt.adapters.athena.impl.uuid4", return_value="uuid") + def test_generate_s3_location( + self, + _, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + s3_path_table_part, + is_temporary_table, + expected, + ): + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="schema", + identifier="table", + s3_path_table_part=s3_path_table_part, + ) + assert expected == self.adapter.generate_s3_location( + relation, s3_data_dir, s3_data_naming, s3_tmp_table_dir, external_location, is_temporary_table + ) + + @mock_aws + def test_get_table_location(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + assert self.adapter.get_glue_table_location(relation) == "s3://test-dbt-athena/tables/test_table" + + @mock_aws + def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name, location="") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + with pytest.raises(S3LocationException) as exc: + self.adapter.get_glue_table_location(relation) + assert exc.value.args[0] == ( + 'Relation "awsdatacatalog"."test_dbt_athena"."test_table" is of type \'table\' which requires a ' + "location, but no location returned by Glue." + ) + + @mock_aws + def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service): + view_name = "view" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_view(view_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=view_name, type=RelationType.View + ) + assert self.adapter.get_glue_table_location(relation) is None + + @mock_aws + def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + assert self.adapter.get_glue_table_location(relation) is None + assert f"Table {relation.render()} does not exists - Ignoring" in dbt_debug_caplog.getvalue() + + @mock_aws + def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service): + table_name = "table" + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name) + mock_aws_service.add_data_in_table(table_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.acquire_connection("dummy") + self.adapter.clean_up_partitions(relation, "dt < '2022-01-03'") + log_records = dbt_debug_caplog.getvalue() + assert ( + "Deleting table data: path=" + "'s3://test-dbt-athena/tables/table/dt=2022-01-01', " + "bucket='test-dbt-athena', " + "prefix='tables/table/dt=2022-01-01/'" in log_records + ) + assert ( + "Deleting table data: path=" + "'s3://test-dbt-athena/tables/table/dt=2022-01-02', " + "bucket='test-dbt-athena', " + "prefix='tables/table/dt=2022-01-02/'" in log_records + ) + s3 = boto3.client("s3", region_name=AWS_REGION) + keys = [obj["Key"] for obj in s3.list_objects_v2(Bucket=BUCKET)["Contents"]] + assert set(keys) == {"tables/table/dt=2022-01-03/data1.parquet", "tables/table/dt=2022-01-03/data2.parquet"} + + @mock_aws + def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="table", + ) + result = self.adapter.clean_up_table(relation) + assert result is None + assert ( + 'Table "awsdatacatalog"."test_dbt_athena"."table" does not exists - Ignoring' in dbt_debug_caplog.getvalue() + ) + + @mock_aws + def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + mock_aws_service.create_view("test_view") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="test_view", + type=RelationType.View, + ) + result = self.adapter.clean_up_table(relation) + assert result is None + + @mock_aws + def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.add_data_in_table("table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="table", + ) + self.adapter.clean_up_table(relation) + assert ( + "Deleting table data: path='s3://test-dbt-athena/tables/table', " + "bucket='test-dbt-athena', " + "prefix='tables/table/'" in dbt_debug_caplog.getvalue() + ) + s3 = boto3.client("s3", region_name=AWS_REGION) + objs = s3.list_objects_v2(Bucket=BUCKET) + assert objs["KeyCount"] == 0 + + @pytest.mark.parametrize( + "column,quote_config,quote_character,expected", + [ + pytest.param("col", False, None, "col"), + pytest.param("col", True, None, '"col"'), + pytest.param("col", False, "`", "col"), + pytest.param("col", True, "`", "`col`"), + ], + ) + def test_quote_seed_column(self, column, quote_config, quote_character, expected): + assert self.adapter.quote_seed_column(column, quote_config, quote_character) == expected + + @mock_aws + def test__get_one_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database("foo") + mock_aws_service.create_database("quux") + mock_aws_service.create_database("baz") + mock_aws_service.create_table(table_name="bar", database_name="foo") + mock_aws_service.create_table(table_name="bar", database_name="quux") + mock_information_schema = mock.MagicMock() + mock_information_schema.database = "awsdatacatalog" + + self.adapter.acquire_connection("dummy") + actual = self.adapter._get_one_catalog(mock_information_schema, {"foo", "quux"}, self.used_schemas) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), + ("awsdatacatalog", "quux", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "dt", 2, "date", None), + ] + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + mock_aws_service.create_table_without_type(table_name="qux", database_name="baz") + with pytest.raises(ValueError): + self.adapter._get_one_catalog(mock_information_schema, {"baz"}, self.used_schemas) + + @mock_aws + def test__get_one_catalog_by_relations(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database("foo") + mock_aws_service.create_database("quux") + mock_aws_service.create_table(database_name="foo", table_name="bar") + # we create another relation + mock_aws_service.create_table(table_name="bar", database_name="quux") + + mock_information_schema = mock.MagicMock() + mock_information_schema.database = "awsdatacatalog" + + self.adapter.acquire_connection("dummy") + + rel_1 = self.adapter.Relation.create( + database="awsdatacatalog", + schema="foo", + identifier="bar", + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + + expected_rows = [ + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.used_schemas) + assert actual.column_names == expected_column_names + assert actual.rows == expected_rows + + @mock_aws + def test__get_one_catalog_shared_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog(catalog_name=SHARED_DATA_CATALOG_NAME, catalog_id=SHARED_DATA_CATALOG_NAME) + mock_aws_service.create_database("foo", catalog_id=SHARED_DATA_CATALOG_NAME) + mock_aws_service.create_table(table_name="bar", database_name="foo", catalog_id=SHARED_DATA_CATALOG_NAME) + mock_information_schema = mock.MagicMock() + mock_information_schema.database = SHARED_DATA_CATALOG_NAME + + self.adapter.acquire_connection("dummy") + actual = self.adapter._get_one_catalog( + mock_information_schema, + {"foo"}, + self.used_schemas, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + ("9876543210", "foo", "bar", "table", None, "id", 0, "string", None), + ("9876543210", "foo", "bar", "table", None, "country", 1, "string", None), + ("9876543210", "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + @mock_aws + def test__get_one_catalog_federated_query_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog( + catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA + ) + mock_information_schema = mock.MagicMock() + mock_information_schema.database = FEDERATED_QUERY_CATALOG_NAME + + # Original botocore _make_api_call function + orig = botocore.client.BaseClient._make_api_call + + # Mocking this as list_table_metadata and creating non-glue tables is not supported by moto. + # Followed this guide: http://docs.getmoto.org/en/latest/docs/services/patching_other_services.html + def mock_athena_list_table_metadata(self, operation_name, kwarg): + if operation_name == "ListTableMetadata": + return { + "TableMetadataList": [ + { + "Name": "bar", + "TableType": "EXTERNAL_TABLE", + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + } + ], + } + # If we don't want to patch the API call + return orig(self, operation_name, kwarg) + + self.adapter.acquire_connection("dummy") + with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): + actual = self.adapter._get_one_catalog( + mock_information_schema, + {"foo"}, + self.used_schemas, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + @mock_aws + def test__get_data_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + self.adapter.acquire_connection("dummy") + res = self.adapter._get_data_catalog(DATA_CATALOG_NAME) + assert {"Name": "awsdatacatalog", "Type": "GLUE", "Parameters": {"catalog-id": DEFAULT_ACCOUNT_ID}} == res + + def _test_list_relations_without_caching(self, schema_relation): + self.adapter.acquire_connection("dummy") + relations = self.adapter.list_relations_without_caching(schema_relation) + assert len(relations) == 4 + assert all(isinstance(rel, AthenaRelation) for rel in relations) + relations.sort(key=lambda rel: rel.name) + iceberg_table = relations[0] + other = relations[1] + table = relations[2] + view = relations[3] + assert iceberg_table.name == "iceberg" + assert iceberg_table.type == "table" + assert iceberg_table.detailed_table_type == "ICEBERG" + assert other.name == "other" + assert other.type == "table" + assert other.detailed_table_type == "" + assert table.name == "table" + assert table.type == "table" + assert table.detailed_table_type == "" + assert view.name == "view" + assert view.type == "view" + assert view.detailed_table_type == "" + + @mock_aws + def test_list_relations_without_caching_with_awsdatacatalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.create_table("other") + mock_aws_service.create_view("view") + mock_aws_service.create_table_without_table_type("without_table_type") + mock_aws_service.create_iceberg_table("iceberg") + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self._test_list_relations_without_caching(schema_relation) + + @mock_aws + def test_list_relations_without_caching_with_other_glue_data_catalog(self, mock_aws_service): + data_catalog_name = "other_data_catalog" + mock_aws_service.create_data_catalog(data_catalog_name) + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.create_table("other") + mock_aws_service.create_view("view") + mock_aws_service.create_table_without_table_type("without_table_type") + mock_aws_service.create_iceberg_table("iceberg") + schema_relation = self.adapter.Relation.create( + database=data_catalog_name, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self._test_list_relations_without_caching(schema_relation) + + @mock_aws + def test_list_relations_without_caching_on_unknown_schema(self, mock_aws_service): + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="unknown_schema", + quote_policy=self.adapter.config.quoting, + ) + self.adapter.acquire_connection("dummy") + relations = self.adapter.list_relations_without_caching(schema_relation) + assert relations == [] + + @mock_aws + @patch("dbt.adapters.athena.impl.SQLAdapter.list_relations_without_caching", return_value=[]) + def test_list_relations_without_caching_with_non_glue_data_catalog( + self, parent_list_relations_without_caching, mock_aws_service + ): + data_catalog_name = "other_data_catalog" + mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE) + schema_relation = self.adapter.Relation.create( + database=data_catalog_name, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self.adapter.acquire_connection("dummy") + self.adapter.list_relations_without_caching(schema_relation) + parent_list_relations_without_caching.assert_called_once_with(schema_relation) + + @pytest.mark.parametrize( + "s3_path,expected", + [ + ("s3://my-bucket/test-dbt/tables/schema/table", ("my-bucket", "test-dbt/tables/schema/table/")), + ("s3://my-bucket/test-dbt/tables/schema/table/", ("my-bucket", "test-dbt/tables/schema/table/")), + ], + ) + def test_parse_s3_path(self, s3_path, expected): + assert self.adapter._parse_s3_path(s3_path) == expected + + @mock_aws + def test_swap_table_with_partitions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table(source_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) + mock_aws_service.create_table(target_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, target_table) + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + + @mock_aws + def test_swap_table_without_partitions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table_without_partitions(source_table) + mock_aws_service.create_table_without_partitions(target_table) + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + + @mock_aws + def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + # source table does not have partitions + mock_aws_service.create_table_without_partitions(source_table) + + # the target table has partitions + mock_aws_service.create_table(target_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, target_table) + + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + + self.adapter.swap_table(source_relation, target_relation) + glue_client = boto3.client("glue", region_name=AWS_REGION) + + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" + ) + + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + assert len(target_table_partitions) == 0 + + @mock_aws + def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table(source_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) + mock_aws_service.create_table_without_partitions(target_table) + glue_client = boto3.client("glue", region_name=AWS_REGION) + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" + ) + assert len(target_table_partitions) == 0 + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + target_table_partitions_after = glue_client.get_partitions( + DatabaseName=DATABASE_NAME, TableName=target_table + ).get("Partitions") + + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + assert len(target_table_partitions_after) == 26 + + @mock_aws + def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_caplog): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + glue = boto3.client("glue", region_name=AWS_REGION) + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") + assert len(table_versions) == 4 + version_to_keep = 1 + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + versions_to_expire = self.adapter._get_glue_table_versions_to_expire(relation, version_to_keep) + assert len(versions_to_expire) == 3 + assert [v["VersionId"] for v in versions_to_expire] == ["3", "2", "1"] + + @mock_aws + def test_expire_glue_table_versions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + glue = boto3.client("glue", region_name=AWS_REGION) + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") + assert len(table_versions) == 4 + version_to_keep = 1 + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.expire_glue_table_versions(relation, version_to_keep, False) + # TODO delete_table_version is not implemented in moto + # TODO moto issue https://github.com/getmoto/moto/issues/5952 + # assert len(result) == 3 + + @mock_aws + def test_upload_seed_to_s3(self, mock_aws_service): + seed_table = agate.Table.from_object(seed_data) + self.adapter.acquire_connection("dummy") + + database = "db_seeds" + table = "data" + + s3_client = boto3.client("s3", region_name=AWS_REGION) + s3_client.create_bucket(Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=database, + identifier=table, + ) + + location = self.adapter.upload_seed_to_s3( + relation, + seed_table, + s3_data_dir=f"s3://{BUCKET}", + s3_data_naming="schema_table", + external_location=None, + ) + + prefix = "db_seeds/data" + objects = s3_client.list_objects(Bucket=BUCKET, Prefix=prefix).get("Contents") + + assert location == f"s3://{BUCKET}/{prefix}" + assert len(objects) == 1 + assert objects[0].get("Key").endswith(".csv") + + @mock_aws + def test_upload_seed_to_s3_external_location(self, mock_aws_service): + seed_table = agate.Table.from_object(seed_data) + self.adapter.acquire_connection("dummy") + + bucket = "my-external-location" + prefix = "seeds/one" + external_location = f"s3://{bucket}/{prefix}" + + s3_client = boto3.client("s3", region_name=AWS_REGION) + s3_client.create_bucket(Bucket=bucket, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="db_seeds", + identifier="data", + ) + + location = self.adapter.upload_seed_to_s3( + relation, + seed_table, + s3_data_dir=None, + s3_data_naming="schema_table", + external_location=external_location, + ) + + objects = s3_client.list_objects(Bucket=bucket, Prefix=prefix).get("Contents") + + assert location == f"s3://{bucket}/{prefix}" + assert len(objects) == 1 + assert objects[0].get("Key").endswith(".csv") + + @mock_aws + def test_get_work_group_output_location(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_with_output_location_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert work_group_location_enforced + + def test_get_work_group_output_location_if_workgroup_check_is_skipepd(self): + settings = { + "skip_workgroup_check": True, + } + + self.config = TestAthenaAdapter._config_from_settings(settings) + self.adapter.acquire_connection("dummy") + + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_get_work_group_output_location_no_location(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_no_output_location(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_get_work_group_output_location_not_enforced(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_with_output_location_not_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_persist_docs_to_glue_no_comment(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.persist_docs_to_glue( + schema_relation, + { + "description": """ + A table with str, 123, &^% \" and ' + + and an other paragraph. + """, + "columns": { + "id": { + "meta": {"primary_key": "true"}, + "description": """ + A column with str, 123, &^% \" and ' + + and an other paragraph. + """, + } + }, + }, + False, + False, + ) + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=DATABASE_NAME, Name=table_name).get("Table") + assert not table.get("Description", "") + assert not table["Parameters"].get("comment") + assert all(not col.get("Comment") for col in table["StorageDescriptor"]["Columns"]) + assert all(not col.get("Parameters") for col in table["StorageDescriptor"]["Columns"]) + + @mock_aws + def test_persist_docs_to_glue_comment(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.persist_docs_to_glue( + schema_relation, + { + "description": """ + A table with str, 123, &^% \" and ' + + and an other paragraph. + """, + "columns": { + "id": { + "meta": {"primary_key": True}, + "description": """ + A column with str, 123, &^% \" and ' + + and an other paragraph. + """, + } + }, + }, + True, + True, + ) + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=DATABASE_NAME, Name=table_name).get("Table") + assert table["Description"] == "A table with str, 123, &^% \" and ' and an other paragraph." + assert table["Parameters"]["comment"] == "A table with str, 123, &^% \" and ' and an other paragraph." + col_id = [col for col in table["StorageDescriptor"]["Columns"] if col["Name"] == "id"][0] + assert col_id["Comment"] == "A column with str, 123, &^% \" and ' and an other paragraph." + assert col_id["Parameters"] == {"primary_key": "True"} + + @mock_aws + def test_list_schemas(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database(name="foo") + mock_aws_service.create_database(name="bar") + mock_aws_service.create_database(name="quux") + self.adapter.acquire_connection("dummy") + res = self.adapter.list_schemas("") + assert sorted(res) == ["bar", "foo", "quux"] + + @mock_aws + def test_get_columns_in_relation(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + columns = self.adapter.get_columns_in_relation( + self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="tbl_name", + ) + ) + assert columns == [ + AthenaColumn(column="id", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="country", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="dt", dtype="date", table_type=TableType.TABLE), + ] + + @mock_aws + def test_get_columns_in_relation_not_found_table(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + columns = self.adapter.get_columns_in_relation( + self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="tbl_name", + ) + ) + assert columns == [] + + @mock_aws + def test_delete_from_glue_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name") + self.adapter.delete_from_glue_catalog(relation) + glue = boto3.client("glue", region_name=AWS_REGION) + tables_list = glue.get_tables(DatabaseName=DATABASE_NAME).get("TableList") + assert tables_list == [] + + @mock_aws + def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_does_not_exist" + ) + delete_table = self.adapter.delete_from_glue_catalog(relation) + assert delete_table is None + error_msg = f"Table {relation.render()} does not exist and will not be deleted, ignoring" + assert error_msg in dbt_debug_caplog.getvalue() + + @mock_aws + def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("test_table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_table" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.TABLE + + @mock_aws + def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table_without_table_type("test_table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_table" + ) + with pytest.raises(ValueError): + self.adapter.get_glue_table_type(relation) + + @mock_aws + def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_view("test_view") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_view" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.VIEW + + @mock_aws + def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_iceberg_table("test_iceberg") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_iceberg" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.ICEBERG + + @pytest.mark.parametrize( + "column,expected", + [ + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "true"}}, True), + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "false"}}, False), + pytest.param({"Name": "user_id", "Type": "int"}, True), + ], + ) + def test__is_current_column(self, column, expected): + assert self.adapter._is_current_column(column) == expected + + @pytest.mark.parametrize( + "partition_keys, expected_result", + [ + ( + ["year(date_col)", "bucket(col_name, 10)", "default_partition_key"], + "date_trunc('year', date_col), col_name, default_partition_key", + ), + ], + ) + def test_format_partition_keys(self, partition_keys, expected_result): + assert self.adapter.format_partition_keys(partition_keys) == expected_result + + @pytest.mark.parametrize( + "partition_key, expected_result", + [ + ("month(hidden)", "date_trunc('month', hidden)"), + ("bucket(bucket_col, 10)", "bucket_col"), + ("regular_col", "regular_col"), + ], + ) + def test_format_one_partition_key(self, partition_key, expected_result): + assert self.adapter.format_one_partition_key(partition_key) == expected_result + + def test_murmur3_hash_with_int(self): + bucket_number = self.adapter.murmur3_hash(123, 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 54 + + def test_murmur3_hash_with_date(self): + d = datetime.date.today() + bucket_number = self.adapter.murmur3_hash(d, 100) + assert isinstance(d, datetime.date) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_datetime(self): + dt = datetime.datetime.now() + bucket_number = self.adapter.murmur3_hash(dt, 100) + assert isinstance(dt, datetime.datetime) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_str(self): + bucket_number = self.adapter.murmur3_hash("test_string", 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 88 + + def test_murmur3_hash_uniqueness(self): + # Ensuring different inputs produce different hashes + hash1 = self.adapter.murmur3_hash("string1", 100) + hash2 = self.adapter.murmur3_hash("string2", 100) + assert hash1 != hash2 + + def test_murmur3_hash_with_unsupported_type(self): + with pytest.raises(TypeError): + self.adapter.murmur3_hash([1, 2, 3], 100) + + @pytest.mark.parametrize( + "value, column_type, expected_result", + [ + (None, "integer", ("null", " is ")), + (42, "integer", ("42", "=")), + ("O'Reilly", "string", ("'O''Reilly'", "=")), + ("test", "string", ("'test'", "=")), + ("2021-01-01", "date", ("DATE'2021-01-01'", "=")), + ("2021-01-01 12:00:00", "timestamp", ("TIMESTAMP'2021-01-01 12:00:00'", "=")), + ], + ) + def test_format_value_for_partition(self, value, column_type, expected_result): + assert self.adapter.format_value_for_partition(value, column_type) == expected_result + + def test_format_unsupported_type(self): + with pytest.raises(ValueError): + self.adapter.format_value_for_partition("test", "unsupported_type") + + +class TestAthenaFilterCatalog: + def test__catalog_filter_table(self): + column_names = ["table_name", "table_database", "table_schema", "something"] + rows = [ + ["foo", "a", "b", "1234"], # include + ["foo", "a", "1234", "1234"], # include, w/ table schema as str + ["foo", "c", "B", "1234"], # skip + ["1234", "A", "B", "1234"], # include, w/ table name as str + ] + table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER) + + result = AthenaAdapter._catalog_filter_table(table, frozenset({("a", "B"), ("a", "1234")})) + assert len(result) == 3 + for row in result.rows: + assert isinstance(row["table_schema"], str) + assert isinstance(row["table_database"], str) + assert isinstance(row["table_name"], str) + assert isinstance(row["something"], decimal.Decimal) + + +class TestAthenaAdapterConversions(TestAdapterConversions): + def test_convert_text_type(self): + rows = [ + ["", "a1", "stringval1"], + ["", "a2", "stringvalasdfasdfasdfa"], + ["", "a3", "stringval3"], + ] + agate_table = self._make_table_of(rows, agate.Text) + expected = ["string", "string", "string"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_text_type(agate_table, col_idx) == expect + + def test_convert_number_type(self): + rows = [ + ["", "23.98", "-1"], + ["", "12.78", "-2"], + ["", "79.41", "-3"], + ] + agate_table = self._make_table_of(rows, agate.Number) + expected = ["integer", "double", "integer"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_number_type(agate_table, col_idx) == expect + + def test_convert_boolean_type(self): + rows = [ + ["", "false", "true"], + ["", "false", "false"], + ["", "false", "true"], + ] + agate_table = self._make_table_of(rows, agate.Boolean) + expected = ["boolean", "boolean", "boolean"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_boolean_type(agate_table, col_idx) == expect + + def test_convert_datetime_type(self): + rows = [ + ["", "20190101T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190102T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190103T01:01:01Z", "2019-01-01 01:01:01"], + ] + agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]) + expected = ["timestamp", "timestamp", "timestamp"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_datetime_type(agate_table, col_idx) == expect + + def test_convert_date_type(self): + rows = [ + ["", "2019-01-01", "2019-01-04"], + ["", "2019-01-02", "2019-01-04"], + ["", "2019-01-03", "2019-01-04"], + ] + agate_table = self._make_table_of(rows, agate.Date) + expected = ["date", "date", "date"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_date_type(agate_table, col_idx) == expect diff --git a/dbt-athena/tests/unit/test_column.py b/dbt-athena/tests/unit/test_column.py new file mode 100644 index 00000000..a2b15bf9 --- /dev/null +++ b/dbt-athena/tests/unit/test_column.py @@ -0,0 +1,108 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.relation import TableType + + +class TestAthenaColumn: + def setup_column(self, **kwargs): + base_kwargs = {"column": "foo", "dtype": "varchar"} + return AthenaColumn(**{**base_kwargs, **kwargs}) + + @pytest.mark.parametrize( + "table_type,expected", + [ + pytest.param(TableType.TABLE, False), + pytest.param(TableType.ICEBERG, True), + ], + ) + def test_is_iceberg(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.is_iceberg() is expected + + @pytest.mark.parametrize( + "dtype,expected_type_func", + [ + pytest.param("varchar", "is_string"), + pytest.param("string", "is_string"), + pytest.param("binary", "is_binary"), + pytest.param("varbinary", "is_binary"), + pytest.param("timestamp", "is_timestamp"), + pytest.param("array ", "is_array"), + pytest.param("array(string)", "is_array"), + ], + ) + def test_is_type(self, dtype, expected_type_func): + column = self.setup_column(dtype=dtype) + for type_func in ["is_string", "is_binary", "is_timestamp", "is_array"]: + if type_func == expected_type_func: + assert getattr(column, type_func)() + else: + assert not getattr(column, type_func)() + + @pytest.mark.parametrize("size,expected", [pytest.param(1, "varchar(1)"), pytest.param(0, "varchar")]) + def test_string_type(self, size, expected): + assert AthenaColumn.string_type(size) == expected + + @pytest.mark.parametrize( + "table_type,expected", + [pytest.param(TableType.TABLE, "timestamp"), pytest.param(TableType.ICEBERG, "timestamp(6)")], + ) + def test_timestamp_type(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.timestamp_type() == expected + + def test_array_type(self): + assert AthenaColumn.array_type("varchar") == "array(varchar)" + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("array ", "string"), + pytest.param("array ", "varchar(10)"), + pytest.param("array >", "array "), + pytest.param("array", "array"), + ], + ) + def test_array_inner_type(self, dtype, expected): + column = self.setup_column(dtype=dtype) + assert column.array_inner_type() == expected + + def test_array_inner_type_raises_for_non_array_type(self): + column = self.setup_column(dtype="varchar") + with pytest.raises(DbtRuntimeError, match=r"Called array_inner_type\(\) on non-array field!"): + column.array_inner_type() + + @pytest.mark.parametrize( + "char_size,expected", + [ + pytest.param(10, 10), + pytest.param(None, 0), + ], + ) + def test_string_size(self, char_size, expected): + column = self.setup_column(dtype="varchar", char_size=char_size) + assert column.string_size() == expected + + def test_string_size_raises_for_non_string_type(self): + column = self.setup_column(dtype="int") + with pytest.raises(DbtRuntimeError, match=r"Called string_size\(\) on non-string field!"): + column.string_size() + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("string", "varchar(10)"), + pytest.param("decimal", "decimal(1,2)"), + pytest.param("binary", "varbinary"), + pytest.param("timestamp", "timestamp(6)"), + pytest.param("array ", "array(varchar(10))"), + pytest.param("array >", "array(array(varchar(10)))"), + ], + ) + def test_data_type(self, dtype, expected): + column = self.setup_column( + table_type=TableType.ICEBERG, dtype=dtype, char_size=10, numeric_precision=1, numeric_scale=2 + ) + assert column.data_type == expected diff --git a/dbt-athena/tests/unit/test_config.py b/dbt-athena/tests/unit/test_config.py new file mode 100644 index 00000000..55a3672e --- /dev/null +++ b/dbt-athena/tests/unit/test_config.py @@ -0,0 +1,114 @@ +import importlib.metadata +from unittest.mock import Mock + +import pytest + +from dbt.adapters.athena.config import AthenaSparkSessionConfig, get_boto3_config + + +class TestConfig: + def test_get_boto3_config(self): + importlib.metadata.version = Mock(return_value="2.4.6") + num_boto3_retries = 5 + get_boto3_config.cache_clear() + config = get_boto3_config(num_retries=num_boto3_retries) + assert config._user_provided_options["user_agent_extra"] == "dbt-athena/2.4.6" + assert config.retries == {"max_attempts": num_boto3_retries, "mode": "standard"} + + +class TestAthenaSparkSessionConfig: + """ + A class to test AthenaSparkSessionConfig + """ + + @pytest.fixture + def spark_config(self, request): + """ + Fixture for providing Spark configuration parameters. + + This fixture returns a dictionary containing the Spark configuration parameters. The parameters can be + customized using the `request.param` object. The default values are: + - `timeout`: 7200 seconds + - `polling_interval`: 5 seconds + - `engine_config`: {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + + Args: + self: The test class instance. + request: The pytest request object. + + Returns: + dict: The Spark configuration parameters. + + """ + return { + "timeout": request.param.get("timeout", 7200), + "polling_interval": request.param.get("polling_interval", 5), + "engine_config": request.param.get( + "engine_config", {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + ), + } + + @pytest.fixture + def spark_config_helper(self, spark_config): + """Fixture for testing AthenaSparkSessionConfig class. + + Args: + spark_config (dict): Fixture for default spark config. + + Returns: + AthenaSparkSessionConfig: An instance of AthenaSparkSessionConfig class. + """ + return AthenaSparkSessionConfig(spark_config) + + @pytest.mark.parametrize( + "spark_config", + [ + {"timeout": 5}, + {"timeout": 10}, + {"timeout": 20}, + {}, + pytest.param({"timeout": -1}, marks=pytest.mark.xfail), + pytest.param({"timeout": None}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_timeout(self, spark_config_helper): + timeout = spark_config_helper.set_timeout() + assert timeout == spark_config_helper.config.get("timeout", 7200) + + @pytest.mark.parametrize( + "spark_config", + [ + {"polling_interval": 5}, + {"polling_interval": 10}, + {"polling_interval": 20}, + {}, + pytest.param({"polling_interval": -1}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_polling_interval(self, spark_config_helper): + polling_interval = spark_config_helper.set_polling_interval() + assert polling_interval == spark_config_helper.config.get("polling_interval", 5) + + @pytest.mark.parametrize( + "spark_config", + [ + {"engine_config": {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}}, + {"engine_config": {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 2}}, + {}, + pytest.param({"engine_config": {"CoordinatorDpuSize": 1}}, marks=pytest.mark.xfail), + pytest.param({"engine_config": [1, 1, 1]}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_engine_config(self, spark_config_helper): + engine_config = spark_config_helper.set_engine_config() + diff = set(engine_config.keys()) - { + "CoordinatorDpuSize", + "MaxConcurrentDpus", + "DefaultExecutorDpuSize", + "SparkProperties", + "AdditionalConfigs", + } + assert len(diff) == 0 diff --git a/dbt-athena/tests/unit/test_connection_manager.py b/dbt-athena/tests/unit/test_connection_manager.py new file mode 100644 index 00000000..fd48a62f --- /dev/null +++ b/dbt-athena/tests/unit/test_connection_manager.py @@ -0,0 +1,35 @@ +from multiprocessing import get_context +from unittest import mock + +import pytest +from pyathena.model import AthenaQueryExecution + +from dbt.adapters.athena import AthenaConnectionManager +from dbt.adapters.athena.connections import AthenaAdapterResponse + + +class TestAthenaConnectionManager: + @pytest.mark.parametrize( + ("state", "result"), + ( + pytest.param(AthenaQueryExecution.STATE_SUCCEEDED, "OK"), + pytest.param(AthenaQueryExecution.STATE_CANCELLED, "ERROR"), + ), + ) + def test_get_response(self, state, result): + cursor = mock.MagicMock() + cursor.rowcount = 1 + cursor.state = state + cursor.data_scanned_in_bytes = 123 + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) + response = cm.get_response(cursor) + assert isinstance(response, AthenaAdapterResponse) + assert response.code == result + assert response.rows_affected == 1 + assert response.data_scanned_in_bytes == 123 + + def test_data_type_code_to_name(self): + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) + assert cm.data_type_code_to_name("array ") == "ARRAY" + assert cm.data_type_code_to_name("map ") == "MAP" + assert cm.data_type_code_to_name("DECIMAL(3, 7)") == "DECIMAL" diff --git a/dbt-athena/tests/unit/test_formatter.py b/dbt-athena/tests/unit/test_formatter.py new file mode 100644 index 00000000..a645635a --- /dev/null +++ b/dbt-athena/tests/unit/test_formatter.py @@ -0,0 +1,118 @@ +import textwrap +from datetime import date, datetime +from decimal import Decimal + +import agate +import pytest +from pyathena.error import ProgrammingError + +from dbt.adapters.athena.connections import AthenaParameterFormatter + + +class TestAthenaParameterFormatter: + formatter = AthenaParameterFormatter() + + @pytest.mark.parametrize( + "sql", + [ + "", + """ + + """, + ], + ) + def test_query_none_or_empty(self, sql): + with pytest.raises(ProgrammingError) as exc: + self.formatter.format(sql) + assert exc.value.__str__() == "Query is none or empty." + + def test_query_none_parameters(self): + sql = self.formatter.format( + """ + + ALTER TABLE table ADD partition (dt = '2022-01-01') + """ + ) + assert sql == "ALTER TABLE table ADD partition (dt = '2022-01-01')" + + def test_query_parameters_not_list(self): + with pytest.raises(ProgrammingError) as exc: + self.formatter.format( + """ + SELECT * + FROM table + WHERE country = %(country)s + """, + {"country": "FR"}, + ) + assert exc.value.__str__() == "Unsupported parameter (Support for list only): {'country': 'FR'}" + + def test_query_parameters_unknown_formatter(self): + with pytest.raises(TypeError) as exc: + self.formatter.format( + """ + SELECT * + FROM table + WHERE country = %s + """, + [agate.Table(rows=[("a", 1), ("b", 2)], column_names=["str", "int"])], + ) + assert exc.value.__str__() == " is not defined formatter." + + def test_query_parameters_list(self): + res = self.formatter.format( + textwrap.dedent( + """ + SELECT * + FROM table + WHERE nullable_field = %s + AND dt = %s + AND dti = %s + AND int_field > %s + AND float_field > %.2f + AND fake_decimal_field < %s + AND str_field = %s + AND str_list_field IN %s + AND int_list_field IN %s + AND float_list_field IN %s + AND dt_list_field IN %s + AND dti_list_field IN %s + AND bool_field = %s + """ + ), + [ + None, + date(2022, 1, 1), + datetime(2022, 1, 1, 0, 0, 0), + 1, + 1.23, + Decimal(2), + "test", + ["a", "b"], + [1, 2, 3, 4], + (1.1, 1.2, 1.3), + (date(2022, 2, 1), date(2022, 3, 1)), + (datetime(2022, 2, 1, 1, 2, 3), datetime(2022, 3, 1, 4, 5, 6)), + True, + ], + ) + expected = textwrap.dedent( + """ + SELECT * + FROM table + WHERE nullable_field = null + AND dt = DATE '2022-01-01' + AND dti = TIMESTAMP '2022-01-01 00:00:00.000' + AND int_field > 1 + AND float_field > 1.23 + AND fake_decimal_field < 2 + AND str_field = 'test' + AND str_list_field IN ('a', 'b') + AND int_list_field IN (1, 2, 3, 4) + AND float_list_field IN (1.100000, 1.200000, 1.300000) + AND dt_list_field IN (DATE '2022-02-01', DATE '2022-03-01') + AND dti_list_field IN (TIMESTAMP '2022-02-01 01:02:03.000', TIMESTAMP '2022-03-01 04:05:06.000') + AND bool_field = True + """ + ).strip() + assert res == expected diff --git a/dbt-athena/tests/unit/test_lakeformation.py b/dbt-athena/tests/unit/test_lakeformation.py new file mode 100644 index 00000000..3a05030c --- /dev/null +++ b/dbt-athena/tests/unit/test_lakeformation.py @@ -0,0 +1,144 @@ +import boto3 +import pytest +from tests.unit.constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME + +import dbt.adapters.athena.lakeformation as lakeformation +from dbt.adapters.athena.lakeformation import LfTagsConfig, LfTagsManager +from dbt.adapters.athena.relation import AthenaRelation + + +# TODO: add more tests for lakeformation once moto library implements required methods: +# https://docs.getmoto.org/en/4.1.9/docs/services/lakeformation.html +# get_resource_lf_tags +class TestLfTagsManager: + @pytest.mark.parametrize( + "response,identifier,columns,lf_tags,verb,expected", + [ + pytest.param( + { + "Failures": [ + { + "LFTag": {"CatalogId": "test_catalog", "TagKey": "test_key", "TagValues": ["test_values"]}, + "Error": {"ErrorCode": "test_code", "ErrorMessage": "test_err_msg"}, + } + ] + }, + "tbl_name", + ["column1", "column2"], + {"tag_key": "tag_value"}, + "add", + None, + id="lf_tag error", + marks=pytest.mark.xfail, + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + id="add lf_tag", + ), + pytest.param( + {"Failures": []}, + None, + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena", + id="add lf_tag_to_database", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, + {"tag_key": "tag_value"}, + "remove", + "Success: remove LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + id="remove lf_tag", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + ["c1", "c2"], + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", + id="lf_tag database table and columns", + ), + ], + ) + def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, columns, lf_tags, verb, expected): + relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=identifier) + lf_client = boto3.client("lakeformation", region_name=AWS_REGION) + manager = LfTagsManager(lf_client, relation, LfTagsConfig()) + manager._parse_and_log_lf_response(response, columns, lf_tags, verb) + assert expected in dbt_debug_caplog.getvalue() + + @pytest.mark.parametrize( + "lf_tags_columns,lf_inherited_tags,expected", + [ + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}]}], + {"inherited"}, + {}, + id="retains-inherited-tag", + ), + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}]}], + {}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag", + ), + pytest.param( + [ + { + "Name": "my_column", + "LFTags": [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + } + ], + {"inherited"}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag-among-inherited", + ), + pytest.param([], {}, {}, id="handles-empty"), + ], + ) + def test__column_tags_to_remove(self, lf_tags_columns, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._column_tags_to_remove(lf_tags_columns, lf_inherited_tags) == expected + + @pytest.mark.parametrize( + "lf_tags_table,lf_tags,lf_inherited_tags,expected", + [ + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {"not-inherited": "some-preexisting-value"}, + {"inherited"}, + {}, + id="retains-being-set-and-inherited", + ), + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {}, + {"inherited"}, + {"not-inherited": ["oh-no-im-not"]}, + id="removes-preexisting-not-being-set", + ), + pytest.param( + [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}], {}, {"inherited"}, {}, id="retains-inherited" + ), + pytest.param([], None, {}, {}, id="handles-empty"), + ], + ) + def test__table_tags_to_remove(self, lf_tags_table, lf_tags, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._table_tags_to_remove(lf_tags_table, lf_tags, lf_inherited_tags) == expected diff --git a/dbt-athena/tests/unit/test_python_submissions.py b/dbt-athena/tests/unit/test_python_submissions.py new file mode 100644 index 00000000..65961d60 --- /dev/null +++ b/dbt-athena/tests/unit/test_python_submissions.py @@ -0,0 +1,209 @@ +import time +import uuid +from unittest.mock import Mock, patch + +import pytest + +from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper +from dbt.adapters.athena.session import AthenaSparkSessionManager + +from .constants import DATABASE_NAME + + +@pytest.mark.usefixtures("athena_credentials", "athena_client") +class TestAthenaPythonJobHelper: + """ + A class to test the AthenaPythonJobHelper + """ + + @pytest.fixture + def parsed_model(self, request): + config: dict[str, int] = request.param.get("config", {"timeout": 1, "polling_interval": 5}) + + return { + "alias": "test_model", + "schema": DATABASE_NAME, + "config": { + "timeout": config["timeout"], + "polling_interval": config["polling_interval"], + "engine_config": request.param.get( + "engine_config", {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + ), + }, + } + + @pytest.fixture + def athena_spark_session_manager(self, athena_credentials, parsed_model): + return AthenaSparkSessionManager( + athena_credentials, + timeout=parsed_model["config"]["timeout"], + polling_interval=parsed_model["config"]["polling_interval"], + engine_config=parsed_model["config"]["engine_config"], + ) + + @pytest.fixture + def athena_job_helper( + self, athena_credentials, athena_client, athena_spark_session_manager, parsed_model, monkeypatch + ): + mock_job_helper = AthenaPythonJobHelper(parsed_model, athena_credentials) + monkeypatch.setattr(mock_job_helper, "athena_client", athena_client) + monkeypatch.setattr(mock_job_helper, "spark_connection", athena_spark_session_manager) + return mock_job_helper + + @pytest.mark.parametrize( + "parsed_model, session_status_response, expected_response", + [ + ( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "IDLE", + }, + None, + ), + pytest.param( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "FAILED", + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "TERMINATED", + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "State": "CREATING", + }, + None, + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_poll_session_idle( + self, session_status_response, expected_response, athena_job_helper, athena_spark_session_manager, monkeypatch + ): + with patch.multiple( + athena_spark_session_manager, + get_session_status=Mock(return_value=session_status_response), + get_session_id=Mock(return_value="test_session_id"), + ): + + def mock_sleep(_): + pass + + monkeypatch.setattr(time, "sleep", mock_sleep) + poll_response = athena_job_helper.poll_until_session_idle() + assert poll_response == expected_response + + @pytest.mark.parametrize( + "parsed_model, execution_status, expected_response", + [ + ( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "COMPLETED", + } + }, + "COMPLETED", + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "FAILED", + } + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "RUNNING", + } + }, + "RUNNING", + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_poll_execution( + self, + execution_status, + expected_response, + athena_job_helper, + athena_spark_session_manager, + athena_client, + monkeypatch, + ): + with patch.multiple( + athena_spark_session_manager, + get_session_id=Mock(return_value=uuid.uuid4()), + ), patch.multiple( + athena_client, + get_calculation_execution=Mock(return_value=execution_status), + ): + + def mock_sleep(_): + pass + + monkeypatch.setattr(time, "sleep", mock_sleep) + poll_response = athena_job_helper.poll_until_execution_completion("test_calculation_id") + assert poll_response == expected_response + + @pytest.mark.parametrize( + "parsed_model, test_calculation_execution_id, test_calculation_execution", + [ + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {"CalculationExecutionId": "test_execution_id"}, + { + "Result": {"ResultS3Uri": "test_results_s3_uri"}, + "Status": {"State": "COMPLETED"}, + }, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {"CalculationExecutionId": "test_execution_id"}, + {"Result": {}, "Status": {"State": "FAILED"}}, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {}, + {"Result": {}, "Status": {"State": "FAILED"}}, + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_submission( + self, + test_calculation_execution_id, + test_calculation_execution, + athena_job_helper, + athena_spark_session_manager, + athena_client, + ): + with patch.multiple( + athena_spark_session_manager, get_session_id=Mock(return_value=uuid.uuid4()) + ), patch.multiple( + athena_client, + start_calculation_execution=Mock(return_value=test_calculation_execution_id), + get_calculation_execution=Mock(return_value=test_calculation_execution), + ), patch.multiple( + athena_job_helper, poll_until_session_idle=Mock(return_value="IDLE") + ): + result = athena_job_helper.submit("hello world") + assert result == test_calculation_execution["Result"] diff --git a/dbt-athena/tests/unit/test_query_headers.py b/dbt-athena/tests/unit/test_query_headers.py new file mode 100644 index 00000000..89a99ef5 --- /dev/null +++ b/dbt-athena/tests/unit/test_query_headers.py @@ -0,0 +1,88 @@ +from unittest import mock + +from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter +from dbt.context.query_header import generate_query_header_context + +from .constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME +from .utils import config_from_parts_or_dicts + + +class TestQueryHeaders: + def setup_method(self, _): + config = config_from_parts_or_dicts( + { + "name": "query_headers", + "version": "0.1", + "profile": "test", + "config-version": 2, + }, + { + "outputs": { + "test": { + "type": "athena", + "s3_staging_dir": "s3://my-bucket/test-dbt/", + "region_name": AWS_REGION, + "database": DATA_CATALOG_NAME, + "work_group": "dbt-athena-adapter", + "schema": DATABASE_NAME, + } + }, + "target": "test", + }, + ) + self.query_header = AthenaMacroQueryStringSetter( + config, generate_query_header_context(config, mock.MagicMock(macros={})) + ) + + def test_append_comment_with_semicolon(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = self.query_header.add("SELECT 1;") + assert sql == "SELECT 1\n-- /* executed by dbt */;" + + def test_append_comment_without_semicolon(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = self.query_header.add("SELECT 1") + assert sql == "SELECT 1\n-- /* executed by dbt */" + + def test_comment_multiple_lines(self): + self.query_header.comment.query_comment = """executed by dbt\nfor table table""" + self.query_header.comment.append = False + sql = self.query_header.add("insert into table(id) values (1);") + assert sql == "-- /* executed by dbt for table table */\ninsert into table(id) values (1);" + + def test_disable_query_comment(self): + self.query_header.comment.query_comment = "" + self.query_header.comment.append = True + assert self.query_header.add("SELECT 1;") == "SELECT 1;" + + def test_no_query_comment_on_alter(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "alter table table add column time time;" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_vacuum(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "VACUUM table" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_msck(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "MSCK REPAIR TABLE" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_vacuum_with_leading_whitespaces(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = " VACUUM table" + assert self.query_header.add(sql) == "VACUUM table" + + def test_no_query_comment_on_vacuum_with_lowercase(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "vacuum table" + assert self.query_header.add(sql) == sql diff --git a/dbt-athena/tests/unit/test_relation.py b/dbt-athena/tests/unit/test_relation.py new file mode 100644 index 00000000..4e8c61c3 --- /dev/null +++ b/dbt-athena/tests/unit/test_relation.py @@ -0,0 +1,55 @@ +import pytest + +from dbt.adapters.athena.relation import AthenaRelation, TableType, get_table_type + +from .constants import DATA_CATALOG_NAME, DATABASE_NAME + +TABLE_NAME = "test_table" + + +class TestRelation: + @pytest.mark.parametrize( + ("table", "expected"), + [ + ({"Name": "n", "TableType": "table"}, TableType.TABLE), + ({"Name": "n", "TableType": "VIRTUAL_VIEW"}, TableType.VIEW), + ({"Name": "n", "TableType": "EXTERNAL_TABLE", "Parameters": {"table_type": "ICEBERG"}}, TableType.ICEBERG), + ], + ) + def test__get_relation_type(self, table, expected): + assert get_table_type(table) == expected + + def test__get_relation_type_with_no_type(self): + with pytest.raises(ValueError): + get_table_type({"Name": "name"}) + + def test__get_relation_type_with_unknown_type(self): + with pytest.raises(ValueError): + get_table_type({"Name": "name", "TableType": "test"}) + + +class TestAthenaRelation: + def test_render_hive_uses_hive_style_quotation_and_only_schema_and_table_names(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + assert relation.render_hive() == f"`{DATABASE_NAME}`.`{TABLE_NAME}`" + + def test_render_hive_resets_quote_character_and_include_policy_after_call(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + relation.render_hive() + assert relation.render() == f'"{DATA_CATALOG_NAME}"."{DATABASE_NAME}"."{TABLE_NAME}"' + + def test_render_pure_resets_quote_character_after_call(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + assert relation.render_pure() == f"{DATA_CATALOG_NAME}.{DATABASE_NAME}.{TABLE_NAME}" diff --git a/dbt-athena/tests/unit/test_session.py b/dbt-athena/tests/unit/test_session.py new file mode 100644 index 00000000..3c1f4324 --- /dev/null +++ b/dbt-athena/tests/unit/test_session.py @@ -0,0 +1,176 @@ +from unittest.mock import Mock, patch +from uuid import UUID + +import botocore.session +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena import AthenaCredentials +from dbt.adapters.athena.session import AthenaSparkSessionManager, get_boto3_session +from dbt.adapters.contracts.connection import Connection + + +class TestSession: + @pytest.mark.parametrize( + ("credentials_profile_name", "boto_profile_name"), + ( + pytest.param(None, "default", id="no_profile_in_credentials"), + pytest.param("my_profile", "my_profile", id="profile_in_credentials"), + ), + ) + def test_session_should_be_called_with_correct_parameters( + self, monkeypatch, credentials_profile_name, boto_profile_name + ): + def mock___build_profile_map(_): + return {**{"default": {}}, **({} if not credentials_profile_name else {credentials_profile_name: {}})} + + monkeypatch.setattr(botocore.session.Session, "_build_profile_map", mock___build_profile_map) + connection = Connection( + type="test", + name="test_session", + credentials=AthenaCredentials( + database="db", + schema="schema", + s3_staging_dir="dir", + region_name="eu-west-1", + aws_profile_name=credentials_profile_name, + ), + ) + session = get_boto3_session(connection) + assert session.region_name == "eu-west-1" + assert session.profile_name == boto_profile_name + + +@pytest.mark.usefixtures("athena_credentials", "athena_client") +class TestAthenaSparkSessionManager: + """ + A class to test the AthenaSparkSessionManager + """ + + @pytest.fixture + def spark_session_manager(self, athena_credentials, athena_client, monkeypatch): + """ + Fixture for creating a mock Spark session manager. + + This fixture creates an instance of AthenaSparkSessionManager with the provided Athena credentials, + timeout, polling interval, and engine configuration. It then patches the Athena client of the manager + with the provided `athena_client` object. The fixture returns the mock Spark session manager. + + Args: + self: The test class instance. + athena_credentials: The Athena credentials. + athena_client: The Athena client object. + monkeypatch: The monkeypatch object for mocking. + + Returns: + The mock Spark session manager. + + """ + mock_session_manager = AthenaSparkSessionManager( + athena_credentials, + timeout=10, + polling_interval=5, + engine_config={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}, + ) + monkeypatch.setattr(mock_session_manager, "athena_client", athena_client) + return mock_session_manager + + @pytest.mark.parametrize( + "session_status_response, expected_response", + [ + pytest.param( + {"Status": {"SessionId": "test_session_id", "State": "CREATING"}}, + DbtRuntimeError( + """Session + did not create within 10 seconds.""" + ), + marks=pytest.mark.xfail, + ), + ( + {"Status": {"SessionId": "635c1c6d-766c-408b-8bce-fae8ea7006f7", "State": "IDLE"}}, + UUID("635c1c6d-766c-408b-8bce-fae8ea7006f7"), + ), + pytest.param( + {"Status": {"SessionId": "test_session_id", "State": "TERMINATED"}}, + DbtRuntimeError("Unable to create session: test_session_id. Got status: TERMINATED."), + marks=pytest.mark.xfail, + ), + ], + ) + def test_start_session( + self, session_status_response, expected_response, spark_session_manager, athena_client + ) -> None: + """ + Test the start_session method of the AthenaJobHelper class. + + Args: + session_status_response (dict): A dictionary containing the response from the Athena session + creation status. + expected_response (Union[dict, DbtRuntimeError]): The expected response from the start_session method. + athena_job_helper (AthenaPythonJobHelper): An instance of the AthenaPythonJobHelper class. + athena_client (botocore.client.BaseClient): An instance of the botocore Athena client. + + Returns: + None + """ + with patch.multiple( + spark_session_manager, + poll_until_session_creation=Mock(return_value=session_status_response), + ), patch.multiple( + athena_client, + get_session_status=Mock(return_value=session_status_response), + start_session=Mock(return_value=session_status_response.get("Status")), + ): + response = spark_session_manager.start_session() + assert response == expected_response + + @pytest.mark.parametrize( + "session_status_response, expected_status", + [ + ( + { + "SessionId": "test_session_id", + "Status": { + "State": "CREATING", + }, + }, + { + "State": "CREATING", + }, + ), + ( + { + "SessionId": "test_session_id", + "Status": { + "State": "IDLE", + }, + }, + { + "State": "IDLE", + }, + ), + ], + ) + def test_get_session_status(self, session_status_response, expected_status, spark_session_manager, athena_client): + """ + Test the get_session_status function. + + Args: + self: The test class instance. + session_status_response (dict): The response from get_session_status. + expected_status (dict): The expected session status. + spark_session_manager: The Spark session manager object. + athena_client: The Athena client object. + + Returns: + None + + Raises: + AssertionError: If the retrieved session status is not correct. + """ + with patch.multiple(athena_client, get_session_status=Mock(return_value=session_status_response)): + response = spark_session_manager.get_session_status("test_session_id") + assert response == expected_status + + def test_get_session_id(self): + pass diff --git a/dbt-athena/tests/unit/test_utils.py b/dbt-athena/tests/unit/test_utils.py new file mode 100644 index 00000000..69a2dc9d --- /dev/null +++ b/dbt-athena/tests/unit/test_utils.py @@ -0,0 +1,72 @@ +import pytest + +from dbt.adapters.athena.utils import ( + clean_sql_comment, + ellipsis_comment, + get_chunks, + is_valid_table_parameter_key, + stringify_table_parameter_value, +) + + +def test_clean_comment(): + assert ( + clean_sql_comment( + """ + my long comment + on several lines + with weird spaces and indents. + """ + ) + == "my long comment on several lines with weird spaces and indents." + ) + + +def test_stringify_table_parameter_value(): + class NonStringifiableObject: + def __str__(self): + raise ValueError("Non-stringifiable object") + + assert stringify_table_parameter_value(True) == "True" + assert stringify_table_parameter_value(123) == "123" + assert stringify_table_parameter_value("dbt-athena") == "dbt-athena" + assert stringify_table_parameter_value(["a", "b", 3]) == '["a", "b", 3]' + assert stringify_table_parameter_value({"a": 1, "b": "c"}) == '{"a": 1, "b": "c"}' + assert len(stringify_table_parameter_value("a" * 512001)) == 512000 + assert stringify_table_parameter_value(NonStringifiableObject()) is None + assert stringify_table_parameter_value([NonStringifiableObject()]) is None + + +def test_is_valid_table_parameter_key(): + assert is_valid_table_parameter_key("valid_key") is True + assert is_valid_table_parameter_key("Valid Key 123*!") is True + assert is_valid_table_parameter_key("invalid \n key") is False + assert is_valid_table_parameter_key("long_key" * 100) is False + + +def test_get_chunks_empty(): + assert len(list(get_chunks([], 5))) == 0 + + +def test_get_chunks_uneven(): + chunks = list(get_chunks([1, 2, 3], 2)) + assert chunks[0] == [1, 2] + assert chunks[1] == [3] + assert len(chunks) == 2 + + +def test_get_chunks_more_elements_than_chunk(): + chunks = list(get_chunks([1, 2, 3], 4)) + assert chunks[0] == [1, 2, 3] + assert len(chunks) == 1 + + +@pytest.mark.parametrize( + ("max_len", "expected"), + ( + pytest.param(12, "abc def ghi", id="ok string"), + pytest.param(6, "abc...", id="ellipsis"), + ), +) +def test_ellipsis_comment(max_len, expected): + assert expected == ellipsis_comment("abc def ghi", max_len=max_len) diff --git a/dbt-athena/tests/unit/utils.py b/dbt-athena/tests/unit/utils.py new file mode 100644 index 00000000..320fdc13 --- /dev/null +++ b/dbt-athena/tests/unit/utils.py @@ -0,0 +1,485 @@ +import os +import string +from typing import Optional + +import agate +import boto3 + +from dbt.adapters.athena.utils import AthenaCatalogType +from dbt.config.project import PartialProject + +from .constants import AWS_REGION, BUCKET, CATALOG_ID, DATA_CATALOG_NAME, DATABASE_NAME + + +class Obj: + which = "blah" + single_threaded = False + + +def profile_from_dict(profile, profile_name, cli_vars="{}"): + from dbt.config import Profile + from dbt.config.renderer import ProfileRenderer + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + renderer = ProfileRenderer(cli_vars) + + # in order to call dbt's internal profile rendering, we need to set the + # flags global. This is a bit of a hack, but it's the best way to do it. + from argparse import Namespace + + from dbt.flags import set_from_args + + set_from_args(Namespace(), None) + return Profile.from_raw_profile_info( + profile, + profile_name, + renderer, + ) + + +def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"): + from dbt.config.renderer import DbtProjectYamlRenderer + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + renderer = DbtProjectYamlRenderer(profile, cli_vars) + + project_root = project.pop("project-root", os.getcwd()) + + partial = PartialProject.from_dicts( + project_root=project_root, + project_dict=project, + packages_dict=packages, + selectors_dict=selectors, + ) + return partial.render(renderer) + + +def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars="{}"): + from copy import deepcopy + + from dbt.config import Profile, Project, RuntimeConfig + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + if isinstance(project, Project): + profile_name = project.profile_name + else: + profile_name = project.get("profile") + + if not isinstance(profile, Profile): + profile = profile_from_dict( + deepcopy(profile), + profile_name, + cli_vars, + ) + + if not isinstance(project, Project): + project = project_from_dict( + deepcopy(project), + profile, + packages, + selectors, + cli_vars, + ) + + args = Obj() + args.vars = cli_vars + args.profile_dir = "/dev/null" + return RuntimeConfig.from_parts(project=project, profile=profile, args=args) + + +def inject_plugin(plugin): + from dbt.adapters.factory import FACTORY + + key = plugin.adapter.type() + FACTORY.plugins[key] = plugin + + +def inject_adapter(value, plugin): + """Inject the given adapter into the adapter factory, so your hand-crafted + artisanal adapter will be available from get_adapter() as if dbt loaded it. + """ + inject_plugin(plugin) + from dbt.adapters.factory import FACTORY + + key = value.type() + FACTORY.adapters[key] = value + + +def clear_plugin(plugin): + from dbt.adapters.factory import FACTORY + + key = plugin.adapter.type() + FACTORY.plugins.pop(key, None) + FACTORY.adapters.pop(key, None) + + +class TestAdapterConversions: + def _get_tester_for(self, column_type): + from dbt_common.clients import agate_helper + + if column_type is agate.TimeDelta: # dbt never makes this! + return agate.TimeDelta() + + for instance in agate_helper.DEFAULT_TYPE_TESTER._possible_types: + if isinstance(instance, column_type): # include child types + return instance + + raise ValueError(f"no tester for {column_type}") + + def _make_table_of(self, rows, column_types): + column_names = list(string.ascii_letters[: len(rows[0])]) + if isinstance(column_types, type): + column_types = [self._get_tester_for(column_types) for _ in column_names] + else: + column_types = [self._get_tester_for(typ) for typ in column_types] + table = agate.Table(rows, column_names=column_names, column_types=column_types) + return table + + +class MockAWSService: + def create_data_catalog( + self, + catalog_name: str = DATA_CATALOG_NAME, + catalog_type: AthenaCatalogType = AthenaCatalogType.GLUE, + catalog_id: str = CATALOG_ID, + ): + athena = boto3.client("athena", region_name=AWS_REGION) + parameters = {} + if catalog_type == AthenaCatalogType.GLUE: + parameters = {"catalog-id": catalog_id} + else: + parameters = {"catalog": catalog_name} + athena.create_data_catalog(Name=catalog_name, Type=catalog_type.value, Parameters=parameters) + + def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_database(DatabaseInput={"Name": name}, CatalogId=catalog_id) + + def create_view(self, view_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": view_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": "", + }, + "TableType": "VIRTUAL_VIEW", + "Parameters": { + "TableOwner": "John Doe", + }, + }, + ) + + def create_table( + self, + table_name: str, + database_name: str = DATABASE_NAME, + catalog_id: str = CATALOG_ID, + location: Optional[str] = "auto", + ): + glue = boto3.client("glue", region_name=AWS_REGION) + if location == "auto": + location = f"s3://{BUCKET}/tables/{table_name}" + glue.create_table( + CatalogId=catalog_id, + DatabaseName=database_name, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "Location": location, + }, + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + "TableType": "table", + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_table_without_type(self, table_name: str, database_name: str = DATABASE_NAME): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=database_name, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_table_without_partitions(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "PartitionKeys": [], + "TableType": "table", + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_iceberg_table(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/data/{table_name}", + }, + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "metadata_location": f"s3://{BUCKET}/tables/metadata/{table_name}/123.json", + "table_type": "ICEBERG", + }, + }, + ) + + def create_table_without_table_type(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "Parameters": { + "TableOwner": "John Doe", + }, + }, + ) + + def create_work_group_with_output_location_enforced(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://pre-configured-output-location/", + }, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def create_work_group_with_output_location_not_enforced(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://pre-configured-output-location/", + }, + "EnforceWorkGroupConfiguration": False, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def create_work_group_no_output_location(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def add_data_in_table(self, table_name: str): + s3 = boto3.client("s3", region_name=AWS_REGION) + s3.create_bucket(Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-01/data1.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-01/data2.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-02/data.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-03/data1.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-03/data2.parquet") + partition_input_list = [ + { + "Values": [dt], + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}/dt={dt}", + }, + } + for dt in ["2022-01-01", "2022-01-02", "2022-01-03"] + ] + glue = boto3.client("glue", region_name=AWS_REGION) + glue.batch_create_partition( + DatabaseName="test_dbt_athena", TableName=table_name, PartitionInputList=partition_input_list + ) + + def add_partitions_to_table(self, database, table_name): + partition_input_list = [ + { + "Values": [dt], + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}/dt={dt}", + }, + "Parameters": {"compressionType": "snappy", "classification": "parquet"}, + } + for dt in [f"2022-01-{day:02d}" for day in range(1, 27)] + ] + glue = boto3.client("glue", region_name=AWS_REGION) + glue.batch_create_partition( + DatabaseName=database, TableName=table_name, PartitionInputList=partition_input_list + ) + + def add_table_version(self, database, table_name): + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=database, Name=table_name).get("Table") + new_table_version = { + "Name": table_name, + "StorageDescriptor": table["StorageDescriptor"], + "PartitionKeys": table["PartitionKeys"], + "TableType": table["TableType"], + "Parameters": table["Parameters"], + } + glue.update_table(DatabaseName=database, TableInput=new_table_version) diff --git a/scripts/migrate-adapter.sh b/scripts/migrate-adapter.sh index 58d70018..9993476c 100644 --- a/scripts/migrate-adapter.sh +++ b/scripts/migrate-adapter.sh @@ -8,7 +8,7 @@ git remote add old https://github.com/dbt-labs/$repo.git git fetch old # merge the updated branch from the legacy repo into the dbt-adapters repo -git checkout -b $target_branch +git checkout $target_branch git merge old/$source_branch --allow-unrelated-histories # remove the remote that was created by this process diff --git a/static/images/dbt-athena-black.png b/static/images/dbt-athena-black.png new file mode 100644 index 00000000..b7ba90b7 Binary files /dev/null and b/static/images/dbt-athena-black.png differ diff --git a/static/images/dbt-athena-color.png b/static/images/dbt-athena-color.png new file mode 100644 index 00000000..aca1b51b Binary files /dev/null and b/static/images/dbt-athena-color.png differ diff --git a/static/images/dbt-athena-long.png b/static/images/dbt-athena-long.png new file mode 100644 index 00000000..22cb5ccd Binary files /dev/null and b/static/images/dbt-athena-long.png differ