From 43e3d079fb567d2b0e867706655bacc9f7ad9a7d Mon Sep 17 00:00:00 2001 From: Hyejin Yoon <0327jane@gmail.com> Date: Fri, 17 Jan 2025 00:28:15 +0900 Subject: [PATCH] [wip] fix ui bugs --- .../profile/DataProcessInstanceSummary.tsx | 2 +- .../entity/mlModel/profile/MLModelSummary.tsx | 4 + .../mlModelGroup/profile/ModelGroupModels.tsx | 48 +-- datahub-web-react/src/graphql/lineage.graphql | 18 + datahub-web-react/src/graphql/mlModel.graphql | 17 + .../examples/ml/create_mlmodel.py | 385 +++++++++++++++--- .../dataprocess/dataprocess_instance.py | 87 +++- 7 files changed, 481 insertions(+), 80 deletions(-) diff --git a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx index 63a387c6935731..30657286e7c350 100644 --- a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx +++ b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx @@ -41,7 +41,7 @@ export default function MLModelSummary() { const baseEntity = useBaseEntity(); const dpi = baseEntity?.dataProcessInstance; - print("dpi TRP", dpi?.mlTrainingRunProperties); + console.log("dpi", dpi); const formatDate = (timestamp?: number) => { if (!timestamp) return '-'; diff --git a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx index ee8b57c1376fee..56d44254801a65 100644 --- a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx +++ b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx @@ -48,6 +48,8 @@ export default function MLModelSummary() { const model = baseEntity?.mlModel; const entityRegistry = useEntityRegistry(); + console.log("model", model); + const propertyTableColumns = [ { title: 'Name', @@ -68,9 +70,11 @@ export default function MLModelSummary() { const renderTrainingJobs = () => { const lineageTrainingJobs = model?.properties?.mlModelLineageInfo?.trainingJobs || []; + console.log("lineageTrainingJobs", model?.properties?.mlModelLineageInfo?.trainingJobs); if (lineageTrainingJobs.length === 0) return '-'; + // TODO: get job name from job URN return lineageTrainingJobs.map((jobUrn, index) => (
diff --git a/datahub-web-react/src/app/entity/mlModelGroup/profile/ModelGroupModels.tsx b/datahub-web-react/src/app/entity/mlModelGroup/profile/ModelGroupModels.tsx index 48df9cd2eeedad..b99d1f08e3c0df 100644 --- a/datahub-web-react/src/app/entity/mlModelGroup/profile/ModelGroupModels.tsx +++ b/datahub-web-react/src/app/entity/mlModelGroup/profile/ModelGroupModels.tsx @@ -100,8 +100,8 @@ export default function MLGroupModels() { return new Date(milliseconds).toISOString().slice(0, 19).replace('T', ' '); }; + console.log("modelGroup", modelGroup); console.log("models", models); - console.log("model properties", models[0].properties.customProperties[0]); const columns = [ { @@ -127,32 +127,32 @@ export default function MLGroupModels() { ), }, { - title: 'Registered At', - key: 'date', + title: 'Created At', + key: 'createdAt', width: 150, render: (_: any, record: EntityType.Mlmodel) => ( - {formatDate(record.properties?.date)} + {formatDate(record.properties?.created?.time)} ), }, // use versionProperties for aliases - // { - // title: 'Aliases', - // key: 'aliases', - // width: 200, - // render: (_: any, record: EntityType.Mlmodel) => { - // const aliases = record.versionProperties?.aliases?.map(va => va.aliasVersion) || []; - - // return ( - // - // {aliases.map((alias) => ( - // - // {alias} - // - // ))} - // - // ); - // }, - // }, + { + title: 'Aliases', + key: 'aliases', + width: 200, + render: (_: any, record: EntityType.Mlmodel) => { + const aliases = record.versionProperties?.aliases?.map(va => va.aliasVersion) || []; + + return ( + + {aliases.map((alias) => ( + + {alias} + + ))} + + ); + }, + }, { title: 'Tags', key: 'tags', @@ -194,10 +194,10 @@ export default function MLGroupModels() { Model Group Details - {modelGroup?.properties?.created?.time ? formatDate(modelGroup.properties.createdAt) : '-'} + {modelGroup?.properties?.created?.time ? formatDate(modelGroup.properties?.created?.time) : '-'} - {modelGroup?.properties?.lastModified?.time ? formatDate(modelGroup.properties.lastModified) : '-'} + {modelGroup?.properties?.lastModified?.time ? formatDate(modelGroup.properties?.lastModified?.time) : '-'} {modelGroup?.properties?.created?.actor && ( diff --git a/datahub-web-react/src/graphql/lineage.graphql b/datahub-web-react/src/graphql/lineage.graphql index 457936ed62cd2e..8b9e2b7070ecd7 100644 --- a/datahub-web-react/src/graphql/lineage.graphql +++ b/datahub-web-react/src/graphql/lineage.graphql @@ -296,6 +296,9 @@ fragment lineageNodeProperties on EntityWithRelationships { name description origin + tags { + ...globalTagsFields + } platform { ...platformFields } @@ -305,6 +308,21 @@ fragment lineageNodeProperties on EntityWithRelationships { status { removed } + properties { + createdTS: created { + time + actor + } + modelVersion: version + tags + customProperties { + key + value + } + } + editableProperties { + description + } structuredProperties { properties { ...structuredPropertiesFields diff --git a/datahub-web-react/src/graphql/mlModel.graphql b/datahub-web-react/src/graphql/mlModel.graphql index ad97c7c6f530a1..ba10a243e6f9b3 100644 --- a/datahub-web-react/src/graphql/mlModel.graphql +++ b/datahub-web-react/src/graphql/mlModel.graphql @@ -20,6 +20,23 @@ query getMLModel($urn: String!) { } } } + trainedBy: relationships(input: { types: ["TrainedBy"], direction: OUTGOING, start: 0, count: 100 }) { + start + count + total + relationships { + type + direction + entity { + ... on DataProcessInstance { + urn + name + type + ...dataProcessInstanceFields + } + } + } + } privileges { ...entityPrivileges } diff --git a/metadata-ingestion/examples/ml/create_mlmodel.py b/metadata-ingestion/examples/ml/create_mlmodel.py index 6876a3dff85340..72264f653040f1 100644 --- a/metadata-ingestion/examples/ml/create_mlmodel.py +++ b/metadata-ingestion/examples/ml/create_mlmodel.py @@ -1,15 +1,218 @@ import time -from typing import Iterable +from dataclasses import dataclass +from typing import List, Optional import datahub.metadata.schema_classes as models from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.metadata.urns import MlModelGroupUrn, MlModelUrn -from datahub.ingestion.graph.client import get_default_graph +from datahub.metadata.urns import MlModelGroupUrn, MlModelUrn, DatasetUrn +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, +) +from datahub.emitter.mcp_builder import ContainerKey +from datahub.emitter.rest_emitter import DatahubRestEmitter + + +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp(self) -> List[MetadataChangeProposalWrapper]: + container_urn = self.key.as_urn() + current_time = int(time.time() * 1000) + + # Create container aspects + container_subtype = models.SubTypesClass(typeNames=[self.subtype]) + container_info = models.ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + created=models.TimeStampClass( + time=current_time, + actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=current_time, + actor="urn:li:corpuser:datahub" + ), + customProperties={}, + ) + browse_path = models.BrowsePathsV2Class(path=[]) + dpi = models.DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + mcps = [] + + # Add container aspects + mcps.extend([ + MetadataChangeProposalWrapper( + entityType="container", + entityUrn=str(container_urn), + aspectName="subTypes", + aspect=container_subtype, + changeType=models.ChangeTypeClass.UPSERT + ), + MetadataChangeProposalWrapper( + entityType="container", + entityUrn=str(container_urn), + aspectName="containerProperties", + aspect=container_info, + changeType=models.ChangeTypeClass.UPSERT + ), + MetadataChangeProposalWrapper( + entityType="container", + entityUrn=str(container_urn), + aspectName="dataPlatformInstance", + aspect=dpi, + changeType=models.ChangeTypeClass.UPSERT + ), + MetadataChangeProposalWrapper( + entityType="container", + entityUrn=str(container_urn), + aspectName="status", + aspect=models.StatusClass(removed=False), + changeType=models.ChangeTypeClass.UPSERT + ) + ]) + + return mcps + + +def create_training_job( + experiment_key: ContainerKeyWithId, + run_id: str, + input_dataset_urn: str +) -> tuple[DataProcessInstance, List[MetadataChangeProposalWrapper]]: + """Create a training job instance""" + data_process_instance = DataProcessInstance.from_container( + container_key=experiment_key, + id=run_id + ) + + data_process_instance.platform = experiment_key.platform + data_process_instance.subtype = "Training Run" + data_process_instance.inlets = [DatasetUrn.from_string(input_dataset_urn)] + data_process_instance.container = experiment_key.as_urn() # Set container relationship here + + created_at = int(time.time() * 1000) + + # First get base MCPs from the instance + mcps = list(data_process_instance.generate_mcp( + created_ts_millis=created_at, + materialize_iolets=True + )) + + # Create and add DPI properties aspect + dpi_props = models.DataProcessInstancePropertiesClass( + name=f"Training {run_id}", + created=models.AuditStampClass( + time=created_at, + actor="urn:li:corpuser:datahub" + ), + externalUrl="http://mlflow:5000", + customProperties={ + "framework": "sklearn", + "python_version": "3.8", + "experiment_id": experiment_key.id, + }, + ) + + # Create training run properties + training_run_props = models.MLTrainingRunPropertiesClass( + customProperties={ + "learning_rate": "0.01", + "batch_size": "64", + }, + externalUrl="http://mlflow:5000", + hyperParams=[ + models.MLHyperParamClass( + name="n_estimators", + value="100", + description="Number of trees" + ), + models.MLHyperParamClass( + name="max_depth", + value="10", + description="Maximum tree depth" + ) + ], + trainingMetrics=[ + models.MLMetricClass( + name="accuracy", + value="0.95", + description="Test accuracy" + ), + models.MLMetricClass( + name="f1_score", + value="0.93", + description="Test F1 score" + ) + ], + outputUrls=["s3://mlflow/outputs"], + id=run_id, + ) + + # Add custom aspects + mcps.extend([ + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + entityType="dataProcessInstance", + aspectName="dataProcessInstanceProperties", + aspect=dpi_props, + changeType=models.ChangeTypeClass.UPSERT + ), + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + entityType="dataProcessInstance", + aspectName="mlTrainingRunProperties", + aspect=training_run_props, + changeType=models.ChangeTypeClass.UPSERT + ) + ]) + + # Add run events + start_time = created_at + end_time = start_time + (45 * 60 * 1000) # 45 minutes duration + + # mcps.extend([ + # MetadataChangeProposalWrapper( + # entityUrn=str(data_process_instance.urn), + # entityType="dataProcessInstance", + # aspectName="dataProcessInstanceRunEvent", + # aspect=models.DataProcessInstanceRunEventClass( + # timestampMillis=start_time, + # eventGranularity="TASK" + # ), + # changeType=models.ChangeTypeClass.UPSERT + # ), + # MetadataChangeProposalWrapper( + # entityUrn=str(data_process_instance.urn), + # entityType="dataProcessInstance", + # aspectName="dataProcessInstanceRunEvent", + # aspect=models.DataProcessInstanceRunEventClass( + # timestampMillis=end_time, + # eventGranularity="TASK", + # result=InstanceRunResult.SUCCESS + # ), + # changeType=models.ChangeTypeClass.UPSERT + # ) + # ]) + + return data_process_instance, mcps + def create_model_group() -> tuple[MlModelGroupUrn, MetadataChangeProposalWrapper]: """Create a model group and return its URN and MCP""" model_group_urn = MlModelGroupUrn(platform="mlflow", name="simple_model_group") current_time = int(time.time() * 1000) - + model_group_info = models.MLModelGroupPropertiesClass( description="Simple ML model group example", customProperties={ @@ -24,56 +227,33 @@ def create_model_group() -> tuple[MlModelGroupUrn, MetadataChangeProposalWrapper time=current_time, actor="urn:li:corpuser:datahub" ), - trainingJobs=[], ) - model_group_mcp = MetadataChangeProposalWrapper( + return model_group_urn, MetadataChangeProposalWrapper( entityUrn=str(model_group_urn), + entityType="mlModelGroup", + aspectName="mlModelGroupProperties", aspect=model_group_info, + changeType=models.ChangeTypeClass.UPSERT ) - - return model_group_urn, model_group_mcp + def create_single_model( - model_name: str, - model_group_urn: str, -) -> MetadataChangeProposalWrapper: + model_name: str, + model_group_urn: str, + training_job_urn: str, +) -> tuple[MlModelUrn, List[MetadataChangeProposalWrapper]]: """Create a single ML model and return its MCP""" model_urn = MlModelUrn(platform="mlflow", name=model_name) current_time = int(time.time() * 1000) - - # Define example metrics and hyperparameters - training_metrics = [ - models.MLMetricClass( - name="accuracy", - value="0.95", - description="Test accuracy" - ), - models.MLMetricClass( - name="f1_score", - value="0.93", - description="Test F1 score" - ) - ] - - hyper_params = [ - models.MLHyperParamClass( - name="n_estimators", - value="100", - description="Number of trees" - ), - models.MLHyperParamClass( - name="max_depth", - value="10", - description="Maximum tree depth" - ) - ] + + mcps = [] model_info = models.MLModelPropertiesClass( - name=model_name, description="Simple example ML model", version=models.VersionTagClass(versionTag="1"), groups=[str(model_group_urn)], + trainingJobs=[str(training_job_urn)], date=current_time, lastModified=models.TimeStampClass( time=current_time, @@ -84,27 +264,124 @@ def create_single_model( actor="urn:li:corpuser:datahub" ), tags=["stage:production", "team:data_science"], - trainingMetrics=training_metrics, - hyperParams=hyper_params, - trainingJobs=[], - downstreamJobs=[], + trainingMetrics=[ + models.MLMetricClass( + name="accuracy", + value="0.95", + description="Test accuracy" + ), + models.MLMetricClass( + name="f1_score", + value="0.93", + description="Test F1 score" + ) + ], + hyperParams=[ + models.MLHyperParamClass( + name="n_estimators", + value="100", + description="Number of trees" + ), + models.MLHyperParamClass( + name="max_depth", + value="10", + description="Maximum tree depth" + ) + ], ) - return MetadataChangeProposalWrapper( - entityUrn=str(model_urn), - aspect=model_info, + # print(str(model_urn)) + # model_version_info = models.VersionPropertiesClass( + # version=models.VersionTagClass(versionTag="1"), + # versionSet="urn:li:mlModel:(urn:li:dataPlatform:mlflow,simple_model,PROD)", + # aliases=[models.VersionTagClass(versionTag="latest")], + # sortId="", + # ) + # + # mcps.append( + # MetadataChangeProposalWrapper( + # entityUrn=str(model_urn), + # entityType="mlModel", + # aspectName="versionProperties", + # aspect=model_version_info, + # changeType=models.ChangeTypeClass.UPSERT + # ) + # ) + + mcps.append( + MetadataChangeProposalWrapper( + entityUrn=str(model_urn), + entityType="mlModel", + aspectName="mlModelProperties", + aspect=model_info, + changeType=models.ChangeTypeClass.UPSERT + ) ) + return model_urn, mcps + + def main(): - # Create the model group and model + # Create emitter with authentication token + token = "eyJhbGciOiJIUzI1NiJ9.eyJhY3RvclR5cGUiOiJVU0VSIiwiYWN0b3JJZCI6ImRhdGFodWIiLCJ0eXBlIjoiUEVSU09OQUwiLCJ2ZXJzaW9uIjoiMiIsImp0aSI6IjE3ZjkyMDVjLTEzMzAtNGYzMC1iYjhhLWU4MjdiNDE1MTRjOSIsInN1YiI6ImRhdGFodWIiLCJleHAiOjE3Mzk2MTc1NjEsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9.QNx813PkhRVEmX7t12-j2uaum0WDpjlCf_j66rzDnWw" + emitter = DatahubRestEmitter( + gms_server="http://localhost:8080", + extra_headers={ + "Authorization": f"Bearer {token}" + } + ) + + # Create the model group model_group_urn, model_group_mcp = create_model_group() - model_mcp = create_single_model("simple_model", str(model_group_urn)) - - # Emit the metadata to DataHub - with get_default_graph() as graph: - graph.emit(model_group_mcp) - graph.emit(model_mcp) - print("Successfully created model group and model in DataHub") + + # Create experiment container + experiment = Container( + key=ContainerKeyWithId( + platform="urn:li:dataPlatform:mlflow", + # instance="prod", + id="airline_forecast_experiment", + ), + subtype="ML Experiment", + name="Airline Forecast Experiment", + description="Experiment for forecasting airline passengers", + ) + + # Create training job instance + training_job, training_mcps = create_training_job( + experiment_key=experiment.key, + run_id="run_1", + input_dataset_urn="urn:li:dataset:(urn:li:dataPlatform:s3,airline_passengers,PROD)" + ) + + # Create the model with training job reference + model_urn, model_mcps = create_single_model( + "simple_model", + str(model_group_urn), + str(training_job.urn) + ) + + # Emit all metadata + # First, emit model group + print("Emitting model group...") + emitter.emit(model_group_mcp) + + # Emit experiment container + print("Emitting container aspects...") + for mcp in experiment.generate_mcp(): + emitter.emit(mcp) + + # Emit training job properties and events + print("Emitting training job aspects...") + for mcp in training_mcps: + emitter.emit(mcp) + + # Finally emit the model and its aspects + print("Emitting model aspects...") + for mcp in model_mcps: + emitter.emit(mcp) + + print("Successfully created model group, training job, and model in DataHub") + if __name__ == "__main__": main() \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py index d406fa36e00db6..027f9d77ee1070 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py @@ -15,16 +15,21 @@ ) from datahub.metadata.schema_classes import ( AuditStampClass, + DataPlatformInstanceClass, DataProcessInstanceRunEventClass, DataProcessInstanceRunResultClass, DataProcessRunStatusClass, DataProcessTypeClass, + SubTypesClass, + ContainerClass ) +from datahub.metadata.urns import DataPlatformInstanceUrn, DataPlatformUrn, ContainerUrn from datahub.utilities.str_enum import StrEnum from datahub.utilities.urns.data_flow_urn import DataFlowUrn from datahub.utilities.urns.data_job_urn import DataJobUrn from datahub.utilities.urns.data_process_instance_urn import DataProcessInstanceUrn from datahub.utilities.urns.dataset_urn import DatasetUrn +from datahub.emitter.mcp_builder import ContainerKey class DataProcessInstanceKey(DatahubKey): @@ -61,7 +66,7 @@ class DataProcessInstance: orchestrator: str cluster: Optional[str] = None type: str = DataProcessTypeClass.BATCH_SCHEDULED - template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn]] = None + template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn, ContainerUrn]] = None parent_instance: Optional[DataProcessInstanceUrn] = None properties: Dict[str, str] = field(default_factory=dict) url: Optional[str] = None @@ -71,6 +76,10 @@ class DataProcessInstance: _template_object: Optional[Union[DataJob, DataFlow]] = field( init=False, default=None, repr=False ) + data_platform: Optional[str] = None + data_plaform_instance: Optional[str] = None + subtype: Optional[str] = None + container_urn: Optional[str] = None def __post_init__(self): self.urn = DataProcessInstanceUrn( @@ -80,6 +89,36 @@ def __post_init__(self): id=self.id, ).guid() ) + if self.data_platform is None: + self.data_platform = self.orchestrator + + try: + # We first try to create from string assuming its an urn + self.data_platform = str( + DataPlatformUrn.create_from_string(self.data_platform) + ) + except Exception: + # If it fails, we assume its an id + self.data_platform = str(DataPlatformUrn.create_from_id(self.data_platform)) + + if self.data_plaform_instance is None and self.cluster is not None: + self.data_plaform_instance = self.cluster + + if self.data_plaform_instance is not None: + try: + # We first try to create from string assuming its an urn + self.data_plaform_instance = str( + DataPlatformInstanceUrn.create_from_string( + self.data_plaform_instance + ) + ) + except Exception: + # If it fails, we assume its an id + self.data_plaform_instance = str( + DataPlatformInstanceUrn( + platform=self.data_platform, instance=self.data_plaform_instance + ) + ) def start_event_mcp( self, start_timestamp_millis: int, attempt: Optional[int] = None @@ -269,6 +308,29 @@ def generate_mcp( ) yield mcp + assert self.data_platform + + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=DataPlatformInstanceClass( + platform=self.data_platform, instance=self.data_plaform_instance + ), + ) + yield mcp + + if self.subtype: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), aspect=SubTypesClass(typeNames=[self.subtype]) + ) + yield mcp + + if self.container_urn: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=ContainerClass(container=self.container_urn), + ) + yield mcp + yield from self.generate_inlet_outlet_mcp(materialize_iolets=materialize_iolets) @staticmethod @@ -331,6 +393,29 @@ def from_datajob( dpi.outlets = datajob.outlets return dpi + @staticmethod + def from_container( + container_key: ContainerKey, + id: str, + ) -> "DataProcessInstance": + """ + Generates DataProcessInstance from a Container + + :param datajob: (DataJob) the datajob from generate the DataProcessInstance + :param id: (str) the id for the DataProcessInstance + :param clone_inlets: (bool) whether to clone datajob's inlets + :param clone_outlets: (bool) whether to clone datajob's outlets + :return: DataProcessInstance + """ + dpi: DataProcessInstance = DataProcessInstance( + id=id, + orchestrator=DataPlatformUrn.from_string(container_key.platform).platform_name, + template_urn=None, + container_urn = container_key.as_urn(), + ) + + return dpi + @staticmethod def from_dataflow(dataflow: DataFlow, id: str) -> "DataProcessInstance": """