Skip to content

Commit

Permalink
Added verbosity to sagemaker-related aws calls
Browse files Browse the repository at this point in the history
  • Loading branch information
skrydal committed Aug 29, 2024
1 parent aea9b35 commit 6a7b8bc
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable, List

Expand Down Expand Up @@ -28,6 +29,8 @@
FeatureGroupSummaryTypeDef,
)

logger = logging.getLogger(__name__)


@dataclass
class FeatureGroupProcessor:
Expand All @@ -41,10 +44,13 @@ def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
"""

feature_groups = []

logger.debug("Attempting to get all feature groups")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_feature_groups
paginator = self.sagemaker_client.get_paginator("list_feature_groups")
for page in paginator.paginate():
logger.debug(
"Retrieved %s feature groups", len(page["FeatureGroupSummaries"])
)
feature_groups += page["FeatureGroupSummaries"]

return feature_groups
Expand All @@ -55,7 +61,7 @@ def get_feature_group_details(
"""
Get details of a feature group (including list of component features).
"""

logger.debug("Attempting to describe feature group: %s", feature_group_name)
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group
feature_group = self.sagemaker_client.describe_feature_group(
FeatureGroupName=feature_group_name
Expand All @@ -66,12 +72,19 @@ def get_feature_group_details(

# paginate over feature group features
while next_token:
logger.debug(
"Iterating over another token to retrieve full feature group description for: %s",
feature_group_name,
)
next_features = self.sagemaker_client.describe_feature_group(
FeatureGroupName=feature_group_name, NextToken=next_token
)
feature_group["FeatureDefinitions"] += next_features["FeatureDefinitions"]
next_token = feature_group.get("NextToken", "")

logger.debug(
"Retrieved full description for feature group: %s", feature_group_name
)
return feature_group

def get_feature_group_wu(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -171,9 +174,11 @@ class JobProcessor:

def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []
logger.debug("Attempting to retrieve all jobs for type %s", job_type)
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]
logger.debug("Retrieved %s jobs", len(page_jobs))

for job in page_jobs:
job_name = (
Expand Down Expand Up @@ -269,6 +274,11 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key

logger.debug(
"Retrieving description for job: %s using command: %s",
job_name,
describe_command,
)
return getattr(self.sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
Expand All @@ -15,6 +16,8 @@
ContextSummaryTypeDef,
)

logger = logging.getLogger(__name__)


@dataclass
class LineageInfo:
Expand Down Expand Up @@ -55,10 +58,11 @@ def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
"""

actions = []

logger.debug("Attempting to get all actions")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_actions
paginator = self.sagemaker_client.get_paginator("list_actions")
for page in paginator.paginate():
logger.debug("Retrieved %s actions", len(page["ActionSummaries"]))
actions += page["ActionSummaries"]

return actions
Expand All @@ -69,10 +73,11 @@ def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
"""

artifacts = []

logger.debug("Attempting to get all artifacts")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_artifacts
paginator = self.sagemaker_client.get_paginator("list_artifacts")
for page in paginator.paginate():
logger.debug("Retrieved %s artifacts", len(page["ArtifactSummaries"]))
artifacts += page["ArtifactSummaries"]

return artifacts
Expand All @@ -83,10 +88,11 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
"""

contexts = []

logger.debug("Attempting to get all contexts")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_contexts
paginator = self.sagemaker_client.get_paginator("list_contexts")
for page in paginator.paginate():
logger.debug("Retrieved %s contexts", len(page["ContextSummaries"]))
contexts += page["ContextSummaries"]

return contexts
Expand All @@ -97,10 +103,11 @@ def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]
"""

edges = []

logger.debug("Attempting to get all incoming edges")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_associations
paginator = self.sagemaker_client.get_paginator("list_associations")
for page in paginator.paginate(DestinationArn=node_arn):
logger.debug("Retrieved %s edges", len(page["AssociationSummaries"]))
edges += page["AssociationSummaries"]

return edges
Expand All @@ -110,10 +117,11 @@ def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]
Get all outgoing edges for a node in the lineage graph.
"""
edges = []

logger.debug("Attempting to get all outgoing edges")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_associations
paginator = self.sagemaker_client.get_paginator("list_associations")
for page in paginator.paginate(SourceArn=node_arn):
logger.debug("Retrieved %s edges", len(page["AssociationSummaries"]))
edges += page["AssociationSummaries"]

return edges
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -53,6 +54,8 @@
ModelSummaryTypeDef,
)

logger = logging.getLogger(__name__)

ENDPOINT_STATUS_MAP: Dict[str, str] = {
"OutOfService": DeploymentStatusClass.OUT_OF_SERVICE,
"Creating": DeploymentStatusClass.CREATING,
Expand Down Expand Up @@ -97,10 +100,11 @@ def get_all_models(self) -> List["ModelSummaryTypeDef"]:
"""

models = []

logger.debug("Attempting to retrieve all models")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_models
paginator = self.sagemaker_client.get_paginator("list_models")
for page in paginator.paginate():
logger.debug("Retrieved %s models", len(page["Models"]))
models += page["Models"]

return models
Expand All @@ -118,10 +122,13 @@ def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
List all model groups in SageMaker.
"""
groups = []

logger.debug("Attempting to retrieve all model groups")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_model_package_groups
paginator = self.sagemaker_client.get_paginator("list_model_package_groups")
for page in paginator.paginate():
logger.debug(
"Retrieved %s model groups", len(page["ModelPackageGroupSummaryList"])
)
groups += page["ModelPackageGroupSummaryList"]

return groups
Expand All @@ -140,11 +147,11 @@ def get_group_details(

def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:
endpoints = []

logger.debug("Attempting to retrieve all endpoints")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_endpoints
paginator = self.sagemaker_client.get_paginator("list_endpoints")

for page in paginator.paginate():
logger.debug("Retrieved %s endpoints", len(page["Endpoints"]))
endpoints += page["Endpoints"]

return endpoints
Expand Down

0 comments on commit 6a7b8bc

Please sign in to comment.