diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/DataProcessInstanceType.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/DataProcessInstanceType.java index c6cede662fa9c2..eeaaaa96f51704 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/DataProcessInstanceType.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/DataProcessInstanceType.java @@ -31,9 +31,15 @@ public class DataProcessInstanceType public static final Set ASPECTS_TO_FETCH = ImmutableSet.of( + DATA_PROCESS_INSTANCE_KEY_ASPECT_NAME, DATA_PLATFORM_INSTANCE_ASPECT_NAME, DATA_PROCESS_INSTANCE_PROPERTIES_ASPECT_NAME, + DATA_PROCESS_INSTANCE_INPUT_ASPECT_NAME, + DATA_PROCESS_INSTANCE_OUTPUT_ASPECT_NAME, + DATA_PROCESS_INSTANCE_RUN_EVENT_ASPECT_NAME, + TEST_RESULTS_ASPECT_NAME, DATA_PROCESS_INSTANCE_RELATIONSHIPS_ASPECT_NAME, + ML_TRAINING_RUN_PROPERTIES_ASPECT_NAME, SUB_TYPES_ASPECT_NAME, CONTAINER_ASPECT_NAME); @@ -90,7 +96,7 @@ public List> batchLoad( .collect(Collectors.toList()); } catch (Exception e) { - throw new RuntimeException("Failed to load schemaField entity", e); + throw new RuntimeException("Failed to load Data Process Instance entity", e); } } } diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java index e3cbcdf709326b..28c9c8936fdbfb 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java @@ -102,6 +102,9 @@ private void mapTrainingRunProperties( com.linkedin.datahub.graphql.generated.MLTrainingRunProperties properties = new com.linkedin.datahub.graphql.generated.MLTrainingRunProperties(); + if (trainingProperties.hasId()) { + properties.setId(trainingProperties.getId()); + } if (trainingProperties.hasOutputUrls()) { properties.setOutputUrls( trainingProperties.getOutputUrls().stream() @@ -133,9 +136,12 @@ private void mapDataProcessProperties( @Nonnull Urn entityUrn) { DataProcessInstanceProperties dataProcessInstanceProperties = new DataProcessInstanceProperties(dataMap); - dpi.setName(dataProcessInstanceProperties.getName()); + com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties properties = new com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties(); + + dpi.setName(dataProcessInstanceProperties.getName()); + properties.setName(dataProcessInstanceProperties.getName()); if (dataProcessInstanceProperties.hasExternalUrl()) { dpi.setExternalUrl(dataProcessInstanceProperties.getExternalUrl().toString()); properties.setExternalUrl(dataProcessInstanceProperties.getExternalUrl().toString()); diff --git a/smoke-test/tests/data_process_instance/__init__.py b/smoke-test/tests/data_process_instance/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/smoke-test/tests/data_process_instance/test_data_process_instance.py b/smoke-test/tests/data_process_instance/test_data_process_instance.py new file mode 100644 index 00000000000000..a8aca6034d5be1 --- /dev/null +++ b/smoke-test/tests/data_process_instance/test_data_process_instance.py @@ -0,0 +1,293 @@ +import logging +import os +import tempfile +from random import randint + +import pytest +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext, RecordEnvelope +from datahub.ingestion.api.sink import NoopWriteCallback +from datahub.ingestion.sink.file import FileSink, FileSinkConfig +from datahub.metadata.schema_classes import ( + AuditStampClass, + ContainerClass, + ContainerPropertiesClass, + DataPlatformInstanceClass, + DataPlatformInstancePropertiesClass, + DataProcessInstanceKeyClass, + DataProcessInstancePropertiesClass, + DataProcessInstanceRunEventClass, + MLHyperParamClass, + MLMetricClass, + MLTrainingRunPropertiesClass, + SubTypesClass, + TimeWindowSizeClass, +) + +from tests.utils import ( + delete_urns_from_file, + ingest_file_via_rest, + wait_for_writes_to_sync, +) + +logger = logging.getLogger(__name__) + +# Generate unique DPI ID +dpi_id = f"test-pipeline-run-{randint(1000, 9999)}" +dpi_urn = f"urn:li:dataProcessInstance:{dpi_id}" + + +class FileEmitter: + def __init__(self, filename: str) -> None: + self.sink: FileSink = FileSink( + ctx=PipelineContext(run_id="create_test_data"), + config=FileSinkConfig(filename=filename), + ) + + def emit(self, event): + self.sink.write_record_async( + record_envelope=RecordEnvelope(record=event, metadata={}), + write_callback=NoopWriteCallback(), + ) + + def close(self): + self.sink.close() + + +def create_test_data(filename: str): + mcps = [ + # Key aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="dataProcessInstanceKey", + aspect=DataProcessInstanceKeyClass(id=dpi_id), + ), + # Properties aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="dataProcessInstanceProperties", + aspect=DataProcessInstancePropertiesClass( + name="Test Pipeline Run", + type="BATCH_SCHEDULED", + created=AuditStampClass( + time=1640692800000, actor="urn:li:corpuser:datahub" + ), + ), + ), + # Run Event aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="dataProcessInstanceRunEvent", + aspect=DataProcessInstanceRunEventClass( + timestampMillis=1704067200000, + eventGranularity=TimeWindowSizeClass(unit="WEEK", multiple=1), + status="COMPLETE", + ), + ), + # Platform Instance aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="dataPlatformInstance", + aspect=DataPlatformInstanceClass( + platform="urn:li:dataPlatform:airflow", + instance="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)", + ), + ), + MetadataChangeProposalWrapper( + entityType="dataPlatformInstance", + entityUrn="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)", + aspectName="dataPlatformInstanceProperties", + aspect=DataPlatformInstancePropertiesClass( + name="my process instance", + ), + ), + # SubTypes aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="subTypes", + aspect=SubTypesClass(typeNames=["TEST", "BATCH_JOB"]), + ), + # Container aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="container", + aspect=ContainerClass(container="urn:li:container:testGroup1"), + ), + MetadataChangeProposalWrapper( + entityType="container", + entityUrn="urn:li:container:testGroup1", + aspectName="containerProperties", + aspect=ContainerPropertiesClass(name="testGroup1"), + ), + # ML Training Run Properties aspect + MetadataChangeProposalWrapper( + entityType="dataProcessInstance", + entityUrn=dpi_urn, + aspectName="mlTrainingRunProperties", + aspect=MLTrainingRunPropertiesClass( + id="test-training-run-123", + trainingMetrics=[ + MLMetricClass( + name="accuracy", + description="accuracy of the model", + value="0.95", + ), + MLMetricClass( + name="loss", + description="accuracy loss of the model", + value="0.05", + ), + ], + hyperParams=[ + MLHyperParamClass( + name="learningRate", + description="rate of learning", + value="0.001", + ), + MLHyperParamClass( + name="batchSize", description="size of the batch", value="32" + ), + ], + outputUrls=["s3://my-bucket/ml/output"], + ), + ), + ] + + file_emitter = FileEmitter(filename) + for mcp in mcps: + file_emitter.emit(mcp) + file_emitter.close() + + +@pytest.fixture(scope="module", autouse=False) +def ingest_cleanup_data(auth_session, graph_client, request): + new_file, filename = tempfile.mkstemp(suffix=".json") + try: + create_test_data(filename) + print("ingesting data process instance test data") + ingest_file_via_rest(auth_session, filename) + wait_for_writes_to_sync() + yield + print("removing data process instance test data") + delete_urns_from_file(graph_client, filename) + wait_for_writes_to_sync() + finally: + os.remove(filename) + + +@pytest.mark.integration +def test_search_dpi(auth_session, ingest_cleanup_data): + """Test DPI search and validation of returned fields using GraphQL.""" + + json = { + "query": """query scrollAcrossEntities($input: ScrollAcrossEntitiesInput!) { + scrollAcrossEntities(input: $input) { + nextScrollId + count + total + searchResults { + entity { + ... on DataProcessInstance { + urn + properties { + name + externalUrl + } + dataPlatformInstance { + platform { + urn + name + } + } + subTypes { + typeNames + } + container { + urn + } + platform { + urn + name + properties { + type + } + } + mlTrainingRunProperties { + id + trainingMetrics { + name + value + } + hyperParams { + name + value + } + outputUrls + } + } + } + } + } + }""", + "variables": { + "input": {"types": ["DATA_PROCESS_INSTANCE"], "query": dpi_id, "count": 10} + }, + } + + response = auth_session.post( + f"{auth_session.frontend_url()}/api/v2/graphql", json=json + ) + response.raise_for_status() + res_data = response.json() + + # Basic response structure validation + assert res_data, "Response should not be empty" + assert "data" in res_data, "Response should contain 'data' field" + print("RESPONSE DATA:" + str(res_data)) + assert ( + "scrollAcrossEntities" in res_data["data"] + ), "Response should contain 'scrollAcrossEntities' field" + + search_results = res_data["data"]["scrollAcrossEntities"] + assert ( + "searchResults" in search_results + ), "Response should contain 'searchResults' field" + + results = search_results["searchResults"] + assert len(results) > 0, "Should find at least one result" + + # Find our test entity + test_entity = None + for result in results: + if result["entity"]["urn"] == dpi_urn: + test_entity = result["entity"] + break + + assert test_entity is not None, f"Should find test entity with URN {dpi_urn}" + + # Validate fields + props = test_entity["properties"] + assert props["name"] == "Test Pipeline Run" + + platform_instance = test_entity["dataPlatformInstance"] + assert platform_instance["platform"]["urn"] == "urn:li:dataPlatform:airflow" + + sub_types = test_entity["subTypes"] + assert set(sub_types["typeNames"]) == {"TEST", "BATCH_JOB"} + + container = test_entity["container"] + assert container["urn"] == "urn:li:container:testGroup1" + + ml_props = test_entity["mlTrainingRunProperties"] + assert ml_props["id"] == "test-training-run-123" + assert ml_props["trainingMetrics"][0] == {"name": "accuracy", "value": "0.95"} + assert ml_props["trainingMetrics"][1] == {"name": "loss", "value": "0.05"} + assert ml_props["hyperParams"][0] == {"name": "learningRate", "value": "0.001"} + assert ml_props["hyperParams"][1] == {"name": "batchSize", "value": "32"} + assert ml_props["outputUrls"][0] == "s3://my-bucket/ml/output" diff --git a/smoke-test/tests/dataprocessinst/test_dataprocessinst.py b/smoke-test/tests/dataprocessinst/test_dataprocessinst.py deleted file mode 100644 index a1fbe28769b8b5..00000000000000 --- a/smoke-test/tests/dataprocessinst/test_dataprocessinst.py +++ /dev/null @@ -1,220 +0,0 @@ -import logging -import os -import tempfile -import time -from random import randint -import pytest -from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.metadata.schema_classes import ( - MLMetricClass, - MLHyperParamClass, - DataProcessInstancePropertiesClass, - DataProcessInstanceKeyClass, - MLTrainingRunPropertiesClass, - AuditStampClass, -) -from tests.consistency_utils import wait_for_writes_to_sync - -logger = logging.getLogger(__name__) - - -def create_sample_dpi(): - """Create a sample DataProcessInstance with realistic ML training properties""" - # Generate timestamps - current_time = int(time.time() * 1000) - run_id = "run_abcde" - dpi_urn = f"urn:li:dataProcessInstance:{run_id}" - - logger.info(f"Creating DPI with URN: {dpi_urn}") - - # Create key aspect - dpi_key = DataProcessInstanceKeyClass( - id=run_id - ) - - hyper_params = [ - MLHyperParamClass( - name="alpha", - value="0.05" - ), - MLHyperParamClass( - name="beta", - value="0.95" - ) - ] - - metrics = [ - MLMetricClass( - name="mse", - value="0.05" - ), - MLMetricClass( - name="r2", - value="0.95" - ) - ] - - # Create DPI properties - dpi_props = DataProcessInstancePropertiesClass( - name=f"Training {run_id}", - type="BATCH_SCHEDULED", - created=AuditStampClass(time=current_time, actor="urn:li:corpuser:datahub"), - externalUrl="http://mlflow:5000", - customProperties={ - "framework": "statsmodels", - "python_version": "3.8", - }, - ) - - dpi_ml_props = MLTrainingRunPropertiesClass( - hyperParams=hyper_params, - trainingMetrics=metrics, - outputUrls=["s3://my-bucket/ml/output"], - ) - - # Create the MCPs - one for the key, one for properties - mcps = [ - # Key aspect - MetadataChangeProposalWrapper( - entityUrn=dpi_urn, - entityType="dataProcessInstance", - aspectName="dataProcessInstanceKey", - changeType="UPSERT", - aspect=dpi_key - ), - # Properties aspect - MetadataChangeProposalWrapper( - entityUrn=dpi_urn, - entityType="dataProcessInstance", - aspectName="dataProcessInstanceProperties", - changeType="UPSERT", - aspect=dpi_props - ), - MetadataChangeProposalWrapper( - entityUrn=dpi_urn, - entityType="dataProcessInstance", - aspectName="mlTrainingRunProperties", - changeType="UPSERT", - aspect=dpi_ml_props - ) - ] - return mcps - - -@pytest.fixture(scope="module") -def ingest_cleanup_data(auth_session, graph_client, request): - """Fixture to handle test data setup and cleanup""" - try: - logger.info("Starting DPI test data creation") - mcps = create_sample_dpi() - - # Emit MCPs directly using graph client - for mcp in mcps: - logger.info(f"Emitting aspect: {mcp.aspect}") - graph_client.emit(mcp) - - wait_for_writes_to_sync() - - # Verify entity exists - dpi_urn = "urn:li:dataProcessInstance:run_abcde" - logger.info(f"Verifying entity exists in graph... {dpi_urn}") - - # Try getting aspect - dpi_props = graph_client.get_aspect( - dpi_urn, - DataProcessInstancePropertiesClass - ) - dpi_key = graph_client.get_aspect( - dpi_urn, - DataProcessInstanceKeyClass - ) - dpi_ml_props = graph_client.get_aspect( - dpi_urn, - MLTrainingRunPropertiesClass - ) - - logger.info(f"DPI properties from graph: {dpi_props}") - logger.info(f"DPI key from graph: {dpi_key}") - logger.info(f"DPI ML properties from graph: {dpi_ml_props}") - - yield - - logger.info("Cleaning up test data") - graph_client.hard_delete_entity(dpi_urn) - wait_for_writes_to_sync() - - except Exception as e: - logger.error(f"Error in test setup/cleanup: {str(e)}") - logger.error(f"Full exception: {e.__class__.__name__}") - raise - - -def test_get_dpi(auth_session, ingest_cleanup_data): - """Test getting a specific DPI entity""" - logger.info("Starting DPI query test") - - json = { - "query": """query dataProcessInstance($urn: String!) { - dataProcessInstance(urn: $urn) { - urn - type - properties { - name - created { - time - actor - } - customProperties { - key - value - } - externalUrl - } - mlTrainingRunProperties { - hyperParams { - name - value - } - trainingMetrics { - name - value - } - outputUrls - } - } - }""", - "variables": { - "urn": "urn:li:dataProcessInstance:run_abcde" - } - } - - # Send GraphQL query - logger.info("Sending GraphQL query") - response = auth_session.post(f"{auth_session.frontend_url()}/api/v2/graphql", json=json) - response.raise_for_status() - res_data = response.json() - - logger.info(f"Response data: {res_data}") - - # Basic response structure validation - assert res_data, "Response should not be empty" - assert "data" in res_data, "Response should contain 'data' field" - assert "dataProcessInstance" in res_data["data"], "Response should contain 'dataProcessInstance' field" - - dpi = res_data["data"]["dataProcessInstance"] - assert dpi, "DPI should not be null" - assert "urn" in dpi, "DPI should have URN" - assert dpi["urn"] == "urn:li:dataProcessInstance:run_abcde", "URN should match expected value" - - # Validate properties if present - if "properties" in dpi and dpi["properties"]: - props = dpi["properties"] - assert "name" in props, "Properties should contain name" - assert "created" in props, "Properties should contain created" - assert "customProperties" in props, "Properties should contain customProperties" - - if "mlTrainingRunProperties" in dpi and dpi["mlTrainingRunProperties"]: - ml_props = dpi["mlTrainingRunProperties"] - assert "hyperParams" in ml_props, "ML properties should contain hyperParams" - assert "trainingMetrics" in ml_props, "ML properties should contain trainingMetrics" - assert "outputUrls" in ml_props, "ML properties should contain outputUrls" \ No newline at end of file diff --git a/smoke-test/tests/ml_models/__init__.py b/smoke-test/tests/ml_models/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/smoke-test/tests/ml_models/test_ml_models.py b/smoke-test/tests/ml_models/test_ml_models.py new file mode 100644 index 00000000000000..59821ab3e3cc41 --- /dev/null +++ b/smoke-test/tests/ml_models/test_ml_models.py @@ -0,0 +1,133 @@ +import logging +import os +import tempfile +from random import randint + +import pytest +from datahub.emitter.mce_builder import make_ml_model_group_urn, make_ml_model_urn +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext, RecordEnvelope +from datahub.ingestion.api.sink import NoopWriteCallback +from datahub.ingestion.graph.client import DataHubGraph +from datahub.ingestion.sink.file import FileSink, FileSinkConfig +from datahub.metadata.schema_classes import ( + MLModelGroupPropertiesClass, + MLModelPropertiesClass, +) + +from tests.utils import ( + delete_urns_from_file, + get_sleep_info, + ingest_file_via_rest, + wait_for_writes_to_sync, +) + +logger = logging.getLogger(__name__) + +# Generate unique model names for testing +start_index = randint(10, 10000) +model_names = [f"test_model_{i}" for i in range(start_index, start_index + 3)] +model_group_urn = make_ml_model_group_urn("workbench", "test_group", "DEV") +model_urns = [make_ml_model_urn("workbench", name, "DEV") for name in model_names] + + +class FileEmitter: + def __init__(self, filename: str) -> None: + self.sink: FileSink = FileSink( + ctx=PipelineContext(run_id="create_test_data"), + config=FileSinkConfig(filename=filename), + ) + + def emit(self, event): + self.sink.write_record_async( + record_envelope=RecordEnvelope(record=event, metadata={}), + write_callback=NoopWriteCallback(), + ) + + def close(self): + self.sink.close() + + +def create_test_data(filename: str): + # Create model group + model_group_mcp = MetadataChangeProposalWrapper( + entityUrn=str(model_group_urn), + aspect=MLModelGroupPropertiesClass( + description="Test model group for integration testing", + trainingJobs=["urn:li:dataProcessInstance:test_job"], + ), + ) + + # Create models that belong to the group + model_mcps = [ + MetadataChangeProposalWrapper( + entityUrn=model_urn, + aspect=MLModelPropertiesClass( + name=f"Test Model ({model_urn})", + description=f"Test model {model_urn}", + groups=[str(model_group_urn)], + trainingJobs=["urn:li:dataProcessInstance:test_job"], + ), + ) + for model_urn in model_urns + ] + + file_emitter = FileEmitter(filename) + for mcps in [model_group_mcp] + model_mcps: + file_emitter.emit(mcps) + + file_emitter.close() + + +sleep_sec, sleep_times = get_sleep_info() + + +@pytest.fixture(scope="module", autouse=False) +def ingest_cleanup_data(auth_session, graph_client, request): + new_file, filename = tempfile.mkstemp(suffix=".json") + try: + create_test_data(filename) + print("ingesting ml model test data") + ingest_file_via_rest(auth_session, filename) + wait_for_writes_to_sync() + yield + print("removing ml model test data") + delete_urns_from_file(graph_client, filename) + wait_for_writes_to_sync() + finally: + os.remove(filename) + + +@pytest.mark.integration +def test_create_ml_models(graph_client: DataHubGraph, ingest_cleanup_data): + """Test creation and validation of ML models and model groups.""" + + # Validate model group properties + fetched_group_props = graph_client.get_aspect( + str(model_group_urn), MLModelGroupPropertiesClass + ) + assert fetched_group_props is not None + assert fetched_group_props.description == "Test model group for integration testing" + assert fetched_group_props.trainingJobs == ["urn:li:dataProcessInstance:test_job"] + + # Validate individual models + for model_urn in model_urns: + fetched_model_props = graph_client.get_aspect(model_urn, MLModelPropertiesClass) + assert fetched_model_props is not None + assert fetched_model_props.name == f"Test Model ({model_urn})" + assert fetched_model_props.description == f"Test model {model_urn}" + assert str(model_group_urn) in (fetched_model_props.groups or []) + assert fetched_model_props.trainingJobs == [ + "urn:li:dataProcessInstance:test_job" + ] + + # Validate relationships between models and group + related_models = set() + for e in graph_client.get_related_entities( + str(model_group_urn), + relationship_types=["MemberOf"], + direction=DataHubGraph.RelationshipDirection.INCOMING, + ): + related_models.add(e.urn) + + assert set(model_urns) == related_models