diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index ba358d2465bbc..6eb02fe80552e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import boto3 @@ -73,6 +74,8 @@ class AwsConnectionConfig(ConfigModel): - dbt source """ + _credentials_expiration: Optional[datetime] = None + aws_access_key_id: Optional[str] = Field( default=None, description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}", @@ -115,6 +118,11 @@ class AwsConnectionConfig(ConfigModel): description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).", ) + def allowed_cred_refresh(self) -> bool: + if self._normalized_aws_roles(): + return True + return False + def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]: if not self.aws_role: return [] @@ -153,11 +161,14 @@ def get_session(self) -> Session: } for role in self._normalized_aws_roles(): - credentials = assume_role( - role, - self.aws_region, - credentials=credentials, - ) + if self._should_refresh_credentials(): + credentials = assume_role( + role, + self.aws_region, + credentials=credentials, + ) + if isinstance(credentials["Expiration"], datetime): + self._credentials_expiration = credentials["Expiration"] session = Session( aws_access_key_id=credentials["AccessKeyId"], @@ -168,6 +179,12 @@ def get_session(self) -> Session: return session + def _should_refresh_credentials(self) -> bool: + if self._credentials_expiration is None: + return True + remaining_time = self._credentials_expiration - datetime.now(timezone.utc) + return remaining_time < timedelta(minutes=5) + def get_credentials(self) -> Dict[str, Optional[str]]: credentials = self.get_session().get_credentials() if credentials is not None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py index acbc6eb9a0e44..b63fa57f069b5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import DefaultDict, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( @@ -33,6 +33,9 @@ StatefulIngestionSourceBase, ) +if TYPE_CHECKING: + from mypy_boto3_sagemaker import SageMakerClient + @platform_name("SageMaker") @config_class(SagemakerSourceConfig) @@ -56,6 +59,7 @@ def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext): self.report = SagemakerSourceReport() self.sagemaker_client = config.sagemaker_client self.env = config.env + self.client_factory = ClientFactory(config) @classmethod def create(cls, config_dict, ctx): @@ -92,7 +96,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # extract jobs if specified if self.source_config.extract_jobs is not False: job_processor = JobProcessor( - sagemaker_client=self.sagemaker_client, + sagemaker_client=self.client_factory.get_client, env=self.env, report=self.report, job_type_filter=self.source_config.extract_jobs, @@ -118,3 +122,15 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: def get_report(self): return self.report + + +class ClientFactory: + def __init__(self, config: SagemakerSourceConfig): + self.config = config + self._cached_client = self.config.sagemaker_client + + def get_client(self) -> "SageMakerClient": + if self.config.allowed_cred_refresh(): + # Always fetch the client dynamically with auto-refresh logic + return self.config.sagemaker_client + return self._cached_client diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py index a1a5a00884237..73a83295ec8cb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, DefaultDict, Dict, Iterable, @@ -147,7 +148,7 @@ class JobProcessor: """ # boto3 SageMaker client - sagemaker_client: "SageMakerClient" + sagemaker_client: Callable[[], "SageMakerClient"] env: str report: SagemakerSourceReport # config filter for specific job types to ingest (see metadata-ingestion README) @@ -170,8 +171,7 @@ class JobProcessor: def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]: jobs = [] - - paginator = self.sagemaker_client.get_paginator(job_spec.list_command) + paginator = self.sagemaker_client().get_paginator(job_spec.list_command) for page in paginator.paginate(): page_jobs: List[Any] = page[job_spec.list_key] @@ -269,7 +269,7 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]: describe_command = job_type_to_info[job_type].describe_command describe_name_key = job_type_to_info[job_type].describe_name_key - return getattr(self.sagemaker_client, describe_command)( + return getattr(self.sagemaker_client(), describe_command)( **{describe_name_key: job_name} ) diff --git a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py index 582d16f5d2612..995d176c213b2 100644 --- a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py +++ b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from botocore.stub import Stubber from freezegun import freeze_time @@ -220,8 +222,17 @@ def test_sagemaker_ingest(tmp_path, pytestconfig): {"ModelName": "the-second-model"}, ) - mce_objects = [wu.metadata for wu in sagemaker_source_instance.get_workunits()] - write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects) + # Patch the client factory's get_client method to return the stubbed client for jobs + with patch.object( + sagemaker_source_instance.client_factory, + "get_client", + return_value=sagemaker_source_instance.sagemaker_client, + ): + # Run the test and generate the MCEs + mce_objects = [ + wu.metadata for wu in sagemaker_source_instance.get_workunits() + ] + write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects) # Verify the output. test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"