Skip to content

Commit

Permalink
Add Redshift Iam Idc token authentication method with an eye towards …
Browse files Browse the repository at this point in the history
…future supported Idps (#970)

* Fix tests and add token authentication method to auth flow

* Add changelog.

* Add token method (>^.^)>

We expect users of this method to provide a YAML-structured set of params including a uri, an authentication string, and whatever paramters might be needed to construct the correct payload equivalent to data in a curl request. There is an all-important under the hood POST which needs a set of params unique to each identity provider to generate access tokens for use with TokenAuthIdpPlugin.

* Add unit tests for current codepaths.

* Make test a bit more specific.

* Add skeleton of test case I've been using for hand testing. Can't commit it due to it being based on a refresh token.

* Code review comments and adapt for Entra + future providers.

* Improve comment.

* Better error handling for missing access_token since it could happen on some Idp
  • Loading branch information
VersusFacit authored Jan 16, 2025
1 parent f10d316 commit de078b8
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 11 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241217-181340.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add IdpTokenAuthPlugin authentication method.
time: 2024-12-17T18:13:40.281494-08:00
custom:
Author: versusfacit
Issue: "898"
87 changes: 87 additions & 0 deletions dbt/adapters/redshift/auth_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import requests
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any

from dbt.adapters.exceptions import FailedToConnectError
from dbt_common.exceptions import DbtRuntimeError


# Define an Enum for the supported token endpoint types
class TokenServiceBase(ABC):
def __init__(self, token_endpoint: Dict[str, Any]):
expected_keys = {"type", "request_url", "request_data"}
for key in expected_keys:
if key not in token_endpoint:
raise FailedToConnectError(f"Missing required key in token_endpoint: '{key}'")

self.type: str = token_endpoint["type"]
self.url: str = token_endpoint["request_url"]
self.data: str = token_endpoint["request_data"]

self.other_params = {k: v for k, v in token_endpoint.items() if k not in expected_keys}

@abstractmethod
def build_header_payload(self) -> Dict[str, Any]:
pass

def handle_request(self) -> requests.Response:
"""
Handles the request with rate limiting and error handling.
"""
response = requests.post(self.url, headers=self.build_header_payload(), data=self.data)

if response.status_code == 429:
raise DbtRuntimeError(
"Rate limit on identity provider's token dispatch has been reached. "
"Consider increasing your identity provider's refresh token rate or "
"lower dbt's maximum concurrent thread count."
)

response.raise_for_status()
return response


class OktaIdpTokenService(TokenServiceBase):
def build_header_payload(self) -> Dict[str, Any]:
if encoded_idp_client_creds := self.other_params.get("idp_auth_credentials"):
return {
"accept": "application/json",
"authorization": f"Basic {encoded_idp_client_creds}",
"content-type": "application/x-www-form-urlencoded",
}
else:
raise FailedToConnectError(
"Missing 'idp_auth_credentials' from token_endpoint. Please provide client_id:client_secret in base64 encoded format as a profile entry under token_endpoint."
)


class EntraIdpTokenService(TokenServiceBase):
"""
formatted based on docs: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token
"""

def build_header_payload(self) -> Dict[str, Any]:
return {
"accept": "application/json",
"content-type": "application/x-www-form-urlencoded",
}


class TokenServiceType(Enum):
OKTA = "okta"
ENTRA = "entra"


def create_token_service_client(token_endpoint: Dict[str, Any]) -> TokenServiceBase:
if (service_type := token_endpoint.get("type")) is None:
raise FailedToConnectError("Missing required key in token_endpoint: 'type'")

if service_type == TokenServiceType.OKTA.value:
return OktaIdpTokenService(token_endpoint)
elif service_type == TokenServiceType.ENTRA.value:
return EntraIdpTokenService(token_endpoint)
else:
raise ValueError(
f"Unsupported identity provider type: {service_type}. Select 'okta' or 'entra.'"
)
49 changes: 41 additions & 8 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import re
import redshift_connector
import sqlparse

from multiprocessing import Lock
from contextlib import contextmanager
from typing import Any, Callable, Dict, Tuple, Union, Optional, List, TYPE_CHECKING
from dataclasses import dataclass, field

import sqlparse
import redshift_connector
from dbt.adapters.exceptions import FailedToConnectError
from redshift_connector.utils.oids import get_datatype_name

from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.redshift.auth_providers import create_token_service_client
from dbt_common.contracts.util import Replaceable
from dbt_common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError
from dbt_common.helper_types import Port
Expand All @@ -37,20 +39,16 @@ def get_message(self) -> str:
logger = AdapterLogger("Redshift")


class IdentityCenterTokenType(StrEnum):
ACCESS_TOKEN = "ACCESS_TOKEN"
EXT_JWT = "EXT_JWT"


class RedshiftConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"
IAM_ROLE = "iam_role"
IAM_IDENTITY_CENTER_BROWSER = "browser_identity_center"
IAM_IDENTITY_CENTER_TOKEN = "oauth_token_identity_center"

@classmethod
def uses_identity_center(cls, method: str) -> bool:
return method in (cls.IAM_IDENTITY_CENTER_BROWSER,)
return method in (cls.IAM_IDENTITY_CENTER_BROWSER, cls.IAM_IDENTITY_CENTER_TOKEN)

@classmethod
def is_iam(cls, method: str) -> bool:
Expand Down Expand Up @@ -153,6 +151,12 @@ class RedshiftCredentials(Credentials):
idc_client_display_name: Optional[str] = "Amazon Redshift driver"
idp_response_timeout: Optional[int] = None

# token_endpoint
# a field that we expect to be a dictionary of values used to create
# access tokens from an external identity provider integrated with a redshift
# and aws org or account Iam Idc instance
token_endpoint: Optional[Dict[str, str]] = None

_ALIASES = {"dbname": "database", "pass": "password"}

@property
Expand Down Expand Up @@ -323,6 +327,34 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]:

return __iam_kwargs(credentials) | idc_kwargs

def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]:
"""
Accepts a `credentials` object with a `token_endpoint` field that corresponds to
either Okta or Entra authentication services.
We only support token_type=EXT_JWT tokens. token_type=ACCESS_TOKEN has not been
tested. It can be added with a presenting use-case.
"""

logger.debug("Connecting to Redshift with '{credentials.method}' credentials method")

__validate_required_fields("oauth_token_identity_center", ("token_endpoint",))

token_service = create_token_service_client(credentials.token_endpoint)
response = token_service.handle_request()
try:
access_token = response.json()["access_token"]
except KeyError:
raise FailedToConnectError(
"access_token missing from Idp token request. Please confirm correct configuration of the token_endpoint field in profiles.yml and that your Idp can use a refresh token to obtain an OIDC-compliant access token."
)

return __iam_kwargs(credentials) | {
"credentials_provider": "IdpTokenAuthPlugin",
"token": access_token,
"token_type": "EXT_JWT",
}

#
# Head of function execution
#
Expand All @@ -333,6 +365,7 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]:
RedshiftConnectionMethod.IAM: __iam_user_kwargs,
RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs,
RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs,
RedshiftConnectionMethod.IAM_IDENTITY_CENTER_TOKEN: __iam_idc_token_kwargs,
}

try:
Expand Down
7 changes: 4 additions & 3 deletions hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ packages = ["dbt"]
dependencies = [
"dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git",
"dbt-common @ git+https://github.com/dbt-labs/dbt-common.git",
"dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter",
"dbt-core @ git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core",
"dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter",
"ddtrace==2.3.0",
"freezegun",
"ipdb~=0.13.13",
"pre-commit==3.7.0",
"freezegun",
"pytest>=7.0,<8.0",
"pytest-csv~=3.0",
"pytest-dotenv",
"pytest-logbook~=1.2",
"pytest-mock",
"pytest-xdist",
"pytest>=7.0,<8.0",
"requests",
]

[envs.default.scripts]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
# installed via dbt-core but referenced directly; don't pin to avoid version conflicts with dbt-core
"sqlparse>=0.5.0,<0.6.0",
"agate",
"requests",
]

[project.urls]
Expand Down
22 changes: 22 additions & 0 deletions tests/functional/test_auth_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,25 @@ def dbt_profile_target(self):
"host": "", # host is a required field in dbt-core
"port": 0, # port is a required field in dbt-core
}


@pytest.mark.skip(
reason="We need to cut over to new adapters team AWS account which has infra to support this as an automated test. This will include a GHA step that renders a refresh token and loading secrets into Github secrets for the <> delimited placeholder values below"
)
class TestIamIdcAuthProfileOktaIdp(AuthMethod):
@pytest.fixture(scope="class")
def dbt_profile_target(self):
return {
"type": "redshift",
"method": "oauth_token_identity_center",
"host": os.getenv("REDSHIFT_TEST_HOST"),
"port": 5439,
"dbname": "dev",
"threads": 1,
"token_endpoint": {
"type": "okta",
"request_url": "https://<subdomain>.oktapreview.com/oauth2/default/v1/token",
"idp_auth_credentials": "<base64 creds>",
"request_data": "grant_type=refresh_token&redirect_uri=<encoded redirect uri>&refresh_token=<a refresh token>",
},
}
112 changes: 112 additions & 0 deletions tests/unit/test_auth_method.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import requests

from multiprocessing import get_context
from unittest import TestCase, mock
from unittest.mock import MagicMock
Expand Down Expand Up @@ -673,3 +675,113 @@ def test_invalid_adapter_missing_fields(self):
"'idc_region', 'issuer_url' field(s) are required for 'browser_identity_center' credentials method"
in context.exception.msg
)


class TestIAMIdcToken(AuthMethod):
@mock.patch("redshift_connector.connect", MagicMock())
def test_profile_idc_token_all_required_fields_okta(self):
"""This test doesn't follow the idiom elsewhere in this file because we
a real test would need a valid refresh token which would require a valid
authorization request, neither of which are possible in automated testing at
merge. This is a surrogate test.
"""
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
token_endpoint={
"type": "okta",
"request_url": "https://dbtcs.oktapreview.com/oauth2/default/v1/token",
"idp_auth_credentials": "my_auth_creds",
"request_data": "grant_type=refresh_token&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Flogin%2Foauth2%2Fcode%2Fokta&refresh_token=my_token",
},
)
with self.assertRaises(requests.exceptions.HTTPError) as context:
"""
An http says we've made it in operation to call the token request which fails
due to invalid refresh token and auth creds
"""
connection = self.adapter.acquire_connection("dummy")
connection.handle

assert "401 Client Error: Unauthorized for url" in str(context.exception)

@mock.patch("redshift_connector.connect", MagicMock())
def test_profile_idc_token_all_required_fields_entra(self):
"""This test doesn't follow the idiom elsewhere in this file because we
a real test would need a valid refresh token which would require a valid
authorization request, neither of which are possible in automated testing at
merge. This is a surrogate test.
"""
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
token_endpoint={
"type": "entra",
"request_url": "https://login.microsoftonline.com/my_tenant/oauth2/v2.0/token",
"request_data": "my_data",
},
)
with self.assertRaises(requests.exceptions.HTTPError) as context:
"""
An http says we've made it in operation to call the token request which fails
due to invalid refresh token and auth creds
"""
connection = self.adapter.acquire_connection("dummy")
connection.handle

assert "400 Client Error: Bad Request for url" in str(context.exception)

@mock.patch("redshift_connector.connect", MagicMock())
def test_invalid_idc_token_missing_field(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
)
with self.assertRaises(FailedToConnectError) as context:
connection = self.adapter.acquire_connection("dummy")
connection.handle
assert (
"'token_endpoint' field(s) are required for 'oauth_token_identity_center' credentials method"
in context.exception.msg
)

@mock.patch("redshift_connector.connect", MagicMock())
def test_invalid_idc_token_missing_token_endpoint_subfield_okta(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
token_endpoint={
"type": "okta",
"request_data": "my_data",
"idp_auth_credentials": "my_auth_creds",
},
)
with self.assertRaises(FailedToConnectError) as context:
connection = self.adapter.acquire_connection("dummy")
connection.handle
assert "Missing required key in token_endpoint: 'request_url'" in context.exception.msg

@mock.patch("redshift_connector.connect", MagicMock())
def test_invalid_idc_token_missing_token_endpoint_subfield_entra(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
token_endpoint={
"type": "entra",
"request_url": "https://dbtcs.oktapreview.com/oauth2/default/v1/token",
},
)
with self.assertRaises(FailedToConnectError) as context:
connection = self.adapter.acquire_connection("dummy")
connection.handle
assert "Missing required key in token_endpoint: 'request_data'" in context.exception.msg

@mock.patch("redshift_connector.connect", MagicMock())
def test_invalid_idc_token_missing_token_endpoint_type(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
method="oauth_token_identity_center",
token_endpoint={},
)
with self.assertRaises(FailedToConnectError) as context:
connection = self.adapter.acquire_connection("dummy")
connection.handle
assert "Missing required key in token_endpoint: 'type'" in context.exception.msg

0 comments on commit de078b8

Please sign in to comment.