Skip to content

Commit

Permalink
Client factory used by all sagemaker sub-steps
Browse files Browse the repository at this point in the history
  • Loading branch information
skrydal committed Aug 29, 2024
1 parent 6a7b8bc commit 07f7dcf
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ class LazyEvaluator:
def __init__(self, callback, *args):
self.callback = callback
self.args = args
self.value_stored = False
self.value = None

def __repr__(self):
return self.callback(*self.args)
if not self.value_stored:
self.value = self.callback(*self.args)
return self.value


class AwsConnectionConfig(ConfigModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# get common lineage graph
lineage_processor = LineageProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
sagemaker_client=self.client_factory.get_client, env=self.env, report=self.report
)
lineage = lineage_processor.get_lineage()

# extract feature groups if specified
if self.source_config.extract_feature_groups:
feature_group_processor = FeatureGroupProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
sagemaker_client=self.client_factory.get_client, env=self.env, report=self.report
)
yield from feature_group_processor.get_workunits()

Expand All @@ -110,7 +110,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# extract models if specified
if self.source_config.extract_models:
model_processor = ModelProcessor(
sagemaker_client=self.sagemaker_client,
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
model_image_to_jobs=model_image_to_jobs,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable, List
from typing import TYPE_CHECKING, Iterable, List, Callable

import datahub.emitter.mce_builder as builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
Expand Down Expand Up @@ -34,7 +34,7 @@

@dataclass
class FeatureGroupProcessor:
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport

Expand All @@ -46,7 +46,7 @@ def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
feature_groups = []
logger.debug("Attempting to get all feature groups")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_feature_groups
paginator = self.sagemaker_client.get_paginator("list_feature_groups")
paginator = self.sagemaker_client().get_paginator("list_feature_groups")
for page in paginator.paginate():
logger.debug(
"Retrieved %s feature groups", len(page["FeatureGroupSummaries"])
Expand All @@ -63,7 +63,7 @@ def get_feature_group_details(
"""
logger.debug("Attempting to describe feature group: %s", feature_group_name)
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group
feature_group = self.sagemaker_client.describe_feature_group(
feature_group = self.sagemaker_client().describe_feature_group(
FeatureGroupName=feature_group_name
)

Expand All @@ -76,7 +76,7 @@ def get_feature_group_details(
"Iterating over another token to retrieve full feature group description for: %s",
feature_group_name,
)
next_features = self.sagemaker_client.describe_feature_group(
next_features = self.sagemaker_client().describe_feature_group(
FeatureGroupName=feature_group_name, NextToken=next_token
)
feature_group["FeatureDefinitions"] += next_features["FeatureDefinitions"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set, Callable

from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceReport,
Expand Down Expand Up @@ -46,7 +46,7 @@ class LineageInfo:

@dataclass
class LineageProcessor:
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport
nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
Expand All @@ -60,7 +60,7 @@ def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
actions = []
logger.debug("Attempting to get all actions")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_actions
paginator = self.sagemaker_client.get_paginator("list_actions")
paginator = self.sagemaker_client().get_paginator("list_actions")
for page in paginator.paginate():
logger.debug("Retrieved %s actions", len(page["ActionSummaries"]))
actions += page["ActionSummaries"]
Expand All @@ -75,7 +75,7 @@ def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
artifacts = []
logger.debug("Attempting to get all artifacts")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_artifacts
paginator = self.sagemaker_client.get_paginator("list_artifacts")
paginator = self.sagemaker_client().get_paginator("list_artifacts")
for page in paginator.paginate():
logger.debug("Retrieved %s artifacts", len(page["ArtifactSummaries"]))
artifacts += page["ArtifactSummaries"]
Expand All @@ -90,7 +90,7 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
contexts = []
logger.debug("Attempting to get all contexts")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_contexts
paginator = self.sagemaker_client.get_paginator("list_contexts")
paginator = self.sagemaker_client().get_paginator("list_contexts")
for page in paginator.paginate():
logger.debug("Retrieved %s contexts", len(page["ContextSummaries"]))
contexts += page["ContextSummaries"]
Expand All @@ -105,7 +105,7 @@ def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]
edges = []
logger.debug("Attempting to get all incoming edges")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_associations
paginator = self.sagemaker_client.get_paginator("list_associations")
paginator = self.sagemaker_client().get_paginator("list_associations")
for page in paginator.paginate(DestinationArn=node_arn):
logger.debug("Retrieved %s edges", len(page["AssociationSummaries"]))
edges += page["AssociationSummaries"]
Expand All @@ -119,7 +119,7 @@ def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]
edges = []
logger.debug("Attempting to get all outgoing edges")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_associations
paginator = self.sagemaker_client.get_paginator("list_associations")
paginator = self.sagemaker_client().get_paginator("list_associations")
for page in paginator.paginate(SourceArn=node_arn):
logger.debug("Retrieved %s edges", len(page["AssociationSummaries"]))
edges += page["AssociationSummaries"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
List,
Optional,
Set,
Tuple,
Tuple, Callable,
)

import datahub.emitter.mce_builder as builder
Expand Down Expand Up @@ -71,7 +71,7 @@

@dataclass
class ModelProcessor:
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport
lineage: LineageInfo
Expand Down Expand Up @@ -102,7 +102,7 @@ def get_all_models(self) -> List["ModelSummaryTypeDef"]:
models = []
logger.debug("Attempting to retrieve all models")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_models
paginator = self.sagemaker_client.get_paginator("list_models")
paginator = self.sagemaker_client().get_paginator("list_models")
for page in paginator.paginate():
logger.debug("Retrieved %s models", len(page["Models"]))
models += page["Models"]
Expand All @@ -115,7 +115,7 @@ def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef":
"""

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model
return self.sagemaker_client.describe_model(ModelName=model_name)
return self.sagemaker_client().describe_model(ModelName=model_name)

def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
"""
Expand All @@ -124,7 +124,7 @@ def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
groups = []
logger.debug("Attempting to retrieve all model groups")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_model_package_groups
paginator = self.sagemaker_client.get_paginator("list_model_package_groups")
paginator = self.sagemaker_client().get_paginator("list_model_package_groups")
for page in paginator.paginate():
logger.debug(
"Retrieved %s model groups", len(page["ModelPackageGroupSummaryList"])
Expand All @@ -141,15 +141,15 @@ def get_group_details(
"""

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model_package_group
return self.sagemaker_client.describe_model_package_group(
return self.sagemaker_client().describe_model_package_group(
ModelPackageGroupName=group_name
)

def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:
endpoints = []
logger.debug("Attempting to retrieve all endpoints")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_endpoints
paginator = self.sagemaker_client.get_paginator("list_endpoints")
paginator = self.sagemaker_client().get_paginator("list_endpoints")
for page in paginator.paginate():
logger.debug("Retrieved %s endpoints", len(page["Endpoints"]))
endpoints += page["Endpoints"]
Expand All @@ -160,7 +160,7 @@ def get_endpoint_details(
self, endpoint_name: str
) -> "DescribeEndpointOutputTypeDef":
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint
return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
return self.sagemaker_client().describe_endpoint(EndpointName=endpoint_name)

def get_endpoint_status(
self, endpoint_name: str, endpoint_arn: str, sagemaker_status: str
Expand Down

0 comments on commit 07f7dcf

Please sign in to comment.