diff --git a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx index 30657286e7c350..d26c8a9431f280 100644 --- a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx +++ b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx @@ -71,9 +71,6 @@ export default function MLModelSummary() { {formatDate(dpi?.properties?.created?.time)} - - {dpi?.properties?.created?.actor} - {formatStatus(dpi?.state)} @@ -83,6 +80,9 @@ export default function MLModelSummary() { {dpi?.mlTrainingRunProperties?.id} + + {dpi?.properties?.created?.actor} + diff --git a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx index db1bdbaacaca9a..fd279230de310c 100644 --- a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx +++ b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx @@ -104,11 +104,11 @@ export default function MLModelSummary() { {formatDate(model?.properties?.lastModified?.time)} - - {model?.properties?.created?.actor} + + {model?.versionProperties?.aliases?.map((alias, index) => ( diff --git a/metadata-ingestion/examples/ml/create_ml.py b/metadata-ingestion/examples/ml/create_ml.py new file mode 100644 index 00000000000000..3ac616cdd730b5 --- /dev/null +++ b/metadata-ingestion/examples/ml/create_ml.py @@ -0,0 +1,359 @@ +import time +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Iterable, List, Optional, Union +import random + +import datahub.metadata.schema_classes as models +from datahub.api.entities.datajob import DataFlow, DataJob +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, +) +from datahub.api.entities.dataset.dataset import Dataset +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import DataHubGraph, DatahubClientConfig +from datahub.metadata.urns import DatasetUrn, DataPlatformUrn, MlModelGroupUrn, MlModelUrn, VersionSetUrn +from datahub.emitter.mcp_builder import ContainerKey + +ORCHESTRATOR_MLFLOW = "mlflow" +ORCHESTRATOR_AIRFLOW = "airflow" + + +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[ + Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper] + ]: + container_urn = self.key.as_urn() + container_subtype = models.SubTypesClass(typeNames=[self.subtype]) + container_info = models.ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + browse_path = models.BrowsePathsV2Class(path=[]) + dpi = models.DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + yield from MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + +def create_model( + model_name: str, + model_group_urn: str, + data_process_instance_urn: str, + tags: List[str], + version_aliases: List[str], + index: int, + training_metrics: List[models.MLMetricClass], + hyper_params: List[models.MLHyperParamClass], + model_description: str, + created_at: int, +) -> Iterable[MetadataChangeProposalWrapper]: + model_urn = MlModelUrn(platform="mlflow", name=model_name) + + # Create model properties + model_info = models.MLModelPropertiesClass( + description=model_description, + version=models.VersionTagClass(versionTag=f"{index}"), + groups=[str(model_group_urn)], + trainingJobs=[str(data_process_instance_urn)], + date=created_at, + tags=tags, + trainingMetrics=training_metrics, + hyperParams=hyper_params, + created=models.TimeStampClass( + time=created_at, + actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=created_at, + actor="urn:li:corpuser:datahub" + ), + ) + + # Create version set + version_set_urn = VersionSetUrn(id=f"mlmodel_{model_name}_versions", entity_type="mlModel") + version_entity = models.VersionSetPropertiesClass( + latest=str(model_urn), + versioningScheme="ALPHANUMERIC_GENERATED_BY_DATAHUB", + ) + + # Create version properties + model_version_info = models.VersionPropertiesClass( + version=models.VersionTagClass(versionTag=f"{index}"), + versionSet=str(version_set_urn), + aliases=[models.VersionTagClass(versionTag=alias) for alias in version_aliases], + sortId="AAAAAAAA", + ) + + # Yield all MCPs + yield MetadataChangeProposalWrapper( + entityUrn=str(version_set_urn), + entityType="versionSet", + aspectName="versionSetProperties", + aspect=version_entity, + changeType=models.ChangeTypeClass.UPSERT + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(model_urn), + entityType="mlModel", + aspectName="versionProperties", + aspect=model_version_info, + changeType=models.ChangeTypeClass.UPSERT + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(model_urn), + entityType="mlModel", + aspectName="mlModelProperties", + aspect=model_info, + changeType=models.ChangeTypeClass.UPSERT + ) + + +def generate_pipeline( + pipeline_name: str, + orchestrator: str, +) -> Iterable[Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper]]: + data_flow = DataFlow( + id=pipeline_name, + orchestrator=orchestrator, + cluster="default", + name=pipeline_name, + ) + + data_job = DataJob(id="training", flow_urn=data_flow.urn, name="Training") + + input_dataset = Dataset( + id="airline_passengers", + name="Airline Passengers", + description="Monthly airline passenger data", + properties={}, + platform="s3", + schema=None, + ) + + if orchestrator == ORCHESTRATOR_MLFLOW: + experiment = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id="airline_forecast_experiment", + ), + subtype="ML Experiment", + name="Airline Forecast Experiment", + description="Experiment for forecasting airline passengers", + ) + + yield from experiment.generate_mcp() + + model_group_urn = MlModelGroupUrn(platform="mlflow", name="airline_forecast_models") + current_time = int(time.time() * 1000) + model_group_info = models.MLModelGroupPropertiesClass( + description="ML models for airline passenger forecasting", + customProperties={ + "stage": "production", + "team": "data_science", + }, + created=models.TimeStampClass( + time=current_time, + actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=current_time, + actor="urn:li:corpuser:datahub" + ), + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(model_group_urn), + entityType="mlModelGroup", + aspectName="mlModelGroupProperties", + aspect=model_group_info, + changeType=models.ChangeTypeClass.UPSERT + ) + + model_aliases = ["challenger", "champion", "production", "experimental", "deprecated"] + model_tags = ["stage:production", "stage:development", "team:data_science", "team:ml_engineering", + "team:analytics"] + + model_dict = { + "arima_model_1": "ARIMA model for airline passenger forecasting", + "arima_model_2": "Enhanced ARIMA model with seasonal components", + "arima_model_3": "ARIMA model optimized for long-term forecasting", + "arima_model_4": "ARIMA model with hyperparameter tuning", + "arima_model_5": "ARIMA model trained on extended dataset", + } + + # Generate run timestamps within the last month + end_time = int(time.time() * 1000) + start_time = end_time - (30 * 24 * 60 * 60 * 1000) + run_timestamps = [ + start_time + (i * 5 * 24 * 60 * 60 * 1000) + for i in range(5) + ] + + run_dict = { + "run_1": {"start_time": run_timestamps[0], "duration": 45, "result": InstanceRunResult.SUCCESS}, + "run_2": {"start_time": run_timestamps[1], "duration": 60, "result": InstanceRunResult.FAILURE}, + "run_3": {"start_time": run_timestamps[2], "duration": 55, "result": InstanceRunResult.SUCCESS}, + "run_4": {"start_time": run_timestamps[3], "duration": 70, "result": InstanceRunResult.SUCCESS}, + "run_5": {"start_time": run_timestamps[4], "duration": 50, "result": InstanceRunResult.FAILURE}, + } + + for i, (model_name, model_description) in enumerate(model_dict.items(), start=1): + run_id = f"run_{i}" + data_process_instance = DataProcessInstance.from_container( + container_key=experiment.key, id=run_id + ) + + data_process_instance.subtype = "Training Run" + data_process_instance.inlets = [DatasetUrn.from_string(input_dataset.urn)] + + output_dataset = Dataset( + id=f"passenger_forecast_24_12_0{i}", + name=f"Passenger Forecast 24_12_0{i}", + description=f"Forecasted airline passenger numbers for run {i}", + properties={}, + platform="s3", + schema=None, + ) + yield from output_dataset.generate_mcp() + + data_process_instance.outlets = [DatasetUrn.from_string(output_dataset.urn)] + + # Training metrics and hyperparameters + training_metrics = [ + models.MLMetricClass( + name="accuracy", + value=str(random.uniform(0.7, 0.99)), + description="Test accuracy" + ), + models.MLMetricClass( + name="f1_score", + value=str(random.uniform(0.7, 0.99)), + description="Test F1 score" + ) + ] + hyper_params = [ + models.MLHyperParamClass( + name="n_estimators", + value=str(random.randint(50, 200)), + description="Number of trees" + ), + models.MLHyperParamClass( + name="max_depth", + value=str(random.randint(5, 15)), + description="Maximum tree depth" + ) + ] + + # DPI properties + created_at = int(time.time() * 1000) + dpi_props = models.DataProcessInstancePropertiesClass( + name=f"Training {run_id}", + created=models.AuditStampClass(time=created_at, actor="urn:li:corpuser:datahub"), + customProperties={ + "framework": "statsmodels", + "python_version": "3.8", + }, + ) + + mlrun_props = models.MLTrainingRunPropertiesClass( + id="run_id", + outputUrls=["s3://mlflow/artifacts"], + hyperParams=hyper_params, + trainingMetrics=training_metrics, + externalUrl="http://mlflow:5000", + ) + + yield from data_process_instance.generate_mcp( + created_ts_millis=created_at, materialize_iolets=True + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=dpi_props, + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=mlrun_props, + ) + + # Generate start and end events + start_time_millis = run_dict[run_id]["start_time"] + duration_minutes = run_dict[run_id]["duration"] + end_time_millis = start_time_millis + duration_minutes * 60000 + result = run_dict[run_id]["result"] + result_type = "SUCCESS" if result == InstanceRunResult.SUCCESS else "FAILURE" + + yield from data_process_instance.start_event_mcp( + start_timestamp_millis=start_time_millis + ) + yield from data_process_instance.end_event_mcp( + end_timestamp_millis=end_time_millis, + result=result, + result_type=result_type, + start_timestamp_millis=start_time_millis, + ) + + # Model + selected_aliases = random.sample(model_aliases, k=random.randint(1, 2)) + selected_tags = random.sample(model_tags, 2) + yield from create_model( + model_name=model_name, + model_group_urn=str(model_group_urn), + data_process_instance_urn=str(data_process_instance.urn), + tags=selected_tags, + version_aliases=selected_aliases, + index=i, + training_metrics=training_metrics, + hyper_params=hyper_params, + model_description=model_description, + created_at=created_at, + ) + + if orchestrator == ORCHESTRATOR_AIRFLOW: + yield from data_flow.generate_mcp() + yield from data_job.generate_mcp() + + yield from input_dataset.generate_mcp() + + +if __name__ == "__main__": + token = "eyJhbGciOiJIUzI1NiJ9.eyJhY3RvclR5cGUiOiJVU0VSIiwiYWN0b3JJZCI6ImRhdGFodWIiLCJ0eXBlIjoiUEVSU09OQUwiLCJ2ZXJzaW9uIjoiMiIsImp0aSI6Ijg3MWEyZjU2LTY2MjUtNGRiMC04OTZhLTAyMzBmNmM0MmRkZCIsInN1YiI6ImRhdGFodWIiLCJleHAiOjE3Mzk2ODcwMDIsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9.HDGaXw8iBTXIEqKyIQl-jSlS8BquAXZHELP4hA9thOM" + graph_config = DatahubClientConfig( + server="http://localhost:8080", + token=token, + extra_headers={ + "Authorization": f"Bearer {token}"} + ) + graph = DataHubGraph(graph_config) + with graph: + for mcp in generate_pipeline( + "airline_forecast_pipeline_mlflow", orchestrator=ORCHESTRATOR_MLFLOW + ): + graph.emit(mcp) + for mcp in generate_pipeline( + "airline_forecast_pipeline_airflow", orchestrator=ORCHESTRATOR_AIRFLOW + ): + graph.emit(mcp) \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl index d005cd557cf77e..32329f60bfaa70 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl @@ -15,7 +15,8 @@ record DataProcessInstanceInput { @Relationship = { "/*": { "name": "Consumes", - "entityTypes": [ "dataset" ] + "entityTypes": [ "dataset" ], + "isLineage": true } } @Searchable = { diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl index fe782dbe01ca9b..6b41cf9ba63417 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl @@ -15,7 +15,9 @@ record DataProcessInstanceOutput { @Relationship = { "/*": { "name": "Produces", - "entityTypes": [ "dataset", "mlModel" ] + "entityTypes": [ "dataset", "mlModel" ], + "isLineage": true, + "isUpstream": false } } @Searchable = {