diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 27cb521e..52fc0e70 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -1,81 +1,94 @@ name: 🐞 Bug -description: Report a bug or an issue you've found with dbt-adapters +description: Report a bug or an issue you've found title: "[Bug] " labels: ["bug", "triage"] body: - - type: markdown +- type: markdown attributes: - value: | - Thanks for taking the time to fill out this bug report! - - type: checkboxes + value: Thanks for taking the time to fill out this bug report! +- type: checkboxes attributes: - label: Is this a new bug? - description: > - In other words, is this an error, flaw, failure or fault in our software? + label: Is this a new bug? + description: > + In other words, is this an error, flaw, failure or fault in our software? - If this is a bug that broke existing functionality that used to work, please open a regression issue. - If this is a bug experienced while using dbt Cloud, please report to [support](mailto:support@getdbt.com). - If this is a request for help or troubleshooting code in your own dbt project, please join our [dbt Community Slack](https://www.getdbt.com/community/join-the-community/) or open a [Discussion question](https://github.com/dbt-labs/docs.getdbt.com/discussions). + If this is a bug that broke existing functionality that used to work, please open a regression issue. + If this is a bug experienced while using dbt Cloud, please report to [support](mailto:support@getdbt.com). + If this is a request for help or troubleshooting code in your own dbt project, please join our [dbt Community Slack](https://www.getdbt.com/community/join-the-community/) or open a [Discussion question](https://github.com/dbt-labs/docs.getdbt.com/discussions). - Please search to see if an issue already exists for the bug you encountered. - options: - - label: I believe this is a new bug - required: true - - label: I have searched the existing issues, and I could not find an existing issue for this bug - required: true - - type: textarea + Please search to see if an issue already exists for the bug you encountered. + options: + - label: I believe this is a new bug + required: true + - label: I have searched the existing issues, and I could not find an existing issue for this bug + required: true +- type: checkboxes attributes: - label: Current Behavior - description: A concise description of what you're experiencing. + label: Which packages are affected? + description: Select one or more options below. + options: + - label: dbt-adapters + required: false + - label: dbt-tests-adapter + required: false + - label: dbt-athena + required: false + - label: dbt-athena-community + required: false +- type: textarea + attributes: + label: Current Behavior + description: A concise description of what you're experiencing. validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Expected Behavior - description: A concise description of what you expected to happen. + label: Expected Behavior + description: A concise description of what you expected to happen. validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Steps To Reproduce - description: Steps to reproduce the behavior. - placeholder: | - 1. In this environment... - 2. With this config... - 3. Run '...' - 4. See error... + label: Steps To Reproduce + description: Steps to reproduce the behavior. + placeholder: | + 1. In this environment... + 2. With this config... + 3. Run '...' + 4. See error... validations: - required: true - - type: textarea + required: true +- type: textarea id: logs attributes: - label: Relevant log output - description: | - If applicable, log output to help explain your problem. - render: shell + label: Relevant log output + description: If applicable, log output to help explain your problem. + render: shell validations: - required: false - - type: textarea + required: false +- type: textarea attributes: - label: Environment - description: | - examples: - - **OS**: Ubuntu 20.04 - - **Python**: 3.11.6 (`python3 --version`) - - **dbt-adapters**: 1.0.0 - value: | - - OS: - - Python: - - dbt-adapters: - render: markdown + label: Environment + description: | + examples: + - **OS**: Ubuntu 20.04 + - **Python**: 3.11.6 (`python3 --version`) + - **dbt-adapters**: 1.10.0 + - **dbt-postgres**: 1.9.0 + value: | + - OS: + - Python: + - dbt-adapters: + - <adapter>: + render: markdown validations: - required: false - - type: textarea + required: false +- type: textarea attributes: - label: Additional Context - description: | - Links? References? Anything that will give us more context about the issue you are encountering! + label: Additional Context + description: | + Links? References? Anything that will give us more context about the issue you are encountering! - Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. validations: - required: false + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index a89889af..fab83f44 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,11 +1,14 @@ blank_issues_enabled: false contact_links: - - name: Ask the community for help +- name: Ask the community for help url: https://github.com/dbt-labs/docs.getdbt.com/discussions about: Need help troubleshooting? Check out our guide on how to ask - - name: Contact dbt Cloud support +- name: Contact dbt Cloud support url: mailto:support@getdbt.com about: Are you using dbt Cloud? Contact our support team for help! - - name: Participate in Discussions +- name: Participate in Discussions url: https://github.com/dbt-labs/dbt-adapters/discussions - about: Do you have a Big Idea for dbt-adapters? Read open discussions, or start a new one + about: Do you have a Big Idea for dbt-adapters or one of the adapter implementations? Read open discussions, or start a new one +- name: Create an issue for dbt-core + url: https://github.com/dbt-labs/dbt-core/issues/new/choose + about: Report a bug or request a feature for dbt-core diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml index 22960c2d..cf376cd1 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yml +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -1,59 +1,55 @@ name: ✨ Feature -description: Propose a straightforward extension of dbt-adapters functionality +description: Propose a straightforward extension of dbt functionality title: "[Feature] <title>" labels: ["enhancement", "triage"] body: - - type: markdown +- type: markdown attributes: - value: | - Thanks for taking the time to fill out this feature request! - - type: checkboxes + value: Thanks for taking the time to fill out this feature request! +- type: checkboxes attributes: - label: Is this your first time submitting a feature request? - description: > - We want to make sure that features are distinct and discoverable, - so that other members of the community can find them and offer their thoughts. + label: Is this your first time submitting a feature request? + description: > + We want to make sure that features are distinct and discoverable, + so that other members of the community can find them and offer their thoughts. - Issues are the right place to request straightforward extensions of existing dbt-adapters functionality. - For "big ideas" about future capabilities of dbt-adapters, we ask that you open a - [discussion](https://github.com/dbt-labs/dbt-adapters/discussions/new?category=ideas) in the "Ideas" category instead. - options: - - label: I have read the [expectations for open source contributors](https://docs.getdbt.com/docs/contributing/oss-expectations) - required: true - - label: I have searched the existing issues, and I could not find an existing issue for this feature - required: true - - label: I am requesting a straightforward extension of existing dbt-adapters functionality, rather than a Big Idea better suited to a discussion - required: true - - type: textarea + Issues are the right place to request straightforward extensions of existing dbt functionality. + For "big ideas" about future capabilities of dbt, we ask that you open a + [discussion](https://github.com/dbt-labs/dbt-adapters/discussions/new?category=ideas) in the "Ideas" category instead. + options: + - label: I have read the [expectations for open source contributors](https://docs.getdbt.com/docs/contributing/oss-expectations) + required: true + - label: I have searched the existing issues, and I could not find an existing issue for this feature + required: true + - label: I am requesting a straightforward extension of existing dbt functionality, rather than a Big Idea better suited to a discussion + required: true +- type: textarea attributes: - label: Describe the feature - description: A clear and concise description of what you want to happen. + label: Describe the feature + description: A clear and concise description of what you want to happen. validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Describe alternatives you've considered - description: | - A clear and concise description of any alternative solutions or features you've considered. + label: Describe alternatives you've considered + description: A clear and concise description of any alternative solutions or features you've considered. validations: - required: false - - type: textarea + required: false +- type: textarea attributes: - label: Who will this benefit? - description: | - What kind of use case will this feature be useful for? Please be specific and provide examples, this will help us prioritize properly. + label: Who will this benefit? + description: What kind of use case will this feature be useful for? Please be specific and provide examples, this will help us prioritize properly. validations: - required: false - - type: input + required: false +- type: input attributes: - label: Are you interested in contributing this feature? - description: Let us know if you want to write some code, and how we can help. + label: Are you interested in contributing this feature? + description: Let us know if you want to write some code, and how we can help. validations: - required: false - - type: textarea + required: false +- type: textarea attributes: - label: Anything else? - description: | - Links? References? Anything that will give us more context about the feature you are suggesting! + label: Anything else? + description: Links? References? Anything that will give us more context about the feature you are suggesting! validations: - required: false + required: false diff --git a/.github/ISSUE_TEMPLATE/internal-epic.yml b/.github/ISSUE_TEMPLATE/internal-epic.yml deleted file mode 100644 index 8cfb3aef..00000000 --- a/.github/ISSUE_TEMPLATE/internal-epic.yml +++ /dev/null @@ -1,91 +0,0 @@ -name: 🧠️ Epic -description: This is an epic ticket intended for use by the maintainers and product managers of `dbt-adapters` -title: "[<project>] <title>" -labels: ["epic"] -body: - - type: markdown - attributes: - value: This is an epic ticket intended for use by the maintainers and product managers of `dbt-adapters` - - - type: input - attributes: - label: Short description - description: | - What's the one-liner that would tell everyone what we're doing? - validations: - required: true - - - type: textarea - attributes: - label: Context - description: | - Provide the "why". What's the motivation? Why are we doing it now? - Will this apply to all adapters or does this support a specific subset (e.g. only SQL adapters, only Snowflake, etc.)? - Is this updating existing functionality or providing all new functionality? - validations: - required: true - - - type: textarea - attributes: - label: Objectives - description: | - What are the high level goals we are trying to achieve? Provide use cases if available. - - Example: - - [ ] Allow adapter maintainers to support custom materializations - - [ ] Reduce maintenance burden for incremental users by offering materialized views - value: | - ```[tasklist] - - [ ] Objective - ``` - validations: - required: true - - - type: textarea - attributes: - label: Implementation tasks - description: | - Provide a list of GH issues that will build out this functionality. - This may start empty, or as a checklist of items. - However, it should eventually become a list of Feature Implementation tickets. - - Example: - - [ ] Create new macro to select warehouse - - [ ] https://github.com/dbt-labs/dbt-adapters/issues/42 - value: | - ```[tasklist] - - [ ] Task - ``` - validations: - required: false - - - type: textarea - attributes: - label: Documentation - description: | - Provide a list of relevant documentation. Is there a proof of concept? - Does this require and RFCs, ADRs, etc.? - If the documentation exists, link it; if it does not exist yet, reference it descriptively. - - Example: - - [ ] RFC for updating connection interface to accept new parameters - - [ ] POC: https://github.com/dbt-labs/dbt-adapters/pull/42 - value: | - ```[tasklist] - - [ ] Task - ``` - validations: - required: false - - - type: textarea - attributes: - label: Consequences - description: | - Will this impact dbt Labs' ability, or a partner's ability, to make a related change? Call that out for discussion. - Review `Impact:<team>` labels to ensure they capture these consequences. - placeholder: | - Example: - - This change impacts `dbt-core` because we updated how relations are managed. (Add the `Impact:[Core]` label.) - - This allows internal analytics to do more robust analysis of infrastructure usage (Add the `Impact:[Analytics]` label.) - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/internal-feature-implementation.yml b/.github/ISSUE_TEMPLATE/internal-feature-implementation.yml deleted file mode 100644 index 7a99365b..00000000 --- a/.github/ISSUE_TEMPLATE/internal-feature-implementation.yml +++ /dev/null @@ -1,103 +0,0 @@ -name: 🛠️ Feature Implementation -description: This is a feature implementation ticket intended for use by the maintainers of `dbt-adapters` -title: "[<project>] <title>" -labels: ["user docs","enhancement","refinement"] -body: - - type: markdown - attributes: - value: This is a feature implementation ticket intended for use by the maintainers of `dbt-adapters` - - - type: checkboxes - attributes: - label: Housekeeping - description: > - A couple friendly reminders: - 1. Remove the `user docs` label if the scope of this work does not require changes to https://docs.getdbt.com/docs: no end-user interface (e.g. yml spec, CLI, error messages, etc) or functional changes. - 2. Only remove the `refinement` label if you're confident this is ready to estimate and/or you've received feedback from at least one other engineer. - 3. Will this change need to be backported? Add the appropriate `backport 1.x.latest` label(s). - 4. Will this impact other teams? Add the appropriate `Impact:[Team]` labels. - options: - - label: I am a maintainer of `dbt-adapters` - required: true - - - type: textarea - attributes: - label: Short description - description: | - Describe the scope of this feature, a high-level implementation approach and any tradeoffs to consider. - validations: - required: true - - - type: textarea - attributes: - label: Context - description: | - Provide the "why", motivation, and alternative approaches considered -- linking to previous refinement issues, spikes and documentation as appropriate. - validations: - required: false - - - type: textarea - attributes: - label: Acceptance criteria - description: | - What is the definition of done for this feature? Include any relevant edge cases and/or test cases. - - Example: - - [ ] If there are no config changes, don't alter the materialized view - - [ ] If the materialized view is scheduled to refresh, a manual refresh should not be issued - value: | - ```[tasklist] - - [ ] Criterion - ``` - validations: - required: true - - - type: textarea - attributes: - label: Testing - description: | - Provide scenarios to test. Include both positive and negative tests if possible. - Link to existing similar tests if appropriate. - - Example: - - [ ] Test with no `materialized` field in the model config. Expect pass. - - [ ] Test with a `materialized` field in the model config that is not valid. Expect ConfigError. - value: | - ```[tasklist] - - [ ] Test - ``` - validations: - required: true - - - type: textarea - attributes: - label: Security - description: | - Are there any security concerns with these changes? - When in doubt, run it by the security team. - placeholder: | - Example: Logging sensitive data - validations: - required: true - - - type: textarea - attributes: - label: Docs - description: | - Are there any docs the will need to be added or updated? - placeholder: | - Example: We need to document how to configure this new authentication method. - validations: - required: true - - - type: textarea - attributes: - label: Consequences - description: | - Will this impact dbt Labs' ability, or a partner's ability, to make a related change? Call that out for discussion. - Review `Impact:<team>` labels to ensure they capture these consequences. - placeholder: | - Example: - - This change impacts `dbt-databricks` because we updated a macro in `dbt-spark`. (Add the `Impact:[Databricks]` label.) - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/regression-report.yml b/.github/ISSUE_TEMPLATE/regression-report.yml index 6831ede2..95f073ca 100644 --- a/.github/ISSUE_TEMPLATE/regression-report.yml +++ b/.github/ISSUE_TEMPLATE/regression-report.yml @@ -1,78 +1,97 @@ name: ☣️ Regression -description: Report a regression you've observed in a newer version of dbt-adapters +description: Report a regression you've observed in a newer version title: "[Regression] <title>" labels: ["regression", "triage"] body: - - type: markdown +- type: markdown attributes: - value: | - Thanks for taking the time to fill out this regression report! - - type: checkboxes + value: Thanks for taking the time to fill out this regression report! +- type: checkboxes attributes: - label: Is this a regression? - description: > - A regression is when documented functionality works as expected in an older version of the software - and no longer works after upgrading to a newer version the software - options: - - label: I believe this is a regression in functionality - required: true - - label: I have searched the existing issues, and I could not find an existing issue for this regression - required: true - - type: textarea + label: Is this a regression? + description: > + A regression is when documented functionality works as expected in an older version of the software + and no longer works after upgrading to a newer version the software + options: + - label: I believe this is a regression in functionality + required: true + - label: I have searched the existing issues, and I could not find an existing issue for this regression + required: true +- type: checkboxes attributes: - label: Current Behavior - description: A concise description of what you're experiencing. + label: Which packages are affected? + description: Select one or more options below. + options: + - label: dbt-adapters + required: false + - label: dbt-tests-adapter + required: false + - label: dbt-athena + required: false + - label: dbt-athena-community + required: false +- type: textarea + attributes: + label: Current Behavior + description: A concise description of what you're experiencing. validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Expected/Previous Behavior - description: A concise description of what you expected to happen. + label: Expected/Previous Behavior + description: A concise description of what you expected to happen. validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Steps To Reproduce - description: Steps to reproduce the behavior. - placeholder: | - 1. In this environment... - 2. With this config... - 3. Run '...' - 4. See error... + label: Steps To Reproduce + description: Steps to reproduce the behavior. + placeholder: | + 1. In this environment... + 2. With this config... + 3. Run '...' + 4. See error... validations: - required: true - - type: textarea + required: true +- type: textarea id: logs attributes: - label: Relevant log output - description: | - If applicable, log output to help explain your problem. - render: shell + label: Relevant log output + description: If applicable, log output to help explain your problem. + render: shell validations: - required: false - - type: textarea + required: false +- type: textarea attributes: - label: Environment - description: | - examples: - - **OS**: Ubuntu 20.04 - - **Python**: 3.11.6 (`python3 --version`) - - **dbt-adapters (working version)**: 1.1.0 - - **dbt-adapters (regression version)**: 1.2.0 - value: | - - OS: - - Python: - - dbt-adapters (working version): - - dbt-adapters (regression version): - render: markdown + label: Environment + description: | + examples: + - **OS**: Ubuntu 20.04 + - **Python**: 3.11.6 (`python3 --version`) + - **dbt-adapters (working version)**: 1.1.0 + - **dbt-adapters (regression version)**: 1.2.0 + - **dbt-core (working version)**: 1.8.1 (`dbt --version`) + - **dbt-core (regression version)**: 1.9.0 (`dbt --version`) + - **dbt-postgres (working version)**: 1.8.0 (`dbt --version`) + - **dbt-postgres (regression version)**: 1.9.0 (`dbt --version`) + value: | + - OS: + - Python: + - dbt-adapters (working version): + - dbt-adapters (regression version): + - dbt-core (working version): + - dbt-core (regression version): + - <adapter> (working version): + - <adapter> (regression version): + render: markdown validations: - required: true - - type: textarea + required: true +- type: textarea attributes: - label: Additional Context - description: | - Links? References? Anything that will give us more context about the issue you are encountering! + label: Additional Context + description: | + Links? References? Anything that will give us more context about the issue you are encountering! - Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. validations: - required: false + required: false diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..5193eb3a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,12 @@ +resolves #<issue-number> + +# Description + +<!-- What approach was taken to resolve the issue and why? --> + +## Checklist + +- [ ] This PR considers the [contributing guide](https://github.com/dbt-labs/dbt-adapters#contributing) +- [ ] This PR is small and focused on a single feature or bug fix +- [ ] This PR includes unit testing, or unit testing is not necessary +- [ ] This PR includes functional testing, or functional testing is not necessary diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 907926a3..4d450671 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,29 +1,47 @@ version: 2 updates: - - package-ecosystem: "pip" +- package-ecosystem: "pip" directory: "/" schedule: - interval: "daily" + interval: "daily" rebase-strategy: "disabled" ignore: - - dependency-name: "*" + - dependency-name: "*" update-types: - - version-update:semver-patch - - package-ecosystem: "pip" + - version-update:semver-patch +- package-ecosystem: "pip" directory: "/dbt-tests-adapter" schedule: - interval: "daily" + interval: "daily" rebase-strategy: "disabled" ignore: - - dependency-name: "*" + - dependency-name: "*" update-types: - - version-update:semver-patch - - package-ecosystem: "github-actions" + - version-update:semver-patch +- package-ecosystem: "pip" + directory: "/dbt-athena" + schedule: + interval: "daily" + rebase-strategy: "disabled" + ignore: + - dependency-name: "*" + update-types: + - version-update:semver-patch +- package-ecosystem: "pip" + directory: "/dbt-athena-community" + schedule: + interval: "daily" + rebase-strategy: "disabled" + ignore: + - dependency-name: "*" + update-types: + - version-update:semver-patch +- package-ecosystem: "github-actions" directory: "/" schedule: - interval: "weekly" + interval: "weekly" rebase-strategy: "disabled" ignore: - - dependency-name: "*" + - dependency-name: "*" update-types: - - version-update:semver-patch + - version-update:semver-patch diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md deleted file mode 100644 index bc82fd93..00000000 --- a/.github/pull_request_template.md +++ /dev/null @@ -1,50 +0,0 @@ -resolves # -[docs](https://github.com/dbt-labs/docs.getdbt.com/issues/new/choose) dbt-labs/docs.getdbt.com/# - -<!--- - Include the number of the issue addressed by this PR above if applicable. - PRs for code changes without an associated issue *will not be merged*. - See CONTRIBUTING.md for more information. - - Include the number of the docs issue that was opened for this PR. If - this change has no user-facing implications, "N/A" suffices instead. New - docs tickets can be created by clicking the link above or by going to - https://github.com/dbt-labs/docs.getdbt.com/issues/new/choose. ---> - -### Problem - -<!--- - Describe the problem this PR is solving. What is the application state - before this PR is merged? ---> - -### Solution - -<!--- - Describe the way this PR solves the above problem. Add as much detail as you - can to help reviewers understand your changes. Include any alternatives and - tradeoffs you considered. ---> - -### Concrete Adapter Testing - -At the appropriate stage of development or review, please use an integration test workflow in each of the following repos against your branch. - -Use these to confirm that your feature add or bug fix (1) achieves the desired behavior (2) does not disrupt other concrete adapters: -* [ ] Postgres -* [ ] Snowflake -* [ ] Spark -* [ ] Redshift -* [ ] Bigquery - -Please link to each CI invocation workflow in this checklist here or in a separate PR comment. - -*Note*: Before hitting merge, best practice is to test against your PR's latest SHA. - -### Checklist - -- [ ] I have read [the contributing guide](https://github.com/dbt-labs/dbt-adapters/blob/main/CONTRIBUTING.md) and understand what's expected of me -- [ ] I have run this code in development, and it appears to resolve the stated issue -- [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] This PR has no interface changes (e.g. macros, cli, logs, json artifacts, config files, adapter interface, etc.) or this PR has already received feedback and approval from Product or DX diff --git a/.github/workflows/_integration-tests.yml b/.github/workflows/_integration-tests.yml new file mode 100644 index 00000000..8adf0b27 --- /dev/null +++ b/.github/workflows/_integration-tests.yml @@ -0,0 +1,81 @@ +name: "Integration tests" + +on: + workflow_call: + inputs: + package: + description: "Choose the package to test" + type: string + default: "dbt-athena" + branch: + description: "Choose the branch to test" + type: string + default: "main" + repository: + description: "Choose the repository to test, when using a fork" + type: string + default: "dbt-labs/dbt-athena" + os: + description: "Choose the OS to test against" + type: string + default: "ubuntu-22.04" + python-version: + description: "Choose the Python version to test against" + type: string + default: "3.9" + workflow_dispatch: + inputs: + package: + description: "Choose the package to test" + type: choice + options: ["dbt-athena", "dbt-athena-community"] + branch: + description: "Choose the branch to test" + type: string + default: "main" + repository: + description: "Choose the repository to test, when using a fork" + type: string + default: "dbt-labs/dbt-athena" + os: + description: "Choose the OS to test against" + type: string + default: "ubuntu-22.04" + python-version: + description: "Choose the Python version to test against" + type: choice + options: ["3.9", "3.10", "3.11", "3.12"] + +permissions: + id-token: write + contents: read + +env: + DBT_TEST_ATHENA_S3_STAGING_DIR: ${{ vars.DBT_TEST_ATHENA_S3_BUCKET }}/staging/ + DBT_TEST_ATHENA_S3_TMP_TABLE_DIR: ${{ vars.DBT_TEST_ATHENA_S3_BUCKET }}/tmp_tables/ + DBT_TEST_ATHENA_REGION_NAME: ${{ vars.DBT_TEST_ATHENA_REGION_NAME }} + DBT_TEST_ATHENA_DATABASE: awsdatacatalog + DBT_TEST_ATHENA_SCHEMA: dbt-tests + DBT_TEST_ATHENA_WORK_GROUP: athena-dbt-tests + DBT_TEST_ATHENA_THREADS: 16 + DBT_TEST_ATHENA_POLL_INTERVAL: 0.5 + DBT_TEST_ATHENA_NUM_RETRIES: 3 + +jobs: + integration-tests: + runs-on: ${{ inputs.os }} + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.branch }} + repository: ${{ inputs.repository }} + - uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + - uses: pypa/hatch@install + - uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.ASSUMABLE_ROLE_NAME }} + aws-region: ${{ vars.DBT_TEST_ATHENA_REGION_NAME }} + - run: hatch run integration-tests + working-directory: ./${{ inputs.package }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 421a66ad..8535ddaf 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -39,13 +39,21 @@ jobs: package: ${{ inputs.package }} branch: ${{ inputs.branch }} + integration-tests: + uses: ./.github/workflows/_integration-tests.yml + with: + branch: ${{ inputs.branch }} + repository: ${{ github.repository }} + secrets: inherit + generate-changelog: - needs: unit-tests + needs: [unit-tests, integration-tests] uses: ./.github/workflows/_generate-changelog.yml with: package: ${{ inputs.package }} merge: ${{ inputs.deploy-to == 'prod' }} branch: ${{ inputs.branch }} + repository: ${{ github.repository }} secrets: inherit publish-internal: diff --git a/.github/workflows/pull-request-checks.yml b/.github/workflows/pull-request-checks.yml index 0fd958ee..312e7cf6 100644 --- a/.github/workflows/pull-request-checks.yml +++ b/.github/workflows/pull-request-checks.yml @@ -14,7 +14,9 @@ jobs: changelog-entry: uses: ./.github/workflows/_changelog-entry-check.yml with: + package: "dbt-athena" pull-request: ${{ github.event.pull_request.number }} + secrets: inherit code-quality: uses: ./.github/workflows/_code-quality.yml @@ -25,31 +27,54 @@ jobs: verify-builds: uses: ./.github/workflows/_verify-build.yml strategy: + fail-fast: false matrix: - package: ["dbt-adapters", "dbt-tests-adapter"] + package: ["dbt-adapters", "dbt-tests-adapter", "dbt-athena", "dbt-athena-community"] + os: [ubuntu-22.04] python-version: ["3.9", "3.10", "3.11", "3.12"] with: package: ${{ matrix.package }} branch: ${{ github.event.pull_request.head.ref }} repository: ${{ github.event.pull_request.head.repo.full_name }} + os: ${{ matrix.os }} python-version: ${{ matrix.python-version }} unit-tests: uses: ./.github/workflows/_unit-tests.yml strategy: + fail-fast: false matrix: - package: ["dbt-adapters"] + package: ["dbt-adapters", "dbt-athena", "dbt-athena-community"] + os: [ ubuntu-22.04 ] python-version: ["3.9", "3.10", "3.11", "3.12"] with: package: ${{ matrix.package }} branch: ${{ github.event.pull_request.head.ref }} repository: ${{ github.event.pull_request.head.repo.full_name }} + os: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + + integration-tests: + uses: ./.github/workflows/_integration-tests.yml + strategy: + fail-fast: false + matrix: + package: ["dbt-athena", "dbt-athena-community"] + os: [ubuntu-22.04] + python-version: ["3.9", "3.10", "3.11", "3.12"] + with: + package: ${{ matrix.package }} + branch: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + os: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + secrets: inherit # This job does nothing and is only used for branch protection results: name: "Pull request checks" # keep this name, branch protection references it if: always() - needs: [changelog-entry, code-quality, verify-builds, unit-tests] + needs: [changelog-entry, code-quality, verify-builds, unit-tests, integration-tests] runs-on: ${{ vars.DEFAULT_RUNNER }} steps: - uses: re-actors/alls-green@release/v1 diff --git a/dbt-athena-community/README.md b/dbt-athena-community/README.md new file mode 100644 index 00000000..c5487c05 --- /dev/null +++ b/dbt-athena-community/README.md @@ -0,0 +1,882 @@ +<!-- markdownlint-disable-next-line MD041 --> +<p align="center"> + <img + src="https://raw.githubusercontent.com/dbt-labs/dbt/ec7dee39f793aa4f7dd3dae37282cc87664813e4/etc/dbt-logo-full.svg" + alt="dbt logo" width="500"/> +</p> +<p align="center"> + <a href="https://pypi.org/project/dbt-athena-community/"> + <img src="https://badge.fury.io/py/dbt-athena-community.svg" /> + </a> + <a target="_blank" href="https://pypi.org/project/dbt-athena-community/" style="background:none"> + <img src="https://img.shields.io/pypi/pyversions/dbt-athena-community"> + </a> + <a href="https://pycqa.github.io/isort/"> + <img src="https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336" /> + </a> + <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a> + <a href="https://github.com/python/mypy"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" /></a> + <a href="https://pepy.tech/project/dbt-athena-community"> + <img src="https://static.pepy.tech/badge/dbt-athena-community/month" /> + </a> +</p> + +<!-- TOC --> +- [Features](#features) + - [Quick start](#quick-start) + - [Installation](#installation) + - [Prerequisites](#prerequisites) + - [Credentials](#credentials) + - [Configuring your profile](#configuring-your-profile) + - [Additional information](#additional-information) + - [Models](#models) + - [Table configuration](#table-configuration) + - [Table location](#table-location) + - [Incremental models](#incremental-models) + - [On schema change](#on-schema-change) + - [Iceberg](#iceberg) + - [Highly available table (HA)](#highly-available-table-ha) + - [HA known issues](#ha-known-issues) + - [Update glue data catalog](#update-glue-data-catalog) + - [Snapshots](#snapshots) + - [Timestamp strategy](#timestamp-strategy) + - [Check strategy](#check-strategy) + - [Hard-deletes](#hard-deletes) + - [Working example](#working-example) + - [Snapshots known issues](#snapshots-known-issues) + - [AWS Lake Formation integration](#aws-lake-formation-integration) + - [Python models](#python-models) + - [Contracts](#contracts) + - [Contributing](#contributing) + - [Contributors ✨](#contributors-) +<!-- TOC --> + +# Features + +- Supports dbt version `1.7.*` +- Support for Python +- Supports [seeds][seeds] +- Correctly detects views and their columns +- Supports [table materialization][table] + - [Iceberg tables][athena-iceberg] are supported **only with Athena Engine v3** and **a unique table location** + (see table location section below) + - Hive tables are supported by both Athena engines +- Supports [incremental models][incremental] + - On Iceberg tables: + - Supports the use of `unique_key` only with the `merge` strategy + - Supports the `append` strategy + - On Hive tables: + - Supports two incremental update strategies: `insert_overwrite` and `append` + - Does **not** support the use of `unique_key` +- Supports [snapshots][snapshots] +- Supports [Python models][python-models] + +[seeds]: https://docs.getdbt.com/docs/building-a-dbt-project/seeds + +[incremental]: https://docs.getdbt.com/docs/build/incremental-models + +[table]: https://docs.getdbt.com/docs/build/materializations#table + +[python-models]: https://docs.getdbt.com/docs/build/python-models#configuring-python-models + +[athena-iceberg]: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html + +[snapshots]: https://docs.getdbt.com/docs/build/snapshots + +## Quick start + +### Installation + +- `pip install dbt-athena-community` +- Or `pip install git+https://github.com/dbt-athena/dbt-athena.git` + +### Prerequisites + +To start, you will need an S3 bucket, for instance `my-bucket` and an Athena database: + +```sql +CREATE DATABASE IF NOT EXISTS analytics_dev +COMMENT 'Analytics models generated by dbt (development)' +LOCATION 's3://my-bucket/' +WITH DBPROPERTIES ('creator'='Foo Bar', 'email'='foo@bar.com'); +``` + +Notes: + +- Take note of your AWS region code (e.g. `us-west-2` or `eu-west-2`, etc.). +- You can also use [AWS Glue](https://docs.aws.amazon.com/athena/latest/ug/glue-athena.html) to create and manage Athena + databases. + +### Credentials + +Credentials can be passed directly to the adapter, or they can +be [determined automatically](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) based +on `aws cli`/`boto3` conventions. +You can either: + +- Configure `aws_access_key_id` and `aws_secret_access_key` +- Configure `aws_profile_name` to match a profile defined in your AWS credentials file. + Checkout dbt profile configuration below for details. + +### Configuring your profile + +A dbt profile can be configured to run against AWS Athena using the following configuration: + +| Option | Description | Required? | Example | +|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------| +| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` | +| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` | +| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` | +| s3_tmp_table_dir | Prefix for storing temporary tables, if different from the connection's `s3_data_dir` | Optional | `s3://bucket3/dbt/` | +| region_name | AWS region of your Athena instance | Required | `eu-west-1` | +| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` | +| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` | +| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` | +| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` | +| aws_access_key_id | Access key ID of the user performing requests | Optional | `AKIAIOSFODNN7EXAMPLE` | +| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | +| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | +| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | +| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` | +| num_retries | Number of times to retry a failing query | Optional | `3` | +| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | +| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` | +| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | +| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | + +**Example profiles.yml entry:** + +```yaml +athena: + target: dev + outputs: + dev: + type: athena + s3_staging_dir: s3://athena-query-results/dbt/ + s3_data_dir: s3://your_s3_bucket/dbt/ + s3_data_naming: schema_table + s3_tmp_table_dir: s3://your_s3_bucket/temp/ + region_name: eu-west-1 + schema: dbt + database: awsdatacatalog + threads: 4 + aws_profile_name: my-profile + work_group: my-workgroup + spark_work_group: my-spark-workgroup + seed_s3_upload_args: + ACL: bucket-owner-full-control +``` + +### Additional information + +- `threads` is supported +- `database` and `catalog` can be used interchangeably + +## Models + +### Table configuration + +- `external_location` (`default=none`) + - If set, the full S3 path to which the table will be saved + - Works only with incremental models + - Does not work with Hive table with `ha` set to true +- `partitioned_by` (`default=none`) + - An array list of columns by which the table will be partitioned + - Limited to creation of 100 partitions (*currently*) +- `bucketed_by` (`default=none`) + - An array list of columns to bucket data, ignored if using Iceberg +- `bucket_count` (`default=none`) + - The number of buckets for bucketing your data, ignored if using Iceberg +- `table_type` (`default='hive'`) + - The type of table + - Supports `hive` or `iceberg` +- `ha` (`default=false`) + - If the table should be built using the high-availability method. This option is only available for Hive tables + since it is by default for Iceberg tables (see the section [below](#highly-available-table-ha)) +- `format` (`default='parquet'`) + - The data format for the table + - Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, `TEXTFILE` +- `write_compression` (`default=none`) + - The compression type to use for any storage format that allows compression to be specified. To see which options are + available, check out [CREATE TABLE AS][create-table-as] +- `field_delimiter` (`default=none`) + - Custom field delimiter, for when format is set to `TEXTFILE` +- `table_properties`: table properties to add to the table, valid for Iceberg only +- `native_drop`: Relation drop operations will be performed with SQL, not direct Glue API calls. No S3 calls will be + made to manage data in S3. Data in S3 will only be cleared up for Iceberg + tables [see AWS docs](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-managing-tables.html). Note that + Iceberg DROP TABLE operations may timeout if they take longer than 60 seconds. +- `seed_by_insert` (`default=false`) + - Default behaviour uploads seed data to S3. This flag will create seeds using an SQL insert statement + - Large seed files cannot use `seed_by_insert`, as the SQL insert statement would + exceed [the Athena limit of 262144 bytes](https://docs.aws.amazon.com/athena/latest/ug/service-limits.html) +- `force_batch` (`default=false`) + - Skip creating the table as CTAS and run the operation directly in batch insert mode + - This is particularly useful when the standard table creation process fails due to partition limitations, + allowing you to work with temporary tables and persist the dataset more efficiently +- `unique_tmp_table_suffix` (`default=false`) + - For incremental models using insert overwrite strategy on hive table + - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid + - Useful if you are looking to run multiple dbt build inserting in the same table in parallel +- `temp_schema` (`default=none`) + - For incremental models, it allows to define a schema to hold temporary create statements + used in incremental model runs + - Schema will be created in the model target database if does not exist +- `lf_tags_config` (`default=none`) + - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns + - `enabled` (`default=False`) whether LF tags management is enabled for a model + - `tags` dictionary with tags and their values to assign for the model + - `tags_columns` dictionary with a tag key, value and list of columns they must be assigned to + - `lf_inherited_tags` (`default=none`) + - List of Lake Formation tag keys that are intended to be inherited from the database level and thus shouldn't be + removed during association of those defined in `lf_tags_config` + - i.e., the default behavior of `lf_tags_config` is to be exhaustive and first remove any pre-existing tags from + tables and columns before associating the ones currently defined for a given model + - This breaks tag inheritance as inherited tags appear on tables and columns like those associated directly + +```sql +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change='append_new_columns', + table_type='iceberg', + schema='test_schema', + lf_tags_config={ + 'enabled': true, + 'tags': { + 'tag1': 'value1', + 'tag2': 'value2' + }, + 'tags_columns': { + 'tag1': { + 'value1': ['column1', 'column2'], + 'value2': ['column3', 'column4'] + } + }, + 'inherited_tags': ['tag1', 'tag2'] + } + ) +}} +``` + +- Format for `dbt_project.yml`: + +```yaml + +lf_tags_config: + enabled: true + tags: + tag1: value1 + tag2: value2 + tags_columns: + tag1: + value1: [ column1, column2 ] + inherited_tags: [ tag1, tag2 ] +``` + +- `lf_grants` (`default=none`) + - Lake Formation grants config for data_cell filters + - Format: + + ```python + lf_grants={ + 'data_cell_filters': { + 'enabled': True | False, + 'filters': { + 'filter_name': { + 'row_filter': '<filter_condition>', + 'principals': ['principal_arn1', 'principal_arn2'] + } + } + } + } + ``` + +> Notes: +> +> - `lf_tags` and `lf_tags_columns` configs support only attaching lf tags to corresponding resources. +> We recommend managing LF Tags permissions somewhere outside dbt. For example, you may use +> [terraform](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) or +> [aws cdk](https://docs.aws.amazon.com/cdk/api/v1/docs/aws-lakeformation-readme.html) for such purpose. +> - `data_cell_filters` management can't be automated outside dbt because the filter can't be attached to the table +> which doesn't exist. Once you `enable` this config, dbt will set all filters and their permissions during every +> dbt run. Such approach keeps the actual state of row level security configuration actual after every dbt run and +> apply changes if they occur: drop, create, update filters and their permissions. +> - Any tags listed in `lf_inherited_tags` should be strictly inherited from the database level and never overridden at + the table and column level +> - Currently `dbt-athena` does not differentiate between an inherited tag association and an override of same it made +> previously +> - e.g. If an inherited tag is overridden by an `lf_tags_config` value in one DBT run, and that override is removed + prior to a subsequent run, the prior override will linger and no longer be encoded anywhere (in e.g. Terraform + where the inherited value is configured nor in the DBT project where the override previously existed but now is + gone) + +[create-table-as]: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + +### Table location + +The location a table is saved to is determined by: + +1. If `external_location` is defined, that value is used +2. If `s3_data_dir` is defined, the path is determined by that and `s3_data_naming` +3. If `s3_data_dir` is not defined, data is stored under `s3_staging_dir/tables/` + +Here all the options available for `s3_data_naming`: + +- `unique`: `{s3_data_dir}/{uuid4()}/` +- `table`: `{s3_data_dir}/{table}/` +- `table_unique`: `{s3_data_dir}/{table}/{uuid4()}/` +- `schema_table`: `{s3_data_dir}/{schema}/{table}/` +- `s3_data_naming=schema_table_unique`: `{s3_data_dir}/{schema}/{table}/{uuid4()}/` + +It's possible to set the `s3_data_naming` globally in the target profile, or overwrite the value in the table config, +or setting up the value for groups of model in dbt_project.yml. + +> Note: when using a workgroup with a default output location configured, `s3_data_naming` and any configured buckets +> are ignored and the location configured in the workgroup is used. + +### Incremental models + +Support for [incremental models](https://docs.getdbt.com/docs/build/incremental-models). + +These strategies are supported: + +- `insert_overwrite` (default): The insert overwrite strategy deletes the overlapping partitions from the destination + table, and then inserts the new records from the source. This strategy depends on the `partitioned_by` keyword! If no + partitions are defined, dbt will fall back to the `append` strategy. +- `append`: Insert new records without updating, deleting or overwriting any existing data. There might be duplicate + data (e.g. great for log or historical data). +- `merge`: Conditionally updates, deletes, or inserts rows into an Iceberg table. Used in combination with `unique_key`. + Only available when using Iceberg. + +### On schema change + +`on_schema_change` is an option to reflect changes of schema in incremental models. +The following options are supported: + +- `ignore` (default) +- `fail` +- `append_new_columns` +- `sync_all_columns` + +For details, please refer +to [dbt docs](https://docs.getdbt.com/docs/build/incremental-models#what-if-the-columns-of-my-incremental-model-change). + +### Iceberg + +The adapter supports table materialization for Iceberg. + +To get started just add this as your model: + +```sql +{{ config( + materialized='table', + table_type='iceberg', + format='parquet', + partitioned_by=['bucket(user_id, 5)'], + table_properties={ + 'optimize_rewrite_delete_file_threshold': '2' + } +) }} + +select 'A' as user_id, + 'pi' as name, + 'active' as status, + 17.89 as cost, + 1 as quantity, + 100000000 as quantity_big, + current_date as my_date +``` + +Iceberg supports bucketing as hidden partitions, therefore use the `partitioned_by` config to add specific bucketing +conditions. + +Iceberg supports several table formats for data : `PARQUET`, `AVRO` and `ORC`. + +It is possible to use Iceberg in an incremental fashion, specifically two strategies are supported: + +- `append`: New records are appended to the table, this can lead to duplicates. +- `merge`: Performs an upsert (and optional delete), where new records are added and existing records are updated. Only + available with Athena engine version 3. + - `unique_key` **(required)**: columns that define a unique record in the source and target tables. + - `incremental_predicates` (optional): SQL conditions that enable custom join clauses in the merge statement. This can + be useful for improving performance via predicate pushdown on the target table. + - `delete_condition` (optional): SQL condition used to identify records that should be deleted. + - `update_condition` (optional): SQL condition used to identify records that should be updated. + - `insert_condition` (optional): SQL condition used to identify records that should be inserted. + - `incremental_predicates`, `delete_condition`, `update_condition` and `insert_condition` can include any column of + the incremental table (`src`) or the final table (`target`). + Column names must be prefixed by either `src` or `target` to prevent a `Column is ambiguous` error. + +`delete_condition` example: + +```sql +{{ config( + materialized='incremental', + table_type='iceberg', + incremental_strategy='merge', + unique_key='user_id', + incremental_predicates=["src.quantity > 1", "target.my_date >= now() - interval '4' year"], + delete_condition="src.status != 'active' and target.my_date < now() - interval '2' year", + format='parquet' +) }} + +select 'A' as user_id, + 'pi' as name, + 'active' as status, + 17.89 as cost, + 1 as quantity, + 100000000 as quantity_big, + current_date as my_date +``` + +`update_condition` example: + +```sql +{{ config( + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + update_condition='target.id > 1', + schema='sandbox' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, 'v1-updated') + , (2, 'v2-updated') +) as t (id, value) + +{% else %} + +select * from ( + values + (-1, 'v-1') + , (0, 'v0') + , (1, 'v1') + , (2, 'v2') +) as t (id, value) + +{% endif %} +``` + +`insert_condition` example: + +```sql +{{ config( + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + insert_condition='target.status != 0', + schema='sandbox' + ) +}} + +select * from ( + values + (1, 0) + , (2, 1) +) as t (id, status) + +``` + +### Highly available table (HA) + +The current implementation of the table materialization can lead to downtime, as the target table is +dropped and re-created. To have the less destructive behavior it's possible to use the `ha` config on +your `table` materialized models. It leverages the table versions feature of glue catalog, creating +a temp table and swapping the target table to the location of the temp table. This materialization +is only available for `table_type=hive` and requires using unique locations. For iceberg, high +availability is the default. + +```sql +{{ config( + materialized='table', + ha=true, + format='parquet', + table_type='hive', + partitioned_by=['status'], + s3_data_naming='table_unique' +) }} + +select 'a' as user_id, + 'pi' as user_name, + 'active' as status +union all +select 'b' as user_id, + 'sh' as user_name, + 'disabled' as status +``` + +By default, the materialization keeps the last 4 table versions, you can change it by setting `versions_to_keep`. + +#### HA known issues + +- When swapping from a table with partitions to a table without (and the other way around), there could be a little + downtime. + If high performances is needed consider bucketing instead of partitions +- By default, Glue "duplicates" the versions internally, so the last two versions of a table point to the same location +- It's recommended to set `versions_to_keep` >= 4, as this will avoid having the older location removed + +### Update glue data catalog + +Optionally persist resource descriptions as column and relation comments to the glue data catalog, and meta as +[glue table properties](https://docs.aws.amazon.com/glue/latest/dg/tables-described.html#table-properties) +and [column parameters](https://docs.aws.amazon.com/glue/latest/webapi/API_Column.html). +By default, documentation persistence is disabled, but it can be enabled for specific resources or +groups of resources as needed. + +For example: + +```yaml +models: + - name: test_deduplicate + description: another value + config: + persist_docs: + relation: true + columns: true + meta: + test: value + columns: + - name: id + meta: + primary_key: true +``` + +See [persist docs](https://docs.getdbt.com/reference/resource-configs/persist_docs) for more details. + +## Snapshots + +The adapter supports snapshot materialization. It supports both timestamp and check strategy. To create a snapshot +create a snapshot file in the snapshots directory. If the directory does not exist create one. + +### Timestamp strategy + +To use the timestamp strategy refer to +the [dbt docs](https://docs.getdbt.com/docs/build/snapshots#timestamp-strategy-recommended) + +### Check strategy + +To use the check strategy refer to the [dbt docs](https://docs.getdbt.com/docs/build/snapshots#check-strategy) + +### Hard-deletes + +The materialization also supports invalidating hard deletes. Check +the [docs](https://docs.getdbt.com/docs/build/snapshots#hard-deletes-opt-in) to understand usage. + +### Working example + +seed file - employent_indicators_november_2022_csv_tables.csv + +```csv +Series_reference,Period,Data_value,Suppressed +MEIM.S1WA,1999.04,80267, +MEIM.S1WA,1999.05,70803, +MEIM.S1WA,1999.06,65792, +MEIM.S1WA,1999.07,66194, +MEIM.S1WA,1999.08,67259, +MEIM.S1WA,1999.09,69691, +MEIM.S1WA,1999.1,72475, +MEIM.S1WA,1999.11,79263, +MEIM.S1WA,1999.12,86540, +MEIM.S1WA,2000.01,82552, +MEIM.S1WA,2000.02,81709, +MEIM.S1WA,2000.03,84126, +MEIM.S1WA,2000.04,77089, +MEIM.S1WA,2000.05,73811, +MEIM.S1WA,2000.06,70070, +MEIM.S1WA,2000.07,69873, +MEIM.S1WA,2000.08,71468, +MEIM.S1WA,2000.09,72462, +MEIM.S1WA,2000.1,74897, +``` + +model.sql + +```sql +{{ config( + materialized='table' +) }} + +select row_number() over() as id + , * + , cast(from_unixtime(to_unixtime(now())) as timestamp(6)) as refresh_timestamp +from {{ ref('employment_indicators_november_2022_csv_tables') }} +``` + +timestamp strategy - model_snapshot_1 + +```sql +{% snapshot model_snapshot_1 %} + +{{ + config( + strategy='timestamp', + updated_at='refresh_timestamp', + unique_key='id' + ) +}} + +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +invalidate hard deletes - model_snapshot_2 + +```sql +{% snapshot model_snapshot_2 %} + +{{ + config + ( + unique_key='id', + strategy='timestamp', + updated_at='refresh_timestamp', + invalidate_hard_deletes=True, + ) +}} +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +check strategy - model_snapshot_3 + +```sql +{% snapshot model_snapshot_3 %} + +{{ + config + ( + unique_key='id', + strategy='check', + check_cols=['series_reference','data_value'] + ) +}} +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +### Snapshots known issues + +- Incremental Iceberg models - Sync all columns on schema change can't remove columns used for partitioning. + The only way, from a dbt perspective, is to do a full-refresh of the incremental model. + +- Tables, schemas and database names should only be lowercase + +- In order to avoid potential conflicts, make sure [`dbt-athena-adapter`](https://github.com/Tomme/dbt-athena) is not + installed in the target environment. + See <https://github.com/dbt-athena/dbt-athena/issues/103> for more details. + +- Snapshot does not support dropping columns from the source table. If you drop a column make sure to drop the column + from the snapshot as well. Another workaround is to NULL the column in the snapshot definition to preserve history + +## AWS Lake Formation integration + +The adapter implements AWS Lake Formation tags management in the following way: + +- You can enable or disable lf-tags management via [config](#table-configuration) (disabled by default) +- Once you enable the feature, lf-tags will be updated on every dbt run +- First, all lf-tags for columns are removed to avoid inheritance issues +- Then, all redundant lf-tags are removed from tables and actual tags from table configs are applied +- Finally, lf-tags for columns are applied + +It's important to understand the following points: + +- dbt does not manage lf-tags for databases +- dbt does not manage Lake Formation permissions + +That's why you should handle this by yourself manually or using an automation tool like terraform, AWS CDK etc. +You may find the following links useful to manage that: + +<!-- markdownlint-disable --> +* [terraform aws_lakeformation_permissions](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) +* [terraform aws_lakeformation_resource_lf_tags](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_resource_lf_tags) +<!-- markdownlint-restore --> + +## Python models + +The adapter supports Python models using [`spark`](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark.html). + +### Setup + +- A Spark-enabled workgroup created in Athena +- Spark execution role granted access to Athena, Glue and S3 +- The Spark workgroup is added to the `~/.dbt/profiles.yml` file and the profile to be used + is referenced in `dbt_project.yml` + +### Spark-specific table configuration + +- `timeout` (`default=43200`) + - Time out in seconds for each Python model execution. Defaults to 12 hours/43200 seconds. +- `spark_encryption` (`default=false`) + - If this flag is set to true, encrypts data in transit between Spark nodes and also encrypts data at rest stored + locally by Spark. +- `spark_cross_account_catalog` (`default=false`) + - When using the Spark Athena workgroup, queries can only be made against catalogs located on the same + AWS account by default. However, sometimes you want to query another catalog located on an external AWS + account. Setting this additional Spark properties parameter to true will enable querying external catalogs. + You can use the syntax `external_catalog_id/database.table` to access the external table on the external + catalog (ex: `999999999999/mydatabase.cloudfront_logs` where 999999999999 is the external catalog ID) +- `spark_requester_pays` (`default=false`) + - When an Amazon S3 bucket is configured as requester pays, the account of the user running the query is charged for + data access and data transfer fees associated with the query. + - If this flag is set to true, requester pays S3 buckets are enabled in Athena for Spark. + +### Spark notes + +- A session is created for each unique engine configuration defined in the models that are part of the invocation. +- A session's idle timeout is set to 10 minutes. Within the timeout period, if there is a new calculation + (Spark Python model) ready for execution and the engine configuration matches, the process will reuse the same session. +- The number of Python models running at a time depends on the `threads`. The number of sessions created for the + entire run depends on the number of unique engine configurations and the availability of sessions to maintain + thread concurrency. +- For Iceberg tables, it is recommended to use `table_properties` configuration to set the `format_version` to 2. + This is to maintain compatibility between Iceberg tables created by Trino with those created by Spark. + +### Example models + +#### Simple pandas model + +```python +import pandas as pd + + +def model(dbt, session): + dbt.config(materialized="table") + + model_df = pd.DataFrame({"A": [1, 2, 3, 4]}) + + return model_df +``` + +#### Simple spark + +```python +def model(dbt, spark_session): + dbt.config(materialized="table") + + data = [(1,), (2,), (3,), (4,)] + + df = spark_session.createDataFrame(data, ["A"]) + + return df +``` + +#### Spark incremental + +```python +def model(dbt, spark_session): + dbt.config(materialized="incremental") + df = dbt.ref("model") + + if dbt.is_incremental: + max_from_this = ( + f"select max(run_date) from {dbt.this.schema}.{dbt.this.identifier}" + ) + df = df.filter(df.run_date >= spark_session.sql(max_from_this).collect()[0][0]) + + return df +``` + +#### Config spark model + +```python +def model(dbt, spark_session): + dbt.config( + materialized="table", + engine_config={ + "CoordinatorDpuSize": 1, + "MaxConcurrentDpus": 3, + "DefaultExecutorDpuSize": 1 + }, + spark_encryption=True, + spark_cross_account_catalog=True, + spark_requester_pays=True + polling_interval=15, + timeout=120, + ) + + data = [(1,), (2,), (3,), (4,)] + + df = spark_session.createDataFrame(data, ["A"]) + + return df +``` + +#### Create pySpark udf using imported external python files + +```python +def model(dbt, spark_session): + dbt.config( + materialized="incremental", + incremental_strategy="merge", + unique_key="num", + ) + sc = spark_session.sparkContext + sc.addPyFile("s3://athena-dbt/test/file1.py") + sc.addPyFile("s3://athena-dbt/test/file2.py") + + def func(iterator): + from file2 import transform + + return [transform(i) for i in iterator] + + from pyspark.sql.functions import udf + from pyspark.sql.functions import col + + udf_with_import = udf(func) + + data = [(1, "a"), (2, "b"), (3, "c")] + cols = ["num", "alpha"] + df = spark_session.createDataFrame(data, cols) + + return df.withColumn("udf_test_col", udf_with_import(col("alpha"))) +``` + +### Known issues in Python models + +- Python models cannot + [reference Athena SQL views](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark.html). +- Third-party Python libraries can be used, but they must be [included in the pre-installed list][pre-installed list] + or [imported manually][imported manually]. +- Python models can only reference or write to tables with names meeting the + regular expression: `^[0-9a-zA-Z_]+$`. Dashes and special characters are not + supported by Spark, even though Athena supports them. +- Incremental models do not fully utilize Spark capabilities. They depend partially on existing SQL-based logic which + runs on Trino. +- Snapshot materializations are not supported. +- Spark can only reference tables within the same catalog. +- For tables created outside of the dbt tool, be sure to populate the location field or dbt will throw an error +when trying to create the table. + +[pre-installed list]: https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-preinstalled-python-libraries.html +[imported manually]: https://docs.aws.amazon.com/athena/latest/ug/notebooks-import-files-libraries.html + +## Contracts + +The adapter partly supports contract definitions: + +- `data_type` is supported but needs to be adjusted for complex types. Types must be specified + entirely (for instance `array<int>`) even though they won't be checked. Indeed, as dbt recommends, we only compare + the broader type (array, map, int, varchar). The complete definition is used in order to check that the data types + defined in Athena are ok (pre-flight check). +- The adapter does not support the constraints since there is no constraint concept in Athena. + +## Contributing + +See [CONTRIBUTING](CONTRIBUTING.md) for more information on how to contribute to this project. + +## Contributors ✨ + +Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): + +<a href="https://github.com/dbt-athena/dbt-athena/graphs/contributors"> + <img src="https://contrib.rocks/image?repo=dbt-athena/dbt-athena" /> +</a> + +Contributions of any kind welcome! diff --git a/dbt-athena-community/hatch.toml b/dbt-athena-community/hatch.toml new file mode 100644 index 00000000..f1ef474c --- /dev/null +++ b/dbt-athena-community/hatch.toml @@ -0,0 +1,56 @@ +[build.targets.sdist] +include = ["src/dbt"] + +[build.targets.wheel] +packages = ["src/dbt"] + +[envs.default] +# the only build dependency is dbt-athena, which will never be published when running this +# because the versions need to be identical +detached = true +dependencies = [ + "dbt-athena @ {root:uri}/../dbt-athena", + "dbt-tests-adapter~=1.9.2", + "isort~=5.13", + "moto~=5.0.13", + "pre-commit~=3.5", + "pyparsing~=3.1.4", + "pytest~=8.3", + "pytest-cov~=5.0", + "pytest-dotenv~=0.5", + "pytest-xdist~=3.6", +] +[envs.default.scripts] +setup = [ + "pre-commit install", + "cp -n ../dbt-athena/test.env.example test.env", +] +code-quality = "pre-commit run --all-files" +unit-tests = "pytest --cov=dbt --cov-report=html:htmlcov {args:../dbt-athena/tests/unit}" +integration-tests = "python -m pytest -n auto {args:../dbt-athena/tests/functional}" +all-tests = ["unit-tests", "integration-tests"] + +[envs.build] +detached = true +dependencies = [ + "wheel", + "twine", + "check-wheel-contents", +] +[envs.build.scripts] +check-all = [ + "- check-wheel", + "- check-sdist", +] +check-wheel = [ + "check-wheel-contents dist/*.whl --ignore W007,W008", + "find ./dist/dbt_athena_community-*.whl -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/", + "pip freeze | grep dbt-athena-community", + "pip freeze | grep dbt-athena", +] +check-sdist = [ + "twine check dist/*", + "find ./dist/dbt_athena_community-*.gz -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/", + "pip freeze | grep dbt-athena-community", + "pip freeze | grep dbt-athena", +] diff --git a/dbt-athena-community/pyproject.toml b/dbt-athena-community/pyproject.toml new file mode 100644 index 00000000..b5de59bb --- /dev/null +++ b/dbt-athena-community/pyproject.toml @@ -0,0 +1,45 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "dbt-athena-community" +description = "The athena adapter plugin for dbt (data build tool)" +readme = "README.md" +keywords = ["dbt", "adapter", "adapters", "database", "elt", "dbt-core", "dbt Core", "dbt Cloud", "dbt Labs", "athena"] +requires-python = ">=3.9.0" +authors = [ + { name = "dbt Labs", email = "info@dbtlabs.com" }, +] +maintainers = [ + { name = "dbt Labs", email = "info@dbtlabs.com" }, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +# these versions should always match and they both should match the local version of dbt-athena +dependencies = ["dbt-athena==1.9.0"] +version = "1.9.0" +[project.urls] +Homepage = "https://github.com/dbt-labs/dbt-athena/dbt-athena" +Documentation = "https://docs.getdbt.com" +Repository = "https://github.com/dbt-labs/dbt-athena.git#subdirectory=dbt-athena" +Issues = "https://github.com/dbt-labs/dbt-athena/issues" + +[tool.pytest] +testpaths = [ + "../dbt-athena/tests/unit", + "../dbt-athena/tests/functional", +] +filterwarnings = [ + "ignore:.*'soft_unicode' has been renamed to 'soft_str'*:DeprecationWarning", + "ignore:unclosed file .*:ResourceWarning", +] diff --git a/dbt-athena-community/src/dbt/adapters/athena_community/__init__.py b/dbt-athena-community/src/dbt/adapters/athena_community/__init__.py new file mode 100644 index 00000000..c10eb7e8 --- /dev/null +++ b/dbt-athena-community/src/dbt/adapters/athena_community/__init__.py @@ -0,0 +1 @@ +# this is a shell package that allows us to publish dbt-athena as dbt-athena-community diff --git a/dbt-athena/.changes/0.0.0.md b/dbt-athena/.changes/0.0.0.md new file mode 100644 index 00000000..3e0643e7 --- /dev/null +++ b/dbt-athena/.changes/0.0.0.md @@ -0,0 +1,5 @@ +# Previous Releases + +For information on prior major and minor releases, see their changelogs: + +- [1.8](https://github.com/dbt-labs/dbt-athena/blob/main/CHANGELOG.md) diff --git a/dbt-athena/.changes/header.tpl.md b/dbt-athena/.changes/header.tpl.md new file mode 100644 index 00000000..44118343 --- /dev/null +++ b/dbt-athena/.changes/header.tpl.md @@ -0,0 +1,7 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), +and is generated by [Changie](https://github.com/miniscruff/changie). diff --git a/dbt-athena/.changes/unreleased/.gitkeep b/dbt-athena/.changes/unreleased/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/.changie.yaml b/dbt-athena/.changie.yaml new file mode 100644 index 00000000..649b8763 --- /dev/null +++ b/dbt-athena/.changie.yaml @@ -0,0 +1,144 @@ +changesDir: .changes +unreleasedDir: unreleased +headerPath: header.tpl.md +versionHeaderPath: "" +changelogPath: CHANGELOG.md +versionExt: md +envPrefix: "CHANGIE_" +versionFormat: '## dbt-athena {{.Version}} - {{.Time.Format "January 02, 2006"}}' +kindFormat: '### {{.Kind}}' +changeFormat: |- + {{- $IssueList := list }} + {{- $changes := splitList " " $.Custom.Issue }} + {{- range $issueNbr := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-athena/issues/nbr)" | replace "nbr" $issueNbr }} + {{- $IssueList = append $IssueList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $IssueList }}{{if $index}}, {{end}}{{$element}}{{end}}) +kinds: +- label: Breaking Changes +- label: Features +- label: Fixes +- label: Under the Hood +- label: Dependencies + changeFormat: |- + {{- $PRList := list }} + {{- $changes := splitList " " $.Custom.PR }} + {{- range $pullrequest := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-athena/pull/nbr)" | replace "nbr" $pullrequest }} + {{- $PRList = append $PRList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $PRList }}{{if $index}}, {{end}}{{$element}}{{end}}) + skipGlobalChoices: true + additionalChoices: + - key: Author + label: GitHub Username(s) (separated by a single space if multiple) + type: string + minLength: 3 + - key: PR + label: GitHub Pull Request Number (separated by a single space if multiple) + type: string + minLength: 1 +- label: Security + changeFormat: |- + {{- $PRList := list }} + {{- $changes := splitList " " $.Custom.PR }} + {{- range $pullrequest := $changes }} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-athena/pull/nbr)" | replace "nbr" $pullrequest }} + {{- $PRList = append $PRList $changeLink }} + {{- end -}} + - {{.Body}} ({{ range $index, $element := $PRList }}{{if $index}}, {{end}}{{$element}}{{end}}) + skipGlobalChoices: true + additionalChoices: + - key: Author + label: GitHub Username(s) (separated by a single space if multiple) + type: string + minLength: 3 + - key: PR + label: GitHub Pull Request Number (separated by a single space if multiple) + type: string + minLength: 1 +newlines: + afterChangelogHeader: 1 + afterKind: 1 + afterChangelogVersion: 1 + beforeKind: 1 + endOfVersion: 1 +custom: +- key: Author + label: GitHub Username(s) (separated by a single space if multiple) + type: string + minLength: 3 +- key: Issue + label: GitHub Issue Number (separated by a single space if multiple) + type: string + minLength: 1 +footerFormat: | + {{- /* We only want to include non-dbt contributors, so build a list of Core and Adapters maintainers for exclusion */ -}} + {{- $maintainers := list -}} + {{- $core_team := splitList " " .Env.CORE_TEAM -}} + {{- range $team_member := $core_team -}} + {{- /* ensure all names in this list are all lowercase for later matching purposes */ -}} + {{- $team_member_lower := lower $team_member -}} + {{- $maintainers = append $maintainers $team_member_lower -}} + {{- end -}} + + {{- /* Ensure we always skip dependabot */ -}} + {{- $maintainers = append $maintainers "dependabot[bot]" -}} + + {{- /* Build the list of contributors along with their PRs */ -}} + {{- $contributorDict := dict -}} + {{- range $change := .Changes -}} + {{- /* PRs can have multiple authors */ -}} + {{- $authorList := splitList " " $change.Custom.Author -}} + + {{- /* Loop through all non-dbt authors for this changelog */ -}} + {{- range $author := $authorList -}} + {{- $authorLower := lower $author -}} + {{- if not (has $authorLower $maintainers) -}} + + {{- $changeList := splitList " " $change.Custom.Author -}} + {{- $IssueList := list -}} + {{- $changeLink := $change.Kind -}} + + {{- /* Build the issue link */ -}} + {{- if or (eq $change.Kind "Dependencies") (eq $change.Kind "Security") -}} + {{- $changes := splitList " " $change.Custom.PR -}} + {{- range $issueNbr := $changes -}} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-athena/pull/nbr)" | replace "nbr" $issueNbr -}} + {{- $IssueList = append $IssueList $changeLink -}} + {{- end -}} + + {{- else -}} + {{- $changes := splitList " " $change.Custom.Issue -}} + {{- range $issueNbr := $changes -}} + {{- $changeLink := "[#nbr](https://github.com/dbt-labs/dbt-athena/issues/nbr)" | replace "nbr" $issueNbr -}} + {{- $IssueList = append $IssueList $changeLink -}} + {{- end -}} + + {{- end -}} + + {{- /* If this contributor has other changes associated with them already, add this issue to the list */ -}} + {{- if hasKey $contributorDict $author -}} + {{- $contributionList := get $contributorDict $author -}} + {{- $contributionList = concat $contributionList $IssueList -}} + {{- $contributorDict := set $contributorDict $author $contributionList -}} + + {{- /* Otherwise create a new entry for this contributor */ -}} + {{- else -}} + {{- $contributionList := $IssueList -}} + {{- $contributorDict := set $contributorDict $author $contributionList -}} + + {{- end -}} + + {{- end -}} + {{- end -}} + {{- end -}} + + {{- /* no indentation here for formatting so the final markdown doesn't have unneeded indentations */ -}} + {{- if $contributorDict }} + ### Contributors + {{- range $k,$v := $contributorDict }} + - [@{{ $k }}](https://github.com/{{ $k }}) ({{ range $index, $element := $v }}{{ if $index }}, {{ end }}{{ $element }}{{ end }}) + {{- end }} + {{- end }} diff --git a/dbt-athena/.env.example b/dbt-athena/.env.example new file mode 100644 index 00000000..ca571cfb --- /dev/null +++ b/dbt-athena/.env.example @@ -0,0 +1,10 @@ +DBT_TEST_ATHENA_S3_STAGING_DIR= +DBT_TEST_ATHENA_S3_TMP_TABLE_DIR= +DBT_TEST_ATHENA_REGION_NAME= +DBT_TEST_ATHENA_DATABASE= +DBT_TEST_ATHENA_SCHEMA= +DBT_TEST_ATHENA_WORK_GROUP= +DBT_TEST_ATHENA_THREADS= +DBT_TEST_ATHENA_POLL_INTERVAL= +DBT_TEST_ATHENA_NUM_RETRIES= +DBT_TEST_ATHENA_AWS_PROFILE_NAME= diff --git a/dbt-athena/CHANGELOG.md b/dbt-athena/CHANGELOG.md new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/CONTRIBUTING.md b/dbt-athena/CONTRIBUTING.md new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/README.md b/dbt-athena/README.md new file mode 100644 index 00000000..c5487c05 --- /dev/null +++ b/dbt-athena/README.md @@ -0,0 +1,882 @@ +<!-- markdownlint-disable-next-line MD041 --> +<p align="center"> + <img + src="https://raw.githubusercontent.com/dbt-labs/dbt/ec7dee39f793aa4f7dd3dae37282cc87664813e4/etc/dbt-logo-full.svg" + alt="dbt logo" width="500"/> +</p> +<p align="center"> + <a href="https://pypi.org/project/dbt-athena-community/"> + <img src="https://badge.fury.io/py/dbt-athena-community.svg" /> + </a> + <a target="_blank" href="https://pypi.org/project/dbt-athena-community/" style="background:none"> + <img src="https://img.shields.io/pypi/pyversions/dbt-athena-community"> + </a> + <a href="https://pycqa.github.io/isort/"> + <img src="https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336" /> + </a> + <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a> + <a href="https://github.com/python/mypy"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" /></a> + <a href="https://pepy.tech/project/dbt-athena-community"> + <img src="https://static.pepy.tech/badge/dbt-athena-community/month" /> + </a> +</p> + +<!-- TOC --> +- [Features](#features) + - [Quick start](#quick-start) + - [Installation](#installation) + - [Prerequisites](#prerequisites) + - [Credentials](#credentials) + - [Configuring your profile](#configuring-your-profile) + - [Additional information](#additional-information) + - [Models](#models) + - [Table configuration](#table-configuration) + - [Table location](#table-location) + - [Incremental models](#incremental-models) + - [On schema change](#on-schema-change) + - [Iceberg](#iceberg) + - [Highly available table (HA)](#highly-available-table-ha) + - [HA known issues](#ha-known-issues) + - [Update glue data catalog](#update-glue-data-catalog) + - [Snapshots](#snapshots) + - [Timestamp strategy](#timestamp-strategy) + - [Check strategy](#check-strategy) + - [Hard-deletes](#hard-deletes) + - [Working example](#working-example) + - [Snapshots known issues](#snapshots-known-issues) + - [AWS Lake Formation integration](#aws-lake-formation-integration) + - [Python models](#python-models) + - [Contracts](#contracts) + - [Contributing](#contributing) + - [Contributors ✨](#contributors-) +<!-- TOC --> + +# Features + +- Supports dbt version `1.7.*` +- Support for Python +- Supports [seeds][seeds] +- Correctly detects views and their columns +- Supports [table materialization][table] + - [Iceberg tables][athena-iceberg] are supported **only with Athena Engine v3** and **a unique table location** + (see table location section below) + - Hive tables are supported by both Athena engines +- Supports [incremental models][incremental] + - On Iceberg tables: + - Supports the use of `unique_key` only with the `merge` strategy + - Supports the `append` strategy + - On Hive tables: + - Supports two incremental update strategies: `insert_overwrite` and `append` + - Does **not** support the use of `unique_key` +- Supports [snapshots][snapshots] +- Supports [Python models][python-models] + +[seeds]: https://docs.getdbt.com/docs/building-a-dbt-project/seeds + +[incremental]: https://docs.getdbt.com/docs/build/incremental-models + +[table]: https://docs.getdbt.com/docs/build/materializations#table + +[python-models]: https://docs.getdbt.com/docs/build/python-models#configuring-python-models + +[athena-iceberg]: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html + +[snapshots]: https://docs.getdbt.com/docs/build/snapshots + +## Quick start + +### Installation + +- `pip install dbt-athena-community` +- Or `pip install git+https://github.com/dbt-athena/dbt-athena.git` + +### Prerequisites + +To start, you will need an S3 bucket, for instance `my-bucket` and an Athena database: + +```sql +CREATE DATABASE IF NOT EXISTS analytics_dev +COMMENT 'Analytics models generated by dbt (development)' +LOCATION 's3://my-bucket/' +WITH DBPROPERTIES ('creator'='Foo Bar', 'email'='foo@bar.com'); +``` + +Notes: + +- Take note of your AWS region code (e.g. `us-west-2` or `eu-west-2`, etc.). +- You can also use [AWS Glue](https://docs.aws.amazon.com/athena/latest/ug/glue-athena.html) to create and manage Athena + databases. + +### Credentials + +Credentials can be passed directly to the adapter, or they can +be [determined automatically](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) based +on `aws cli`/`boto3` conventions. +You can either: + +- Configure `aws_access_key_id` and `aws_secret_access_key` +- Configure `aws_profile_name` to match a profile defined in your AWS credentials file. + Checkout dbt profile configuration below for details. + +### Configuring your profile + +A dbt profile can be configured to run against AWS Athena using the following configuration: + +| Option | Description | Required? | Example | +|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------| +| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` | +| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` | +| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` | +| s3_tmp_table_dir | Prefix for storing temporary tables, if different from the connection's `s3_data_dir` | Optional | `s3://bucket3/dbt/` | +| region_name | AWS region of your Athena instance | Required | `eu-west-1` | +| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` | +| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` | +| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` | +| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` | +| aws_access_key_id | Access key ID of the user performing requests | Optional | `AKIAIOSFODNN7EXAMPLE` | +| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | +| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` | +| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | +| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` | +| num_retries | Number of times to retry a failing query | Optional | `3` | +| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` | +| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` | +| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` | +| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | + +**Example profiles.yml entry:** + +```yaml +athena: + target: dev + outputs: + dev: + type: athena + s3_staging_dir: s3://athena-query-results/dbt/ + s3_data_dir: s3://your_s3_bucket/dbt/ + s3_data_naming: schema_table + s3_tmp_table_dir: s3://your_s3_bucket/temp/ + region_name: eu-west-1 + schema: dbt + database: awsdatacatalog + threads: 4 + aws_profile_name: my-profile + work_group: my-workgroup + spark_work_group: my-spark-workgroup + seed_s3_upload_args: + ACL: bucket-owner-full-control +``` + +### Additional information + +- `threads` is supported +- `database` and `catalog` can be used interchangeably + +## Models + +### Table configuration + +- `external_location` (`default=none`) + - If set, the full S3 path to which the table will be saved + - Works only with incremental models + - Does not work with Hive table with `ha` set to true +- `partitioned_by` (`default=none`) + - An array list of columns by which the table will be partitioned + - Limited to creation of 100 partitions (*currently*) +- `bucketed_by` (`default=none`) + - An array list of columns to bucket data, ignored if using Iceberg +- `bucket_count` (`default=none`) + - The number of buckets for bucketing your data, ignored if using Iceberg +- `table_type` (`default='hive'`) + - The type of table + - Supports `hive` or `iceberg` +- `ha` (`default=false`) + - If the table should be built using the high-availability method. This option is only available for Hive tables + since it is by default for Iceberg tables (see the section [below](#highly-available-table-ha)) +- `format` (`default='parquet'`) + - The data format for the table + - Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, `TEXTFILE` +- `write_compression` (`default=none`) + - The compression type to use for any storage format that allows compression to be specified. To see which options are + available, check out [CREATE TABLE AS][create-table-as] +- `field_delimiter` (`default=none`) + - Custom field delimiter, for when format is set to `TEXTFILE` +- `table_properties`: table properties to add to the table, valid for Iceberg only +- `native_drop`: Relation drop operations will be performed with SQL, not direct Glue API calls. No S3 calls will be + made to manage data in S3. Data in S3 will only be cleared up for Iceberg + tables [see AWS docs](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-managing-tables.html). Note that + Iceberg DROP TABLE operations may timeout if they take longer than 60 seconds. +- `seed_by_insert` (`default=false`) + - Default behaviour uploads seed data to S3. This flag will create seeds using an SQL insert statement + - Large seed files cannot use `seed_by_insert`, as the SQL insert statement would + exceed [the Athena limit of 262144 bytes](https://docs.aws.amazon.com/athena/latest/ug/service-limits.html) +- `force_batch` (`default=false`) + - Skip creating the table as CTAS and run the operation directly in batch insert mode + - This is particularly useful when the standard table creation process fails due to partition limitations, + allowing you to work with temporary tables and persist the dataset more efficiently +- `unique_tmp_table_suffix` (`default=false`) + - For incremental models using insert overwrite strategy on hive table + - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid + - Useful if you are looking to run multiple dbt build inserting in the same table in parallel +- `temp_schema` (`default=none`) + - For incremental models, it allows to define a schema to hold temporary create statements + used in incremental model runs + - Schema will be created in the model target database if does not exist +- `lf_tags_config` (`default=none`) + - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns + - `enabled` (`default=False`) whether LF tags management is enabled for a model + - `tags` dictionary with tags and their values to assign for the model + - `tags_columns` dictionary with a tag key, value and list of columns they must be assigned to + - `lf_inherited_tags` (`default=none`) + - List of Lake Formation tag keys that are intended to be inherited from the database level and thus shouldn't be + removed during association of those defined in `lf_tags_config` + - i.e., the default behavior of `lf_tags_config` is to be exhaustive and first remove any pre-existing tags from + tables and columns before associating the ones currently defined for a given model + - This breaks tag inheritance as inherited tags appear on tables and columns like those associated directly + +```sql +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change='append_new_columns', + table_type='iceberg', + schema='test_schema', + lf_tags_config={ + 'enabled': true, + 'tags': { + 'tag1': 'value1', + 'tag2': 'value2' + }, + 'tags_columns': { + 'tag1': { + 'value1': ['column1', 'column2'], + 'value2': ['column3', 'column4'] + } + }, + 'inherited_tags': ['tag1', 'tag2'] + } + ) +}} +``` + +- Format for `dbt_project.yml`: + +```yaml + +lf_tags_config: + enabled: true + tags: + tag1: value1 + tag2: value2 + tags_columns: + tag1: + value1: [ column1, column2 ] + inherited_tags: [ tag1, tag2 ] +``` + +- `lf_grants` (`default=none`) + - Lake Formation grants config for data_cell filters + - Format: + + ```python + lf_grants={ + 'data_cell_filters': { + 'enabled': True | False, + 'filters': { + 'filter_name': { + 'row_filter': '<filter_condition>', + 'principals': ['principal_arn1', 'principal_arn2'] + } + } + } + } + ``` + +> Notes: +> +> - `lf_tags` and `lf_tags_columns` configs support only attaching lf tags to corresponding resources. +> We recommend managing LF Tags permissions somewhere outside dbt. For example, you may use +> [terraform](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) or +> [aws cdk](https://docs.aws.amazon.com/cdk/api/v1/docs/aws-lakeformation-readme.html) for such purpose. +> - `data_cell_filters` management can't be automated outside dbt because the filter can't be attached to the table +> which doesn't exist. Once you `enable` this config, dbt will set all filters and their permissions during every +> dbt run. Such approach keeps the actual state of row level security configuration actual after every dbt run and +> apply changes if they occur: drop, create, update filters and their permissions. +> - Any tags listed in `lf_inherited_tags` should be strictly inherited from the database level and never overridden at + the table and column level +> - Currently `dbt-athena` does not differentiate between an inherited tag association and an override of same it made +> previously +> - e.g. If an inherited tag is overridden by an `lf_tags_config` value in one DBT run, and that override is removed + prior to a subsequent run, the prior override will linger and no longer be encoded anywhere (in e.g. Terraform + where the inherited value is configured nor in the DBT project where the override previously existed but now is + gone) + +[create-table-as]: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + +### Table location + +The location a table is saved to is determined by: + +1. If `external_location` is defined, that value is used +2. If `s3_data_dir` is defined, the path is determined by that and `s3_data_naming` +3. If `s3_data_dir` is not defined, data is stored under `s3_staging_dir/tables/` + +Here all the options available for `s3_data_naming`: + +- `unique`: `{s3_data_dir}/{uuid4()}/` +- `table`: `{s3_data_dir}/{table}/` +- `table_unique`: `{s3_data_dir}/{table}/{uuid4()}/` +- `schema_table`: `{s3_data_dir}/{schema}/{table}/` +- `s3_data_naming=schema_table_unique`: `{s3_data_dir}/{schema}/{table}/{uuid4()}/` + +It's possible to set the `s3_data_naming` globally in the target profile, or overwrite the value in the table config, +or setting up the value for groups of model in dbt_project.yml. + +> Note: when using a workgroup with a default output location configured, `s3_data_naming` and any configured buckets +> are ignored and the location configured in the workgroup is used. + +### Incremental models + +Support for [incremental models](https://docs.getdbt.com/docs/build/incremental-models). + +These strategies are supported: + +- `insert_overwrite` (default): The insert overwrite strategy deletes the overlapping partitions from the destination + table, and then inserts the new records from the source. This strategy depends on the `partitioned_by` keyword! If no + partitions are defined, dbt will fall back to the `append` strategy. +- `append`: Insert new records without updating, deleting or overwriting any existing data. There might be duplicate + data (e.g. great for log or historical data). +- `merge`: Conditionally updates, deletes, or inserts rows into an Iceberg table. Used in combination with `unique_key`. + Only available when using Iceberg. + +### On schema change + +`on_schema_change` is an option to reflect changes of schema in incremental models. +The following options are supported: + +- `ignore` (default) +- `fail` +- `append_new_columns` +- `sync_all_columns` + +For details, please refer +to [dbt docs](https://docs.getdbt.com/docs/build/incremental-models#what-if-the-columns-of-my-incremental-model-change). + +### Iceberg + +The adapter supports table materialization for Iceberg. + +To get started just add this as your model: + +```sql +{{ config( + materialized='table', + table_type='iceberg', + format='parquet', + partitioned_by=['bucket(user_id, 5)'], + table_properties={ + 'optimize_rewrite_delete_file_threshold': '2' + } +) }} + +select 'A' as user_id, + 'pi' as name, + 'active' as status, + 17.89 as cost, + 1 as quantity, + 100000000 as quantity_big, + current_date as my_date +``` + +Iceberg supports bucketing as hidden partitions, therefore use the `partitioned_by` config to add specific bucketing +conditions. + +Iceberg supports several table formats for data : `PARQUET`, `AVRO` and `ORC`. + +It is possible to use Iceberg in an incremental fashion, specifically two strategies are supported: + +- `append`: New records are appended to the table, this can lead to duplicates. +- `merge`: Performs an upsert (and optional delete), where new records are added and existing records are updated. Only + available with Athena engine version 3. + - `unique_key` **(required)**: columns that define a unique record in the source and target tables. + - `incremental_predicates` (optional): SQL conditions that enable custom join clauses in the merge statement. This can + be useful for improving performance via predicate pushdown on the target table. + - `delete_condition` (optional): SQL condition used to identify records that should be deleted. + - `update_condition` (optional): SQL condition used to identify records that should be updated. + - `insert_condition` (optional): SQL condition used to identify records that should be inserted. + - `incremental_predicates`, `delete_condition`, `update_condition` and `insert_condition` can include any column of + the incremental table (`src`) or the final table (`target`). + Column names must be prefixed by either `src` or `target` to prevent a `Column is ambiguous` error. + +`delete_condition` example: + +```sql +{{ config( + materialized='incremental', + table_type='iceberg', + incremental_strategy='merge', + unique_key='user_id', + incremental_predicates=["src.quantity > 1", "target.my_date >= now() - interval '4' year"], + delete_condition="src.status != 'active' and target.my_date < now() - interval '2' year", + format='parquet' +) }} + +select 'A' as user_id, + 'pi' as name, + 'active' as status, + 17.89 as cost, + 1 as quantity, + 100000000 as quantity_big, + current_date as my_date +``` + +`update_condition` example: + +```sql +{{ config( + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + update_condition='target.id > 1', + schema='sandbox' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, 'v1-updated') + , (2, 'v2-updated') +) as t (id, value) + +{% else %} + +select * from ( + values + (-1, 'v-1') + , (0, 'v0') + , (1, 'v1') + , (2, 'v2') +) as t (id, value) + +{% endif %} +``` + +`insert_condition` example: + +```sql +{{ config( + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + insert_condition='target.status != 0', + schema='sandbox' + ) +}} + +select * from ( + values + (1, 0) + , (2, 1) +) as t (id, status) + +``` + +### Highly available table (HA) + +The current implementation of the table materialization can lead to downtime, as the target table is +dropped and re-created. To have the less destructive behavior it's possible to use the `ha` config on +your `table` materialized models. It leverages the table versions feature of glue catalog, creating +a temp table and swapping the target table to the location of the temp table. This materialization +is only available for `table_type=hive` and requires using unique locations. For iceberg, high +availability is the default. + +```sql +{{ config( + materialized='table', + ha=true, + format='parquet', + table_type='hive', + partitioned_by=['status'], + s3_data_naming='table_unique' +) }} + +select 'a' as user_id, + 'pi' as user_name, + 'active' as status +union all +select 'b' as user_id, + 'sh' as user_name, + 'disabled' as status +``` + +By default, the materialization keeps the last 4 table versions, you can change it by setting `versions_to_keep`. + +#### HA known issues + +- When swapping from a table with partitions to a table without (and the other way around), there could be a little + downtime. + If high performances is needed consider bucketing instead of partitions +- By default, Glue "duplicates" the versions internally, so the last two versions of a table point to the same location +- It's recommended to set `versions_to_keep` >= 4, as this will avoid having the older location removed + +### Update glue data catalog + +Optionally persist resource descriptions as column and relation comments to the glue data catalog, and meta as +[glue table properties](https://docs.aws.amazon.com/glue/latest/dg/tables-described.html#table-properties) +and [column parameters](https://docs.aws.amazon.com/glue/latest/webapi/API_Column.html). +By default, documentation persistence is disabled, but it can be enabled for specific resources or +groups of resources as needed. + +For example: + +```yaml +models: + - name: test_deduplicate + description: another value + config: + persist_docs: + relation: true + columns: true + meta: + test: value + columns: + - name: id + meta: + primary_key: true +``` + +See [persist docs](https://docs.getdbt.com/reference/resource-configs/persist_docs) for more details. + +## Snapshots + +The adapter supports snapshot materialization. It supports both timestamp and check strategy. To create a snapshot +create a snapshot file in the snapshots directory. If the directory does not exist create one. + +### Timestamp strategy + +To use the timestamp strategy refer to +the [dbt docs](https://docs.getdbt.com/docs/build/snapshots#timestamp-strategy-recommended) + +### Check strategy + +To use the check strategy refer to the [dbt docs](https://docs.getdbt.com/docs/build/snapshots#check-strategy) + +### Hard-deletes + +The materialization also supports invalidating hard deletes. Check +the [docs](https://docs.getdbt.com/docs/build/snapshots#hard-deletes-opt-in) to understand usage. + +### Working example + +seed file - employent_indicators_november_2022_csv_tables.csv + +```csv +Series_reference,Period,Data_value,Suppressed +MEIM.S1WA,1999.04,80267, +MEIM.S1WA,1999.05,70803, +MEIM.S1WA,1999.06,65792, +MEIM.S1WA,1999.07,66194, +MEIM.S1WA,1999.08,67259, +MEIM.S1WA,1999.09,69691, +MEIM.S1WA,1999.1,72475, +MEIM.S1WA,1999.11,79263, +MEIM.S1WA,1999.12,86540, +MEIM.S1WA,2000.01,82552, +MEIM.S1WA,2000.02,81709, +MEIM.S1WA,2000.03,84126, +MEIM.S1WA,2000.04,77089, +MEIM.S1WA,2000.05,73811, +MEIM.S1WA,2000.06,70070, +MEIM.S1WA,2000.07,69873, +MEIM.S1WA,2000.08,71468, +MEIM.S1WA,2000.09,72462, +MEIM.S1WA,2000.1,74897, +``` + +model.sql + +```sql +{{ config( + materialized='table' +) }} + +select row_number() over() as id + , * + , cast(from_unixtime(to_unixtime(now())) as timestamp(6)) as refresh_timestamp +from {{ ref('employment_indicators_november_2022_csv_tables') }} +``` + +timestamp strategy - model_snapshot_1 + +```sql +{% snapshot model_snapshot_1 %} + +{{ + config( + strategy='timestamp', + updated_at='refresh_timestamp', + unique_key='id' + ) +}} + +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +invalidate hard deletes - model_snapshot_2 + +```sql +{% snapshot model_snapshot_2 %} + +{{ + config + ( + unique_key='id', + strategy='timestamp', + updated_at='refresh_timestamp', + invalidate_hard_deletes=True, + ) +}} +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +check strategy - model_snapshot_3 + +```sql +{% snapshot model_snapshot_3 %} + +{{ + config + ( + unique_key='id', + strategy='check', + check_cols=['series_reference','data_value'] + ) +}} +select * +from {{ ref('model') }} {% endsnapshot %} +``` + +### Snapshots known issues + +- Incremental Iceberg models - Sync all columns on schema change can't remove columns used for partitioning. + The only way, from a dbt perspective, is to do a full-refresh of the incremental model. + +- Tables, schemas and database names should only be lowercase + +- In order to avoid potential conflicts, make sure [`dbt-athena-adapter`](https://github.com/Tomme/dbt-athena) is not + installed in the target environment. + See <https://github.com/dbt-athena/dbt-athena/issues/103> for more details. + +- Snapshot does not support dropping columns from the source table. If you drop a column make sure to drop the column + from the snapshot as well. Another workaround is to NULL the column in the snapshot definition to preserve history + +## AWS Lake Formation integration + +The adapter implements AWS Lake Formation tags management in the following way: + +- You can enable or disable lf-tags management via [config](#table-configuration) (disabled by default) +- Once you enable the feature, lf-tags will be updated on every dbt run +- First, all lf-tags for columns are removed to avoid inheritance issues +- Then, all redundant lf-tags are removed from tables and actual tags from table configs are applied +- Finally, lf-tags for columns are applied + +It's important to understand the following points: + +- dbt does not manage lf-tags for databases +- dbt does not manage Lake Formation permissions + +That's why you should handle this by yourself manually or using an automation tool like terraform, AWS CDK etc. +You may find the following links useful to manage that: + +<!-- markdownlint-disable --> +* [terraform aws_lakeformation_permissions](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) +* [terraform aws_lakeformation_resource_lf_tags](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_resource_lf_tags) +<!-- markdownlint-restore --> + +## Python models + +The adapter supports Python models using [`spark`](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark.html). + +### Setup + +- A Spark-enabled workgroup created in Athena +- Spark execution role granted access to Athena, Glue and S3 +- The Spark workgroup is added to the `~/.dbt/profiles.yml` file and the profile to be used + is referenced in `dbt_project.yml` + +### Spark-specific table configuration + +- `timeout` (`default=43200`) + - Time out in seconds for each Python model execution. Defaults to 12 hours/43200 seconds. +- `spark_encryption` (`default=false`) + - If this flag is set to true, encrypts data in transit between Spark nodes and also encrypts data at rest stored + locally by Spark. +- `spark_cross_account_catalog` (`default=false`) + - When using the Spark Athena workgroup, queries can only be made against catalogs located on the same + AWS account by default. However, sometimes you want to query another catalog located on an external AWS + account. Setting this additional Spark properties parameter to true will enable querying external catalogs. + You can use the syntax `external_catalog_id/database.table` to access the external table on the external + catalog (ex: `999999999999/mydatabase.cloudfront_logs` where 999999999999 is the external catalog ID) +- `spark_requester_pays` (`default=false`) + - When an Amazon S3 bucket is configured as requester pays, the account of the user running the query is charged for + data access and data transfer fees associated with the query. + - If this flag is set to true, requester pays S3 buckets are enabled in Athena for Spark. + +### Spark notes + +- A session is created for each unique engine configuration defined in the models that are part of the invocation. +- A session's idle timeout is set to 10 minutes. Within the timeout period, if there is a new calculation + (Spark Python model) ready for execution and the engine configuration matches, the process will reuse the same session. +- The number of Python models running at a time depends on the `threads`. The number of sessions created for the + entire run depends on the number of unique engine configurations and the availability of sessions to maintain + thread concurrency. +- For Iceberg tables, it is recommended to use `table_properties` configuration to set the `format_version` to 2. + This is to maintain compatibility between Iceberg tables created by Trino with those created by Spark. + +### Example models + +#### Simple pandas model + +```python +import pandas as pd + + +def model(dbt, session): + dbt.config(materialized="table") + + model_df = pd.DataFrame({"A": [1, 2, 3, 4]}) + + return model_df +``` + +#### Simple spark + +```python +def model(dbt, spark_session): + dbt.config(materialized="table") + + data = [(1,), (2,), (3,), (4,)] + + df = spark_session.createDataFrame(data, ["A"]) + + return df +``` + +#### Spark incremental + +```python +def model(dbt, spark_session): + dbt.config(materialized="incremental") + df = dbt.ref("model") + + if dbt.is_incremental: + max_from_this = ( + f"select max(run_date) from {dbt.this.schema}.{dbt.this.identifier}" + ) + df = df.filter(df.run_date >= spark_session.sql(max_from_this).collect()[0][0]) + + return df +``` + +#### Config spark model + +```python +def model(dbt, spark_session): + dbt.config( + materialized="table", + engine_config={ + "CoordinatorDpuSize": 1, + "MaxConcurrentDpus": 3, + "DefaultExecutorDpuSize": 1 + }, + spark_encryption=True, + spark_cross_account_catalog=True, + spark_requester_pays=True + polling_interval=15, + timeout=120, + ) + + data = [(1,), (2,), (3,), (4,)] + + df = spark_session.createDataFrame(data, ["A"]) + + return df +``` + +#### Create pySpark udf using imported external python files + +```python +def model(dbt, spark_session): + dbt.config( + materialized="incremental", + incremental_strategy="merge", + unique_key="num", + ) + sc = spark_session.sparkContext + sc.addPyFile("s3://athena-dbt/test/file1.py") + sc.addPyFile("s3://athena-dbt/test/file2.py") + + def func(iterator): + from file2 import transform + + return [transform(i) for i in iterator] + + from pyspark.sql.functions import udf + from pyspark.sql.functions import col + + udf_with_import = udf(func) + + data = [(1, "a"), (2, "b"), (3, "c")] + cols = ["num", "alpha"] + df = spark_session.createDataFrame(data, cols) + + return df.withColumn("udf_test_col", udf_with_import(col("alpha"))) +``` + +### Known issues in Python models + +- Python models cannot + [reference Athena SQL views](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark.html). +- Third-party Python libraries can be used, but they must be [included in the pre-installed list][pre-installed list] + or [imported manually][imported manually]. +- Python models can only reference or write to tables with names meeting the + regular expression: `^[0-9a-zA-Z_]+$`. Dashes and special characters are not + supported by Spark, even though Athena supports them. +- Incremental models do not fully utilize Spark capabilities. They depend partially on existing SQL-based logic which + runs on Trino. +- Snapshot materializations are not supported. +- Spark can only reference tables within the same catalog. +- For tables created outside of the dbt tool, be sure to populate the location field or dbt will throw an error +when trying to create the table. + +[pre-installed list]: https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-preinstalled-python-libraries.html +[imported manually]: https://docs.aws.amazon.com/athena/latest/ug/notebooks-import-files-libraries.html + +## Contracts + +The adapter partly supports contract definitions: + +- `data_type` is supported but needs to be adjusted for complex types. Types must be specified + entirely (for instance `array<int>`) even though they won't be checked. Indeed, as dbt recommends, we only compare + the broader type (array, map, int, varchar). The complete definition is used in order to check that the data types + defined in Athena are ok (pre-flight check). +- The adapter does not support the constraints since there is no constraint concept in Athena. + +## Contributing + +See [CONTRIBUTING](CONTRIBUTING.md) for more information on how to contribute to this project. + +## Contributors ✨ + +Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): + +<a href="https://github.com/dbt-athena/dbt-athena/graphs/contributors"> + <img src="https://contrib.rocks/image?repo=dbt-athena/dbt-athena" /> +</a> + +Contributions of any kind welcome! diff --git a/dbt-athena/hatch.toml b/dbt-athena/hatch.toml new file mode 100644 index 00000000..8f0e1263 --- /dev/null +++ b/dbt-athena/hatch.toml @@ -0,0 +1,57 @@ +[version] +path = "src/dbt/adapters/athena/__version__.py" + +[build.targets.sdist] +packages = ["src/dbt/adapters", "src/dbt/include"] +sources = ["src"] + +[build.targets.wheel] +packages = ["src/dbt/adapters", "src/dbt/include"] +sources = ["src"] + +[envs.default] +dependencies = [ + "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git", + "dbt-common @ git+https://github.com/dbt-labs/dbt-common.git", + "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter", + "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core", + "moto~=5.0.13", + "pre-commit~=3.5", + "pyparsing~=3.1.4", + "pytest~=8.3", + "pytest-cov~=5.0", + "pytest-dotenv~=0.5", + "pytest-xdist~=3.6", +] +[envs.default.scripts] +setup = [ + "pre-commit install", + "cp -n test.env.example test.env", +] +code-quality = "pre-commit run --all-files" +unit-tests = "pytest --cov=dbt --cov-report=html:htmlcov {args:tests/unit}" +integration-tests = "python -m pytest -n auto {args:tests/functional}" +all-tests = ["unit-tests", "integration-tests"] + +[envs.build] +detached = true +dependencies = [ + "wheel", + "twine", + "check-wheel-contents", +] +[envs.build.scripts] +check-all = [ + "- check-wheel", + "- check-sdist", +] +check-wheel = [ + "check-wheel-contents dist/*.whl --ignore W007,W008", + "find ./dist/dbt_athena-*.whl -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/", + "pip freeze | grep dbt-athena", +] +check-sdist = [ + "twine check dist/*", + "find ./dist/dbt_athena-*.gz -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/", + "pip freeze | grep dbt-athena", +] diff --git a/dbt-athena/pyproject.toml b/dbt-athena/pyproject.toml new file mode 100644 index 00000000..69ec0018 --- /dev/null +++ b/dbt-athena/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +dynamic = ["version"] +name = "dbt-athena" +description = "The athena adapter plugin for dbt (data build tool)" +readme = "README.md" +keywords = ["dbt", "adapter", "adapters", "database", "elt", "dbt-core", "dbt Core", "dbt Cloud", "dbt Labs", "athena"] +requires-python = ">=3.9.0" +authors = [ + { name = "dbt Labs", email = "info@dbtlabs.com" }, +] +maintainers = [ + { name = "dbt Labs", email = "info@dbtlabs.com" }, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies=[ + "dbt-adapters>=1.0.0,<2.0", + "dbt-common>=1.0.0,<2.0", + # add dbt-core to ensure backwards compatibility of installation, this is not a functional dependency + "dbt-core>=1.8.0", + "boto3>=1.28", + "boto3-stubs[athena,glue,lakeformation,sts]>=1.28", + "mmh3>=4.0.1,<4.2.0", + "pyathena>=2.25,<4.0", + "pydantic>=1.10,<3.0", + "tenacity>=8.2,<10.0", +] +[project.urls] +Homepage = "https://github.com/dbt-labs/dbt-athena/dbt-athena" +Documentation = "https://docs.getdbt.com" +Repository = "https://github.com/dbt-labs/dbt-athena.git#subdirectory=dbt-athena" +Issues = "https://github.com/dbt-labs/dbt-athena/issues" +Changelog = "https://github.com/dbt-labs/dbt-athena/blob/main/dbt-athena/CHANGELOG.md" + +[tool.pytest] +testpaths = ["tests/unit", "tests/functional"] +color = true +csv = "results.csv" +filterwarnings = [ + "ignore:.*'soft_unicode' has been renamed to 'soft_str'*:DeprecationWarning", + "ignore:unclosed file .*:ResourceWarning", +] diff --git a/dbt-athena/src/dbt/__init__.py b/dbt-athena/src/dbt/__init__.py new file mode 100644 index 00000000..b36383a6 --- /dev/null +++ b/dbt-athena/src/dbt/__init__.py @@ -0,0 +1,3 @@ +from pkgutil import extend_path + +__path__ = extend_path(__path__, __name__) diff --git a/dbt-athena/src/dbt/adapters/athena/__init__.py b/dbt-athena/src/dbt/adapters/athena/__init__.py new file mode 100644 index 00000000..c2f140db --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/__init__.py @@ -0,0 +1,15 @@ +from dbt.adapters.athena.connections import AthenaConnectionManager, AthenaCredentials +from dbt.adapters.athena.impl import AthenaAdapter +from dbt.adapters.base import AdapterPlugin +from dbt.include import athena + +Plugin: AdapterPlugin = AdapterPlugin( + adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH +) + +__all__ = [ + "AthenaConnectionManager", + "AthenaCredentials", + "AthenaAdapter", + "Plugin", +] diff --git a/dbt-athena/src/dbt/adapters/athena/__version__.py b/dbt-athena/src/dbt/adapters/athena/__version__.py new file mode 100644 index 00000000..7aba6409 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/__version__.py @@ -0,0 +1 @@ +version = "1.9.0" diff --git a/dbt-athena/src/dbt/adapters/athena/column.py b/dbt-athena/src/dbt/adapters/athena/column.py new file mode 100644 index 00000000..a220bf3b --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/column.py @@ -0,0 +1,97 @@ +import re +from dataclasses import dataclass +from typing import ClassVar, Dict + +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.relation import TableType +from dbt.adapters.base.column import Column + + +@dataclass +class AthenaColumn(Column): + table_type: TableType = TableType.TABLE + + TYPE_LABELS: ClassVar[Dict[str, str]] = { + "STRING": "VARCHAR", + "TEXT": "VARCHAR", + } + + def is_iceberg(self) -> bool: + return self.table_type == TableType.ICEBERG + + def is_string(self) -> bool: + return self.dtype.lower() in {"varchar", "string"} + + def is_binary(self) -> bool: + return self.dtype.lower() in {"binary", "varbinary"} + + def is_timestamp(self) -> bool: + return self.dtype.lower() in {"timestamp"} + + def is_array(self) -> bool: + return self.dtype.lower().startswith("array") # type: ignore + + @classmethod + def string_type(cls, size: int) -> str: + return f"varchar({size})" if size > 0 else "varchar" + + @classmethod + def binary_type(cls) -> str: + return "varbinary" + + def timestamp_type(self) -> str: + return "timestamp(6)" if self.is_iceberg() else "timestamp" + + @classmethod + def array_type(cls, inner_type: str) -> str: + return f"array({inner_type})" + + def array_inner_type(self) -> str: + if not self.is_array(): + raise DbtRuntimeError("Called array_inner_type() on non-array field!") + # Match either `array<inner_type>` or `array(inner_type)`. Don't bother + # parsing nested arrays here, since we will expect the caller to be + # responsible for formatting the inner type, including nested arrays + pattern = r"^array[<(](.*)[>)]$" + match = re.match(pattern, self.dtype) + if match: + return match.group(1) + # If for some reason there's no match, fall back to the original string + return self.dtype # type: ignore + + def string_size(self) -> int: + if not self.is_string(): + raise DbtRuntimeError("Called string_size() on non-string field!") + # Handle error: '>' not supported between instances of 'NoneType' and 'NoneType' for union relations macro + return self.char_size or 0 + + @property + def data_type(self) -> str: + if self.is_string(): + return self.string_type(self.string_size()) + + if self.is_numeric(): + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) # type: ignore + + if self.is_binary(): + return self.binary_type() + + if self.is_timestamp(): + return self.timestamp_type() + + if self.is_array(): + # Resolve the inner type of the array, using an AthenaColumn + # instance to properly convert the inner type. Note that this will + # cause recursion in cases of nested arrays + inner_type = self.array_inner_type() + inner_type_col = AthenaColumn( + column=self.column, + dtype=inner_type, + char_size=self.char_size, + numeric_precision=self.numeric_precision, + numeric_scale=self.numeric_scale, + ) + return self.array_type(inner_type_col.data_type) + + return self.dtype # type: ignore diff --git a/dbt-athena/src/dbt/adapters/athena/config.py b/dbt-athena/src/dbt/adapters/athena/config.py new file mode 100644 index 00000000..eacc10e0 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/config.py @@ -0,0 +1,159 @@ +import importlib.metadata +from functools import lru_cache +from typing import Any, Dict + +from botocore import config + +from dbt.adapters.athena.constants import ( + DEFAULT_CALCULATION_TIMEOUT, + DEFAULT_POLLING_INTERVAL, + DEFAULT_SPARK_COORDINATOR_DPU_SIZE, + DEFAULT_SPARK_EXECUTOR_DPU_SIZE, + DEFAULT_SPARK_MAX_CONCURRENT_DPUS, + DEFAULT_SPARK_PROPERTIES, + LOGGER, +) + + +@lru_cache() +def get_boto3_config(num_retries: int) -> config.Config: + return config.Config( + user_agent_extra="dbt-athena/" + importlib.metadata.version("dbt-athena"), + retries={"max_attempts": num_retries, "mode": "standard"}, + ) + + +class AthenaSparkSessionConfig: + """ + A helper class to manage Athena Spark Session Configuration. + """ + + def __init__(self, config: Dict[str, Any], **session_kwargs: Any) -> None: + self.config = config + self.session_kwargs = session_kwargs + + def set_timeout(self) -> int: + """ + Get the timeout value. + + This function retrieves the timeout value from the parsed model's configuration. If the timeout value + is not defined, it falls back to the default timeout value. If the retrieved timeout value is less than or + equal to 0, a ValueError is raised as timeout must be a positive integer. + + Returns: + int: The timeout value in seconds. + + Raises: + ValueError: If the timeout value is not a positive integer. + + """ + timeout = self.config.get("timeout", DEFAULT_CALCULATION_TIMEOUT) + if not isinstance(timeout, int): + raise TypeError("Timeout must be an integer") + if timeout <= 0: + raise ValueError("Timeout must be a positive integer") + LOGGER.debug(f"Setting timeout: {timeout}") + return timeout + + def get_polling_interval(self) -> Any: + """ + Get the polling interval for the configuration. + + Returns: + Any: The polling interval value. + + Raises: + KeyError: If the polling interval is not found in either `self.config` + or `self.session_kwargs`. + """ + try: + return self.config["polling_interval"] + except KeyError: + try: + return self.session_kwargs["polling_interval"] + except KeyError: + return DEFAULT_POLLING_INTERVAL + + def set_polling_interval(self) -> float: + """ + Set the polling interval for the configuration. + + Returns: + float: The polling interval value. + + Raises: + ValueError: If the polling interval is not a positive integer. + """ + polling_interval = self.get_polling_interval() + if not (isinstance(polling_interval, float) or isinstance(polling_interval, int)) or polling_interval <= 0: + raise ValueError(f"Polling_interval must be a positive number. Got: {polling_interval}") + LOGGER.debug(f"Setting polling_interval: {polling_interval}") + return float(polling_interval) + + def set_engine_config(self) -> Dict[str, Any]: + """Set the engine configuration. + + Returns: + Dict[str, Any]: The engine configuration. + + Raises: + TypeError: If the engine configuration is not of type dict. + KeyError: If the keys of the engine configuration dictionary do not match the expected format. + """ + table_type = self.config.get("table_type", "hive") + spark_encryption = self.config.get("spark_encryption", False) + spark_cross_account_catalog = self.config.get("spark_cross_account_catalog", False) + spark_requester_pays = self.config.get("spark_requester_pays", False) + + default_spark_properties: Dict[str, str] = dict( + **DEFAULT_SPARK_PROPERTIES.get(table_type) + if table_type.lower() in ["iceberg", "hudi", "delta_lake"] + else {}, + **DEFAULT_SPARK_PROPERTIES.get("spark_encryption") if spark_encryption else {}, + **DEFAULT_SPARK_PROPERTIES.get("spark_cross_account_catalog") if spark_cross_account_catalog else {}, + **DEFAULT_SPARK_PROPERTIES.get("spark_requester_pays") if spark_requester_pays else {}, + ) + + default_engine_config = { + "CoordinatorDpuSize": DEFAULT_SPARK_COORDINATOR_DPU_SIZE, + "MaxConcurrentDpus": DEFAULT_SPARK_MAX_CONCURRENT_DPUS, + "DefaultExecutorDpuSize": DEFAULT_SPARK_EXECUTOR_DPU_SIZE, + "SparkProperties": default_spark_properties, + } + engine_config = self.config.get("engine_config", None) + + if engine_config: + provided_spark_properties = engine_config.get("SparkProperties", None) + if provided_spark_properties: + default_spark_properties.update(provided_spark_properties) + default_engine_config["SparkProperties"] = default_spark_properties + engine_config.pop("SparkProperties") + default_engine_config.update(engine_config) + engine_config = default_engine_config + + if not isinstance(engine_config, dict): + raise TypeError("Engine configuration has to be of type dict") + + expected_keys = { + "CoordinatorDpuSize", + "MaxConcurrentDpus", + "DefaultExecutorDpuSize", + "SparkProperties", + "AdditionalConfigs", + } + + if set(engine_config.keys()) - { + "CoordinatorDpuSize", + "MaxConcurrentDpus", + "DefaultExecutorDpuSize", + "SparkProperties", + "AdditionalConfigs", + }: + raise KeyError( + f"The engine configuration keys provided do not match the expected athena engine keys: {expected_keys}" + ) + + if engine_config["MaxConcurrentDpus"] == 1: + raise KeyError("The lowest value supported for MaxConcurrentDpus is 2") + LOGGER.debug(f"Setting engine configuration: {engine_config}") + return engine_config diff --git a/dbt-athena/src/dbt/adapters/athena/connections.py b/dbt-athena/src/dbt/adapters/athena/connections.py new file mode 100644 index 00000000..39e4c59e --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/connections.py @@ -0,0 +1,371 @@ +import json +import re +import time +from concurrent.futures.thread import ThreadPoolExecutor +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from decimal import Decimal +from typing import Any, ContextManager, Dict, List, Optional, Tuple + +from dbt_common.exceptions import ConnectionError, DbtRuntimeError +from dbt_common.utils import md5 +from pyathena.connection import Connection as AthenaConnection +from pyathena.cursor import Cursor +from pyathena.error import OperationalError, ProgrammingError + +# noinspection PyProtectedMember +from pyathena.formatter import ( + _DEFAULT_FORMATTERS, + Formatter, + _escape_hive, + _escape_presto, +) +from pyathena.model import AthenaQueryExecution +from pyathena.result_set import AthenaResultSet +from pyathena.util import RetryConfig +from tenacity import ( + retry, + retry_if_exception, + stop_after_attempt, + wait_random_exponential, +) +from typing_extensions import Self + +from dbt.adapters.athena.config import get_boto3_config +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter +from dbt.adapters.athena.session import get_boto3_session +from dbt.adapters.contracts.connection import ( + AdapterResponse, + Connection, + ConnectionState, + Credentials, +) +from dbt.adapters.sql import SQLConnectionManager + + +@dataclass +class AthenaAdapterResponse(AdapterResponse): + data_scanned_in_bytes: Optional[int] = None + + +@dataclass +class AthenaCredentials(Credentials): + s3_staging_dir: str + region_name: str + endpoint_url: Optional[str] = None + work_group: Optional[str] = None + skip_workgroup_check: bool = False + aws_profile_name: Optional[str] = None + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None + poll_interval: float = 1.0 + debug_query_state: bool = False + _ALIASES = {"catalog": "database"} + num_retries: int = 5 + num_boto3_retries: Optional[int] = None + num_iceberg_retries: int = 3 + s3_data_dir: Optional[str] = None + s3_data_naming: str = "schema_table_unique" + spark_work_group: Optional[str] = None + s3_tmp_table_dir: Optional[str] = None + # Unfortunately we can not just use dict, must be Dict because we'll get the following error: + # Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict' + seed_s3_upload_args: Optional[Dict[str, Any]] = None + lf_tags_database: Optional[Dict[str, str]] = None + + @property + def type(self) -> str: + return "athena" + + @property + def unique_field(self) -> str: + return f"athena-{md5(self.s3_staging_dir)}" + + @property + def effective_num_retries(self) -> int: + return self.num_boto3_retries or self.num_retries + + def _connection_keys(self) -> Tuple[str, ...]: + return ( + "s3_staging_dir", + "work_group", + "skip_workgroup_check", + "region_name", + "database", + "schema", + "poll_interval", + "aws_profile_name", + "aws_access_key_id", + "endpoint_url", + "s3_data_dir", + "s3_data_naming", + "s3_tmp_table_dir", + "debug_query_state", + "seed_s3_upload_args", + "lf_tags_database", + "spark_work_group", + ) + + +class AthenaCursor(Cursor): + def __init__(self, **kwargs) -> None: # type: ignore + super().__init__(**kwargs) + self._executor = ThreadPoolExecutor() + + def _collect_result_set(self, query_id: str) -> AthenaResultSet: + query_execution = self._poll(query_id) + return self._result_set_class( + connection=self._connection, + converter=self._converter, + query_execution=query_execution, + arraysize=self._arraysize, + retry_config=self._retry_config, + ) + + def _poll(self, query_id: str) -> AthenaQueryExecution: + try: + query_execution = self.__poll(query_id) + except KeyboardInterrupt as e: + if self._kill_on_interrupt: + LOGGER.warning("Query canceled by user.") + self._cancel(query_id) + query_execution = self.__poll(query_id) + else: + raise e + return query_execution + + def __poll(self, query_id: str) -> AthenaQueryExecution: + while True: + query_execution = self._get_query_execution(query_id) + if query_execution.state in [ + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ]: + return query_execution + + if self.connection.cursor_kwargs.get("debug_query_state", False): + LOGGER.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...") + time.sleep(self._poll_interval) + + def execute( + self, + operation: str, + parameters: Optional[Dict[str, Any]] = None, + work_group: Optional[str] = None, + s3_staging_dir: Optional[str] = None, + endpoint_url: Optional[str] = None, + cache_size: int = 0, + cache_expiration_time: int = 0, + catch_partitions_limit: bool = False, + **kwargs: Dict[str, Any], + ) -> Self: + @retry( + # No need to retry if TOO_MANY_OPEN_PARTITIONS occurs. + # Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry, + # because not all files are removed immediately after first try to create table + retry=retry_if_exception( + lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True + ), + stop=stop_after_attempt(self._retry_config.attempt), + wait=wait_random_exponential( + multiplier=self._retry_config.attempt, + max=self._retry_config.max_delay, + exp_base=self._retry_config.exponential_base, + ), + reraise=True, + ) + def inner() -> AthenaCursor: + num_iceberg_retries = self.connection.cursor_kwargs.get("num_iceberg_retries") + 1 + + @retry( + # Nested retry is needed to handle ICEBERG_COMMIT_ERROR for parallel inserts + retry=retry_if_exception(lambda e: "ICEBERG_COMMIT_ERROR" in str(e)), + stop=stop_after_attempt(num_iceberg_retries), + wait=wait_random_exponential( + multiplier=num_iceberg_retries, + max=self._retry_config.max_delay, + exp_base=self._retry_config.exponential_base, + ), + reraise=True, + ) + def execute_with_iceberg_retries() -> AthenaCursor: + query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + ) + + LOGGER.debug(f"Athena query ID {query_id}") + + query_execution = self._executor.submit(self._collect_result_set, query_id).result() + if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: + self.result_set = self._result_set_class( + self._connection, + self._converter, + query_execution, + self.arraysize, + self._retry_config, + ) + return self + raise OperationalError(query_execution.state_change_reason) + + return execute_with_iceberg_retries() # type: ignore + + return inner() # type: ignore + + +class AthenaConnectionManager(SQLConnectionManager): + TYPE = "athena" + + def set_query_header(self, query_header_context: Dict[str, Any]) -> None: + self.query_header = AthenaMacroQueryStringSetter(self.profile, query_header_context) + + @classmethod + def data_type_code_to_name(cls, type_code: str) -> str: + """ + Get the string representation of the data type from the Athena metadata. Dbt performs a + query to retrieve the types of the columns in the SQL query. Then these types are compared + to the types in the contract config, simplified because they need to match what is returned + by Athena metadata (we are only interested in the broader type, without subtypes nor granularity). + """ + return type_code.split("(")[0].split("<")[0].upper() + + @contextmanager # type: ignore + def exception_handler(self, sql: str) -> ContextManager: # type: ignore + try: + yield + except Exception as e: + LOGGER.debug(f"Error running SQL: {sql}") + raise DbtRuntimeError(str(e)) from e + + @classmethod + def open(cls, connection: Connection) -> Connection: + if connection.state == "open": + LOGGER.debug("Connection is already open, skipping open.") + return connection + + try: + creds: AthenaCredentials = connection.credentials + + handle = AthenaConnection( + s3_staging_dir=creds.s3_staging_dir, + endpoint_url=creds.endpoint_url, + catalog_name=creds.database, + schema_name=creds.schema, + work_group=creds.work_group, + cursor_class=AthenaCursor, + cursor_kwargs={ + "debug_query_state": creds.debug_query_state, + "num_iceberg_retries": creds.num_iceberg_retries, + }, + formatter=AthenaParameterFormatter(), + poll_interval=creds.poll_interval, + session=get_boto3_session(connection), + retry_config=RetryConfig( + attempt=creds.num_retries + 1, + exceptions=("ThrottlingException", "TooManyRequestsException", "InternalServerException"), + ), + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + connection.state = ConnectionState.OPEN + connection.handle = handle + + except Exception as exc: + LOGGER.exception(f"Got an error when attempting to open a Athena connection due to {exc}") + connection.handle = None + connection.state = ConnectionState.FAIL + raise ConnectionError(str(exc)) + + return connection + + @classmethod + def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse: + code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR" + rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor) + return AthenaAdapterResponse( + _message=f"{code} {rowcount}", + rows_affected=rowcount, + code=code, + data_scanned_in_bytes=data_scanned_in_bytes, + ) + + @staticmethod + def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]: + """ + Helper function to parse query statistics from SELECT statements. + The function looks for all statements that contains rowcount or data_scanned_in_bytes, + then strip the SELECT statements, and pick the value between curly brackets. + """ + if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])): + try: + query_split = cursor.query.lower().split("select")[-1] + # query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3} + # the following statement extract the content between { and } + query_stats = re.search("{(.*)}", query_split) + if query_stats: + stats = json.loads("{" + query_stats.group(1) + "}") + return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0) + except Exception as err: + LOGGER.debug(f"There was an error parsing query stats {err}") + return -1, 0 + return cursor.rowcount, cursor.data_scanned_in_bytes + + def cancel(self, connection: Connection) -> None: + pass + + def add_begin_query(self) -> None: + pass + + def add_commit_query(self) -> None: + pass + + def begin(self) -> None: + pass + + def commit(self) -> None: + pass + + +class AthenaParameterFormatter(Formatter): + def __init__(self) -> None: + super().__init__(mappings=deepcopy(_DEFAULT_FORMATTERS), default=None) + + def format(self, operation: str, parameters: Optional[List[str]] = None) -> str: + if not operation or not operation.strip(): + raise ProgrammingError("Query is none or empty.") + operation = operation.strip() + + if operation.upper().startswith(("SELECT", "WITH", "INSERT")): + escaper = _escape_presto + elif operation.upper().startswith(("VACUUM", "OPTIMIZE")): + operation = operation.replace('"', "") + else: + # Fixes ParseException that comes with newer version of PyAthena + operation = operation.replace("\n\n ", "\n") + + escaper = _escape_hive + + kwargs: Optional[List[str]] = None + if parameters is not None: + kwargs = list() + if isinstance(parameters, list): + for v in parameters: + # TODO Review this annoying Decimal hack, unsure if issue in dbt, agate or pyathena + if isinstance(v, Decimal) and v == int(v): + v = int(v) + + func = self.get(v) + if not func: + raise TypeError(f"{type(v)} is not defined formatter.") + kwargs.append(func(self, escaper, v)) + else: + raise ProgrammingError(f"Unsupported parameter (Support for list only): {parameters}") + return (operation % tuple(kwargs)).strip() if kwargs is not None else operation.strip() diff --git a/dbt-athena/src/dbt/adapters/athena/constants.py b/dbt-athena/src/dbt/adapters/athena/constants.py new file mode 100644 index 00000000..9f132d54 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/constants.py @@ -0,0 +1,41 @@ +from dbt.adapters.events.logging import AdapterLogger + +DEFAULT_THREAD_COUNT = 4 +DEFAULT_RETRY_ATTEMPTS = 3 +DEFAULT_POLLING_INTERVAL = 5 +DEFAULT_SPARK_COORDINATOR_DPU_SIZE = 1 +DEFAULT_SPARK_MAX_CONCURRENT_DPUS = 2 +DEFAULT_SPARK_EXECUTOR_DPU_SIZE = 1 +DEFAULT_CALCULATION_TIMEOUT = 43200 # seconds = 12 hours +SESSION_IDLE_TIMEOUT_MIN = 10 # minutes + +DEFAULT_SPARK_PROPERTIES = { + # https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-table-formats.html + "iceberg": { + "spark.sql.catalog.spark_catalog": "org.apache.iceberg.spark.SparkSessionCatalog", + "spark.sql.catalog.spark_catalog.catalog-impl": "org.apache.iceberg.aws.glue.GlueCatalog", + "spark.sql.catalog.spark_catalog.io-impl": "org.apache.iceberg.aws.s3.S3FileIO", + "spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + }, + "hudi": { + "spark.sql.catalog.spark_catalog": "org.apache.spark.sql.hudi.catalog.HoodieCatalog", + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.sql.extensions": "org.apache.spark.sql.hudi.HoodieSparkSessionExtension", + }, + "delta_lake": { + "spark.sql.catalog.spark_catalog": "org.apache.spark.sql.delta.catalog.DeltaCatalog", + "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", + }, + # https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-encryption.html + "spark_encryption": { + "spark.authenticate": "true", + "spark.io.encryption.enabled": "true", + "spark.network.crypto.enabled": "true", + }, + # https://docs.aws.amazon.com/athena/latest/ug/spark-notebooks-cross-account-glue.html + "spark_cross_account_catalog": {"spark.hadoop.aws.glue.catalog.separator": "/"}, + # https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-requester-pays.html + "spark_requester_pays": {"spark.hadoop.fs.s3.useRequesterPaysHeader": "true"}, +} + +LOGGER = AdapterLogger(__name__) diff --git a/dbt-athena/src/dbt/adapters/athena/exceptions.py b/dbt-athena/src/dbt/adapters/athena/exceptions.py new file mode 100644 index 00000000..adf6928f --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/exceptions.py @@ -0,0 +1,9 @@ +from dbt_common.exceptions import CompilationError, DbtRuntimeError + + +class SnapshotMigrationRequired(CompilationError): + """Hive snapshot requires a manual operation due to backward incompatible changes.""" + + +class S3LocationException(DbtRuntimeError): + pass diff --git a/dbt-athena/src/dbt/adapters/athena/impl.py b/dbt-athena/src/dbt/adapters/athena/impl.py new file mode 100755 index 00000000..e70cca3c --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/impl.py @@ -0,0 +1,1365 @@ +import csv +import os +import posixpath as path +import re +import struct +import tempfile +from dataclasses import dataclass +from datetime import date, datetime +from functools import lru_cache +from textwrap import dedent +from threading import Lock +from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type +from urllib.parse import urlparse +from uuid import uuid4 + +import agate +import mmh3 +from botocore.exceptions import ClientError +from dbt_common.clients.agate_helper import table_from_rows +from dbt_common.contracts.constraints import ConstraintType +from dbt_common.exceptions import DbtRuntimeError +from mypy_boto3_athena.type_defs import DataCatalogTypeDef, GetWorkGroupOutputTypeDef +from mypy_boto3_glue.type_defs import ( + ColumnTypeDef, + GetTableResponseTypeDef, + TableInputTypeDef, + TableTypeDef, + TableVersionTypeDef, +) +from pyathena.error import OperationalError + +from dbt.adapters.athena import AthenaConnectionManager +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.config import get_boto3_config +from dbt.adapters.athena.connections import AthenaCursor +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.athena.exceptions import ( + S3LocationException, + SnapshotMigrationRequired, +) +from dbt.adapters.athena.lakeformation import ( + LfGrantsConfig, + LfPermissions, + LfTagsConfig, + LfTagsManager, +) +from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper +from dbt.adapters.athena.relation import ( + AthenaRelation, + AthenaSchemaSearchMap, + TableType, + get_table_type, +) +from dbt.adapters.athena.s3 import S3DataNaming +from dbt.adapters.athena.utils import ( + AthenaCatalogType, + clean_sql_comment, + ellipsis_comment, + get_catalog_id, + get_catalog_type, + get_chunks, + is_valid_table_parameter_key, + stringify_table_parameter_value, +) +from dbt.adapters.base import ConstraintSupport, PythonJobHelper, available +from dbt.adapters.base.impl import AdapterConfig +from dbt.adapters.base.relation import BaseRelation, InformationSchema +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.contracts.relation import RelationConfig +from dbt.adapters.sql import SQLAdapter + +boto3_client_lock = Lock() + + +@dataclass +class AthenaConfig(AdapterConfig): + """ + Database and relation-level configs. + + Args: + work_group: Identifier of Athena workgroup. + skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped. + s3_staging_dir: S3 location to store Athena query results and metadata. + external_location: If set, the full S3 path in which the table will be saved. + partitioned_by: An array list of columns by which the table will be partitioned. + bucketed_by: An array list of columns to bucket data, ignored if using Iceberg. + bucket_count: The number of buckets for bucketing your data, ignored if using Iceberg. + table_type: The type of table, supports hive or iceberg. + ha: If the table should be built using the high-availability method. + format: The data format for the table. Supports ORC, PARQUET, AVRO, JSON, TEXTFILE. + write_compression: The compression type to use for any storage format + that allows compression to be specified. + field_delimiter: Custom field delimiter, for when format is set to TEXTFILE. + table_properties : Table properties to add to the table, valid for Iceberg only. + native_drop: Relation drop operations will be performed with SQL, not direct Glue API calls. + seed_by_insert: default behaviour uploads seed data to S3. + lf_tags_config: AWS lakeformation tags to associate with the table and columns. + seed_s3_upload_args: Dictionary containing boto3 ExtraArgs when uploading to S3. + partitions_limit: Maximum numbers of partitions when batching. + force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode. + unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp. + temp_schema: Define in which schema to create temporary tables used in incremental runs. + """ + + work_group: Optional[str] = None + skip_workgroup_check: bool = False + s3_staging_dir: Optional[str] = None + external_location: Optional[str] = None + partitioned_by: Optional[str] = None + bucketed_by: Optional[str] = None + bucket_count: Optional[str] = None + table_type: str = "hive" + ha: bool = False + format: str = "parquet" + write_compression: Optional[str] = None + field_delimiter: Optional[str] = None + table_properties: Optional[str] = None + native_drop: Optional[str] = None + seed_by_insert: bool = False + lf_tags_config: Optional[Dict[str, Any]] = None + seed_s3_upload_args: Optional[Dict[str, Any]] = None + partitions_limit: Optional[int] = None + force_batch: bool = False + unique_tmp_table_suffix: bool = False + temp_schema: Optional[str] = None + + +class AthenaAdapter(SQLAdapter): + BATCH_CREATE_PARTITION_API_LIMIT = 100 + BATCH_DELETE_PARTITION_API_LIMIT = 25 + INTEGER_MAX_VALUE_32_BIT_SIGNED = 0x7FFFFFFF + + ConnectionManager = AthenaConnectionManager + Relation = AthenaRelation + AdapterSpecificConfigs = AthenaConfig + Column = AthenaColumn + + quote_character: str = '"' # Presto quote character + + # There is no such concept as constraints in Athena + CONSTRAINT_SUPPORT = { + ConstraintType.check: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.not_null: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.unique: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.primary_key: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.foreign_key: ConstraintSupport.NOT_SUPPORTED, + } + + @classmethod + def date_function(cls) -> str: + return "now()" + + @classmethod + def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "string" + + @classmethod + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) + return "double" if decimals else "integer" + + @classmethod + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "timestamp" + + @available + def add_lf_tags_to_database(self, relation: AthenaRelation) -> None: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + if lf_tags := conn.credentials.lf_tags_database: + config = LfTagsConfig(enabled=True, tags=lf_tags) + with boto3_client_lock: + lf_client = client.session.client( + "lakeformation", + client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + manager = LfTagsManager(lf_client, relation, config) + manager.process_lf_tags_database() + else: + LOGGER.debug(f"Lakeformation is disabled for {relation}") + + @available + def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None: + config = LfTagsConfig(**lf_tags_config) + if config.enabled: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + with boto3_client_lock: + lf_client = client.session.client( + "lakeformation", + client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + manager = LfTagsManager(lf_client, relation, config) + manager.process_lf_tags() + return + LOGGER.debug(f"Lakeformation is disabled for {relation}") + + @available + def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str, Any]) -> None: + lf_config = LfGrantsConfig(**lf_grants_config) + if lf_config.data_cell_filters.enabled: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + with boto3_client_lock: + lf = client.session.client( + "lakeformation", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(catalog) + lf_permissions = LfPermissions(catalog_id, relation, lf) # type: ignore + lf_permissions.process_filters(lf_config) + lf_permissions.process_permissions(lf_config) + + @lru_cache() + def _get_work_group(self, work_group: str) -> GetWorkGroupOutputTypeDef: + """ + helper function to cache the result of the get_work_group to avoid APIs throttling + """ + LOGGER.debug("get_work_group for %s", work_group) + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + with boto3_client_lock: + athena_client = client.session.client( + "athena", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + return athena_client.get_work_group(WorkGroup=work_group) + + @available + def is_work_group_output_location_enforced(self) -> bool: + conn = self.connections.get_thread_connection() + creds = conn.credentials + + if creds.work_group and not creds.skip_workgroup_check: + work_group = self._get_work_group(creds.work_group) + output_location = ( + work_group.get("WorkGroup", {}) + .get("Configuration", {}) + .get("ResultConfiguration", {}) + .get("OutputLocation", None) + ) + + output_location_enforced = ( + work_group.get("WorkGroup", {}).get("Configuration", {}).get("EnforceWorkGroupConfiguration", False) + ) + + return output_location is not None and output_location_enforced + else: + return False + + def _s3_table_prefix( + self, s3_data_dir: Optional[str], s3_tmp_table_dir: Optional[str], is_temporary_table: bool + ) -> str: + """ + Returns the root location for storing tables in S3. + This is `s3_data_dir`, if set at the model level, the s3_data_dir of the connection if provided, + and `s3_staging_dir/tables/` if nothing provided as data dir. + We generate a value here even if `s3_data_dir` is not set, + since creating a seed table requires a non-default location. + + When `s3_tmp_table_dir` is set, we use that as the root location for temporary tables. + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + + s3_tmp_table_dir = s3_tmp_table_dir or creds.s3_tmp_table_dir + if s3_tmp_table_dir and is_temporary_table: + return s3_tmp_table_dir + + if s3_data_dir is not None: + return s3_data_dir + return path.join(creds.s3_staging_dir, "tables") + + def _s3_data_naming(self, s3_data_naming: Optional[str]) -> S3DataNaming: + """ + Returns the s3 data naming strategy if provided, otherwise the value from the connection. + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + if s3_data_naming is not None: + return S3DataNaming(s3_data_naming) + + return S3DataNaming(creds.s3_data_naming) or S3DataNaming.TABLE_UNIQUE + + @available + def generate_s3_location( + self, + relation: AthenaRelation, + s3_data_dir: Optional[str] = None, + s3_data_naming: Optional[str] = None, + s3_tmp_table_dir: Optional[str] = None, + external_location: Optional[str] = None, + is_temporary_table: bool = False, + ) -> str: + """ + Returns either a UUID or database/table prefix for storing a table, + depending on the value of s3_table + """ + if external_location and not is_temporary_table: + return external_location.rstrip("/") + s3_path_table_part = relation.s3_path_table_part or relation.identifier + schema_name = relation.schema + table_prefix = self._s3_table_prefix(s3_data_dir, s3_tmp_table_dir, is_temporary_table) + + mapping = { + S3DataNaming.UNIQUE: path.join(table_prefix, str(uuid4())), + S3DataNaming.TABLE: path.join(table_prefix, s3_path_table_part), + S3DataNaming.TABLE_UNIQUE: path.join(table_prefix, s3_path_table_part, str(uuid4())), + S3DataNaming.SCHEMA_TABLE: path.join(table_prefix, schema_name, s3_path_table_part), + S3DataNaming.SCHEMA_TABLE_UNIQUE: path.join(table_prefix, schema_name, s3_path_table_part, str(uuid4())), + } + + return mapping[self._s3_data_naming(s3_data_naming)] + + @available + def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseTypeDef]: + """ + Helper function to get a relation via Glue + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + try: + table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier) + except ClientError as e: + if e.response["Error"]["Code"] == "EntityNotFoundException": + LOGGER.debug(f"Table {relation.render()} does not exists - Ignoring") + return None + raise e + return table + + @available + def get_glue_table_type(self, relation: AthenaRelation) -> Optional[TableType]: + """ + Get the table type of the relation from Glue + """ + table = self.get_glue_table(relation) + if not table: + LOGGER.debug(f"Table {relation.render()} does not exist - Ignoring") + return None + + return get_table_type(table["Table"]) + + @available + def get_glue_table_location(self, relation: AthenaRelation) -> Optional[str]: + """ + Helper function to get location of a relation in S3. + Will return None if the table does not exist or does not have a location (views) + """ + table = self.get_glue_table(relation) + if not table: + LOGGER.debug(f"Table {relation.render()} does not exist - Ignoring") + return None + + table_type = get_table_type(table["Table"]) + table_location = table["Table"].get("StorageDescriptor", {}).get("Location") + if table_type.is_physical(): + if not table_location: + raise S3LocationException( + f"Relation {relation.render()} is of type '{table_type.value}' which requires a location, " + f"but no location returned by Glue." + ) + LOGGER.debug(f"{relation.render()} is stored in {table_location}") + return str(table_location) + return None + + @available + def clean_up_partitions(self, relation: AthenaRelation, where_condition: str) -> None: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + paginator = glue_client.get_paginator("get_partitions") + partition_params = { + "CatalogId": catalog_id, + "DatabaseName": relation.schema, + "TableName": relation.identifier, + "Expression": where_condition, + "ExcludeColumnSchema": True, + } + partition_pg = paginator.paginate(**partition_params) + partitions = partition_pg.build_full_result().get("Partitions") + for partition in partitions: + self.delete_from_s3(partition["StorageDescriptor"]["Location"]) + glue_client.delete_partition( + CatalogId=catalog_id, + DatabaseName=relation.schema, + TableName=relation.identifier, + PartitionValues=partition["Values"], + ) + + @available + def clean_up_table(self, relation: AthenaRelation) -> None: + # this check avoids issues for when the table location is an empty string + # or when the table does not exist and table location is None + if table_location := self.get_glue_table_location(relation): + self.delete_from_s3(table_location) + + @available + def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str: + return f"{suffix_initial}_{str(uuid4()).replace('-', '_')}" + + def quote(self, identifier: str) -> str: + return f"{self.quote_character}{identifier}{self.quote_character}" + + @available + def quote_seed_column( + self, column: str, quote_config: Optional[bool], quote_character: Optional[str] = None + ) -> str: + if quote_character: + old_value = self.quote_character + object.__setattr__(self, "quote_character", quote_character) + quoted_column = str(super().quote_seed_column(column, quote_config)) + object.__setattr__(self, "quote_character", old_value) + else: + quoted_column = str(super().quote_seed_column(column, quote_config)) + + return quoted_column + + @available + def upload_seed_to_s3( + self, + relation: AthenaRelation, + table: agate.Table, + s3_data_dir: Optional[str] = None, + s3_data_naming: Optional[str] = None, + external_location: Optional[str] = None, + seed_s3_upload_args: Optional[Dict[str, Any]] = None, + ) -> str: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + # TODO: consider using the workgroup default location when configured + s3_location = self.generate_s3_location( + relation, s3_data_dir, s3_data_naming, external_location=external_location + ) + bucket, prefix = self._parse_s3_path(s3_location) + + file_name = f"{relation.identifier}.csv" + object_name = path.join(prefix, file_name) + + with boto3_client_lock: + s3_client = client.session.client( + "s3", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + # This ensures cross-platform support, tempfile.NamedTemporaryFile does not + tmpfile = os.path.join(tempfile.gettempdir(), os.urandom(24).hex()) + table.to_csv(tmpfile, quoting=csv.QUOTE_NONNUMERIC) + s3_client.upload_file(tmpfile, bucket, object_name, ExtraArgs=seed_s3_upload_args) + os.remove(tmpfile) + + return str(s3_location) + + @available + def delete_from_s3(self, s3_path: str) -> None: + """ + Deletes files from s3 given a s3 path in the format: s3://my_bucket/prefix + Additionally, parses the response from the s3 delete request and raises + a DbtRuntimeError in case it included errors. + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + bucket_name, prefix = self._parse_s3_path(s3_path) + if self._s3_path_exists(bucket_name, prefix): + s3_resource = client.session.resource( + "s3", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + s3_bucket = s3_resource.Bucket(bucket_name) + LOGGER.debug(f"Deleting table data: path='{s3_path}', bucket='{bucket_name}', prefix='{prefix}'") + response = s3_bucket.objects.filter(Prefix=prefix).delete() + is_all_successful = True + for res in response: + if "Errors" in res: + for err in res["Errors"]: + is_all_successful = False + LOGGER.error( + "Failed to delete files: Key='{}', Code='{}', Message='{}', s3_bucket='{}'", + err["Key"], + err["Code"], + err["Message"], + bucket_name, + ) + if is_all_successful is False: + raise DbtRuntimeError("Failed to delete files from S3.") + else: + LOGGER.debug("S3 path does not exist") + + @staticmethod + def _parse_s3_path(s3_path: str) -> Tuple[str, str]: + """ + Parses and splits a s3 path into bucket name and prefix. + This assumes that s3_path is a prefix instead of a URI. It adds a + trailing slash to the prefix, if there is none. + """ + o = urlparse(s3_path, allow_fragments=False) + bucket_name = o.netloc + prefix = o.path.lstrip("/").rstrip("/") + "/" + return bucket_name, prefix + + def _s3_path_exists(self, s3_bucket: str, s3_prefix: str) -> bool: + """Checks whether a given s3 path exists.""" + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + with boto3_client_lock: + s3_client = client.session.client( + "s3", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix) + return True if "Contents" in response else False + + def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List[Dict[str, Any]]: + table_catalog = { + "table_database": database, + "table_schema": table["DatabaseName"], + "table_name": table["Name"], + "table_type": get_table_type(table).value, + "table_comment": table.get("Parameters", {}).get("comment", table.get("Description", "")), + } + return [ + { + **table_catalog, + **{ + "column_name": col["Name"], + "column_index": idx, + "column_type": col["Type"], + "column_comment": col.get("Comment", ""), + }, + } + for idx, col in enumerate(table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", [])) + if self._is_current_column(col) + ] + + @staticmethod + def _get_one_table_for_non_glue_catalog(table: TableTypeDef, schema: str, database: str) -> List[Dict[str, Any]]: + table_catalog = { + "table_database": database, + "table_schema": schema, + "table_name": table["Name"], + "table_type": get_table_type(table).value, + "table_comment": table.get("Parameters", {}).get("comment", ""), + } + return [ + { + **table_catalog, + **{ + "column_name": col["Name"], + "column_index": idx, + "column_type": col["Type"], + "column_comment": col.get("Comment", ""), + }, + } + # TODO: review this code part as TableTypeDef class does not contain "Columns" attribute + for idx, col in enumerate(table["Columns"] + table.get("PartitionKeys", [])) + ] + + def _get_one_catalog( + self, + information_schema: InformationSchema, + schemas: Set[str], + used_schemas: FrozenSet[Tuple[str, str]], + ) -> agate.Table: + """ + This function is invoked by Adapter.get_catalog for each schema. + """ + data_catalog = self._get_data_catalog(information_schema.database) + data_catalog_type = get_catalog_type(data_catalog) + + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + if data_catalog_type == AthenaCatalogType.GLUE: + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + catalog = [] + paginator = glue_client.get_paginator("get_tables") + for schema in schemas: + kwargs = { + "DatabaseName": schema, + "MaxResults": 100, + } + # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 + # infers it from the account Id. + catalog_id = get_catalog_id(data_catalog) + if catalog_id: + kwargs["CatalogId"] = catalog_id + + for page in paginator.paginate(**kwargs): + for table in page["TableList"]: + catalog.extend(self._get_one_table_for_catalog(table, information_schema.database)) + table = agate.Table.from_object(catalog) + else: + with boto3_client_lock: + athena_client = client.session.client( + "athena", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + catalog = [] + paginator = athena_client.get_paginator("list_table_metadata") + for schema in schemas: + for page in paginator.paginate( + CatalogName=information_schema.database, + DatabaseName=schema, + MaxResults=50, # Limit supported by this operation + ): + for table in page["TableMetadataList"]: + catalog.extend( + self._get_one_table_for_non_glue_catalog(table, schema, information_schema.database) + ) + table = agate.Table.from_object(catalog) + + return self._catalog_filter_table(table, used_schemas) + + def _get_catalog_schemas(self, relation_configs: Iterable[RelationConfig]) -> AthenaSchemaSearchMap: + """ + Get the schemas from the catalog. + It's called by the `get_catalog` method. + """ + info_schema_name_map = AthenaSchemaSearchMap() + for relation_config in relation_configs: + relation = self.Relation.create_from(quoting=self.config, relation_config=relation_config) + info_schema_name_map.add(relation) + return info_schema_name_map + + def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]: + if database: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + if database.lower() == "awsdatacatalog": + with boto3_client_lock: + sts = client.session.client( + "sts", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + catalog_id = sts.get_caller_identity()["Account"] + return {"Name": database, "Type": "GLUE", "Parameters": {"catalog-id": catalog_id}} + with boto3_client_lock: + athena = client.session.client( + "athena", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + return athena.get_data_catalog(Name=database)["DataCatalog"] + return None + + @available + def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[BaseRelation]: + data_catalog = self._get_data_catalog(schema_relation.database) + if data_catalog and data_catalog["Type"] != "GLUE": + # For non-Glue Data Catalogs, use the original Athena query against INFORMATION_SCHEMA approach + return super().list_relations_without_caching(schema_relation) # type: ignore + + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + kwargs = { + "DatabaseName": schema_relation.schema, + } + if catalog_id := get_catalog_id(data_catalog): + kwargs["CatalogId"] = catalog_id + paginator = glue_client.get_paginator("get_tables") + try: + tables = paginator.paginate(**kwargs).build_full_result().get("TableList") + except ClientError as e: + # don't error out when schema doesn't exist + # this allows dbt to create and manage schemas/databases + if e.response["Error"]["Code"] == "EntityNotFoundException": + LOGGER.debug(f"Schema '{schema_relation.schema}' does not exist - Ignoring: {e}") + return [] + else: + raise e + + relations: List[BaseRelation] = [] + quote_policy = {"database": True, "schema": True, "identifier": True} + for table in tables: + if "TableType" not in table: + LOGGER.info(f"Table '{table['Name']}' has no TableType attribute - Ignoring") + continue + _type = table["TableType"] + _detailed_table_type = table.get("Parameters", {}).get("table_type", "") + if _type == "VIRTUAL_VIEW": + _type = self.Relation.View + else: + _type = self.Relation.Table + + relations.append( + self.Relation.create( + schema=schema_relation.schema, + database=schema_relation.database, + identifier=table["Name"], + quote_policy=quote_policy, + type=_type, + detailed_table_type=_detailed_table_type, + ) + ) + return relations + + def _get_one_catalog_by_relations( + self, + information_schema: InformationSchema, + relations: List[AthenaRelation], + used_schemas: FrozenSet[Tuple[str, str]], + ) -> "agate.Table": + """ + Overwrite of _get_one_catalog_by_relations for Athena, in order to use glue apis. + This function is invoked by Adapter.get_catalog_by_relations. + """ + _table_definitions = [] + for _rel in relations: + glue_table_definition = self.get_glue_table(_rel) + if glue_table_definition: + _table_definition = self._get_one_table_for_catalog(glue_table_definition["Table"], _rel.database) + _table_definitions.extend(_table_definition) + table = agate.Table.from_object(_table_definitions) + # picked from _catalog_filter_table, force database + schema to be strings + return table_from_rows( + table.rows, + table.column_names, + text_only_columns=["table_database", "table_schema", "table_name"], + ) + + @available + def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(src_relation.database) + src_catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + src_table = glue_client.get_table( + CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier + ).get("Table") + + src_table_get_partitions_paginator = glue_client.get_paginator("get_partitions") + src_table_partitions_result = src_table_get_partitions_paginator.paginate( + **{ + "CatalogId": src_catalog_id, + "DatabaseName": src_relation.schema, + "TableName": src_relation.identifier, + } + ) + src_table_partitions = src_table_partitions_result.build_full_result().get("Partitions") + + data_catalog = self._get_data_catalog(src_relation.database) + target_catalog_id = get_catalog_id(data_catalog) + + target_get_partitions_paginator = glue_client.get_paginator("get_partitions") + target_table_partitions_result = target_get_partitions_paginator.paginate( + **{ + "CatalogId": target_catalog_id, + "DatabaseName": target_relation.schema, + "TableName": target_relation.identifier, + } + ) + target_table_partitions = target_table_partitions_result.build_full_result().get("Partitions") + + target_table_version = { + "Name": target_relation.identifier, + "StorageDescriptor": src_table["StorageDescriptor"], + "PartitionKeys": src_table["PartitionKeys"], + "TableType": src_table["TableType"], + "Parameters": src_table["Parameters"], + "Description": src_table.get("Description", ""), + } + + # perform a table swap + glue_client.update_table( + CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableInput=target_table_version + ) + LOGGER.debug(f"Table {target_relation.render()} swapped with the content of {src_relation.render()}") + + # we delete the target table partitions in any case + # if source table has partitions we need to delete and add partitions + # it source table hasn't any partitions we need to delete target table partitions + if target_table_partitions: + for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT): + glue_client.batch_delete_partition( + CatalogId=target_catalog_id, + DatabaseName=target_relation.schema, + TableName=target_relation.identifier, + PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch], + ) + + if src_table_partitions: + for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT): + glue_client.batch_create_partition( + CatalogId=target_catalog_id, + DatabaseName=target_relation.schema, + TableName=target_relation.identifier, + PartitionInputList=[ + { + "Values": partition["Values"], + "StorageDescriptor": partition["StorageDescriptor"], + "Parameters": partition["Parameters"], + } + for partition in partition_batch + ], + ) + + def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep: int) -> List[TableVersionTypeDef]: + """ + Given a table and the amount of its version to keep, it returns the versions to delete + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + paginator = glue_client.get_paginator("get_table_versions") + response_iterator = paginator.paginate( + **{ + "DatabaseName": relation.schema, + "TableName": relation.identifier, + } + ) + table_versions = response_iterator.build_full_result().get("TableVersions") + LOGGER.debug(f"Total table versions: {[v['VersionId'] for v in table_versions]}") + table_versions_ordered = sorted(table_versions, key=lambda i: int(i["Table"]["VersionId"]), reverse=True) + return table_versions_ordered[int(to_keep) :] + + @available + def expire_glue_table_versions(self, relation: AthenaRelation, to_keep: int, delete_s3: bool) -> List[str]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + versions_to_delete = self._get_glue_table_versions_to_expire(relation, to_keep) + LOGGER.debug(f"Versions to delete: {[v['VersionId'] for v in versions_to_delete]}") + + deleted_versions = [] + for v in versions_to_delete: + version = v["Table"]["VersionId"] + location = v["Table"]["StorageDescriptor"]["Location"] + try: + glue_client.delete_table_version( + CatalogId=catalog_id, + DatabaseName=relation.schema, + TableName=relation.identifier, + VersionId=str(version), + ) + deleted_versions.append(version) + LOGGER.debug(f"Deleted version {version} of table {relation.render()} ") + if delete_s3: + self.delete_from_s3(location) + LOGGER.debug(f"{location} was deleted") + except Exception as err: + LOGGER.debug(f"There was an error when expiring table version {version} with error: {err}") + return deleted_versions + + @available + def persist_docs_to_glue( + self, + relation: AthenaRelation, + model: Dict[str, Any], + persist_relation_docs: bool = False, + persist_column_docs: bool = False, + skip_archive_table_version: bool = False, + ) -> None: + """Save model/columns description to Glue Table metadata. + + :param relation: Relation we are performing the docs persist + :param model: The dbt model definition as a dict + :param persist_relation_docs: Flag indicating whether we want to persist the model description as glue table + description + :param persist_column_docs: Flag indicating whether we want to persist column description as glue column + description + :param skip_archive_table_version: if True, current table version will not be archived before creating new one. + The purpose is to avoid creating redundant table version if it already was created during the same dbt run + after CREATE OR REPLACE VIEW or ALTER TABLE statements. + Every dbt run should create not more than one table version. + """ + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + # By default, there is no need to update Glue Table + need_to_update_table = False + # Get Table from Glue + table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"] + # Prepare new version of Glue Table picking up significant fields + table_input = self._get_table_input(table) + table_parameters = table_input["Parameters"] + + # Update table description + if persist_relation_docs: + # Prepare dbt description + clean_table_description = ellipsis_comment(clean_sql_comment(model["description"])) + # Get current description from Glue + glue_table_description = table.get("Description", "") + # Get current description parameter from Glue + glue_table_comment = table["Parameters"].get("comment", "") + # Check that description is already attached to Glue table + if clean_table_description != glue_table_description or clean_table_description != glue_table_comment: + need_to_update_table = True + # Save dbt description + table_input["Description"] = clean_table_description + table_parameters["comment"] = clean_table_description + + # Get dbt model meta if available + meta: Dict[str, Any] = model.get("config", {}).get("meta", {}) + # Add some of dbt model config fields as table meta + meta["unique_id"] = model.get("unique_id") + meta["materialized"] = model.get("config", {}).get("materialized") + # Add dbt project metadata to table meta + meta["dbt_project_name"] = self.config.project_name + meta["dbt_project_version"] = self.config.version + # Prepare meta values for table properties and check if update is required + for meta_key, meta_value_raw in meta.items(): + if is_valid_table_parameter_key(meta_key): + meta_value = stringify_table_parameter_value(meta_value_raw) + if meta_value is not None: + # Check that meta value is already attached to Glue table + current_meta_value: Optional[str] = table_parameters.get(meta_key) + if current_meta_value is None or current_meta_value != meta_value: + need_to_update_table = True + # Save Glue table parameter + table_parameters[meta_key] = meta_value + else: + LOGGER.warning(f"Meta value for key '{meta_key}' is not supported and will be ignored") + else: + LOGGER.warning(f"Meta key '{meta_key}' is not supported and will be ignored") + + # Update column comments + if persist_column_docs: + # Process every column + for col_obj in table_input["StorageDescriptor"]["Columns"]: + # Get column description from dbt + col_name = col_obj["Name"] + if col_name in model["columns"]: + col_comment = model["columns"][col_name]["description"] + # Prepare column description from dbt + clean_col_comment = ellipsis_comment(clean_sql_comment(col_comment)) + # Get current column comment from Glue + glue_col_comment = col_obj.get("Comment", "") + # Check that column description is already attached to Glue table + if glue_col_comment != clean_col_comment: + need_to_update_table = True + # Save column description from dbt + col_obj["Comment"] = clean_col_comment + + # Get dbt model column meta if available + col_meta: Dict[str, Any] = model["columns"][col_name].get("meta", {}) + # Add empty Parameters dictionary if not present + if col_meta and "Parameters" not in col_obj.keys(): + col_obj["Parameters"] = {} + # Prepare meta values for column properties and check if update is required + for meta_key, meta_value_raw in col_meta.items(): + if is_valid_table_parameter_key(meta_key): + meta_value = stringify_table_parameter_value(meta_value_raw) + if meta_value is not None: + # Check if meta value is already attached to Glue column + col_current_meta_value: Optional[str] = col_obj["Parameters"].get(meta_key) + if col_current_meta_value is None or col_current_meta_value != meta_value: + need_to_update_table = True + # Save Glue column parameter + col_obj["Parameters"][meta_key] = meta_value + else: + LOGGER.warning( + f"Column meta value for key '{meta_key}' is not supported and will be ignored" + ) + else: + LOGGER.warning(f"Column meta key '{meta_key}' is not supported and will be ignored") + + # Update Glue Table only if table/column description is modified. + # It prevents redundant schema version creating after incremental runs. + if need_to_update_table: + table_input["Parameters"] = table_parameters + glue_client.update_table( + CatalogId=catalog_id, + DatabaseName=relation.schema, + TableInput=table_input, + SkipArchive=skip_archive_table_version, + ) + + def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse: + if not submission_result: + return AdapterResponse(_message="ERROR") + return AdapterResponse(_message="OK") + + @property + def default_python_submission_method(self) -> str: + return "athena_helper" + + @property + def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: + return {"athena_helper": AthenaPythonJobHelper} + + @available + def list_schemas(self, database: str) -> List[str]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + paginator = glue_client.get_paginator("get_databases") + result = [] + for page in paginator.paginate(): + result.extend([schema["Name"] for schema in page["DatabaseList"]]) + return result + + @staticmethod + def _is_current_column(col: ColumnTypeDef) -> bool: + """ + Check if a column is explicitly set as not current. If not, it is considered as current. + """ + return bool(col.get("Parameters", {}).get("iceberg.field.current") != "false") + + @available + def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + get_table_kwargs = dict( + DatabaseName=relation.schema, + Name=relation.identifier, + ) + if catalog_id: + get_table_kwargs["CatalogId"] = catalog_id + + try: + table = glue_client.get_table(**get_table_kwargs)["Table"] + except ClientError as e: + if e.response["Error"]["Code"] == "EntityNotFoundException": + LOGGER.debug("table not exist, catching the error") + return [] + else: + LOGGER.error(e) + raise e + + table_type = get_table_type(table) + + columns = [c for c in table["StorageDescriptor"]["Columns"] if self._is_current_column(c)] + partition_keys = table.get("PartitionKeys", []) + + LOGGER.debug(f"Columns in relation {relation.identifier}: {columns + partition_keys}") + + return [ + AthenaColumn(column=c["Name"], dtype=c["Type"], table_type=table_type) for c in columns + partition_keys + ] + + @available + def delete_from_glue_catalog(self, relation: AthenaRelation) -> None: + schema_name = relation.schema + table_name = relation.identifier + + conn = self.connections.get_thread_connection() + creds = conn.credentials + client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + + with boto3_client_lock: + glue_client = client.session.client( + "glue", + region_name=client.region_name, + config=get_boto3_config(num_retries=creds.effective_num_retries), + ) + + try: + glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name) + LOGGER.debug(f"Deleted table from glue catalog: {relation.render()}") + except ClientError as e: + if e.response["Error"]["Code"] == "EntityNotFoundException": + LOGGER.debug(f"Table {relation.render()} does not exist and will not be deleted, ignoring") + else: + LOGGER.error(e) + raise e + + @available.parse_none + def valid_snapshot_target(self, relation: BaseRelation) -> None: + """Log an error to help developers migrate to the new snapshot logic""" + super().valid_snapshot_target(relation) + columns = self.get_columns_in_relation(relation) + names = {c.name.lower() for c in columns} + + table_columns = [col for col in names if not col.startswith("dbt_") and col != "is_current_record"] + + if "dbt_unique_key" in names: + sql = self._generate_snapshot_migration_sql(relation=relation, table_columns=table_columns) + msg = ( + f"{'!' * 90}\n" + "The snapshot logic of dbt-athena has changed in an incompatible way to be more consistent " + "with the dbt-core implementation.\nYou will need to migrate your existing snapshot tables to be " + "able to keep using them with the latest dbt-athena version.\nYou can find more information " + "in the release notes:\nhttps://github.com/dbt-athena/dbt-athena/releases\n" + f"{'!' * 90}\n\n" + "You can use the example query below as a baseline to perform the migration:\n\n" + f"{'-' * 90}\n" + f"{sql}\n" + f"{'-' * 90}\n\n" + ) + LOGGER.error(msg) + raise SnapshotMigrationRequired("Look into 1.5 dbt-athena docs for the complete migration procedure") + + def _generate_snapshot_migration_sql(self, relation: AthenaRelation, table_columns: List[str]) -> str: + """Generate a sequence of queries that can be used to migrate the existing table to the new format. + + The queries perform the following steps: + - Backup the existing table + - Make the necessary modifications and store the results in a staging table + - Delete the target table (users might have to delete the S3 files manually) + - Copy the content of the staging table to the final table + - Delete the staging table + """ + col_csv = f", \n{' ' * 16}".join(table_columns) + staging_relation = relation.incorporate( + path={"identifier": relation.identifier + "__dbt_tmp_migration_staging"} + ) + ctas = dedent( + f"""\ + select + {col_csv} , + dbt_snapshot_at as dbt_updated_at, + dbt_valid_from, + if(dbt_valid_to > cast('9000-01-01' as timestamp), null, dbt_valid_to) as dbt_valid_to, + dbt_scd_id + from {relation} + where dbt_change_type != 'delete' + ;""" + ) + staging_sql = self.execute_macro( + "create_table_as", kwargs=dict(temporary=True, relation=staging_relation, compiled_code=ctas) + ) + + backup_relation = relation.incorporate(path={"identifier": relation.identifier + "__dbt_tmp_migration_backup"}) + backup_sql = self.execute_macro( + "create_table_as", + kwargs=dict(temporary=True, relation=backup_relation, compiled_code=f"select * from {relation};"), + ) + + drop_target_sql = f"drop table {relation.render_hive()};" + + copy_to_target_sql = self.execute_macro( + "create_table_as", kwargs=dict(relation=relation, compiled_code=f"select * from {staging_relation};") + ) + + drop_staging_sql = f"drop table {staging_relation.render_hive()};" + + return "\n".join( + [ + "-- Backup original table", + backup_sql.strip(), + "\n\n-- Store new results in staging table", + staging_sql.strip(), + "\n\n-- Drop target table\n" + "-- Note: you will need to manually remove the S3 files if you have a static table location\n", + drop_target_sql.strip(), + "\n\n-- Copy staging to target", + copy_to_target_sql.strip(), + "\n\n-- Drop staging table", + drop_staging_sql.strip(), + ] + ) + + @available + def is_list(self, value: Any) -> bool: + """ + This function is intended to test whether a Jinja object is + a list since this is complicated with purely Jinja syntax. + """ + return isinstance(value, list) + + @staticmethod + def _get_table_input(table: TableTypeDef) -> TableInputTypeDef: + """ + Prepare Glue Table dictionary to be a table_input argument of update_table() method. + + This is needed because update_table() does not accept some read-only fields of table dictionary + returned by get_table() method. + """ + return {k: v for k, v in table.items() if k in TableInputTypeDef.__annotations__} + + @available + def run_query_with_partitions_limit_catching(self, sql: str) -> str: + try: + cursor = self._run_query(sql, catch_partitions_limit=True) + except OperationalError as e: + if "TOO_MANY_OPEN_PARTITIONS" in str(e): + return "TOO_MANY_OPEN_PARTITIONS" + raise e + return f'{{"rowcount":{cursor.rowcount},"data_scanned_in_bytes":{cursor.data_scanned_in_bytes}}}' + + @available + def format_partition_keys(self, partition_keys: List[str]) -> str: + return ", ".join([self.format_one_partition_key(k) for k in partition_keys]) + + @available + def format_one_partition_key(self, partition_key: str) -> str: + """Check if partition key uses Iceberg hidden partitioning or bucket partitioning""" + hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower()) + bucket = re.search(r"bucket\((.+),", partition_key.lower()) + if hidden: + return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" + elif bucket: + return bucket.group(1) + else: + return partition_key.lower() + + @available + def murmur3_hash(self, value: Any, num_buckets: int) -> int: + """ + Computes a hash for the given value using the MurmurHash3 algorithm and returns a bucket number. + + This method was adopted from https://github.com/apache/iceberg-python/blob/main/pyiceberg/transforms.py#L240 + """ + if isinstance(value, int): # int, long + hash_value = mmh3.hash(struct.pack("<q", value)) + elif isinstance(value, (datetime, date)): # date, time, timestamp, timestampz + timestamp = int(value.timestamp()) if isinstance(value, datetime) else int(value.strftime("%s")) + hash_value = mmh3.hash(struct.pack("<q", timestamp)) + elif isinstance(value, (str, bytes)): # string + hash_value = mmh3.hash(value) + else: + raise TypeError(f"Need to add support data type for hashing: {type(value)}") + + return int((hash_value & self.INTEGER_MAX_VALUE_32_BIT_SIGNED) % num_buckets) + + @available + def format_value_for_partition(self, value: Any, column_type: str) -> Tuple[str, str]: + """Formats a value based on its column type for inclusion in a SQL query.""" + comp_func = "=" # Default comparison function + if value is None: + return "null", " is " + elif column_type == "integer": + return str(value), comp_func + elif column_type == "string": + # Properly escape single quotes in the string value + escaped_value = str(value).replace("'", "''") + return f"'{escaped_value}'", comp_func + elif column_type == "date": + return f"DATE'{value}'", comp_func + elif column_type == "timestamp": + return f"TIMESTAMP'{value}'", comp_func + else: + # Raise an error for unsupported column types + raise ValueError(f"Unsupported column type: {column_type}") + + @available + def run_operation_with_potential_multiple_runs(self, query: str, op: str) -> None: + while True: + try: + self._run_query(query, catch_partitions_limit=False) + break + except OperationalError as e: + if f"ICEBERG_{op.upper()}_MORE_RUNS_NEEDED" not in str(e): + raise e + + def _run_query(self, sql: str, catch_partitions_limit: bool) -> AthenaCursor: + query = self.connections._add_query_comment(sql) + conn = self.connections.get_thread_connection() + cursor: AthenaCursor = conn.handle.cursor() + LOGGER.debug(f"Running Athena query:\n{query}") + try: + cursor.execute(query, catch_partitions_limit=catch_partitions_limit) + except OperationalError as e: + LOGGER.debug(f"CAUGHT EXCEPTION: {e}") + raise e + return cursor diff --git a/dbt-athena/src/dbt/adapters/athena/lakeformation.py b/dbt-athena/src/dbt/adapters/athena/lakeformation.py new file mode 100644 index 00000000..86b51d01 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/lakeformation.py @@ -0,0 +1,287 @@ +"""AWS Lakeformation permissions management helper utilities.""" + +from typing import Dict, List, Optional, Sequence, Set, Union + +from dbt_common.exceptions import DbtRuntimeError +from mypy_boto3_lakeformation import LakeFormationClient +from mypy_boto3_lakeformation.type_defs import ( + AddLFTagsToResourceResponseTypeDef, + BatchPermissionsRequestEntryTypeDef, + ColumnLFTagTypeDef, + DataCellsFilterTypeDef, + GetResourceLFTagsResponseTypeDef, + LFTagPairTypeDef, + RemoveLFTagsFromResourceResponseTypeDef, + ResourceTypeDef, +) +from pydantic import BaseModel + +from dbt.adapters.athena.relation import AthenaRelation +from dbt.adapters.events.logging import AdapterLogger + +logger = AdapterLogger("AthenaLakeFormation") + + +class LfTagsConfig(BaseModel): + enabled: bool = False + tags: Optional[Dict[str, str]] = None + tags_columns: Optional[Dict[str, Dict[str, List[str]]]] = None + inherited_tags: Optional[List[str]] = None + + +class LfTagsManager: + def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_tags_config: LfTagsConfig): + self.lf_client = lf_client + self.database = relation.schema + self.table = relation.identifier + self.lf_tags = lf_tags_config.tags + self.lf_tags_columns = lf_tags_config.tags_columns + self.lf_inherited_tags = set(lf_tags_config.inherited_tags) if lf_tags_config.inherited_tags else set() + + def process_lf_tags_database(self) -> None: + if self.lf_tags: + database_resource = {"Database": {"Name": self.database}} + response = self.lf_client.add_lf_tags_to_resource( + Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + + def process_lf_tags(self) -> None: + table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}} + existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource) + self._remove_lf_tags_columns(existing_lf_tags) + self._apply_lf_tags_table(table_resource, existing_lf_tags) + self._apply_lf_tags_columns() + + @staticmethod + def _column_tags_to_remove( + lf_tags_columns: List[ColumnLFTagTypeDef], lf_inherited_tags: Set[str] + ) -> Dict[str, Dict[str, List[str]]]: + to_remove = {} + + for column in lf_tags_columns: + non_inherited_tags = [tag for tag in column["LFTags"] if not tag["TagKey"] in lf_inherited_tags] + for tag in non_inherited_tags: + tag_key = tag["TagKey"] + tag_value = tag["TagValues"][0] + if tag_key not in to_remove: + to_remove[tag_key] = {tag_value: [column["Name"]]} + elif tag_value not in to_remove[tag_key]: + to_remove[tag_key][tag_value] = [column["Name"]] + else: + to_remove[tag_key][tag_value].append(column["Name"]) + + return to_remove + + def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTypeDef) -> None: + lf_tags_columns = existing_lf_tags.get("LFTagsOnColumns", []) + logger.debug(f"COLUMNS: {lf_tags_columns}") + if lf_tags_columns: + to_remove = LfTagsManager._column_tags_to_remove(lf_tags_columns, self.lf_inherited_tags) + logger.debug(f"TO REMOVE: {to_remove}") + for tag_key, tag_config in to_remove.items(): + for tag_value, columns in tag_config.items(): + resource = { + "TableWithColumns": {"DatabaseName": self.database, "Name": self.table, "ColumnNames": columns} + } + response = self.lf_client.remove_lf_tags_from_resource( + Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}] + ) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove") + + @staticmethod + def _table_tags_to_remove( + lf_tags_table: List[LFTagPairTypeDef], lf_tags: Optional[Dict[str, str]], lf_inherited_tags: Set[str] + ) -> Dict[str, Sequence[str]]: + return { + tag["TagKey"]: tag["TagValues"] + for tag in lf_tags_table + if tag["TagKey"] not in (lf_tags or {}) + if tag["TagKey"] not in lf_inherited_tags + } + + def _apply_lf_tags_table( + self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef + ) -> None: + lf_tags_table = existing_lf_tags.get("LFTagsOnTable", []) + logger.debug(f"EXISTING TABLE TAGS: {lf_tags_table}") + logger.debug(f"CONFIG TAGS: {self.lf_tags}") + + to_remove = LfTagsManager._table_tags_to_remove(lf_tags_table, self.lf_tags, self.lf_inherited_tags) + + logger.debug(f"TAGS TO REMOVE: {to_remove}") + if to_remove: + response = self.lf_client.remove_lf_tags_from_resource( + Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags, "remove") + + if self.lf_tags: + response = self.lf_client.add_lf_tags_to_resource( + Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + + def _apply_lf_tags_columns(self) -> None: + if self.lf_tags_columns: + for tag_key, tag_config in self.lf_tags_columns.items(): + for tag_value, columns in tag_config.items(): + resource = { + "TableWithColumns": {"DatabaseName": self.database, "Name": self.table, "ColumnNames": columns} + } + response = self.lf_client.add_lf_tags_to_resource( + Resource=resource, + LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}], + ) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}) + + def _parse_and_log_lf_response( + self, + response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef], + columns: Optional[List[str]] = None, + lf_tags: Optional[Dict[str, str]] = None, + verb: str = "add", + ) -> None: + table_appendix = f".{self.table}" if self.table else "" + columns_appendix = f" for columns {columns}" if columns else "" + resource_msg = self.database + table_appendix + columns_appendix + if failures := response.get("Failures", []): + base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg + for failure in failures: + tag = failure.get("LFTag", {}).get("TagKey") + error = failure.get("Error", {}).get("ErrorMessage") + logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}") + raise DbtRuntimeError(base_msg) + logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg) + + +class FilterConfig(BaseModel): + row_filter: str + column_names: List[str] = [] + principals: List[str] = [] + + def to_api_repr(self, catalog_id: str, database: str, table: str, name: str) -> DataCellsFilterTypeDef: + return { + "TableCatalogId": catalog_id, + "DatabaseName": database, + "TableName": table, + "Name": name, + "RowFilter": {"FilterExpression": self.row_filter}, + "ColumnNames": self.column_names, + "ColumnWildcard": {"ExcludedColumnNames": []}, + } + + def to_update(self, existing: DataCellsFilterTypeDef) -> bool: + return self.row_filter != existing["RowFilter"]["FilterExpression"] or set(self.column_names) != set( + existing["ColumnNames"] + ) + + +class DataCellFiltersConfig(BaseModel): + enabled: bool = False + filters: Dict[str, FilterConfig] + + +class LfGrantsConfig(BaseModel): + data_cell_filters: DataCellFiltersConfig + + +class LfPermissions: + def __init__(self, catalog_id: str, relation: AthenaRelation, lf_client: LakeFormationClient) -> None: + self.catalog_id = catalog_id + self.relation = relation + self.database: str = relation.schema + self.table: str = relation.identifier + self.lf_client = lf_client + + def get_filters(self) -> Dict[str, DataCellsFilterTypeDef]: + table_resource = {"CatalogId": self.catalog_id, "DatabaseName": self.database, "Name": self.table} + return {f["Name"]: f for f in self.lf_client.list_data_cells_filter(Table=table_resource)["DataCellsFilters"]} + + def process_filters(self, config: LfGrantsConfig) -> None: + current_filters = self.get_filters() + logger.debug(f"CURRENT FILTERS: {current_filters}") + + to_drop = [f for name, f in current_filters.items() if name not in config.data_cell_filters.filters] + logger.debug(f"FILTERS TO DROP: {to_drop}") + for f in to_drop: + self.lf_client.delete_data_cells_filter( + TableCatalogId=f["TableCatalogId"], + DatabaseName=f["DatabaseName"], + TableName=f["TableName"], + Name=f["Name"], + ) + + to_add = [ + f.to_api_repr(self.catalog_id, self.database, self.table, name) + for name, f in config.data_cell_filters.filters.items() + if name not in current_filters + ] + logger.debug(f"FILTERS TO ADD: {to_add}") + for f in to_add: + self.lf_client.create_data_cells_filter(TableData=f) + + to_update = [ + f.to_api_repr(self.catalog_id, self.database, self.table, name) + for name, f in config.data_cell_filters.filters.items() + if name in current_filters and f.to_update(current_filters[name]) + ] + logger.debug(f"FILTERS TO UPDATE: {to_update}") + for f in to_update: + self.lf_client.update_data_cells_filter(TableData=f) + + def process_permissions(self, config: LfGrantsConfig) -> None: + for name, f in config.data_cell_filters.filters.items(): + logger.debug(f"Start processing permissions for filter: {name}") + current_permissions = self.lf_client.list_permissions( + Resource={ + "DataCellsFilter": { + "TableCatalogId": self.catalog_id, + "DatabaseName": self.database, + "TableName": self.table, + "Name": name, + } + } + )["PrincipalResourcePermissions"] + + current_principals = {p["Principal"]["DataLakePrincipalIdentifier"] for p in current_permissions} + + to_revoke = {p for p in current_principals if p not in f.principals} + if to_revoke: + self.lf_client.batch_revoke_permissions( + CatalogId=self.catalog_id, + Entries=[self._permission_entry(name, principal, idx) for idx, principal in enumerate(to_revoke)], + ) + revoke_principals_msg = "\n".join(to_revoke) + logger.debug(f"Revoked permissions for filter {name} from principals:\n{revoke_principals_msg}") + else: + logger.debug(f"No redundant permissions found for filter: {name}") + + to_add = {p for p in f.principals if p not in current_principals} + if to_add: + self.lf_client.batch_grant_permissions( + CatalogId=self.catalog_id, + Entries=[self._permission_entry(name, principal, idx) for idx, principal in enumerate(to_add)], + ) + add_principals_msg = "\n".join(to_add) + logger.debug(f"Granted permissions for filter {name} to principals:\n{add_principals_msg}") + else: + logger.debug(f"No new permissions added for filter {name}") + + logger.debug(f"Permissions are set to be consistent with config for filter: {name}") + + def _permission_entry(self, filter_name: str, principal: str, idx: int) -> BatchPermissionsRequestEntryTypeDef: + return { + "Id": str(idx), + "Principal": {"DataLakePrincipalIdentifier": principal}, + "Resource": { + "DataCellsFilter": { + "TableCatalogId": self.catalog_id, + "DatabaseName": self.database, + "TableName": self.table, + "Name": filter_name, + } + }, + "Permissions": ["SELECT"], + "PermissionsWithGrantOption": [], + } diff --git a/dbt-athena/src/dbt/adapters/athena/python_submissions.py b/dbt-athena/src/dbt/adapters/athena/python_submissions.py new file mode 100644 index 00000000..5a3799ec --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/python_submissions.py @@ -0,0 +1,256 @@ +import time +from functools import cached_property +from typing import Any, Dict + +import botocore +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.config import AthenaSparkSessionConfig +from dbt.adapters.athena.connections import AthenaCredentials +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.athena.session import AthenaSparkSessionManager +from dbt.adapters.base import PythonJobHelper + +SUBMISSION_LANGUAGE = "python" + + +class AthenaPythonJobHelper(PythonJobHelper): + """ + Default helper to execute python models with Athena Spark. + + Args: + PythonJobHelper (PythonJobHelper): The base python helper class + """ + + def __init__(self, parsed_model: Dict[Any, Any], credentials: AthenaCredentials) -> None: + """ + Initialize spark config and connection. + + Args: + parsed_model (Dict[Any, Any]): The parsed python model. + credentials (AthenaCredentials): Credentials for Athena connection. + """ + self.relation_name = parsed_model.get("relation_name", None) + self.config = AthenaSparkSessionConfig( + parsed_model.get("config", {}), + polling_interval=credentials.poll_interval, + retry_attempts=credentials.num_retries, + ) + self.spark_connection = AthenaSparkSessionManager( + credentials, self.timeout, self.polling_interval, self.engine_config, self.relation_name + ) + + @cached_property + def timeout(self) -> int: + """ + Get the timeout value. + + Returns: + int: The timeout value in seconds. + """ + return self.config.set_timeout() + + @cached_property + def session_id(self) -> str: + """ + Get the session ID. + + Returns: + str: The session ID as a string. + """ + return str(self.spark_connection.get_session_id()) + + @cached_property + def polling_interval(self) -> float: + """ + Get the polling interval. + + Returns: + float: The polling interval in seconds. + """ + return self.config.set_polling_interval() + + @cached_property + def engine_config(self) -> Dict[str, int]: + """ + Get the engine configuration. + + Returns: + Dict[str, int]: A dictionary containing the engine configuration. + """ + return self.config.set_engine_config() + + @cached_property + def athena_client(self) -> Any: + """ + Get the Athena client. + + Returns: + Any: The Athena client object. + """ + return self.spark_connection.athena_client + + def get_current_session_status(self) -> Any: + """ + Get the current session status. + + Returns: + Any: The status of the session + """ + return self.spark_connection.get_session_status(self.session_id) + + def submit(self, compiled_code: str) -> Any: + """ + Submit a calculation to Athena. + + This function submits a calculation to Athena for execution using the provided compiled code. + It starts a calculation execution with the current session ID and the compiled code as the code block. + The function then polls until the calculation execution is completed, and retrieves the result. + If the execution is successful and completed, the result S3 URI is returned. Otherwise, a DbtRuntimeError + is raised with the execution status. + + Args: + compiled_code (str): The compiled code to submit for execution. + + Returns: + dict: The result S3 URI if the execution is successful and completed. + + Raises: + DbtRuntimeError: If the execution ends in a state other than "COMPLETED". + + """ + # Seeing an empty calculation along with main python model code calculation is submitted for almost every model + # Also, if not returning the result json, we are getting green ERROR messages instead of OK messages. + # And with this handling, the run model code in target folder every model under run folder seems to be empty + # Need to fix this work around solution + if compiled_code.strip(): + while True: + try: + LOGGER.debug( + f"Model {self.relation_name} - Using session: {self.session_id} to start calculation execution." + ) + calculation_execution_id = self.athena_client.start_calculation_execution( + SessionId=self.session_id, CodeBlock=compiled_code.lstrip() + )["CalculationExecutionId"] + break + except botocore.exceptions.ClientError as ce: + LOGGER.exception(f"Encountered client error: {ce}") + if ( + ce.response["Error"]["Code"] == "InvalidRequestException" + and "Session is in the BUSY state; needs to be IDLE to accept Calculations." + in ce.response["Error"]["Message"] + ): + LOGGER.exception("Going to poll until session is IDLE") + self.poll_until_session_idle() + except Exception as e: + raise DbtRuntimeError(f"Unable to start spark python code execution. Got: {e}") + execution_status = self.poll_until_execution_completion(calculation_execution_id) + LOGGER.debug(f"Model {self.relation_name} - Received execution status {execution_status}") + if execution_status == "COMPLETED": + try: + result = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + )["Result"] + except Exception as e: + LOGGER.error(f"Unable to retrieve results: Got: {e}") + result = {} + return result + else: + return {"ResultS3Uri": "string", "ResultType": "string", "StdErrorS3Uri": "string", "StdOutS3Uri": "string"} + + def poll_until_session_idle(self) -> None: + """ + Polls the session status until it becomes idle or exceeds the timeout. + + Raises: + DbtRuntimeError: If the session chosen is not available or if it does not become idle within the timeout. + """ + polling_interval = self.polling_interval + timer: float = 0 + while True: + session_status = self.get_current_session_status()["State"] + if session_status in ["TERMINATING", "TERMINATED", "DEGRADED", "FAILED"]: + LOGGER.debug( + f"Model {self.relation_name} - The session: {self.session_id} was not available. " + f"Got status: {session_status}. Will try with a different session." + ) + self.spark_connection.remove_terminated_session(self.session_id) + if "session_id" in self.__dict__: + del self.__dict__["session_id"] + break + if session_status == "IDLE": + break + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + LOGGER.debug( + f"Model {self.relation_name} - Session {self.session_id} did not become free within {self.timeout}" + " seconds. Will try with a different session." + ) + if "session_id" in self.__dict__: + del self.__dict__["session_id"] + break + + def poll_until_execution_completion(self, calculation_execution_id: str) -> Any: + """ + Poll the status of a calculation execution until it is completed, failed, or canceled. + + This function polls the status of a calculation execution identified by the given `calculation_execution_id` + until it is completed, failed, or canceled. It uses the Athena client to retrieve the status of the execution + and checks if the state is one of "COMPLETED", "FAILED", or "CANCELED". If the execution is not yet completed, + the function sleeps for a certain polling interval, which starts with the value of `self.polling_interval` and + doubles after each iteration until it reaches the `self.timeout` period. If the execution does not complete + within the timeout period, a `DbtRuntimeError` is raised. + + Args: + calculation_execution_id (str): The ID of the calculation execution to poll. + + Returns: + str: The final state of the calculation execution, which can be one of "COMPLETED", "FAILED" or "CANCELED". + + Raises: + DbtRuntimeError: If the calculation execution does not complete within the timeout period. + + """ + try: + polling_interval = self.polling_interval + timer: float = 0 + while True: + execution_response = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + ) + execution_session = execution_response.get("SessionId", None) + execution_status = execution_response.get("Status", None) + execution_result = execution_response.get("Result", None) + execution_stderr_s3_path = "" + if execution_result: + execution_stderr_s3_path = execution_result.get("StdErrorS3Uri", None) + + execution_status_state = "" + execution_status_reason = "" + if execution_status: + execution_status_state = execution_status.get("State", None) + execution_status_reason = execution_status.get("StateChangeReason", None) + + if execution_status_state in ["FAILED", "CANCELED"]: + raise DbtRuntimeError( + f"""Calculation Id: {calculation_execution_id} +Session Id: {execution_session} +Status: {execution_status_state} +Reason: {execution_status_reason} +Stderr s3 path: {execution_stderr_s3_path} +""" + ) + + if execution_status_state == "COMPLETED": + return execution_status_state + + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + self.athena_client.stop_calculation_execution(CalculationExecutionId=calculation_execution_id) + raise DbtRuntimeError( + f"Execution {calculation_execution_id} did not complete within {self.timeout} seconds." + ) + finally: + self.spark_connection.set_spark_session_load(self.session_id, -1) diff --git a/dbt-athena/src/dbt/adapters/athena/query_headers.py b/dbt-athena/src/dbt/adapters/athena/query_headers.py new file mode 100644 index 00000000..5220299a --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/query_headers.py @@ -0,0 +1,43 @@ +from typing import Any, Dict + +from dbt.adapters.base.query_headers import MacroQueryStringSetter, _QueryComment +from dbt.adapters.contracts.connection import AdapterRequiredConfig + + +class AthenaMacroQueryStringSetter(MacroQueryStringSetter): + def __init__(self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any]): + super().__init__(config, query_header_context) + self.comment = _AthenaQueryComment(None) + + +class _AthenaQueryComment(_QueryComment): + """ + Athena DDL does not always respect /* ... */ block quotations. + This function is the same as _QueryComment.add except that + a leading "-- " is prepended to the query_comment and any newlines + in the query_comment are replaced with " ". This allows the default + query_comment to be added to `create external table` statements. + """ + + def add(self, sql: str) -> str: + if not self.query_comment: + return sql + + # alter or vacuum statements don't seem to support properly query comments + # let's just exclude them + sql = sql.lstrip() + if any(sql.lower().startswith(keyword) for keyword in ["alter", "drop", "optimize", "vacuum", "msck"]): + return sql + + cleaned_query_comment = self.query_comment.strip().replace("\n", " ") + + if self.append: + # replace last ';' with '<comment>;' + sql = sql.rstrip() + if sql[-1] == ";": + sql = sql[:-1] + return f"{sql}\n-- /* {cleaned_query_comment} */;" + + return f"{sql}\n-- /* {cleaned_query_comment} */" + + return f"-- /* {cleaned_query_comment} */\n{sql}" diff --git a/dbt-athena/src/dbt/adapters/athena/relation.py b/dbt-athena/src/dbt/adapters/athena/relation.py new file mode 100644 index 00000000..0de8d784 --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/relation.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Optional, Set + +from mypy_boto3_glue.type_defs import TableTypeDef + +from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy + + +class TableType(Enum): + TABLE = "table" + VIEW = "view" + CTE = "cte" + MATERIALIZED_VIEW = "materializedview" + ICEBERG = "iceberg_table" + + def is_physical(self) -> bool: + return self in [TableType.TABLE, TableType.ICEBERG] + + +@dataclass +class AthenaIncludePolicy(Policy): + database: bool = True + schema: bool = True + identifier: bool = True + + +@dataclass +class AthenaHiveIncludePolicy(Policy): + database: bool = False + schema: bool = True + identifier: bool = True + + +@dataclass(frozen=True, eq=False, repr=False) +class AthenaRelation(BaseRelation): + quote_character: str = '"' # Presto quote character + include_policy: Policy = field(default_factory=lambda: AthenaIncludePolicy()) + s3_path_table_part: Optional[str] = None + detailed_table_type: Optional[str] = None # table_type option from the table Parameters in Glue Catalog + require_alias: bool = False + + def render_hive(self) -> str: + """ + Render relation with Hive format. Athena uses a Hive format for some DDL statements. + + See: + - https://aws.amazon.com/athena/faqs/ "Q: How do I create tables and schemas for my data on Amazon S3?" + - https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + """ + + old_quote_character = self.quote_character + object.__setattr__(self, "quote_character", "`") # Hive quote char + old_include_policy = self.include_policy + object.__setattr__(self, "include_policy", AthenaHiveIncludePolicy()) + rendered = self.render() + object.__setattr__(self, "quote_character", old_quote_character) + object.__setattr__(self, "include_policy", old_include_policy) + return str(rendered) + + def render_pure(self) -> str: + """ + Render relation without quotes characters. + This is needed for not standard executions like optimize and vacuum + """ + old_value = self.quote_character + object.__setattr__(self, "quote_character", "") + rendered = self.render() + object.__setattr__(self, "quote_character", old_value) + return str(rendered) + + +class AthenaSchemaSearchMap(Dict[InformationSchema, Dict[str, Set[Optional[str]]]]): + """A utility class to keep track of what information_schema tables to + search for what schemas and relations. The schema and relation values are all + lowercase to avoid duplication. + """ + + def add(self, relation: AthenaRelation) -> None: + key = relation.information_schema_only() + if key not in self: + self[key] = {} + if relation.schema is not None: + schema = relation.schema.lower() + relation_name = relation.name.lower() + if schema not in self[key]: + self[key][schema] = set() + self[key][schema].add(relation_name) + + +RELATION_TYPE_MAP = { + "EXTERNAL_TABLE": TableType.TABLE, + "EXTERNAL": TableType.TABLE, # type returned by federated query tables + "GOVERNED": TableType.TABLE, + "MANAGED_TABLE": TableType.TABLE, + "VIRTUAL_VIEW": TableType.VIEW, + "table": TableType.TABLE, + "view": TableType.VIEW, + "cte": TableType.CTE, + "materializedview": TableType.MATERIALIZED_VIEW, +} + + +def get_table_type(table: TableTypeDef) -> TableType: + table_full_name = ".".join(filter(None, [table.get("CatalogId"), table.get("DatabaseName"), table["Name"]])) + + input_table_type = table.get("TableType") + if input_table_type and input_table_type not in RELATION_TYPE_MAP: + raise ValueError(f"Table type {table['TableType']} is not supported for table {table_full_name}") + + if table.get("Parameters", {}).get("table_type", "").lower() == "iceberg": + _type = TableType.ICEBERG + elif not input_table_type: + raise ValueError(f"Table type cannot be None for table {table_full_name}") + else: + _type = RELATION_TYPE_MAP[input_table_type] + + LOGGER.debug(f"table_name : {table_full_name}") + LOGGER.debug(f"table type : {_type}") + + return _type diff --git a/dbt-athena/src/dbt/adapters/athena/s3.py b/dbt-athena/src/dbt/adapters/athena/s3.py new file mode 100644 index 00000000..94e1270f --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/s3.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class S3DataNaming(Enum): + UNIQUE = "unique" + TABLE = "table" + TABLE_UNIQUE = "table_unique" + SCHEMA_TABLE = "schema_table" + SCHEMA_TABLE_UNIQUE = "schema_table_unique" diff --git a/dbt-athena/src/dbt/adapters/athena/session.py b/dbt-athena/src/dbt/adapters/athena/session.py new file mode 100644 index 00000000..b346d13e --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/session.py @@ -0,0 +1,263 @@ +import json +import threading +import time +from functools import cached_property +from hashlib import md5 +from typing import Any, Dict +from uuid import UUID + +import boto3 +import boto3.session +from dbt_common.exceptions import DbtRuntimeError +from dbt_common.invocation import get_invocation_id + +from dbt.adapters.athena.config import get_boto3_config +from dbt.adapters.athena.constants import ( + DEFAULT_THREAD_COUNT, + LOGGER, + SESSION_IDLE_TIMEOUT_MIN, +) +from dbt.adapters.contracts.connection import Connection + +invocation_id = get_invocation_id() +spark_session_list: Dict[UUID, str] = {} +spark_session_load: Dict[UUID, int] = {} + + +def get_boto3_session(connection: Connection) -> boto3.session.Session: + return boto3.session.Session( + aws_access_key_id=connection.credentials.aws_access_key_id, + aws_secret_access_key=connection.credentials.aws_secret_access_key, + aws_session_token=connection.credentials.aws_session_token, + region_name=connection.credentials.region_name, + profile_name=connection.credentials.aws_profile_name, + ) + + +def get_boto3_session_from_credentials(credentials: Any) -> boto3.session.Session: + return boto3.session.Session( + aws_access_key_id=credentials.aws_access_key_id, + aws_secret_access_key=credentials.aws_secret_access_key, + aws_session_token=credentials.aws_session_token, + region_name=credentials.region_name, + profile_name=credentials.aws_profile_name, + ) + + +class AthenaSparkSessionManager: + """ + A helper class to manage Athena Spark Sessions. + """ + + def __init__( + self, + credentials: Any, + timeout: int, + polling_interval: float, + engine_config: Dict[str, int], + relation_name: str = "N/A", + ) -> None: + """ + Initialize the AthenaSparkSessionManager instance. + + Args: + credentials (Any): The credentials to be used. + timeout (int): The timeout value in seconds. + polling_interval (float): The polling interval in seconds. + engine_config (Dict[str, int]): The engine configuration. + + """ + self.credentials = credentials + self.timeout = timeout + self.polling_interval = polling_interval + self.engine_config = engine_config + self.lock = threading.Lock() + self.relation_name = relation_name + + @cached_property + def spark_threads(self) -> int: + """ + Get the number of Spark threads. + + Returns: + int: The number of Spark threads. If not found in the profile, returns the default thread count. + """ + if not DEFAULT_THREAD_COUNT: + LOGGER.debug(f"""Threads not found in profile. Got: {DEFAULT_THREAD_COUNT}""") + return 1 + return int(DEFAULT_THREAD_COUNT) + + @cached_property + def spark_work_group(self) -> str: + """ + Get the Spark work group. + + Returns: + str: The Spark work group. Raises an exception if not found in the profile. + """ + if not self.credentials.spark_work_group: + raise DbtRuntimeError(f"Expected spark_work_group in profile. Got: {self.credentials.spark_work_group}") + return str(self.credentials.spark_work_group) + + @cached_property + def athena_client(self) -> Any: + """ + Get the AWS Athena client. + + This function returns an AWS Athena client object that can be used to interact with the Athena service. + The client is created using the region name and profile name provided during object instantiation. + + Returns: + Any: The Athena client object. + + """ + return get_boto3_session_from_credentials(self.credentials).client( + "athena", config=get_boto3_config(num_retries=self.credentials.effective_num_retries) + ) + + @cached_property + def session_description(self) -> str: + """ + Converts the engine configuration to md5 hash value + + Returns: + str: A concatenated text of dbt invocation_id and engine configuration's md5 hash + """ + hash_desc = md5(json.dumps(self.engine_config, sort_keys=True, ensure_ascii=True).encode("utf-8")).hexdigest() + return f"dbt: {invocation_id} - {hash_desc}" + + def get_session_id(self, session_query_capacity: int = 1) -> UUID: + """ + Get a session ID for the Spark session. + When does a new session get created: + - When thread limit not reached + - When thread limit reached but same engine configuration session is not available + - When thread limit reached and same engine configuration session exist and it is busy running a python model + and has one python model in queue (determined by session_query_capacity). + + Returns: + UUID: The session ID. + """ + session_list = list(spark_session_list.items()) + + if len(session_list) < self.spark_threads: + LOGGER.debug( + f"Within thread limit, creating new session for model: {self.relation_name}" + f" with session description: {self.session_description}." + ) + return self.start_session() + else: + matching_session_id = next( + ( + session_id + for session_id, description in session_list + if description == self.session_description + and spark_session_load.get(session_id, 0) <= session_query_capacity + ), + None, + ) + if matching_session_id: + LOGGER.debug( + f"Over thread limit, matching session found for model: {self.relation_name}" + f" with session description: {self.session_description} and has capacity." + ) + self.set_spark_session_load(str(matching_session_id), 1) + return matching_session_id + else: + LOGGER.debug( + f"Over thread limit, matching session not found or found with over capacity. Creating new session" + f" for model: {self.relation_name} with session description: {self.session_description}." + ) + return self.start_session() + + def start_session(self) -> UUID: + """ + Start an Athena session. + + This function sends a request to the Athena service to start a session in the specified Spark workgroup. + It configures the session with specific engine configurations. If the session state is not IDLE, the function + polls until the session creation is complete. The response containing session information is returned. + + Returns: + dict: The session information dictionary. + + """ + description = self.session_description + response = self.athena_client.start_session( + Description=description, + WorkGroup=self.credentials.spark_work_group, + EngineConfiguration=self.engine_config, + SessionIdleTimeoutInMinutes=SESSION_IDLE_TIMEOUT_MIN, + ) + session_id = response["SessionId"] + if response["State"] != "IDLE": + self.poll_until_session_creation(session_id) + + with self.lock: + spark_session_list[UUID(session_id)] = self.session_description + spark_session_load[UUID(session_id)] = 1 + + return UUID(session_id) + + def poll_until_session_creation(self, session_id: str) -> None: + """ + Polls the status of an Athena session creation until it is completed or reaches the timeout. + + Args: + session_id (str): The ID of the session being created. + + Returns: + str: The final status of the session, which will be "IDLE" if the session creation is successful. + + Raises: + DbtRuntimeError: If the session creation fails, is terminated, or degrades during polling. + DbtRuntimeError: If the session does not become IDLE within the specified timeout. + + """ + polling_interval = self.polling_interval + timer: float = 0 + while True: + creation_status_response = self.get_session_status(session_id) + creation_status_state = creation_status_response.get("State", "") + creation_status_reason = creation_status_response.get("StateChangeReason", "") + if creation_status_state in ["FAILED", "TERMINATED", "DEGRADED"]: + raise DbtRuntimeError( + f"Unable to create session: {session_id}. Got status: {creation_status_state}" + f" with reason: {creation_status_reason}." + ) + elif creation_status_state == "IDLE": + LOGGER.debug(f"Session: {session_id} created") + break + time.sleep(polling_interval) + timer += polling_interval + if timer > self.timeout: + self.remove_terminated_session(session_id) + raise DbtRuntimeError(f"Session {session_id} did not create within {self.timeout} seconds.") + + def get_session_status(self, session_id: str) -> Any: + """ + Get the session status. + + Returns: + Any: The status of the session + """ + return self.athena_client.get_session_status(SessionId=session_id)["Status"] + + def remove_terminated_session(self, session_id: str) -> None: + """ + Removes session uuid from session list variable + + Returns: None + """ + with self.lock: + spark_session_list.pop(UUID(session_id), "Session id not found") + spark_session_load.pop(UUID(session_id), "Session id not found") + + def set_spark_session_load(self, session_id: str, change: int) -> None: + """ + Increase or decrease the session load variable + + Returns: None + """ + with self.lock: + spark_session_load[UUID(session_id)] = spark_session_load.get(UUID(session_id), 0) + change diff --git a/dbt-athena/src/dbt/adapters/athena/utils.py b/dbt-athena/src/dbt/adapters/athena/utils.py new file mode 100644 index 00000000..39ec755b --- /dev/null +++ b/dbt-athena/src/dbt/adapters/athena/utils.py @@ -0,0 +1,63 @@ +import json +import re +from enum import Enum +from typing import Any, Generator, List, Optional, TypeVar + +from mypy_boto3_athena.type_defs import DataCatalogTypeDef + +from dbt.adapters.athena.constants import LOGGER + + +def clean_sql_comment(comment: str) -> str: + split_and_strip = [line.strip() for line in comment.split("\n")] + return " ".join(line for line in split_and_strip if line) + + +def stringify_table_parameter_value(value: Any) -> Optional[str]: + """Convert any variable to string for Glue Table property.""" + try: + if isinstance(value, (dict, list)): + value_str: str = json.dumps(value) + else: + value_str = str(value) + return value_str[:512000] + except (TypeError, ValueError) as e: + # Handle non-stringifiable objects and non-serializable objects + LOGGER.warning(f"Non-stringifiable object. Error: {str(e)}") + return None + + +def is_valid_table_parameter_key(key: str) -> bool: + """Check if key is valid for Glue Table property according to official documentation.""" + # Simplified version of key pattern which works with re + # Original pattern can be found here https://docs.aws.amazon.com/glue/latest/webapi/API_Table.html + key_pattern: str = r"^[\u0020-\uD7FF\uE000-\uFFFD\t]*$" + return len(key) <= 255 and bool(re.match(key_pattern, key)) + + +def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: + return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None + + +class AthenaCatalogType(Enum): + GLUE = "GLUE" + LAMBDA = "LAMBDA" + HIVE = "HIVE" + + +def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]: + return AthenaCatalogType(catalog["Type"]) if catalog else None + + +T = TypeVar("T") + + +def get_chunks(lst: List[T], n: int) -> Generator[List[T], None, None]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def ellipsis_comment(s: str, max_len: int = 255) -> str: + """Ellipsis string if it exceeds max length""" + return f"{s[:(max_len - 3)]}..." if len(s) > max_len else s diff --git a/dbt-athena/src/dbt/include/athena/__init__.py b/dbt-athena/src/dbt/include/athena/__init__.py new file mode 100644 index 00000000..b177e5d4 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/__init__.py @@ -0,0 +1,3 @@ +import os + +PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt-athena/src/dbt/include/athena/dbt_project.yml b/dbt-athena/src/dbt/include/athena/dbt_project.yml new file mode 100644 index 00000000..471a15ed --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/dbt_project.yml @@ -0,0 +1,5 @@ +name: dbt_athena +version: 1.0 +config-version: 2 + +macro-paths: ['macros'] diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql new file mode 100644 index 00000000..bb106d3b --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/columns.sql @@ -0,0 +1,20 @@ +{% macro athena__get_columns_in_relation(relation) -%} + {{ return(adapter.get_columns_in_relation(relation)) }} +{% endmacro %} + +{% macro athena__get_empty_schema_sql(columns) %} + {%- set col_err = [] -%} + select + {% for i in columns %} + {%- set col = columns[i] -%} + {%- if col['data_type'] is not defined -%} + {{ col_err.append(col['name']) }} + {%- else -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + cast(null as {{ dml_data_type(col['data_type']) }}) as {{ col_name }}{{ ", " if not loop.last }} + {%- endif -%} + {%- endfor -%} + {%- if (col_err | length) > 0 -%} + {{ exceptions.column_type_missing(column_names=col_err) }} + {%- endif -%} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql new file mode 100644 index 00000000..ae7a187c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/metadata.sql @@ -0,0 +1,17 @@ +{% macro athena__get_catalog(information_schema, schemas) -%} + {{ return(adapter.get_catalog()) }} +{%- endmacro %} + + +{% macro athena__list_schemas(database) -%} + {{ return(adapter.list_schemas(database)) }} +{% endmacro %} + + +{% macro athena__list_relations_without_caching(schema_relation) %} + {{ return(adapter.list_relations_without_caching(schema_relation)) }} +{% endmacro %} + +{% macro athena__get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql new file mode 100644 index 00000000..7b95bd7b --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/persist_docs.sql @@ -0,0 +1,14 @@ +{% macro athena__persist_docs(relation, model, for_relation=true, for_columns=true) -%} + {% set persist_relation_docs = for_relation and config.persist_relation_docs()%} + {% set persist_column_docs = for_columns and config.persist_column_docs() and model.columns %} + {% if persist_relation_docs or persist_column_docs %} + {% do adapter.persist_docs_to_glue( + relation=relation, + model=model, + persist_relation_docs=persist_relation_docs, + persist_column_docs=persist_column_docs, + skip_archive_table_version=true + ) + %}} + {% endif %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql new file mode 100644 index 00000000..f668acf7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/python_submissions.sql @@ -0,0 +1,95 @@ +{%- macro athena__py_save_table_as(compiled_code, target_relation, optional_args={}) -%} + {% set location = optional_args.get("location") %} + {% set format = optional_args.get("format", "parquet") %} + {% set mode = optional_args.get("mode", "overwrite") %} + {% set write_compression = optional_args.get("write_compression", "snappy") %} + {% set partitioned_by = optional_args.get("partitioned_by") %} + {% set bucketed_by = optional_args.get("bucketed_by") %} + {% set sorted_by = optional_args.get("sorted_by") %} + {% set merge_schema = optional_args.get("merge_schema", true) %} + {% set bucket_count = optional_args.get("bucket_count") %} + {% set field_delimiter = optional_args.get("field_delimiter") %} + {% set spark_ctas = optional_args.get("spark_ctas", "") %} + +import pyspark + + +{{ compiled_code }} +def materialize(spark_session, df, target_relation): + import pandas + if isinstance(df, pyspark.sql.dataframe.DataFrame): + pass + elif isinstance(df, pandas.core.frame.DataFrame): + df = spark_session.createDataFrame(df) + else: + msg = f"{type(df)} is not a supported type for dbt Python materialization" + raise Exception(msg) + +{% if spark_ctas|length > 0 %} + df.createOrReplaceTempView("{{ target_relation.schema}}_{{ target_relation.identifier }}_tmpvw") + spark_session.sql(""" + {{ spark_ctas }} + select * from {{ target_relation.schema}}_{{ target_relation.identifier }}_tmpvw + """) +{% else %} + writer = df.write \ + .format("{{ format }}") \ + .mode("{{ mode }}") \ + .option("path", "{{ location }}") \ + .option("compression", "{{ write_compression }}") \ + .option("mergeSchema", "{{ merge_schema }}") \ + .option("delimiter", "{{ field_delimiter }}") + if {{ partitioned_by }} is not None: + writer = writer.partitionBy({{ partitioned_by }}) + if {{ bucketed_by }} is not None: + writer = writer.bucketBy({{ bucket_count }},{{ bucketed_by }}) + if {{ sorted_by }} is not None: + writer = writer.sortBy({{ sorted_by }}) + + writer.saveAsTable( + name="{{ target_relation.schema}}.{{ target_relation.identifier }}", + ) +{% endif %} + + return "Success: {{ target_relation.schema}}.{{ target_relation.identifier }}" + +{{ athena__py_get_spark_dbt_object() }} + +dbt = SparkdbtObj() +df = model(dbt, spark) +materialize(spark, df, dbt.this) +{%- endmacro -%} + +{%- macro athena__py_execute_query(query) -%} +{{ athena__py_get_spark_dbt_object() }} + +def execute_query(spark_session): + spark_session.sql("""{{ query }}""") + return "OK" + +dbt = SparkdbtObj() +execute_query(spark) +{%- endmacro -%} + +{%- macro athena__py_get_spark_dbt_object() -%} +def get_spark_df(identifier): + """ + Override the arguments to ref and source dynamically + + spark.table('awsdatacatalog.analytics_dev.model') + Raises pyspark.sql.utils.AnalysisException: + spark_catalog requires a single-part namespace, + but got [awsdatacatalog, analytics_dev] + + So the override removes the catalog component and only + provides the schema and identifer to spark.table() + """ + return spark.table(".".join(identifier.split(".")[1:]).replace('"', '')) + +class SparkdbtObj(dbtObj): + def __init__(self): + super().__init__(load_df_function=get_spark_df) + self.source = lambda *args: source(*args, dbt_load_df_function=get_spark_df) + self.ref = lambda *args: ref(*args, dbt_load_df_function=get_spark_df) + +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql new file mode 100644 index 00000000..611ffc59 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/relation.sql @@ -0,0 +1,58 @@ +{% macro athena__drop_relation(relation) -%} + {%- set native_drop = config.get('native_drop', default=false) -%} + {%- set rel_type_object = adapter.get_glue_table_type(relation) -%} + {%- set rel_type = none if rel_type_object == none else rel_type_object.value -%} + {%- set natively_droppable = rel_type == 'iceberg_table' or relation.type == 'view' -%} + + {%- if native_drop and natively_droppable -%} + {%- do drop_relation_sql(relation) -%} + {%- else -%} + {%- do drop_relation_glue(relation) -%} + {%- endif -%} +{% endmacro %} + +{% macro drop_relation_glue(relation) -%} + {%- do log('Dropping relation via Glue and S3 APIs') -%} + {%- do adapter.clean_up_table(relation) -%} + {%- do adapter.delete_from_glue_catalog(relation) -%} +{% endmacro %} + +{% macro drop_relation_sql(relation) -%} + + {%- do log('Dropping relation via SQL only') -%} + {% call statement('drop_relation', auto_begin=False) -%} + {%- if relation.type == 'view' -%} + drop {{ relation.type }} if exists {{ relation.render() }} + {%- else -%} + drop {{ relation.type }} if exists {{ relation.render_hive() }} + {% endif %} + {%- endcall %} +{% endmacro %} + +{% macro set_table_classification(relation) -%} + {%- set format = config.get('format', default='parquet') -%} + {% call statement('set_table_classification', auto_begin=False) -%} + alter table {{ relation.render_hive() }} set tblproperties ('classification' = '{{ format }}') + {%- endcall %} +{%- endmacro %} + +{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %} + {%- set temp_identifier = base_relation.identifier ~ suffix -%} + {%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%} + + {%- if temp_schema is not none -%} + {%- set temp_relation = temp_relation.incorporate(path={ + "identifier": temp_identifier, + "schema": temp_schema + }) -%} + {%- do create_schema(temp_relation) -%} + {% endif %} + + {{ return(temp_relation) }} +{% endmacro %} + +{% macro athena__rename_relation(from_relation, to_relation) %} + {% call statement('rename_relation') -%} + alter table {{ from_relation.render_hive() }} rename to `{{ to_relation.schema }}`.`{{ to_relation.identifier }}` + {%- endcall %} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql b/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql new file mode 100644 index 00000000..2750a7f9 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/adapters/schema.sql @@ -0,0 +1,15 @@ +{% macro athena__create_schema(relation) -%} + {%- call statement('create_schema') -%} + create schema if not exists {{ relation.without_identifier().render_hive() }} + {% endcall %} + + {{ adapter.add_lf_tags_to_database(relation) }} + +{% endmacro %} + + +{% macro athena__drop_schema(relation) -%} + {%- call statement('drop_schema') -%} + drop schema if exists {{ relation.without_identifier().render_hive() }} cascade + {% endcall %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql new file mode 100644 index 00000000..07c32d35 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/hooks.sql @@ -0,0 +1,17 @@ +{% macro run_hooks(hooks, inside_transaction=True) %} + {% set re = modules.re %} + {% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %} + {% set rendered = render(hook.get('sql')) | trim %} + {% if (rendered | length) > 0 %} + {%- if re.match("optimize\W+\S+\W+rewrite data using bin_pack", rendered.lower(), re.MULTILINE) -%} + {%- do adapter.run_operation_with_potential_multiple_runs(rendered, "optimize") -%} + {%- elif re.match("vacuum\W+\S+", rendered.lower(), re.MULTILINE) -%} + {%- do adapter.run_operation_with_potential_multiple_runs(rendered, "vacuum") -%} + {%- else -%} + {% call statement(auto_begin=inside_transaction) %} + {{ rendered }} + {% endcall %} + {%- endif -%} + {% endif %} + {% endfor %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql new file mode 100644 index 00000000..c1bf6505 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql @@ -0,0 +1,88 @@ +{% macro get_partition_batches(sql, as_subquery=True) -%} + {# Retrieve partition configuration and set default partition limit #} + {%- set partitioned_by = config.get('partitioned_by') -%} + {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%} + {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%} + {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %} + + {# Retrieve distinct partitions from the given SQL #} + {% call statement('get_partitions', fetch_result=True) %} + {%- if as_subquery -%} + select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }}; + {%- else -%} + select distinct {{ partitioned_keys }} from {{ sql }} order by {{ partitioned_keys }}; + {%- endif -%} + {% endcall %} + + {# Initialize variables to store partition info #} + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set ns = namespace(partitions = [], bucket_conditions = {}, bucket_numbers = [], bucket_column = None, is_bucketed = false) -%} + + {# Process each partition row #} + {%- for row in rows -%} + {%- set single_partition = [] -%} + {# Use Namespace to hold the counter for loop index #} + {%- set counter = namespace(value=0) -%} + {# Loop through each column in the row #} + {%- for col, partition_key in zip(row, partitioned_by) -%} + {# Process bucketed columns using the new macro with the index #} + {%- do process_bucket_column(col, partition_key, table, ns, counter.value) -%} + + {# Logic for non-bucketed columns #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + {%- if not bucket_match -%} + {# For non-bucketed columns, format partition key and value #} + {%- set column_type = adapter.convert_type(table, counter.value) -%} + {%- set value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + {%- set partition_key_formatted = adapter.format_one_partition_key(partitioned_by[counter.value]) -%} + {%- do single_partition.append(partition_key_formatted + comp_func + value) -%} + {%- endif -%} + {# Increment the counter #} + {%- set counter.value = counter.value + 1 -%} + {%- endfor -%} + + {# Concatenate conditions for a single partition #} + {%- set single_partition_expression = single_partition | join(' and ') -%} + {%- if single_partition_expression not in ns.partitions %} + {%- do ns.partitions.append(single_partition_expression) -%} + {%- endif -%} + {%- endfor -%} + + {# Calculate total batches based on bucketing and partitioning #} + {%- if ns.is_bucketed -%} + {%- set total_batches = ns.partitions | length * ns.bucket_numbers | length -%} + {%- else -%} + {%- set total_batches = ns.partitions | length -%} + {%- endif -%} + {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ total_batches) %} + + {# Determine the number of batches per partition limit #} + {%- set batches_per_partition_limit = (total_batches // athena_partitions_limit) + (total_batches % athena_partitions_limit > 0) -%} + + {# Create conditions for each batch #} + {%- set partitions_batches = [] -%} + {%- for i in range(batches_per_partition_limit) -%} + {%- set batch_conditions = [] -%} + {%- if ns.is_bucketed -%} + {# Combine partition and bucket conditions for each batch #} + {%- for partition_expression in ns.partitions -%} + {%- for bucket_num in ns.bucket_numbers -%} + {%- set bucket_condition = ns.bucket_column + " IN (" + ns.bucket_conditions[bucket_num] | join(", ") + ")" -%} + {%- set combined_condition = "(" + partition_expression + ' and ' + bucket_condition + ")" -%} + {%- do batch_conditions.append(combined_condition) -%} + {%- endfor -%} + {%- endfor -%} + {%- else -%} + {# Extend batch conditions with partitions for non-bucketed columns #} + {%- do batch_conditions.extend(ns.partitions) -%} + {%- endif -%} + {# Calculate batch start and end index and append batch conditions #} + {%- set start_index = i * athena_partitions_limit -%} + {%- set end_index = start_index + athena_partitions_limit -%} + {%- do partitions_batches.append(batch_conditions[start_index:end_index] | join(' or ')) -%} + {%- endfor -%} + + {{ return(partitions_batches) }} + +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql new file mode 100644 index 00000000..3790fbba --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql @@ -0,0 +1,20 @@ +{% macro process_bucket_column(col, partition_key, table, ns, col_index) %} + {# Extract bucket information from the partition key #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + + {%- if bucket_match -%} + {# For bucketed columns, compute bucket numbers and conditions #} + {%- set column_type = adapter.convert_type(table, col_index) -%} + {%- set ns.is_bucketed = true -%} + {%- set ns.bucket_column = bucket_match[1] -%} + {%- set bucket_num = adapter.murmur3_hash(col, bucket_match[2] | int) -%} + {%- set formatted_value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + + {%- if bucket_num not in ns.bucket_numbers %} + {%- do ns.bucket_numbers.append(bucket_num) %} + {%- do ns.bucket_conditions.update({bucket_num: [formatted_value]}) -%} + {%- elif formatted_value not in ns.bucket_conditions[bucket_num] %} + {%- do ns.bucket_conditions[bucket_num].append(formatted_value) -%} + {%- endif -%} + {%- endif -%} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql new file mode 100644 index 00000000..952d6ecb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql @@ -0,0 +1,58 @@ +{% macro alter_relation_add_columns(relation, add_columns = none, table_type = 'hive') -%} + {% if add_columns is none %} + {% set add_columns = [] %} + {% endif %} + + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} + add columns ( + {%- for column in add_columns -%} + {{ column.name }} {{ ddl_data_type(column.data_type, table_type) }}{{ ', ' if not loop.last }} + {%- endfor -%} + ) + {%- endset -%} + + {% if (add_columns | length) > 0 %} + {{ return(run_query(sql)) }} + {% endif %} +{% endmacro %} + +{% macro alter_relation_drop_columns(relation, remove_columns = none) -%} + {% if remove_columns is none %} + {% set remove_columns = [] %} + {% endif %} + + {%- for column in remove_columns -%} + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} drop column {{ column.name }} + {% endset %} + {% do run_query(sql) %} + {%- endfor -%} +{% endmacro %} + +{% macro alter_relation_replace_columns(relation, replace_columns = none, table_type = 'hive') -%} + {% if replace_columns is none %} + {% set replace_columns = [] %} + {% endif %} + + {% set sql -%} + alter {{ relation.type }} {{ relation.render_hive() }} + replace columns ( + {%- for column in replace_columns -%} + {{ column.name }} {{ ddl_data_type(column.data_type, table_type) }}{{ ', ' if not loop.last }} + {%- endfor -%} + ) + {%- endset -%} + + {% if (replace_columns | length) > 0 %} + {{ return(run_query(sql)) }} + {% endif %} +{% endmacro %} + +{% macro alter_relation_rename_column(relation, source_column, target_column, target_column_type, table_type = 'hive') -%} + {% set sql -%} + alter {{ relation.type }} {{ relation.render_pure() }} + change column {{ source_column }} {{ target_column }} {{ ddl_data_type(target_column_type, table_type) }} + {%- endset -%} + {{ return(run_query(sql)) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql new file mode 100644 index 00000000..13011578 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -0,0 +1,122 @@ +{% macro validate_get_incremental_strategy(raw_strategy, table_type) %} + {%- if table_type == 'iceberg' -%} + {% set invalid_strategy_msg -%} + Invalid incremental strategy provided: {{ raw_strategy }} + Incremental models on Iceberg tables only work with 'append' or 'merge' (v3 only) strategy. + {%- endset %} + {% if raw_strategy not in ['append', 'merge'] %} + {% do exceptions.raise_compiler_error(invalid_strategy_msg) %} + {% endif %} + {%- else -%} + {% set invalid_strategy_msg -%} + Invalid incremental strategy provided: {{ raw_strategy }} + Expected one of: 'append', 'insert_overwrite' + {%- endset %} + + {% if raw_strategy not in ['append', 'insert_overwrite'] %} + {% do exceptions.raise_compiler_error(invalid_strategy_msg) %} + {% endif %} + {% endif %} + + {% do return(raw_strategy) %} +{% endmacro %} + + +{% macro batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches|length) -%} + {%- set insert_batch_partitions -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + ); + {%- endset -%} + {%- do run_query(insert_batch_partitions) -%} + {%- endfor -%} +{% endmacro %} + + +{% macro incremental_insert( + on_schema_change, + tmp_relation, + target_relation, + existing_relation, + force_batch, + statement_name="main" + ) +%} + {%- set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- if not dest_columns -%} + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} + {%- endif -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {% if force_batch %} + {% do batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {% else %} + {%- set insert_full -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + ); + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(insert_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% do batch_incremental_insert(tmp_relation, target_relation, dest_cols_csv) %} + {%- endif -%} + {%- endif -%} + + SELECT '{{query_result}}' + +{%- endmacro %} + + +{% macro delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} + {%- set partitioned_keys = partitioned_by | tojson | replace('\"', '') | replace('[', '') | replace(']', '') -%} + {% call statement('get_partitions', fetch_result=True) %} + select distinct {{partitioned_keys}} from {{ tmp_relation }}; + {% endcall %} + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set partitions = [] -%} + {%- for row in rows -%} + {%- set single_partition = [] -%} + {%- for col in row -%} + {%- set column_type = adapter.convert_type(table, loop.index0) -%} + {%- if column_type == 'integer' or column_type is none -%} + {%- set value = col|string -%} + {%- elif column_type == 'string' -%} + {%- set value = "'" + col + "'" -%} + {%- elif column_type == 'date' -%} + {%- set value = "'" + col|string + "'" -%} + {%- elif column_type == 'timestamp' -%} + {%- set value = "'" + col|string + "'" -%} + {%- else -%} + {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} + {%- endif -%} + {%- do single_partition.append(partitioned_by[loop.index0] + '=' + value) -%} + {%- endfor -%} + {%- set single_partition_expression = single_partition | join(' and ') -%} + {%- do partitions.append('(' + single_partition_expression + ')') -%} + {%- endfor -%} + {%- for i in range(partitions | length) %} + {%- do adapter.clean_up_partitions(target_relation, partitions[i]) -%} + {%- endfor -%} +{%- endmacro %} + +{% macro remove_partitions_from_columns(columns_with_partitions, partition_keys) %} + {%- set columns = [] -%} + {%- for column in columns_with_partitions -%} + {%- if column.name not in partition_keys -%} + {%- do columns.append(column) -%} + {%- endif -%} + {%- endfor -%} + {{ return(columns) }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql new file mode 100644 index 00000000..c18ac681 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -0,0 +1,175 @@ +{% materialization incremental, adapter='athena', supported_languages=['sql', 'python'] -%} + {% set raw_strategy = config.get('incremental_strategy') or 'insert_overwrite' %} + {% set table_type = config.get('table_type', default='hive') | lower %} + {% set model_language = model['language'] %} + {% set strategy = validate_get_incremental_strategy(raw_strategy, table_type) %} + {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} + {% set versions_to_keep = config.get('versions_to_keep', 1) | as_number %} + {% set lf_tags_config = config.get('lf_tags_config') %} + {% set lf_grants = config.get('lf_grants') %} + {% set partitioned_by = config.get('partitioned_by') %} + {% set force_batch = config.get('force_batch', False) | as_bool -%} + {% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%} + {% set temp_schema = config.get('temp_schema') %} + {% set target_relation = this.incorporate(type='table') %} + {% set existing_relation = load_relation(this) %} + -- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix + {% if unique_tmp_table_suffix == True and strategy == 'insert_overwrite' and table_type == 'hive' %} + {% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %} + {% else %} + {% set tmp_table_suffix = '__dbt_tmp' %} + {% endif %} + + {% if unique_tmp_table_suffix == True and table_type == 'iceberg' %} + {% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %} + {% endif %} + + {% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix, + schema=schema, + database=database) %} + {% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix, temp_schema=temp_schema) %} + + -- If no partitions are used with insert_overwrite, we fall back to append mode. + {% if partitioned_by is none and strategy == 'insert_overwrite' %} + {% set strategy = 'append' %} + {% endif %} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + -- `BEGIN` happens here: + {{ run_hooks(pre_hooks, inside_transaction=True) }} + + {% set to_drop = [] %} + {% if existing_relation is none %} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} + {% elif existing_relation.is_view or should_full_refresh() %} + {% do drop_relation(existing_relation) %} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} + {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} + {% set build_sql = incremental_insert( + on_schema_change, tmp_relation, target_relation, existing_relation, force_batch + ) + %} + {% do to_drop.append(tmp_relation) %} + {% elif strategy == 'append' %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = incremental_insert( + on_schema_change, tmp_relation, target_relation, existing_relation, force_batch + ) + %} + {% do to_drop.append(tmp_relation) %} + {% elif strategy == 'merge' and table_type == 'iceberg' %} + {% set unique_key = config.get('unique_key') %} + {% set incremental_predicates = config.get('incremental_predicates') %} + {% set delete_condition = config.get('delete_condition') %} + {% set update_condition = config.get('update_condition') %} + {% set insert_condition = config.get('insert_condition') %} + {% set empty_unique_key -%} + Merge strategy must implement unique_key as a single column or a list of columns. + {%- endset %} + {% if unique_key is none %} + {% do exceptions.raise_compiler_error(empty_unique_key) %} + {% endif %} + {% if incremental_predicates is not none %} + {% set inc_predicates_not_list -%} + Merge strategy must implement incremental_predicates as a list of predicates. + {%- endset %} + {% if not adapter.is_list(incremental_predicates) %} + {% do exceptions.raise_compiler_error(inc_predicates_not_list) %} + {% endif %} + {% endif %} + {% if old_tmp_relation is not none %} + {% do drop_relation(old_tmp_relation) %} + {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, model_language, force_batch) -%} + {%- if model_language == 'python' -%} + {% call statement('create_table', language=model_language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {% set build_sql = iceberg_merge( + on_schema_change=on_schema_change, + tmp_relation=tmp_relation, + target_relation=target_relation, + unique_key=unique_key, + incremental_predicates=incremental_predicates, + existing_relation=existing_relation, + delete_condition=delete_condition, + update_condition=update_condition, + insert_condition=insert_condition, + force_batch=force_batch, + ) + %} + {% do to_drop.append(tmp_relation) %} + {% endif %} + + {% call statement("main", language=model_language) %} + {% if model_language == 'sql' %} + {{ build_sql }} + {% else %} + {{ log(build_sql) }} + {% do athena__py_execute_query(query=build_sql) %} + {% endif %} + {% endcall %} + + -- set table properties + {% if not to_drop and table_type != 'iceberg' and model_language != 'python' %} + {{ set_table_classification(target_relation) }} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + -- `COMMIT` happens here + {% do adapter.commit() %} + + {% for rel in to_drop %} + {% do drop_relation(rel) %} + {% endfor %} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, False) %} + + {{ return({'relations': [target_relation]}) }} + +{%- endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql new file mode 100644 index 00000000..8741adb7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/merge.sql @@ -0,0 +1,162 @@ +{% macro get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns) %} + + {%- if merge_update_columns and merge_exclude_columns -%} + {{ exceptions.raise_compiler_error( + 'Model cannot specify merge_update_columns and merge_exclude_columns. Please update model to use only one config' + )}} + {%- elif merge_update_columns -%} + {%- set update_columns = [] -%} + {%- for column in dest_columns -%} + {% if column.column | lower in merge_update_columns | map("lower") | list %} + {%- do update_columns.append(column) -%} + {% endif %} + {%- endfor -%} + {%- elif merge_exclude_columns -%} + {%- set update_columns = [] -%} + {%- for column in dest_columns -%} + {% if column.column | lower not in merge_exclude_columns | map("lower") | list %} + {%- do update_columns.append(column) -%} + {% endif %} + {%- endfor -%} + {%- else -%} + {%- set update_columns = dest_columns -%} + {%- endif -%} + + {{ return(update_columns) }} + +{% endmacro %} + +{%- macro get_update_statement(col, rule, is_last) -%} + {%- if rule == "coalesce" -%} + {{ col.quoted }} = {{ 'coalesce(src.' + col.quoted + ', target.' + col.quoted + ')' }} + {%- elif rule == "sum" -%} + {%- if col.data_type.startswith("map") -%} + {{ col.quoted }} = {{ 'map_zip_with(coalesce(src.' + col.quoted + ', map()), coalesce(target.' + col.quoted + ', map()), (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0))' }} + {%- else -%} + {{ col.quoted }} = {{ 'src.' + col.quoted + ' + target.' + col.quoted }} + {%- endif -%} + {%- elif rule == "append" -%} + {{ col.quoted }} = {{ 'src.' + col.quoted + ' || target.' + col.quoted }} + {%- elif rule == "append_distinct" -%} + {{ col.quoted }} = {{ 'array_distinct(src.' + col.quoted + ' || target.' + col.quoted + ')' }} + {%- elif rule == "replace" -%} + {{ col.quoted }} = {{ 'src.' + col.quoted }} + {%- else -%} + {{ col.quoted }} = {{ rule | replace("_new_", 'src.' + col.quoted) | replace("_old_", 'target.' + col.quoted) }} + {%- endif -%} + {{ "," if not is_last }} +{%- endmacro -%} + + +{% macro batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + {%- set src_batch_part -%} + merge into {{ target_relation }} as target + using (select {{ dest_cols_csv }} from {{ tmp_relation }} where {{ batch }}) as src + {%- endset -%} + {%- set merge_batch -%} + {{ src_batch_part }} + {{ merge_part }} + {%- endset -%} + {%- do run_query(merge_batch) -%} + {%- endfor -%} +{%- endmacro -%} + + +{% macro iceberg_merge( + on_schema_change, + tmp_relation, + target_relation, + unique_key, + incremental_predicates, + existing_relation, + delete_condition, + update_condition, + insert_condition, + force_batch, + statement_name="main" + ) +%} + {%- set merge_update_columns = config.get('merge_update_columns') -%} + {%- set merge_exclude_columns = config.get('merge_exclude_columns') -%} + {%- set merge_update_columns_default_rule = config.get('merge_update_columns_default_rule', 'replace') -%} + {%- set merge_update_columns_rules = config.get('merge_update_columns_rules') -%} + + {% set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} + {% if not dest_columns %} + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} + {% endif %} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + {%- if unique_key is sequence and unique_key is not string -%} + {%- set unique_key_cols = unique_key -%} + {%- else -%} + {%- set unique_key_cols = [unique_key] -%} + {%- endif -%} + {%- set src_columns_quoted = [] -%} + {%- set dest_columns_wo_keys = [] -%} + {%- for col in dest_columns -%} + {%- do src_columns_quoted.append('src.' + col.quoted ) -%} + {%- if col.name not in unique_key_cols -%} + {%- do dest_columns_wo_keys.append(col) -%} + {%- endif -%} + {%- endfor -%} + {%- set update_columns = get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns_wo_keys) -%} + {%- set src_cols_csv = src_columns_quoted | join(', ') -%} + + {%- set merge_part -%} + on ( + {%- for key in unique_key_cols -%} + target.{{ key }} = src.{{ key }} + {{ " and " if not loop.last }} + {%- endfor -%} + {% if incremental_predicates is not none -%} + and ( + {%- for inc_predicate in incremental_predicates %} + {{ inc_predicate }} {{ "and " if not loop.last }} + {%- endfor %} + ) + {%- endif %} + ) + {% if delete_condition is not none -%} + when matched and ({{ delete_condition }}) + then delete + {%- endif %} + {% if update_columns -%} + when matched {% if update_condition is not none -%} and {{ update_condition }} {%- endif %} + then update set + {%- for col in update_columns %} + {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} + {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} + {%- else -%} + {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} + {%- endif -%} + {%- endfor %} + {%- endif %} + when not matched {% if insert_condition is not none -%} and {{ insert_condition }} {%- endif %} + then insert ({{ dest_cols_csv }}) + values ({{ src_cols_csv }}) + {%- endset -%} + + {%- if force_batch -%} + {% do batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {%- else -%} + {%- set src_part -%} + merge into {{ target_relation }} as target using {{ tmp_relation }} as src + {%- endset -%} + {%- set merge_full -%} + {{ src_part }} + {{ merge_part }} + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(merge_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% do batch_iceberg_merge(tmp_relation, target_relation, merge_part, dest_cols_csv) %} + {%- endif -%} + {%- endif -%} + + SELECT '{{query_result}}' +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql new file mode 100644 index 00000000..fc7fcbcb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql @@ -0,0 +1,89 @@ +{% macro sync_column_schemas(on_schema_change, target_relation, schema_changes_dict) %} + {%- set partitioned_by = config.get('partitioned_by', default=none) -%} + {% set table_type = config.get('table_type', default='hive') | lower %} + {%- if partitioned_by is none -%} + {%- set partitioned_by = [] -%} + {%- endif %} + {%- set add_to_target_arr = schema_changes_dict['source_not_in_target'] -%} + {%- if on_schema_change == 'append_new_columns'-%} + {%- if add_to_target_arr | length > 0 -%} + {%- do alter_relation_add_columns(target_relation, add_to_target_arr, table_type) -%} + {%- endif -%} + {% elif on_schema_change == 'sync_all_columns' %} + {%- set remove_from_target_arr = schema_changes_dict['target_not_in_source'] -%} + {%- set new_target_types = schema_changes_dict['new_target_types'] -%} + {% if table_type == 'iceberg' %} + {# + If last run of alter_column_type was failed on rename tmp column to origin. + Do rename to protect origin column from deletion and losing data. + #} + {% for remove_col in remove_from_target_arr if remove_col.column.endswith('__dbt_alter') %} + {%- set origin_col_name = remove_col.column | replace('__dbt_alter', '') -%} + {% for add_col in add_to_target_arr if add_col.column == origin_col_name %} + {%- do alter_relation_rename_column(target_relation, remove_col.name, add_col.name, add_col.data_type, table_type) -%} + {%- do remove_from_target_arr.remove(remove_col) -%} + {%- do add_to_target_arr.remove(add_col) -%} + {% endfor %} + {% endfor %} + + {% if add_to_target_arr | length > 0 %} + {%- do alter_relation_add_columns(target_relation, add_to_target_arr, table_type) -%} + {% endif %} + {% if remove_from_target_arr | length > 0 %} + {%- do alter_relation_drop_columns(target_relation, remove_from_target_arr) -%} + {% endif %} + {% if new_target_types != [] %} + {% for ntt in new_target_types %} + {% set column_name = ntt['column_name'] %} + {% set new_type = ntt['new_type'] %} + {% do alter_column_type(target_relation, column_name, new_type) %} + {% endfor %} + {% endif %} + {% else %} + {%- set replace_with_target_arr = remove_partitions_from_columns(schema_changes_dict['source_columns'], partitioned_by) -%} + {% if add_to_target_arr | length > 0 or remove_from_target_arr | length > 0 or new_target_types | length > 0 %} + {%- do alter_relation_replace_columns(target_relation, replace_with_target_arr, table_type) -%} + {% endif %} + {% endif %} + {% endif %} + {% set schema_change_message %} + In {{ target_relation }}: + Schema change approach: {{ on_schema_change }} + Columns added: {{ add_to_target_arr }} + Columns removed: {{ remove_from_target_arr }} + Data types changed: {{ new_target_types }} + {% endset %} + {% do log(schema_change_message) %} +{% endmacro %} + +{% macro athena__alter_column_type(relation, column_name, new_column_type) -%} + {# + 1. Create a new column (w/ temp name and correct type) + 2. Copy data over to it + 3. Drop the existing column + 4. Rename the new column to existing column + #} + {%- set table_type = config.get('table_type', 'hive') -%} + {%- set tmp_column = column_name + '__dbt_alter' -%} + {%- set new_ddl_data_type = ddl_data_type(new_column_type, table_type) -%} + + {#- do alter_relation_add_columns(relation, [ tmp_column ], table_type) -#} + {%- set add_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} add columns({{ tmp_column }} {{ new_ddl_data_type }}); + {%- endset -%} + {%- do run_query(add_column_query) -%} + + {%- set update_query -%} + update {{ relation.render_pure() }} set {{ tmp_column }} = cast({{ column_name }} as {{ new_column_type }}); + {%- endset -%} + {%- do run_query(update_query) -%} + + {#- do alter_relation_drop_columns(relation, [ column_name ]) -#} + {%- set drop_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} drop column {{ column_name }}; + {%- endset -%} + {%- do run_query(drop_column_query) -%} + + {%- do alter_relation_rename_column(relation, tmp_column, column_name, new_column_type, table_type) -%} + +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql new file mode 100644 index 00000000..e20c4020 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -0,0 +1,223 @@ +{% macro create_table_as(temporary, relation, compiled_code, language='sql', skip_partitioning=false) -%} + {{ adapter.dispatch('create_table_as', 'athena')(temporary, relation, compiled_code, language, skip_partitioning) }} +{%- endmacro %} + + +{% macro athena__create_table_as(temporary, relation, compiled_code, language='sql', skip_partitioning=false) -%} + {%- set materialized = config.get('materialized', default='table') -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- do log("Skip partitioning: " ~ skip_partitioning) -%} + {%- set partitioned_by = config.get('partitioned_by', default=none) if not skip_partitioning else none -%} + {%- set bucketed_by = config.get('bucketed_by', default=none) -%} + {%- set bucket_count = config.get('bucket_count', default=none) -%} + {%- set field_delimiter = config.get('field_delimiter', default=none) -%} + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set format = config.get('format', default='parquet') -%} + {%- set write_compression = config.get('write_compression', default=none) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} + {%- set s3_tmp_table_dir = config.get('s3_tmp_table_dir', default=target.s3_tmp_table_dir) -%} + {%- set extra_table_properties = config.get('table_properties', default=none) -%} + + {%- set location_property = 'external_location' -%} + {%- set partition_property = 'partitioned_by' -%} + {%- set work_group_output_location_enforced = adapter.is_work_group_output_location_enforced() -%} + {%- set location = adapter.generate_s3_location(relation, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + temporary, + ) -%} + {%- set native_drop = config.get('native_drop', default=false) -%} + + {%- set contract_config = config.get('contract') -%} + {%- if contract_config.enforced -%} + {{ get_assert_columns_equivalent(compiled_code) }} + {%- endif -%} + + {%- if native_drop and table_type == 'iceberg' -%} + {% do log('Config native_drop enabled, skipping direct S3 delete') %} + {%- else -%} + {% do adapter.delete_from_s3(location) %} + {%- endif -%} + + {%- if language == 'python' -%} + {%- set spark_ctas = '' -%} + {%- if table_type == 'iceberg' -%} + {%- set spark_ctas -%} + create table {{ relation.schema | replace('\"', '`') }}.{{ relation.identifier | replace('\"', '`') }} + using iceberg + location '{{ location }}/' + + {%- if partitioned_by is not none %} + partitioned by ( + {%- for prop_value in partitioned_by -%} + {{ prop_value }} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ) + {%- endif %} + + {%- if extra_table_properties is not none %} + tblproperties( + {%- for prop_name, prop_value in extra_table_properties.items() -%} + '{{ prop_name }}'='{{ prop_value }}' + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ) + {% endif %} + + as + {%- endset -%} + {%- endif -%} + + {# {% do log('Creating table with spark and compiled code: ' ~ compiled_code) %} #} + {{ athena__py_save_table_as( + compiled_code, + relation, + optional_args={ + 'location': location, + 'format': format, + 'mode': 'overwrite', + 'partitioned_by': partitioned_by, + 'bucketed_by': bucketed_by, + 'write_compression': write_compression, + 'bucket_count': bucket_count, + 'field_delimiter': field_delimiter, + 'spark_ctas': spark_ctas + } + ) + }} + {%- else -%} + {%- if table_type == 'iceberg' -%} + {%- set location_property = 'location' -%} + {%- set partition_property = 'partitioning' -%} + {%- if bucketed_by is not none or bucket_count is not none -%} + {%- set ignored_bucket_iceberg -%} + bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function + when partitioning. Will be ignored + {%- endset -%} + {%- set bucketed_by = none -%} + {%- set bucket_count = none -%} + {% do log(ignored_bucket_iceberg) %} + {%- endif -%} + {%- if materialized == 'table' and ( 'unique' not in s3_data_naming or external_location is not none) -%} + {%- set error_unique_location_iceberg -%} + You need to have an unique table location when creating Iceberg table since we use the RENAME feature + to have near-zero downtime. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} + {%- endif -%} + {%- endif %} + + create table {{ relation }} + with ( + table_type='{{ table_type }}', + is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, + {%- if not work_group_output_location_enforced or table_type == 'iceberg' -%} + {{ location_property }}='{{ location }}', + {%- endif %} + {%- if partitioned_by is not none %} + {{ partition_property }}=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucketed_by is not none %} + bucketed_by=ARRAY{{ bucketed_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucket_count is not none %} + bucket_count={{ bucket_count }}, + {%- endif %} + {%- if field_delimiter is not none %} + field_delimiter='{{ field_delimiter }}', + {%- endif %} + {%- if write_compression is not none %} + write_compression='{{ write_compression }}', + {%- endif %} + format='{{ format }}' + {%- if extra_table_properties is not none -%} + {%- for prop_name, prop_value in extra_table_properties.items() -%} + , + {{ prop_name }}={{ prop_value }} + {%- endfor -%} + {% endif %} + ) + as + {{ compiled_code }} + {%- endif -%} +{%- endmacro -%} + +{% macro create_table_as_with_partitions(temporary, relation, compiled_code, language='sql') -%} + + {%- set tmp_relation = api.Relation.create( + identifier=relation.identifier ~ '__tmp_not_partitioned', + schema=relation.schema, + database=relation.database, + s3_path_table_part=relation.identifier ~ '__tmp_not_partitioned' , + type='table' + ) + -%} + + {%- if tmp_relation is not none -%} + {%- do drop_relation(tmp_relation) -%} + {%- endif -%} + + {%- do log('CREATE NON-PARTIONED STAGING TABLE: ' ~ tmp_relation) -%} + {%- do run_query(create_table_as(temporary, tmp_relation, compiled_code, language, true)) -%} + + {% set partitions_batches = get_partition_batches(sql=tmp_relation, as_subquery=False) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + + {%- set dest_columns = adapter.get_columns_in_relation(tmp_relation) -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + + {%- if loop.index == 1 -%} + {%- set create_target_relation_sql -%} + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + {%- do run_query(create_table_as(temporary, relation, create_target_relation_sql, language)) -%} + {%- else -%} + {%- set insert_batch_partitions_sql -%} + insert into {{ relation }} ({{ dest_cols_csv }}) + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + + {%- do run_query(insert_batch_partitions_sql) -%} + {%- endif -%} + + + {%- endfor -%} + + {%- do drop_relation(tmp_relation) -%} + + select 'SUCCESSFULLY CREATED TABLE {{ relation }}' + +{%- endmacro %} + +{% macro safe_create_table_as(temporary, relation, compiled_code, language='sql', force_batch=False) -%} + {%- if language != 'sql' -%} + {{ return(create_table_as(temporary, relation, compiled_code, language)) }} + {%- elif force_batch -%} + {%- do create_table_as_with_partitions(temporary, relation, compiled_code, language) -%} + {%- set query_result = relation ~ ' with many partitions created' -%} + {%- else -%} + {%- if temporary -%} + {%- do run_query(create_table_as(temporary, relation, compiled_code, language, true)) -%} + {%- set compiled_code_result = relation ~ ' as temporary relation without partitioning created' -%} + {%- else -%} + {%- set compiled_code_result = adapter.run_query_with_partitions_limit_catching(create_table_as(temporary, relation, compiled_code)) -%} + {%- do log('COMPILED CODE RESULT: ' ~ compiled_code_result) -%} + {%- if compiled_code_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {%- do create_table_as_with_partitions(temporary, relation, compiled_code, language) -%} + {%- set compiled_code_result = relation ~ ' with many partitions created' -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {{ return(compiled_code_result) }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql new file mode 100644 index 00000000..1f94361c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/table/table.sql @@ -0,0 +1,183 @@ +-- TODO create a drop_relation_with_versions, to be sure to remove all historical versions of a table +{% materialization table, adapter='athena', supported_languages=['sql', 'python'] -%} + {%- set identifier = model['alias'] -%} + {%- set language = model['language'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set old_tmp_relation = adapter.get_relation(identifier=identifier ~ '__ha', + schema=schema, + database=database) -%} + {%- set old_bkp_relation = adapter.get_relation(identifier=identifier ~ '__bkp', + schema=schema, + database=database) -%} + {%- set is_ha = config.get('ha', default=false) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default='table_unique') -%} + {%- set full_refresh_config = config.get('full_refresh', default=False) -%} + {%- set is_full_refresh_mode = (flags.FULL_REFRESH == True or full_refresh_config == True) -%} + {%- set versions_to_keep = config.get('versions_to_keep', default=4) -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set force_batch = config.get('force_batch', False) | as_bool -%} + {%- set target_relation = api.Relation.create(identifier=identifier, + schema=schema, + database=database, + type='table') -%} + {%- set tmp_relation = api.Relation.create(identifier=target_relation.identifier ~ '__ha', + schema=schema, + database=database, + s3_path_table_part=target_relation.identifier, + type='table') -%} + + {%- if ( + table_type == 'hive' + and is_ha + and ('unique' not in s3_data_naming or external_location is not none) + ) -%} + {%- set error_unique_location_hive_ha -%} + You need to have an unique table location when using ha config with hive table. + Use s3_data_naming unique, table_unique or schema_table_unique, and avoid to set an explicit + external_location. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_hive_ha) %} + {%- endif -%} + + {{ run_hooks(pre_hooks) }} + + {%- if table_type == 'hive' -%} + + -- for ha tables that are not in full refresh mode and when the relation exists we use the swap behavior + {%- if is_ha and not is_full_refresh_mode and old_relation is not none -%} + -- drop the old_tmp_relation if it exists + {%- if old_tmp_relation is not none -%} + {%- do adapter.delete_from_glue_catalog(old_tmp_relation) -%} + {%- endif -%} + + -- create tmp table + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + -- swap table + {%- set swap_table = adapter.swap_table(tmp_relation, target_relation) -%} + + -- delete glue tmp table, do not use drop_relation, as it will remove data of the target table + {%- do adapter.delete_from_glue_catalog(tmp_relation) -%} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, True) %} + + {%- else -%} + -- Here we are in the case of non-ha tables or ha tables but in case of full refresh. + {%- if old_relation is not none -%} + {{ drop_relation(old_relation) }} + {%- endif -%} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- endif -%} + + {%- if language != 'python' -%} + {{ set_table_classification(target_relation) }} + {%- endif -%} + {%- else -%} + + {%- if old_relation is none -%} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- else -%} + {%- if old_relation.is_view -%} + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) -%} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + {%- do drop_relation(old_relation) -%} + {%- do rename_relation(tmp_relation, target_relation) -%} + {%- else -%} + -- delete old tmp iceberg table if it exists + {%- if old_tmp_relation is not none -%} + {%- do drop_relation(old_tmp_relation) -%} + {%- endif -%} + + -- If we have this, it means that at least the first renaming occurred but there was an issue + -- afterwards, therefore we are in weird state. The easiest and cleanest should be to remove + -- the backup relation. It won't have an impact because since we are in the else condition, + -- that means that old relation exists therefore no downtime yet. + {%- if old_bkp_relation is not none -%} + {%- do drop_relation(old_bkp_relation) -%} + {%- endif -%} + + {% set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language, force_batch) %} + -- Execute python code that is available in query result object + {%- if language == 'python' -%} + {% call statement('create_table', language=language) %} + {{ query_result }} + {% endcall %} + {%- endif -%} + + {%- set old_relation_table_type = adapter.get_glue_table_type(old_relation).value if old_relation else none -%} + + -- we cannot use old_bkp_relation, because it returns None if the relation doesn't exist + -- we need to create a python object via the make_temp_relation instead + {%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%} + + {%- if old_relation_table_type == 'iceberg_table' -%} + {{ rename_relation(old_relation, old_relation_bkp) }} + {%- else -%} + {%- do drop_relation_glue(old_relation) -%} + {%- endif -%} + + -- publish the target table doing a final renaming + {{ rename_relation(tmp_relation, target_relation) }} + + -- if old relation is iceberg_table, we have a backup + -- therefore we can drop the old relation backup, in all other cases there is nothing to do + -- in case of switch from hive to iceberg the backup table do not exists + -- in case of first run, the backup table do not exists + {%- if old_relation_table_type == 'iceberg_table' -%} + {%- do drop_relation(old_relation_bkp) -%} + {%- endif -%} + + {%- endif -%} + {%- endif -%} + + {%- endif -%} + + {% call statement("main", language=language) %} + {%- if language=='sql' -%} + SELECT '{{ query_result }}'; + {%- endif -%} + {% endcall %} + + {{ run_hooks(post_hooks) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {{ return({'relations': [target_relation]}) }} + +{% endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql new file mode 100644 index 00000000..655354ac --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql @@ -0,0 +1,54 @@ +{% macro create_or_replace_view(run_outside_transaction_hooks=True) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%} + {%- set target_relation = api.Relation.create( + identifier=identifier, + schema=schema, + database=database, + type='view', + ) -%} + + {% if run_outside_transaction_hooks %} + -- no transactions on BigQuery + {{ run_hooks(pre_hooks, inside_transaction=False) }} + {% endif %} + -- `BEGIN` happens here on Snowflake + {{ run_hooks(pre_hooks, inside_transaction=True) }} + -- If there's a table with the same name and we weren't told to full refresh, + -- that's an error. If we were told to full refresh, drop it. This behavior differs + -- for Snowflake and BigQuery, so multiple dispatch is used. + {%- if old_relation is not none and old_relation.is_table -%} + {{ handle_existing_table(should_full_refresh(), old_relation) }} + {%- endif -%} + + -- build model + {% call statement('main') -%} + {{ create_view_as(target_relation, compiled_code) }} + {%- endcall %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + {{ adapter.commit() }} + + {% if run_outside_transaction_hooks %} + -- No transactions on BigQuery + {{ run_hooks(post_hooks, inside_transaction=False) }} + {% endif %} + + {{ return({'relations': [target_relation]}) }} + +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql new file mode 100644 index 00000000..019676db --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/create_view_as.sql @@ -0,0 +1,10 @@ +{% macro athena__create_view_as(relation, sql) -%} + {%- set contract_config = config.get('contract') -%} + {%- if contract_config.enforced -%} + {{ get_assert_columns_equivalent(sql) }} + {%- endif -%} + create or replace view + {{ relation }} + as + {{ sql }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql new file mode 100644 index 00000000..449e0833 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/models/view/view.sql @@ -0,0 +1,17 @@ +{% materialization view, adapter='athena' -%} + {%- set identifier = model['alias'] -%} + {%- set versions_to_keep = config.get('versions_to_keep', default=4) -%} + {%- set target_relation = api.Relation.create(identifier=identifier, + schema=schema, + database=database, + type='view') -%} + + {% set to_return = create_or_replace_view(run_outside_transaction_hooks=False) %} + + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, False) %} + + {% set target_relation = this.incorporate(type='view') %} + {% do persist_docs(target_relation, model) %} + + {% do return(to_return) %} +{%- endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql new file mode 100644 index 00000000..f44c7c9c --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/seeds/helpers.sql @@ -0,0 +1,215 @@ +{% macro default__reset_csv_table(model, full_refresh, old_relation, agate_table) %} + {% set sql = "" %} + -- No truncate in Athena so always drop CSV table and recreate + {{ drop_relation(old_relation) }} + {% set sql = create_csv_table(model, agate_table) %} + + {{ return(sql) }} +{% endmacro %} + +{% macro try_cast_timestamp(col) %} + {% set date_formats = [ + '%Y-%m-%d %H:%i:%s', + '%Y/%m/%d %H:%i:%s', + '%d %M %Y %H:%i:%s', + '%d/%m/%Y %H:%i:%s', + '%d-%m-%Y %H:%i:%s', + '%Y-%m-%d %H:%i:%s.%f', + '%Y/%m/%d %H:%i:%s.%f', + '%d %M %Y %H:%i:%s.%f', + '%d/%m/%Y %H:%i:%s.%f', + '%Y-%m-%dT%H:%i:%s.%fZ', + '%Y-%m-%dT%H:%i:%sZ', + '%Y-%m-%dT%H:%i:%s', + ]%} + + coalesce( + {% for date_format in date_formats %} + try(date_parse({{ col }}, '{{ date_format }}')) + {%- if not loop.last -%}, {% endif -%} + {% endfor %} + ) as {{ col }} +{% endmacro %} + +{% macro create_csv_table_insert(model, agate_table) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + {%- set column_override = config.get('column_types', {}) -%} + {%- set quote_seed_column = config.get('quote_columns') -%} + {%- set s3_data_dir = config.get('s3_data_dir', target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', target.s3_data_naming) -%} + {%- set s3_tmp_table_dir = config.get('s3_tmp_table_dir', default=target.s3_tmp_table_dir) -%} + {%- set external_location = config.get('external_location') -%} + + {%- set relation = api.Relation.create( + identifier=identifier, + schema=model.schema, + database=model.database, + type='table' + ) -%} + + {%- set location = adapter.generate_s3_location(relation, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + temporary) -%} + + {% set sql_table %} + create external table {{ relation.render_hive() }} ( + {%- for col_name in agate_table.column_names -%} + {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} + {%- set type = column_override.get(col_name, inferred_type) -%} + {%- set type = type if type != "string" else "varchar" -%} + {%- set column_name = (col_name | string) -%} + {{ adapter.quote_seed_column(column_name, quote_seed_column, "`") }} {{ ddl_data_type(type) }} {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + ) + location '{{ location }}' + + {% endset %} + + {% call statement('_') -%} + {{ sql_table }} + {%- endcall %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(relation, lf_grants) }} + {% endif %} + + {{ return(sql) }} +{% endmacro %} + + +{% macro create_csv_table_upload(model, agate_table) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {%- set column_override = config.get('column_types', {}) -%} + {%- set quote_seed_column = config.get('quote_columns', None) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', target.s3_data_naming) -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set seed_s3_upload_args = config.get('seed_s3_upload_args', default=target.seed_s3_upload_args) -%} + + {%- set tmp_relation = api.Relation.create( + identifier=identifier + "__dbt_tmp", + schema=model.schema, + database=model.database, + type='table' + ) -%} + + {%- set tmp_s3_location = adapter.upload_seed_to_s3( + tmp_relation, + agate_table, + s3_data_dir, + s3_data_naming, + external_location, + seed_s3_upload_args=seed_s3_upload_args + ) -%} + + -- create target relation + {%- set relation = api.Relation.create( + identifier=identifier, + schema=model.schema, + database=model.database, + type='table' + ) -%} + + -- drop tmp relation if exists + {{ drop_relation(tmp_relation) }} + + {% set sql_tmp_table %} + create external table {{ tmp_relation.render_hive() }} ( + {%- for col_name in agate_table.column_names -%} + {%- set column_name = (col_name | string) -%} + {{ adapter.quote_seed_column(column_name, quote_seed_column, "`") }} string {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + ) + row format serde 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + location '{{ tmp_s3_location }}' + tblproperties ( + 'skip.header.line.count'='1' + ) + {% endset %} + + -- casting to type string is not allowed needs to be varchar + {% set sql %} + select + {% for col_name in agate_table.column_names -%} + {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} + {%- set type = column_override.get(col_name, inferred_type) -%} + {%- set type = type if type != "string" else "varchar" -%} + {%- set column_name = (col_name | string) -%} + {%- set quoted_column_name = adapter.quote_seed_column(column_name, quote_seed_column) -%} + {% if type == 'timestamp' %} + {{ try_cast_timestamp(quoted_column_name) }} + {% else %} + cast(nullif({{quoted_column_name}}, '') as {{ type }}) as {{quoted_column_name}} + {% endif %} + {%- if not loop.last -%}, {% endif -%} + {%- endfor %} + from + {{ tmp_relation }} + {% endset %} + + -- create tmp table + {% call statement('_') -%} + {{ sql_tmp_table }} + {%- endcall -%} + + -- create target table from tmp table + {% set sql_table = create_table_as(false, relation, sql) %} + {% call statement('_') -%} + {{ sql_table }} + {%- endcall %} + + -- drop tmp table + {{ drop_relation(tmp_relation) }} + + -- delete csv file from s3 + {% do adapter.delete_from_s3(tmp_s3_location) %} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(relation, lf_grants) }} + {% endif %} + + {{ return(sql_table) }} +{% endmacro %} + +{% macro athena__create_csv_table(model, agate_table) %} + + {%- set seed_by_insert = config.get('seed_by_insert', False) | as_bool -%} + + {%- if seed_by_insert -%} + {% do log('seed by insert...') %} + {%- set sql_table = create_csv_table_insert(model, agate_table) -%} + {%- else -%} + {% do log('seed by upload...') %} + {%- set sql_table = create_csv_table_upload(model, agate_table) -%} + {%- endif -%} + + {{ return(sql_table) }} +{% endmacro %} + +{# Overwrite to satisfy dbt-core logic #} +{% macro athena__load_csv_rows(model, agate_table) %} + {%- set seed_by_insert = config.get('seed_by_insert', False) | as_bool -%} + {%- if seed_by_insert %} + {{ default__load_csv_rows(model, agate_table) }} + {%- else -%} + select 1 + {% endif %} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql b/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql new file mode 100644 index 00000000..23dd2608 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/materializations/snapshots/snapshot.sql @@ -0,0 +1,244 @@ +{# + Hash function to generate dbt_scd_id. dbt by default uses md5(coalesce(cast(field as varchar))) + but it does not work with athena. It throws an error Unexpected parameters (varchar) for function + md5. Expected: md5(varbinary) +#} +{% macro athena__snapshot_hash_arguments(args) -%} + to_hex(md5(to_utf8({%- for arg in args -%} + coalesce(cast({{ arg }} as varchar ), '') + {% if not loop.last %} || '|' || {% endif %} + {%- endfor -%}))) +{%- endmacro %} + + +{# + If hive table then Recreate the snapshot table from the new_snapshot_table + If iceberg table then Update the standard snapshot merge to include the DBT_INTERNAL_SOURCE prefix in the src_cols_csv +#} +{% macro hive_snapshot_merge_sql(target, source, insert_cols, table_type) -%} + {%- set target_relation = adapter.get_relation(database=target.database, schema=target.schema, identifier=target.identifier) -%} + {%- if target_relation is not none -%} + {% do adapter.drop_relation(target_relation) %} + {%- endif -%} + + {% set sql -%} + select * from {{ source }}; + {%- endset -%} + + {{ create_table_as(False, target_relation, sql) }} +{% endmacro %} + +{% macro iceberg_snapshot_merge_sql(target, source, insert_cols) %} + {%- set insert_cols_csv = insert_cols | join(', ') -%} + {%- set src_columns = [] -%} + {%- for col in insert_cols -%} + {%- do src_columns.append('dbt_internal_source.' + col) -%} + {%- endfor -%} + {%- set src_cols_csv = src_columns | join(', ') -%} + + merge into {{ target }} as dbt_internal_dest + using {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + + when matched + and dbt_internal_dest.dbt_valid_to is null + and dbt_internal_source.dbt_change_type in ('update', 'delete') + then update + set dbt_valid_to = dbt_internal_source.dbt_valid_to + + when not matched + and dbt_internal_source.dbt_change_type = 'insert' + then insert ({{ insert_cols_csv }}) + values ({{ src_cols_csv }}) +{% endmacro %} + + +{# + Create a new temporary table that will hold the new snapshot results. + This table will then be used to overwrite the target snapshot table. +#} +{% macro hive_create_new_snapshot_table(target, source, insert_cols) %} + {%- set temp_relation = make_temp_relation(target, '__dbt_tmp_snapshot') -%} + {%- set preexisting_tmp_relation = load_cached_relation(temp_relation) -%} + {%- if preexisting_tmp_relation is not none -%} + {%- do adapter.drop_relation(preexisting_tmp_relation) -%} + {%- endif -%} + + {# TODO: Add insert_cols #} + {%- set src_columns = [] -%} + {%- set dst_columns = [] -%} + {%- set updated_columns = [] -%} + {%- for col in insert_cols -%} + {%- do src_columns.append('dbt_internal_source.' + col) -%} + {%- do dst_columns.append('dbt_internal_dest.' + col) -%} + {%- if col.replace('"', '') in ['dbt_valid_to'] -%} + {%- do updated_columns.append('dbt_internal_source.' + col) -%} + {%- else -%} + {%- do updated_columns.append('dbt_internal_dest.' + col) -%} + {%- endif -%} + {%- endfor -%} + {%- set src_cols_csv = src_columns | join(', ') -%} + {%- set dst_cols_csv = dst_columns | join(', ') -%} + {%- set updated_cols_csv = updated_columns | join(', ') -%} + + {%- set source_columns = adapter.get_columns_in_relation(source) -%} + + {% set sql -%} + -- Unchanged rows + select {{ dst_cols_csv }} + from {{ target }} as dbt_internal_dest + left join {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + where dbt_internal_source.dbt_scd_id is null + + union all + + -- Updated or deleted rows + select {{ updated_cols_csv }} + from {{ target }} as dbt_internal_dest + inner join {{ source }} as dbt_internal_source + on dbt_internal_source.dbt_scd_id = dbt_internal_dest.dbt_scd_id + where dbt_internal_dest.dbt_valid_to is null + and dbt_internal_source.dbt_change_type in ('update', 'delete') + + union all + + -- New rows + select {{ src_cols_csv }} + from {{ source }} as dbt_internal_source + left join {{ target }} as dbt_internal_dest + on dbt_internal_dest.dbt_scd_id = dbt_internal_source.dbt_scd_id + where dbt_internal_dest.dbt_scd_id is null + + {%- endset -%} + + {% call statement('create_new_snapshot_table') %} + {{ create_table_as(False, temp_relation, sql) }} + {% endcall %} + + {% do return(temp_relation) %} +{% endmacro %} + +{% materialization snapshot, adapter='athena' %} + {%- set config = model['config'] -%} + + {%- set target_table = model.get('alias', model.get('name')) -%} + {%- set strategy_name = config.get('strategy') -%} + {%- set file_format = config.get('file_format', 'parquet') -%} + {%- set table_type = config.get('table_type', 'hive') -%} + + {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_grants = config.get('lf_grants') -%} + + {{ log('Checking if target table exists') }} + {% set target_relation_exists, target_relation = get_or_create_relation( + database=model.database, + schema=model.schema, + identifier=target_table, + type='table') -%} + + {%- if not target_relation.is_table -%} + {% do exceptions.relation_wrong_type(target_relation, 'table') %} + {%- endif -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {{ run_hooks(pre_hooks, inside_transaction=True) }} + + {% set strategy_macro = strategy_dispatch(strategy_name) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} + + {% if not target_relation_exists %} + + {% set build_sql = build_snapshot_table(strategy, model['compiled_sql']) %} + {% set final_sql = create_table_as(False, target_relation, build_sql) %} + + {% else %} + + {{ adapter.valid_snapshot_target(target_relation) }} + + {% set staging_relation = make_temp_relation(target_relation) %} + {%- set preexisting_staging_relation = load_cached_relation(staging_relation) -%} + {%- if preexisting_staging_relation is not none -%} + {%- do adapter.drop_relation(preexisting_staging_relation) -%} + {%- endif -%} + + {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} + + {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | list %} + + + {% if missing_columns %} + {% do alter_relation_add_columns(target_relation, missing_columns, table_type) %} + {% endif %} + + + {% set source_columns = adapter.get_columns_in_relation(staging_table) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | list %} + + {% set quoted_source_columns = [] %} + {% for column in source_columns %} + {% do quoted_source_columns.append(adapter.quote(column.name)) %} + {% endfor %} + + {% if table_type == 'iceberg' %} + {% set final_sql = iceberg_snapshot_merge_sql( + target = target_relation, + source = staging_table, + insert_cols = quoted_source_columns, + ) + %} + {% else %} + {% set new_snapshot_table = hive_create_new_snapshot_table( + target = target_relation, + source = staging_table, + insert_cols = quoted_source_columns, + ) + %} + {% set final_sql = hive_snapshot_merge_sql( + target = target_relation, + source = new_snapshot_table + ) + %} + {% endif %} + + {% endif %} + + {% call statement('main') %} + {{ final_sql }} + {% endcall %} + + {{ run_hooks(post_hooks, inside_transaction=True) }} + + {% if staging_table is defined %} + {% do adapter.drop_relation(staging_table) %} + {% endif %} + + {% if new_snapshot_table is defined %} + {% do adapter.drop_relation(new_snapshot_table) %} + {% endif %} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {% if lf_tags_config is not none %} + {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {% endif %} + + {% if lf_grants is not none %} + {{ adapter.apply_lf_grants(target_relation, lf_grants) }} + {% endif %} + + {% do persist_docs(target_relation, model) %} + + {{ return({'relations': [target_relation]}) }} + +{% endmaterialization %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql b/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql new file mode 100644 index 00000000..2a8bf0b5 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/any_value.sql @@ -0,0 +1,3 @@ +{% macro athena__any_value(expression) -%} + arbitrary({{ expression }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql new file mode 100644 index 00000000..e9868945 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_append.sql @@ -0,0 +1,3 @@ +{% macro athena__array_append(array, new_element) -%} + {{ array }} || {{ new_element }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql new file mode 100644 index 00000000..0300db98 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_concat.sql @@ -0,0 +1,3 @@ +{% macro athena__array_concat(array_1, array_2) -%} + {{ array_1 }} || {{ array_2 }} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql b/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql new file mode 100644 index 00000000..e52c8f3d --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/array_construct.sql @@ -0,0 +1,7 @@ +{% macro athena__array_construct(inputs, data_type) -%} + {% if inputs|length > 0 %} + array[ {{ inputs|join(' , ') }} ] + {% else %} + {{ safe_cast('array[]', 'array(' ~ data_type ~ ')') }} + {% endif %} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql b/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql new file mode 100644 index 00000000..0b70d2f8 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/bool_or.sql @@ -0,0 +1,3 @@ +{% macro athena__bool_or(expression) -%} + bool_or({{ expression }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql b/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql new file mode 100644 index 00000000..4b80643d --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/datatypes.sql @@ -0,0 +1,18 @@ +{# Implements Athena-specific datatypes where they differ from the dbt-core defaults #} +{# See https://docs.aws.amazon.com/athena/latest/ug/data-types.html #} + +{%- macro athena__type_float() -%} + DOUBLE +{%- endmacro -%} + +{%- macro athena__type_numeric() -%} + DECIMAL(38,6) +{%- endmacro -%} + +{%- macro athena__type_int() -%} + INTEGER +{%- endmacro -%} + +{%- macro athena__type_string() -%} + VARCHAR +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql b/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql new file mode 100644 index 00000000..aa3400b4 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/date_trunc.sql @@ -0,0 +1,3 @@ +{% macro athena__date_trunc(datepart, date) -%} + date_trunc('{{datepart}}', {{date}}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql b/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql new file mode 100644 index 00000000..7d06bea7 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/dateadd.sql @@ -0,0 +1,3 @@ +{% macro athena__dateadd(datepart, interval, from_date_or_timestamp) -%} + date_add('{{ datepart }}', {{ interval }}, {{ from_date_or_timestamp }}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql b/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql new file mode 100644 index 00000000..4c2dfeeb --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/datediff.sql @@ -0,0 +1,28 @@ +{% macro athena__datediff(first_date, second_date, datepart) -%} + {%- if datepart == 'year' -%} + (year(CAST({{ second_date }} AS TIMESTAMP)) - year(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'quarter' -%} + ({{ datediff(first_date, second_date, 'year') }} * 4) + quarter(CAST({{ second_date }} AS TIMESTAMP)) - quarter(CAST({{ first_date }} AS TIMESTAMP)) + {%- elif datepart == 'month' -%} + ({{ datediff(first_date, second_date, 'year') }} * 12) + month(CAST({{ second_date }} AS TIMESTAMP)) - month(CAST({{ first_date }} AS TIMESTAMP)) + {%- elif datepart == 'day' -%} + ((to_milliseconds((CAST(CAST({{ second_date }} AS TIMESTAMP) AS DATE) - CAST(CAST({{ first_date }} AS TIMESTAMP) AS DATE)))) / 86400000) + {%- elif datepart == 'week' -%} + ({{ datediff(first_date, second_date, 'day') }} / 7 + case + when dow(CAST({{first_date}} AS TIMESTAMP)) <= dow(CAST({{second_date}} AS TIMESTAMP)) then + case when {{first_date}} <= {{second_date}} then 0 else -1 end + else + case when {{first_date}} <= {{second_date}} then 1 else 0 end + end) + {%- elif datepart == 'hour' -%} + ({{ datediff(first_date, second_date, 'day') }} * 24 + hour(CAST({{ second_date }} AS TIMESTAMP)) - hour(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'minute' -%} + ({{ datediff(first_date, second_date, 'hour') }} * 60 + minute(CAST({{ second_date }} AS TIMESTAMP)) - minute(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'second' -%} + ({{ datediff(first_date, second_date, 'minute') }} * 60 + second(CAST({{ second_date }} AS TIMESTAMP)) - second(CAST({{ first_date }} AS TIMESTAMP))) + {%- elif datepart == 'millisecond' -%} + (to_milliseconds((CAST({{ second_date }} AS TIMESTAMP) - CAST({{ first_date }} AS TIMESTAMP)))) + {%- else -%} + {% if execute %}{{ exceptions.raise_compiler_error("Unsupported datepart for macro datediff in Athena: {!r}".format(datepart)) }}{% endif %} + {%- endif -%} +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql b/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql new file mode 100644 index 00000000..f93ed907 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/ddl_dml_data_type.sql @@ -0,0 +1,41 @@ +{# Athena has different types between DML and DDL #} +{# ref: https://docs.aws.amazon.com/athena/latest/ug/data-types.html #} +{% macro ddl_data_type(col_type, table_type = 'hive') -%} + -- transform varchar + {% set re = modules.re %} + {% set data_type = re.sub('(?:varchar|character varying)(?:\(\d+\))?', 'string', col_type) %} + + -- transform array and map + {%- if 'array' in data_type or 'map' in data_type -%} + {% set data_type = data_type.replace('(', '<').replace(')', '>') -%} + {%- endif -%} + + -- transform int + {%- if 'integer' in data_type -%} + {% set data_type = data_type.replace('integer', 'int') -%} + {%- endif -%} + + -- transform timestamp + {%- if table_type == 'iceberg' -%} + {%- if 'timestamp' in data_type -%} + {% set data_type = 'timestamp' -%} + {%- endif -%} + + {%- if 'binary' in data_type -%} + {% set data_type = 'binary' -%} + {%- endif -%} + {%- endif -%} + + {{ return(data_type) }} +{% endmacro %} + +{% macro dml_data_type(col_type) -%} + {%- set re = modules.re -%} + -- transform int to integer + {%- set data_type = re.sub('\bint\b', 'integer', col_type) -%} + -- transform string to varchar because string does not work in DML + {%- set data_type = re.sub('string', 'varchar', data_type) -%} + -- transform float to real because float does not work in DML + {%- set data_type = re.sub('float', 'real', data_type) -%} + {{ return(data_type) }} +{% endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql b/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql new file mode 100644 index 00000000..773eb507 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/hash.sql @@ -0,0 +1,3 @@ +{% macro athena__hash(field) -%} + lower(to_hex(md5(to_utf8(cast({{field}} as varchar))))) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql b/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql new file mode 100644 index 00000000..265e1bd1 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/listagg.sql @@ -0,0 +1,18 @@ +{% macro athena__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + array_join( + {%- if limit_num %} + slice( + {%- endif %} + array_agg( + {{ measure }} + {%- if order_by_clause %} + {{ order_by_clause }} + {%- endif %} + ) + {%- if limit_num %} + , 1, {{ limit_num }} + ) + {%- endif %} + , {{ delimiter_text }} + ) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/right.sql b/dbt-athena/src/dbt/include/athena/macros/utils/right.sql new file mode 100644 index 00000000..a0203d61 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/right.sql @@ -0,0 +1,7 @@ +{% macro athena__right(string_text, length_expression) %} + case when {{ length_expression }} = 0 + then '' + else + substr({{ string_text }}, -1 * ({{ length_expression }})) + end +{%- endmacro -%} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql b/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql new file mode 100644 index 00000000..aed0866e --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/safe_cast.sql @@ -0,0 +1,4 @@ +-- TODO: make safe_cast supports complex structures +{% macro athena__safe_cast(field, type) -%} + try_cast({{field}} as {{type}}) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql b/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql new file mode 100644 index 00000000..c746b6d2 --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/macros/utils/timestamps.sql @@ -0,0 +1,31 @@ +{% + pyathena converts time zoned timestamps to strings so lets avoid them now() +%} + +{% macro athena__current_timestamp() -%} + {{ cast_timestamp('now()') }} +{%- endmacro %} + + +{% macro cast_timestamp(timestamp_col) -%} + {%- set config = model.get('config', {}) -%} + {%- set table_type = config.get('table_type', 'hive') -%} + {%- set is_view = config.get('materialized', 'table') in ['view', 'ephemeral'] -%} + {%- if table_type == 'iceberg' and not is_view -%} + cast({{ timestamp_col }} as timestamp(6)) + {%- else -%} + cast({{ timestamp_col }} as timestamp) + {%- endif -%} +{%- endmacro %} + +{% + Macro to get the end_of_time timestamp +%} + +{% macro end_of_time() -%} + {{ return(adapter.dispatch('end_of_time')()) }} +{%- endmacro %} + +{% macro athena__end_of_time() -%} + cast('9999-01-01' AS timestamp) +{%- endmacro %} diff --git a/dbt-athena/src/dbt/include/athena/profile_template.yml b/dbt-athena/src/dbt/include/athena/profile_template.yml new file mode 100644 index 00000000..87be52ed --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/profile_template.yml @@ -0,0 +1,23 @@ +fixed: + type: athena +prompts: + s3_staging_dir: + hint: S3 location to store Athena query results and metadata, e.g. s3://athena_query_result/prefix/ + + s3_data_dir: + hint: S3 location where to store data/tables, e.g. s3://bucket_name/prefix/ + + region_name: + hint: AWS region of your Athena instance + + schema: + hint: Specify the schema (Athena database) to build models into (lowercase only) + + database: + hint: Specify the database (Data catalog) to build models into (lowercase only) + default: awsdatacatalog + + threads: + hint: '1 or more' + type: 'int' + default: 1 diff --git a/dbt-athena/src/dbt/include/athena/sample_profiles.yml b/dbt-athena/src/dbt/include/athena/sample_profiles.yml new file mode 100644 index 00000000..f9cbd0cf --- /dev/null +++ b/dbt-athena/src/dbt/include/athena/sample_profiles.yml @@ -0,0 +1,19 @@ +default: + outputs: + dev: + type: athena + s3_staging_dir: [s3_staging_dir] + s3_data_dir: [s3_data_dir] + region_name: [region_name] + database: [database name] + schema: [dev_schema] + + prod: + type: athena + s3_staging_dir: [s3_staging_dir] + s3_data_dir: [s3_data_dir] + region_name: [region_name] + database: [database name] + schema: [prod_schema] + + target: dev diff --git a/dbt-athena/tests/__init__.py b/dbt-athena/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/functional/__init__.py b/dbt-athena/tests/functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/functional/adapter/fixture_datediff.py b/dbt-athena/tests/functional/adapter/fixture_datediff.py new file mode 100644 index 00000000..4986ae57 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/fixture_datediff.py @@ -0,0 +1,52 @@ +seeds__data_datediff_csv = """first_date,second_date,datepart,result +2018-01-01 01:00:00,2018-01-02 01:00:00,day,1 +2018-01-01 01:00:00,2018-02-01 01:00:00,month,1 +2018-01-01 01:00:00,2019-01-01 01:00:00,year,1 +2018-01-01 01:00:00,2018-01-01 02:00:00,hour,1 +2018-01-01 01:00:00,2018-01-01 02:01:00,minute,61 +2018-01-01 01:00:00,2018-01-01 02:00:01,second,3601 +2019-12-31 00:00:00,2019-12-27 00:00:00,week,-1 +2019-12-31 00:00:00,2019-12-30 00:00:00,week,0 +2019-12-31 00:00:00,2020-01-02 00:00:00,week,0 +2019-12-31 00:00:00,2020-01-06 02:00:00,week,1 +,2018-01-01 02:00:00,hour, +2018-01-01 02:00:00,,hour, +""" + +models__test_datediff_sql = """ +with data as ( + select * from {{ ref('data_datediff') }} +) +select + case + when datepart = 'second' then {{ datediff('first_date', 'second_date', 'second') }} + when datepart = 'minute' then {{ datediff('first_date', 'second_date', 'minute') }} + when datepart = 'hour' then {{ datediff('first_date', 'second_date', 'hour') }} + when datepart = 'day' then {{ datediff('first_date', 'second_date', 'day') }} + when datepart = 'week' then {{ datediff('first_date', 'second_date', 'week') }} + when datepart = 'month' then {{ datediff('first_date', 'second_date', 'month') }} + when datepart = 'year' then {{ datediff('first_date', 'second_date', 'year') }} + else null + end as actual, + result as expected +from data +-- Also test correct casting of literal values. +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "millisecond") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "second") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "minute") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "hour") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "day") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-03 00:00:00.000'", "week") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "month") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "quarter") }} as actual, 1 as expected +union all select + {{ datediff("'1999-12-31 23:59:59.999'", "'2000-01-01 00:00:00.000'", "year") }} as actual, 1 as expected +""" diff --git a/dbt-athena/tests/functional/adapter/fixture_split_parts.py b/dbt-athena/tests/functional/adapter/fixture_split_parts.py new file mode 100644 index 00000000..c9ff3d7c --- /dev/null +++ b/dbt-athena/tests/functional/adapter/fixture_split_parts.py @@ -0,0 +1,39 @@ +models__test_split_part_sql = """ +with data as ( + + select * from {{ ref('data_split_part') }} + +) + +select + {{ split_part('parts', 'split_on', 1) }} as actual, + result_1 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 2) }} as actual, + result_2 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 3) }} as actual, + result_3 as expected + +from data +""" + +models__test_split_part_yml = """ +version: 2 +models: + - name: test_split_part + tests: + - assert_equal: + actual: actual + expected: expected +""" diff --git a/dbt-athena/tests/functional/adapter/test_basic_hive.py b/dbt-athena/tests/functional/adapter/test_basic_hive.py new file mode 100644 index 00000000..b4026884 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_hive.py @@ -0,0 +1,63 @@ +""" +Run the basic dbt test suite on hive tables when applicable. +""" +import pytest + +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_empty import BaseEmpty +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_singular_tests import BaseSingularTests +from dbt.tests.adapter.basic.test_singular_tests_ephemeral import ( + BaseSingularTestsEphemeral, +) +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + pass + + +class TestSingularTestsHive(BaseSingularTests): + pass + + +class TestSingularTestsEphemeralHive(BaseSingularTestsEphemeral): + pass + + +class TestEmptyHive(BaseEmpty): + pass + + +class TestEphemeralHive(BaseEphemeral): + pass + + +class TestIncrementalHive(BaseIncremental): + pass + + +class TestGenericTestsHive(BaseGenericTests): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsHive(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampHive(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodHive(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py b/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py new file mode 100644 index 00000000..e532cf7b --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_hive_native_drop.py @@ -0,0 +1,151 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +modified_config_materialized_table = """ + {{ config(materialized="table", table_type="hive", native_drop="True") }} +""" + +modified_config_materialized_incremental = """ + {{ config(materialized="incremental", + table_type="hive", + incremental_strategy="append", + unique_key="id", + native_drop="True") }} +""" + +modified_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +modified_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, modified_config_materialized_table), + (config_materialized_incremental, modified_config_materialized_incremental), + (model_base, modified_model_base), + (model_ephemeral, modified_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_iceberg.py b/dbt-athena/tests/functional/adapter/test_basic_iceberg.py new file mode 100644 index 00000000..f092b8f9 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_iceberg.py @@ -0,0 +1,147 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +iceberg_config_materialized_table = """ + {{ config(materialized="table", table_type="iceberg") }} +""" + +iceberg_config_materialized_incremental = """ + {{ config(materialized="incremental", table_type="iceberg", incremental_strategy="merge", unique_key="id") }} +""" + +iceberg_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +iceberg_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, iceberg_config_materialized_table), + (config_materialized_incremental, iceberg_config_materialized_incremental), + (model_base, iceberg_model_base), + (model_ephemeral, iceberg_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py b/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py new file mode 100644 index 00000000..020551c2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_basic_iceberg_native_drop.py @@ -0,0 +1,152 @@ +""" +Run the basic dbt test suite on hive tables when applicable. + +Some test classes are not included here, because they don't contain table models. +Those are run in the hive test suite. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + base_ephemeral_sql, + base_materialized_var_sql, + base_table_sql, + base_view_sql, + config_materialized_incremental, + config_materialized_table, + ephemeral_table_sql, + ephemeral_view_sql, + generic_test_table_yml, + generic_test_view_yml, + incremental_sql, + model_base, + model_ephemeral, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations +from dbt.tests.adapter.basic.test_ephemeral import BaseEphemeral +from dbt.tests.adapter.basic.test_generic_tests import BaseGenericTests +from dbt.tests.adapter.basic.test_incremental import BaseIncremental +from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols +from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp + +iceberg_config_materialized_table = """ + {{ config(materialized="table", table_type="iceberg", native_drop="True") }} +""" + +iceberg_config_materialized_incremental = """ + {{ config(materialized="incremental", + table_type="iceberg", + incremental_strategy="merge", + unique_key="id", + native_drop="True") }} +""" + +iceberg_model_base = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ source('raw', 'seed') }} +""" + +iceberg_model_ephemeral = """ + select + id, + name, + {{ cast_timestamp('some_date') }} as some_date + from {{ ref('ephemeral') }} +""" + + +def configure_single_model_to_use_iceberg(model): + """Adjust a given model configuration to use iceberg instead of hive.""" + replacements = [ + (config_materialized_table, iceberg_config_materialized_table), + (config_materialized_incremental, iceberg_config_materialized_incremental), + (model_base, iceberg_model_base), + (model_ephemeral, iceberg_model_ephemeral), + ] + for original, new in replacements: + model = model.replace(original.strip(), new.strip()) + + return model + + +def configure_models_to_use_iceberg(models): + """Loop over all the dbt models and set the table configuration to iceberg.""" + return {key: configure_single_model_to_use_iceberg(val) for key, val in models.items()} + + +@pytest.mark.skip( + reason="The materialized var doesn't work well, because we only want to change tables, not views. " + "It's hard to come up with an elegant fix." +) +class TestSimpleMaterializationsIceberg(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestEphemeralIceberg(BaseEphemeral): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "ephemeral.sql": base_ephemeral_sql, + "view_model.sql": ephemeral_view_sql, + "table_model.sql": ephemeral_table_sql, + "schema.yml": schema_base_yml, + } + ) + + +@pytest.mark.skip(reason="The native drop will usually time out in the test") +class TestIncrementalIceberg(BaseIncremental): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "incremental.sql": incremental_sql, + "schema.yml": schema_base_yml, + } + ) + + +class TestGenericTestsIceberg(BaseGenericTests): + @pytest.fixture(scope="class") + def models(self): + return configure_models_to_use_iceberg( + { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "schema.yml": schema_base_yml, + "schema_view.yml": generic_test_view_yml, + "schema_table.yml": generic_test_table_yml, + } + ) + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotCheckColsIceberg(BaseSnapshotCheckCols): + pass + + +@pytest.mark.skip(reason="The in-place update is not supported for seeds. We need our own implementation instead.") +class TestSnapshotTimestampIceberg(BaseSnapshotTimestamp): + pass + + +@pytest.mark.skip( + reason="Fails because the test tries to fetch the table metadata during the compile step, " + "before the models are actually run. Not sure how this test is intended to work." +) +class TestBaseAdapterMethodIceberg(BaseAdapterMethod): + pass diff --git a/dbt-athena/tests/functional/adapter/test_change_relation_types.py b/dbt-athena/tests/functional/adapter/test_change_relation_types.py new file mode 100644 index 00000000..047cac75 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_change_relation_types.py @@ -0,0 +1,26 @@ +import pytest + +from dbt.tests.adapter.relations.test_changing_relation_type import ( + BaseChangeRelationTypeValidator, +) + + +class TestChangeRelationTypesHive(BaseChangeRelationTypeValidator): + pass + + +class TestChangeRelationTypesIceberg(BaseChangeRelationTypeValidator): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + } + } + + def test_changing_materialization_changes_relation_type(self, project): + self._run_and_check_materialization("view") + self._run_and_check_materialization("table") + self._run_and_check_materialization("view") + # skip incremntal that doesn't work with Iceberg + self._run_and_check_materialization("table", extra_args=["--full-refresh"]) diff --git a/dbt-athena/tests/functional/adapter/test_constraints.py b/dbt-athena/tests/functional/adapter/test_constraints.py new file mode 100644 index 00000000..44ff8c57 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_constraints.py @@ -0,0 +1,28 @@ +import pytest + +from dbt.tests.adapter.constraints.fixtures import ( + model_quoted_column_schema_yml, + my_model_with_quoted_column_name_sql, +) +from dbt.tests.adapter.constraints.test_constraints import BaseConstraintQuotedColumn + + +class TestAthenaConstraintQuotedColumn(BaseConstraintQuotedColumn): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_with_quoted_column_name_sql, + # we replace text type with varchar + # do not replace text with string, because then string is replaced with a capital TEXT that leads to failure + "constraints_schema.yml": model_quoted_column_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self): + # FIXME: dbt-athena outputs a query about stats into `target/run/` directory. + # dbt-core expects the query to be a ddl statement to create a table. + # This is a workaround to pass the test for now. + + # NOTE: by the above reason, this test just checks the query can be executed without errors. + # The query itself is not checked. + return 'SELECT \'{"rowcount":1,"data_scanned_in_bytes":0}\';' diff --git a/dbt-athena/tests/functional/adapter/test_detailed_table_type.py b/dbt-athena/tests/functional/adapter/test_detailed_table_type.py new file mode 100644 index 00000000..ba6c885c --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_detailed_table_type.py @@ -0,0 +1,44 @@ +import re + +import pytest + +from dbt.tests.util import run_dbt, run_dbt_and_capture + +get_detailed_table_type_sql = """ +{% macro get_detailed_table_type(schema) %} + {% if execute %} + {% set relation = api.Relation.create(database="awsdatacatalog", schema=schema) %} + {% set schema_tables = adapter.list_relations_without_caching(schema_relation = relation) %} + {% for rel in schema_tables %} + {% do log('Detailed Table Type: ' ~ rel.detailed_table_type, info=True) %} + {% endfor %} + {% endif %} +{% endmacro %} +""" + +# Model SQL for an Iceberg table +iceberg_model_sql = """ + select 1 as id, 'iceberg' as name + {{ config(materialized='table', table_type='iceberg') }} +""" + + +@pytest.mark.usefixtures("project") +class TestDetailedTableType: + @pytest.fixture(scope="class") + def macros(self): + return {"get_detailed_table_type.sql": get_detailed_table_type_sql} + + @pytest.fixture(scope="class") + def models(self): + return {"iceberg_model.sql": iceberg_model_sql} + + def test_detailed_table_type(self, project): + # Run the models + run_results = run_dbt(["run"]) + assert len(run_results) == 1 # Ensure model ran successfully + + args_str = f'{{"schema": "{project.test_schema}"}}' + run_macro, stdout = run_dbt_and_capture(["run-operation", "get_detailed_table_type", "--args", args_str]) + iceberg_table_type = re.search(r"Detailed Table Type: (\w+)", stdout).group(1) + assert iceberg_table_type == "ICEBERG" diff --git a/dbt-athena/tests/functional/adapter/test_docs.py b/dbt-athena/tests/functional/adapter/test_docs.py new file mode 100644 index 00000000..e8bc5978 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_docs.py @@ -0,0 +1,135 @@ +import os + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.adapter.basic.expected_catalog import base_expected_catalog, no_stats +from dbt.tests.adapter.basic.test_docs_generate import ( + BaseDocsGenerate, + run_and_generate, + verify_metadata, +) +from dbt.tests.util import get_artifact, run_dbt + +model_sql = """ +select 1 as id +""" + +iceberg_model_sql = """ +{{ + config( + materialized="table", + table_type="iceberg", + post_hook="alter table model drop column to_drop" + ) +}} +select 1 as id, 'to_drop' as to_drop +""" + +override_macros_sql = """ +{% macro get_catalog_relations(information_schema, relations) %} + {{ return(adapter.get_catalog_by_relations(information_schema, relations)) }} +{% endmacro %} +""" + + +def custom_verify_catalog_athena(project, expected_catalog, start_time): + # get the catalog.json + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + assert os.path.exists(catalog_path) + catalog = get_artifact(catalog_path) + + # verify the catalog + assert set(catalog) == {"errors", "nodes", "sources", "metadata"} + verify_metadata( + catalog["metadata"], + "https://schemas.getdbt.com/dbt/catalog/v1.json", + start_time, + ) + assert not catalog["errors"] + + for key in "nodes", "sources": + for unique_id, expected_node in expected_catalog[key].items(): + found_node = catalog[key][unique_id] + for node_key in expected_node: + assert node_key in found_node + # the value of found_node[node_key] is not exactly expected_node[node_key] + + +class TestDocsGenerate(BaseDocsGenerate): + """ + Override of BaseDocsGenerate to make it working with Athena + """ + + @pytest.fixture(scope="class") + def expected_catalog(self, project): + return base_expected_catalog( + project, + role="test", + id_type="integer", + text_type="text", + time_type="timestamp without time zone", + view_type="VIEW", + table_type="BASE TABLE", + model_stats=no_stats(), + ) + + def test_run_and_generate_no_compile(self, project, expected_catalog): + start_time = run_and_generate(project, ["--no-compile"]) + assert not os.path.exists(os.path.join(project.project_root, "target", "manifest.json")) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Test generic "docs generate" command + def test_run_and_generate(self, project, expected_catalog): + start_time = run_and_generate(project) + custom_verify_catalog_athena(project, expected_catalog, start_time) + + # Check that assets have been copied to the target directory for use in the docs html page + assert os.path.exists(os.path.join(".", "target", "assets")) + assert os.path.exists(os.path.join(".", "target", "assets", "lorem-ipsum.txt")) + assert not os.path.exists(os.path.join(".", "target", "non-existent-assets")) + + +class TestDocsGenerateOverride: + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": model_sql} + + @pytest.fixture(scope="class") + def macros(self): + return {"override_macros_sql.sql": override_macros_sql} + + def test_generate_docs(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + docs_generate = run_dbt(["--warn-error", "docs", "generate"]) + assert len(docs_generate._compile_results.results) == 1 + assert docs_generate._compile_results.results[0].status == RunStatus.Success + assert docs_generate.errors is None + + +class TestDocsGenerateIcebergNonCurrentColumn: + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": iceberg_model_sql} + + @pytest.fixture(scope="class") + def macros(self): + return {"override_macros_sql.sql": override_macros_sql} + + def test_generate_docs(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + docs_generate = run_dbt(["--warn-error", "docs", "generate"]) + assert len(docs_generate._compile_results.results) == 1 + assert docs_generate._compile_results.results[0].status == RunStatus.Success + assert docs_generate.errors is None + + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + assert os.path.exists(catalog_path) + catalog = get_artifact(catalog_path) + columns = catalog["nodes"]["model.test.model"]["columns"] + assert "to_drop" not in columns + assert "id" in columns diff --git a/dbt-athena/tests/functional/adapter/test_empty.py b/dbt-athena/tests/functional/adapter/test_empty.py new file mode 100644 index 00000000..c79fdfc2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_empty.py @@ -0,0 +1,9 @@ +from dbt.tests.adapter.empty.test_empty import BaseTestEmpty, BaseTestEmptyInlineSourceRef + + +class TestAthenaEmpty(BaseTestEmpty): + pass + + +class TestAthenaEmptyInlineSourceRef(BaseTestEmptyInlineSourceRef): + pass diff --git a/dbt-athena/tests/functional/adapter/test_force_batch.py b/dbt-athena/tests/functional/adapter/test_force_batch.py new file mode 100644 index 00000000..4d1fcc30 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_force_batch.py @@ -0,0 +1,136 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__force_batch_sql = """ +{{ config( + materialized='table', + partitioned_by=['date_column'], + force_batch=true + ) +}} + +select + random() as rnd, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +models_append_force_batch_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='append', + partitioned_by=['date_column'], + force_batch=true + ) +}} + +select + random() as rnd, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +models_merge_force_batch_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['date_column'], + partitioned_by=['date_column'], + force_batch=true + ) +}} +{% if is_incremental() %} + select + 1 as rnd, + cast(date_column as date) as date_column + from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) + ) as t1(date_array) + cross join unnest(date_array) as t2(date_column) +{% else %} + select + 2 as rnd, + cast(date_column as date) as date_column + from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) + ) as t1(date_array) + cross join unnest(date_array) as t2(date_column) +{% endif %} +""" + + +class TestForceBatchInsertParam: + @pytest.fixture(scope="class") + def models(self): + return {"force_batch.sql": models__force_batch_sql} + + def test__force_batch_param(self, project): + relation_name = "force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert models_records_count == 212 + + +class TestAppendForceBatch: + @pytest.fixture(scope="class") + def models(self): + return {"models_append_force_batch.sql": models_append_force_batch_sql} + + def test__append_force_batch_param(self, project): + relation_name = "models_append_force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count == 212 + + +class TestMergeForceBatch: + @pytest.fixture(scope="class") + def models(self): + return {"models_merge_force_batch.sql": models_merge_force_batch_sql} + + def test__merge_force_batch_param(self, project): + relation_name = "models_merge_force_batch" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_distinct_query = f"select distinct rnd from {project.test_schema}.{relation_name}" + + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count == 212 + + models_distinct_records = project.run_sql(model_run_result_distinct_query, fetch="all")[0][0] + assert models_distinct_records == 1 diff --git a/dbt-athena/tests/functional/adapter/test_ha_iceberg.py b/dbt-athena/tests/functional/adapter/test_ha_iceberg.py new file mode 100644 index 00000000..769ce581 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_ha_iceberg.py @@ -0,0 +1,90 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__table_iceberg_naming_table_unique = """ +{{ config( + materialized='table', + table_type='iceberg', + s3_data_naming='table_unique' + ) +}} + +select + 1 as id, + 'test 1' as name +union all +select + 2 as id, + 'test 2' as name +""" + +models__table_iceberg_naming_table = """ +{{ config( + materialized='table', + table_type='iceberg', + s3_data_naming='table' + ) +}} + +select + 1 as id, + 'test 1' as name +union all +select + 2 as id, + 'test 2' as name +""" + + +class TestTableIcebergTableUnique: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_table_unique.sql": models__table_iceberg_naming_table_unique} + + def test__table_creation(self, project, capsys): + relation_name = "table_iceberg_table_unique" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + fist_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = fist_model_run.results[0] + assert first_model_run_result.status == RunStatus.Success + + first_models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert first_models_records_count == 2 + + second_model_run = run_dbt(["run", "-d", "--select", relation_name]) + second_model_run_result = second_model_run.results[0] + assert second_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + # in case of 2nd run we expect that the target table is renamed to __bkp + alter_statement = ( + f"alter table `{project.test_schema}`.`{relation_name}` " + f"rename to `{project.test_schema}`.`{relation_name}__bkp`" + ) + delete_bkp_table_log = ( + f'Deleted table from glue catalog: "awsdatacatalog"."{project.test_schema}"."{relation_name}__bkp"' + ) + assert alter_statement in out + assert delete_bkp_table_log in out + + second_models_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert second_models_records_count == 2 + + +# in case s3_data_naming=table for iceberg a compile error must be raised +# with this test we want to be sure that this type of behavior is not violated +class TestTableIcebergTable: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_table.sql": models__table_iceberg_naming_table} + + def test__table_creation(self, project): + relation_name = "table_iceberg_table" + + with pytest.raises(Exception): + run_dbt(["run", "--select", relation_name]) diff --git a/dbt-athena/tests/functional/adapter/test_hive_iceberg.py b/dbt-athena/tests/functional/adapter/test_hive_iceberg.py new file mode 100644 index 00000000..c52a720b --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_hive_iceberg.py @@ -0,0 +1,67 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__table_base_model = """ +{{ config( + materialized='table', + table_type=var("table_type"), + s3_data_naming='table_unique' + ) +}} + +select + 1 as id, + 'test 1' as name, + {{ cast_timestamp('current_timestamp') }} as created_at +union all +select + 2 as id, + 'test 2' as name, + {{ cast_timestamp('current_timestamp') }} as created_at +""" + + +class TestTableFromHiveToIceberg: + @pytest.fixture(scope="class") + def models(self): + return {"table_hive_to_iceberg.sql": models__table_base_model} + + def test__table_creation(self, project): + relation_name = "table_hive_to_iceberg" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run_hive = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"hive"}']) + model_run_result_hive = model_run_hive.results[0] + assert model_run_result_hive.status == RunStatus.Success + models_records_count_hive = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_hive == 2 + + model_run_iceberg = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"iceberg"}']) + model_run_result_iceberg = model_run_iceberg.results[0] + assert model_run_result_iceberg.status == RunStatus.Success + models_records_count_iceberg = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_iceberg == 2 + + +class TestTableFromIcebergToHive: + @pytest.fixture(scope="class") + def models(self): + return {"table_iceberg_to_hive.sql": models__table_base_model} + + def test__table_creation(self, project): + relation_name = "table_iceberg_to_hive" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + model_run_iceberg = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"iceberg"}']) + model_run_result_iceberg = model_run_iceberg.results[0] + assert model_run_result_iceberg.status == RunStatus.Success + models_records_count_iceberg = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_iceberg == 2 + + model_run_hive = run_dbt(["run", "--select", relation_name, "--vars", '{"table_type":"hive"}']) + model_run_result_hive = model_run_hive.results[0] + assert model_run_result_hive.status == RunStatus.Success + models_records_count_hive = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert models_records_count_hive == 2 diff --git a/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py b/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py new file mode 100644 index 00000000..4e45403f --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_iceberg.py @@ -0,0 +1,390 @@ +""" +Run optional dbt functional tests for Iceberg incremental merge, including delete and incremental predicates. + +""" +import re + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.adapter.incremental.test_incremental_merge_exclude_columns import ( + BaseMergeExcludeColumns, +) +from dbt.tests.adapter.incremental.test_incremental_predicates import ( + BaseIncrementalPredicates, +) +from dbt.tests.adapter.incremental.test_incremental_unique_id import ( + BaseIncrementalUniqueKey, + models__duplicated_unary_unique_key_list_sql, + models__empty_str_unique_key_sql, + models__empty_unique_key_list_sql, + models__expected__one_str__overwrite_sql, + models__expected__unique_key_list__inplace_overwrite_sql, + models__no_unique_key_sql, + models__nontyped_trinary_unique_key_list_sql, + models__not_found_unique_key_list_sql, + models__not_found_unique_key_sql, + models__str_unique_key_sql, + models__trinary_unique_key_list_sql, + models__unary_unique_key_list_sql, + seeds__add_new_rows_sql, + seeds__duplicate_insert_sql, + seeds__seed_csv, +) +from dbt.tests.util import check_relations_equal, run_dbt + +seeds__expected_incremental_predicates_csv = """id,msg,color +3,anyway,purple +1,hey,blue +2,goodbye,red +2,yo,green +""" + +seeds__expected_delete_condition_csv = """id,msg,color +1,hey,blue +3,anyway,purple +""" + +seeds__expected_predicates_and_delete_condition_csv = """id,msg,color +1,hey,blue +1,hello,blue +3,anyway,purple +""" + +models__merge_exclude_all_columns_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy='merge', + merge_exclude_columns=['msg', 'color'] +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +seeds__expected_merge_exclude_all_columns_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + +models__update_condition_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + update_condition='target.id > 1' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, 'v1-updated') + , (2, 'v2-updated') +) as t (id, value) + +{% else %} + +select * from ( + values + (-1, 'v-1') + , (0, 'v0') + , (1, 'v1') + , (2, 'v2') +) as t (id, value) + +{% endif %} +""" + +seeds__expected_update_condition_csv = """id,value +-1,v-1 +0,v0 +1,v1 +2,v2-updated +""" + +models__insert_condition_sql = """ +{{ config( + table_type='iceberg', + materialized='incremental', + incremental_strategy='merge', + unique_key=['id'], + insert_condition='src.status != 0' + ) +}} + +{% if is_incremental() %} + +select * from ( + values + (1, -1) + , (2, 0) + , (3, 1) +) as t (id, status) + +{% else %} + +select * from ( + values + (0, 1) +) as t (id, status) + +{% endif %} + +""" + +seeds__expected_insert_condition_csv = """id,status +0, 1 +1,-1 +3,1 +""" + + +class TestIcebergIncrementalUniqueKey(BaseIncrementalUniqueKey): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "iceberg", "incremental_strategy": "merge"}} + + @pytest.fixture(scope="class") + def models(self): + return { + "trinary_unique_key_list.sql": models__trinary_unique_key_list_sql, + "nontyped_trinary_unique_key_list.sql": models__nontyped_trinary_unique_key_list_sql, + "unary_unique_key_list.sql": models__unary_unique_key_list_sql, + "not_found_unique_key.sql": models__not_found_unique_key_sql, + "empty_unique_key_list.sql": models__empty_unique_key_list_sql, + "no_unique_key.sql": models__no_unique_key_sql, + "empty_str_unique_key.sql": models__empty_str_unique_key_sql, + "str_unique_key.sql": models__str_unique_key_sql, + "duplicated_unary_unique_key_list.sql": models__duplicated_unary_unique_key_list_sql, + "not_found_unique_key_list.sql": models__not_found_unique_key_list_sql, + "expected": { + "one_str__overwrite.sql": replace_cast_date(models__expected__one_str__overwrite_sql), + "unique_key_list__inplace_overwrite.sql": replace_cast_date( + models__expected__unique_key_list__inplace_overwrite_sql + ), + }, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "duplicate_insert.sql": replace_cast_date(seeds__duplicate_insert_sql), + "seed.csv": seeds__seed_csv, + "add_new_rows.sql": replace_cast_date(seeds__add_new_rows_sql), + } + + @pytest.mark.xfail(reason="Model config 'unique_keys' is required for incremental merge.") + def test__no_unique_keys(self, project): + super().test__no_unique_keys(project) + + @pytest.mark.skip( + reason="""" + If 'unique_keys' does not contain columns then the join condition will fail. + The adapter isn't handling this input scenario. + """ + ) + def test__empty_str_unique_key(self): + pass + + @pytest.mark.skip( + reason=""" + If 'unique_keys' does not contain columns then the join condition will fail. + The adapter isn't handling this input scenario. + """ + ) + def test__empty_unique_key_list(self): + pass + + +class TestIcebergIncrementalPredicates(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+incremental_predicates": ["src.id <> 3", "target.id <> 2"], + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_delete_insert_incremental_predicates.csv": seeds__expected_incremental_predicates_csv} + + +class TestIcebergDeleteCondition(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+delete_condition": "src.id = 2 and target.color = 'red'", + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_delete_insert_incremental_predicates.csv": seeds__expected_delete_condition_csv} + + # Modifying the seed_rows number from the base class method + def test__incremental_predicates(self, project): + """seed should match model after two incremental runs""" + + expected_fields = self.get_expected_fields( + relation="expected_delete_insert_incremental_predicates", seed_rows=2 + ) + test_case_fields = self.get_test_fields( + project, + seed="expected_delete_insert_incremental_predicates", + incremental_model="delete_insert_incremental_predicates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) + + +class TestIcebergPredicatesAndDeleteCondition(BaseIncrementalPredicates): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + "+delete_condition": "src.msg = 'yo' and target.color = 'red'", + "+incremental_predicates": ["src.id <> 1", "target.msg <> 'blue'"], + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_delete_insert_incremental_predicates.csv": seeds__expected_predicates_and_delete_condition_csv + } + + # Modifying the seed_rows number from the base class method + def test__incremental_predicates(self, project): + """seed should match model after two incremental runs""" + + expected_fields = self.get_expected_fields( + relation="expected_delete_insert_incremental_predicates", seed_rows=3 + ) + test_case_fields = self.get_test_fields( + project, + seed="expected_delete_insert_incremental_predicates", + incremental_model="delete_insert_incremental_predicates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) + + +class TestIcebergMergeExcludeColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + +class TestIcebergMergeExcludeAllColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + @pytest.fixture(scope="class") + def models(self): + return {"merge_exclude_columns.sql": models__merge_exclude_all_columns_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_exclude_columns.csv": seeds__expected_merge_exclude_all_columns_csv} + + +class TestIcebergUpdateCondition: + @pytest.fixture(scope="class") + def models(self): + return {"merge_update_condition.sql": models__update_condition_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_update_condition.csv": seeds__expected_update_condition_csv} + + def test__merge_update_condition(self, project): + """Seed should match the model after incremental run""" + + expected_seed_name = "expected_merge_update_condition" + run_dbt(["seed", "--select", expected_seed_name, "--full-refresh"]) + + relation_name = "merge_update_condition" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + check_relations_equal(project.adapter, [relation_name, expected_seed_name]) + + +def replace_cast_date(model: str) -> str: + """Wrap all date strings with a cast date function""" + + new_model = re.sub("'[0-9]{4}-[0-9]{2}-[0-9]{2}'", r"cast(\g<0> as date)", model) + return new_model + + +class TestIcebergInsertCondition: + @pytest.fixture(scope="class") + def models(self): + return {"merge_insert_condition.sql": models__insert_condition_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_insert_condition.csv": seeds__expected_insert_condition_csv} + + def test__merge_insert_condition(self, project): + """Seed should match the model after run""" + + expected_seed_name = "expected_merge_insert_condition" + run_dbt(["seed", "--select", expected_seed_name, "--full-refresh"]) + + relation_name = "merge_insert_condition" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_update = run_dbt(["run", "--select", relation_name]) + model_update_result = model_update.results[0] + assert model_update_result.status == RunStatus.Success + + check_relations_equal(project.adapter, [relation_name, expected_seed_name]) diff --git a/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py b/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py new file mode 100644 index 00000000..039da2d8 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py @@ -0,0 +1,115 @@ +from collections import namedtuple + +import pytest + +from dbt.tests.util import check_relations_equal, run_dbt + +models__merge_no_updates_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy = 'merge', + merge_update_columns = ['id'], + table_type = 'iceberg', +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +seeds__expected_merge_no_updates_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + +ResultHolder = namedtuple( + "ResultHolder", + [ + "seed_count", + "model_count", + "seed_rows", + "inc_test_model_count", + "relation", + ], +) + + +class TestIncrementalIcebergMergeNoUpdates: + @pytest.fixture(scope="class") + def models(self): + return {"merge_no_updates.sql": models__merge_no_updates_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_no_updates.csv": seeds__expected_merge_no_updates_csv} + + def update_incremental_model(self, incremental_model): + """update incremental model after the seed table has been updated""" + model_result_set = run_dbt(["run", "--select", incremental_model]) + return len(model_result_set) + + def get_test_fields(self, project, seed, incremental_model, update_sql_file): + seed_count = len(run_dbt(["seed", "--select", seed, "--full-refresh"])) + + model_count = len(run_dbt(["run", "--select", incremental_model, "--full-refresh"])) + + relation = incremental_model + # update seed in anticipation of incremental model update + row_count_query = f"select * from {project.test_schema}.{seed}" + + seed_rows = len(project.run_sql(row_count_query, fetch="all")) + + # propagate seed state to incremental model according to unique keys + inc_test_model_count = self.update_incremental_model(incremental_model=incremental_model) + + return ResultHolder(seed_count, model_count, seed_rows, inc_test_model_count, relation) + + def check_scenario_correctness(self, expected_fields, test_case_fields, project): + """Invoke assertions to verify correct build functionality""" + # 1. test seed(s) should build afresh + assert expected_fields.seed_count == test_case_fields.seed_count + # 2. test model(s) should build afresh + assert expected_fields.model_count == test_case_fields.model_count + # 3. seeds should have intended row counts post update + assert expected_fields.seed_rows == test_case_fields.seed_rows + # 4. incremental test model(s) should be updated + assert expected_fields.inc_test_model_count == test_case_fields.inc_test_model_count + # 5. result table should match intended result set (itself a relation) + check_relations_equal(project.adapter, [expected_fields.relation, test_case_fields.relation]) + + def test__merge_no_updates(self, project): + """seed should match model after incremental run""" + + expected_fields = ResultHolder( + seed_count=1, + model_count=1, + inc_test_model_count=1, + seed_rows=3, + relation="expected_merge_no_updates", + ) + + test_case_fields = self.get_test_fields( + project, + seed="expected_merge_no_updates", + incremental_model="merge_no_updates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) diff --git a/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py b/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py new file mode 100644 index 00000000..d06e95f2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_incremental_tmp_schema.py @@ -0,0 +1,108 @@ +import pytest +import yaml +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__schema_tmp_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='insert_overwrite', + partitioned_by=['date_column'], + temp_schema=var('temp_schema_name') + ) +}} +select + random() as rnd, + cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column +""" + + +class TestIncrementalTmpSchema: + @pytest.fixture(scope="class") + def models(self): + return {"schema_tmp.sql": models__schema_tmp_sql} + + def test__schema_tmp(self, project, capsys): + relation_name = "schema_tmp" + temp_schema_name = f"{project.test_schema}_tmp" + drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + vars_dict = { + "temp_schema_name": temp_schema_name, + "logical_date": "2024-01-01", + } + + first_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + first_model_run_result = first_model_run.results[0] + + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 1 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name not in incremental_model_run_result_table_name + + vars_dict["logical_date"] = "2024-01-02" + incremental_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + records_count_incremental_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_incremental_run == 2 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name == incremental_model_run_result_table_name.split(".")[1].strip('"') + + project.run_sql(drop_temp_schema) diff --git a/dbt-athena/tests/functional/adapter/test_on_schema_change.py b/dbt-athena/tests/functional/adapter/test_on_schema_change.py new file mode 100644 index 00000000..bebd9110 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_on_schema_change.py @@ -0,0 +1,100 @@ +import json + +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt, run_dbt_and_capture + +models__table_base_model = """ +{{ + config( + materialized='incremental', + incremental_strategy='append', + on_schema_change=var("on_schema_change"), + table_type=var("table_type"), + ) +}} + +select + 1 as id, + 'test 1' as name +{%- if is_incremental() -%} + ,current_date as updated_at +{%- endif -%} +""" + + +class TestOnSchemaChange: + @pytest.fixture(scope="class") + def models(self): + models = {} + for table_type in ["hive", "iceberg"]: + for on_schema_change in ["sync_all_columns", "append_new_columns", "ignore", "fail"]: + models[f"{table_type}_on_schema_change_{on_schema_change}.sql"] = models__table_base_model + return models + + def _column_names(self, project, relation_name): + result = project.run_sql(f"show columns from {relation_name}", fetch="all") + column_names = [row[0].strip() for row in result] + return column_names + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__sync_all_columns(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_sync_all_columns" + vars = {"on_schema_change": "sync_all_columns", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name", "updated_at"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__append_new_columns(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_append_new_columns" + vars = {"on_schema_change": "append_new_columns", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name", "updated_at"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__ignore(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_ignore" + vars = {"on_schema_change": "ignore", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental = run_dbt(args) + assert model_run_incremental.results[0].status == RunStatus.Success + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name"] + + @pytest.mark.parametrize("table_type", ["hive", "iceberg"]) + def test__fail(self, project, table_type): + relation_name = f"{table_type}_on_schema_change_fail" + vars = {"on_schema_change": "fail", "table_type": table_type} + args = ["run", "--select", relation_name, "--vars", json.dumps(vars)] + + model_run_initial = run_dbt(args) + assert model_run_initial.results[0].status == RunStatus.Success + + model_run_incremental, log = run_dbt_and_capture(args, expect_pass=False) + assert model_run_incremental.results[0].status == RunStatus.Error + assert "The source and target schemas on this incremental model are out of sync!" in log + + new_column_names = self._column_names(project, relation_name) + assert new_column_names == ["id", "name"] diff --git a/dbt-athena/tests/functional/adapter/test_partitions.py b/dbt-athena/tests/functional/adapter/test_partitions.py new file mode 100644 index 00000000..da2e5955 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_partitions.py @@ -0,0 +1,352 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +# this query generates 212 records +test_partitions_model_sql = """ +select + random() as rnd, + cast(date_column as date) as date_column, + doy(date_column) as doy +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +test_single_nullable_partition_model_sql = """ +with data as ( + select + random() as col_1, + row_number() over() as id + from + unnest(sequence(1, 200)) +) + +select + col_1, id +from data +union all +select random() as col_1, NULL as id +union all +select random() as col_1, NULL as id +""" + +test_nullable_partitions_model_sql = """ +{{ config( + materialized='table', + format='parquet', + s3_data_naming='table', + partitioned_by=['id', 'date_column'] +) }} + +with data as ( + select + random() as rnd, + row_number() over() as id, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +) + +select + rnd, + case when id <= 50 then null else id end as id, + date_column +from data +union all +select + random() as rnd, + NULL as id, + NULL as date_column +union all +select + random() as rnd, + NULL as id, + cast('2023-09-02' as date) as date_column +union all +select + random() as rnd, + 40 as id, + NULL as date_column +""" + +test_bucket_partitions_sql = """ +with non_random_strings as ( + select + chr(cast(65 + (row_number() over () % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 1) % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 4) % 26) as bigint)) as non_random_str + from + (select 1 union all select 2 union all select 3) as temp_table +) +select + cast(date_column as date) as date_column, + doy(date_column) as doy, + nrnd.non_random_str +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-24'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +join non_random_strings nrnd on true +""" + + +class TestHiveTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "hive", "+materialized": "table", "+partitioned_by": ["date_column", "doy"]}} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_hive_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_hive_partitions" + model_run_result_row_count_query = "select count(*) as records from {}.{}".format( + project.test_schema, relation_name + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_iceberg_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergIncrementalPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "incremental", + "+incremental_strategy": "merge", + "+unique_key": "doy", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions_incremental.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + """ + Check that the incremental run works with iceberg and partitioned datasets + """ + + relation_name = "test_iceberg_partitions_incremental" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name, "--full-refresh"]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + incremental_model_run = run_dbt(["run", "--select", relation_name]) + + incremental_model_run_result = incremental_model_run.results[0] + + # check that the model run successfully after incremental run + assert incremental_model_run_result.status == RunStatus.Success + + incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert incremental_records_count == 212 + + +class TestHiveNullValuedPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id", "date_column"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_nullable_partitions_model.sql": test_nullable_partitions_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_nullable_partitions_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_null_id_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where id is null" + ) + model_run_result_null_date_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where date_column is null" + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 215 + + null_id_count_first_run = project.run_sql(model_run_result_null_id_count_query, fetch="all")[0][0] + + assert null_id_count_first_run == 52 + + null_date_count_first_run = project.run_sql(model_run_result_null_date_count_query, fetch="all")[0][0] + + assert null_date_count_first_run == 2 + + +class TestHiveSingleNullValuedPartition: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_single_nullable_partition_model.sql": test_single_nullable_partition_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_single_nullable_partition_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 202 + + +class TestIcebergTablePartitionsBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy", "bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_and_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 + + +class TestIcebergTableBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_in_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 diff --git a/dbt-athena/tests/functional/adapter/test_python_submissions.py b/dbt-athena/tests/functional/adapter/test_python_submissions.py new file mode 100644 index 00000000..55f8beb9 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_python_submissions.py @@ -0,0 +1,177 @@ +# import os +# import shutil +# from collections import Counter +# from copy import deepcopy + +# import pytest +# import yaml + +# from dbt.tests.adapter.python_model.test_python_model import ( +# BasePythonIncrementalTests, +# BasePythonModelTests, +# ) +# from dbt.tests.util import run_dbt + +# basic_sql = """ +# {{ config(materialized="table") }} +# select 1 as column_1, 2 as column_2, '{{ run_started_at.strftime("%Y-%m-%d") }}' as run_date +# """ + +# basic_python = """ +# def model(dbt, spark): +# dbt.config( +# materialized='table', +# ) +# df = dbt.ref("model") +# return df +# """ + +# basic_spark_python = """ +# def model(dbt, spark_session): +# dbt.config(materialized="table") + +# data = [(1,), (2,), (3,), (4,)] + +# df = spark_session.createDataFrame(data, ["A"]) + +# return df +# """ + +# second_sql = """ +# select * from {{ref('my_python_model')}} +# """ + +# schema_yml = """version: 2 +# models: +# - name: model +# versions: +# - v: 1 +# """ + + +# class TestBasePythonModelTests(BasePythonModelTests): +# @pytest.fixture(scope="class") +# def models(self): +# return { +# "schema.yml": schema_yml, +# "model.sql": basic_sql, +# "my_python_model.py": basic_python, +# "spark_model.py": basic_spark_python, +# "second_sql_model.sql": second_sql, +# } + + +# incremental_python = """ +# def model(dbt, spark_session): +# dbt.config(materialized="incremental") +# df = dbt.ref("model") + +# if dbt.is_incremental: +# max_from_this = ( +# f"select max(run_date) from {dbt.this.schema}.{dbt.this.identifier}" +# ) +# df = df.filter(df.run_date >= spark_session.sql(max_from_this).collect()[0][0]) + +# return df +# """ + + +# class TestBasePythonIncrementalTests(BasePythonIncrementalTests): +# @pytest.fixture(scope="class") +# def project_config_update(self): +# return {"models": {"+incremental_strategy": "append"}} + +# @pytest.fixture(scope="class") +# def models(self): +# return {"model.sql": basic_sql, "incremental.py": incremental_python} + +# def test_incremental(self, project): +# vars_dict = { +# "test_run_schema": project.test_schema, +# } + +# results = run_dbt(["run", "--vars", yaml.safe_dump(vars_dict)]) +# assert len(results) == 2 + + +# class TestPythonClonePossible: +# """Test that basic clone operations are possible on Python models. This +# class has been adapted from the BaseClone and BaseClonePossible classes in +# dbt-core and modified to test Python models in addition to SQL models.""" + +# @pytest.fixture(scope="class") +# def models(self): +# return { +# "schema.yml": schema_yml, +# "model.sql": basic_sql, +# "my_python_model.py": basic_python, +# } + +# @pytest.fixture(scope="class") +# def other_schema(self, unique_schema): +# return unique_schema + "_other" + +# @pytest.fixture(scope="class") +# def profiles_config_update(self, dbt_profile_target, unique_schema, other_schema): +# """Update the profiles config to duplicate the default schema to a +# separate schema called `otherschema`.""" +# outputs = {"default": dbt_profile_target, "otherschema": deepcopy(dbt_profile_target)} +# outputs["default"]["schema"] = unique_schema +# outputs["otherschema"]["schema"] = other_schema +# return {"test": {"outputs": outputs, "target": "default"}} + +# def copy_state(self, project_root): +# """Copy the manifest.json project for a run into a separate `state/` +# directory inside the project root, so that we can reference it +# for cloning.""" +# state_path = os.path.join(project_root, "state") +# if not os.path.exists(state_path): +# os.makedirs(state_path) +# shutil.copyfile(f"{project_root}/target/manifest.json", f"{state_path}/manifest.json") + +# def run_and_save_state(self, project_root): +# """Run models and save the state to a separate directory to prepare +# for testing clone operations.""" +# results = run_dbt(["run"]) +# assert len(results) == 2 +# self.copy_state(project_root) + +# def assert_relation_types_match_counter(self, project, schema, counter): +# """Check that relation types in a given database and schema match the +# counts specified by a Counter object.""" +# schema_relations = project.adapter.list_relations(database=project.database, schema=schema) +# schema_types = [str(r.type) for r in schema_relations] +# assert Counter(schema_types) == counter + +# def test_can_clone_true(self, project, unique_schema, other_schema): +# """Test that Python models can be cloned using `dbt clone`. Adapted from +# the BaseClonePossible.test_can_clone_true test in dbt-core.""" +# project.create_test_schema(other_schema) +# self.run_and_save_state(project.project_root) + +# # Base models should be materialized as tables +# self.assert_relation_types_match_counter(project, unique_schema, Counter({"table": 2})) + +# clone_args = [ +# "clone", +# "--state", +# "state", +# "--target", +# "otherschema", +# ] + +# results = run_dbt(clone_args) +# assert len(results) == 2 + +# # Cloned models should be materialized as views +# self.assert_relation_types_match_counter(project, other_schema, Counter({"view": 2})) + +# # Objects already exist, so this is a no-op +# results = run_dbt(clone_args) +# assert len(results) == 2 +# assert all("no-op" in r.message.lower() for r in results) + +# # Recreate all objects +# results = run_dbt([*clone_args, "--full-refresh"]) +# assert len(results) == 2 +# assert not any("no-op" in r.message.lower() for r in results) diff --git a/dbt-athena/tests/functional/adapter/test_quote_seed_column.py b/dbt-athena/tests/functional/adapter/test_quote_seed_column.py new file mode 100644 index 00000000..c0845c39 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_quote_seed_column.py @@ -0,0 +1,39 @@ +import pytest + +from dbt.tests.adapter.basic.files import ( + base_materialized_var_sql, + base_table_sql, + base_view_sql, + schema_base_yml, + seeds_base_csv, +) +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations + +seed_base_csv_underscore_column = seeds_base_csv.replace("name", "_name") + +quote_columns_seed_schema = """ + +seeds: +- name: base + config: + quote_columns: true + +""" + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + schema = schema_base_yml + quote_columns_seed_schema + return { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seed_base_csv_underscore_column, + } diff --git a/dbt-athena/tests/functional/adapter/test_retries_iceberg.py b/dbt-athena/tests/functional/adapter/test_retries_iceberg.py new file mode 100644 index 00000000..f4711d9d --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_retries_iceberg.py @@ -0,0 +1,136 @@ +"""Test parallel insert into iceberg table.""" +import copy +import os + +import pytest + +from dbt.artifacts.schemas.results import RunStatus +from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture + +PARALLELISM = 10 + +base_dbt_profile = { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "threads": PARALLELISM, + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": 0, + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, +} + +models__target = """ +{{ + config( + table_type='iceberg', + materialized='table' + ) +}} + +select * from ( + values + (1, -1) +) as t (id, status) +limit 0 + +""" + +models__source = { + f"model_{i}.sql": f""" +{{{{ + config( + table_type='iceberg', + materialized='table', + tags=['src'], + pre_hook='insert into target values ({i}, {i})' + ) +}}}} + +select 1 as col +""" + for i in range(PARALLELISM) +} + +seeds__expected_target_init = "id,status" +seeds__expected_target_post = "id,status\n" + "\n".join([f"{i},{i}" for i in range(PARALLELISM)]) + + +class TestIcebergRetriesDisabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + profile["num_iceberg_retries"] = 0 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run, log = run_dbt_and_capture(["run", "--select", "tag:src"], expect_pass=False) + assert any(model_run_result.status == RunStatus.Error for model_run_result in run.results) + assert "ICEBERG_COMMIT_ERROR" in log + + +class TestIcebergRetriesEnabled: + @pytest.fixture(scope="class") + def dbt_profile_target(self): + profile = copy.deepcopy(base_dbt_profile) + # we set iceberg retries to a high number to ensure that the test will pass + profile["num_iceberg_retries"] = PARALLELISM * 5 + return profile + + @pytest.fixture(scope="class") + def models(self): + return {**{"target.sql": models__target}, **models__source} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "expected_target_init.csv": seeds__expected_target_init, + "expected_target_post.csv": seeds__expected_target_post, + } + + def test__retries_iceberg(self, project): + """Seed should match the model after run""" + + expected__init_seed_name = "expected_target_init" + run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"]) + + relation_name = "target" + model_run = run_dbt(["run", "--select", relation_name]) + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + check_relations_equal(project.adapter, [relation_name, expected__init_seed_name]) + + expected__post_seed_name = "expected_target_post" + run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"]) + + run = run_dbt(["run", "--select", "tag:src"]) + assert all([model_run_result.status == RunStatus.Success for model_run_result in run.results]) + check_relations_equal(project.adapter, [relation_name, expected__post_seed_name]) diff --git a/dbt-athena/tests/functional/adapter/test_seed_by_insert.py b/dbt-athena/tests/functional/adapter/test_seed_by_insert.py new file mode 100644 index 00000000..cb1f91c5 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_seed_by_insert.py @@ -0,0 +1,30 @@ +import pytest + +from dbt.tests.adapter.basic.files import ( + base_materialized_var_sql, + base_table_sql, + base_view_sql, + schema_base_yml, +) +from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations + +seed_by_insert_schema = """ + +seeds: +- name: base + config: + seed_by_insert: True + +""" + + +class TestSimpleMaterializationsHive(BaseSimpleMaterializations): + @pytest.fixture(scope="class") + def models(self): + schema = schema_base_yml + seed_by_insert_schema + return { + "view_model.sql": base_view_sql, + "table_model.sql": base_table_sql, + "swappable.sql": base_materialized_var_sql, + "schema.yml": schema, + } diff --git a/dbt-athena/tests/functional/adapter/test_snapshot.py b/dbt-athena/tests/functional/adapter/test_snapshot.py new file mode 100644 index 00000000..3debfb04 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_snapshot.py @@ -0,0 +1,300 @@ +""" +These are the test suites from `dbt.tests.adapter.basic.test_snapshot_check_cols` +and `dbt.tests.adapter.basic.test_snapshot_timestamp`, but slightly adapted. +The original test suites for snapshots didn't work out-of-the-box, because Athena +has no support for UPDATE statements on hive tables. This is required by those +tests to update the input seeds. + +This file also includes test suites for the iceberg tables. +""" +import pytest + +from dbt.tests.adapter.basic.files import ( + cc_all_snapshot_sql, + cc_date_snapshot_sql, + cc_name_snapshot_sql, + seeds_added_csv, + seeds_base_csv, + ts_snapshot_sql, +) +from dbt.tests.util import relation_from_name, run_dbt + + +def check_relation_rows(project, snapshot_name, count): + """Assert that the relation has the given number of rows""" + relation = relation_from_name(project.adapter, snapshot_name) + result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one") + assert result[0] == count + + +def check_relation_columns(project, snapshot_name, count): + """Assert that the relation has the given number of columns""" + relation = relation_from_name(project.adapter, snapshot_name) + result = project.run_sql(f"select * from {relation}", fetch="one") + assert len(result) == count + + +seeds_altered_added_csv = """ +11,Mateo,2014-09-08T17:04:27 +12,Julian,2000-02-05T11:48:30 +13,Gabriel,2001-07-11T07:32:52 +14,Isaac,2002-11-25T03:22:28 +15,Levi,2009-11-16T11:57:15 +16,Elizabeth,2005-04-10T03:50:11 +17,Grayson,2019-08-07T19:28:17 +18,Dylan,2014-03-02T11:50:41 +19,Jayden,2009-06-07T07:12:49 +20,Luke,2003-12-06T21:42:18 +""".lstrip() + + +seeds_altered_base_csv = """ +id,name,some_date +1,Easton_updated,1981-05-20T06:46:51 +2,Lillian_updated,1978-09-03T18:10:33 +3,Jeremiah_updated,1982-03-11T03:59:51 +4,Nolan_updated,1976-05-06T20:21:35 +5,Hannah_updated,1982-06-23T05:41:26 +6,Eleanor_updated,1991-08-10T23:12:21 +7,Lily_updated,1971-03-29T14:58:02 +8,Jonathan_updated,1988-02-26T02:55:24 +9,Adrian_updated,1994-02-09T13:14:23 +10,Nora_updated,1976-03-01T16:51:39 +""".lstrip() + + +iceberg_cc_all_snapshot_sql = """ +{% snapshot cc_all_snapshot %} + {{ config( + check_cols='all', + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_cc_name_snapshot_sql = """ +{% snapshot cc_name_snapshot %} + {{ config( + check_cols=['name'], + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_cc_date_snapshot_sql = """ +{% snapshot cc_date_snapshot %} + {{ config( + check_cols=['some_date'], + unique_key='id', + strategy='check', + target_database=database, + target_schema=schema, + table_type='iceberg' + ) }} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +iceberg_ts_snapshot_sql = """ +{% snapshot ts_snapshot %} + {{ config( + strategy='timestamp', + unique_key='id', + updated_at='some_date', + target_database=database, + target_schema=schema, + table_type='iceberg', + )}} + select + id, + name, + cast(some_date as timestamp(6)) as some_date + from {{ ref(var('seed_name', 'base')) }} +{% endsnapshot %} +""".strip() + + +class TestSnapshotCheckColsHive: + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seeds_base_csv, + "added.csv": seeds_added_csv, + "updated_1.csv": seeds_base_csv + seeds_altered_added_csv, + "updated_2.csv": seeds_altered_base_csv + seeds_altered_added_csv, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "hive_snapshot_strategy_check_cols"} + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "cc_all_snapshot.sql": cc_all_snapshot_sql, + "cc_date_snapshot.sql": cc_date_snapshot_sql, + "cc_name_snapshot.sql": cc_name_snapshot_sql, + } + + def test_snapshot_check_cols(self, project): + # seed command + results = run_dbt(["seed"]) + assert len(results) == 4 + + # snapshot command + results = run_dbt(["snapshot"]) + for result in results: + assert result.status == "success" + + check_relation_columns(project, "cc_all_snapshot", 7) + check_relation_columns(project, "cc_name_snapshot", 7) + check_relation_columns(project, "cc_date_snapshot", 7) + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 10) + check_relation_rows(project, "cc_name_snapshot", 10) + check_relation_rows(project, "cc_date_snapshot", 10) + + relation_from_name(project.adapter, "cc_all_snapshot") + + # point at the "added" seed so the snapshot sees 10 new rows + results = run_dbt(["--no-partial-parse", "snapshot", "--vars", "seed_name: added"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 20) + check_relation_rows(project, "cc_name_snapshot", 20) + check_relation_rows(project, "cc_date_snapshot", 20) + + # re-run snapshots, using "updated_1" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_1"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 30) + check_relation_rows(project, "cc_date_snapshot", 30) + # unchanged: only the timestamp changed + check_relation_rows(project, "cc_name_snapshot", 20) + + # re-run snapshots, using "updated_2" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_2"]) + for result in results: + assert result.status == "success" + + # check rowcounts for all snapshots + check_relation_rows(project, "cc_all_snapshot", 40) + check_relation_rows(project, "cc_name_snapshot", 30) + # does not see name updates + check_relation_rows(project, "cc_date_snapshot", 30) + + +class TestSnapshotTimestampHive: + @pytest.fixture(scope="class") + def seeds(self): + return { + "base.csv": seeds_base_csv, + "added.csv": seeds_added_csv, + "updated_1.csv": seeds_base_csv + seeds_altered_added_csv, + "updated_2.csv": seeds_altered_base_csv + seeds_altered_added_csv, + } + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "ts_snapshot.sql": ts_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "hive_snapshot_strategy_timestamp"} + + def test_snapshot_timestamp(self, project): + # seed command + results = run_dbt(["seed"]) + assert len(results) == 4 + + # snapshot command + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relation_columns(project, "ts_snapshot", 7) + + # snapshot has 10 rows + check_relation_rows(project, "ts_snapshot", 10) + + # point at the "added" seed so the snapshot sees 10 new rows + results = run_dbt(["snapshot", "--vars", "seed_name: added"]) + for result in results: + assert result.status == "success" + + # snapshot now has 20 rows + check_relation_rows(project, "ts_snapshot", 20) + + # re-run snapshots, using "updated_1" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_1"]) + for result in results: + assert result.status == "success" + + # snapshot now has 30 rows + check_relation_rows(project, "ts_snapshot", 30) + + # re-run snapshots, using "updated_2" + results = run_dbt(["snapshot", "--vars", "seed_name: updated_2"]) + for result in results: + assert result.status == "success" + + # snapshot still has 30 rows because timestamp not updated + check_relation_rows(project, "ts_snapshot", 30) + + +class TestIcebergSnapshotTimestamp(TestSnapshotTimestampHive): + @pytest.fixture(scope="class") + def snapshots(self): + return { + "ts_snapshot.sql": iceberg_ts_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "iceberg_snapshot_strategy_timestamp"} + + +class TestIcebergSnapshotCheckCols(TestSnapshotCheckColsHive): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "iceberg_snapshot_strategy_check_cols"} + + @pytest.fixture(scope="class") + def snapshots(self): + return { + "cc_all_snapshot.sql": iceberg_cc_all_snapshot_sql, + "cc_date_snapshot.sql": iceberg_cc_date_snapshot_sql, + "cc_name_snapshot.sql": iceberg_cc_name_snapshot_sql, + } diff --git a/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py b/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py new file mode 100644 index 00000000..0f188ce1 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_unique_tmp_table_suffix.py @@ -0,0 +1,130 @@ +import re + +import pytest +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__unique_tmp_table_suffix_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='insert_overwrite', + partitioned_by=['date_column'], + unique_tmp_table_suffix=True + ) +}} +select + random() as rnd, + cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column +""" + + +class TestUniqueTmpTableSuffix: + @pytest.fixture(scope="class") + def models(self): + return {"unique_tmp_table_suffix.sql": models__unique_tmp_table_suffix_sql} + + def test__unique_tmp_table_suffix(self, project, capsys): + relation_name = "unique_tmp_table_suffix" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + expected_unique_table_name_re = ( + r"unique_tmp_table_suffix__dbt_tmp_" + r"[0-9a-fA-F]{8}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{12}" + ) + + first_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-01"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + first_model_run_result = first_model_run.results[0] + + assert first_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + first_model_run_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0] + + # Run statements logged output should not contain unique table suffix after first run + assert not bool(re.search(expected_unique_table_name_re, first_model_run_result_table_name)) + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 1 + + incremental_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-02"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + # Run statements logged for subsequent incremental model runs should use unique table suffix + assert bool(re.search(expected_unique_table_name_re, incremental_model_run_result_table_name)) + + assert first_model_run_result_table_name != incremental_model_run_result_table_name + + incremental_model_run_2 = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + '{"logical_date": "2024-01-03"}', + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run_2.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + incremental_model_run_result_table_name_2 = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert incremental_model_run_result_table_name != incremental_model_run_result_table_name_2 + + assert first_model_run_result_table_name != incremental_model_run_result_table_name_2 diff --git a/dbt-athena/tests/functional/adapter/test_unit_testing.py b/dbt-athena/tests/functional/adapter/test_unit_testing.py new file mode 100644 index 00000000..5ec246c2 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/test_unit_testing.py @@ -0,0 +1,36 @@ +import pytest + +from dbt.tests.adapter.unit_testing.test_case_insensitivity import ( + BaseUnitTestCaseInsensivity, +) +from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput +from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes + + +class TestAthenaUnitTestingTypes(BaseUnitTestingTypes): + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["2.0", "2.0"], + ["'12345'", "12345"], + ["'string'", "string"], + ["true", "true"], + ["date '2024-04-01'", "2024-04-01"], + ["timestamp '2024-04-01 00:00:00.000'", "'2024-04-01 00:00:00.000'"], + # TODO: activate once safe_cast supports complex structures + # ["array[1, 2, 3]", "[1, 2, 3]"], + # [ + # "map(array['10', '15', '20'], array['t', 'f', NULL])", + # """'{"10: "t", "15": "f", "20": null}'""", + # ], + ] + + +class TestAthenaUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass + + +class TestAthenaUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass diff --git a/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py b/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py new file mode 100644 index 00000000..4f448420 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/utils/parse_dbt_run_output.py @@ -0,0 +1,36 @@ +import json +import re +from typing import List + + +def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]: + sql_create_statements = [] + # Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..." + for events_msg in dbt_run_capsys_output.split("\n")[1:]: + base_msg_data = None + # Best effort solution to avoid invalid records and blank lines + try: + base_msg_data = json.loads(events_msg).get("data") + except json.JSONDecodeError: + pass + """First run will not produce data.sql object in the execution logs, only data.base_msg + containing the "Running Athena query:" initial create statement. + Subsequent incremental runs will only contain the insert from the tmp table into the model + table destination. + Since we want to compare both run create statements, we need to handle both cases""" + if base_msg_data: + base_msg = base_msg_data.get("base_msg") + if "Running Athena query:" in str(base_msg): + if "create table" in base_msg: + sql_create_statements.append(base_msg) + + if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data: + if "create table" in base_msg_data.get("sql"): + sql_create_statements.append(base_msg_data.get("sql")) + + return sql_create_statements + + +def extract_create_statement_table_names(sql_create_statement: str) -> List[str]: + table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement) + return [table_name.rstrip() for table_name in table_names] diff --git a/dbt-athena/tests/functional/adapter/utils/test_utils.py b/dbt-athena/tests/functional/adapter/utils/test_utils.py new file mode 100644 index 00000000..ef13c062 --- /dev/null +++ b/dbt-athena/tests/functional/adapter/utils/test_utils.py @@ -0,0 +1,196 @@ +import pytest +from tests.functional.adapter.fixture_datediff import ( + models__test_datediff_sql, + seeds__data_datediff_csv, +) +from tests.functional.adapter.fixture_split_parts import ( + models__test_split_part_sql, + models__test_split_part_yml, +) + +from dbt.tests.adapter.utils.fixture_datediff import models__test_datediff_yml +from dbt.tests.adapter.utils.test_any_value import BaseAnyValue +from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend +from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat +from dbt.tests.adapter.utils.test_array_construct import BaseArrayConstruct +from dbt.tests.adapter.utils.test_bool_or import BaseBoolOr +from dbt.tests.adapter.utils.test_concat import BaseConcat +from dbt.tests.adapter.utils.test_current_timestamp import BaseCurrentTimestampNaive +from dbt.tests.adapter.utils.test_date_trunc import BaseDateTrunc +from dbt.tests.adapter.utils.test_dateadd import BaseDateAdd +from dbt.tests.adapter.utils.test_datediff import BaseDateDiff +from dbt.tests.adapter.utils.test_escape_single_quotes import ( + BaseEscapeSingleQuotesQuote, +) +from dbt.tests.adapter.utils.test_except import BaseExcept +from dbt.tests.adapter.utils.test_hash import BaseHash +from dbt.tests.adapter.utils.test_intersect import BaseIntersect +from dbt.tests.adapter.utils.test_length import BaseLength +from dbt.tests.adapter.utils.test_listagg import BaseListagg +from dbt.tests.adapter.utils.test_position import BasePosition +from dbt.tests.adapter.utils.test_replace import BaseReplace +from dbt.tests.adapter.utils.test_right import BaseRight +from dbt.tests.adapter.utils.test_split_part import BaseSplitPart +from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral + +models__array_concat_expected_sql = """ +select 1 as id, {{ array_construct([1,2,3,4,5,6]) }} as array_col +""" + + +models__array_concat_actual_sql = """ +select 1 as id, {{ array_concat(array_construct([1,2,3]), array_construct([4,5,6])) }} as array_col +""" + + +class TestAnyValue(BaseAnyValue): + pass + + +class TestArrayAppend(BaseArrayAppend): + pass + + +class TestArrayConstruct(BaseArrayConstruct): + pass + + +# Altered test because can't merge empty and non-empty arrays. +class TestArrayConcat(BaseArrayConcat): + @pytest.fixture(scope="class") + def models(self): + return { + "actual.sql": models__array_concat_actual_sql, + "expected.sql": models__array_concat_expected_sql, + } + + +class TestBoolOr(BaseBoolOr): + pass + + +class TestConcat(BaseConcat): + pass + + +class TestEscapeSingleQuotes(BaseEscapeSingleQuotesQuote): + pass + + +class TestExcept(BaseExcept): + pass + + +class TestHash(BaseHash): + pass + + +class TestIntersect(BaseIntersect): + pass + + +class TestLength(BaseLength): + pass + + +class TestPosition(BasePosition): + pass + + +class TestReplace(BaseReplace): + pass + + +class TestRight(BaseRight): + pass + + +class TestSplitPart(BaseSplitPart): + @pytest.fixture(scope="class") + def models(self): + return { + "test_split_part.yml": models__test_split_part_yml, + "test_split_part.sql": self.interpolate_macro_namespace(models__test_split_part_sql, "split_part"), + } + + +class TestStringLiteral(BaseStringLiteral): + pass + + +class TestDateTrunc(BaseDateTrunc): + pass + + +class TestDateAdd(BaseDateAdd): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "name": "test_date_add", + "seeds": { + "test": { + "date_add": { + "data_dateadd": { + "+column_types": { + "from_time": "timestamp", + "result": "timestamp", + }, + } + } + }, + }, + } + + +class TestDateDiff(BaseDateDiff): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "name": "test_date_diff", + "seeds": { + "test": { + "date_diff": { + "data_datediff": { + "+column_types": { + "first_date": "timestamp", + "second_date": "timestamp", + }, + } + } + }, + }, + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"data_datediff.csv": seeds__data_datediff_csv} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_datediff.yml": models__test_datediff_yml, + "test_datediff.sql": self.interpolate_macro_namespace(models__test_datediff_sql, "datediff"), + } + + +# TODO: Activate this once we have datatypes.sql macro +# class TestCastBoolToText(BaseCastBoolToText): +# pass + + +# TODO: Activate this once we have datatypes.sql macro +# class TestSafeCast(BaseSafeCast): +# pass + + +class TestListagg(BaseListagg): + pass + + +# TODO: Implement this macro when needed +# class TestLastDay(BaseLastDay): +# pass + + +class TestCurrentTimestamp(BaseCurrentTimestampNaive): + pass diff --git a/dbt-athena/tests/functional/conftest.py b/dbt-athena/tests/functional/conftest.py new file mode 100644 index 00000000..1591459b --- /dev/null +++ b/dbt-athena/tests/functional/conftest.py @@ -0,0 +1,26 @@ +import os + +import pytest + +# Import the functional fixtures as a plugin +# Note: fixtures with session scope need to be local +pytest_plugins = ["dbt.tests.fixtures.project"] + + +# The profile dictionary, used to write out profiles.yml +@pytest.fixture(scope="class") +def dbt_profile_target(): + return { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")), + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, + "spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"), + } diff --git a/dbt-athena/tests/unit/__init__.py b/dbt-athena/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt-athena/tests/unit/conftest.py b/dbt-athena/tests/unit/conftest.py new file mode 100644 index 00000000..21d051c7 --- /dev/null +++ b/dbt-athena/tests/unit/conftest.py @@ -0,0 +1,75 @@ +from io import StringIO +import os +from unittest.mock import MagicMock, patch + +import boto3 +import pytest + +from dbt_common.events import get_event_manager +from dbt_common.events.base_types import EventLevel +from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter + +from dbt.adapters.athena import connections +from dbt.adapters.athena.connections import AthenaCredentials + +from tests.unit.utils import MockAWSService +from tests.unit import constants + + +@pytest.fixture(scope="class") +def athena_client(): + with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client: + return mock_athena_client + + +@pytest.fixture(scope="function") +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = constants.AWS_REGION + + +@patch.object(connections, "AthenaCredentials") +@pytest.fixture(scope="class") +def athena_credentials(): + return AthenaCredentials( + database=constants.DATA_CATALOG_NAME, + schema=constants.DATABASE_NAME, + s3_staging_dir=constants.S3_STAGING_DIR, + region_name=constants.AWS_REGION, + work_group=constants.ATHENA_WORKGROUP, + spark_work_group=constants.SPARK_WORKGROUP, + ) + + +@pytest.fixture() +def mock_aws_service(aws_credentials) -> MockAWSService: + return MockAWSService() + + +@pytest.fixture(scope="function") +def dbt_error_caplog() -> StringIO: + return _setup_custom_caplog("dbt_error", EventLevel.ERROR) + + +@pytest.fixture(scope="function") +def dbt_debug_caplog() -> StringIO: + return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG) + + +def _setup_custom_caplog(name: str, level: EventLevel): + string_buf = StringIO() + capture_config = LoggerConfig( + name=name, + level=level, + use_colors=False, + line_format=LineFormat.PlainText, + filter=NoFilter, + output_stream=string_buf, + ) + event_manager = get_event_manager() + event_manager.add_logger(capture_config) + return string_buf diff --git a/dbt-athena/tests/unit/constants.py b/dbt-athena/tests/unit/constants.py new file mode 100644 index 00000000..f549ad64 --- /dev/null +++ b/dbt-athena/tests/unit/constants.py @@ -0,0 +1,11 @@ +CATALOG_ID = "12345678910" +DATA_CATALOG_NAME = "awsdatacatalog" +SHARED_DATA_CATALOG_NAME = "9876543210" +FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source" +DATABASE_NAME = "test_dbt_athena" +BUCKET = "test-dbt-athena" +AWS_REGION = "eu-west-1" +S3_STAGING_DIR = "s3://my-bucket/test-dbt/" +S3_TMP_TABLE_DIR = "s3://my-bucket/test-dbt-temp/" +ATHENA_WORKGROUP = "dbt-athena-adapter" +SPARK_WORKGROUP = "spark" diff --git a/dbt-athena/tests/unit/fixtures.py b/dbt-athena/tests/unit/fixtures.py new file mode 100644 index 00000000..ba3d048e --- /dev/null +++ b/dbt-athena/tests/unit/fixtures.py @@ -0,0 +1 @@ +seed_data = [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}] diff --git a/dbt-athena/tests/unit/test_adapter.py b/dbt-athena/tests/unit/test_adapter.py new file mode 100644 index 00000000..e3c3b923 --- /dev/null +++ b/dbt-athena/tests/unit/test_adapter.py @@ -0,0 +1,1330 @@ +import datetime +import decimal +from multiprocessing import get_context +from unittest import mock +from unittest.mock import patch + +import agate +import boto3 +import botocore +import pytest +from dbt_common.clients import agate_helper +from dbt_common.exceptions import ConnectionError, DbtRuntimeError +from moto import mock_aws +from moto.core import DEFAULT_ACCOUNT_ID + +from dbt.adapters.athena import AthenaAdapter +from dbt.adapters.athena import Plugin as AthenaPlugin +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter +from dbt.adapters.athena.exceptions import S3LocationException +from dbt.adapters.athena.relation import AthenaRelation, TableType +from dbt.adapters.athena.utils import AthenaCatalogType +from dbt.adapters.contracts.connection import ConnectionState +from dbt.adapters.contracts.relation import RelationType + +from .constants import ( + ATHENA_WORKGROUP, + AWS_REGION, + BUCKET, + DATA_CATALOG_NAME, + DATABASE_NAME, + FEDERATED_QUERY_CATALOG_NAME, + S3_STAGING_DIR, + SHARED_DATA_CATALOG_NAME, +) +from .fixtures import seed_data +from .utils import TestAdapterConversions, config_from_parts_or_dicts, inject_adapter + + +class TestAthenaAdapter: + def setup_method(self, _): + self.config = TestAthenaAdapter._config_from_settings() + self._adapter = None + self.used_schemas = frozenset( + { + ("awsdatacatalog", "foo"), + ("awsdatacatalog", "quux"), + ("awsdatacatalog", "baz"), + (SHARED_DATA_CATALOG_NAME, "foo"), + (FEDERATED_QUERY_CATALOG_NAME, "foo"), + } + ) + + @property + def adapter(self): + if self._adapter is None: + self._adapter = AthenaAdapter(self.config, get_context("spawn")) + inject_adapter(self._adapter, AthenaPlugin) + return self._adapter + + @staticmethod + def _config_from_settings(settings={}): + project_cfg = { + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "config-version": 2, + } + + profile_cfg = { + "outputs": { + "test": { + **{ + "type": "athena", + "s3_staging_dir": S3_STAGING_DIR, + "region_name": AWS_REGION, + "database": DATA_CATALOG_NAME, + "work_group": ATHENA_WORKGROUP, + "schema": DATABASE_NAME, + }, + **settings, + } + }, + "target": "test", + } + + return config_from_parts_or_dicts(project_cfg, profile_cfg) + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection_validations(self, connection_cls): + try: + connection = self.adapter.acquire_connection("dummy") + except DbtRuntimeError as e: + pytest.fail(f"got ValidationException: {e}") + except BaseException as e: + pytest.fail(f"acquiring connection failed with unknown exception: {e}") + + connection_cls.assert_not_called() + connection.handle + connection_cls.assert_called_once() + _, arguments = connection_cls.call_args_list[0] + assert arguments["s3_staging_dir"] == "s3://my-bucket/test-dbt/" + assert arguments["endpoint_url"] is None + assert arguments["schema_name"] == "test_dbt_athena" + assert arguments["work_group"] == "dbt-athena-adapter" + assert arguments["cursor_class"] == AthenaCursor + assert isinstance(arguments["formatter"], AthenaParameterFormatter) + assert arguments["poll_interval"] == 1.0 + assert arguments["retry_config"].attempt == 6 + assert arguments["retry_config"].exceptions == ( + "ThrottlingException", + "TooManyRequestsException", + "InternalServerException", + ) + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection(self, connection_cls): + connection = self.adapter.acquire_connection("dummy") + + connection_cls.assert_not_called() + connection.handle + assert connection.state == ConnectionState.OPEN + assert connection.handle is not None + connection_cls.assert_called_once() + + @mock.patch("dbt.adapters.athena.connections.AthenaConnection") + def test_acquire_connection_exc(self, connection_cls, dbt_error_caplog): + connection_cls.side_effect = lambda **_: (_ for _ in ()).throw(Exception("foobar")) + connection = self.adapter.acquire_connection("dummy") + conn_res = None + with pytest.raises(ConnectionError) as exc: + conn_res = connection.handle + + assert conn_res is None + assert connection.state == ConnectionState.FAIL + assert exc.value.__str__() == "foobar" + assert "Got an error when attempting to open a Athena connection due to foobar" in dbt_error_caplog.getvalue() + + @pytest.mark.parametrize( + ( + "s3_data_dir", + "s3_data_naming", + "s3_path_table_part", + "s3_tmp_table_dir", + "external_location", + "is_temporary_table", + "expected", + ), + ( + pytest.param( + None, "table", None, None, None, False, "s3://my-bucket/test-dbt/tables/table", id="table naming" + ), + pytest.param( + None, "unique", None, None, None, False, "s3://my-bucket/test-dbt/tables/uuid", id="unique naming" + ), + pytest.param( + None, + "table_unique", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/table/uuid", + id="table_unique naming", + ), + pytest.param( + None, + "schema_table", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/table", + id="schema_table naming", + ), + pytest.param( + None, + "schema_table_unique", + None, + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/table/uuid", + id="schema_table_unique naming", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + None, + False, + "s3://my-data-bucket/schema/table/uuid", + id="data_dir set", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + "s3://path/to/external/", + False, + "s3://path/to/external", + id="external_location set and not temporary", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + "s3://my-bucket/test-dbt-temp/", + "s3://path/to/external/", + True, + "s3://my-bucket/test-dbt-temp/schema/table/uuid", + id="s3_tmp_table_dir set, external_location set and temporary", + ), + pytest.param( + "s3://my-data-bucket/", + "schema_table_unique", + None, + None, + "s3://path/to/external/", + True, + "s3://my-data-bucket/schema/table/uuid", + id="s3_tmp_table_dir is empty, external_location set and temporary", + ), + pytest.param( + None, + "schema_table_unique", + "other_table", + None, + None, + False, + "s3://my-bucket/test-dbt/tables/schema/other_table/uuid", + id="s3_path_table_part set", + ), + ), + ) + @patch("dbt.adapters.athena.impl.uuid4", return_value="uuid") + def test_generate_s3_location( + self, + _, + s3_data_dir, + s3_data_naming, + s3_tmp_table_dir, + external_location, + s3_path_table_part, + is_temporary_table, + expected, + ): + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="schema", + identifier="table", + s3_path_table_part=s3_path_table_part, + ) + assert expected == self.adapter.generate_s3_location( + relation, s3_data_dir, s3_data_naming, s3_tmp_table_dir, external_location, is_temporary_table + ) + + @mock_aws + def test_get_table_location(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + assert self.adapter.get_glue_table_location(relation) == "s3://test-dbt-athena/tables/test_table" + + @mock_aws + def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name, location="") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + with pytest.raises(S3LocationException) as exc: + self.adapter.get_glue_table_location(relation) + assert exc.value.args[0] == ( + 'Relation "awsdatacatalog"."test_dbt_athena"."test_table" is of type \'table\' which requires a ' + "location, but no location returned by Glue." + ) + + @mock_aws + def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service): + view_name = "view" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_view(view_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=view_name, type=RelationType.View + ) + assert self.adapter.get_glue_table_location(relation) is None + + @mock_aws + def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service): + table_name = "test_table" + self.adapter.acquire_connection("dummy") + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + assert self.adapter.get_glue_table_location(relation) is None + assert f"Table {relation.render()} does not exists - Ignoring" in dbt_debug_caplog.getvalue() + + @mock_aws + def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service): + table_name = "table" + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table(table_name) + mock_aws_service.add_data_in_table(table_name) + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.acquire_connection("dummy") + self.adapter.clean_up_partitions(relation, "dt < '2022-01-03'") + log_records = dbt_debug_caplog.getvalue() + assert ( + "Deleting table data: path=" + "'s3://test-dbt-athena/tables/table/dt=2022-01-01', " + "bucket='test-dbt-athena', " + "prefix='tables/table/dt=2022-01-01/'" in log_records + ) + assert ( + "Deleting table data: path=" + "'s3://test-dbt-athena/tables/table/dt=2022-01-02', " + "bucket='test-dbt-athena', " + "prefix='tables/table/dt=2022-01-02/'" in log_records + ) + s3 = boto3.client("s3", region_name=AWS_REGION) + keys = [obj["Key"] for obj in s3.list_objects_v2(Bucket=BUCKET)["Contents"]] + assert set(keys) == {"tables/table/dt=2022-01-03/data1.parquet", "tables/table/dt=2022-01-03/data2.parquet"} + + @mock_aws + def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="table", + ) + result = self.adapter.clean_up_table(relation) + assert result is None + assert ( + 'Table "awsdatacatalog"."test_dbt_athena"."table" does not exists - Ignoring' in dbt_debug_caplog.getvalue() + ) + + @mock_aws + def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + mock_aws_service.create_view("test_view") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="test_view", + type=RelationType.View, + ) + result = self.adapter.clean_up_table(relation) + assert result is None + + @mock_aws + def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.add_data_in_table("table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="table", + ) + self.adapter.clean_up_table(relation) + assert ( + "Deleting table data: path='s3://test-dbt-athena/tables/table', " + "bucket='test-dbt-athena', " + "prefix='tables/table/'" in dbt_debug_caplog.getvalue() + ) + s3 = boto3.client("s3", region_name=AWS_REGION) + objs = s3.list_objects_v2(Bucket=BUCKET) + assert objs["KeyCount"] == 0 + + @pytest.mark.parametrize( + "column,quote_config,quote_character,expected", + [ + pytest.param("col", False, None, "col"), + pytest.param("col", True, None, '"col"'), + pytest.param("col", False, "`", "col"), + pytest.param("col", True, "`", "`col`"), + ], + ) + def test_quote_seed_column(self, column, quote_config, quote_character, expected): + assert self.adapter.quote_seed_column(column, quote_config, quote_character) == expected + + @mock_aws + def test__get_one_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database("foo") + mock_aws_service.create_database("quux") + mock_aws_service.create_database("baz") + mock_aws_service.create_table(table_name="bar", database_name="foo") + mock_aws_service.create_table(table_name="bar", database_name="quux") + mock_information_schema = mock.MagicMock() + mock_information_schema.database = "awsdatacatalog" + + self.adapter.acquire_connection("dummy") + actual = self.adapter._get_one_catalog(mock_information_schema, {"foo", "quux"}, self.used_schemas) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), + ("awsdatacatalog", "quux", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "dt", 2, "date", None), + ] + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + mock_aws_service.create_table_without_type(table_name="qux", database_name="baz") + with pytest.raises(ValueError): + self.adapter._get_one_catalog(mock_information_schema, {"baz"}, self.used_schemas) + + @mock_aws + def test__get_one_catalog_by_relations(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database("foo") + mock_aws_service.create_database("quux") + mock_aws_service.create_table(database_name="foo", table_name="bar") + # we create another relation + mock_aws_service.create_table(table_name="bar", database_name="quux") + + mock_information_schema = mock.MagicMock() + mock_information_schema.database = "awsdatacatalog" + + self.adapter.acquire_connection("dummy") + + rel_1 = self.adapter.Relation.create( + database="awsdatacatalog", + schema="foo", + identifier="bar", + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + + expected_rows = [ + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.used_schemas) + assert actual.column_names == expected_column_names + assert actual.rows == expected_rows + + @mock_aws + def test__get_one_catalog_shared_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog(catalog_name=SHARED_DATA_CATALOG_NAME, catalog_id=SHARED_DATA_CATALOG_NAME) + mock_aws_service.create_database("foo", catalog_id=SHARED_DATA_CATALOG_NAME) + mock_aws_service.create_table(table_name="bar", database_name="foo", catalog_id=SHARED_DATA_CATALOG_NAME) + mock_information_schema = mock.MagicMock() + mock_information_schema.database = SHARED_DATA_CATALOG_NAME + + self.adapter.acquire_connection("dummy") + actual = self.adapter._get_one_catalog( + mock_information_schema, + {"foo"}, + self.used_schemas, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + ("9876543210", "foo", "bar", "table", None, "id", 0, "string", None), + ("9876543210", "foo", "bar", "table", None, "country", 1, "string", None), + ("9876543210", "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + @mock_aws + def test__get_one_catalog_federated_query_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog( + catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA + ) + mock_information_schema = mock.MagicMock() + mock_information_schema.database = FEDERATED_QUERY_CATALOG_NAME + + # Original botocore _make_api_call function + orig = botocore.client.BaseClient._make_api_call + + # Mocking this as list_table_metadata and creating non-glue tables is not supported by moto. + # Followed this guide: http://docs.getmoto.org/en/latest/docs/services/patching_other_services.html + def mock_athena_list_table_metadata(self, operation_name, kwarg): + if operation_name == "ListTableMetadata": + return { + "TableMetadataList": [ + { + "Name": "bar", + "TableType": "EXTERNAL_TABLE", + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + } + ], + } + # If we don't want to patch the API call + return orig(self, operation_name, kwarg) + + self.adapter.acquire_connection("dummy") + with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): + actual = self.adapter._get_one_catalog( + mock_information_schema, + {"foo"}, + self.used_schemas, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + ) + expected_rows = [ + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + + @mock_aws + def test__get_data_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + self.adapter.acquire_connection("dummy") + res = self.adapter._get_data_catalog(DATA_CATALOG_NAME) + assert {"Name": "awsdatacatalog", "Type": "GLUE", "Parameters": {"catalog-id": DEFAULT_ACCOUNT_ID}} == res + + def _test_list_relations_without_caching(self, schema_relation): + self.adapter.acquire_connection("dummy") + relations = self.adapter.list_relations_without_caching(schema_relation) + assert len(relations) == 4 + assert all(isinstance(rel, AthenaRelation) for rel in relations) + relations.sort(key=lambda rel: rel.name) + iceberg_table = relations[0] + other = relations[1] + table = relations[2] + view = relations[3] + assert iceberg_table.name == "iceberg" + assert iceberg_table.type == "table" + assert iceberg_table.detailed_table_type == "ICEBERG" + assert other.name == "other" + assert other.type == "table" + assert other.detailed_table_type == "" + assert table.name == "table" + assert table.type == "table" + assert table.detailed_table_type == "" + assert view.name == "view" + assert view.type == "view" + assert view.detailed_table_type == "" + + @mock_aws + def test_list_relations_without_caching_with_awsdatacatalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.create_table("other") + mock_aws_service.create_view("view") + mock_aws_service.create_table_without_table_type("without_table_type") + mock_aws_service.create_iceberg_table("iceberg") + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self._test_list_relations_without_caching(schema_relation) + + @mock_aws + def test_list_relations_without_caching_with_other_glue_data_catalog(self, mock_aws_service): + data_catalog_name = "other_data_catalog" + mock_aws_service.create_data_catalog(data_catalog_name) + mock_aws_service.create_database() + mock_aws_service.create_table("table") + mock_aws_service.create_table("other") + mock_aws_service.create_view("view") + mock_aws_service.create_table_without_table_type("without_table_type") + mock_aws_service.create_iceberg_table("iceberg") + schema_relation = self.adapter.Relation.create( + database=data_catalog_name, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self._test_list_relations_without_caching(schema_relation) + + @mock_aws + def test_list_relations_without_caching_on_unknown_schema(self, mock_aws_service): + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="unknown_schema", + quote_policy=self.adapter.config.quoting, + ) + self.adapter.acquire_connection("dummy") + relations = self.adapter.list_relations_without_caching(schema_relation) + assert relations == [] + + @mock_aws + @patch("dbt.adapters.athena.impl.SQLAdapter.list_relations_without_caching", return_value=[]) + def test_list_relations_without_caching_with_non_glue_data_catalog( + self, parent_list_relations_without_caching, mock_aws_service + ): + data_catalog_name = "other_data_catalog" + mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE) + schema_relation = self.adapter.Relation.create( + database=data_catalog_name, + schema=DATABASE_NAME, + quote_policy=self.adapter.config.quoting, + ) + self.adapter.acquire_connection("dummy") + self.adapter.list_relations_without_caching(schema_relation) + parent_list_relations_without_caching.assert_called_once_with(schema_relation) + + @pytest.mark.parametrize( + "s3_path,expected", + [ + ("s3://my-bucket/test-dbt/tables/schema/table", ("my-bucket", "test-dbt/tables/schema/table/")), + ("s3://my-bucket/test-dbt/tables/schema/table/", ("my-bucket", "test-dbt/tables/schema/table/")), + ], + ) + def test_parse_s3_path(self, s3_path, expected): + assert self.adapter._parse_s3_path(s3_path) == expected + + @mock_aws + def test_swap_table_with_partitions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table(source_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) + mock_aws_service.create_table(target_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, target_table) + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + + @mock_aws + def test_swap_table_without_partitions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table_without_partitions(source_table) + mock_aws_service.create_table_without_partitions(target_table) + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + + @mock_aws + def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + # source table does not have partitions + mock_aws_service.create_table_without_partitions(source_table) + + # the target table has partitions + mock_aws_service.create_table(target_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, target_table) + + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + + self.adapter.swap_table(source_relation, target_relation) + glue_client = boto3.client("glue", region_name=AWS_REGION) + + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" + ) + + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + assert len(target_table_partitions) == 0 + + @mock_aws + def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + target_table = "target_table" + source_table = "source_table" + mock_aws_service.create_table(source_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) + mock_aws_service.create_table_without_partitions(target_table) + glue_client = boto3.client("glue", region_name=AWS_REGION) + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" + ) + assert len(target_table_partitions) == 0 + source_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=source_table, + ) + target_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=target_table, + ) + self.adapter.swap_table(source_relation, target_relation) + target_table_partitions_after = glue_client.get_partitions( + DatabaseName=DATABASE_NAME, TableName=target_table + ).get("Partitions") + + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" + assert len(target_table_partitions_after) == 26 + + @mock_aws + def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_caplog): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + glue = boto3.client("glue", region_name=AWS_REGION) + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") + assert len(table_versions) == 4 + version_to_keep = 1 + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + versions_to_expire = self.adapter._get_glue_table_versions_to_expire(relation, version_to_keep) + assert len(versions_to_expire) == 3 + assert [v["VersionId"] for v in versions_to_expire] == ["3", "2", "1"] + + @mock_aws + def test_expire_glue_table_versions(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + mock_aws_service.add_table_version(DATABASE_NAME, table_name) + glue = boto3.client("glue", region_name=AWS_REGION) + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") + assert len(table_versions) == 4 + version_to_keep = 1 + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.expire_glue_table_versions(relation, version_to_keep, False) + # TODO delete_table_version is not implemented in moto + # TODO moto issue https://github.com/getmoto/moto/issues/5952 + # assert len(result) == 3 + + @mock_aws + def test_upload_seed_to_s3(self, mock_aws_service): + seed_table = agate.Table.from_object(seed_data) + self.adapter.acquire_connection("dummy") + + database = "db_seeds" + table = "data" + + s3_client = boto3.client("s3", region_name=AWS_REGION) + s3_client.create_bucket(Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=database, + identifier=table, + ) + + location = self.adapter.upload_seed_to_s3( + relation, + seed_table, + s3_data_dir=f"s3://{BUCKET}", + s3_data_naming="schema_table", + external_location=None, + ) + + prefix = "db_seeds/data" + objects = s3_client.list_objects(Bucket=BUCKET, Prefix=prefix).get("Contents") + + assert location == f"s3://{BUCKET}/{prefix}" + assert len(objects) == 1 + assert objects[0].get("Key").endswith(".csv") + + @mock_aws + def test_upload_seed_to_s3_external_location(self, mock_aws_service): + seed_table = agate.Table.from_object(seed_data) + self.adapter.acquire_connection("dummy") + + bucket = "my-external-location" + prefix = "seeds/one" + external_location = f"s3://{bucket}/{prefix}" + + s3_client = boto3.client("s3", region_name=AWS_REGION) + s3_client.create_bucket(Bucket=bucket, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema="db_seeds", + identifier="data", + ) + + location = self.adapter.upload_seed_to_s3( + relation, + seed_table, + s3_data_dir=None, + s3_data_naming="schema_table", + external_location=external_location, + ) + + objects = s3_client.list_objects(Bucket=bucket, Prefix=prefix).get("Contents") + + assert location == f"s3://{bucket}/{prefix}" + assert len(objects) == 1 + assert objects[0].get("Key").endswith(".csv") + + @mock_aws + def test_get_work_group_output_location(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_with_output_location_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert work_group_location_enforced + + def test_get_work_group_output_location_if_workgroup_check_is_skipepd(self): + settings = { + "skip_workgroup_check": True, + } + + self.config = TestAthenaAdapter._config_from_settings(settings) + self.adapter.acquire_connection("dummy") + + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_get_work_group_output_location_no_location(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_no_output_location(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_get_work_group_output_location_not_enforced(self, mock_aws_service): + self.adapter.acquire_connection("dummy") + mock_aws_service.create_work_group_with_output_location_not_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_aws + def test_persist_docs_to_glue_no_comment(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.persist_docs_to_glue( + schema_relation, + { + "description": """ + A table with str, 123, &^% \" and ' + + and an other paragraph. + """, + "columns": { + "id": { + "meta": {"primary_key": "true"}, + "description": """ + A column with str, 123, &^% \" and ' + + and an other paragraph. + """, + } + }, + }, + False, + False, + ) + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=DATABASE_NAME, Name=table_name).get("Table") + assert not table.get("Description", "") + assert not table["Parameters"].get("comment") + assert all(not col.get("Comment") for col in table["StorageDescriptor"]["Columns"]) + assert all(not col.get("Parameters") for col in table["StorageDescriptor"]["Columns"]) + + @mock_aws + def test_persist_docs_to_glue_comment(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + table_name = "my_table" + mock_aws_service.create_table(table_name) + schema_relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier=table_name, + ) + self.adapter.persist_docs_to_glue( + schema_relation, + { + "description": """ + A table with str, 123, &^% \" and ' + + and an other paragraph. + """, + "columns": { + "id": { + "meta": {"primary_key": True}, + "description": """ + A column with str, 123, &^% \" and ' + + and an other paragraph. + """, + } + }, + }, + True, + True, + ) + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=DATABASE_NAME, Name=table_name).get("Table") + assert table["Description"] == "A table with str, 123, &^% \" and ' and an other paragraph." + assert table["Parameters"]["comment"] == "A table with str, 123, &^% \" and ' and an other paragraph." + col_id = [col for col in table["StorageDescriptor"]["Columns"] if col["Name"] == "id"][0] + assert col_id["Comment"] == "A column with str, 123, &^% \" and ' and an other paragraph." + assert col_id["Parameters"] == {"primary_key": "True"} + + @mock_aws + def test_list_schemas(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database(name="foo") + mock_aws_service.create_database(name="bar") + mock_aws_service.create_database(name="quux") + self.adapter.acquire_connection("dummy") + res = self.adapter.list_schemas("") + assert sorted(res) == ["bar", "foo", "quux"] + + @mock_aws + def test_get_columns_in_relation(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + columns = self.adapter.get_columns_in_relation( + self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="tbl_name", + ) + ) + assert columns == [ + AthenaColumn(column="id", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="country", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="dt", dtype="date", table_type=TableType.TABLE), + ] + + @mock_aws + def test_get_columns_in_relation_not_found_table(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + columns = self.adapter.get_columns_in_relation( + self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="tbl_name", + ) + ) + assert columns == [] + + @mock_aws + def test_delete_from_glue_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name") + self.adapter.delete_from_glue_catalog(relation) + glue = boto3.client("glue", region_name=AWS_REGION) + tables_list = glue.get_tables(DatabaseName=DATABASE_NAME).get("TableList") + assert tables_list == [] + + @mock_aws + def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("tbl_name") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_does_not_exist" + ) + delete_table = self.adapter.delete_from_glue_catalog(relation) + assert delete_table is None + error_msg = f"Table {relation.render()} does not exist and will not be deleted, ignoring" + assert error_msg in dbt_debug_caplog.getvalue() + + @mock_aws + def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table("test_table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_table" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.TABLE + + @mock_aws + def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_table_without_table_type("test_table") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_table" + ) + with pytest.raises(ValueError): + self.adapter.get_glue_table_type(relation) + + @mock_aws + def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_view("test_view") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_view" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.VIEW + + @mock_aws + def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service): + mock_aws_service.create_data_catalog() + mock_aws_service.create_database() + mock_aws_service.create_iceberg_table("test_iceberg") + self.adapter.acquire_connection("dummy") + relation = self.adapter.Relation.create( + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="test_iceberg" + ) + table_type = self.adapter.get_glue_table_type(relation) + assert table_type == TableType.ICEBERG + + @pytest.mark.parametrize( + "column,expected", + [ + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "true"}}, True), + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "false"}}, False), + pytest.param({"Name": "user_id", "Type": "int"}, True), + ], + ) + def test__is_current_column(self, column, expected): + assert self.adapter._is_current_column(column) == expected + + @pytest.mark.parametrize( + "partition_keys, expected_result", + [ + ( + ["year(date_col)", "bucket(col_name, 10)", "default_partition_key"], + "date_trunc('year', date_col), col_name, default_partition_key", + ), + ], + ) + def test_format_partition_keys(self, partition_keys, expected_result): + assert self.adapter.format_partition_keys(partition_keys) == expected_result + + @pytest.mark.parametrize( + "partition_key, expected_result", + [ + ("month(hidden)", "date_trunc('month', hidden)"), + ("bucket(bucket_col, 10)", "bucket_col"), + ("regular_col", "regular_col"), + ], + ) + def test_format_one_partition_key(self, partition_key, expected_result): + assert self.adapter.format_one_partition_key(partition_key) == expected_result + + def test_murmur3_hash_with_int(self): + bucket_number = self.adapter.murmur3_hash(123, 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 54 + + def test_murmur3_hash_with_date(self): + d = datetime.date.today() + bucket_number = self.adapter.murmur3_hash(d, 100) + assert isinstance(d, datetime.date) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_datetime(self): + dt = datetime.datetime.now() + bucket_number = self.adapter.murmur3_hash(dt, 100) + assert isinstance(dt, datetime.datetime) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_str(self): + bucket_number = self.adapter.murmur3_hash("test_string", 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 88 + + def test_murmur3_hash_uniqueness(self): + # Ensuring different inputs produce different hashes + hash1 = self.adapter.murmur3_hash("string1", 100) + hash2 = self.adapter.murmur3_hash("string2", 100) + assert hash1 != hash2 + + def test_murmur3_hash_with_unsupported_type(self): + with pytest.raises(TypeError): + self.adapter.murmur3_hash([1, 2, 3], 100) + + @pytest.mark.parametrize( + "value, column_type, expected_result", + [ + (None, "integer", ("null", " is ")), + (42, "integer", ("42", "=")), + ("O'Reilly", "string", ("'O''Reilly'", "=")), + ("test", "string", ("'test'", "=")), + ("2021-01-01", "date", ("DATE'2021-01-01'", "=")), + ("2021-01-01 12:00:00", "timestamp", ("TIMESTAMP'2021-01-01 12:00:00'", "=")), + ], + ) + def test_format_value_for_partition(self, value, column_type, expected_result): + assert self.adapter.format_value_for_partition(value, column_type) == expected_result + + def test_format_unsupported_type(self): + with pytest.raises(ValueError): + self.adapter.format_value_for_partition("test", "unsupported_type") + + +class TestAthenaFilterCatalog: + def test__catalog_filter_table(self): + column_names = ["table_name", "table_database", "table_schema", "something"] + rows = [ + ["foo", "a", "b", "1234"], # include + ["foo", "a", "1234", "1234"], # include, w/ table schema as str + ["foo", "c", "B", "1234"], # skip + ["1234", "A", "B", "1234"], # include, w/ table name as str + ] + table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER) + + result = AthenaAdapter._catalog_filter_table(table, frozenset({("a", "B"), ("a", "1234")})) + assert len(result) == 3 + for row in result.rows: + assert isinstance(row["table_schema"], str) + assert isinstance(row["table_database"], str) + assert isinstance(row["table_name"], str) + assert isinstance(row["something"], decimal.Decimal) + + +class TestAthenaAdapterConversions(TestAdapterConversions): + def test_convert_text_type(self): + rows = [ + ["", "a1", "stringval1"], + ["", "a2", "stringvalasdfasdfasdfa"], + ["", "a3", "stringval3"], + ] + agate_table = self._make_table_of(rows, agate.Text) + expected = ["string", "string", "string"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_text_type(agate_table, col_idx) == expect + + def test_convert_number_type(self): + rows = [ + ["", "23.98", "-1"], + ["", "12.78", "-2"], + ["", "79.41", "-3"], + ] + agate_table = self._make_table_of(rows, agate.Number) + expected = ["integer", "double", "integer"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_number_type(agate_table, col_idx) == expect + + def test_convert_boolean_type(self): + rows = [ + ["", "false", "true"], + ["", "false", "false"], + ["", "false", "true"], + ] + agate_table = self._make_table_of(rows, agate.Boolean) + expected = ["boolean", "boolean", "boolean"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_boolean_type(agate_table, col_idx) == expect + + def test_convert_datetime_type(self): + rows = [ + ["", "20190101T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190102T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190103T01:01:01Z", "2019-01-01 01:01:01"], + ] + agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]) + expected = ["timestamp", "timestamp", "timestamp"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_datetime_type(agate_table, col_idx) == expect + + def test_convert_date_type(self): + rows = [ + ["", "2019-01-01", "2019-01-04"], + ["", "2019-01-02", "2019-01-04"], + ["", "2019-01-03", "2019-01-04"], + ] + agate_table = self._make_table_of(rows, agate.Date) + expected = ["date", "date", "date"] + for col_idx, expect in enumerate(expected): + assert AthenaAdapter.convert_date_type(agate_table, col_idx) == expect diff --git a/dbt-athena/tests/unit/test_column.py b/dbt-athena/tests/unit/test_column.py new file mode 100644 index 00000000..a2b15bf9 --- /dev/null +++ b/dbt-athena/tests/unit/test_column.py @@ -0,0 +1,108 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.relation import TableType + + +class TestAthenaColumn: + def setup_column(self, **kwargs): + base_kwargs = {"column": "foo", "dtype": "varchar"} + return AthenaColumn(**{**base_kwargs, **kwargs}) + + @pytest.mark.parametrize( + "table_type,expected", + [ + pytest.param(TableType.TABLE, False), + pytest.param(TableType.ICEBERG, True), + ], + ) + def test_is_iceberg(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.is_iceberg() is expected + + @pytest.mark.parametrize( + "dtype,expected_type_func", + [ + pytest.param("varchar", "is_string"), + pytest.param("string", "is_string"), + pytest.param("binary", "is_binary"), + pytest.param("varbinary", "is_binary"), + pytest.param("timestamp", "is_timestamp"), + pytest.param("array<string>", "is_array"), + pytest.param("array(string)", "is_array"), + ], + ) + def test_is_type(self, dtype, expected_type_func): + column = self.setup_column(dtype=dtype) + for type_func in ["is_string", "is_binary", "is_timestamp", "is_array"]: + if type_func == expected_type_func: + assert getattr(column, type_func)() + else: + assert not getattr(column, type_func)() + + @pytest.mark.parametrize("size,expected", [pytest.param(1, "varchar(1)"), pytest.param(0, "varchar")]) + def test_string_type(self, size, expected): + assert AthenaColumn.string_type(size) == expected + + @pytest.mark.parametrize( + "table_type,expected", + [pytest.param(TableType.TABLE, "timestamp"), pytest.param(TableType.ICEBERG, "timestamp(6)")], + ) + def test_timestamp_type(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.timestamp_type() == expected + + def test_array_type(self): + assert AthenaColumn.array_type("varchar") == "array(varchar)" + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("array<string>", "string"), + pytest.param("array<varchar(10)>", "varchar(10)"), + pytest.param("array<array<int>>", "array<int>"), + pytest.param("array", "array"), + ], + ) + def test_array_inner_type(self, dtype, expected): + column = self.setup_column(dtype=dtype) + assert column.array_inner_type() == expected + + def test_array_inner_type_raises_for_non_array_type(self): + column = self.setup_column(dtype="varchar") + with pytest.raises(DbtRuntimeError, match=r"Called array_inner_type\(\) on non-array field!"): + column.array_inner_type() + + @pytest.mark.parametrize( + "char_size,expected", + [ + pytest.param(10, 10), + pytest.param(None, 0), + ], + ) + def test_string_size(self, char_size, expected): + column = self.setup_column(dtype="varchar", char_size=char_size) + assert column.string_size() == expected + + def test_string_size_raises_for_non_string_type(self): + column = self.setup_column(dtype="int") + with pytest.raises(DbtRuntimeError, match=r"Called string_size\(\) on non-string field!"): + column.string_size() + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("string", "varchar(10)"), + pytest.param("decimal", "decimal(1,2)"), + pytest.param("binary", "varbinary"), + pytest.param("timestamp", "timestamp(6)"), + pytest.param("array<string>", "array(varchar(10))"), + pytest.param("array<array<string>>", "array(array(varchar(10)))"), + ], + ) + def test_data_type(self, dtype, expected): + column = self.setup_column( + table_type=TableType.ICEBERG, dtype=dtype, char_size=10, numeric_precision=1, numeric_scale=2 + ) + assert column.data_type == expected diff --git a/dbt-athena/tests/unit/test_config.py b/dbt-athena/tests/unit/test_config.py new file mode 100644 index 00000000..55a3672e --- /dev/null +++ b/dbt-athena/tests/unit/test_config.py @@ -0,0 +1,114 @@ +import importlib.metadata +from unittest.mock import Mock + +import pytest + +from dbt.adapters.athena.config import AthenaSparkSessionConfig, get_boto3_config + + +class TestConfig: + def test_get_boto3_config(self): + importlib.metadata.version = Mock(return_value="2.4.6") + num_boto3_retries = 5 + get_boto3_config.cache_clear() + config = get_boto3_config(num_retries=num_boto3_retries) + assert config._user_provided_options["user_agent_extra"] == "dbt-athena/2.4.6" + assert config.retries == {"max_attempts": num_boto3_retries, "mode": "standard"} + + +class TestAthenaSparkSessionConfig: + """ + A class to test AthenaSparkSessionConfig + """ + + @pytest.fixture + def spark_config(self, request): + """ + Fixture for providing Spark configuration parameters. + + This fixture returns a dictionary containing the Spark configuration parameters. The parameters can be + customized using the `request.param` object. The default values are: + - `timeout`: 7200 seconds + - `polling_interval`: 5 seconds + - `engine_config`: {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + + Args: + self: The test class instance. + request: The pytest request object. + + Returns: + dict: The Spark configuration parameters. + + """ + return { + "timeout": request.param.get("timeout", 7200), + "polling_interval": request.param.get("polling_interval", 5), + "engine_config": request.param.get( + "engine_config", {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + ), + } + + @pytest.fixture + def spark_config_helper(self, spark_config): + """Fixture for testing AthenaSparkSessionConfig class. + + Args: + spark_config (dict): Fixture for default spark config. + + Returns: + AthenaSparkSessionConfig: An instance of AthenaSparkSessionConfig class. + """ + return AthenaSparkSessionConfig(spark_config) + + @pytest.mark.parametrize( + "spark_config", + [ + {"timeout": 5}, + {"timeout": 10}, + {"timeout": 20}, + {}, + pytest.param({"timeout": -1}, marks=pytest.mark.xfail), + pytest.param({"timeout": None}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_timeout(self, spark_config_helper): + timeout = spark_config_helper.set_timeout() + assert timeout == spark_config_helper.config.get("timeout", 7200) + + @pytest.mark.parametrize( + "spark_config", + [ + {"polling_interval": 5}, + {"polling_interval": 10}, + {"polling_interval": 20}, + {}, + pytest.param({"polling_interval": -1}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_polling_interval(self, spark_config_helper): + polling_interval = spark_config_helper.set_polling_interval() + assert polling_interval == spark_config_helper.config.get("polling_interval", 5) + + @pytest.mark.parametrize( + "spark_config", + [ + {"engine_config": {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}}, + {"engine_config": {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 2}}, + {}, + pytest.param({"engine_config": {"CoordinatorDpuSize": 1}}, marks=pytest.mark.xfail), + pytest.param({"engine_config": [1, 1, 1]}, marks=pytest.mark.xfail), + ], + indirect=True, + ) + def test_set_engine_config(self, spark_config_helper): + engine_config = spark_config_helper.set_engine_config() + diff = set(engine_config.keys()) - { + "CoordinatorDpuSize", + "MaxConcurrentDpus", + "DefaultExecutorDpuSize", + "SparkProperties", + "AdditionalConfigs", + } + assert len(diff) == 0 diff --git a/dbt-athena/tests/unit/test_connection_manager.py b/dbt-athena/tests/unit/test_connection_manager.py new file mode 100644 index 00000000..fd48a62f --- /dev/null +++ b/dbt-athena/tests/unit/test_connection_manager.py @@ -0,0 +1,35 @@ +from multiprocessing import get_context +from unittest import mock + +import pytest +from pyathena.model import AthenaQueryExecution + +from dbt.adapters.athena import AthenaConnectionManager +from dbt.adapters.athena.connections import AthenaAdapterResponse + + +class TestAthenaConnectionManager: + @pytest.mark.parametrize( + ("state", "result"), + ( + pytest.param(AthenaQueryExecution.STATE_SUCCEEDED, "OK"), + pytest.param(AthenaQueryExecution.STATE_CANCELLED, "ERROR"), + ), + ) + def test_get_response(self, state, result): + cursor = mock.MagicMock() + cursor.rowcount = 1 + cursor.state = state + cursor.data_scanned_in_bytes = 123 + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) + response = cm.get_response(cursor) + assert isinstance(response, AthenaAdapterResponse) + assert response.code == result + assert response.rows_affected == 1 + assert response.data_scanned_in_bytes == 123 + + def test_data_type_code_to_name(self): + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) + assert cm.data_type_code_to_name("array<string>") == "ARRAY" + assert cm.data_type_code_to_name("map<int, boolean>") == "MAP" + assert cm.data_type_code_to_name("DECIMAL(3, 7)") == "DECIMAL" diff --git a/dbt-athena/tests/unit/test_formatter.py b/dbt-athena/tests/unit/test_formatter.py new file mode 100644 index 00000000..a645635a --- /dev/null +++ b/dbt-athena/tests/unit/test_formatter.py @@ -0,0 +1,118 @@ +import textwrap +from datetime import date, datetime +from decimal import Decimal + +import agate +import pytest +from pyathena.error import ProgrammingError + +from dbt.adapters.athena.connections import AthenaParameterFormatter + + +class TestAthenaParameterFormatter: + formatter = AthenaParameterFormatter() + + @pytest.mark.parametrize( + "sql", + [ + "", + """ + + """, + ], + ) + def test_query_none_or_empty(self, sql): + with pytest.raises(ProgrammingError) as exc: + self.formatter.format(sql) + assert exc.value.__str__() == "Query is none or empty." + + def test_query_none_parameters(self): + sql = self.formatter.format( + """ + + ALTER TABLE table ADD partition (dt = '2022-01-01') + """ + ) + assert sql == "ALTER TABLE table ADD partition (dt = '2022-01-01')" + + def test_query_parameters_not_list(self): + with pytest.raises(ProgrammingError) as exc: + self.formatter.format( + """ + SELECT * + FROM table + WHERE country = %(country)s + """, + {"country": "FR"}, + ) + assert exc.value.__str__() == "Unsupported parameter (Support for list only): {'country': 'FR'}" + + def test_query_parameters_unknown_formatter(self): + with pytest.raises(TypeError) as exc: + self.formatter.format( + """ + SELECT * + FROM table + WHERE country = %s + """, + [agate.Table(rows=[("a", 1), ("b", 2)], column_names=["str", "int"])], + ) + assert exc.value.__str__() == "<class 'agate.table.Table'> is not defined formatter." + + def test_query_parameters_list(self): + res = self.formatter.format( + textwrap.dedent( + """ + SELECT * + FROM table + WHERE nullable_field = %s + AND dt = %s + AND dti = %s + AND int_field > %s + AND float_field > %.2f + AND fake_decimal_field < %s + AND str_field = %s + AND str_list_field IN %s + AND int_list_field IN %s + AND float_list_field IN %s + AND dt_list_field IN %s + AND dti_list_field IN %s + AND bool_field = %s + """ + ), + [ + None, + date(2022, 1, 1), + datetime(2022, 1, 1, 0, 0, 0), + 1, + 1.23, + Decimal(2), + "test", + ["a", "b"], + [1, 2, 3, 4], + (1.1, 1.2, 1.3), + (date(2022, 2, 1), date(2022, 3, 1)), + (datetime(2022, 2, 1, 1, 2, 3), datetime(2022, 3, 1, 4, 5, 6)), + True, + ], + ) + expected = textwrap.dedent( + """ + SELECT * + FROM table + WHERE nullable_field = null + AND dt = DATE '2022-01-01' + AND dti = TIMESTAMP '2022-01-01 00:00:00.000' + AND int_field > 1 + AND float_field > 1.23 + AND fake_decimal_field < 2 + AND str_field = 'test' + AND str_list_field IN ('a', 'b') + AND int_list_field IN (1, 2, 3, 4) + AND float_list_field IN (1.100000, 1.200000, 1.300000) + AND dt_list_field IN (DATE '2022-02-01', DATE '2022-03-01') + AND dti_list_field IN (TIMESTAMP '2022-02-01 01:02:03.000', TIMESTAMP '2022-03-01 04:05:06.000') + AND bool_field = True + """ + ).strip() + assert res == expected diff --git a/dbt-athena/tests/unit/test_lakeformation.py b/dbt-athena/tests/unit/test_lakeformation.py new file mode 100644 index 00000000..3a05030c --- /dev/null +++ b/dbt-athena/tests/unit/test_lakeformation.py @@ -0,0 +1,144 @@ +import boto3 +import pytest +from tests.unit.constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME + +import dbt.adapters.athena.lakeformation as lakeformation +from dbt.adapters.athena.lakeformation import LfTagsConfig, LfTagsManager +from dbt.adapters.athena.relation import AthenaRelation + + +# TODO: add more tests for lakeformation once moto library implements required methods: +# https://docs.getmoto.org/en/4.1.9/docs/services/lakeformation.html +# get_resource_lf_tags +class TestLfTagsManager: + @pytest.mark.parametrize( + "response,identifier,columns,lf_tags,verb,expected", + [ + pytest.param( + { + "Failures": [ + { + "LFTag": {"CatalogId": "test_catalog", "TagKey": "test_key", "TagValues": ["test_values"]}, + "Error": {"ErrorCode": "test_code", "ErrorMessage": "test_err_msg"}, + } + ] + }, + "tbl_name", + ["column1", "column2"], + {"tag_key": "tag_value"}, + "add", + None, + id="lf_tag error", + marks=pytest.mark.xfail, + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + id="add lf_tag", + ), + pytest.param( + {"Failures": []}, + None, + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena", + id="add lf_tag_to_database", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, + {"tag_key": "tag_value"}, + "remove", + "Success: remove LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + id="remove lf_tag", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + ["c1", "c2"], + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", + id="lf_tag database table and columns", + ), + ], + ) + def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, columns, lf_tags, verb, expected): + relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=identifier) + lf_client = boto3.client("lakeformation", region_name=AWS_REGION) + manager = LfTagsManager(lf_client, relation, LfTagsConfig()) + manager._parse_and_log_lf_response(response, columns, lf_tags, verb) + assert expected in dbt_debug_caplog.getvalue() + + @pytest.mark.parametrize( + "lf_tags_columns,lf_inherited_tags,expected", + [ + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}]}], + {"inherited"}, + {}, + id="retains-inherited-tag", + ), + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}]}], + {}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag", + ), + pytest.param( + [ + { + "Name": "my_column", + "LFTags": [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + } + ], + {"inherited"}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag-among-inherited", + ), + pytest.param([], {}, {}, id="handles-empty"), + ], + ) + def test__column_tags_to_remove(self, lf_tags_columns, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._column_tags_to_remove(lf_tags_columns, lf_inherited_tags) == expected + + @pytest.mark.parametrize( + "lf_tags_table,lf_tags,lf_inherited_tags,expected", + [ + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {"not-inherited": "some-preexisting-value"}, + {"inherited"}, + {}, + id="retains-being-set-and-inherited", + ), + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {}, + {"inherited"}, + {"not-inherited": ["oh-no-im-not"]}, + id="removes-preexisting-not-being-set", + ), + pytest.param( + [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}], {}, {"inherited"}, {}, id="retains-inherited" + ), + pytest.param([], None, {}, {}, id="handles-empty"), + ], + ) + def test__table_tags_to_remove(self, lf_tags_table, lf_tags, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._table_tags_to_remove(lf_tags_table, lf_tags, lf_inherited_tags) == expected diff --git a/dbt-athena/tests/unit/test_python_submissions.py b/dbt-athena/tests/unit/test_python_submissions.py new file mode 100644 index 00000000..65961d60 --- /dev/null +++ b/dbt-athena/tests/unit/test_python_submissions.py @@ -0,0 +1,209 @@ +import time +import uuid +from unittest.mock import Mock, patch + +import pytest + +from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper +from dbt.adapters.athena.session import AthenaSparkSessionManager + +from .constants import DATABASE_NAME + + +@pytest.mark.usefixtures("athena_credentials", "athena_client") +class TestAthenaPythonJobHelper: + """ + A class to test the AthenaPythonJobHelper + """ + + @pytest.fixture + def parsed_model(self, request): + config: dict[str, int] = request.param.get("config", {"timeout": 1, "polling_interval": 5}) + + return { + "alias": "test_model", + "schema": DATABASE_NAME, + "config": { + "timeout": config["timeout"], + "polling_interval": config["polling_interval"], + "engine_config": request.param.get( + "engine_config", {"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1} + ), + }, + } + + @pytest.fixture + def athena_spark_session_manager(self, athena_credentials, parsed_model): + return AthenaSparkSessionManager( + athena_credentials, + timeout=parsed_model["config"]["timeout"], + polling_interval=parsed_model["config"]["polling_interval"], + engine_config=parsed_model["config"]["engine_config"], + ) + + @pytest.fixture + def athena_job_helper( + self, athena_credentials, athena_client, athena_spark_session_manager, parsed_model, monkeypatch + ): + mock_job_helper = AthenaPythonJobHelper(parsed_model, athena_credentials) + monkeypatch.setattr(mock_job_helper, "athena_client", athena_client) + monkeypatch.setattr(mock_job_helper, "spark_connection", athena_spark_session_manager) + return mock_job_helper + + @pytest.mark.parametrize( + "parsed_model, session_status_response, expected_response", + [ + ( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "IDLE", + }, + None, + ), + pytest.param( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "FAILED", + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 5, "polling_interval": 5}}, + { + "State": "TERMINATED", + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "State": "CREATING", + }, + None, + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_poll_session_idle( + self, session_status_response, expected_response, athena_job_helper, athena_spark_session_manager, monkeypatch + ): + with patch.multiple( + athena_spark_session_manager, + get_session_status=Mock(return_value=session_status_response), + get_session_id=Mock(return_value="test_session_id"), + ): + + def mock_sleep(_): + pass + + monkeypatch.setattr(time, "sleep", mock_sleep) + poll_response = athena_job_helper.poll_until_session_idle() + assert poll_response == expected_response + + @pytest.mark.parametrize( + "parsed_model, execution_status, expected_response", + [ + ( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "COMPLETED", + } + }, + "COMPLETED", + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "FAILED", + } + }, + None, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + { + "Status": { + "State": "RUNNING", + } + }, + "RUNNING", + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_poll_execution( + self, + execution_status, + expected_response, + athena_job_helper, + athena_spark_session_manager, + athena_client, + monkeypatch, + ): + with patch.multiple( + athena_spark_session_manager, + get_session_id=Mock(return_value=uuid.uuid4()), + ), patch.multiple( + athena_client, + get_calculation_execution=Mock(return_value=execution_status), + ): + + def mock_sleep(_): + pass + + monkeypatch.setattr(time, "sleep", mock_sleep) + poll_response = athena_job_helper.poll_until_execution_completion("test_calculation_id") + assert poll_response == expected_response + + @pytest.mark.parametrize( + "parsed_model, test_calculation_execution_id, test_calculation_execution", + [ + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {"CalculationExecutionId": "test_execution_id"}, + { + "Result": {"ResultS3Uri": "test_results_s3_uri"}, + "Status": {"State": "COMPLETED"}, + }, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {"CalculationExecutionId": "test_execution_id"}, + {"Result": {}, "Status": {"State": "FAILED"}}, + marks=pytest.mark.xfail, + ), + pytest.param( + {"config": {"timeout": 1, "polling_interval": 5}}, + {}, + {"Result": {}, "Status": {"State": "FAILED"}}, + marks=pytest.mark.xfail, + ), + ], + indirect=["parsed_model"], + ) + def test_submission( + self, + test_calculation_execution_id, + test_calculation_execution, + athena_job_helper, + athena_spark_session_manager, + athena_client, + ): + with patch.multiple( + athena_spark_session_manager, get_session_id=Mock(return_value=uuid.uuid4()) + ), patch.multiple( + athena_client, + start_calculation_execution=Mock(return_value=test_calculation_execution_id), + get_calculation_execution=Mock(return_value=test_calculation_execution), + ), patch.multiple( + athena_job_helper, poll_until_session_idle=Mock(return_value="IDLE") + ): + result = athena_job_helper.submit("hello world") + assert result == test_calculation_execution["Result"] diff --git a/dbt-athena/tests/unit/test_query_headers.py b/dbt-athena/tests/unit/test_query_headers.py new file mode 100644 index 00000000..89a99ef5 --- /dev/null +++ b/dbt-athena/tests/unit/test_query_headers.py @@ -0,0 +1,88 @@ +from unittest import mock + +from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter +from dbt.context.query_header import generate_query_header_context + +from .constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME +from .utils import config_from_parts_or_dicts + + +class TestQueryHeaders: + def setup_method(self, _): + config = config_from_parts_or_dicts( + { + "name": "query_headers", + "version": "0.1", + "profile": "test", + "config-version": 2, + }, + { + "outputs": { + "test": { + "type": "athena", + "s3_staging_dir": "s3://my-bucket/test-dbt/", + "region_name": AWS_REGION, + "database": DATA_CATALOG_NAME, + "work_group": "dbt-athena-adapter", + "schema": DATABASE_NAME, + } + }, + "target": "test", + }, + ) + self.query_header = AthenaMacroQueryStringSetter( + config, generate_query_header_context(config, mock.MagicMock(macros={})) + ) + + def test_append_comment_with_semicolon(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = self.query_header.add("SELECT 1;") + assert sql == "SELECT 1\n-- /* executed by dbt */;" + + def test_append_comment_without_semicolon(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = self.query_header.add("SELECT 1") + assert sql == "SELECT 1\n-- /* executed by dbt */" + + def test_comment_multiple_lines(self): + self.query_header.comment.query_comment = """executed by dbt\nfor table table""" + self.query_header.comment.append = False + sql = self.query_header.add("insert into table(id) values (1);") + assert sql == "-- /* executed by dbt for table table */\ninsert into table(id) values (1);" + + def test_disable_query_comment(self): + self.query_header.comment.query_comment = "" + self.query_header.comment.append = True + assert self.query_header.add("SELECT 1;") == "SELECT 1;" + + def test_no_query_comment_on_alter(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "alter table table add column time time;" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_vacuum(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "VACUUM table" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_msck(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "MSCK REPAIR TABLE" + assert self.query_header.add(sql) == sql + + def test_no_query_comment_on_vacuum_with_leading_whitespaces(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = " VACUUM table" + assert self.query_header.add(sql) == "VACUUM table" + + def test_no_query_comment_on_vacuum_with_lowercase(self): + self.query_header.comment.query_comment = "executed by dbt" + self.query_header.comment.append = True + sql = "vacuum table" + assert self.query_header.add(sql) == sql diff --git a/dbt-athena/tests/unit/test_relation.py b/dbt-athena/tests/unit/test_relation.py new file mode 100644 index 00000000..4e8c61c3 --- /dev/null +++ b/dbt-athena/tests/unit/test_relation.py @@ -0,0 +1,55 @@ +import pytest + +from dbt.adapters.athena.relation import AthenaRelation, TableType, get_table_type + +from .constants import DATA_CATALOG_NAME, DATABASE_NAME + +TABLE_NAME = "test_table" + + +class TestRelation: + @pytest.mark.parametrize( + ("table", "expected"), + [ + ({"Name": "n", "TableType": "table"}, TableType.TABLE), + ({"Name": "n", "TableType": "VIRTUAL_VIEW"}, TableType.VIEW), + ({"Name": "n", "TableType": "EXTERNAL_TABLE", "Parameters": {"table_type": "ICEBERG"}}, TableType.ICEBERG), + ], + ) + def test__get_relation_type(self, table, expected): + assert get_table_type(table) == expected + + def test__get_relation_type_with_no_type(self): + with pytest.raises(ValueError): + get_table_type({"Name": "name"}) + + def test__get_relation_type_with_unknown_type(self): + with pytest.raises(ValueError): + get_table_type({"Name": "name", "TableType": "test"}) + + +class TestAthenaRelation: + def test_render_hive_uses_hive_style_quotation_and_only_schema_and_table_names(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + assert relation.render_hive() == f"`{DATABASE_NAME}`.`{TABLE_NAME}`" + + def test_render_hive_resets_quote_character_and_include_policy_after_call(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + relation.render_hive() + assert relation.render() == f'"{DATA_CATALOG_NAME}"."{DATABASE_NAME}"."{TABLE_NAME}"' + + def test_render_pure_resets_quote_character_after_call(self): + relation = AthenaRelation.create( + identifier=TABLE_NAME, + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + ) + assert relation.render_pure() == f"{DATA_CATALOG_NAME}.{DATABASE_NAME}.{TABLE_NAME}" diff --git a/dbt-athena/tests/unit/test_session.py b/dbt-athena/tests/unit/test_session.py new file mode 100644 index 00000000..3c1f4324 --- /dev/null +++ b/dbt-athena/tests/unit/test_session.py @@ -0,0 +1,176 @@ +from unittest.mock import Mock, patch +from uuid import UUID + +import botocore.session +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena import AthenaCredentials +from dbt.adapters.athena.session import AthenaSparkSessionManager, get_boto3_session +from dbt.adapters.contracts.connection import Connection + + +class TestSession: + @pytest.mark.parametrize( + ("credentials_profile_name", "boto_profile_name"), + ( + pytest.param(None, "default", id="no_profile_in_credentials"), + pytest.param("my_profile", "my_profile", id="profile_in_credentials"), + ), + ) + def test_session_should_be_called_with_correct_parameters( + self, monkeypatch, credentials_profile_name, boto_profile_name + ): + def mock___build_profile_map(_): + return {**{"default": {}}, **({} if not credentials_profile_name else {credentials_profile_name: {}})} + + monkeypatch.setattr(botocore.session.Session, "_build_profile_map", mock___build_profile_map) + connection = Connection( + type="test", + name="test_session", + credentials=AthenaCredentials( + database="db", + schema="schema", + s3_staging_dir="dir", + region_name="eu-west-1", + aws_profile_name=credentials_profile_name, + ), + ) + session = get_boto3_session(connection) + assert session.region_name == "eu-west-1" + assert session.profile_name == boto_profile_name + + +@pytest.mark.usefixtures("athena_credentials", "athena_client") +class TestAthenaSparkSessionManager: + """ + A class to test the AthenaSparkSessionManager + """ + + @pytest.fixture + def spark_session_manager(self, athena_credentials, athena_client, monkeypatch): + """ + Fixture for creating a mock Spark session manager. + + This fixture creates an instance of AthenaSparkSessionManager with the provided Athena credentials, + timeout, polling interval, and engine configuration. It then patches the Athena client of the manager + with the provided `athena_client` object. The fixture returns the mock Spark session manager. + + Args: + self: The test class instance. + athena_credentials: The Athena credentials. + athena_client: The Athena client object. + monkeypatch: The monkeypatch object for mocking. + + Returns: + The mock Spark session manager. + + """ + mock_session_manager = AthenaSparkSessionManager( + athena_credentials, + timeout=10, + polling_interval=5, + engine_config={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}, + ) + monkeypatch.setattr(mock_session_manager, "athena_client", athena_client) + return mock_session_manager + + @pytest.mark.parametrize( + "session_status_response, expected_response", + [ + pytest.param( + {"Status": {"SessionId": "test_session_id", "State": "CREATING"}}, + DbtRuntimeError( + """Session <MagicMock name='client.start_session().__getitem__()' id='140219810489792'> + did not create within 10 seconds.""" + ), + marks=pytest.mark.xfail, + ), + ( + {"Status": {"SessionId": "635c1c6d-766c-408b-8bce-fae8ea7006f7", "State": "IDLE"}}, + UUID("635c1c6d-766c-408b-8bce-fae8ea7006f7"), + ), + pytest.param( + {"Status": {"SessionId": "test_session_id", "State": "TERMINATED"}}, + DbtRuntimeError("Unable to create session: test_session_id. Got status: TERMINATED."), + marks=pytest.mark.xfail, + ), + ], + ) + def test_start_session( + self, session_status_response, expected_response, spark_session_manager, athena_client + ) -> None: + """ + Test the start_session method of the AthenaJobHelper class. + + Args: + session_status_response (dict): A dictionary containing the response from the Athena session + creation status. + expected_response (Union[dict, DbtRuntimeError]): The expected response from the start_session method. + athena_job_helper (AthenaPythonJobHelper): An instance of the AthenaPythonJobHelper class. + athena_client (botocore.client.BaseClient): An instance of the botocore Athena client. + + Returns: + None + """ + with patch.multiple( + spark_session_manager, + poll_until_session_creation=Mock(return_value=session_status_response), + ), patch.multiple( + athena_client, + get_session_status=Mock(return_value=session_status_response), + start_session=Mock(return_value=session_status_response.get("Status")), + ): + response = spark_session_manager.start_session() + assert response == expected_response + + @pytest.mark.parametrize( + "session_status_response, expected_status", + [ + ( + { + "SessionId": "test_session_id", + "Status": { + "State": "CREATING", + }, + }, + { + "State": "CREATING", + }, + ), + ( + { + "SessionId": "test_session_id", + "Status": { + "State": "IDLE", + }, + }, + { + "State": "IDLE", + }, + ), + ], + ) + def test_get_session_status(self, session_status_response, expected_status, spark_session_manager, athena_client): + """ + Test the get_session_status function. + + Args: + self: The test class instance. + session_status_response (dict): The response from get_session_status. + expected_status (dict): The expected session status. + spark_session_manager: The Spark session manager object. + athena_client: The Athena client object. + + Returns: + None + + Raises: + AssertionError: If the retrieved session status is not correct. + """ + with patch.multiple(athena_client, get_session_status=Mock(return_value=session_status_response)): + response = spark_session_manager.get_session_status("test_session_id") + assert response == expected_status + + def test_get_session_id(self): + pass diff --git a/dbt-athena/tests/unit/test_utils.py b/dbt-athena/tests/unit/test_utils.py new file mode 100644 index 00000000..69a2dc9d --- /dev/null +++ b/dbt-athena/tests/unit/test_utils.py @@ -0,0 +1,72 @@ +import pytest + +from dbt.adapters.athena.utils import ( + clean_sql_comment, + ellipsis_comment, + get_chunks, + is_valid_table_parameter_key, + stringify_table_parameter_value, +) + + +def test_clean_comment(): + assert ( + clean_sql_comment( + """ + my long comment + on several lines + with weird spaces and indents. + """ + ) + == "my long comment on several lines with weird spaces and indents." + ) + + +def test_stringify_table_parameter_value(): + class NonStringifiableObject: + def __str__(self): + raise ValueError("Non-stringifiable object") + + assert stringify_table_parameter_value(True) == "True" + assert stringify_table_parameter_value(123) == "123" + assert stringify_table_parameter_value("dbt-athena") == "dbt-athena" + assert stringify_table_parameter_value(["a", "b", 3]) == '["a", "b", 3]' + assert stringify_table_parameter_value({"a": 1, "b": "c"}) == '{"a": 1, "b": "c"}' + assert len(stringify_table_parameter_value("a" * 512001)) == 512000 + assert stringify_table_parameter_value(NonStringifiableObject()) is None + assert stringify_table_parameter_value([NonStringifiableObject()]) is None + + +def test_is_valid_table_parameter_key(): + assert is_valid_table_parameter_key("valid_key") is True + assert is_valid_table_parameter_key("Valid Key 123*!") is True + assert is_valid_table_parameter_key("invalid \n key") is False + assert is_valid_table_parameter_key("long_key" * 100) is False + + +def test_get_chunks_empty(): + assert len(list(get_chunks([], 5))) == 0 + + +def test_get_chunks_uneven(): + chunks = list(get_chunks([1, 2, 3], 2)) + assert chunks[0] == [1, 2] + assert chunks[1] == [3] + assert len(chunks) == 2 + + +def test_get_chunks_more_elements_than_chunk(): + chunks = list(get_chunks([1, 2, 3], 4)) + assert chunks[0] == [1, 2, 3] + assert len(chunks) == 1 + + +@pytest.mark.parametrize( + ("max_len", "expected"), + ( + pytest.param(12, "abc def ghi", id="ok string"), + pytest.param(6, "abc...", id="ellipsis"), + ), +) +def test_ellipsis_comment(max_len, expected): + assert expected == ellipsis_comment("abc def ghi", max_len=max_len) diff --git a/dbt-athena/tests/unit/utils.py b/dbt-athena/tests/unit/utils.py new file mode 100644 index 00000000..320fdc13 --- /dev/null +++ b/dbt-athena/tests/unit/utils.py @@ -0,0 +1,485 @@ +import os +import string +from typing import Optional + +import agate +import boto3 + +from dbt.adapters.athena.utils import AthenaCatalogType +from dbt.config.project import PartialProject + +from .constants import AWS_REGION, BUCKET, CATALOG_ID, DATA_CATALOG_NAME, DATABASE_NAME + + +class Obj: + which = "blah" + single_threaded = False + + +def profile_from_dict(profile, profile_name, cli_vars="{}"): + from dbt.config import Profile + from dbt.config.renderer import ProfileRenderer + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + renderer = ProfileRenderer(cli_vars) + + # in order to call dbt's internal profile rendering, we need to set the + # flags global. This is a bit of a hack, but it's the best way to do it. + from argparse import Namespace + + from dbt.flags import set_from_args + + set_from_args(Namespace(), None) + return Profile.from_raw_profile_info( + profile, + profile_name, + renderer, + ) + + +def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"): + from dbt.config.renderer import DbtProjectYamlRenderer + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + renderer = DbtProjectYamlRenderer(profile, cli_vars) + + project_root = project.pop("project-root", os.getcwd()) + + partial = PartialProject.from_dicts( + project_root=project_root, + project_dict=project, + packages_dict=packages, + selectors_dict=selectors, + ) + return partial.render(renderer) + + +def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars="{}"): + from copy import deepcopy + + from dbt.config import Profile, Project, RuntimeConfig + from dbt.config.utils import parse_cli_vars + + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + if isinstance(project, Project): + profile_name = project.profile_name + else: + profile_name = project.get("profile") + + if not isinstance(profile, Profile): + profile = profile_from_dict( + deepcopy(profile), + profile_name, + cli_vars, + ) + + if not isinstance(project, Project): + project = project_from_dict( + deepcopy(project), + profile, + packages, + selectors, + cli_vars, + ) + + args = Obj() + args.vars = cli_vars + args.profile_dir = "/dev/null" + return RuntimeConfig.from_parts(project=project, profile=profile, args=args) + + +def inject_plugin(plugin): + from dbt.adapters.factory import FACTORY + + key = plugin.adapter.type() + FACTORY.plugins[key] = plugin + + +def inject_adapter(value, plugin): + """Inject the given adapter into the adapter factory, so your hand-crafted + artisanal adapter will be available from get_adapter() as if dbt loaded it. + """ + inject_plugin(plugin) + from dbt.adapters.factory import FACTORY + + key = value.type() + FACTORY.adapters[key] = value + + +def clear_plugin(plugin): + from dbt.adapters.factory import FACTORY + + key = plugin.adapter.type() + FACTORY.plugins.pop(key, None) + FACTORY.adapters.pop(key, None) + + +class TestAdapterConversions: + def _get_tester_for(self, column_type): + from dbt_common.clients import agate_helper + + if column_type is agate.TimeDelta: # dbt never makes this! + return agate.TimeDelta() + + for instance in agate_helper.DEFAULT_TYPE_TESTER._possible_types: + if isinstance(instance, column_type): # include child types + return instance + + raise ValueError(f"no tester for {column_type}") + + def _make_table_of(self, rows, column_types): + column_names = list(string.ascii_letters[: len(rows[0])]) + if isinstance(column_types, type): + column_types = [self._get_tester_for(column_types) for _ in column_names] + else: + column_types = [self._get_tester_for(typ) for typ in column_types] + table = agate.Table(rows, column_names=column_names, column_types=column_types) + return table + + +class MockAWSService: + def create_data_catalog( + self, + catalog_name: str = DATA_CATALOG_NAME, + catalog_type: AthenaCatalogType = AthenaCatalogType.GLUE, + catalog_id: str = CATALOG_ID, + ): + athena = boto3.client("athena", region_name=AWS_REGION) + parameters = {} + if catalog_type == AthenaCatalogType.GLUE: + parameters = {"catalog-id": catalog_id} + else: + parameters = {"catalog": catalog_name} + athena.create_data_catalog(Name=catalog_name, Type=catalog_type.value, Parameters=parameters) + + def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_database(DatabaseInput={"Name": name}, CatalogId=catalog_id) + + def create_view(self, view_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": view_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": "", + }, + "TableType": "VIRTUAL_VIEW", + "Parameters": { + "TableOwner": "John Doe", + }, + }, + ) + + def create_table( + self, + table_name: str, + database_name: str = DATABASE_NAME, + catalog_id: str = CATALOG_ID, + location: Optional[str] = "auto", + ): + glue = boto3.client("glue", region_name=AWS_REGION) + if location == "auto": + location = f"s3://{BUCKET}/tables/{table_name}" + glue.create_table( + CatalogId=catalog_id, + DatabaseName=database_name, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "Location": location, + }, + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + "TableType": "table", + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_table_without_type(self, table_name: str, database_name: str = DATABASE_NAME): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=database_name, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_table_without_partitions(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "PartitionKeys": [], + "TableType": "table", + "Parameters": { + "compressionType": "snappy", + "classification": "parquet", + "projection.enabled": "false", + "typeOfData": "file", + }, + }, + ) + + def create_iceberg_table(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/data/{table_name}", + }, + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "metadata_location": f"s3://{BUCKET}/tables/metadata/{table_name}/123.json", + "table_type": "ICEBERG", + }, + }, + ) + + def create_table_without_table_type(self, table_name: str): + glue = boto3.client("glue", region_name=AWS_REGION) + glue.create_table( + DatabaseName=DATABASE_NAME, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}", + }, + "Parameters": { + "TableOwner": "John Doe", + }, + }, + ) + + def create_work_group_with_output_location_enforced(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://pre-configured-output-location/", + }, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def create_work_group_with_output_location_not_enforced(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://pre-configured-output-location/", + }, + "EnforceWorkGroupConfiguration": False, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def create_work_group_no_output_location(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + + def add_data_in_table(self, table_name: str): + s3 = boto3.client("s3", region_name=AWS_REGION) + s3.create_bucket(Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-01/data1.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-01/data2.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-02/data.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-03/data1.parquet") + s3.put_object(Body=b"{}", Bucket=BUCKET, Key=f"tables/{table_name}/dt=2022-01-03/data2.parquet") + partition_input_list = [ + { + "Values": [dt], + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}/dt={dt}", + }, + } + for dt in ["2022-01-01", "2022-01-02", "2022-01-03"] + ] + glue = boto3.client("glue", region_name=AWS_REGION) + glue.batch_create_partition( + DatabaseName="test_dbt_athena", TableName=table_name, PartitionInputList=partition_input_list + ) + + def add_partitions_to_table(self, database, table_name): + partition_input_list = [ + { + "Values": [dt], + "StorageDescriptor": { + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + { + "Name": "dt", + "Type": "date", + }, + ], + "Location": f"s3://{BUCKET}/tables/{table_name}/dt={dt}", + }, + "Parameters": {"compressionType": "snappy", "classification": "parquet"}, + } + for dt in [f"2022-01-{day:02d}" for day in range(1, 27)] + ] + glue = boto3.client("glue", region_name=AWS_REGION) + glue.batch_create_partition( + DatabaseName=database, TableName=table_name, PartitionInputList=partition_input_list + ) + + def add_table_version(self, database, table_name): + glue = boto3.client("glue", region_name=AWS_REGION) + table = glue.get_table(DatabaseName=database, Name=table_name).get("Table") + new_table_version = { + "Name": table_name, + "StorageDescriptor": table["StorageDescriptor"], + "PartitionKeys": table["PartitionKeys"], + "TableType": table["TableType"], + "Parameters": table["Parameters"], + } + glue.update_table(DatabaseName=database, TableInput=new_table_version) diff --git a/scripts/migrate-adapter.sh b/scripts/migrate-adapter.sh index 58d70018..9993476c 100644 --- a/scripts/migrate-adapter.sh +++ b/scripts/migrate-adapter.sh @@ -8,7 +8,7 @@ git remote add old https://github.com/dbt-labs/$repo.git git fetch old # merge the updated branch from the legacy repo into the dbt-adapters repo -git checkout -b $target_branch +git checkout $target_branch git merge old/$source_branch --allow-unrelated-histories # remove the remote that was created by this process diff --git a/static/images/dbt-athena-black.png b/static/images/dbt-athena-black.png new file mode 100644 index 00000000..b7ba90b7 Binary files /dev/null and b/static/images/dbt-athena-black.png differ diff --git a/static/images/dbt-athena-color.png b/static/images/dbt-athena-color.png new file mode 100644 index 00000000..aca1b51b Binary files /dev/null and b/static/images/dbt-athena-color.png differ diff --git a/static/images/dbt-athena-long.png b/static/images/dbt-athena-long.png new file mode 100644 index 00000000..22cb5ccd Binary files /dev/null and b/static/images/dbt-athena-long.png differ