Skip to content

Commit

Permalink
smoke test + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanHolstien committed Jan 13, 2025
1 parent 9a60c1d commit 55a77b0
Show file tree
Hide file tree
Showing 7 changed files with 440 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@ public class DataProcessInstanceType

public static final Set<String> ASPECTS_TO_FETCH =
ImmutableSet.of(
DATA_PROCESS_INSTANCE_KEY_ASPECT_NAME,
DATA_PLATFORM_INSTANCE_ASPECT_NAME,
DATA_PROCESS_INSTANCE_PROPERTIES_ASPECT_NAME,
DATA_PROCESS_INSTANCE_INPUT_ASPECT_NAME,
DATA_PROCESS_INSTANCE_OUTPUT_ASPECT_NAME,
DATA_PROCESS_INSTANCE_RUN_EVENT_ASPECT_NAME,
TEST_RESULTS_ASPECT_NAME,
DATA_PROCESS_INSTANCE_RELATIONSHIPS_ASPECT_NAME,
ML_TRAINING_RUN_PROPERTIES_ASPECT_NAME,
SUB_TYPES_ASPECT_NAME,
CONTAINER_ASPECT_NAME);

Expand Down Expand Up @@ -90,7 +96,7 @@ public List<DataFetcherResult<DataProcessInstance>> batchLoad(
.collect(Collectors.toList());

} catch (Exception e) {
throw new RuntimeException("Failed to load schemaField entity", e);
throw new RuntimeException("Failed to load Data Process Instance entity", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ private void mapTrainingRunProperties(

com.linkedin.datahub.graphql.generated.MLTrainingRunProperties properties =
new com.linkedin.datahub.graphql.generated.MLTrainingRunProperties();
if (trainingProperties.hasId()) {
properties.setId(trainingProperties.getId());
}
if (trainingProperties.hasOutputUrls()) {
properties.setOutputUrls(
trainingProperties.getOutputUrls().stream()
Expand Down Expand Up @@ -133,9 +136,12 @@ private void mapDataProcessProperties(
@Nonnull Urn entityUrn) {
DataProcessInstanceProperties dataProcessInstanceProperties =
new DataProcessInstanceProperties(dataMap);
dpi.setName(dataProcessInstanceProperties.getName());

com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties properties =
new com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties();

dpi.setName(dataProcessInstanceProperties.getName());
properties.setName(dataProcessInstanceProperties.getName());
if (dataProcessInstanceProperties.hasExternalUrl()) {
dpi.setExternalUrl(dataProcessInstanceProperties.getExternalUrl().toString());
properties.setExternalUrl(dataProcessInstanceProperties.getExternalUrl().toString());
Expand Down
Empty file.
293 changes: 293 additions & 0 deletions smoke-test/tests/data_process_instance/test_data_process_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import logging
import os
import tempfile
from random import randint

import pytest
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.sink import NoopWriteCallback
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.metadata.schema_classes import (
AuditStampClass,
ContainerClass,
ContainerPropertiesClass,
DataPlatformInstanceClass,
DataPlatformInstancePropertiesClass,
DataProcessInstanceKeyClass,
DataProcessInstancePropertiesClass,
DataProcessInstanceRunEventClass,
MLHyperParamClass,
MLMetricClass,
MLTrainingRunPropertiesClass,
SubTypesClass,
TimeWindowSizeClass,
)

from tests.utils import (
delete_urns_from_file,
ingest_file_via_rest,
wait_for_writes_to_sync,
)

logger = logging.getLogger(__name__)

# Generate unique DPI ID
dpi_id = f"test-pipeline-run-{randint(1000, 9999)}"
dpi_urn = f"urn:li:dataProcessInstance:{dpi_id}"


class FileEmitter:
def __init__(self, filename: str) -> None:
self.sink: FileSink = FileSink(
ctx=PipelineContext(run_id="create_test_data"),
config=FileSinkConfig(filename=filename),
)

def emit(self, event):
self.sink.write_record_async(
record_envelope=RecordEnvelope(record=event, metadata={}),
write_callback=NoopWriteCallback(),
)

def close(self):
self.sink.close()


def create_test_data(filename: str):
mcps = [
# Key aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceKey",
aspect=DataProcessInstanceKeyClass(id=dpi_id),
),
# Properties aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceProperties",
aspect=DataProcessInstancePropertiesClass(
name="Test Pipeline Run",
type="BATCH_SCHEDULED",
created=AuditStampClass(
time=1640692800000, actor="urn:li:corpuser:datahub"
),
),
),
# Run Event aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataProcessInstanceRunEvent",
aspect=DataProcessInstanceRunEventClass(
timestampMillis=1704067200000,
eventGranularity=TimeWindowSizeClass(unit="WEEK", multiple=1),
status="COMPLETE",
),
),
# Platform Instance aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="dataPlatformInstance",
aspect=DataPlatformInstanceClass(
platform="urn:li:dataPlatform:airflow",
instance="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)",
),
),
MetadataChangeProposalWrapper(
entityType="dataPlatformInstance",
entityUrn="urn:li:dataPlatformInstance:(urn:li:dataPlatform:airflow,1234567890)",
aspectName="dataPlatformInstanceProperties",
aspect=DataPlatformInstancePropertiesClass(
name="my process instance",
),
),
# SubTypes aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="subTypes",
aspect=SubTypesClass(typeNames=["TEST", "BATCH_JOB"]),
),
# Container aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="container",
aspect=ContainerClass(container="urn:li:container:testGroup1"),
),
MetadataChangeProposalWrapper(
entityType="container",
entityUrn="urn:li:container:testGroup1",
aspectName="containerProperties",
aspect=ContainerPropertiesClass(name="testGroup1"),
),
# ML Training Run Properties aspect
MetadataChangeProposalWrapper(
entityType="dataProcessInstance",
entityUrn=dpi_urn,
aspectName="mlTrainingRunProperties",
aspect=MLTrainingRunPropertiesClass(
id="test-training-run-123",
trainingMetrics=[
MLMetricClass(
name="accuracy",
description="accuracy of the model",
value="0.95",
),
MLMetricClass(
name="loss",
description="accuracy loss of the model",
value="0.05",
),
],
hyperParams=[
MLHyperParamClass(
name="learningRate",
description="rate of learning",
value="0.001",
),
MLHyperParamClass(
name="batchSize", description="size of the batch", value="32"
),
],
outputUrls=["s3://my-bucket/ml/output"],
),
),
]

file_emitter = FileEmitter(filename)
for mcp in mcps:
file_emitter.emit(mcp)
file_emitter.close()


@pytest.fixture(scope="module", autouse=False)
def ingest_cleanup_data(auth_session, graph_client, request):
new_file, filename = tempfile.mkstemp(suffix=".json")
try:
create_test_data(filename)
print("ingesting data process instance test data")
ingest_file_via_rest(auth_session, filename)
wait_for_writes_to_sync()
yield
print("removing data process instance test data")
delete_urns_from_file(graph_client, filename)
wait_for_writes_to_sync()
finally:
os.remove(filename)


@pytest.mark.integration
def test_search_dpi(auth_session, ingest_cleanup_data):
"""Test DPI search and validation of returned fields using GraphQL."""

json = {
"query": """query scrollAcrossEntities($input: ScrollAcrossEntitiesInput!) {
scrollAcrossEntities(input: $input) {
nextScrollId
count
total
searchResults {
entity {
... on DataProcessInstance {
urn
properties {
name
externalUrl
}
dataPlatformInstance {
platform {
urn
name
}
}
subTypes {
typeNames
}
container {
urn
}
platform {
urn
name
properties {
type
}
}
mlTrainingRunProperties {
id
trainingMetrics {
name
value
}
hyperParams {
name
value
}
outputUrls
}
}
}
}
}
}""",
"variables": {
"input": {"types": ["DATA_PROCESS_INSTANCE"], "query": dpi_id, "count": 10}
},
}

response = auth_session.post(
f"{auth_session.frontend_url()}/api/v2/graphql", json=json
)
response.raise_for_status()
res_data = response.json()

# Basic response structure validation
assert res_data, "Response should not be empty"
assert "data" in res_data, "Response should contain 'data' field"
print("RESPONSE DATA:" + str(res_data))
assert (
"scrollAcrossEntities" in res_data["data"]
), "Response should contain 'scrollAcrossEntities' field"

search_results = res_data["data"]["scrollAcrossEntities"]
assert (
"searchResults" in search_results
), "Response should contain 'searchResults' field"

results = search_results["searchResults"]
assert len(results) > 0, "Should find at least one result"

# Find our test entity
test_entity = None
for result in results:
if result["entity"]["urn"] == dpi_urn:
test_entity = result["entity"]
break

assert test_entity is not None, f"Should find test entity with URN {dpi_urn}"

# Validate fields
props = test_entity["properties"]
assert props["name"] == "Test Pipeline Run"

platform_instance = test_entity["dataPlatformInstance"]
assert platform_instance["platform"]["urn"] == "urn:li:dataPlatform:airflow"

sub_types = test_entity["subTypes"]
assert set(sub_types["typeNames"]) == {"TEST", "BATCH_JOB"}

container = test_entity["container"]
assert container["urn"] == "urn:li:container:testGroup1"

ml_props = test_entity["mlTrainingRunProperties"]
assert ml_props["id"] == "test-training-run-123"
assert ml_props["trainingMetrics"][0] == {"name": "accuracy", "value": "0.95"}
assert ml_props["trainingMetrics"][1] == {"name": "loss", "value": "0.05"}
assert ml_props["hyperParams"][0] == {"name": "learningRate", "value": "0.001"}
assert ml_props["hyperParams"][1] == {"name": "batchSize", "value": "32"}
assert ml_props["outputUrls"][0] == "s3://my-bucket/ml/output"
Loading

0 comments on commit 55a77b0

Please sign in to comment.