diff --git a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx
index 30657286e7c350..d26c8a9431f280 100644
--- a/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx
+++ b/datahub-web-react/src/app/entity/dataProcessInstance/profile/DataProcessInstanceSummary.tsx
@@ -71,9 +71,6 @@ export default function MLModelSummary() {
{formatDate(dpi?.properties?.created?.time)}
-
- {dpi?.properties?.created?.actor}
-
{formatStatus(dpi?.state)}
@@ -83,6 +80,9 @@ export default function MLModelSummary() {
{dpi?.mlTrainingRunProperties?.id}
+
+ {dpi?.properties?.created?.actor}
+
diff --git a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx
index db1bdbaacaca9a..fd279230de310c 100644
--- a/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx
+++ b/datahub-web-react/src/app/entity/mlModel/profile/MLModelSummary.tsx
@@ -104,11 +104,11 @@ export default function MLModelSummary() {
{formatDate(model?.properties?.lastModified?.time)}
-
-
{model?.properties?.created?.actor}
+
+
{model?.versionProperties?.aliases?.map((alias, index) => (
diff --git a/metadata-ingestion/examples/ml/create_ml.py b/metadata-ingestion/examples/ml/create_ml.py
new file mode 100644
index 00000000000000..3ac616cdd730b5
--- /dev/null
+++ b/metadata-ingestion/examples/ml/create_ml.py
@@ -0,0 +1,359 @@
+import time
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from typing import Iterable, List, Optional, Union
+import random
+
+import datahub.metadata.schema_classes as models
+from datahub.api.entities.datajob import DataFlow, DataJob
+from datahub.api.entities.dataprocess.dataprocess_instance import (
+ DataProcessInstance,
+ InstanceRunResult,
+)
+from datahub.api.entities.dataset.dataset import Dataset
+from datahub.emitter.mcp import MetadataChangeProposalWrapper
+from datahub.ingestion.graph.client import DataHubGraph, DatahubClientConfig
+from datahub.metadata.urns import DatasetUrn, DataPlatformUrn, MlModelGroupUrn, MlModelUrn, VersionSetUrn
+from datahub.emitter.mcp_builder import ContainerKey
+
+ORCHESTRATOR_MLFLOW = "mlflow"
+ORCHESTRATOR_AIRFLOW = "airflow"
+
+
+class ContainerKeyWithId(ContainerKey):
+ id: str
+
+
+@dataclass
+class Container:
+ key: ContainerKeyWithId
+ subtype: str
+ name: Optional[str] = None
+ description: Optional[str] = None
+
+ def generate_mcp(
+ self,
+ ) -> Iterable[
+ Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper]
+ ]:
+ container_urn = self.key.as_urn()
+ container_subtype = models.SubTypesClass(typeNames=[self.subtype])
+ container_info = models.ContainerPropertiesClass(
+ name=self.name or self.key.id,
+ description=self.description,
+ customProperties={},
+ )
+ browse_path = models.BrowsePathsV2Class(path=[])
+ dpi = models.DataPlatformInstanceClass(
+ platform=self.key.platform,
+ instance=self.key.instance,
+ )
+
+ yield from MetadataChangeProposalWrapper.construct_many(
+ entityUrn=container_urn,
+ aspects=[container_subtype, container_info, browse_path, dpi],
+ )
+
+
+def create_model(
+ model_name: str,
+ model_group_urn: str,
+ data_process_instance_urn: str,
+ tags: List[str],
+ version_aliases: List[str],
+ index: int,
+ training_metrics: List[models.MLMetricClass],
+ hyper_params: List[models.MLHyperParamClass],
+ model_description: str,
+ created_at: int,
+) -> Iterable[MetadataChangeProposalWrapper]:
+ model_urn = MlModelUrn(platform="mlflow", name=model_name)
+
+ # Create model properties
+ model_info = models.MLModelPropertiesClass(
+ description=model_description,
+ version=models.VersionTagClass(versionTag=f"{index}"),
+ groups=[str(model_group_urn)],
+ trainingJobs=[str(data_process_instance_urn)],
+ date=created_at,
+ tags=tags,
+ trainingMetrics=training_metrics,
+ hyperParams=hyper_params,
+ created=models.TimeStampClass(
+ time=created_at,
+ actor="urn:li:corpuser:datahub"
+ ),
+ lastModified=models.TimeStampClass(
+ time=created_at,
+ actor="urn:li:corpuser:datahub"
+ ),
+ )
+
+ # Create version set
+ version_set_urn = VersionSetUrn(id=f"mlmodel_{model_name}_versions", entity_type="mlModel")
+ version_entity = models.VersionSetPropertiesClass(
+ latest=str(model_urn),
+ versioningScheme="ALPHANUMERIC_GENERATED_BY_DATAHUB",
+ )
+
+ # Create version properties
+ model_version_info = models.VersionPropertiesClass(
+ version=models.VersionTagClass(versionTag=f"{index}"),
+ versionSet=str(version_set_urn),
+ aliases=[models.VersionTagClass(versionTag=alias) for alias in version_aliases],
+ sortId="AAAAAAAA",
+ )
+
+ # Yield all MCPs
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(version_set_urn),
+ entityType="versionSet",
+ aspectName="versionSetProperties",
+ aspect=version_entity,
+ changeType=models.ChangeTypeClass.UPSERT
+ )
+
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(model_urn),
+ entityType="mlModel",
+ aspectName="versionProperties",
+ aspect=model_version_info,
+ changeType=models.ChangeTypeClass.UPSERT
+ )
+
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(model_urn),
+ entityType="mlModel",
+ aspectName="mlModelProperties",
+ aspect=model_info,
+ changeType=models.ChangeTypeClass.UPSERT
+ )
+
+
+def generate_pipeline(
+ pipeline_name: str,
+ orchestrator: str,
+) -> Iterable[Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper]]:
+ data_flow = DataFlow(
+ id=pipeline_name,
+ orchestrator=orchestrator,
+ cluster="default",
+ name=pipeline_name,
+ )
+
+ data_job = DataJob(id="training", flow_urn=data_flow.urn, name="Training")
+
+ input_dataset = Dataset(
+ id="airline_passengers",
+ name="Airline Passengers",
+ description="Monthly airline passenger data",
+ properties={},
+ platform="s3",
+ schema=None,
+ )
+
+ if orchestrator == ORCHESTRATOR_MLFLOW:
+ experiment = Container(
+ key=ContainerKeyWithId(
+ platform=str(DataPlatformUrn.create_from_id("mlflow")),
+ id="airline_forecast_experiment",
+ ),
+ subtype="ML Experiment",
+ name="Airline Forecast Experiment",
+ description="Experiment for forecasting airline passengers",
+ )
+
+ yield from experiment.generate_mcp()
+
+ model_group_urn = MlModelGroupUrn(platform="mlflow", name="airline_forecast_models")
+ current_time = int(time.time() * 1000)
+ model_group_info = models.MLModelGroupPropertiesClass(
+ description="ML models for airline passenger forecasting",
+ customProperties={
+ "stage": "production",
+ "team": "data_science",
+ },
+ created=models.TimeStampClass(
+ time=current_time,
+ actor="urn:li:corpuser:datahub"
+ ),
+ lastModified=models.TimeStampClass(
+ time=current_time,
+ actor="urn:li:corpuser:datahub"
+ ),
+ )
+
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(model_group_urn),
+ entityType="mlModelGroup",
+ aspectName="mlModelGroupProperties",
+ aspect=model_group_info,
+ changeType=models.ChangeTypeClass.UPSERT
+ )
+
+ model_aliases = ["challenger", "champion", "production", "experimental", "deprecated"]
+ model_tags = ["stage:production", "stage:development", "team:data_science", "team:ml_engineering",
+ "team:analytics"]
+
+ model_dict = {
+ "arima_model_1": "ARIMA model for airline passenger forecasting",
+ "arima_model_2": "Enhanced ARIMA model with seasonal components",
+ "arima_model_3": "ARIMA model optimized for long-term forecasting",
+ "arima_model_4": "ARIMA model with hyperparameter tuning",
+ "arima_model_5": "ARIMA model trained on extended dataset",
+ }
+
+ # Generate run timestamps within the last month
+ end_time = int(time.time() * 1000)
+ start_time = end_time - (30 * 24 * 60 * 60 * 1000)
+ run_timestamps = [
+ start_time + (i * 5 * 24 * 60 * 60 * 1000)
+ for i in range(5)
+ ]
+
+ run_dict = {
+ "run_1": {"start_time": run_timestamps[0], "duration": 45, "result": InstanceRunResult.SUCCESS},
+ "run_2": {"start_time": run_timestamps[1], "duration": 60, "result": InstanceRunResult.FAILURE},
+ "run_3": {"start_time": run_timestamps[2], "duration": 55, "result": InstanceRunResult.SUCCESS},
+ "run_4": {"start_time": run_timestamps[3], "duration": 70, "result": InstanceRunResult.SUCCESS},
+ "run_5": {"start_time": run_timestamps[4], "duration": 50, "result": InstanceRunResult.FAILURE},
+ }
+
+ for i, (model_name, model_description) in enumerate(model_dict.items(), start=1):
+ run_id = f"run_{i}"
+ data_process_instance = DataProcessInstance.from_container(
+ container_key=experiment.key, id=run_id
+ )
+
+ data_process_instance.subtype = "Training Run"
+ data_process_instance.inlets = [DatasetUrn.from_string(input_dataset.urn)]
+
+ output_dataset = Dataset(
+ id=f"passenger_forecast_24_12_0{i}",
+ name=f"Passenger Forecast 24_12_0{i}",
+ description=f"Forecasted airline passenger numbers for run {i}",
+ properties={},
+ platform="s3",
+ schema=None,
+ )
+ yield from output_dataset.generate_mcp()
+
+ data_process_instance.outlets = [DatasetUrn.from_string(output_dataset.urn)]
+
+ # Training metrics and hyperparameters
+ training_metrics = [
+ models.MLMetricClass(
+ name="accuracy",
+ value=str(random.uniform(0.7, 0.99)),
+ description="Test accuracy"
+ ),
+ models.MLMetricClass(
+ name="f1_score",
+ value=str(random.uniform(0.7, 0.99)),
+ description="Test F1 score"
+ )
+ ]
+ hyper_params = [
+ models.MLHyperParamClass(
+ name="n_estimators",
+ value=str(random.randint(50, 200)),
+ description="Number of trees"
+ ),
+ models.MLHyperParamClass(
+ name="max_depth",
+ value=str(random.randint(5, 15)),
+ description="Maximum tree depth"
+ )
+ ]
+
+ # DPI properties
+ created_at = int(time.time() * 1000)
+ dpi_props = models.DataProcessInstancePropertiesClass(
+ name=f"Training {run_id}",
+ created=models.AuditStampClass(time=created_at, actor="urn:li:corpuser:datahub"),
+ customProperties={
+ "framework": "statsmodels",
+ "python_version": "3.8",
+ },
+ )
+
+ mlrun_props = models.MLTrainingRunPropertiesClass(
+ id="run_id",
+ outputUrls=["s3://mlflow/artifacts"],
+ hyperParams=hyper_params,
+ trainingMetrics=training_metrics,
+ externalUrl="http://mlflow:5000",
+ )
+
+ yield from data_process_instance.generate_mcp(
+ created_ts_millis=created_at, materialize_iolets=True
+ )
+
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(data_process_instance.urn),
+ aspect=dpi_props,
+ )
+
+ yield MetadataChangeProposalWrapper(
+ entityUrn=str(data_process_instance.urn),
+ aspect=mlrun_props,
+ )
+
+ # Generate start and end events
+ start_time_millis = run_dict[run_id]["start_time"]
+ duration_minutes = run_dict[run_id]["duration"]
+ end_time_millis = start_time_millis + duration_minutes * 60000
+ result = run_dict[run_id]["result"]
+ result_type = "SUCCESS" if result == InstanceRunResult.SUCCESS else "FAILURE"
+
+ yield from data_process_instance.start_event_mcp(
+ start_timestamp_millis=start_time_millis
+ )
+ yield from data_process_instance.end_event_mcp(
+ end_timestamp_millis=end_time_millis,
+ result=result,
+ result_type=result_type,
+ start_timestamp_millis=start_time_millis,
+ )
+
+ # Model
+ selected_aliases = random.sample(model_aliases, k=random.randint(1, 2))
+ selected_tags = random.sample(model_tags, 2)
+ yield from create_model(
+ model_name=model_name,
+ model_group_urn=str(model_group_urn),
+ data_process_instance_urn=str(data_process_instance.urn),
+ tags=selected_tags,
+ version_aliases=selected_aliases,
+ index=i,
+ training_metrics=training_metrics,
+ hyper_params=hyper_params,
+ model_description=model_description,
+ created_at=created_at,
+ )
+
+ if orchestrator == ORCHESTRATOR_AIRFLOW:
+ yield from data_flow.generate_mcp()
+ yield from data_job.generate_mcp()
+
+ yield from input_dataset.generate_mcp()
+
+
+if __name__ == "__main__":
+ token = "eyJhbGciOiJIUzI1NiJ9.eyJhY3RvclR5cGUiOiJVU0VSIiwiYWN0b3JJZCI6ImRhdGFodWIiLCJ0eXBlIjoiUEVSU09OQUwiLCJ2ZXJzaW9uIjoiMiIsImp0aSI6Ijg3MWEyZjU2LTY2MjUtNGRiMC04OTZhLTAyMzBmNmM0MmRkZCIsInN1YiI6ImRhdGFodWIiLCJleHAiOjE3Mzk2ODcwMDIsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9.HDGaXw8iBTXIEqKyIQl-jSlS8BquAXZHELP4hA9thOM"
+ graph_config = DatahubClientConfig(
+ server="http://localhost:8080",
+ token=token,
+ extra_headers={
+ "Authorization": f"Bearer {token}"}
+ )
+ graph = DataHubGraph(graph_config)
+ with graph:
+ for mcp in generate_pipeline(
+ "airline_forecast_pipeline_mlflow", orchestrator=ORCHESTRATOR_MLFLOW
+ ):
+ graph.emit(mcp)
+ for mcp in generate_pipeline(
+ "airline_forecast_pipeline_airflow", orchestrator=ORCHESTRATOR_AIRFLOW
+ ):
+ graph.emit(mcp)
\ No newline at end of file
diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl
index d005cd557cf77e..32329f60bfaa70 100644
--- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl
+++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl
@@ -15,7 +15,8 @@ record DataProcessInstanceInput {
@Relationship = {
"/*": {
"name": "Consumes",
- "entityTypes": [ "dataset" ]
+ "entityTypes": [ "dataset" ],
+ "isLineage": true
}
}
@Searchable = {
diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl
index fe782dbe01ca9b..6b41cf9ba63417 100644
--- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl
+++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl
@@ -15,7 +15,9 @@ record DataProcessInstanceOutput {
@Relationship = {
"/*": {
"name": "Produces",
- "entityTypes": [ "dataset", "mlModel" ]
+ "entityTypes": [ "dataset", "mlModel" ],
+ "isLineage": true,
+ "isUpstream": false
}
}
@Searchable = {