From a29183cbd1bc5f66ceae272e4ca565315413eb7a Mon Sep 17 00:00:00 2001 From: sagar-salvi-apptware Date: Mon, 20 Jan 2025 13:37:38 +0530 Subject: [PATCH] fix(ingest/ruff): added flake8-comprehensions to ruff --- .../ecommerce/01_snowflake_load.py | 4 +- .../01-operation/marketing/01_send_emails.py | 2 +- .../02-assertion/marketing/02_send_emails.py | 4 +- .../examples/demo_data/enrich.py | 20 +- .../add_mlfeature_to_mlfeature_table.py | 9 +- .../library/add_mlfeature_to_mlmodel.py | 10 +- .../library/add_mlgroup_to_mlmodel.py | 13 +- .../examples/library/create_domain.py | 3 +- .../examples/library/create_form.py | 4 +- .../examples/library/create_mlfeature.py | 4 +- .../library/create_mlfeature_table.py | 11 +- .../examples/library/create_mlmodel.py | 28 +- .../examples/library/create_mlmodel_group.py | 4 +- .../examples/library/create_mlprimarykey.py | 4 +- .../examples/library/dashboard_usage.py | 47 +- .../library/data_quality_mcpw_rest.py | 8 +- .../dataset_add_column_documentation.py | 6 +- .../library/dataset_add_column_tag.py | 3 +- .../library/dataset_add_column_term.py | 9 +- .../library/dataset_add_documentation.py | 11 +- .../dataset_add_glossary_term_patch.py | 4 +- .../examples/library/dataset_add_owner.py | 5 +- .../library/dataset_add_owner_patch.py | 6 +- ...dataset_add_structured_properties_patch.py | 5 +- .../examples/library/dataset_add_tag_patch.py | 4 +- .../examples/library/dataset_add_term.py | 3 +- .../dataset_add_upstream_lineage_patch.py | 17 +- .../dataset_field_add_glossary_term_patch.py | 8 +- .../library/dataset_field_add_tag_patch.py | 6 +- .../library/dataset_replace_properties.py | 9 +- .../library/dataset_report_operation.py | 4 +- .../examples/library/dataset_schema.py | 15 +- .../library/dataset_schema_with_tags_terms.py | 8 +- .../examples/library/delete_assertion.py | 2 +- .../examples/library/delete_dataset.py | 2 +- .../examples/library/delete_form.py | 2 +- .../library/lineage_dataset_job_dataset.py | 18 +- .../lineage_emitter_dataset_finegrained.py | 3 +- ...eage_emitter_dataset_finegrained_sample.py | 6 +- .../examples/library/lineage_emitter_kafka.py | 6 +- .../examples/library/lineage_job_dataflow.py | 9 +- .../lineage_job_dataflow_new_api_simple.py | 18 +- .../lineage_job_dataflow_new_api_verbose.py | 20 +- .../examples/library/programatic_pipeline.py | 2 +- .../library/read_lineage_execute_graphql.py | 8 +- .../library/report_assertion_result.py | 2 +- .../examples/library/run_assertion.py | 4 +- .../examples/library/run_assertions.py | 12 +- .../library/run_assertions_for_asset.py | 13 +- .../examples/library/update_form.py | 16 +- .../library/upsert_custom_assertion.py | 2 +- .../examples/library/upsert_group.py | 5 +- .../examples/library/upsert_user.py | 2 +- .../examples/perf/lineage_perf_example.py | 47 +- .../create_structured_property.py | 2 +- .../update_structured_property.py | 7 +- .../transforms/custom_transform_example.py | 5 +- metadata-ingestion/pyproject.toml | 2 + .../src/datahub/_codegen/aspect.py | 2 +- .../assertion_circuit_breaker.py | 12 +- .../operation_circuit_breaker.py | 2 +- .../api/entities/assertion/assertion.py | 5 +- .../entities/assertion/assertion_operator.py | 14 +- .../entities/assertion/assertion_trigger.py | 7 +- .../entities/assertion/compiler_interface.py | 5 +- .../api/entities/assertion/field_assertion.py | 2 +- .../entities/assertion/freshness_assertion.py | 6 +- .../entities/common/data_platform_instance.py | 3 +- .../api/entities/common/serialized_value.py | 16 +- .../api/entities/corpgroup/corpgroup.py | 35 +- .../datahub/api/entities/corpuser/corpuser.py | 8 +- .../datacontract/assertion_operator.py | 11 +- .../datacontract/data_quality_assertion.py | 11 +- .../api/entities/datacontract/datacontract.py | 23 +- .../datacontract/freshness_assertion.py | 8 +- .../entities/datacontract/schema_assertion.py | 6 +- .../datahub/api/entities/datajob/dataflow.py | 13 +- .../datahub/api/entities/datajob/datajob.py | 15 +- .../dataprocess/dataprocess_instance.py | 25 +- .../api/entities/dataproduct/dataproduct.py | 54 +- .../datahub/api/entities/dataset/dataset.py | 55 +- .../src/datahub/api/entities/forms/forms.py | 60 +- .../platformresource/platform_resource.py | 31 +- .../structuredproperties.py | 12 +- .../src/datahub/api/graphql/base.py | 3 +- .../src/datahub/api/graphql/operation.py | 5 +- .../src/datahub/cli/check_cli.py | 17 +- .../src/datahub/cli/cli_utils.py | 41 +- .../src/datahub/cli/config_utils.py | 14 +- .../src/datahub/cli/delete_cli.py | 104 ++- .../src/datahub/cli/docker_check.py | 16 +- .../src/datahub/cli/docker_cli.py | 53 +- metadata-ingestion/src/datahub/cli/get_cli.py | 2 +- .../src/datahub/cli/ingest_cli.py | 73 +- .../src/datahub/cli/json_file.py | 7 +- .../src/datahub/cli/lite_cli.py | 46 +- metadata-ingestion/src/datahub/cli/migrate.py | 64 +- .../src/datahub/cli/migration_utils.py | 6 +- metadata-ingestion/src/datahub/cli/put_cli.py | 17 +- .../src/datahub/cli/quickstart_versioning.py | 23 +- .../datahub/cli/specific/assertions_cli.py | 10 +- .../datahub/cli/specific/datacontract_cli.py | 9 +- .../datahub/cli/specific/dataproduct_cli.py | 43 +- .../src/datahub/cli/specific/dataset_cli.py | 9 +- .../src/datahub/cli/specific/forms_cli.py | 2 +- .../src/datahub/cli/specific/group_cli.py | 8 +- .../cli/specific/structuredproperties_cli.py | 14 +- .../src/datahub/cli/specific/user_cli.py | 6 +- .../src/datahub/cli/timeline_cli.py | 22 +- .../src/datahub/configuration/_config_enum.py | 8 +- .../src/datahub/configuration/common.py | 4 +- .../datahub/configuration/config_loader.py | 6 +- .../configuration/connection_resolver.py | 4 +- .../src/datahub/configuration/datetimes.py | 5 +- .../src/datahub/configuration/git.py | 18 +- .../configuration/kafka_consumer_config.py | 2 +- .../configuration/time_window_config.py | 12 +- .../configuration/validate_field_rename.py | 4 +- .../validate_multiline_string.py | 3 +- .../src/datahub/emitter/kafka_emitter.py | 5 +- .../src/datahub/emitter/mce_builder.py | 55 +- metadata-ingestion/src/datahub/emitter/mcp.py | 25 +- .../src/datahub/emitter/mcp_builder.py | 39 +- .../src/datahub/emitter/mcp_patch_builder.py | 9 +- .../src/datahub/emitter/request_helper.py | 7 +- .../src/datahub/emitter/rest_emitter.py | 37 +- .../datahub/emitter/serialization_helper.py | 12 +- .../datahub/emitter/sql_parsing_builder.py | 14 +- .../emitter/synchronized_file_emitter.py | 4 +- metadata-ingestion/src/datahub/entrypoints.py | 18 +- .../auto_dataset_properties_aspect.py | 8 +- .../auto_ensure_aspect_size.py | 14 +- .../src/datahub/ingestion/api/committable.py | 5 +- .../src/datahub/ingestion/api/common.py | 4 +- .../src/datahub/ingestion/api/decorators.py | 14 +- .../api/incremental_lineage_helper.py | 41 +- .../api/incremental_properties_helper.py | 11 +- .../src/datahub/ingestion/api/registry.py | 29 +- .../src/datahub/ingestion/api/report.py | 2 +- .../src/datahub/ingestion/api/sink.py | 17 +- .../src/datahub/ingestion/api/source.py | 61 +- .../datahub/ingestion/api/source_helpers.py | 35 +- .../src/datahub/ingestion/api/transform.py | 3 +- .../src/datahub/ingestion/api/workunit.py | 18 +- .../ingestion/extractor/json_ref_patch.py | 3 +- .../ingestion/extractor/json_schema_util.py | 91 ++- .../ingestion/extractor/mce_extractor.py | 10 +- .../ingestion/extractor/protobuf_util.py | 18 +- .../ingestion/extractor/schema_util.py | 57 +- .../src/datahub/ingestion/fs/s3_fs.py | 12 +- .../glossary/classification_mixin.py | 65 +- .../datahub/ingestion/glossary/classifier.py | 3 +- .../ingestion/glossary/datahub_classifier.py | 6 +- .../src/datahub/ingestion/graph/client.py | 179 +++-- .../src/datahub/ingestion/graph/filters.py | 3 +- .../datahub_ingestion_run_summary_provider.py | 13 +- .../ingestion/reporting/file_reporter.py | 2 +- .../reporting/reporting_provider_registry.py | 2 +- .../src/datahub/ingestion/run/connection.py | 2 +- .../src/datahub/ingestion/run/pipeline.py | 112 ++- .../datahub/ingestion/run/pipeline_config.py | 12 +- .../src/datahub/ingestion/sink/blackhole.py | 4 +- .../src/datahub/ingestion/sink/console.py | 4 +- .../datahub/ingestion/sink/datahub_kafka.py | 8 +- .../datahub/ingestion/sink/datahub_rest.py | 29 +- .../src/datahub/ingestion/sink/file.py | 3 +- .../datahub/ingestion/source/abs/config.py | 35 +- .../source/abs/datalake_profiler_config.py | 9 +- .../datahub/ingestion/source/abs/profiling.py | 31 +- .../datahub/ingestion/source/abs/source.py | 84 +- .../ingestion/source/aws/aws_common.py | 15 +- .../src/datahub/ingestion/source/aws/glue.py | 184 +++-- .../ingestion/source/aws/s3_boto_utils.py | 13 +- .../datahub/ingestion/source/aws/s3_util.py | 6 +- .../datahub/ingestion/source/aws/sagemaker.py | 14 +- .../source/aws/sagemaker_processors/common.py | 9 +- .../sagemaker_processors/feature_groups.py | 36 +- .../source/aws/sagemaker_processors/jobs.py | 78 +- .../aws/sagemaker_processors/lineage.py | 18 +- .../source/aws/sagemaker_processors/models.py | 70 +- .../source/azure/abs_folder_utils.py | 27 +- .../ingestion/source/azure/abs_utils.py | 8 +- .../ingestion/source/azure/azure_common.py | 4 +- .../ingestion/source/bigquery_v2/bigquery.py | 21 +- .../source/bigquery_v2/bigquery_audit.py | 74 +- .../bigquery_v2/bigquery_audit_log_api.py | 8 +- .../source/bigquery_v2/bigquery_config.py | 48 +- .../bigquery_v2/bigquery_data_reader.py | 4 +- .../source/bigquery_v2/bigquery_helper.py | 3 +- .../bigquery_platform_resource_helper.py | 18 +- .../source/bigquery_v2/bigquery_queries.py | 8 +- .../source/bigquery_v2/bigquery_report.py | 30 +- .../source/bigquery_v2/bigquery_schema.py | 75 +- .../source/bigquery_v2/bigquery_schema_gen.py | 200 +++-- .../bigquery_v2/bigquery_test_connection.py | 25 +- .../ingestion/source/bigquery_v2/common.py | 12 +- .../ingestion/source/bigquery_v2/lineage.py | 157 ++-- .../ingestion/source/bigquery_v2/profiler.py | 60 +- .../ingestion/source/bigquery_v2/queries.py | 4 +- .../source/bigquery_v2/queries_extractor.py | 33 +- .../ingestion/source/bigquery_v2/usage.py | 116 ++- .../ingestion/source/cassandra/cassandra.py | 69 +- .../source/cassandra/cassandra_api.py | 30 +- .../source/cassandra/cassandra_config.py | 12 +- .../source/cassandra/cassandra_profiling.py | 27 +- .../source/cassandra/cassandra_utils.py | 16 +- .../ingestion/source/common/data_reader.py | 5 +- .../source/confluent_schema_registry.py | 82 +- .../datahub/ingestion/source/csv_enricher.py | 26 +- .../source/data_lake_common/config.py | 2 +- .../data_lake_common/data_lake_utils.py | 8 +- .../source/data_lake_common/path_spec.py | 48 +- .../ingestion/source/datahub/config.py | 5 +- .../source/datahub/datahub_api_reader.py | 2 +- .../source/datahub/datahub_database_reader.py | 25 +- .../source/datahub/datahub_kafka_reader.py | 13 +- .../source/datahub/datahub_source.py | 40 +- .../ingestion/source/datahub/report.py | 2 +- .../datahub/ingestion/source/datahub/state.py | 10 +- .../datahub/ingestion/source/dbt/dbt_cloud.py | 17 +- .../ingestion/source/dbt/dbt_common.py | 223 ++++-- .../datahub/ingestion/source/dbt/dbt_core.py | 59 +- .../datahub/ingestion/source/dbt/dbt_tests.py | 2 +- .../ingestion/source/delta_lake/config.py | 8 +- .../source/delta_lake/delta_lake_utils.py | 4 +- .../ingestion/source/delta_lake/source.py | 24 +- .../ingestion/source/dremio/dremio_api.py | 106 ++- .../ingestion/source/dremio/dremio_aspects.py | 66 +- .../ingestion/source/dremio/dremio_config.py | 5 +- .../dremio/dremio_datahub_source_mapping.py | 7 +- .../source/dremio/dremio_entities.py | 22 +- .../source/dremio/dremio_profiling.py | 38 +- .../source/dremio/dremio_reporting.py | 4 +- .../ingestion/source/dremio/dremio_source.py | 60 +- .../ingestion/source/dynamodb/data_reader.py | 7 +- .../ingestion/source/dynamodb/dynamodb.py | 42 +- .../ingestion/source/elastic_search.py | 82 +- .../src/datahub/ingestion/source/feast.py | 47 +- .../src/datahub/ingestion/source/file.py | 31 +- .../ingestion/source/fivetran/config.py | 23 +- .../ingestion/source/fivetran/fivetran.py | 28 +- .../source/fivetran/fivetran_log_api.py | 46 +- .../datahub/ingestion/source/gc/datahub_gc.py | 41 +- .../source/gc/dataprocess_cleanup.py | 61 +- .../source/gc/execution_request_cleanup.py | 18 +- .../source/gc/soft_deleted_entity_cleanup.py | 28 +- .../ingestion/source/gcs/gcs_source.py | 16 +- .../ingestion/source/ge_data_profiler.py | 209 +++-- .../ingestion/source/ge_profiling_config.py | 11 +- .../ingestion/source/git/git_import.py | 9 +- .../source/grafana/grafana_source.py | 11 +- .../ingestion/source/iceberg/iceberg.py | 78 +- .../source/iceberg/iceberg_common.py | 57 +- .../source/iceberg/iceberg_profiler.py | 29 +- .../ingestion/source/identity/azure_ad.py | 68 +- .../datahub/ingestion/source/identity/okta.py | 78 +- .../datahub/ingestion/source/kafka/kafka.py | 91 ++- .../kafka/kafka_schema_registry_base.py | 9 +- .../ingestion/source/kafka_connect/common.py | 24 +- .../source/kafka_connect/kafka_connect.py | 69 +- .../source/kafka_connect/sink_connectors.py | 20 +- .../source/kafka_connect/source_connectors.py | 30 +- .../src/datahub/ingestion/source/ldap.py | 58 +- .../ingestion/source/looker/looker_common.py | 117 +-- .../ingestion/source/looker/looker_config.py | 34 +- .../source/looker/looker_connection.py | 4 +- .../source/looker/looker_dataclasses.py | 20 +- .../source/looker/looker_file_loader.py | 7 +- .../source/looker/looker_lib_wrapper.py | 40 +- .../ingestion/source/looker/looker_source.py | 249 +++--- .../source/looker/looker_template_language.py | 37 +- .../ingestion/source/looker/looker_usage.py | 121 +-- .../source/looker/looker_view_id_cache.py | 10 +- .../source/looker/lookml_concept_context.py | 20 +- .../ingestion/source/looker/lookml_config.py | 26 +- .../source/looker/lookml_refinement.py | 49 +- .../ingestion/source/looker/lookml_source.py | 113 +-- .../ingestion/source/looker/urn_functions.py | 2 +- .../ingestion/source/looker/view_upstream.py | 64 +- .../src/datahub/ingestion/source/metabase.py | 67 +- .../source/metadata/business_glossary.py | 70 +- .../ingestion/source/metadata/lineage.py | 26 +- .../src/datahub/ingestion/source/mlflow.py | 5 +- .../src/datahub/ingestion/source/mode.py | 197 +++-- .../src/datahub/ingestion/source/mongodb.py | 44 +- .../ingestion/source/neo4j/neo4j_source.py | 56 +- .../src/datahub/ingestion/source/nifi.py | 196 +++-- .../src/datahub/ingestion/source/openapi.py | 59 +- .../ingestion/source/openapi_parser.py | 21 +- .../ingestion/source/powerbi/config.py | 57 +- .../powerbi/dataplatform_instance_resolver.py | 20 +- .../source/powerbi/m_query/parser.py | 11 +- .../source/powerbi/m_query/pattern_handler.py | 137 ++-- .../source/powerbi/m_query/resolver.py | 53 +- .../source/powerbi/m_query/tree_function.py | 5 +- .../source/powerbi/m_query/validator.py | 3 +- .../ingestion/source/powerbi/powerbi.py | 150 ++-- .../powerbi/rest_api_wrapper/data_classes.py | 18 +- .../powerbi/rest_api_wrapper/data_resolver.py | 101 ++- .../powerbi/rest_api_wrapper/powerbi_api.py | 82 +- .../powerbi_report_server/report_server.py | 62 +- .../report_server_domain.py | 31 +- .../src/datahub/ingestion/source/preset.py | 11 +- .../ingestion/source/profiling/common.py | 3 +- .../src/datahub/ingestion/source/pulsar.py | 55 +- .../ingestion/source/qlik_sense/config.py | 4 +- .../source/qlik_sense/data_classes.py | 24 +- .../ingestion/source/qlik_sense/qlik_api.py | 54 +- .../ingestion/source/qlik_sense/qlik_sense.py | 96 ++- .../source/qlik_sense/websocket_connection.py | 8 +- .../src/datahub/ingestion/source/redash.py | 85 +- .../ingestion/source/redshift/config.py | 7 +- .../ingestion/source/redshift/exception.py | 3 +- .../ingestion/source/redshift/lineage.py | 92 ++- .../ingestion/source/redshift/lineage_v2.py | 40 +- .../ingestion/source/redshift/profile.py | 5 +- .../ingestion/source/redshift/query.py | 44 +- .../ingestion/source/redshift/redshift.py | 133 ++-- .../source/redshift/redshift_data_reader.py | 9 +- .../source/redshift/redshift_schema.py | 31 +- .../ingestion/source/redshift/report.py | 2 +- .../ingestion/source/redshift/usage.py | 48 +- .../src/datahub/ingestion/source/s3/config.py | 44 +- .../source/s3/datalake_profiler_config.py | 9 +- .../datahub/ingestion/source/s3/profiling.py | 31 +- .../src/datahub/ingestion/source/s3/source.py | 147 ++-- .../src/datahub/ingestion/source/sac/sac.py | 61 +- .../datahub/ingestion/source/salesforce.py | 149 ++-- .../ingestion/source/schema/json_schema.py | 38 +- .../ingestion/source/schema_inference/json.py | 3 +- .../source/schema_inference/object.py | 6 +- .../datahub/ingestion/source/sigma/config.py | 10 +- .../datahub/ingestion/source/sigma/sigma.py | 85 +- .../ingestion/source/sigma/sigma_api.py | 66 +- .../datahub/ingestion/source/slack/slack.py | 32 +- .../source/snowflake/oauth_config.py | 11 +- .../source/snowflake/oauth_generator.py | 11 +- .../source/snowflake/snowflake_assertion.py | 18 +- .../source/snowflake/snowflake_config.py | 42 +- .../source/snowflake/snowflake_connection.py | 34 +- .../source/snowflake/snowflake_data_reader.py | 16 +- .../source/snowflake/snowflake_lineage_v2.py | 63 +- .../source/snowflake/snowflake_profiler.py | 21 +- .../source/snowflake/snowflake_queries.py | 72 +- .../source/snowflake/snowflake_query.py | 27 +- .../source/snowflake/snowflake_report.py | 2 +- .../source/snowflake/snowflake_schema.py | 92 ++- .../source/snowflake/snowflake_schema_gen.py | 260 ++++-- .../source/snowflake/snowflake_shares.py | 32 +- .../source/snowflake/snowflake_summary.py | 16 +- .../source/snowflake/snowflake_tag.py | 28 +- .../source/snowflake/snowflake_usage_v2.py | 87 +- .../source/snowflake/snowflake_utils.py | 46 +- .../source/snowflake/snowflake_v2.py | 79 +- .../datahub/ingestion/source/sql/athena.py | 49 +- .../ingestion/source/sql/clickhouse.py | 77 +- .../ingestion/source/sql/cockroachdb.py | 2 +- .../src/datahub/ingestion/source/sql/hive.py | 54 +- .../ingestion/source/sql/hive_metastore.py | 70 +- .../ingestion/source/sql/mssql/job_models.py | 3 +- .../ingestion/source/sql/mssql/source.py | 102 ++- .../sql/mssql/stored_procedure_lineage.py | 4 +- .../src/datahub/ingestion/source/sql/mysql.py | 2 +- .../datahub/ingestion/source/sql/oracle.py | 93 ++- .../datahub/ingestion/source/sql/postgres.py | 31 +- .../datahub/ingestion/source/sql/presto.py | 9 +- .../ingestion/source/sql/sql_common.py | 152 +++- .../ingestion/source/sql/sql_config.py | 24 +- .../ingestion/source/sql/sql_generic.py | 4 +- .../source/sql/sql_generic_profiler.py | 46 +- .../ingestion/source/sql/sql_report.py | 13 +- .../datahub/ingestion/source/sql/sql_utils.py | 26 +- .../source/sql/sqlalchemy_data_reader.py | 12 +- .../source/sql/sqlalchemy_uri_mapper.py | 5 +- .../datahub/ingestion/source/sql/teradata.py | 72 +- .../src/datahub/ingestion/source/sql/trino.py | 59 +- .../source/sql/two_tier_sql_source.py | 7 +- .../datahub/ingestion/source/sql/vertica.py | 154 +++- .../datahub/ingestion/source/sql_queries.py | 11 +- .../ingestion/source/state/checkpoint.py | 29 +- .../source/state/entity_removal_state.py | 21 +- .../source/state/profiling_state_handler.py | 3 +- .../state/redundant_run_skip_handler.py | 35 +- .../state/stale_entity_removal_handler.py | 20 +- .../source/state/stateful_ingestion_base.py | 57 +- ...atahub_ingestion_checkpointing_provider.py | 24 +- .../file_ingestion_checkpointing_provider.py | 20 +- .../state_provider/state_provider_registry.py | 2 +- .../src/datahub/ingestion/source/superset.py | 74 +- .../ingestion/source/tableau/tableau.py | 388 +++++---- .../source/tableau/tableau_common.py | 36 +- .../source/tableau/tableau_validation.py | 2 +- .../source/unity/analyze_profiler.py | 16 +- .../datahub/ingestion/source/unity/config.py | 33 +- .../ingestion/source/unity/connection_test.py | 3 +- .../ingestion/source/unity/ge_profiler.py | 14 +- .../source/unity/hive_metastore_proxy.py | 38 +- .../datahub/ingestion/source/unity/proxy.py | 56 +- .../ingestion/source/unity/proxy_profiling.py | 43 +- .../ingestion/source/unity/proxy_types.py | 10 +- .../datahub/ingestion/source/unity/report.py | 4 +- .../datahub/ingestion/source/unity/source.py | 111 ++- .../datahub/ingestion/source/unity/usage.py | 41 +- .../source/usage/clickhouse_usage.py | 15 +- .../source/usage/starburst_trino_usage.py | 19 +- .../ingestion/source/usage/usage_common.py | 27 +- .../ingestion/source_config/csv_enricher.py | 9 +- .../source_config/operation_config.py | 4 +- .../datahub/ingestion/source_config/pulsar.py | 36 +- .../transformer/add_dataset_browse_path.py | 17 +- .../transformer/add_dataset_dataproduct.py | 26 +- .../transformer/add_dataset_ownership.py | 31 +- .../transformer/add_dataset_properties.py | 22 +- .../transformer/add_dataset_schema_tags.py | 16 +- .../transformer/add_dataset_schema_terms.py | 22 +- .../ingestion/transformer/add_dataset_tags.py | 16 +- .../transformer/add_dataset_terms.py | 21 +- .../transformer/auto_helper_transformer.py | 8 +- .../ingestion/transformer/base_transformer.py | 36 +- .../ingestion/transformer/dataset_domain.py | 47 +- .../dataset_domain_based_on_tags.py | 16 +- .../transformer/dataset_transformer.py | 14 +- .../transformer/extract_dataset_tags.py | 14 +- .../extract_ownership_from_tags.py | 18 +- .../transformer/generic_aspect_transformer.py | 26 +- .../transformer/mark_dataset_status.py | 5 +- .../pattern_cleanup_dataset_usage_user.py | 14 +- .../transformer/pattern_cleanup_ownership.py | 11 +- .../transformer/remove_dataset_ownership.py | 9 +- .../transformer/replace_external_url.py | 28 +- .../system_metadata_transformer.py | 8 +- .../ingestion/transformer/tags_to_terms.py | 23 +- .../integrations/assertion/registry.py | 2 +- .../assertion/snowflake/compiler.py | 27 +- .../assertion/snowflake/dmf_generator.py | 12 +- .../snowflake/field_metric_sql_generator.py | 89 ++- .../field_values_metric_sql_generator.py | 14 +- .../snowflake/metric_sql_generator.py | 2 +- .../src/datahub/lite/duckdb_lite.py | 100 ++- .../src/datahub/lite/lite_local.py | 3 +- .../src/datahub/lite/lite_util.py | 17 +- .../datahub/secret/datahub_secret_store.py | 4 +- .../src/datahub/secret/secret_common.py | 12 +- .../aspect_helpers/custom_properties.py | 3 +- .../specific/aspect_helpers/ownership.py | 9 +- .../aspect_helpers/structured_properties.py | 8 +- .../datahub/specific/aspect_helpers/tags.py | 5 +- .../datahub/specific/aspect_helpers/terms.py | 10 +- .../src/datahub/specific/chart.py | 13 +- .../src/datahub/specific/dashboard.py | 22 +- .../src/datahub/specific/datajob.py | 16 +- .../src/datahub/specific/dataproduct.py | 3 +- .../src/datahub/specific/dataset.py | 57 +- .../src/datahub/specific/form.py | 4 +- .../datahub/specific/structured_property.py | 25 +- .../src/datahub/sql_parsing/_sqlglot_patch.py | 3 +- .../src/datahub/sql_parsing/datajob.py | 6 +- .../src/datahub/sql_parsing/query_types.py | 6 +- .../datahub/sql_parsing/schema_resolver.py | 41 +- .../datahub/sql_parsing/split_statements.py | 12 +- .../sql_parsing/sql_parsing_aggregator.py | 190 +++-- .../datahub/sql_parsing/sql_parsing_common.py | 2 +- .../datahub/sql_parsing/sqlglot_lineage.py | 117 ++- .../src/datahub/sql_parsing/sqlglot_utils.py | 48 +- .../sql_parsing/tool_meta_extractor.py | 13 +- .../src/datahub/telemetry/stats.py | 3 +- .../src/datahub/telemetry/telemetry.py | 23 +- .../src/datahub/testing/check_imports.py | 2 +- .../testing/check_sql_parser_result.py | 5 +- .../src/datahub/testing/check_str_enum.py | 2 +- .../datahub/testing/compare_metadata_json.py | 2 +- .../src/datahub/testing/docker_utils.py | 9 +- .../src/datahub/testing/mcp_diff.py | 3 +- .../src/datahub/upgrade/upgrade.py | 49 +- .../utilities/_custom_package_loader.py | 2 +- .../utilities/backpressure_aware_executor.py | 3 +- .../utilities/checkpoint_state_util.py | 9 +- .../datahub/utilities/cooperative_timeout.py | 2 +- .../utilities/file_backed_collections.py | 35 +- .../datahub/utilities/hive_schema_to_avro.py | 36 +- .../src/datahub/utilities/logging_manager.py | 2 +- .../datahub/utilities/lossy_collections.py | 2 +- .../src/datahub/utilities/mapping.py | 60 +- .../src/datahub/utilities/parsing_util.py | 3 +- .../datahub/utilities/partition_executor.py | 22 +- .../datahub/utilities/prefix_batch_builder.py | 12 +- .../utilities/registries/domain_registry.py | 6 +- .../src/datahub/utilities/search_utils.py | 64 +- .../datahub/utilities/serialized_lru_cache.py | 3 +- .../src/datahub/utilities/sql_formatter.py | 6 +- .../utilities/sqlalchemy_query_combiner.py | 35 +- .../utilities/sqlalchemy_type_converter.py | 34 +- .../src/datahub/utilities/sqllineage_patch.py | 2 +- .../datahub/utilities/stats_collections.py | 4 +- .../utilities/threaded_iterator_executor.py | 3 +- .../datahub/utilities/threading_timeout.py | 2 +- .../src/datahub/utilities/type_annotations.py | 4 +- .../src/datahub/utilities/unified_diff.py | 10 +- .../src/datahub/utilities/urns/_urn_base.py | 12 +- .../src/datahub/utilities/urns/urn_iter.py | 16 +- metadata-ingestion/tests/conftest.py | 3 +- .../integration/athena/test_athena_source.py | 9 +- .../integration/azure_ad/test_azure_ad.py | 37 +- .../integration/bigquery_v2/test_bigquery.py | 40 +- .../bigquery_v2/test_bigquery_queries.py | 3 +- .../test_business_glossary.py | 14 +- .../integration/cassandra/test_cassandra.py | 5 +- .../circuit_breaker/test_circuit_breaker.py | 18 +- .../integration/clickhouse/test_clickhouse.py | 11 +- .../csv-enricher/test_csv_enricher.py | 4 +- .../tests/integration/dbt/test_dbt.py | 68 +- .../delta_lake/test_delta_lake_minio.py | 5 +- .../tests/integration/dremio/test_dremio.py | 6 +- .../integration/dynamodb/test_dynamodb.py | 16 +- .../feast/test_feast_repository.py | 4 +- .../integration/file/test_file_source.py | 2 +- .../integration/fivetran/test_fivetran.py | 32 +- .../tests/integration/git/test_git_clone.py | 8 +- .../tests/integration/grafana/test_grafana.py | 14 +- .../tests/integration/hana/test_hana.py | 6 +- .../hive-metastore/test_hive_metastore.py | 13 +- .../tests/integration/hive/test_hive.py | 19 +- .../tests/integration/iceberg/setup/create.py | 2 +- .../tests/integration/iceberg/test_iceberg.py | 42 +- .../kafka-connect/test_kafka_connect.py | 55 +- .../kafka/create_key_value_topic.py | 16 +- .../tests/integration/kafka/test_kafka.py | 35 +- .../integration/kafka/test_kafka_state.py | 24 +- .../tests/integration/ldap/test_ldap.py | 20 +- .../integration/ldap/test_ldap_stateful.py | 21 +- .../tests/integration/looker/test_looker.py | 95 ++- .../tests/integration/lookml/test_lookml.py | 70 +- .../integration/metabase/test_metabase.py | 43 +- .../tests/integration/mode/test_mode.py | 12 +- .../tests/integration/mongodb/test_mongodb.py | 7 +- .../tests/integration/mysql/test_mysql.py | 6 +- .../tests/integration/nifi/test_nifi.py | 12 +- .../tests/integration/okta/test_okta.py | 49 +- .../tests/integration/oracle/common.py | 11 +- .../integration/postgres/test_postgres.py | 15 +- .../powerbi/test_admin_only_api.py | 88 +- .../integration/powerbi/test_m_parser.py | 30 +- .../tests/integration/powerbi/test_powerbi.py | 130 +-- .../integration/powerbi/test_profiling.py | 46 +- .../powerbi/test_stateful_ingestion.py | 36 +- .../test_powerbi_report_server.py | 18 +- .../tests/integration/preset/test_preset.py | 23 +- .../integration/qlik_sense/test_qlik_sense.py | 92 ++- .../redshift-usage/test_redshift_usage.py | 16 +- .../tests/integration/remote/test_remote.py | 11 +- .../tests/integration/s3/test_s3.py | 25 +- .../integration/salesforce/test_salesforce.py | 23 +- .../tests/integration/sigma/test_sigma.py | 16 +- .../tests/integration/snowflake/common.py | 70 +- .../integration/snowflake/test_snowflake.py | 55 +- .../test_snowflake_classification.py | 17 +- .../snowflake/test_snowflake_failures.py | 6 +- .../snowflake/test_snowflake_queries.py | 2 +- .../snowflake/test_snowflake_stateful.py | 10 +- .../snowflake/test_snowflake_tag.py | 16 +- .../integration/sql_server/test_sql_server.py | 12 +- .../test_starburst_trino_usage.py | 6 +- .../integration/superset/test_superset.py | 32 +- .../tableau/test_tableau_ingest.py | 79 +- .../tests/integration/trino/test_trino.py | 31 +- .../unity/test_unity_catalog_ingest.py | 22 +- .../tests/integration/vertica/test_vertica.py | 6 +- .../performance/bigquery/bigquery_events.py | 20 +- .../bigquery/test_bigquery_usage.py | 6 +- .../tests/performance/data_generation.py | 15 +- .../tests/performance/databricks/generator.py | 11 +- .../performance/databricks/test_unity.py | 6 +- .../databricks/unity_proxy_mock.py | 2 +- .../tests/performance/helpers.py | 6 +- .../performance/snowflake/test_snowflake.py | 2 +- .../performance/sql/test_sql_formatter.py | 4 +- .../tests/test_helpers/click_helpers.py | 4 +- .../tests/test_helpers/graph_helpers.py | 20 +- .../tests/test_helpers/mce_helpers.py | 41 +- .../tests/test_helpers/sink_helpers.py | 4 +- .../tests/test_helpers/state_helpers.py | 17 +- .../test_helpers/test_connection_helpers.py | 6 +- .../test_data_quality_assertion.py | 5 +- .../entities/dataproducts/test_dataproduct.py | 35 +- .../test_platform_resource.py | 26 +- .../test_auto_browse_path_v2.py | 43 +- .../source_helpers/test_ensure_aspect_size.py | 78 +- .../test_incremental_lineage_helper.py | 33 +- .../api/source_helpers/test_source_helpers.py | 43 +- .../tests/unit/api/test_pipeline.py | 71 +- .../tests/unit/api/test_plugin_system.py | 6 +- .../tests/unit/api/test_workunit.py | 5 +- .../unit/bigquery/test_bigquery_lineage.py | 32 +- .../unit/bigquery/test_bigquery_source.py | 332 +++++--- .../unit/bigquery/test_bigquery_usage.py | 113 +-- .../bigquery/test_bigqueryv2_usage_source.py | 18 +- .../bigquery/test_bq_get_partition_range.py | 33 +- .../tests/unit/cli/test_cli_utils.py | 3 +- .../cli/test_quickstart_version_mapping.py | 8 +- .../tests/unit/config/test_config_clean.py | 3 +- .../tests/unit/config/test_config_loader.py | 6 +- .../unit/config/test_connection_resolver.py | 10 +- .../tests/unit/config/test_datetime_parser.py | 78 +- .../unit/config/test_key_value_pattern.py | 3 +- .../unit/config/test_pydantic_validators.py | 2 +- .../unit/config/test_time_window_config.py | 7 +- .../unit/data_lake/test_schema_inference.py | 16 +- .../tests/unit/glue/test_glue_source.py | 31 +- .../tests/unit/glue/test_glue_source_stubs.py | 52 +- .../tests/unit/patch/test_patch_builder.py | 63 +- .../unit/redshift/redshift_query_mocker.py | 4 +- .../unit/redshift/test_redshift_lineage.py | 118 +-- .../test_datahub_ingestion_reporter.py | 6 +- .../tests/unit/s3/test_s3_source.py | 19 +- .../unit/sagemaker/test_sagemaker_source.py | 14 +- .../sagemaker/test_sagemaker_source_stubs.py | 42 +- .../unit/schema/test_json_schema_util.py | 51 +- .../tests/unit/sdk/test_client.py | 2 +- .../tests/unit/sdk/test_kafka_emitter.py | 6 +- .../tests/unit/sdk/test_mcp_builder.py | 10 +- .../tests/unit/sdk/test_rest_emitter.py | 7 +- .../tests/unit/serde/test_codegen.py | 19 +- .../tests/unit/serde/test_serde.py | 61 +- .../tests/unit/serde/test_urn_iterator.py | 14 +- .../unit/snowflake/test_snowflake_shares.py | 52 +- .../unit/snowflake/test_snowflake_source.py | 93 ++- .../unit/sql_parsing/test_schemaresolver.py | 5 +- .../unit/sql_parsing/test_sql_aggregator.py | 149 ++-- .../unit/sql_parsing/test_sqlglot_lineage.py | 3 +- .../unit/sql_parsing/test_sqlglot_patch.py | 4 +- .../unit/sql_parsing/test_sqlglot_utils.py | 25 +- .../sql_parsing/test_tool_meta_extractor.py | 3 +- .../provider/test_provider.py | 35 +- .../state/test_checkpoint.py | 23 +- .../state/test_redundant_run_skip_handler.py | 35 +- .../state/test_sql_common_state.py | 12 +- .../test_stale_entity_removal_handler.py | 9 +- .../state/test_stateful_ingestion.py | 46 +- .../unit/stateful_ingestion/test_configs.py | 2 +- .../stateful_ingestion/test_kafka_state.py | 6 +- .../test_structured_properties.py | 17 +- .../tests/unit/test_athena_source.py | 43 +- .../tests/unit/test_aws_common.py | 24 +- .../tests/unit/test_capability_report.py | 16 +- .../tests/unit/test_cassandra_source.py | 8 +- .../tests/unit/test_classification.py | 33 +- .../tests/unit/test_clickhouse_source.py | 12 +- .../tests/unit/test_compare_metadata.py | 14 +- .../unit/test_confluent_schema_registry.py | 25 +- .../tests/unit/test_csv_enricher_source.py | 32 +- .../tests/unit/test_datahub_source.py | 3 +- .../tests/unit/test_dbt_source.py | 49 +- .../tests/unit/test_elasticsearch_source.py | 14 +- metadata-ingestion/tests/unit/test_gc.py | 41 +- .../tests/unit/test_gcs_source.py | 11 +- .../tests/unit/test_ge_profiling_config.py | 6 +- .../unit/test_generic_aspect_transformer.py | 48 +- .../tests/unit/test_hana_source.py | 4 +- .../tests/unit/test_hive_source.py | 2 +- metadata-ingestion/tests/unit/test_iceberg.py | 120 ++- .../tests/unit/test_kafka_sink.py | 17 +- .../tests/unit/test_kafka_source.py | 38 +- .../tests/unit/test_ldap_source.py | 8 +- metadata-ingestion/tests/unit/test_mapping.py | 2 +- .../tests/unit/test_mlflow_source.py | 2 +- .../tests/unit/test_neo4j_source.py | 26 +- .../tests/unit/test_nifi_source.py | 31 +- .../tests/unit/test_oracle_source.py | 6 +- .../tests/unit/test_packaging.py | 2 +- .../tests/unit/test_postgres_source.py | 10 +- .../tests/unit/test_powerbi_parser.py | 2 +- .../tests/unit/test_protobuf_util.py | 17 +- .../tests/unit/test_pulsar_source.py | 17 +- .../tests/unit/test_redash_source.py | 44 +- .../tests/unit/test_rest_sink.py | 44 +- .../tests/unit/test_schema_util.py | 18 +- metadata-ingestion/tests/unit/test_source.py | 4 +- .../tests/unit/test_sql_common.py | 13 +- .../tests/unit/test_sql_utils.py | 8 +- .../tests/unit/test_tableau_source.py | 29 +- .../tests/unit/test_transform_dataset.py | 749 ++++++++++-------- .../tests/unit/test_unity_catalog_config.py | 23 +- .../tests/unit/test_usage_common.py | 12 +- .../tests/unit/test_vertica_source.py | 2 +- .../tests/unit/urns/test_data_job_urn.py | 5 +- .../urns/test_data_process_instance_urn.py | 2 +- .../tests/unit/urns/test_urn.py | 2 +- .../test_backpressure_aware_executor.py | 9 +- .../tests/unit/utilities/test_cli_logging.py | 3 +- .../utilities/test_file_backed_collections.py | 5 +- .../utilities/test_hive_schema_to_avro.py | 6 +- .../unit/utilities/test_lossy_collections.py | 5 +- .../unit/utilities/test_partition_executor.py | 8 +- .../tests/unit/utilities/test_search_utils.py | 2 +- .../test_sqlalchemy_type_converter.py | 11 +- .../test_threaded_iterator_executor.py | 4 +- .../tests/unit/utilities/test_unified_diff.py | 48 +- .../tests/unit/utilities/test_urn_encoder.py | 3 +- .../unit/utilities/test_yaml_sync_utils.py | 4 +- 699 files changed, 13432 insertions(+), 7876 deletions(-) diff --git a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/ecommerce/01_snowflake_load.py b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/ecommerce/01_snowflake_load.py index 582002f8d80f16..16854e5e8bdc97 100644 --- a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/ecommerce/01_snowflake_load.py +++ b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/ecommerce/01_snowflake_load.py @@ -24,7 +24,9 @@ def report_operation(context): for outlet in task._outlets: print(f"Reporting insert operation for {outlet.urn}") reporter.report_operation( - urn=outlet.urn, operation_type="INSERT", num_affected_rows=123 + urn=outlet.urn, + operation_type="INSERT", + num_affected_rows=123, ) diff --git a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/marketing/01_send_emails.py b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/marketing/01_send_emails.py index 62bf8962bfa721..5df742f58dd91e 100644 --- a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/marketing/01_send_emails.py +++ b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/01-operation/marketing/01_send_emails.py @@ -21,7 +21,7 @@ task_id="pet_profiles_operation_sensor", datahub_rest_conn_id="datahub_longtail", urn=[ - "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)", ], time_delta=datetime.timedelta(minutes=10), ) diff --git a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/02-assertion/marketing/02_send_emails.py b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/02-assertion/marketing/02_send_emails.py index 8f6464206e62a9..257fba5c652fbe 100644 --- a/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/02-assertion/marketing/02_send_emails.py +++ b/metadata-ingestion/examples/airflow/circuit_breaker/long_tail_companion/02-assertion/marketing/02_send_emails.py @@ -24,7 +24,7 @@ task_id="pet_profiles_operation_sensor", datahub_rest_conn_id="datahub_longtail", urn=[ - "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)", ], time_delta=datetime.timedelta(days=1), ) @@ -36,7 +36,7 @@ task_id="pet_profiles_assertion_circuit_breaker", datahub_rest_conn_id="datahub_longtail", urn=[ - "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:snowflake,long_tail_companions.adoption.pet_profiles,PROD)", ], check_last_assertion_time=True, ) diff --git a/metadata-ingestion/examples/demo_data/enrich.py b/metadata-ingestion/examples/demo_data/enrich.py index b689a4fda9e994..151fb9534ad6de 100644 --- a/metadata-ingestion/examples/demo_data/enrich.py +++ b/metadata-ingestion/examples/demo_data/enrich.py @@ -88,9 +88,9 @@ def create_owner_entity_mce(owner: str) -> MetadataChangeEventClass: displayName=owner, fullName=owner, email=f"{clean_name}-demo@example.com", - ) + ), ], - ) + ), ) @@ -111,9 +111,9 @@ def create_ownership_aspect_mce(directive: Directive) -> MetadataChangeEventClas time=int(time.time() * 1000), actor="urn:li:corpuser:datahub", ), - ) + ), ], - ) + ), ) @@ -133,10 +133,10 @@ def create_lineage_aspect_mce(directive: Directive) -> MetadataChangeEventClass: ), ) for upstream in directive.depends_on - ] - ) + ], + ), ], - ) + ), ) @@ -145,7 +145,7 @@ def create_global_tags_aspect_mce(directive: Directive) -> MetadataChangeEventCl proposedSnapshot=DatasetSnapshotClass( urn=dataset_name_to_urn(directive.table), aspects=[GlobalTagsClass(tags=[])], - ) + ), ) @@ -166,9 +166,9 @@ def create_editable_schema_info_aspect_mce( actor="urn:li:corpuser:datahub", ), editableSchemaFieldInfo=[], - ) + ), ], - ) + ), ) diff --git a/metadata-ingestion/examples/library/add_mlfeature_to_mlfeature_table.py b/metadata-ingestion/examples/library/add_mlfeature_to_mlfeature_table.py index b8ec042bd75aae..dff87d5166d4dd 100644 --- a/metadata-ingestion/examples/library/add_mlfeature_to_mlfeature_table.py +++ b/metadata-ingestion/examples/library/add_mlfeature_to_mlfeature_table.py @@ -10,11 +10,13 @@ emitter = DatahubRestEmitter(gms_server=gms_endpoint, extra_headers={}) feature_table_urn = builder.make_ml_feature_table_urn( - feature_table_name="my-feature-table", platform="feast" + feature_table_name="my-feature-table", + platform="feast", ) feature_urns = [ builder.make_ml_feature_urn( - feature_name="my-feature2", feature_table_name="my-feature-table" + feature_name="my-feature2", + feature_table_name="my-feature-table", ), ] @@ -22,7 +24,8 @@ # If you want to replace all existing features with only the new ones, you can comment out this line. graph = DataHubGraph(DatahubClientConfig(server=gms_endpoint)) feature_table_properties = graph.get_aspect( - entity_urn=feature_table_urn, aspect_type=MLFeatureTablePropertiesClass + entity_urn=feature_table_urn, + aspect_type=MLFeatureTablePropertiesClass, ) if feature_table_properties: current_features = feature_table_properties.mlFeatures diff --git a/metadata-ingestion/examples/library/add_mlfeature_to_mlmodel.py b/metadata-ingestion/examples/library/add_mlfeature_to_mlmodel.py index 3b9c26992e8570..955ecdcfeb70d8 100644 --- a/metadata-ingestion/examples/library/add_mlfeature_to_mlmodel.py +++ b/metadata-ingestion/examples/library/add_mlfeature_to_mlmodel.py @@ -10,11 +10,14 @@ emitter = DatahubRestEmitter(gms_server=gms_endpoint, extra_headers={}) model_urn = builder.make_ml_model_urn( - model_name="my-test-model", platform="science", env="PROD" + model_name="my-test-model", + platform="science", + env="PROD", ) feature_urns = [ builder.make_ml_feature_urn( - feature_name="my-feature3", feature_table_name="my-feature-table" + feature_name="my-feature3", + feature_table_name="my-feature-table", ), ] @@ -22,7 +25,8 @@ # If you want to replace all existing features with only the new ones, you can comment out this line. graph = DataHubGraph(DatahubClientConfig(server=gms_endpoint)) model_properties = graph.get_aspect( - entity_urn=model_urn, aspect_type=MLModelPropertiesClass + entity_urn=model_urn, + aspect_type=MLModelPropertiesClass, ) if model_properties: current_features = model_properties.mlFeatures diff --git a/metadata-ingestion/examples/library/add_mlgroup_to_mlmodel.py b/metadata-ingestion/examples/library/add_mlgroup_to_mlmodel.py index 702080767449c6..f833e34b5425a1 100644 --- a/metadata-ingestion/examples/library/add_mlgroup_to_mlmodel.py +++ b/metadata-ingestion/examples/library/add_mlgroup_to_mlmodel.py @@ -10,11 +10,15 @@ model_group_urns = [ builder.make_ml_model_group_urn( - group_name="my-model-group", platform="science", env="PROD" - ) + group_name="my-model-group", + platform="science", + env="PROD", + ), ] model_urn = builder.make_ml_model_urn( - model_name="science-model", platform="science", env="PROD" + model_name="science-model", + platform="science", + env="PROD", ) # This code concatenates the new features with the existing features in the feature table. @@ -22,7 +26,8 @@ graph = DataHubGraph(DatahubClientConfig(server=gms_endpoint)) target_model_properties = graph.get_aspect( - entity_urn=model_urn, aspect_type=models.MLModelPropertiesClass + entity_urn=model_urn, + aspect_type=models.MLModelPropertiesClass, ) if target_model_properties: current_model_groups = target_model_properties.groups diff --git a/metadata-ingestion/examples/library/create_domain.py b/metadata-ingestion/examples/library/create_domain.py index da9cb3e604fae2..bf77835bf072ad 100644 --- a/metadata-ingestion/examples/library/create_domain.py +++ b/metadata-ingestion/examples/library/create_domain.py @@ -10,7 +10,8 @@ domain_urn = make_domain_urn("marketing") domain_properties_aspect = DomainPropertiesClass( - name="Marketing", description="Entities related to the marketing department" + name="Marketing", + description="Entities related to the marketing department", ) event: MetadataChangeProposalWrapper = MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/examples/library/create_form.py b/metadata-ingestion/examples/library/create_form.py index 3ba545ead3ed88..dc23a974ddd5f0 100644 --- a/metadata-ingestion/examples/library/create_form.py +++ b/metadata-ingestion/examples/library/create_form.py @@ -23,7 +23,7 @@ title="First Prompt", type=FormPromptTypeClass.STRUCTURED_PROPERTY, # structured property type prompt structuredPropertyParams=StructuredPropertyParamsClass( - urn="urn:li:structuredProperty:property1" + urn="urn:li:structuredProperty:property1", ), # reference existing structured property required=True, ) @@ -32,7 +32,7 @@ title="Second Prompt", type=FormPromptTypeClass.FIELDS_STRUCTURED_PROPERTY, # structured property prompt on dataset schema fields structuredPropertyParams=StructuredPropertyParamsClass( - urn="urn:li:structuredProperty:property1" + urn="urn:li:structuredProperty:property1", ), required=False, # dataset schema fields prompts should not be required ) diff --git a/metadata-ingestion/examples/library/create_mlfeature.py b/metadata-ingestion/examples/library/create_mlfeature.py index 0f6d146dbf1442..5e7aa4a453d01e 100644 --- a/metadata-ingestion/examples/library/create_mlfeature.py +++ b/metadata-ingestion/examples/library/create_mlfeature.py @@ -7,7 +7,9 @@ emitter = DatahubRestEmitter(gms_server="http://localhost:8080", extra_headers={}) dataset_urn = builder.make_dataset_urn( - name="fct_users_created", platform="hive", env="PROD" + name="fct_users_created", + platform="hive", + env="PROD", ) feature_urn = builder.make_ml_feature_urn( feature_table_name="users_feature_table", diff --git a/metadata-ingestion/examples/library/create_mlfeature_table.py b/metadata-ingestion/examples/library/create_mlfeature_table.py index d579d36a0811ae..b7b7d6d8336f53 100644 --- a/metadata-ingestion/examples/library/create_mlfeature_table.py +++ b/metadata-ingestion/examples/library/create_mlfeature_table.py @@ -7,15 +7,18 @@ emitter = DatahubRestEmitter(gms_server="http://localhost:8080", extra_headers={}) feature_table_urn = builder.make_ml_feature_table_urn( - feature_table_name="users_feature_table", platform="feast" + feature_table_name="users_feature_table", + platform="feast", ) feature_urns = [ builder.make_ml_feature_urn( - feature_name="user_signup_date", feature_table_name="users_feature_table" + feature_name="user_signup_date", + feature_table_name="users_feature_table", ), builder.make_ml_feature_urn( - feature_name="user_last_active_date", feature_table_name="users_feature_table" + feature_name="user_last_active_date", + feature_table_name="users_feature_table", ), ] @@ -23,7 +26,7 @@ builder.make_ml_primary_key_urn( feature_table_name="users_feature_table", primary_key_name="user_id", - ) + ), ] feature_table_properties = models.MLFeatureTablePropertiesClass( diff --git a/metadata-ingestion/examples/library/create_mlmodel.py b/metadata-ingestion/examples/library/create_mlmodel.py index 92ca8b93e82088..c94825c5220b31 100644 --- a/metadata-ingestion/examples/library/create_mlmodel.py +++ b/metadata-ingestion/examples/library/create_mlmodel.py @@ -6,19 +6,25 @@ # Create an emitter to DataHub over REST emitter = DatahubRestEmitter(gms_server="http://localhost:8080", extra_headers={}) model_urn = builder.make_ml_model_urn( - model_name="my-recommendations-model-run-1", platform="science", env="PROD" + model_name="my-recommendations-model-run-1", + platform="science", + env="PROD", ) model_group_urns = [ builder.make_ml_model_group_urn( - group_name="my-recommendations-model-group", platform="science", env="PROD" - ) + group_name="my-recommendations-model-group", + platform="science", + env="PROD", + ), ] feature_urns = [ builder.make_ml_feature_urn( - feature_name="user_signup_date", feature_table_name="users_feature_table" + feature_name="user_signup_date", + feature_table_name="users_feature_table", ), builder.make_ml_feature_urn( - feature_name="user_last_active_date", feature_table_name="users_feature_table" + feature_name="user_last_active_date", + feature_table_name="users_feature_table", ), ] @@ -33,13 +39,17 @@ mlFeatures=feature_urns, trainingMetrics=[ models.MLMetricClass( - name="accuracy", description="accuracy of the model", value="1.0" - ) + name="accuracy", + description="accuracy of the model", + value="1.0", + ), ], hyperParams=[ models.MLHyperParamClass( - name="hyper_1", description="hyper_1", value="0.102" - ) + name="hyper_1", + description="hyper_1", + value="0.102", + ), ], ), ) diff --git a/metadata-ingestion/examples/library/create_mlmodel_group.py b/metadata-ingestion/examples/library/create_mlmodel_group.py index e39d26ac0f64e2..2d661537515095 100644 --- a/metadata-ingestion/examples/library/create_mlmodel_group.py +++ b/metadata-ingestion/examples/library/create_mlmodel_group.py @@ -6,7 +6,9 @@ # Create an emitter to DataHub over REST emitter = DatahubRestEmitter(gms_server="http://localhost:8080", extra_headers={}) model_group_urn = builder.make_ml_model_group_urn( - group_name="my-recommendations-model-group", platform="science", env="PROD" + group_name="my-recommendations-model-group", + platform="science", + env="PROD", ) diff --git a/metadata-ingestion/examples/library/create_mlprimarykey.py b/metadata-ingestion/examples/library/create_mlprimarykey.py index 3fb397183a07f2..c11478d5a49158 100644 --- a/metadata-ingestion/examples/library/create_mlprimarykey.py +++ b/metadata-ingestion/examples/library/create_mlprimarykey.py @@ -7,7 +7,9 @@ emitter = DatahubRestEmitter(gms_server="http://localhost:8080", extra_headers={}) dataset_urn = builder.make_dataset_urn( - name="fct_users_created", platform="hive", env="PROD" + name="fct_users_created", + platform="hive", + env="PROD", ) primary_key_urn = builder.make_ml_primary_key_urn( feature_table_name="users_feature_table", diff --git a/metadata-ingestion/examples/library/dashboard_usage.py b/metadata-ingestion/examples/library/dashboard_usage.py index 10edd72a9ea410..39b378694caaaf 100644 --- a/metadata-ingestion/examples/library/dashboard_usage.py +++ b/metadata-ingestion/examples/library/dashboard_usage.py @@ -19,10 +19,14 @@ usage_day_1_user_counts: List[DashboardUserUsageCountsClass] = [ DashboardUserUsageCountsClass( - user=make_user_urn("user1"), executionsCount=3, usageCount=3 + user=make_user_urn("user1"), + executionsCount=3, + usageCount=3, ), DashboardUserUsageCountsClass( - user=make_user_urn("user2"), executionsCount=2, usageCount=2 + user=make_user_urn("user2"), + executionsCount=2, + usageCount=2, ), ] @@ -30,7 +34,7 @@ entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-09", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-09", "%Y-%m-%d").timestamp() * 1000, ), eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY), uniqueUserCount=2, @@ -44,15 +48,16 @@ entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-09", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-09", "%Y-%m-%d").timestamp() * 1000, ), favoritesCount=100, viewsCount=25, lastViewedAt=round( datetime.strptime( - "2022-02-09 04:45:30", "%Y-%m-%d %H:%M:%S" + "2022-02-09 04:45:30", + "%Y-%m-%d %H:%M:%S", ).timestamp() - * 1000 + * 1000, ), ), ) @@ -63,17 +68,21 @@ usage_day_2_user_counts: List[DashboardUserUsageCountsClass] = [ DashboardUserUsageCountsClass( - user=make_user_urn("user1"), executionsCount=4, usageCount=4 + user=make_user_urn("user1"), + executionsCount=4, + usageCount=4, ), DashboardUserUsageCountsClass( - user=make_user_urn("user2"), executionsCount=6, usageCount=6 + user=make_user_urn("user2"), + executionsCount=6, + usageCount=6, ), ] usage_day_2: MetadataChangeProposalWrapper = MetadataChangeProposalWrapper( entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-10", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-10", "%Y-%m-%d").timestamp() * 1000, ), eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY), uniqueUserCount=2, @@ -87,15 +96,16 @@ entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-10", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-10", "%Y-%m-%d").timestamp() * 1000, ), favoritesCount=100, viewsCount=27, lastViewedAt=round( datetime.strptime( - "2022-02-10 10:45:30", "%Y-%m-%d %H:%M:%S" + "2022-02-10 10:45:30", + "%Y-%m-%d %H:%M:%S", ).timestamp() - * 1000 + * 1000, ), ), ) @@ -106,14 +116,16 @@ usage_day_3_user_counts: List[DashboardUserUsageCountsClass] = [ DashboardUserUsageCountsClass( - user=make_user_urn("user1"), executionsCount=2, usageCount=2 + user=make_user_urn("user1"), + executionsCount=2, + usageCount=2, ), ] usage_day_3: MetadataChangeProposalWrapper = MetadataChangeProposalWrapper( entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-11", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-11", "%Y-%m-%d").timestamp() * 1000, ), eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY), uniqueUserCount=1, @@ -127,15 +139,16 @@ entityUrn=make_dashboard_urn("looker", "dashboards.999999"), aspect=DashboardUsageStatisticsClass( timestampMillis=round( - datetime.strptime("2022-02-11", "%Y-%m-%d").timestamp() * 1000 + datetime.strptime("2022-02-11", "%Y-%m-%d").timestamp() * 1000, ), favoritesCount=102, viewsCount=30, lastViewedAt=round( datetime.strptime( - "2022-02-11 02:45:30", "%Y-%m-%d %H:%M:%S" + "2022-02-11 02:45:30", + "%Y-%m-%d %H:%M:%S", ).timestamp() - * 1000 + * 1000, ), ), ) diff --git a/metadata-ingestion/examples/library/data_quality_mcpw_rest.py b/metadata-ingestion/examples/library/data_quality_mcpw_rest.py index 35192a7ae07d6e..1d9024496e0ecd 100644 --- a/metadata-ingestion/examples/library/data_quality_mcpw_rest.py +++ b/metadata-ingestion/examples/library/data_quality_mcpw_rest.py @@ -74,10 +74,12 @@ def emitAssertionResult(assertionResult: AssertionRunEvent) -> None: nativeParameters={"max_value": "99", "min_value": "89"}, parameters=AssertionStdParameters( minValue=AssertionStdParameter( - type=AssertionStdParameterType.NUMBER, value="89" + type=AssertionStdParameterType.NUMBER, + value="89", ), maxValue=AssertionStdParameter( - type=AssertionStdParameterType.NUMBER, value="99" + type=AssertionStdParameterType.NUMBER, + value="99", ), ), ), @@ -95,7 +97,7 @@ def emitAssertionResult(assertionResult: AssertionRunEvent) -> None: # Construct an assertion platform object. assertion_dataPlatformInstance = DataPlatformInstance( - platform=builder.make_data_platform_urn("great-expectations") + platform=builder.make_data_platform_urn("great-expectations"), ) # Construct a MetadataChangeProposalWrapper object for assertion platform diff --git a/metadata-ingestion/examples/library/dataset_add_column_documentation.py b/metadata-ingestion/examples/library/dataset_add_column_documentation.py index bf871e2dcdb8e6..2cd24bd61453f8 100644 --- a/metadata-ingestion/examples/library/dataset_add_column_documentation.py +++ b/metadata-ingestion/examples/library/dataset_add_column_documentation.py @@ -27,7 +27,8 @@ dataset_urn = make_dataset_urn(platform="hive", name="fct_users_deleted", env="PROD") column = "user_name" field_info_to_set = EditableSchemaFieldInfoClass( - fieldPath=column, description=documentation_to_add + fieldPath=column, + description=documentation_to_add, ) @@ -77,7 +78,8 @@ current_institutional_memory = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=InstitutionalMemoryClass + entity_urn=dataset_urn, + aspect_type=InstitutionalMemoryClass, ) need_write = False diff --git a/metadata-ingestion/examples/library/dataset_add_column_tag.py b/metadata-ingestion/examples/library/dataset_add_column_tag.py index 94204bc39b8746..5bc0f05b9b038d 100644 --- a/metadata-ingestion/examples/library/dataset_add_column_tag.py +++ b/metadata-ingestion/examples/library/dataset_add_column_tag.py @@ -42,7 +42,8 @@ tag_association_to_add = TagAssociationClass(tag=tag_to_add) tags_aspect_to_set = GlobalTagsClass(tags=[tag_association_to_add]) field_info_to_set = EditableSchemaFieldInfoClass( - fieldPath=column, globalTags=tags_aspect_to_set + fieldPath=column, + globalTags=tags_aspect_to_set, ) diff --git a/metadata-ingestion/examples/library/dataset_add_column_term.py b/metadata-ingestion/examples/library/dataset_add_column_term.py index 9796fa9d5404ce..773aadc7c7df71 100644 --- a/metadata-ingestion/examples/library/dataset_add_column_term.py +++ b/metadata-ingestion/examples/library/dataset_add_column_term.py @@ -33,7 +33,8 @@ current_editable_schema_metadata = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=EditableSchemaMetadataClass + entity_urn=dataset_urn, + aspect_type=EditableSchemaMetadataClass, ) @@ -43,10 +44,12 @@ term_association_to_add = GlossaryTermAssociationClass(urn=term_to_add) term_aspect_to_set = GlossaryTermsClass( - terms=[term_association_to_add], auditStamp=current_timestamp + terms=[term_association_to_add], + auditStamp=current_timestamp, ) field_info_to_set = EditableSchemaFieldInfoClass( - fieldPath=column, glossaryTerms=term_aspect_to_set + fieldPath=column, + glossaryTerms=term_aspect_to_set, ) need_write = False diff --git a/metadata-ingestion/examples/library/dataset_add_documentation.py b/metadata-ingestion/examples/library/dataset_add_documentation.py index 15e3577d68748a..c3e40883e09d9f 100644 --- a/metadata-ingestion/examples/library/dataset_add_documentation.py +++ b/metadata-ingestion/examples/library/dataset_add_documentation.py @@ -40,7 +40,8 @@ graph = DataHubGraph(config=DatahubClientConfig(server=gms_endpoint)) current_editable_properties = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=EditableDatasetPropertiesClass + entity_urn=dataset_urn, + aspect_type=EditableDatasetPropertiesClass, ) need_write = False @@ -51,7 +52,8 @@ else: # create a brand new editable dataset properties aspect current_editable_properties = EditableDatasetPropertiesClass( - created=current_timestamp, description=documentation_to_add + created=current_timestamp, + description=documentation_to_add, ) need_write = True @@ -68,7 +70,8 @@ current_institutional_memory = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=InstitutionalMemoryClass + entity_urn=dataset_urn, + aspect_type=InstitutionalMemoryClass, ) need_write = False @@ -80,7 +83,7 @@ else: # create a brand new institutional memory aspect current_institutional_memory = InstitutionalMemoryClass( - elements=[institutional_memory_element] + elements=[institutional_memory_element], ) need_write = True diff --git a/metadata-ingestion/examples/library/dataset_add_glossary_term_patch.py b/metadata-ingestion/examples/library/dataset_add_glossary_term_patch.py index d0b9a866fde615..41482819026294 100644 --- a/metadata-ingestion/examples/library/dataset_add_glossary_term_patch.py +++ b/metadata-ingestion/examples/library/dataset_add_glossary_term_patch.py @@ -8,7 +8,9 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) # Create Dataset Patch to Add + Remove Term for 'profile_id' column diff --git a/metadata-ingestion/examples/library/dataset_add_owner.py b/metadata-ingestion/examples/library/dataset_add_owner.py index 51ae15719ff217..c79f5a6a079d85 100644 --- a/metadata-ingestion/examples/library/dataset_add_owner.py +++ b/metadata-ingestion/examples/library/dataset_add_owner.py @@ -34,7 +34,8 @@ current_owners: Optional[OwnershipClass] = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=OwnershipClass + entity_urn=dataset_urn, + aspect_type=OwnershipClass, ) @@ -58,7 +59,7 @@ ) graph.emit(event) log.info( - f"Owner {owner_to_add}, type {ownership_type} added to dataset {dataset_urn}" + f"Owner {owner_to_add}, type {ownership_type} added to dataset {dataset_urn}", ) else: diff --git a/metadata-ingestion/examples/library/dataset_add_owner_patch.py b/metadata-ingestion/examples/library/dataset_add_owner_patch.py index 8d3130c09c4bbf..8cccf965fff12f 100644 --- a/metadata-ingestion/examples/library/dataset_add_owner_patch.py +++ b/metadata-ingestion/examples/library/dataset_add_owner_patch.py @@ -8,13 +8,15 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) # Create Dataset Patch to Add + Remove Owners patch_builder = DatasetPatchBuilder(dataset_urn) patch_builder.add_owner( - OwnerClass(make_user_urn("user-to-add-id"), OwnershipTypeClass.TECHNICAL_OWNER) + OwnerClass(make_user_urn("user-to-add-id"), OwnershipTypeClass.TECHNICAL_OWNER), ) patch_builder.remove_owner(make_group_urn("group-to-remove-id")) patch_mcps = patch_builder.build() diff --git a/metadata-ingestion/examples/library/dataset_add_structured_properties_patch.py b/metadata-ingestion/examples/library/dataset_add_structured_properties_patch.py index ef72ed58a4b82f..bd3ace02cd16cc 100644 --- a/metadata-ingestion/examples/library/dataset_add_structured_properties_patch.py +++ b/metadata-ingestion/examples/library/dataset_add_structured_properties_patch.py @@ -11,10 +11,11 @@ # Create Dataset Patch to Add and Remove Structured Properties patch_builder = DatasetPatchBuilder(dataset_urn) patch_builder.add_structured_property( - "urn:li:structuredProperty:retentionTimeInDays", 12 + "urn:li:structuredProperty:retentionTimeInDays", + 12, ) patch_builder.remove_structured_property( - "urn:li:structuredProperty:customClassification" + "urn:li:structuredProperty:customClassification", ) patch_mcps = patch_builder.build() diff --git a/metadata-ingestion/examples/library/dataset_add_tag_patch.py b/metadata-ingestion/examples/library/dataset_add_tag_patch.py index 0bc644d6865f63..d7cdb302556374 100644 --- a/metadata-ingestion/examples/library/dataset_add_tag_patch.py +++ b/metadata-ingestion/examples/library/dataset_add_tag_patch.py @@ -8,7 +8,9 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) # Create Dataset Patch diff --git a/metadata-ingestion/examples/library/dataset_add_term.py b/metadata-ingestion/examples/library/dataset_add_term.py index 1d97797edae6d9..ba937c34b4cb4b 100644 --- a/metadata-ingestion/examples/library/dataset_add_term.py +++ b/metadata-ingestion/examples/library/dataset_add_term.py @@ -25,7 +25,8 @@ dataset_urn = make_dataset_urn(platform="hive", name="realestate_db.sales", env="PROD") current_terms: Optional[GlossaryTermsClass] = graph.get_aspect( - entity_urn=dataset_urn, aspect_type=GlossaryTermsClass + entity_urn=dataset_urn, + aspect_type=GlossaryTermsClass, ) term_to_add = make_term_urn("Classification.HighlyConfidential") diff --git a/metadata-ingestion/examples/library/dataset_add_upstream_lineage_patch.py b/metadata-ingestion/examples/library/dataset_add_upstream_lineage_patch.py index 0b4e5e39bf627e..e2ea28fa02d008 100644 --- a/metadata-ingestion/examples/library/dataset_add_upstream_lineage_patch.py +++ b/metadata-ingestion/examples/library/dataset_add_upstream_lineage_patch.py @@ -13,10 +13,14 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) upstream_to_remove_urn = make_dataset_urn( - platform="s3", name="fct_users_old", env="PROD" + platform="s3", + name="fct_users_old", + env="PROD", ) upstream_to_add_urn = make_dataset_urn(platform="s3", name="fct_users_new", env="PROD") @@ -24,7 +28,7 @@ patch_builder = DatasetPatchBuilder(dataset_urn) patch_builder.remove_upstream_lineage(upstream_to_remove_urn) patch_builder.add_upstream_lineage( - UpstreamClass(upstream_to_add_urn, DatasetLineageTypeClass.TRANSFORMED) + UpstreamClass(upstream_to_add_urn, DatasetLineageTypeClass.TRANSFORMED), ) # ...And also include schema field lineage @@ -37,11 +41,12 @@ FineGrainedLineageUpstreamTypeClass.FIELD_SET, [upstream_field_to_add_urn], [downstream_field_to_add_urn], - ) + ), ) upstream_field_to_remove_urn = make_schema_field_urn( - upstream_to_remove_urn, "profile_id" + upstream_to_remove_urn, + "profile_id", ) downstream_field_to_remove_urn = make_schema_field_urn(dataset_urn, "profile_id") @@ -51,7 +56,7 @@ FineGrainedLineageUpstreamTypeClass.FIELD_SET, [upstream_field_to_remove_urn], [downstream_field_to_remove_urn], - ) + ), ) patch_mcps = patch_builder.build() diff --git a/metadata-ingestion/examples/library/dataset_field_add_glossary_term_patch.py b/metadata-ingestion/examples/library/dataset_field_add_glossary_term_patch.py index 3f8da2c143c924..838ed189a9458e 100644 --- a/metadata-ingestion/examples/library/dataset_field_add_glossary_term_patch.py +++ b/metadata-ingestion/examples/library/dataset_field_add_glossary_term_patch.py @@ -8,16 +8,18 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) # Create Dataset Patch to Add + Remove Term for 'profile_id' column patch_builder = DatasetPatchBuilder(dataset_urn) patch_builder.for_field("profile_id").add_term( - GlossaryTermAssociationClass(make_term_urn("term-to-add-id")) + GlossaryTermAssociationClass(make_term_urn("term-to-add-id")), ) patch_builder.for_field("profile_id").remove_term( - "urn:li:glossaryTerm:term-to-remove-id" + "urn:li:glossaryTerm:term-to-remove-id", ) patch_mcps = patch_builder.build() diff --git a/metadata-ingestion/examples/library/dataset_field_add_tag_patch.py b/metadata-ingestion/examples/library/dataset_field_add_tag_patch.py index 3075cac5320ae9..c393e5b5ddbb7c 100644 --- a/metadata-ingestion/examples/library/dataset_field_add_tag_patch.py +++ b/metadata-ingestion/examples/library/dataset_field_add_tag_patch.py @@ -8,13 +8,15 @@ # Create Dataset URN dataset_urn = make_dataset_urn( - platform="snowflake", name="fct_users_created", env="PROD" + platform="snowflake", + name="fct_users_created", + env="PROD", ) # Create Dataset Patch to Add + Remove Tag for 'profile_id' column patch_builder = DatasetPatchBuilder(dataset_urn) patch_builder.for_field("profile_id").add_tag( - TagAssociationClass(make_tag_urn("tag-to-add-id")) + TagAssociationClass(make_tag_urn("tag-to-add-id")), ) patch_builder.for_field("profile_id").remove_tag("urn:li:tag:tag-to-remove-id") patch_mcps = patch_builder.build() diff --git a/metadata-ingestion/examples/library/dataset_replace_properties.py b/metadata-ingestion/examples/library/dataset_replace_properties.py index 8cf1bc89eec9a4..284a73cfec200d 100644 --- a/metadata-ingestion/examples/library/dataset_replace_properties.py +++ b/metadata-ingestion/examples/library/dataset_replace_properties.py @@ -23,9 +23,10 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: return DatahubKafkaEmitter( config=KafkaEmitterConfig( connection=KafkaProducerConnectionConfig( - bootstrap=kafka_server, schema_registry_url=schema_registry_url - ) - ) + bootstrap=kafka_server, + schema_registry_url=schema_registry_url, + ), + ), ) @@ -46,5 +47,5 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: log.info( - f"Replaced custom properties on dataset {dataset_urn} as {property_map_to_set}" + f"Replaced custom properties on dataset {dataset_urn} as {property_map_to_set}", ) diff --git a/metadata-ingestion/examples/library/dataset_report_operation.py b/metadata-ingestion/examples/library/dataset_report_operation.py index 15ebc43dba60a1..8c6aaa7d50abf0 100644 --- a/metadata-ingestion/examples/library/dataset_report_operation.py +++ b/metadata-ingestion/examples/library/dataset_report_operation.py @@ -15,5 +15,7 @@ # Report a change operation for the Dataset. operation_client.report_operation( - urn=dataset_urn, operation_type=operation_type, source_type=source_type + urn=dataset_urn, + operation_type=operation_type, + source_type=source_type, ) diff --git a/metadata-ingestion/examples/library/dataset_schema.py b/metadata-ingestion/examples/library/dataset_schema.py index ed77df9ddd1845..70c9c95aa42490 100644 --- a/metadata-ingestion/examples/library/dataset_schema.py +++ b/metadata-ingestion/examples/library/dataset_schema.py @@ -23,7 +23,8 @@ hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=OtherSchemaClass(rawSchema="__insert raw schema here__"), lastModified=AuditStampClass( - time=1640692800000, actor="urn:li:corpuser:ingestion" + time=1640692800000, + actor="urn:li:corpuser:ingestion", ), fields=[ SchemaFieldClass( @@ -32,7 +33,8 @@ nativeDataType="VARCHAR(50)", # use this to provide the type of the field in the source system's vernacular description="This is the zipcode of the address. Specified using extended form and limited to addresses in the United States", lastModified=AuditStampClass( - time=1640692800000, actor="urn:li:corpuser:ingestion" + time=1640692800000, + actor="urn:li:corpuser:ingestion", ), ), SchemaFieldClass( @@ -41,7 +43,8 @@ nativeDataType="VARCHAR(100)", description="Street corresponding to the address", lastModified=AuditStampClass( - time=1640692800000, actor="urn:li:corpuser:ingestion" + time=1640692800000, + actor="urn:li:corpuser:ingestion", ), ), SchemaFieldClass( @@ -50,10 +53,12 @@ nativeDataType="Date", description="Date of the last sale date for this property", created=AuditStampClass( - time=1640692800000, actor="urn:li:corpuser:ingestion" + time=1640692800000, + actor="urn:li:corpuser:ingestion", ), lastModified=AuditStampClass( - time=1640692800000, actor="urn:li:corpuser:ingestion" + time=1640692800000, + actor="urn:li:corpuser:ingestion", ), ), ], diff --git a/metadata-ingestion/examples/library/dataset_schema_with_tags_terms.py b/metadata-ingestion/examples/library/dataset_schema_with_tags_terms.py index eb9088844f04ea..aa1ad3fee6e221 100644 --- a/metadata-ingestion/examples/library/dataset_schema_with_tags_terms.py +++ b/metadata-ingestion/examples/library/dataset_schema_with_tags_terms.py @@ -43,22 +43,22 @@ # It is rare to attach tags to fields as part of the technical schema unless you are purely reflecting state that exists in the source system. # For an editable (in UI) version of this, use the editableSchemaMetadata aspect globalTags=GlobalTagsClass( - tags=[TagAssociationClass(tag=make_tag_urn("location"))] + tags=[TagAssociationClass(tag=make_tag_urn("location"))], ), # It is rare to attach glossary terms to fields as part of the technical schema unless you are purely reflecting state that exists in the source system. # For an editable (in UI) version of this, use the editableSchemaMetadata aspect glossaryTerms=GlossaryTermsClass( terms=[ GlossaryTermAssociationClass( - urn=make_term_urn("Classification.PII") - ) + urn=make_term_urn("Classification.PII"), + ), ], auditStamp=AuditStampClass( # represents the time when this term was attached to this field? time=0, # time in milliseconds, leave as 0 if no time of association is known actor="urn:li:corpuser:ingestion", # if this is a system provided tag, use a bot user id like ingestion ), ), - ) + ), ], ), ) diff --git a/metadata-ingestion/examples/library/delete_assertion.py b/metadata-ingestion/examples/library/delete_assertion.py index 08df53ba1c3d30..0e7f6217c8b5c8 100644 --- a/metadata-ingestion/examples/library/delete_assertion.py +++ b/metadata-ingestion/examples/library/delete_assertion.py @@ -7,7 +7,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) assertion_urn = "urn:li:assertion:my-assertion" diff --git a/metadata-ingestion/examples/library/delete_dataset.py b/metadata-ingestion/examples/library/delete_dataset.py index 548cc0d850b29d..b2f09e5e624f05 100644 --- a/metadata-ingestion/examples/library/delete_dataset.py +++ b/metadata-ingestion/examples/library/delete_dataset.py @@ -9,7 +9,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) dataset_urn = make_dataset_urn(name="fct_users_created", platform="hive") diff --git a/metadata-ingestion/examples/library/delete_form.py b/metadata-ingestion/examples/library/delete_form.py index 189e80435fcf01..a8e1aa23a4751a 100644 --- a/metadata-ingestion/examples/library/delete_form.py +++ b/metadata-ingestion/examples/library/delete_form.py @@ -9,7 +9,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) form_urn = FormUrn("metadata_initiative_1") diff --git a/metadata-ingestion/examples/library/lineage_dataset_job_dataset.py b/metadata-ingestion/examples/library/lineage_dataset_job_dataset.py index 0c625c8e4a168d..fbdcd74eb7b336 100644 --- a/metadata-ingestion/examples/library/lineage_dataset_job_dataset.py +++ b/metadata-ingestion/examples/library/lineage_dataset_job_dataset.py @@ -13,14 +13,19 @@ output_datasets: List[str] = [ builder.make_dataset_urn( - platform="kafka", name="debezium.topics.librarydb.member_checkout", env="PROD" - ) + platform="kafka", + name="debezium.topics.librarydb.member_checkout", + env="PROD", + ), ] input_data_jobs: List[str] = [ builder.make_data_job_urn( - orchestrator="airflow", flow_id="flow1", job_id="job0", cluster="PROD" - ) + orchestrator="airflow", + flow_id="flow1", + job_id="job0", + cluster="PROD", + ), ] datajob_input_output = DataJobInputOutputClass( @@ -33,7 +38,10 @@ # NOTE: This will overwrite all of the existing lineage information associated with this job. datajob_input_output_mcp = MetadataChangeProposalWrapper( entityUrn=builder.make_data_job_urn( - orchestrator="airflow", flow_id="flow1", job_id="job1", cluster="PROD" + orchestrator="airflow", + flow_id="flow1", + job_id="job1", + cluster="PROD", ), aspect=datajob_input_output, ) diff --git a/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained.py b/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained.py index 2d417c1646a4f9..13444bcfeb4034 100644 --- a/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained.py +++ b/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained.py @@ -70,7 +70,8 @@ def fldUrn(tbl, fld): upstream = Upstream(dataset=datasetUrn("bar2"), type=DatasetLineageType.TRANSFORMED) fieldLineages = UpstreamLineage( - upstreams=[upstream], fineGrainedLineages=fineGrainedLineages + upstreams=[upstream], + fineGrainedLineages=fineGrainedLineages, ) lineageMcp = MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained_sample.py b/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained_sample.py index 637b24887f35df..ed9627861f18b5 100644 --- a/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained_sample.py +++ b/metadata-ingestion/examples/library/lineage_emitter_dataset_finegrained_sample.py @@ -34,11 +34,13 @@ def fldUrn(tbl, fld): # this is just to check if any conflicts with existing Upstream, particularly the DownstreamOf relationship upstream = Upstream( - dataset=datasetUrn("fct_users_deleted"), type=DatasetLineageType.TRANSFORMED + dataset=datasetUrn("fct_users_deleted"), + type=DatasetLineageType.TRANSFORMED, ) fieldLineages = UpstreamLineage( - upstreams=[upstream], fineGrainedLineages=fineGrainedLineages + upstreams=[upstream], + fineGrainedLineages=fineGrainedLineages, ) lineageMcp = MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/examples/library/lineage_emitter_kafka.py b/metadata-ingestion/examples/library/lineage_emitter_kafka.py index 81a294908eeec6..19f0d5759f92a4 100644 --- a/metadata-ingestion/examples/library/lineage_emitter_kafka.py +++ b/metadata-ingestion/examples/library/lineage_emitter_kafka.py @@ -19,9 +19,9 @@ "bootstrap": "broker:9092", "producer_config": {}, "schema_registry_url": "http://schema-registry:8081", - } - } - ) + }, + }, + ), ) diff --git a/metadata-ingestion/examples/library/lineage_job_dataflow.py b/metadata-ingestion/examples/library/lineage_job_dataflow.py index 6ab631491668fd..f0460b12eda91b 100644 --- a/metadata-ingestion/examples/library/lineage_job_dataflow.py +++ b/metadata-ingestion/examples/library/lineage_job_dataflow.py @@ -6,7 +6,9 @@ # Construct the DataJobInfo aspect with the job -> flow lineage. dataflow_urn = builder.make_data_flow_urn( - orchestrator="airflow", flow_id="flow_old_api", cluster="prod" + orchestrator="airflow", + flow_id="flow_old_api", + cluster="prod", ) dataflow_info = DataFlowInfoClass(name="LowLevelApiFlow") @@ -22,7 +24,10 @@ # NOTE: This will overwrite all of the existing dataJobInfo aspect information associated with this job. datajob_info_mcp = MetadataChangeProposalWrapper( entityUrn=builder.make_data_job_urn( - orchestrator="airflow", flow_id="flow_old_api", job_id="job1", cluster="prod" + orchestrator="airflow", + flow_id="flow_old_api", + job_id="job1", + cluster="prod", ), aspect=datajob_info, ) diff --git a/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_simple.py b/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_simple.py index d47cdbce143036..7688fd81803ccf 100644 --- a/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_simple.py +++ b/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_simple.py @@ -34,15 +34,18 @@ dataJob4.emit(emitter) jobFlowRun = DataProcessInstance.from_dataflow( - dataflow=jobFlow, id=f"{jobFlow.id}-{uuid.uuid4()}" + dataflow=jobFlow, + id=f"{jobFlow.id}-{uuid.uuid4()}", ) jobFlowRun.emit_process_start( - emitter, int(datetime.now(timezone.utc).timestamp() * 1000) + emitter, + int(datetime.now(timezone.utc).timestamp() * 1000), ) jobRun = DataProcessInstance.from_datajob( - datajob=dataJob, id=f"{jobFlow.id}-{uuid.uuid4()}" + datajob=dataJob, + id=f"{jobFlow.id}-{uuid.uuid4()}", ) jobRun.emit_process_start(emitter, int(datetime.now(timezone.utc).timestamp() * 1000)) @@ -54,7 +57,8 @@ job2Run = DataProcessInstance.from_datajob( - datajob=dataJob2, id=f"{jobFlow.id}-{uuid.uuid4()}" + datajob=dataJob2, + id=f"{jobFlow.id}-{uuid.uuid4()}", ) job2Run.emit_process_start(emitter, int(datetime.now(timezone.utc).timestamp() * 1000)) @@ -66,7 +70,8 @@ job3Run = DataProcessInstance.from_datajob( - datajob=dataJob3, id=f"{jobFlow.id}-{uuid.uuid4()}" + datajob=dataJob3, + id=f"{jobFlow.id}-{uuid.uuid4()}", ) job3Run.emit_process_start(emitter, int(datetime.now(timezone.utc).timestamp() * 1000)) @@ -78,7 +83,8 @@ job4Run = DataProcessInstance.from_datajob( - datajob=dataJob4, id=f"{jobFlow.id}-{uuid.uuid4()}" + datajob=dataJob4, + id=f"{jobFlow.id}-{uuid.uuid4()}", ) job4Run.emit_process_start(emitter, int(datetime.now(timezone.utc).timestamp() * 1000)) diff --git a/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_verbose.py b/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_verbose.py index 67a8ce2059679a..be9667389351c8 100644 --- a/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_verbose.py +++ b/metadata-ingestion/examples/library/lineage_job_dataflow_new_api_verbose.py @@ -39,7 +39,9 @@ # Hello World jobFlowRun: DataProcessInstance = DataProcessInstance( - orchestrator="airflow", cluster="prod", id=f"{jobFlow.id}-{uuid.uuid4()}" + orchestrator="airflow", + cluster="prod", + id=f"{jobFlow.id}-{uuid.uuid4()}", ) jobRun1: DataProcessInstance = DataProcessInstance( orchestrator="airflow", @@ -49,7 +51,9 @@ jobRun1.parent_instance = jobFlowRun.urn jobRun1.template_urn = dataJob.urn jobRun1.emit_process_start( - emitter=emitter, start_timestamp_millis=int(time.time() * 1000), emit_template=False + emitter=emitter, + start_timestamp_millis=int(time.time() * 1000), + emit_template=False, ) jobRun1.emit_process_end( emitter=emitter, @@ -66,7 +70,9 @@ jobRun2.parent_instance = jobFlowRun.urn jobRun2.upstream_urns = [jobRun1.urn] jobRun2.emit_process_start( - emitter=emitter, start_timestamp_millis=int(time.time() * 1000), emit_template=False + emitter=emitter, + start_timestamp_millis=int(time.time() * 1000), + emit_template=False, ) jobRun2.emit_process_end( emitter=emitter, @@ -84,7 +90,9 @@ jobRun3.template_urn = dataJob3.urn jobRun3.upstream_urns = [jobRun1.urn] jobRun3.emit_process_start( - emitter=emitter, start_timestamp_millis=int(time.time() * 1000), emit_template=False + emitter=emitter, + start_timestamp_millis=int(time.time() * 1000), + emit_template=False, ) jobRun3.emit_process_end( emitter=emitter, @@ -101,7 +109,9 @@ jobRun4.template_urn = dataJob4.urn jobRun4.upstream_urns = [jobRun2.urn, jobRun3.urn] jobRun4.emit_process_start( - emitter=emitter, start_timestamp_millis=int(time.time() * 1000), emit_template=False + emitter=emitter, + start_timestamp_millis=int(time.time() * 1000), + emit_template=False, ) jobRun4.emit_process_end( emitter=emitter, diff --git a/metadata-ingestion/examples/library/programatic_pipeline.py b/metadata-ingestion/examples/library/programatic_pipeline.py index a4b5ecc1c650ad..6d4d856ee40f8f 100644 --- a/metadata-ingestion/examples/library/programatic_pipeline.py +++ b/metadata-ingestion/examples/library/programatic_pipeline.py @@ -16,7 +16,7 @@ "type": "datahub-rest", "config": {"server": "http://localhost:8080"}, }, - } + }, ) # Run the pipeline and report the results. diff --git a/metadata-ingestion/examples/library/read_lineage_execute_graphql.py b/metadata-ingestion/examples/library/read_lineage_execute_graphql.py index 7b7f8ef43f4f5e..a421d69111b875 100644 --- a/metadata-ingestion/examples/library/read_lineage_execute_graphql.py +++ b/metadata-ingestion/examples/library/read_lineage_execute_graphql.py @@ -33,11 +33,11 @@ "negated": "false", "field": "degree", "values": ["1", "2", "3+"], - } - ] - } + }, + ], + }, ], - } + }, } result = graph.execute_graphql(query=query, variables=variables) diff --git a/metadata-ingestion/examples/library/report_assertion_result.py b/metadata-ingestion/examples/library/report_assertion_result.py index 17b075b229916f..1f081b18c9ed61 100644 --- a/metadata-ingestion/examples/library/report_assertion_result.py +++ b/metadata-ingestion/examples/library/report_assertion_result.py @@ -8,7 +8,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) existing_assertion_urn = "urn:li:assertion:my-unique-assertion-id" diff --git a/metadata-ingestion/examples/library/run_assertion.py b/metadata-ingestion/examples/library/run_assertion.py index e7c717837eed3c..6341bb79acce0f 100644 --- a/metadata-ingestion/examples/library/run_assertion.py +++ b/metadata-ingestion/examples/library/run_assertion.py @@ -7,7 +7,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) assertion_urn = "urn:li:assertion:6e3f9e09-1483-40f9-b9cd-30e5f182694a" @@ -16,5 +16,5 @@ assertion_result = graph.run_assertion(urn=assertion_urn, save_result=True) log.info( - f"Assertion result (SUCCESS / FAILURE / ERROR): {assertion_result.get('type')}" + f"Assertion result (SUCCESS / FAILURE / ERROR): {assertion_result.get('type')}", ) diff --git a/metadata-ingestion/examples/library/run_assertions.py b/metadata-ingestion/examples/library/run_assertions.py index 6d38d9b5edecd9..eddc0ea4f7b0f5 100644 --- a/metadata-ingestion/examples/library/run_assertions.py +++ b/metadata-ingestion/examples/library/run_assertions.py @@ -7,7 +7,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) assertion_urns = [ @@ -17,21 +17,21 @@ # Run the assertions assertion_results = graph.run_assertions(urns=assertion_urns, save_result=True).get( - "results" + "results", ) if assertion_results is not None: assertion_result_1 = assertion_results.get( - "urn:li:assertion:6e3f9e09-1483-40f9-b9cd-30e5f182694a" + "urn:li:assertion:6e3f9e09-1483-40f9-b9cd-30e5f182694a", ) assertion_result_2 = assertion_results.get( - "urn:li:assertion:9e3f9e09-1483-40f9-b9cd-30e5f182694g" + "urn:li:assertion:9e3f9e09-1483-40f9-b9cd-30e5f182694g", ) log.info(f"Assertion results: {assertion_results}") log.info( - f"Assertion result 1 (SUCCESS / FAILURE / ERROR): {assertion_result_1.get('type')}" + f"Assertion result 1 (SUCCESS / FAILURE / ERROR): {assertion_result_1.get('type')}", ) log.info( - f"Assertion result 2 (SUCCESS / FAILURE / ERROR): {assertion_result_2.get('type')}" + f"Assertion result 2 (SUCCESS / FAILURE / ERROR): {assertion_result_2.get('type')}", ) diff --git a/metadata-ingestion/examples/library/run_assertions_for_asset.py b/metadata-ingestion/examples/library/run_assertions_for_asset.py index ab2793c3b5b8a6..0ccb1f9ae282dd 100644 --- a/metadata-ingestion/examples/library/run_assertions_for_asset.py +++ b/metadata-ingestion/examples/library/run_assertions_for_asset.py @@ -7,7 +7,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) dataset_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_snowflake_table,PROD)" @@ -17,22 +17,23 @@ if assertion_results is not None: assertion_result_1 = assertion_results.get( - "urn:li:assertion:6e3f9e09-1483-40f9-b9cd-30e5f182694a" + "urn:li:assertion:6e3f9e09-1483-40f9-b9cd-30e5f182694a", ) assertion_result_2 = assertion_results.get( - "urn:li:assertion:9e3f9e09-1483-40f9-b9cd-30e5f182694g" + "urn:li:assertion:9e3f9e09-1483-40f9-b9cd-30e5f182694g", ) log.info(f"Assertion results: {assertion_results}") log.info( - f"Assertion result 1 (SUCCESS / FAILURE / ERROR): {assertion_result_1.get('type')}" + f"Assertion result 1 (SUCCESS / FAILURE / ERROR): {assertion_result_1.get('type')}", ) log.info( - f"Assertion result 2 (SUCCESS / FAILURE / ERROR): {assertion_result_2.get('type')}" + f"Assertion result 2 (SUCCESS / FAILURE / ERROR): {assertion_result_2.get('type')}", ) # Run a subset of native assertions having a specific tag important_assertion_tag = "urn:li:tag:my-important-assertion-tag" assertion_results = graph.run_assertions_for_asset( - urn=dataset_urn, tag_urns=[important_assertion_tag] + urn=dataset_urn, + tag_urns=[important_assertion_tag], ).get("results") diff --git a/metadata-ingestion/examples/library/update_form.py b/metadata-ingestion/examples/library/update_form.py index b7ae2a3003af98..902857c660b5bb 100644 --- a/metadata-ingestion/examples/library/update_form.py +++ b/metadata-ingestion/examples/library/update_form.py @@ -31,9 +31,10 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: return DatahubKafkaEmitter( config=KafkaEmitterConfig( connection=KafkaProducerConnectionConfig( - bootstrap=kafka_server, schema_registry_url=schema_registry_url - ) - ) + bootstrap=kafka_server, + schema_registry_url=schema_registry_url, + ), + ), ) @@ -46,7 +47,7 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: title="title", type=FormPromptTypeClass.STRUCTURED_PROPERTY, structuredPropertyParams=StructuredPropertyParamsClass( - "urn:li:structuredProperty:io.acryl.test" + "urn:li:structuredProperty:io.acryl.test", ), required=True, ) @@ -55,7 +56,7 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: title="title", type=FormPromptTypeClass.FIELDS_STRUCTURED_PROPERTY, structuredPropertyParams=StructuredPropertyParamsClass( - "urn:li:structuredProperty:io.acryl.test" + "urn:li:structuredProperty:io.acryl.test", ), required=True, ) @@ -65,8 +66,9 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: FormPatchBuilder(str(form_urn)) .add_owner( OwnerClass( - owner="urn:li:corpuser:jdoe", type=OwnershipTypeClass.TECHNICAL_OWNER - ) + owner="urn:li:corpuser:jdoe", + type=OwnershipTypeClass.TECHNICAL_OWNER, + ), ) .set_name("New Name") .set_description("New description here") diff --git a/metadata-ingestion/examples/library/upsert_custom_assertion.py b/metadata-ingestion/examples/library/upsert_custom_assertion.py index 1878a894c0455c..29ce6fae02d673 100644 --- a/metadata-ingestion/examples/library/upsert_custom_assertion.py +++ b/metadata-ingestion/examples/library/upsert_custom_assertion.py @@ -7,7 +7,7 @@ graph = DataHubGraph( config=DatahubClientConfig( server="http://localhost:8080", - ) + ), ) new_assertion_urn = "urn:li:assertion:my-unique-assertion-id" diff --git a/metadata-ingestion/examples/library/upsert_group.py b/metadata-ingestion/examples/library/upsert_group.py index 84844e142f46c0..e5da5844ec13f5 100644 --- a/metadata-ingestion/examples/library/upsert_group.py +++ b/metadata-ingestion/examples/library/upsert_group.py @@ -29,8 +29,9 @@ for event in group.generate_mcp( generation_config=CorpGroupGenerationConfig( - override_editable=False, datahub_graph=datahub_graph - ) + override_editable=False, + datahub_graph=datahub_graph, + ), ): datahub_graph.emit(event) log.info(f"Upserted group {group.urn}") diff --git a/metadata-ingestion/examples/library/upsert_user.py b/metadata-ingestion/examples/library/upsert_user.py index 8b26bba0f34005..6078a48c85e31c 100644 --- a/metadata-ingestion/examples/library/upsert_user.py +++ b/metadata-ingestion/examples/library/upsert_user.py @@ -21,7 +21,7 @@ # Create graph client datahub_graph = DataHubGraph(DataHubGraphConfig(server="http://localhost:8080")) for event in user.generate_mcp( - generation_config=CorpUserGenerationConfig(override_editable=False) + generation_config=CorpUserGenerationConfig(override_editable=False), ): datahub_graph.emit(event) log.info(f"Upserted user {user.urn}") diff --git a/metadata-ingestion/examples/perf/lineage_perf_example.py b/metadata-ingestion/examples/perf/lineage_perf_example.py index 3ee78bacb268ad..e10ecbbd9ea4ab 100644 --- a/metadata-ingestion/examples/perf/lineage_perf_example.py +++ b/metadata-ingestion/examples/perf/lineage_perf_example.py @@ -16,7 +16,8 @@ def lineage_mcp_generator( - urn: str, upstreams: List[str] + urn: str, + upstreams: List[str], ) -> Iterable[MetadataChangeProposalWrapper]: yield MetadataChangeProposalWrapper( entityUrn=urn, @@ -27,12 +28,13 @@ def lineage_mcp_generator( type=DatasetLineageTypeClass.TRANSFORMED, ) for upstream in upstreams - ] + ], ), ) for upstream in upstreams: yield MetadataChangeProposalWrapper( - entityUrn=upstream, aspect=StatusClass(removed=False) + entityUrn=upstream, + aspect=StatusClass(removed=False), ) for urn_itr in [urn, *upstreams]: yield MetadataChangeProposalWrapper( @@ -42,7 +44,9 @@ def lineage_mcp_generator( def datajob_lineage_mcp_generator( - urn: str, upstreams: List[str], downstreams: List[str] + urn: str, + upstreams: List[str], + downstreams: List[str], ) -> Iterable[MetadataChangeProposalWrapper]: yield MetadataChangeProposalWrapper( entityUrn=urn, @@ -53,11 +57,13 @@ def datajob_lineage_mcp_generator( ) for upstream in upstreams: yield MetadataChangeProposalWrapper( - entityUrn=upstream, aspect=StatusClass(removed=False) + entityUrn=upstream, + aspect=StatusClass(removed=False), ) for downstream in downstreams: yield MetadataChangeProposalWrapper( - entityUrn=downstream, aspect=StatusClass(removed=False) + entityUrn=downstream, + aspect=StatusClass(removed=False), ) @@ -76,7 +82,8 @@ def scenario_truncate_basic(): for i in range(10): yield from lineage_mcp_generator( - make_dataset_urn("snowflake", f"{path}.d_{i}"), [root_urn] + make_dataset_urn("snowflake", f"{path}.d_{i}"), + [root_urn], ) @@ -90,7 +97,8 @@ def scenario_truncate_intermediate(): root_urn = make_dataset_urn("snowflake", f"{path}.root") yield from lineage_mcp_generator( - root_urn, [make_dataset_urn("snowflake", f"{path}.u_{i}") for i in range(10)] + root_urn, + [make_dataset_urn("snowflake", f"{path}.u_{i}") for i in range(10)], ) for i in range(3): @@ -101,7 +109,8 @@ def scenario_truncate_intermediate(): for i in range(3): yield from lineage_mcp_generator( - make_dataset_urn("snowflake", f"{path}.d_{i}"), [root_urn] + make_dataset_urn("snowflake", f"{path}.d_{i}"), + [root_urn], ) for j in range(3): yield from lineage_mcp_generator( @@ -139,7 +148,8 @@ def scenario_truncate_complex(): } lvl_e = { (a, b, c, d): make_dataset_urn( - "snowflake", f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}" + "snowflake", + f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}", ) for a in range(2) for b in range(3) @@ -148,7 +158,8 @@ def scenario_truncate_complex(): } lvl_f = { (a, b, c, d, e): make_dataset_urn( - "snowflake", f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}_u_{e}" + "snowflake", + f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}_u_{e}", ) for a in range(2) for b in range(3) @@ -167,7 +178,8 @@ def scenario_truncate_complex(): yield from lineage_mcp_generator(urn, [lvl_e[(a, b, c, d)] for d in range(5)]) for (a, b, c, d), urn in lvl_e.items(): yield from lineage_mcp_generator( - urn, [lvl_f[(a, b, c, d, e)] for e in range(6 if d % 2 == 0 else 1)] + urn, + [lvl_f[(a, b, c, d, e)] for e in range(6 if d % 2 == 0 else 1)], ) @@ -239,7 +251,9 @@ def scenario_skip_intermediate(): ], ) yield from datajob_lineage_mcp_generator( - upstream_airflow_urn, [upstream_dbt_urns[1]], [upstream_dbt_urns[0]] + upstream_airflow_urn, + [upstream_dbt_urns[1]], + [upstream_dbt_urns[0]], ) yield from lineage_mcp_generator( upstream_dbt_urns[1], @@ -301,7 +315,9 @@ def scenario_skip_complex(): ], ) yield from datajob_lineage_mcp_generator( - upstream_airflow_urn, [upstream_dbt_urns[1]], [upstream_dbt_urns[0]] + upstream_airflow_urn, + [upstream_dbt_urns[1]], + [upstream_dbt_urns[0]], ) yield from lineage_mcp_generator( upstream_dbt_urns[1], @@ -370,7 +386,8 @@ def scenario_perf(): } lvl_e = { (a, b, c, d): make_dataset_urn( - "snowflake", f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}" + "snowflake", + f"{path}.u_0_u_{a}_u_{b}_u_{c}_u_{d}", ) for a in range(100) for b in range(30) diff --git a/metadata-ingestion/examples/structured_properties/create_structured_property.py b/metadata-ingestion/examples/structured_properties/create_structured_property.py index 64bc0a67812775..a6b72ffc96e462 100644 --- a/metadata-ingestion/examples/structured_properties/create_structured_property.py +++ b/metadata-ingestion/examples/structured_properties/create_structured_property.py @@ -52,7 +52,7 @@ "allowedTypes": [ "urn:li:entityType:datahub.corpuser", "urn:li:entityType:datahub.corpGroup", - ] + ], }, # this line ensures only user or group urns can be assigned as values ) diff --git a/metadata-ingestion/examples/structured_properties/update_structured_property.py b/metadata-ingestion/examples/structured_properties/update_structured_property.py index 6f4b8b3be20d15..0bc537b16448bb 100644 --- a/metadata-ingestion/examples/structured_properties/update_structured_property.py +++ b/metadata-ingestion/examples/structured_properties/update_structured_property.py @@ -23,9 +23,10 @@ def get_emitter() -> Union[DataHubRestEmitter, DatahubKafkaEmitter]: return DatahubKafkaEmitter( config=KafkaEmitterConfig( connection=KafkaProducerConnectionConfig( - bootstrap=kafka_server, schema_registry_url=schema_registry_url - ) - ) + bootstrap=kafka_server, + schema_registry_url=schema_registry_url, + ), + ), ) diff --git a/metadata-ingestion/examples/transforms/custom_transform_example.py b/metadata-ingestion/examples/transforms/custom_transform_example.py index d8639bafe9b83c..d091fc4965012e 100644 --- a/metadata-ingestion/examples/transforms/custom_transform_example.py +++ b/metadata-ingestion/examples/transforms/custom_transform_example.py @@ -53,7 +53,10 @@ def aspect_name(self) -> str: return "ownership" def transform_aspect( # type: ignore - self, entity_urn: str, aspect_name: str, aspect: Optional[OwnershipClass] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[OwnershipClass], ) -> Optional[OwnershipClass]: owners_to_add = self.owners assert aspect is None or isinstance(aspect, OwnershipClass) diff --git a/metadata-ingestion/pyproject.toml b/metadata-ingestion/pyproject.toml index 1d434eb8c3a94f..bd0b6b7d175a73 100644 --- a/metadata-ingestion/pyproject.toml +++ b/metadata-ingestion/pyproject.toml @@ -29,6 +29,7 @@ extend-select = [ "G010", # logging.warn -> logging.warning "I", # isort "TID", # flake8-tidy-imports + "COM", # flake8-comprehensions ] extend-ignore = [ "E501", # Handled by formatter @@ -38,6 +39,7 @@ extend-ignore = [ "E203", # Ignore whitespace before ':' (matches Black) "B019", # Allow usages of functools.lru_cache "B008", # Allow function call in argument defaults + "COM812", # Avoid conflicts with formatter # TODO: Enable these later "B006", # Mutable args "B017", # Do not assert blind exception diff --git a/metadata-ingestion/src/datahub/_codegen/aspect.py b/metadata-ingestion/src/datahub/_codegen/aspect.py index 28fa3f1536a862..517007873b97ad 100644 --- a/metadata-ingestion/src/datahub/_codegen/aspect.py +++ b/metadata-ingestion/src/datahub/_codegen/aspect.py @@ -19,7 +19,7 @@ def __init__(self): # Ensure that it cannot be instantiated directly, as # per https://stackoverflow.com/a/7989101/5004662. raise TypeError( - "_Aspect is an abstract class, and cannot be instantiated directly." + "_Aspect is an abstract class, and cannot be instantiated directly.", ) super().__init__() diff --git a/metadata-ingestion/src/datahub/api/circuit_breaker/assertion_circuit_breaker.py b/metadata-ingestion/src/datahub/api/circuit_breaker/assertion_circuit_breaker.py index 283cdaa8333338..f746ff15f6c760 100644 --- a/metadata-ingestion/src/datahub/api/circuit_breaker/assertion_circuit_breaker.py +++ b/metadata-ingestion/src/datahub/api/circuit_breaker/assertion_circuit_breaker.py @@ -53,7 +53,9 @@ def get_last_updated(self, urn: str) -> Optional[datetime]: return parse_ts_millis(operations[0]["lastUpdatedTimestamp"]) def _check_if_assertion_failed( - self, assertions: List[Dict[str, Any]], last_updated: Optional[datetime] = None + self, + assertions: List[Dict[str, Any]], + last_updated: Optional[datetime] = None, ) -> bool: @dataclass class AssertionResult: @@ -87,7 +89,7 @@ class AssertionResult: if last_assertion.state == "FAILURE": logger.debug(f"Runevent: {last_assertion.run_event}") logger.info( - f"Assertion {assertion_urn} is failed on dataset. Breaking the circuit" + f"Assertion {assertion_urn} is failed on dataset. Breaking the circuit", ) return True elif last_assertion.state == "SUCCESS": @@ -97,7 +99,7 @@ class AssertionResult: last_run = parse_ts_millis(last_assertion.time) if last_updated > last_run: logger.error( - f"Missing assertion run for {assertion_urn}. The dataset was updated on {last_updated} but the last assertion run was at {last_run}" + f"Missing assertion run for {assertion_urn}. The dataset was updated on {last_updated} but the last assertion run was at {last_run}", ) return True return result @@ -114,13 +116,13 @@ def is_circuit_breaker_active(self, urn: str) -> bool: if self.config.verify_after_last_update: last_updated = self.get_last_updated(urn) logger.info( - f"The dataset {urn} was last updated at {last_updated}, using this as min assertion date." + f"The dataset {urn} was last updated at {last_updated}, using this as min assertion date.", ) if not last_updated: last_updated = datetime.now(tz=timezone.utc) - self.config.time_delta logger.info( - f"Dataset {urn} doesn't have last updated or check_last_assertion_time is false, using calculated min assertion date {last_updated}" + f"Dataset {urn} doesn't have last updated or check_last_assertion_time is false, using calculated min assertion date {last_updated}", ) assertions = self.assertion_api.query_assertion( diff --git a/metadata-ingestion/src/datahub/api/circuit_breaker/operation_circuit_breaker.py b/metadata-ingestion/src/datahub/api/circuit_breaker/operation_circuit_breaker.py index 58a4ee37d959b6..da72e8ac491050 100644 --- a/metadata-ingestion/src/datahub/api/circuit_breaker/operation_circuit_breaker.py +++ b/metadata-ingestion/src/datahub/api/circuit_breaker/operation_circuit_breaker.py @@ -61,7 +61,7 @@ def is_circuit_breaker_active( """ start_time_millis: int = int( - (datetime.now() - self.config.time_delta).timestamp() * 1000 + (datetime.now() - self.config.time_delta).timestamp() * 1000, ) operations = self.operation_api.query_operations( urn, diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py index e0975a1c0351c7..f8b125baecb3f0 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion.py @@ -49,9 +49,10 @@ class BaseAssertion(v1_ConfigModel): class BaseEntityAssertion(BaseAssertion): entity: str = v1_Field( - description="The entity urn that the assertion is associated with" + description="The entity urn that the assertion is associated with", ) trigger: Optional[AssertionTrigger] = v1_Field( - description="The trigger schedule for assertion", alias="schedule" + description="The trigger schedule for assertion", + alias="schedule", ) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py index a05386798495de..b194cca385c8e3 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_operator.py @@ -30,19 +30,22 @@ def _generate_assertion_std_parameter( ) -> AssertionStdParameterClass: if isinstance(value, str): return AssertionStdParameterClass( - value=value, type=AssertionStdParameterTypeClass.STRING + value=value, + type=AssertionStdParameterTypeClass.STRING, ) elif isinstance(value, (int, float)): return AssertionStdParameterClass( - value=str(value), type=AssertionStdParameterTypeClass.NUMBER + value=str(value), + type=AssertionStdParameterTypeClass.NUMBER, ) elif isinstance(value, list): return AssertionStdParameterClass( - value=json.dumps(value), type=AssertionStdParameterTypeClass.LIST + value=json.dumps(value), + type=AssertionStdParameterTypeClass.LIST, ) else: raise ValueError( - f"Unsupported assertion parameter {value} of type {type(value)}" + f"Unsupported assertion parameter {value} of type {type(value)}", ) @@ -99,7 +102,8 @@ def id(self) -> str: def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters( - min_value=self.min, max_value=self.max + min_value=self.min, + max_value=self.max, ) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py index d7809164847447..14df6b1a8ba6af 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/assertion_trigger.py @@ -14,7 +14,7 @@ class CronTrigger(v1_ConfigModel): type: Literal["cron"] cron: str = v1_Field( - description="The cron expression to use. See https://crontab.guru/ for help." + description="The cron expression to use. See https://crontab.guru/ for help.", ) timezone: str = v1_Field( "UTC", @@ -44,7 +44,10 @@ class ManualTrigger(v1_ConfigModel): class AssertionTrigger(v1_ConfigModel): __root__: Union[ - CronTrigger, IntervalTrigger, EntityChangeTrigger, ManualTrigger + CronTrigger, + IntervalTrigger, + EntityChangeTrigger, + ManualTrigger, ] = v1_Field(discriminator="type") @property diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/compiler_interface.py b/metadata-ingestion/src/datahub/api/entities/assertion/compiler_interface.py index 09a2371329c723..9c786149603ee0 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/compiler_interface.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/compiler_interface.py @@ -54,7 +54,7 @@ class AssertionCompilationResult: status: Literal["success", "failure"] report: AssertionCompilationReport = field( - default_factory=AssertionCompilationReport + default_factory=AssertionCompilationReport, ) artifacts: List[CompileResultArtifact] = field(default_factory=list) @@ -72,6 +72,7 @@ def create(cls, output_dir: str, extras: Dict[str, str]) -> "AssertionCompiler": @abstractmethod def compile( - self, assertion_config_spec: AssertionsConfigSpec + self, + assertion_config_spec: AssertionsConfigSpec, ) -> AssertionCompilationResult: pass diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py index ae062c3a8e5cbd..8e571d5c9b9cf6 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/field_assertion.py @@ -56,7 +56,7 @@ class FieldValuesAssertion(BaseEntityAssertion): operator: Operators = v1_Field(discriminator="type", alias="condition") filters: Optional[DatasetFilter] = v1_Field(default=None) failure_threshold: FieldValuesFailThreshold = v1_Field( - default=FieldValuesFailThreshold() + default=FieldValuesFailThreshold(), ) exclude_nulls: bool = v1_Field(default=True) diff --git a/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py b/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py index f9e1df7d68f271..52437fbd82fae1 100644 --- a/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/assertion/freshness_assertion.py @@ -34,14 +34,14 @@ class CronFreshnessAssertion(BaseEntityAssertion): type: Literal["freshness"] freshness_type: Literal["cron"] cron: str = v1_Field( - description="The cron expression to use. See https://crontab.guru/ for help." + description="The cron expression to use. See https://crontab.guru/ for help.", ) timezone: str = v1_Field( "UTC", description="The timezone to use for the cron schedule. Defaults to UTC.", ) source_type: FreshnessSourceType = v1_Field( - default=FreshnessSourceType.LAST_MODIFIED_COLUMN + default=FreshnessSourceType.LAST_MODIFIED_COLUMN, ) last_modified_field: str filters: Optional[DatasetFilter] = v1_Field(default=None) @@ -69,7 +69,7 @@ class FixedIntervalFreshnessAssertion(BaseEntityAssertion): lookback_interval: timedelta filters: Optional[DatasetFilter] = v1_Field(default=None) source_type: FreshnessSourceType = v1_Field( - default=FreshnessSourceType.LAST_MODIFIED_COLUMN + default=FreshnessSourceType.LAST_MODIFIED_COLUMN, ) last_modified_field: str diff --git a/metadata-ingestion/src/datahub/api/entities/common/data_platform_instance.py b/metadata-ingestion/src/datahub/api/entities/common/data_platform_instance.py index c48b1f6b1c2e22..6f84a904a1f3aa 100644 --- a/metadata-ingestion/src/datahub/api/entities/common/data_platform_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/common/data_platform_instance.py @@ -28,7 +28,8 @@ def to_data_platform_instance(self) -> models.DataPlatformInstanceClass: platform=make_data_platform_urn(self.platform), instance=( make_dataplatform_instance_urn( - platform=self.platform, instance=self.platform_instance + platform=self.platform, + instance=self.platform_instance, ) if self.platform_instance else None diff --git a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py index 0f13ea04ab0753..cb462006af6f50 100644 --- a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py +++ b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py @@ -50,10 +50,10 @@ def as_raw_json(self) -> Optional[Dict]: return object_dict else: logger.warning( - f"Unsupported schema type {self.schema_type} for parsing value" + f"Unsupported schema type {self.schema_type} for parsing value", ) raise ValueError( - f"Unsupported schema type {self.schema_type} for parsing value" + f"Unsupported schema type {self.schema_type} for parsing value", ) def as_pegasus_object(self) -> DictWrapper: @@ -79,11 +79,13 @@ def as_pegasus_object(self) -> DictWrapper: return model_type.from_obj(object_dict or {}) else: raise ValueError( - f"Could not find schema ref {self.schema_ref} for parsing value" + f"Could not find schema ref {self.schema_ref} for parsing value", ) def as_pydantic_object( - self, model_type: Type[BaseModel], validate_schema_ref: bool = False + self, + model_type: Type[BaseModel], + validate_schema_ref: bool = False, ) -> BaseModel: """ Parse the blob into a Pydantic-defined Python object based on the schema type and schema @@ -107,7 +109,8 @@ def as_pydantic_object( @classmethod def from_resource_value( - cls, resource_value: models.SerializedValueClass + cls, + resource_value: models.SerializedValueClass, ) -> "SerializedResourceValue": return cls( content_type=resource_value.contentType, @@ -118,7 +121,8 @@ def from_resource_value( @classmethod def create( - cls, object: Union[DictWrapper, BaseModel, Dict] + cls, + object: Union[DictWrapper, BaseModel, Dict], ) -> "SerializedResourceValue": if isinstance(object, DictWrapper): return SerializedResourceValue( diff --git a/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py b/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py index bf58d2fbbda913..1934801c7291b7 100644 --- a/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py +++ b/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py @@ -84,7 +84,8 @@ def _needs_editable_aspect(self) -> bool: return bool(self.picture_link) def generate_mcp( - self, generation_config: CorpGroupGenerationConfig = CorpGroupGenerationConfig() + self, + generation_config: CorpGroupGenerationConfig = CorpGroupGenerationConfig(), ) -> Iterable[MetadataChangeProposalWrapper]: urns_created = set() # dedup member creation on the way out members_to_create: List[CorpUser] = ( @@ -109,13 +110,13 @@ def generate_mcp( if m.urn not in urns_created: yield from m.generate_mcp( generation_config=CorpUserGenerationConfig( - override_editable=generation_config.override_editable - ) + override_editable=generation_config.override_editable, + ), ) urns_created.add(m.urn) else: logger.warning( - f"Suppressing emission of member {m.urn} before we already emitted metadata for it" + f"Suppressing emission of member {m.urn} before we already emitted metadata for it", ) aspects: List[_Aspect] = [StatusClass(removed=False)] @@ -126,7 +127,7 @@ def generate_mcp( pictureLink=self.picture_link, slack=self.slack, email=self.email, - ) + ), ) else: aspects.append( @@ -138,7 +139,7 @@ def generate_mcp( email=self.email, description=self.description, slack=self.slack, - ) + ), ) # picture link is only available in the editable aspect, so we have to use it if it is provided if self._needs_editable_aspect(): @@ -148,7 +149,7 @@ def generate_mcp( pictureLink=self.picture_link, slack=self.slack, email=self.email, - ) + ), ) for aspect in aspects: yield MetadataChangeProposalWrapper(entityUrn=self.urn, aspect=aspect) @@ -158,7 +159,7 @@ def generate_mcp( ownership = OwnershipClass(owners=[]) for urn in owner_urns: ownership.owners.append( - OwnerClass(owner=urn, type=OwnershipTypeClass.TECHNICAL_OWNER) + OwnerClass(owner=urn, type=OwnershipTypeClass.TECHNICAL_OWNER), ) yield MetadataChangeProposalWrapper(entityUrn=self.urn, aspect=ownership) @@ -171,25 +172,28 @@ def generate_mcp( # Add group membership to each user. for urn in member_urns: group_membership = datahub_graph.get_aspect( - urn, GroupMembershipClass + urn, + GroupMembershipClass, ) or GroupMembershipClass(groups=[]) if self.urn not in group_membership.groups: group_membership.groups = sorted( - set(group_membership.groups + [self.urn]) + set(group_membership.groups + [self.urn]), ) yield MetadataChangeProposalWrapper( - entityUrn=urn, aspect=group_membership + entityUrn=urn, + aspect=group_membership, ) else: if member_urns: raise ConfigurationError( - "Unable to emit group membership because members is non-empty, and a DataHubGraph instance was not provided." + "Unable to emit group membership because members is non-empty, and a DataHubGraph instance was not provided.", ) # emit status aspects for all user urns referenced (to ensure they get created) for urn in set(owner_urns).union(set(member_urns)): yield MetadataChangeProposalWrapper( - entityUrn=urn, aspect=StatusClass(removed=False) + entityUrn=urn, + aspect=StatusClass(removed=False), ) def emit( @@ -215,7 +219,8 @@ def emit( datahub_graph = emitter.to_graph() for mcp in self.generate_mcp( generation_config=CorpGroupGenerationConfig( - override_editable=self.overrideEditable, datahub_graph=datahub_graph - ) + override_editable=self.overrideEditable, + datahub_graph=datahub_graph, + ), ): emitter.emit(mcp, callback) diff --git a/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py b/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py index 9fe1ebedafca7e..f18e951ef0e1b3 100644 --- a/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py +++ b/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py @@ -86,14 +86,15 @@ def _needs_editable_aspect(self) -> bool: def generate_group_membership_aspect(self) -> Iterable[GroupMembershipClass]: if self.groups is not None: group_membership = GroupMembershipClass( - groups=[builder.make_group_urn(group) for group in self.groups] + groups=[builder.make_group_urn(group) for group in self.groups], ) return [group_membership] else: return [] def generate_mcp( - self, generation_config: CorpUserGenerationConfig = CorpUserGenerationConfig() + self, + generation_config: CorpUserGenerationConfig = CorpUserGenerationConfig(), ) -> Iterable[MetadataChangeProposalWrapper]: if generation_config.override_editable or self._needs_editable_aspect(): mcp = MetadataChangeProposalWrapper( @@ -136,7 +137,8 @@ def generate_mcp( # Finally emit status yield MetadataChangeProposalWrapper( - entityUrn=self.urn, aspect=StatusClass(removed=False) + entityUrn=self.urn, + aspect=StatusClass(removed=False), ) def emit( diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py index 145a6097d7336c..0f24b945496afc 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/assertion_operator.py @@ -29,15 +29,17 @@ def _generate_assertion_std_parameter( ) -> AssertionStdParameterClass: if isinstance(value, str): return AssertionStdParameterClass( - value=value, type=AssertionStdParameterTypeClass.STRING + value=value, + type=AssertionStdParameterTypeClass.STRING, ) elif isinstance(value, (int, float)): return AssertionStdParameterClass( - value=str(value), type=AssertionStdParameterTypeClass.NUMBER + value=str(value), + type=AssertionStdParameterTypeClass.NUMBER, ) else: raise ValueError( - f"Unsupported assertion parameter {value} of type {type(value)}" + f"Unsupported assertion parameter {value} of type {type(value)}", ) @@ -81,7 +83,8 @@ def id(self) -> str: def generate_parameters(self) -> AssertionStdParametersClass: return _generate_assertion_std_parameters( - min_value=self.min, max_value=self.max + min_value=self.min, + max_value=self.max, ) diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py index 975aa359bd2031..776c5ed26a92a5 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/data_quality_assertion.py @@ -77,8 +77,9 @@ def generate_assertion_info(self, entity_urn: str) -> AssertionInfoClass: aggregation=AssertionStdAggregationClass.UNIQUE_PROPOTION, # purposely using the misspelled version to work with gql parameters=AssertionStdParametersClass( value=AssertionStdParameterClass( - value="1", type=AssertionStdParameterTypeClass.NUMBER - ) + value="1", + type=AssertionStdParameterTypeClass.NUMBER, + ), ), ) return AssertionInfoClass( @@ -104,11 +105,13 @@ def id(self) -> str: return self.__root__.type def generate_mcp( - self, assertion_urn: str, entity_urn: str + self, + assertion_urn: str, + entity_urn: str, ) -> List[MetadataChangeProposalWrapper]: return [ MetadataChangeProposalWrapper( entityUrn=assertion_urn, aspect=self.__root__.generate_assertion_info(entity_urn), - ) + ), ] diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py index 4a68fa6c66adad..424db48fe72d5f 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/datacontract.py @@ -45,7 +45,7 @@ class DataContract(v1_ConfigModel): description="The data contract urn. If not provided, one will be generated.", ) entity: str = v1_Field( - description="The entity urn that the Data Contract is associated with" + description="The entity urn that the Data Contract is associated with", ) # TODO: add support for properties # properties: Optional[Dict[str, str]] = None @@ -61,7 +61,8 @@ class DataContract(v1_ConfigModel): @v1_validator("data_quality") # type: ignore def validate_data_quality( - cls, data_quality: Optional[List[DataQualityAssertion]] + cls, + data_quality: Optional[List[DataQualityAssertion]], ) -> Optional[List[DataQualityAssertion]]: if data_quality: # Raise an error if there are duplicate ids. @@ -70,7 +71,7 @@ def validate_data_quality( if duplicates: raise ValueError( - f"Got multiple data quality tests with the same type or ID: {duplicates}. Set a unique ID for each data quality test." + f"Got multiple data quality tests with the same type or ID: {duplicates}. Set a unique ID for each data quality test.", ) return data_quality @@ -87,7 +88,8 @@ def urn(self) -> str: return urn def _generate_freshness_assertion( - self, freshness: FreshnessAssertion + self, + freshness: FreshnessAssertion, ) -> Tuple[str, List[MetadataChangeProposalWrapper]]: guid_dict = { "contract": self.urn, @@ -102,7 +104,8 @@ def _generate_freshness_assertion( ) def _generate_schema_assertion( - self, schema_metadata: SchemaAssertion + self, + schema_metadata: SchemaAssertion, ) -> Tuple[str, List[MetadataChangeProposalWrapper]]: # ingredients for guid -> the contract id, the fact that this is a schema assertion and the entity on which the assertion is made guid_dict = { @@ -118,7 +121,8 @@ def _generate_schema_assertion( ) def _generate_data_quality_assertion( - self, data_quality: DataQualityAssertion + self, + data_quality: DataQualityAssertion, ) -> Tuple[str, List[MetadataChangeProposalWrapper]]: guid_dict = { "contract": self.urn, @@ -133,14 +137,15 @@ def _generate_data_quality_assertion( ) def _generate_dq_assertions( - self, data_quality_spec: List[DataQualityAssertion] + self, + data_quality_spec: List[DataQualityAssertion], ) -> Tuple[List[str], List[MetadataChangeProposalWrapper]]: assertion_urns = [] assertion_mcps = [] for dq_check in data_quality_spec: assertion_urn, assertion_mcp = self._generate_data_quality_assertion( - dq_check + dq_check, ) assertion_urns.append(assertion_urn) @@ -168,7 +173,7 @@ def generate_mcp( yield from sla_assertion_mcps dq_assertions, dq_assertion_mcps = self._generate_dq_assertions( - self.data_quality or [] + self.data_quality or [], ) yield from dq_assertion_mcps diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py index 86942766889676..4159a7dd218ec7 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/freshness_assertion.py @@ -25,7 +25,7 @@ class CronFreshnessAssertion(BaseAssertion): type: Literal["cron"] cron: str = v1_Field( - description="The cron expression to use. See https://crontab.guru/ for help." + description="The cron expression to use. See https://crontab.guru/ for help.", ) timezone: str = v1_Field( "UTC", @@ -59,7 +59,7 @@ def generate_freshness_assertion_schedule(self) -> FreshnessAssertionScheduleCla class FreshnessAssertion(v1_ConfigModel): __root__: Union[CronFreshnessAssertion, FixedIntervalFreshnessAssertion] = v1_Field( - discriminator="type" + discriminator="type", ) @property @@ -67,7 +67,9 @@ def id(self): return self.__root__.type def generate_mcp( - self, assertion_urn: str, entity_urn: str + self, + assertion_urn: str, + entity_urn: str, ) -> List[MetadataChangeProposalWrapper]: aspect = AssertionInfoClass( type=AssertionTypeClass.FRESHNESS, diff --git a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py index 39297d1a98d026..dde67049d3ed58 100644 --- a/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py +++ b/metadata-ingestion/src/datahub/api/entities/datacontract/schema_assertion.py @@ -60,7 +60,7 @@ def _init_private_attributes(self) -> None: class SchemaAssertion(v1_ConfigModel): __root__: Union[JsonSchemaContract, FieldListSchemaContract] = v1_Field( - discriminator="type" + discriminator="type", ) @property @@ -68,7 +68,9 @@ def id(self): return self.__root__.type def generate_mcp( - self, assertion_urn: str, entity_urn: str + self, + assertion_urn: str, + entity_urn: str, ) -> List[MetadataChangeProposalWrapper]: aspect = AssertionInfoClass( type=AssertionTypeClass.DATA_SCHEMA, diff --git a/metadata-ingestion/src/datahub/api/entities/datajob/dataflow.py b/metadata-ingestion/src/datahub/api/entities/datajob/dataflow.py index e169c07445e969..4696f1c21ce70f 100644 --- a/metadata-ingestion/src/datahub/api/entities/datajob/dataflow.py +++ b/metadata-ingestion/src/datahub/api/entities/datajob/dataflow.py @@ -60,7 +60,7 @@ class DataFlow: def __post_init__(self): if self.env is not None and self.cluster is not None: raise ValueError( - "Cannot provide both env and cluster parameter. Cluster is deprecated in favor of env." + "Cannot provide both env and cluster parameter. Cluster is deprecated in favor of env.", ) if self.env is None and self.cluster is None: @@ -68,7 +68,7 @@ def __post_init__(self): if self.cluster is not None: logger.warning( - "The cluster argument is deprecated. Use env and possibly platform_instance instead." + "The cluster argument is deprecated. Use env and possibly platform_instance instead.", ) self.env = self.cluster @@ -97,7 +97,8 @@ def generate_ownership_aspect(self): for urn in (owners or []) ], lastModified=AuditStampClass( - time=0, actor=builder.make_user_urn(self.orchestrator) + time=0, + actor=builder.make_user_urn(self.orchestrator), ), ) return [ownership] @@ -107,7 +108,7 @@ def generate_tags_aspect(self) -> List[GlobalTagsClass]: tags=[ TagAssociationClass(tag=builder.make_tag_urn(tag)) for tag in (sorted(self.tags) or []) - ] + ], ) return [tags] @@ -117,7 +118,7 @@ def _get_env(self) -> Optional[str]: env = self.env.upper() else: logger.debug( - f"{self.env} is not a valid environment type so Environment filter won't work." + f"{self.env} is not a valid environment type so Environment filter won't work.", ) return env @@ -137,7 +138,7 @@ def generate_mce(self) -> MetadataChangeEventClass: *self.generate_ownership_aspect(), *self.generate_tags_aspect(), ], - ) + ), ) return flow_mce diff --git a/metadata-ingestion/src/datahub/api/entities/datajob/datajob.py b/metadata-ingestion/src/datahub/api/entities/datajob/datajob.py index 4958a68caa95fe..f998a89d3c260a 100644 --- a/metadata-ingestion/src/datahub/api/entities/datajob/datajob.py +++ b/metadata-ingestion/src/datahub/api/entities/datajob/datajob.py @@ -69,7 +69,8 @@ def __post_init__(self): flow_id=self.flow_urn.flow_id, ) self.urn = DataJobUrn.create_from_ids( - data_flow_urn=str(job_flow_urn), job_id=self.id + data_flow_urn=str(job_flow_urn), + job_id=self.id, ) def generate_ownership_aspect(self) -> Iterable[OwnershipClass]: @@ -100,19 +101,20 @@ def generate_tags_aspect(self) -> Iterable[GlobalTagsClass]: tags=[ TagAssociationClass(tag=builder.make_tag_urn(tag)) for tag in (sorted(self.tags) or []) - ] + ], ) return [tags] def generate_mcp( - self, materialize_iolets: bool = True + self, + materialize_iolets: bool = True, ) -> Iterable[MetadataChangeProposalWrapper]: env: Optional[str] = None if self.flow_urn.cluster.upper() in builder.ALL_ENV_TYPES: env = self.flow_urn.cluster.upper() else: logger.debug( - f"cluster {self.flow_urn.cluster} is not a valid environment type so Environment filter won't work." + f"cluster {self.flow_urn.cluster} is not a valid environment type so Environment filter won't work.", ) mcp = MetadataChangeProposalWrapper( entityUrn=str(self.urn), @@ -136,7 +138,7 @@ def generate_mcp( yield mcp yield from self.generate_data_input_output_mcp( - materialize_iolets=materialize_iolets + materialize_iolets=materialize_iolets, ) for owner in self.generate_ownership_aspect(): @@ -169,7 +171,8 @@ def emit( emitter.emit(mcp, callback) def generate_data_input_output_mcp( - self, materialize_iolets: bool + self, + materialize_iolets: bool, ) -> Iterable[MetadataChangeProposalWrapper]: mcp = MetadataChangeProposalWrapper( entityUrn=str(self.urn), diff --git a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py index d406fa36e00db6..02ab81cc6f75cb 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py @@ -69,7 +69,9 @@ class DataProcessInstance: outlets: List[DatasetUrn] = field(default_factory=list) upstream_urns: List[DataProcessInstanceUrn] = field(default_factory=list) _template_object: Optional[Union[DataJob, DataFlow]] = field( - init=False, default=None, repr=False + init=False, + default=None, + repr=False, ) def __post_init__(self): @@ -78,11 +80,13 @@ def __post_init__(self): cluster=self.cluster, orchestrator=self.orchestrator, id=self.id, - ).guid() + ).guid(), ) def start_event_mcp( - self, start_timestamp_millis: int, attempt: Optional[int] = None + self, + start_timestamp_millis: int, + attempt: Optional[int] = None, ) -> Iterable[MetadataChangeProposalWrapper]: """ @@ -143,14 +147,14 @@ def emit_process_start( self._emit_mcp(mcp, emitter, callback) else: raise Exception( - f"Invalid urn type {self.template_urn.__class__.__name__}" + f"Invalid urn type {self.template_urn.__class__.__name__}", ) for upstream in self.upstream_urns: input_datajob_urns.append( DataJobUrn.create_from_ids( job_id=upstream.get_dataprocessinstance_id(), data_flow_urn=str(job_flow_urn), - ) + ), ) else: template_object = self._template_object @@ -236,7 +240,9 @@ def emit_process_end( self._emit_mcp(mcp, emitter, callback) def generate_mcp( - self, created_ts_millis: Optional[int], materialize_iolets: bool + self, + created_ts_millis: Optional[int], + materialize_iolets: bool, ) -> Iterable[MetadataChangeProposalWrapper]: """Generates mcps from the object""" @@ -350,13 +356,14 @@ def from_dataflow(dataflow: DataFlow, id: str) -> "DataProcessInstance": return dpi def generate_inlet_outlet_mcp( - self, materialize_iolets: bool + self, + materialize_iolets: bool, ) -> Iterable[MetadataChangeProposalWrapper]: if self.inlets: mcp = MetadataChangeProposalWrapper( entityUrn=str(self.urn), aspect=DataProcessInstanceInput( - inputs=[str(urn) for urn in self.inlets] + inputs=[str(urn) for urn in self.inlets], ), ) yield mcp @@ -365,7 +372,7 @@ def generate_inlet_outlet_mcp( mcp = MetadataChangeProposalWrapper( entityUrn=str(self.urn), aspect=DataProcessInstanceOutput( - outputs=[str(urn) for urn in self.outlets] + outputs=[str(urn) for urn in self.outlets], ), ) yield mcp diff --git a/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py b/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py index d2035d560716ae..cc4ffa15171efd 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py +++ b/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py @@ -155,7 +155,7 @@ def _mint_owner(self, owner: Union[str, Ownership]) -> OwnerClass: else: assert isinstance(owner, Ownership) ownership_type, ownership_type_urn = builder.validate_ownership_type( - owner.type + owner.type, ) return OwnerClass( owner=builder.make_user_urn(owner.id), @@ -164,7 +164,8 @@ def _mint_owner(self, owner: Union[str, Ownership]) -> OwnerClass: ) def _generate_properties_mcp( - self, upsert_mode: bool = False + self, + upsert_mode: bool = False, ) -> Iterable[Union[MetadataChangeProposalWrapper, MetadataChangeProposalClass]]: if upsert_mode: dataproduct_patch_builder = DataProductPatchBuilder(self.urn) @@ -182,14 +183,14 @@ def _generate_properties_mcp( created=self._mint_auditstamp("yaml"), ) for asset in self.assets - ] + ], ) if self.properties is not None: dataproduct_patch_builder.set_custom_properties(self.properties) if self.external_url is not None: dataproduct_patch_builder.set_external_url( - external_url=self.external_url + external_url=self.external_url, ) yield from dataproduct_patch_builder.build() @@ -213,11 +214,12 @@ def _generate_properties_mcp( yield mcp def generate_mcp( - self, upsert: bool + self, + upsert: bool, ) -> Iterable[Union[MetadataChangeProposalWrapper, MetadataChangeProposalClass]]: if self._resolved_domain_urn is None: raise Exception( - f"Unable to generate MCP-s because we were unable to resolve the domain {self.domain} to an urn." + f"Unable to generate MCP-s because we were unable to resolve the domain {self.domain} to an urn.", ) yield from self._generate_properties_mcp(upsert_mode=upsert) @@ -225,7 +227,7 @@ def generate_mcp( mcp = MetadataChangeProposalWrapper( entityUrn=self.urn, aspect=DomainsClass( - domains=[builder.make_domain_urn(self._resolved_domain_urn)] + domains=[builder.make_domain_urn(self._resolved_domain_urn)], ), ) yield mcp @@ -237,7 +239,7 @@ def generate_mcp( tags=[ TagAssociationClass(tag=builder.make_tag_urn(tag)) for tag in self.tags - ] + ], ), ) yield mcp @@ -259,7 +261,7 @@ def generate_mcp( mcp = MetadataChangeProposalWrapper( entityUrn=self.urn, aspect=OwnershipClass( - owners=[self._mint_owner(o) for o in self.owners] + owners=[self._mint_owner(o) for o in self.owners], ), ) yield mcp @@ -275,14 +277,15 @@ def generate_mcp( createStamp=self._mint_auditstamp("yaml"), ) for element in self.institutional_memory.elements - ] + ], ), ) yield mcp # Finally emit status yield MetadataChangeProposalWrapper( - entityUrn=self.urn, aspect=StatusClass(removed=False) + entityUrn=self.urn, + aspect=StatusClass(removed=False), ) def emit( @@ -312,7 +315,8 @@ def from_yaml( parsed_data_product = DataProduct.parse_obj_allow_extras(orig_dictionary) # resolve domains if needed domain_registry: DomainRegistry = DomainRegistry( - cached_domains=[parsed_data_product.domain], graph=graph + cached_domains=[parsed_data_product.domain], + graph=graph, ) domain_urn = domain_registry.get_domain_urn(parsed_data_product.domain) parsed_data_product._resolved_domain_urn = domain_urn @@ -338,7 +342,8 @@ def from_datahub(cls, graph: DataHubGraph, id: str) -> DataProduct: else: yaml_owners.append(Ownership(id=o.owner, type=str(o.type))) glossary_terms: Optional[GlossaryTermsClass] = graph.get_aspect( - id, GlossaryTermsClass + id, + GlossaryTermsClass, ) tags: Optional[GlobalTagsClass] = graph.get_aspect(id, GlobalTagsClass) return DataProduct( @@ -425,7 +430,7 @@ def _patch_ownership( patches_add.append(new_owner) else: patches_add.append( - Ownership(id=new_owner, type=new_owner_type).dict() + Ownership(id=new_owner, type=new_owner_type).dict(), ) mutation_needed = bool(patches_replace or patches_drop or patches_add) @@ -460,7 +465,7 @@ def patch_yaml( this_dataproduct_dict = self.dict() for simple_field in ["display_name", "description", "external_url"]: if original_dataproduct_dict.get(simple_field) != this_dataproduct_dict.get( - simple_field + simple_field, ): update_needed = True orig_dictionary[simple_field] = this_dataproduct_dict.get(simple_field) @@ -488,14 +493,21 @@ def patch_yaml( orig_dictionary["assets"].append(asset_to_add) update_needed = update_needed or patch_list( - original_dataproduct.terms, self.terms, orig_dictionary, "terms" + original_dataproduct.terms, + self.terms, + orig_dictionary, + "terms", ) update_needed = update_needed or patch_list( - original_dataproduct.tags, self.tags, orig_dictionary, "tags" + original_dataproduct.tags, + self.tags, + orig_dictionary, + "tags", ) (ownership_update_needed, new_ownership_list) = self._patch_ownership( - original_dataproduct.owners, orig_dictionary.get("owners") + original_dataproduct.owners, + orig_dictionary.get("owners"), ) if ownership_update_needed: update_needed = True @@ -507,7 +519,7 @@ def patch_yaml( orig_dictionary["owners"] = None if this_dataproduct_dict.get("properties") != original_dataproduct_dict.get( - "properties" + "properties", ): update_needed = True if self.properties is not None and original_dataproduct.properties is None: @@ -546,7 +558,9 @@ def get_patch_builder( audit_header: Optional[KafkaAuditHeaderClass] = None, ) -> DataProductPatchBuilder: return DataProductPatchBuilder( - urn=urn, system_metadata=system_metadata, audit_header=audit_header + urn=urn, + system_metadata=system_metadata, + audit_header=audit_header, ) diff --git a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py index bf824a11a77b5d..24d6f9b12303d1 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py +++ b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py @@ -81,12 +81,15 @@ def with_structured_properties( @classmethod def from_schema_field( - cls, schema_field: SchemaFieldClass, parent_urn: str + cls, + schema_field: SchemaFieldClass, + parent_urn: str, ) -> "SchemaFieldSpecification": return SchemaFieldSpecification( id=Dataset._simplify_field_path(schema_field.fieldPath), urn=make_schema_field_urn( - parent_urn, Dataset._simplify_field_path(schema_field.fieldPath) + parent_urn, + Dataset._simplify_field_path(schema_field.fieldPath), ), type=str(schema_field.type), nativeDataType=schema_field.nativeDataType, @@ -259,7 +262,8 @@ def generate_mcp( fields=avro_schema_to_mce_fields(schema_string), ) mcp = MetadataChangeProposalWrapper( - entityUrn=self.urn, aspect=schema_metadata + entityUrn=self.urn, + aspect=schema_metadata, ) yield mcp @@ -278,7 +282,7 @@ def generate_mcp( tags=[ TagAssociationClass(tag=make_tag_urn(tag)) for tag in field.globalTags - ] + ], ), ) yield mcp @@ -289,7 +293,7 @@ def generate_mcp( aspect=GlossaryTermsClass( terms=[ GlossaryTermAssociationClass( - urn=make_term_urn(term) + urn=make_term_urn(term), ) for term in field.glossaryTerms ], @@ -312,7 +316,7 @@ def generate_mcp( ), ) for prop_key, prop_value in field.structured_properties.items() - ] + ], ), ) yield mcp @@ -321,7 +325,7 @@ def generate_mcp( mcp = MetadataChangeProposalWrapper( entityUrn=self.urn, aspect=SubTypesClass( - typeNames=[s for s in [self.subtype] + (self.subtypes or []) if s] + typeNames=[s for s in [self.subtype] + (self.subtypes or []) if s], ), ) yield mcp @@ -332,7 +336,7 @@ def generate_mcp( aspect=GlobalTagsClass( tags=[ TagAssociationClass(tag=make_tag_urn(tag)) for tag in self.tags - ] + ], ), ) yield mcp @@ -354,7 +358,7 @@ def generate_mcp( mcp = MetadataChangeProposalWrapper( entityUrn=self.urn, aspect=OwnershipClass( - owners=[self._mint_owner(o) for o in self.owners] + owners=[self._mint_owner(o) for o in self.owners], ), ) yield mcp @@ -373,7 +377,7 @@ def generate_mcp( ), ) for prop_key, prop_value in self.structured_properties.items() - ] + ], ), ) yield mcp @@ -386,7 +390,7 @@ def generate_mcp( UpstreamClass( dataset=self.urn, type="COPY", - ) + ), ) yield from patch_builder.build() @@ -394,7 +398,9 @@ def generate_mcp( @staticmethod def validate_type( - sp_name: str, sp_value: Union[str, float], allowed_type: str + sp_name: str, + sp_value: Union[str, float], + allowed_type: str, ) -> Tuple[str, Union[str, float]]: if allowed_type == AllowedTypes.NUMBER.value: return (sp_name, float(sp_value)) @@ -427,24 +433,27 @@ def _simplify_field_path(field_path: str) -> str: @staticmethod def _schema_from_schema_metadata( - graph: DataHubGraph, urn: str + graph: DataHubGraph, + urn: str, ) -> Optional[SchemaSpecification]: schema_metadata: Optional[SchemaMetadataClass] = graph.get_aspect( - urn, SchemaMetadataClass + urn, + SchemaMetadataClass, ) if schema_metadata: schema_specification = SchemaSpecification( fields=[ SchemaFieldSpecification.from_schema_field( - field, urn + field, + urn, ).with_structured_properties( { sp.propertyUrn: sp.values for sp in structured_props.properties } if structured_props - else None + else None, ) for field, structured_props in [ ( @@ -455,14 +464,15 @@ def _schema_from_schema_metadata( ) or graph.get_aspect( make_schema_field_urn( - urn, Dataset._simplify_field_path(field.fieldPath) + urn, + Dataset._simplify_field_path(field.fieldPath), ), StructuredPropertiesClass, ), ) for field in schema_metadata.fields ] - ] + ], ) return schema_specification else: @@ -487,17 +497,20 @@ def extract_owners_if_exists( @classmethod def from_datahub(cls, graph: DataHubGraph, urn: str) -> "Dataset": dataset_properties: Optional[DatasetPropertiesClass] = graph.get_aspect( - urn, DatasetPropertiesClass + urn, + DatasetPropertiesClass, ) subtypes: Optional[SubTypesClass] = graph.get_aspect(urn, SubTypesClass) tags: Optional[GlobalTagsClass] = graph.get_aspect(urn, GlobalTagsClass) glossary_terms: Optional[GlossaryTermsClass] = graph.get_aspect( - urn, GlossaryTermsClass + urn, + GlossaryTermsClass, ) owners: Optional[OwnershipClass] = graph.get_aspect(urn, OwnershipClass) yaml_owners = Dataset.extract_owners_if_exists(owners) structured_properties: Optional[StructuredPropertiesClass] = graph.get_aspect( - urn, StructuredPropertiesClass + urn, + StructuredPropertiesClass, ) if structured_properties: structured_properties_map: Dict[str, List[Union[str, float]]] = {} diff --git a/metadata-ingestion/src/datahub/api/entities/forms/forms.py b/metadata-ingestion/src/datahub/api/entities/forms/forms.py index fa6453851e0e74..0b299a59849e0e 100644 --- a/metadata-ingestion/src/datahub/api/entities/forms/forms.py +++ b/metadata-ingestion/src/datahub/api/entities/forms/forms.py @@ -142,7 +142,7 @@ def create(file: str) -> None: try: if not FormType.has_value(form.type): logger.error( - f"Form type {form.type} does not exist. Please try again with a valid type." + f"Form type {form.type} does not exist. Please try again with a valid type.", ) mcp = MetadataChangeProposalWrapper( @@ -182,7 +182,7 @@ def validate_prompts(self, emitter: DataHubGraph) -> List[FormPromptClass]: if not prompt.id: prompt.id = str(uuid.uuid4()) logger.warning( - f"Prompt id not provided. Setting prompt id to {prompt.id}" + f"Prompt id not provided. Setting prompt id to {prompt.id}", ) if prompt.structured_property_urn: structured_property_urn = prompt.structured_property_urn @@ -190,7 +190,7 @@ def validate_prompts(self, emitter: DataHubGraph) -> List[FormPromptClass]: prompt.structured_property_urn = structured_property_urn else: raise Exception( - f"Structured property {structured_property_urn} does not exist. Unable to create form." + f"Structured property {structured_property_urn} does not exist. Unable to create form.", ) elif ( prompt.type @@ -201,14 +201,14 @@ def validate_prompts(self, emitter: DataHubGraph) -> List[FormPromptClass]: and not prompt.structured_property_urn ): raise Exception( - f"Prompt type is {prompt.type} but no structured properties exist. Unable to create form." + f"Prompt type is {prompt.type} but no structured properties exist. Unable to create form.", ) if ( prompt.type == PromptType.FIELDS_STRUCTURED_PROPERTY.value and prompt.required ): raise Exception( - "Schema field prompts cannot be marked as required. Ensure these prompts are not required." + "Schema field prompts cannot be marked as required. Ensure these prompts are not required.", ) prompts.append( @@ -219,13 +219,13 @@ def validate_prompts(self, emitter: DataHubGraph) -> List[FormPromptClass]: type=prompt.type, structuredPropertyParams=( StructuredPropertyParamsClass( - urn=prompt.structured_property_urn + urn=prompt.structured_property_urn, ) if prompt.structured_property_urn else None ), required=prompt.required, - ) + ), ) else: logger.warning(f"No prompts exist on form {self.urn}. Is that intended?") @@ -233,7 +233,8 @@ def validate_prompts(self, emitter: DataHubGraph) -> List[FormPromptClass]: return prompts def create_form_actors( - self, actors: Optional[Actors] = None + self, + actors: Optional[Actors] = None, ) -> Union[None, FormActorAssignmentClass]: if actors is None: return None @@ -247,16 +248,19 @@ def create_form_actors( groups = Forms.format_groups(actors.groups) return FormActorAssignmentClass( - owners=actors.owners, users=users, groups=groups + owners=actors.owners, + users=users, + groups=groups, ) def upload_entities_for_form(self, emitter: DataHubGraph) -> Union[None, Exception]: if self.entities and self.entities.urns: formatted_entity_urns = ", ".join( - [f'"{value}"' for value in self.entities.urns] + [f'"{value}"' for value in self.entities.urns], ) query = UPLOAD_ENTITIES_FOR_FORMS.format( - form_urn=self.urn, entity_urns=formatted_entity_urns + form_urn=self.urn, + entity_urns=formatted_entity_urns, ) result = emitter.execute_graphql(query=query) if not result: @@ -272,14 +276,15 @@ def create_form_filters(self, emitter: DataHubGraph) -> Union[None, Exception]: if filters.types: filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_TYPES, filters.types) + Forms.format_form_filter(FILTER_CRITERION_TYPES, filters.types), ) if filters.sub_types: filters_raw.append( Forms.format_form_filter( - FILTER_CRITERION_SUB_TYPES, filters.sub_types - ) + FILTER_CRITERION_SUB_TYPES, + filters.sub_types, + ), ) if filters.platforms: @@ -287,25 +292,25 @@ def create_form_filters(self, emitter: DataHubGraph) -> Union[None, Exception]: make_data_platform_urn(platform) for platform in filters.platforms ] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_PLATFORMS, urns) + Forms.format_form_filter(FILTER_CRITERION_PLATFORMS, urns), ) if filters.platform_instances: urns = [] for platform_instance in filters.platform_instances: platform_instance_urn = Forms.validate_platform_instance_urn( - platform_instance + platform_instance, ) if platform_instance_urn: urns.append(platform_instance_urn) filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_PLATFORM_INSTANCES, urns) + Forms.format_form_filter(FILTER_CRITERION_PLATFORM_INSTANCES, urns), ) if filters.domains: urns = [make_domain_urn(domain) for domain in filters.domains] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_DOMAINS, urns) + Forms.format_form_filter(FILTER_CRITERION_DOMAINS, urns), ) if filters.containers: @@ -313,37 +318,38 @@ def create_form_filters(self, emitter: DataHubGraph) -> Union[None, Exception]: make_container_urn(container) for container in filters.containers ] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_CONTAINERS, urns) + Forms.format_form_filter(FILTER_CRITERION_CONTAINERS, urns), ) if filters.owners: urns = [make_user_urn(owner) for owner in filters.owners] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_OWNERS, urns) + Forms.format_form_filter(FILTER_CRITERION_OWNERS, urns), ) if filters.tags: urns = [make_tag_urn(tag) for tag in filters.tags] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_TAGS, urns) + Forms.format_form_filter(FILTER_CRITERION_TAGS, urns), ) if filters.terms: urns = [make_term_urn(term) for term in filters.terms] filters_raw.append( - Forms.format_form_filter(FILTER_CRITERION_GLOSSARY_TERMS, urns) + Forms.format_form_filter(FILTER_CRITERION_GLOSSARY_TERMS, urns), ) filters_str = ", ".join(item for item in filters_raw) result = emitter.execute_graphql( query=CREATE_DYNAMIC_FORM_ASSIGNMENT.format( - form_urn=self.urn, filters=filters_str - ) + form_urn=self.urn, + filters=filters_str, + ), ) if not result: return Exception( - f"Could not bulk upload urns or filters for form {self.urn}." + f"Could not bulk upload urns or filters for form {self.urn}.", ) return None @@ -381,7 +387,7 @@ def validate_platform_instance_urn(instance: str) -> Union[str, None]: return instance logger.warning( - f"{instance} is not an urn. Unable to create platform instance filter." + f"{instance} is not an urn. Unable to create platform instance filter.", ) return None @@ -403,7 +409,7 @@ def from_datahub(graph: DataHubGraph, urn: str) -> "Forms": if prompt_raw.structuredPropertyParams else None ), - ) + ), ) return Forms( urn=urn, diff --git a/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py b/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py index 0ba43d7b101e5f..7915db0509bd30 100644 --- a/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py +++ b/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py @@ -50,12 +50,13 @@ class PlatformResourceInfo(BaseModel): @classmethod def from_resource_info( - cls, resource_info: models.PlatformResourceInfoClass + cls, + resource_info: models.PlatformResourceInfoClass, ) -> "PlatformResourceInfo": serialized_value: Optional[SerializedResourceValue] = None if resource_info.value: serialized_value = SerializedResourceValue.from_resource_value( - resource_info.value + resource_info.value, ) return cls( primary_key=resource_info.primaryKey, @@ -101,7 +102,8 @@ def __init__(self, field_name: str): @classmethod def from_search_field( - cls, search_field: SearchField + cls, + search_field: SearchField, ) -> "PlatformResourceSearchField": # pretends to be a class method, but just returns the input return search_field # type: ignore @@ -115,13 +117,13 @@ class PlatformResourceSearchFields: UrnSearchField( field_name="platform.keyword", urn_value_extractor=DataPlatformUrn.create_from_id, - ) + ), ) PLATFORM_INSTANCE = PlatformResourceSearchField.from_search_field( UrnSearchField( field_name="platformInstance.keyword", urn_value_extractor=DataPlatformInstanceUrn.from_string, - ) + ), ) @@ -179,7 +181,8 @@ def create( platform=make_data_platform_urn(key.platform), platform_instance=( make_dataplatform_instance_urn( - platform=key.platform, instance=key.platform_instance + platform=key.platform, + instance=key.platform_instance, ) if key.platform_instance else None @@ -208,7 +211,9 @@ def to_datahub(self, graph_client: DataHubGraph) -> None: @classmethod def from_datahub( - cls, graph_client: DataHubGraph, key: Union[PlatformResourceKey, str] + cls, + graph_client: DataHubGraph, + key: Union[PlatformResourceKey, str], ) -> Optional["PlatformResource"]: """ Fetches a PlatformResource from the graph given a key. @@ -226,14 +231,14 @@ def from_datahub( id=urn.id, resource_info=( PlatformResourceInfo.from_resource_info( - platform_resource["platformResourceInfo"] + platform_resource["platformResourceInfo"], ) if "platformResourceInfo" in platform_resource else None ), data_platform_instance=( DataPlatformInstance.from_data_platform_instance( - platform_resource["dataPlatformInstance"] + platform_resource["dataPlatformInstance"], ) if "dataPlatformInstance" in platform_resource else None @@ -267,12 +272,16 @@ def search_by_key( ElasticPlatformResourceQuery.create_from() .group(LogicalOperator.OR) .add_field_match( - PlatformResourceSearchFields.PRIMARY_KEY, key, is_exact=is_exact + PlatformResourceSearchFields.PRIMARY_KEY, + key, + is_exact=is_exact, ) ) if not primary: # we expand the search to secondary keys elastic_platform_resource_group.add_field_match( - PlatformResourceSearchFields.SECONDARY_KEYS, key, is_exact=is_exact + PlatformResourceSearchFields.SECONDARY_KEYS, + key, + is_exact=is_exact, ) query = elastic_platform_resource_group.end() openapi_client = OpenAPIGraphClient(graph_client) diff --git a/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py b/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py index b0b434751ad2cc..beb1dabc4e583c 100644 --- a/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py +++ b/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py @@ -52,7 +52,7 @@ def _validate_entity_type_urn(v: str) -> str: urn = Urn.make_entity_type_urn(v) if urn not in VALID_ENTITY_TYPE_URNS: raise ValueError( - f"Input {v} is not a valid entity type urn. {_VALID_ENTITY_TYPES_STRING}" + f"Input {v} is not a valid entity type urn. {_VALID_ENTITY_TYPES_STRING}", ) v = str(urn) return v @@ -62,7 +62,7 @@ class TypeQualifierAllowedTypes(ConfigModel): allowed_types: List[str] _check_allowed_types = validator("allowed_types", each_item=True, allow_reuse=True)( - _validate_entity_type_urn + _validate_entity_type_urn, ) @@ -81,7 +81,7 @@ class StructuredProperties(ConfigModel): immutable: Optional[bool] = False _check_entity_types = validator("entity_types", each_item=True, allow_reuse=True)( - _validate_entity_type_urn + _validate_entity_type_urn, ) @validator("type") @@ -95,7 +95,7 @@ def validate_type(cls, v: str) -> str: # Convert to lowercase if needed v = v.lower() logger.warning( - f"Structured property type should be lowercase. Updated to {v}" + f"Structured property type should be lowercase. Updated to {v}", ) urn = Urn.make_data_type_urn(v) @@ -107,7 +107,7 @@ def validate_type(cls, v: str) -> str: unqualified_data_type = unqualified_data_type[len("datahub.") :] if not AllowedTypes.check_allowed_type(unqualified_data_type): raise ValueError( - f"Type {unqualified_data_type} is not allowed. Allowed types are {AllowedTypes.values()}" + f"Type {unqualified_data_type} is not allowed. Allowed types are {AllowedTypes.values()}", ) return urn @@ -189,7 +189,7 @@ def from_datahub(cls, graph: DataHubGraph, urn: str) -> "StructuredProperties": ) if structured_property is None: raise Exception( - "StructuredPropertyDefinition aspect is None. Unable to create structured property." + "StructuredPropertyDefinition aspect is None. Unable to create structured property.", ) return StructuredProperties( urn=urn, diff --git a/metadata-ingestion/src/datahub/api/graphql/base.py b/metadata-ingestion/src/datahub/api/graphql/base.py index c1ea6b71a6d145..2d8730c00cb535 100644 --- a/metadata-ingestion/src/datahub/api/graphql/base.py +++ b/metadata-ingestion/src/datahub/api/graphql/base.py @@ -38,7 +38,8 @@ def __init__( ) def gen_filter( - self, filters: Dict[str, Optional[str]] + self, + filters: Dict[str, Optional[str]], ) -> Optional[Dict[str, List[Dict[str, str]]]]: filter_expression: Optional[Dict[str, List[Dict[str, str]]]] = None if not filters: diff --git a/metadata-ingestion/src/datahub/api/graphql/operation.py b/metadata-ingestion/src/datahub/api/graphql/operation.py index 9cb40ce5815a56..dba740ba7c9ff0 100644 --- a/metadata-ingestion/src/datahub/api/graphql/operation.py +++ b/metadata-ingestion/src/datahub/api/graphql/operation.py @@ -80,7 +80,8 @@ def report_operation( variable_values["customProperties"] = custom_properties result = self.client.execute( - gql(Operation.REPORT_OPERATION_MUTATION), variable_values + gql(Operation.REPORT_OPERATION_MUTATION), + variable_values, ) return result["reportOperation"] @@ -121,7 +122,7 @@ def query_operations( "sourceType": source_type, "operationType": operation_type, "partition": partition, - } + }, ), }, ) diff --git a/metadata-ingestion/src/datahub/cli/check_cli.py b/metadata-ingestion/src/datahub/cli/check_cli.py index fbe07b64f0e154..277f0f41219056 100644 --- a/metadata-ingestion/src/datahub/cli/check_cli.py +++ b/metadata-ingestion/src/datahub/cli/check_cli.py @@ -40,7 +40,10 @@ def check() -> None: help="Rewrite the JSON file to it's canonical form.", ) @click.option( - "--unpack-mces", default=False, is_flag=True, help="Converts MCEs into MCPs" + "--unpack-mces", + default=False, + is_flag=True, + help="Converts MCEs into MCPs", ) @telemetry.with_telemetry() def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None: @@ -102,7 +105,10 @@ def metadata_file(json_file: str, rewrite: bool, unpack_mces: bool) -> None: ) @telemetry.with_telemetry() def metadata_diff( - actual_file: str, expected_file: str, verbose: bool, ignore_path: List[str] + actual_file: str, + expected_file: str, + verbose: bool, + ignore_path: List[str], ) -> None: """Compare two metadata (MCE or MCP) JSON files. @@ -159,7 +165,7 @@ def plugins(source: Optional[str], verbose: bool) -> None: if not verbose: click.echo("For details on why a plugin is disabled, rerun with '--verbose'") click.echo( - f"If a plugin is disabled, try running: pip install '{__package_name__}[]'" + f"If a plugin is disabled, try running: pip install '{__package_name__}[]'", ) @@ -386,7 +392,7 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None: except Exception as e: logger.error( - f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}" + f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}", ) raise e @@ -403,7 +409,8 @@ def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None: shared_connection = ConnectionWrapper(pathlib.Path(query_log_file)) query_log = FileBackedList[LoggedQuery]( - shared_connection=shared_connection, tablename="stored_queries" + shared_connection=shared_connection, + tablename="stored_queries", ) logger.info(f"Extracting {len(query_log)} queries from {query_log_file}") queries = [dataclasses.asdict(query) for query in query_log] diff --git a/metadata-ingestion/src/datahub/cli/cli_utils.py b/metadata-ingestion/src/datahub/cli/cli_utils.py index 1f13391644c6c8..a04eceeca50e28 100644 --- a/metadata-ingestion/src/datahub/cli/cli_utils.py +++ b/metadata-ingestion/src/datahub/cli/cli_utils.py @@ -57,7 +57,7 @@ def parse_run_restli_response(response: requests.Response) -> dict: if not isinstance(response_json, dict): click.echo( - f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}" + f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}", ) click.echo() click.echo(response_json) @@ -66,7 +66,7 @@ def parse_run_restli_response(response: requests.Response) -> dict: summary = response_json.get("value") if not isinstance(summary, dict): click.echo( - f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}" + f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}", ) click.echo() click.echo(response_json) @@ -82,7 +82,7 @@ def format_aspect_summaries(summaries: list) -> typing.List[typing.List[str]]: row.get("urn"), row.get("aspectName"), datetime.fromtimestamp(row.get("timestamp") / 1000).strftime( - "%Y-%m-%d %H:%M:%S" + "%Y-%m-%d %H:%M:%S", ) + f" ({local_timezone})", ] @@ -109,7 +109,7 @@ def post_rollback_endpoint( unsafe_entity_count = summary.get("unsafeEntitiesCount", 0) unsafe_entities = summary.get("unsafeEntities", []) rolled_back_aspects = list( - filter(lambda row: row["runId"] == payload_obj["runId"], rows) + filter(lambda row: row["runId"] == payload_obj["runId"], rows), ) if len(rows) == 0: @@ -140,7 +140,7 @@ def get_entity( encoded_urn = Urn.url_encode(urn) else: raise Exception( - f"urn {urn} does not seem to be a valid raw (starts with urn:) or encoded urn (starts with urn%3A)" + f"urn {urn} does not seem to be a valid raw (starts with urn:) or encoded urn (starts with urn%3A)", ) # TODO: Replace with DataHubGraph.get_entity_raw. @@ -183,7 +183,7 @@ def post_entity( if system_metadata is not None: proposal["proposal"]["systemMetadata"] = json.dumps( - pre_json_transform(system_metadata.to_obj()) + pre_json_transform(system_metadata.to_obj()), ) payload = json.dumps(proposal) @@ -242,7 +242,11 @@ def get_aspects_for_entity( # Process non-timeseries aspects non_timeseries_aspects = [a for a in aspects if a not in TIMESERIES_ASPECT_MAP] entity_response = get_entity( - session, gms_host, entity_urn, non_timeseries_aspects, cached_session_host + session, + gms_host, + entity_urn, + non_timeseries_aspects, + cached_session_host, ) aspect_list: Dict[str, dict] = entity_response["aspects"] @@ -250,12 +254,16 @@ def get_aspects_for_entity( timeseries_aspects: List[str] = [a for a in aspects if a in TIMESERIES_ASPECT_MAP] for timeseries_aspect in timeseries_aspects: timeseries_response: Dict = get_latest_timeseries_aspect_values( - session, gms_host, entity_urn, timeseries_aspect, cached_session_host + session, + gms_host, + entity_urn, + timeseries_aspect, + cached_session_host, ) values: List[Dict] = timeseries_response.get("value", {}).get("values", []) if values: aspect_cls: Optional[Type] = _get_pydantic_class_from_aspect_name( - timeseries_aspect + timeseries_aspect, ) if aspect_cls is not None: ts_aspect = values[0]["aspect"] @@ -266,7 +274,7 @@ def get_aspects_for_entity( aspect_map: Dict[str, Union[dict, _Aspect]] = {} for aspect_name, a in aspect_list.items(): aspect_py_class: Optional[Type[Any]] = _get_pydantic_class_from_aspect_name( - aspect_name + aspect_name, ) if details: @@ -315,7 +323,9 @@ def command(ctx: click.Context) -> None: def get_frontend_session_login_as( - username: str, password: str, frontend_url: str + username: str, + password: str, + frontend_url: str, ) -> requests.Session: session = requests.Session() headers = { @@ -399,19 +409,22 @@ def generate_access_token( "actorUrn": f"urn:li:corpuser:{username}", "duration": validity, "name": token_name, - } + }, }, } response = session.post(f"{frontend_url}/api/v2/graphql", json=json) response.raise_for_status() return token_name, response.json().get("data", {}).get("createAccessToken", {}).get( - "accessToken", None + "accessToken", + None, ) def ensure_has_system_metadata( event: Union[ - MetadataChangeProposal, MetadataChangeProposalWrapper, MetadataChangeEvent + MetadataChangeProposal, + MetadataChangeProposalWrapper, + MetadataChangeEvent, ], ) -> None: if event.systemMetadata is None: diff --git a/metadata-ingestion/src/datahub/cli/config_utils.py b/metadata-ingestion/src/datahub/cli/config_utils.py index 5d9604de7836f9..da0751a2722960 100644 --- a/metadata-ingestion/src/datahub/cli/config_utils.py +++ b/metadata-ingestion/src/datahub/cli/config_utils.py @@ -79,7 +79,7 @@ def _get_config_from_env() -> Tuple[Optional[str], Optional[str]]: # If port is not being used we assume someone is using host env var as URL if url is None and host is not None: logger.warning( - f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead" + f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead", ) return url or host, token @@ -99,14 +99,14 @@ def load_client_config() -> DatahubClientConfig: if _should_skip_config(): raise MissingConfigError( - "You have set the skip config flag, but no GMS host or token was provided in env variables." + "You have set the skip config flag, but no GMS host or token was provided in env variables.", ) try: _ensure_datahub_config() client_config_dict = get_raw_client_config() datahub_config: DatahubClientConfig = DatahubConfig.parse_obj( - client_config_dict + client_config_dict, ).gms return datahub_config @@ -120,12 +120,14 @@ def _ensure_datahub_config() -> None: if not os.path.isfile(DATAHUB_CONFIG_PATH): raise MissingConfigError( f"No {CONDENSED_DATAHUB_CONFIG_PATH} file found, and no configuration was found in environment variables. " - f"Run `datahub init` to create a {CONDENSED_DATAHUB_CONFIG_PATH} file." + f"Run `datahub init` to create a {CONDENSED_DATAHUB_CONFIG_PATH} file.", ) def write_gms_config( - host: str, token: Optional[str], merge_with_previous: bool = True + host: str, + token: Optional[str], + merge_with_previous: bool = True, ) -> None: config = DatahubConfig(gms=DatahubClientConfig(server=host, token=token)) if merge_with_previous: @@ -136,7 +138,7 @@ def write_gms_config( # ok to fail on this previous_config = {} logger.debug( - f"Failed to retrieve config from file {DATAHUB_CONFIG_PATH}: {e}. This isn't fatal." + f"Failed to retrieve config from file {DATAHUB_CONFIG_PATH}: {e}. This isn't fatal.", ) config_dict = {**previous_config, **config.dict()} else: diff --git a/metadata-ingestion/src/datahub/cli/delete_cli.py b/metadata-ingestion/src/datahub/cli/delete_cli.py index 8501cf71f0d544..8d7fea9ab580eb 100644 --- a/metadata-ingestion/src/datahub/cli/delete_cli.py +++ b/metadata-ingestion/src/datahub/cli/delete_cli.py @@ -62,16 +62,20 @@ class DeletionResult: def merge(self, another_result: "DeletionResult") -> None: self.num_records = self._sum_handle_unknown( - self.num_records, another_result.num_records + self.num_records, + another_result.num_records, ) self.num_timeseries_records = self._sum_handle_unknown( - self.num_timeseries_records, another_result.num_timeseries_records + self.num_timeseries_records, + another_result.num_timeseries_records, ) self.num_entities = self._sum_handle_unknown( - self.num_entities, another_result.num_entities + self.num_entities, + another_result.num_entities, ) self.num_referenced_entities = self._sum_handle_unknown( - self.num_referenced_entities, another_result.num_referenced_entities + self.num_referenced_entities, + another_result.num_referenced_entities, ) def format_message(self, *, dry_run: bool, soft: bool, time_sec: float) -> str: @@ -104,7 +108,10 @@ def _sum_handle_unknown(cls, value1: int, value2: int) -> int: @delete.command(no_args_is_help=True) @click.option( - "--registry-id", required=True, type=str, help="e.g. mycompany-dq-model:0.0.1" + "--registry-id", + required=True, + type=str, + help="e.g. mycompany-dq-model:0.0.1", ) @click.option( "--soft/--hard", @@ -128,7 +135,7 @@ def by_registry( if soft and not dry_run: raise click.UsageError( - "Soft-deleting with a registry-id is not yet supported. Try --dry-run to see what you will be deleting, before issuing a hard-delete using the --hard flag" + "Soft-deleting with a registry-id is not yet supported. Try --dry-run to see what you will be deleting, before issuing a hard-delete using the --hard flag", ) with PerfTimer() as timer: @@ -152,12 +159,12 @@ def by_registry( click.echo( f"Took {timer.elapsed_seconds()} seconds to {message}" f" {aspects_affected} versioned rows" - f" for {entities_affected} entities." + f" for {entities_affected} entities.", ) else: click.echo( f"{entities_affected} entities with {aspects_affected} rows will be affected. " - f"Took {timer.elapsed_seconds()} seconds to evaluate." + f"Took {timer.elapsed_seconds()} seconds to evaluate.", ) if structured_rows: click.echo(tabulate(structured_rows, _RUN_TABLE_COLUMNS, tablefmt="grid")) @@ -167,7 +174,11 @@ def by_registry( @click.option("--urn", required=True, type=str, help="the urn of the entity") @click.option("-n", "--dry-run", required=False, is_flag=True) @click.option( - "-f", "--force", required=False, is_flag=True, help="force the delete if set" + "-f", + "--force", + required=False, + is_flag=True, + help="force the delete if set", ) @telemetry.with_telemetry() def references(urn: str, dry_run: bool, force: bool) -> None: @@ -233,7 +244,9 @@ def references(urn: str, dry_run: bool, force: bool) -> None: "Maximum 10000. Large batch sizes may cause timeouts.", ) def undo_by_filter( - urn: Optional[str], platform: Optional[str], batch_size: int + urn: Optional[str], + platform: Optional[str], + batch_size: int, ) -> None: """ Undo soft deletion by filters @@ -249,7 +262,7 @@ def undo_by_filter( query="*", status=RemovedStatusFilter.ONLY_SOFT_DELETED, batch_size=batch_size, - ) + ), ) logger.info(f"Going to un-soft delete {len(urns)} urns") urns_iter = progressbar.progressbar(urns, redirect_stdout=True) @@ -289,7 +302,11 @@ def undo_by_filter( help="Soft deletion can be undone, while hard deletion removes from the database", ) @click.option( - "-e", "--env", required=False, type=str, help="Environment filter (e.g. PROD)" + "-e", + "--env", + required=False, + type=str, + help="Environment filter (e.g. PROD)", ) @click.option( "-p", @@ -347,7 +364,10 @@ def undo_by_filter( help="Only delete soft-deleted entities, for hard deletion", ) @click.option( - "--workers", type=int, default=1, help="Num of workers to use for deletion." + "--workers", + type=int, + default=1, + help="Num of workers to use for deletion.", ) @upgrade.check_upgrade @telemetry.with_telemetry() @@ -380,7 +400,9 @@ def by_filter( recursive=recursive, ) soft_delete_filter = _validate_user_soft_delete_flags( - soft=soft, aspect=aspect, only_soft_deleted=only_soft_deleted + soft=soft, + aspect=aspect, + only_soft_deleted=only_soft_deleted, ) _validate_user_aspect_flags(aspect=aspect, start_time=start_time, end_time=end_time) _validate_batch_size(batch_size) @@ -419,7 +441,7 @@ def by_filter( platform_instance=urn, status=soft_delete_filter, batch_size=batch_size, - ) + ), ) else: urns.extend( @@ -427,7 +449,7 @@ def by_filter( container=urn, status=soft_delete_filter, batch_size=batch_size, - ) + ), ) else: urns = list( @@ -438,11 +460,11 @@ def by_filter( query=query, status=soft_delete_filter, batch_size=batch_size, - ) + ), ) if len(urns) == 0: click.echo( - "Found no urns to delete. Maybe you want to change your filters to be something different?" + "Found no urns to delete. Maybe you want to change your filters to be something different?", ) return @@ -457,11 +479,11 @@ def by_filter( click.echo("Found urns of multiple entity types") for entity_type, entity_urns in urns_by_type.items(): click.echo( - f"- {len(entity_urns)} {entity_type} urn(s). Sample: {random.sample(entity_urns, k=min(5, len(entity_urns)))}" + f"- {len(entity_urns)} {entity_type} urn(s). Sample: {random.sample(entity_urns, k=min(5, len(entity_urns)))}", ) else: click.echo( - f"Found {len(urns)} {entity_type} urn(s). Sample: {random.sample(urns, k=min(5, len(urns)))}" + f"Found {len(urns)} {entity_type} urn(s). Sample: {random.sample(urns, k=min(5, len(urns)))}", ) if not force and not dry_run: @@ -530,8 +552,10 @@ def process_urn(urn): click.echo( deletion_result.format_message( - dry_run=dry_run, soft=soft, time_sec=timer.elapsed_seconds() - ) + dry_run=dry_run, + soft=soft, + time_sec=timer.elapsed_seconds(), + ), ) @@ -547,39 +571,41 @@ def _validate_user_urn_and_filters( if urn: if entity_type or platform or env or query: raise click.UsageError( - "You cannot provide both an urn and a filter rule (entity-type / platform / env / query)." + "You cannot provide both an urn and a filter rule (entity-type / platform / env / query).", ) elif not urn and not (entity_type or platform or env or query): raise click.UsageError( - "You must provide either an urn or at least one filter (entity-type / platform / env / query) in order to delete entities." + "You must provide either an urn or at least one filter (entity-type / platform / env / query) in order to delete entities.", ) elif query: logger.warning( - "Using --query is an advanced feature and can easily delete unintended entities. Please use with caution." + "Using --query is an advanced feature and can easily delete unintended entities. Please use with caution.", ) elif env and not (platform or entity_type): logger.warning( - f"Using --env without other filters will delete all metadata in the {env} environment. Please use with caution." + f"Using --env without other filters will delete all metadata in the {env} environment. Please use with caution.", ) # Check recursive flag. if recursive: if not urn: raise click.UsageError( - "The --recursive flag can only be used with a single urn." + "The --recursive flag can only be used with a single urn.", ) elif guess_entity_type(urn) not in _RECURSIVE_DELETE_TYPES: raise click.UsageError( - f"The --recursive flag can only be used with these entity types: {_RECURSIVE_DELETE_TYPES}." + f"The --recursive flag can only be used with these entity types: {_RECURSIVE_DELETE_TYPES}.", ) elif urn and guess_entity_type(urn) in _RECURSIVE_DELETE_TYPES: logger.warning( - f"This will only delete {urn}. Use --recursive to delete all contained entities." + f"This will only delete {urn}. Use --recursive to delete all contained entities.", ) def _validate_user_soft_delete_flags( - soft: bool, aspect: Optional[str], only_soft_deleted: bool + soft: bool, + aspect: Optional[str], + only_soft_deleted: bool, ) -> RemovedStatusFilter: # Check soft / hard delete flags. # Note: aspect not None ==> hard delete, @@ -588,12 +614,12 @@ def _validate_user_soft_delete_flags( if soft: if aspect: raise click.UsageError( - "You cannot provide an aspect name when performing a soft delete. Use --hard to perform a hard delete." + "You cannot provide an aspect name when performing a soft delete. Use --hard to perform a hard delete.", ) if only_soft_deleted: raise click.UsageError( - "You cannot provide --only-soft-deleted when performing a soft delete. Use --hard to perform a hard delete." + "You cannot provide --only-soft-deleted when performing a soft delete. Use --hard to perform a hard delete.", ) soft_delete_filter = RemovedStatusFilter.NOT_SOFT_DELETED @@ -617,22 +643,22 @@ def _validate_user_aspect_flags( if aspect and aspect not in ASPECT_MAP: logger.info(f"Supported aspects: {list(sorted(ASPECT_MAP.keys()))}") raise click.UsageError( - f"Unknown aspect {aspect}. Ensure the aspect is in the above list." + f"Unknown aspect {aspect}. Ensure the aspect is in the above list.", ) # Check that start/end time are set if and only if the aspect is a timeseries aspect. if aspect and aspect in TIMESERIES_ASPECT_MAP: if not start_time or not end_time: raise click.UsageError( - "You must provide both --start-time and --end-time when deleting a timeseries aspect." + "You must provide both --start-time and --end-time when deleting a timeseries aspect.", ) elif start_time or end_time: raise click.UsageError( - "You can only provide --start-time and --end-time when deleting a timeseries aspect." + "You can only provide --start-time and --end-time when deleting a timeseries aspect.", ) elif aspect: raise click.UsageError( - "Aspect-specific deletion is only supported for timeseries aspects. Please delete the full entity or use a rollback instead." + "Aspect-specific deletion is only supported for timeseries aspects. Please delete the full entity or use a rollback instead.", ) @@ -681,7 +707,7 @@ def _delete_one_urn( ) else: logger.info( - f"[Dry-run] Would hard-delete {urn} timeseries aspect {aspect_name}" + f"[Dry-run] Would hard-delete {urn} timeseries aspect {aspect_name}", ) ts_rows_affected = _UNKNOWN_NUM_RECORDS @@ -690,7 +716,7 @@ def _delete_one_urn( # TODO: The backend doesn't support this yet. raise NotImplementedError( - "Delete by aspect is not supported yet for non-timeseries aspects. Please delete the full entity or use rollback instead." + "Delete by aspect is not supported yet for non-timeseries aspects. Please delete the full entity or use rollback instead.", ) else: @@ -714,7 +740,7 @@ def _delete_one_urn( ) if dry_run and referenced_entities_affected > 0: logger.info( - f"[Dry-run] Would remove {referenced_entities_affected} references to {urn}" + f"[Dry-run] Would remove {referenced_entities_affected} references to {urn}", ) return DeletionResult( diff --git a/metadata-ingestion/src/datahub/cli/docker_check.py b/metadata-ingestion/src/datahub/cli/docker_check.py index ff3965455d1633..758816ba7c094b 100644 --- a/metadata-ingestion/src/datahub/cli/docker_check.py +++ b/metadata-ingestion/src/datahub/cli/docker_check.py @@ -16,7 +16,7 @@ DOCKER_COMPOSE_PROJECT_NAME = os.getenv("DATAHUB_COMPOSE_PROJECT_NAME", "datahub") DATAHUB_COMPOSE_PROJECT_FILTER = { - "label": f"com.docker.compose.project={DOCKER_COMPOSE_PROJECT_NAME}" + "label": f"com.docker.compose.project={DOCKER_COMPOSE_PROJECT_NAME}", } DATAHUB_COMPOSE_LEGACY_VOLUME_FILTERS = [ @@ -69,7 +69,7 @@ def get_docker_client() -> Iterator[docker.DockerClient]: raise error except docker.errors.DockerException as error: raise DockerNotRunningError( - "Docker doesn't seem to be running. Did you start it?" + "Docker doesn't seem to be running. Did you start it?", ) from error assert client @@ -78,7 +78,7 @@ def get_docker_client() -> Iterator[docker.DockerClient]: client.ping() except docker.errors.DockerException as error: raise DockerNotRunningError( - "Unable to talk to Docker. Did you start it?" + "Unable to talk to Docker. Did you start it?", ) from error # Yield the client and make sure to close it. @@ -99,7 +99,7 @@ def run_quickstart_preflight_checks(client: docker.DockerClient) -> None: if memory_in_gb(total_mem_configured) < MIN_MEMORY_NEEDED: raise DockerLowMemoryError( f"Total Docker memory configured {memory_in_gb(total_mem_configured):.2f}GB is below the minimum threshold {MIN_MEMORY_NEEDED}GB. " - "You can increase the memory allocated to Docker in the Docker settings." + "You can increase the memory allocated to Docker in the Docker settings.", ) @@ -152,7 +152,9 @@ def needs_up(self) -> bool: ) def to_exception( - self, header: str, footer: Optional[str] = None + self, + header: str, + footer: Optional[str] = None, ) -> QuickstartError: message = f"{header}\n" for error in self.errors(): @@ -205,7 +207,7 @@ def check_docker_quickstart() -> QuickstartStatus: for config_file in config_files: with open(config_file) as config_file: all_containers.update( - yaml.safe_load(config_file).get("services", {}).keys() + yaml.safe_load(config_file).get("services", {}).keys(), ) existing_containers = set() @@ -238,7 +240,7 @@ def check_docker_quickstart() -> QuickstartStatus: missing_containers = set(all_containers) - existing_containers for missing in missing_containers: container_statuses.append( - DockerContainerStatus(missing, ContainerStatus.MISSING) + DockerContainerStatus(missing, ContainerStatus.MISSING), ) return QuickstartStatus(container_statuses) diff --git a/metadata-ingestion/src/datahub/cli/docker_cli.py b/metadata-ingestion/src/datahub/cli/docker_cli.py index b744ac573aed6e..19a94b18fcf66f 100644 --- a/metadata-ingestion/src/datahub/cli/docker_cli.py +++ b/metadata-ingestion/src/datahub/cli/docker_cli.py @@ -143,7 +143,7 @@ def should_use_neo4j_for_graph_service(graph_service_override: Optional[str]) -> if len(client.volumes.list(filters={"name": "datahub_neo4jdata"})) > 0: click.echo( "Datahub Neo4j volume found, starting with neo4j as graph service.\n" - "If you want to run using elastic, run `datahub docker nuke` and re-ingest your data.\n" + "If you want to run using elastic, run `datahub docker nuke` and re-ingest your data.\n", ) return True @@ -151,7 +151,7 @@ def should_use_neo4j_for_graph_service(graph_service_override: Optional[str]) -> "No Datahub Neo4j volume found, starting with elasticsearch as graph service.\n" "To use neo4j as a graph backend, run \n" "`datahub docker quickstart --graph-service-impl neo4j`" - "\nfrom the root of the datahub repo\n" + "\nfrom the root of the datahub repo\n", ) return False @@ -169,7 +169,7 @@ def _set_environment_variables( if version is not None: if not version.startswith("v") and "." in version: logger.warning( - f"Version passed in '{version}' doesn't start with v, substituting with 'v{version}'" + f"Version passed in '{version}' doesn't start with v, substituting with 'v{version}'", ) version = f"v{version}" os.environ["DATAHUB_VERSION"] = version @@ -207,7 +207,8 @@ def _docker_compose_v2() -> List[str]: try: # Check for the docker compose v2 plugin. compose_version = subprocess.check_output( - ["docker", "compose", "version", "--short"], stderr=subprocess.STDOUT + ["docker", "compose", "version", "--short"], + stderr=subprocess.STDOUT, ).decode() assert compose_version.startswith("2.") or compose_version.startswith("v2.") return ["docker", "compose"] @@ -215,7 +216,8 @@ def _docker_compose_v2() -> List[str]: # We'll check for docker-compose as well. try: compose_version = subprocess.check_output( - ["docker-compose", "version", "--short"], stderr=subprocess.STDOUT + ["docker-compose", "version", "--short"], + stderr=subprocess.STDOUT, ).decode() if compose_version.startswith("2.") or compose_version.startswith("v2."): # This will happen if docker compose v2 is installed in standalone mode @@ -225,7 +227,7 @@ def _docker_compose_v2() -> List[str]: raise DockerComposeVersionError( f"You have docker-compose v1 ({compose_version}) installed, but we require Docker Compose v2. " "Please upgrade to Docker Compose v2. " - "See https://docs.docker.com/compose/compose-v2/ for more information." + "See https://docs.docker.com/compose/compose-v2/ for more information.", ) except (OSError, subprocess.CalledProcessError): # docker-compose v1 is not installed either. @@ -283,10 +285,10 @@ def _backup(backup_file: str) -> int: "bash", "-c", f"docker exec {DOCKER_COMPOSE_PROJECT_NAME}-mysql-1 mysqldump -u root -pdatahub datahub > {resolved_backup_file}", - ] + ], ) logger.info( - f"Backup written to {resolved_backup_file} with status {result.returncode}" + f"Backup written to {resolved_backup_file} with status {result.returncode}", ) return result.returncode @@ -376,8 +378,8 @@ def _restore( # ELASTICSEARCH_SSL_KEYSTORE_FILE= # ELASTICSEARCH_SSL_KEYSTORE_TYPE= # ELASTICSEARCH_SSL_KEYSTORE_PASSWORD= - """ - ).encode("utf-8") + """, + ).encode("utf-8"), ) env_fp.flush() if logger.isEnabledFor(logging.DEBUG): @@ -404,7 +406,7 @@ def _restore( capture_output=True, ) logger.info( - f"Index restore command finished with status {result.returncode}" + f"Index restore command finished with status {result.returncode}", ) if result.returncode != 0: logger.info(result.stderr) @@ -422,7 +424,7 @@ def detect_quickstart_arch(arch: Optional[str]) -> Architectures: matched_arch = [a for a in Architectures if arch.lower() == a.value] if not matched_arch: click.secho( - f"Failed to match arch {arch} with list of architectures supported {[a.value for a in Architectures]}" + f"Failed to match arch {arch} with list of architectures supported {[a.value for a in Architectures]}", ) quickstart_arch = matched_arch[0] click.secho(f"Using architecture {quickstart_arch}", fg="yellow") @@ -590,7 +592,7 @@ def detect_quickstart_arch(arch: Optional[str]) -> Architectures: "standalone_consumers", "kafka_setup", "arch", - ] + ], ) def quickstart( # noqa: C901 version: Optional[str], @@ -644,7 +646,7 @@ def quickstart( # noqa: C901 quickstart_arch = detect_quickstart_arch(arch) quickstart_versioning = QuickstartVersionMappingConfig.fetch_quickstart_config() quickstart_execution_plan = quickstart_versioning.get_quickstart_execution_plan( - version + version, ) logger.info(f"Using quickstart plan: {quickstart_execution_plan}") @@ -653,7 +655,7 @@ def quickstart( # noqa: C901 run_quickstart_preflight_checks(client) quickstart_compose_file = list( - quickstart_compose_file + quickstart_compose_file, ) # convert to list from tuple auth_resources_folder = Path(DATAHUB_ROOT_FOLDER) / "plugins/auth/resources" @@ -702,7 +704,8 @@ def quickstart( # noqa: C901 if pull_images: click.echo("\nPulling docker images... ") click.secho( - "This may take a while depending on your network bandwidth.", dim=True + "This may take a while depending on your network bandwidth.", + dim=True, ) # docker compose v2 seems to spam the stderr when used in a non-interactive environment. @@ -901,7 +904,7 @@ def download_compose_files( path = pathlib.Path(tmp_file.name) quickstart_compose_file_list.append(path) click.echo( - f"Fetching consumer docker-compose file {consumer_github_file} from GitHub" + f"Fetching consumer docker-compose file {consumer_github_file} from GitHub", ) # Download the quickstart docker-compose file from GitHub. quickstart_download_response = request_session.get(consumer_github_file) @@ -923,7 +926,7 @@ def download_compose_files( path = pathlib.Path(tmp_file.name) quickstart_compose_file_list.append(path) click.echo( - f"Fetching consumer docker-compose file {kafka_setup_github_file} from GitHub" + f"Fetching consumer docker-compose file {kafka_setup_github_file} from GitHub", ) # Download the quickstart docker-compose file from GitHub. quickstart_download_response = request_session.get(kafka_setup_github_file) @@ -933,11 +936,14 @@ def download_compose_files( def valid_restore_options( - restore: bool, restore_indices: bool, no_restore_indices: bool + restore: bool, + restore_indices: bool, + no_restore_indices: bool, ) -> bool: if no_restore_indices and not restore: click.secho( - "Using --no-restore-indices without a --restore isn't defined", fg="red" + "Using --no-restore-indices without a --restore isn't defined", + fg="red", ) return False if no_restore_indices and restore_indices: @@ -1012,18 +1018,19 @@ def nuke(keep_data: bool) -> None: with get_docker_client() as client: click.echo(f"Removing containers in the {DOCKER_COMPOSE_PROJECT_NAME} project") for container in client.containers.list( - all=True, filters=DATAHUB_COMPOSE_PROJECT_FILTER + all=True, + filters=DATAHUB_COMPOSE_PROJECT_FILTER, ): container.remove(v=True, force=True) if keep_data: click.echo( - f"Skipping deleting data volumes in the {DOCKER_COMPOSE_PROJECT_NAME} project" + f"Skipping deleting data volumes in the {DOCKER_COMPOSE_PROJECT_NAME} project", ) else: click.echo(f"Removing volumes in the {DOCKER_COMPOSE_PROJECT_NAME} project") for filter in DATAHUB_COMPOSE_LEGACY_VOLUME_FILTERS + [ - DATAHUB_COMPOSE_PROJECT_FILTER + DATAHUB_COMPOSE_PROJECT_FILTER, ]: for volume in client.volumes.list(filters=filter): volume.remove(force=True) diff --git a/metadata-ingestion/src/datahub/cli/get_cli.py b/metadata-ingestion/src/datahub/cli/get_cli.py index ab31323ac52b6f..c82ec50fdb40a3 100644 --- a/metadata-ingestion/src/datahub/cli/get_cli.py +++ b/metadata-ingestion/src/datahub/cli/get_cli.py @@ -72,5 +72,5 @@ def urn(ctx: Any, urn: Optional[str], aspect: List[str], details: bool) -> None: aspect_data, sort_keys=True, indent=2, - ) + ), ) diff --git a/metadata-ingestion/src/datahub/cli/ingest_cli.py b/metadata-ingestion/src/datahub/cli/ingest_cli.py index c9eaccbc65ee21..fe6ff383aebe08 100644 --- a/metadata-ingestion/src/datahub/cli/ingest_cli.py +++ b/metadata-ingestion/src/datahub/cli/ingest_cli.py @@ -93,7 +93,11 @@ def ingest() -> None: help="Turn off default reporting of ingestion results to DataHub", ) @click.option( - "--no-spinner", type=bool, is_flag=True, default=False, help="Turn off spinner" + "--no-spinner", + type=bool, + is_flag=True, + default=False, + help="Turn off spinner", ) @click.option( "--no-progress", @@ -111,7 +115,7 @@ def ingest() -> None: "no_default_report", "no_spinner", "no_progress", - ] + ], ) def run( config: str, @@ -134,10 +138,10 @@ def run_pipeline_to_completion(pipeline: Pipeline) -> int: pipeline.run() except Exception as e: logger.info( - f"Source ({pipeline.source_type}) report:\n{pipeline.source.get_report().as_string()}" + f"Source ({pipeline.source_type}) report:\n{pipeline.source.get_report().as_string()}", ) logger.info( - f"Sink ({pipeline.sink_type}) report:\n{pipeline.sink.get_report().as_string()}" + f"Sink ({pipeline.sink_type}) report:\n{pipeline.sink.get_report().as_string()}", ) raise e else: @@ -183,7 +187,8 @@ def run_pipeline_to_completion(pipeline: Pipeline) -> int: # The main ingestion has completed. If it was successful, potentially show an upgrade nudge message. if ret == 0: upgrade.check_upgrade_post( - main_method_runtime=timer.elapsed_seconds(), graph=pipeline.ctx.graph + main_method_runtime=timer.elapsed_seconds(), + graph=pipeline.ctx.graph, ) if ret: @@ -195,7 +200,7 @@ def _make_ingestion_urn(name: str) -> str: guid = datahub_guid( { "name": name, - } + }, ) return f"urn:li:dataHubIngestionSource:deploy-{guid}" @@ -293,7 +298,7 @@ def deploy( else: if not name: raise click.UsageError( - "Either --name must be set or deployment_name specified in the config" + "Either --name must be set or deployment_name specified in the config", ) deploy_options = DeployOptions(name=name) @@ -354,12 +359,14 @@ def deploy( } }) } - """ + """, ) try: response = datahub_graph.execute_graphql( - graphql_query, variables=variables, format_exception=False + graphql_query, + variables=variables, + format_exception=False, ) except GraphError as graph_error: try: @@ -367,12 +374,14 @@ def deploy( click.secho(error[0]["message"], fg="red", err=True) except Exception: click.secho( - f"Could not create ingestion source:\n{graph_error}", fg="red", err=True + f"Could not create ingestion source:\n{graph_error}", + fg="red", + err=True, ) sys.exit(1) click.echo( - f"✅ Successfully wrote data ingestion source metadata for recipe {deploy_options.name}:" + f"✅ Successfully wrote data ingestion source metadata for recipe {deploy_options.name}:", ) click.echo(response) @@ -443,7 +452,10 @@ def mcps(path: str) -> None: @click.argument("page_size", type=int, default=100) @click.option("--urn", type=str, default=None, help="Filter by ingestion source URN.") @click.option( - "--source", type=str, default=None, help="Filter by ingestion source name." + "--source", + type=str, + default=None, + help="Filter by ingestion source name.", ) @upgrade.check_upgrade @telemetry.with_telemetry() @@ -482,7 +494,7 @@ def list_source_runs(page_offset: int, page_size: int, urn: str, source: str) -> "start": page_offset, "count": page_size, "filters": filters, - } + }, } client = get_default_graph() @@ -534,7 +546,7 @@ def list_source_runs(page_offset: int, page_size: int, urn: str, source: str) -> try: start_time = ( datetime.fromtimestamp( - result.get("startTimeMs", 0) / 1000 + result.get("startTimeMs", 0) / 1000, ).strftime("%Y-%m-%d %H:%M:%S") if status != "DUPLICATE" and result.get("startTimeMs") is not None else "N/A" @@ -553,7 +565,7 @@ def list_source_runs(page_offset: int, page_size: int, urn: str, source: str) -> rows, headers=INGEST_SRC_TABLE_COLUMNS, tablefmt="grid", - ) + ), ) @@ -595,7 +607,7 @@ def list_runs(page_offset: int, page_size: int, include_soft_deletes: bool) -> N row.get("runId"), row.get("rows"), datetime.fromtimestamp(row.get("timestamp") / 1000).strftime( - "%Y-%m-%d %H:%M:%S" + "%Y-%m-%d %H:%M:%S", ) + f" ({local_timezone})", ] @@ -619,7 +631,11 @@ def list_runs(page_offset: int, page_size: int, include_soft_deletes: bool) -> N @upgrade.check_upgrade @telemetry.with_telemetry() def show( - run_id: str, start: int, count: int, include_soft_deletes: bool, show_aspect: bool + run_id: str, + start: int, + count: int, + include_soft_deletes: bool, + show_aspect: bool, ) -> None: """Describe a provided ingestion run to datahub""" client = get_default_graph() @@ -647,7 +663,7 @@ def show( cli_utils.format_aspect_summaries(rows), RUN_TABLE_COLUMNS, tablefmt="grid", - ) + ), ) else: for row in rows: @@ -669,7 +685,11 @@ def show( @upgrade.check_upgrade @telemetry.with_telemetry() def rollback( - run_id: str, force: bool, dry_run: bool, safe: bool, report_dir: str + run_id: str, + force: bool, + dry_run: bool, + safe: bool, + report_dir: str, ) -> None: """Rollback a provided ingestion run to datahub""" client = get_default_graph() @@ -689,29 +709,32 @@ def rollback( unsafe_entity_count, unsafe_entities, ) = cli_utils.post_rollback_endpoint( - client._session, client.config.server, payload_obj, "/runs?action=rollback" + client._session, + client.config.server, + payload_obj, + "/runs?action=rollback", ) click.echo( - "Rolling back deletes the entities created by a run and reverts the updated aspects" + "Rolling back deletes the entities created by a run and reverts the updated aspects", ) click.echo( - f"This rollback {'will' if dry_run else ''} {'delete' if dry_run else 'deleted'} {entities_affected} entities and {'will roll' if dry_run else 'rolled'} back {aspects_reverted} aspects" + f"This rollback {'will' if dry_run else ''} {'delete' if dry_run else 'deleted'} {entities_affected} entities and {'will roll' if dry_run else 'rolled'} back {aspects_reverted} aspects", ) click.echo( - f"showing first {len(structured_rows)} of {aspects_reverted} aspects {'that will be ' if dry_run else ''}reverted by this run" + f"showing first {len(structured_rows)} of {aspects_reverted} aspects {'that will be ' if dry_run else ''}reverted by this run", ) click.echo(tabulate(structured_rows, RUN_TABLE_COLUMNS, tablefmt="grid")) if aspects_affected > 0: if safe: click.echo( - f"WARNING: This rollback {'will hide' if dry_run else 'has hidden'} {aspects_affected} aspects related to {unsafe_entity_count} entities being rolled back that are not part ingestion run id." + f"WARNING: This rollback {'will hide' if dry_run else 'has hidden'} {aspects_affected} aspects related to {unsafe_entity_count} entities being rolled back that are not part ingestion run id.", ) else: click.echo( - f"WARNING: This rollback {'will delete' if dry_run else 'has deleted'} {aspects_affected} aspects related to {unsafe_entity_count} entities being rolled back that are not part ingestion run id." + f"WARNING: This rollback {'will delete' if dry_run else 'has deleted'} {aspects_affected} aspects related to {unsafe_entity_count} entities being rolled back that are not part ingestion run id.", ) if unsafe_entity_count > 0: diff --git a/metadata-ingestion/src/datahub/cli/json_file.py b/metadata-ingestion/src/datahub/cli/json_file.py index c2c17d92e51166..a97260578779a0 100644 --- a/metadata-ingestion/src/datahub/cli/json_file.py +++ b/metadata-ingestion/src/datahub/cli/json_file.py @@ -8,19 +8,20 @@ def check_mce_file(filepath: str) -> str: mce_source = GenericFileSource.create( - {"filename": filepath}, PipelineContext(run_id="json-file") + {"filename": filepath}, + PipelineContext(run_id="json-file"), ) for _ in mce_source.get_workunits(): pass if len(mce_source.get_report().failures): # raise the first failure found logger.error( - f"Event file check failed with errors. Raising first error found. Full report {mce_source.get_report().as_string()}" + f"Event file check failed with errors. Raising first error found. Full report {mce_source.get_report().as_string()}", ) for failure in mce_source.get_report().failures: raise Exception(failure.context) raise Exception( - f"Failed to process file due to {mce_source.get_report().failures}" + f"Failed to process file due to {mce_source.get_report().failures}", ) else: return f"{mce_source.get_report().events_produced} MCEs found - all valid" diff --git a/metadata-ingestion/src/datahub/cli/lite_cli.py b/metadata-ingestion/src/datahub/cli/lite_cli.py index 90bbb353deab18..12cb626c76f87b 100644 --- a/metadata-ingestion/src/datahub/cli/lite_cli.py +++ b/metadata-ingestion/src/datahub/cli/lite_cli.py @@ -40,7 +40,8 @@ class DuckDBLiteConfigWrapper(DuckDBLiteConfig): class LiteCliConfig(DatahubConfig): lite: LiteLocalConfig = LiteLocalConfig( - type="duckdb", config=DuckDBLiteConfigWrapper().dict() + type="duckdb", + config=DuckDBLiteConfigWrapper().dict(), ) @@ -86,11 +87,12 @@ def shell_complete(self, ctx, param, incomplete): return [ ( CompletionItem( - browseable.auto_complete.suggested_path, type="plain" + browseable.auto_complete.suggested_path, + type="plain", ) if browseable.auto_complete else CompletionItem( - f"{incomplete}/{browseable.name}".replace("//", "/") + f"{incomplete}/{browseable.name}".replace("//", "/"), ) ) for browseable in completions @@ -127,7 +129,7 @@ def get( if urn is None and path is None: if not ctx.args: raise click.UsageError( - "Nothing for me to get. Maybe provide an urn or a path? Use ls if you want to explore me." + "Nothing for me to get. Maybe provide an urn or a path? Use ls if you want to explore me.", ) urn_or_path = ctx.args[0] if urn_or_path.startswith("urn:"): @@ -153,7 +155,7 @@ def get( details=verbose, ), indent=2, - ) + ), ) else: parents.update(browseable.parents or []) @@ -162,10 +164,13 @@ def get( click.echo( json.dumps( lite.get( - id=p, aspects=aspect, as_of=asof_millis, details=verbose + id=p, + aspects=aspect, + as_of=asof_millis, + details=verbose, ), indent=2, - ) + ), ) if urn: @@ -173,7 +178,7 @@ def get( json.dumps( lite.get(id=urn, aspects=aspect, as_of=asof_millis, details=verbose), indent=2, - ) + ), ) end_time = time.time() logger.debug(f"Time taken: {int((end_time - start_time) * 1000.0)} millis") @@ -186,7 +191,7 @@ def nuke(ctx: click.Context) -> None: """Nuke the instance""" lite = _get_datahub_lite() if click.confirm( - f"This will permanently delete the DataHub Lite instance at: {lite.location()}. Are you sure?" + f"This will permanently delete the DataHub Lite instance at: {lite.location()}. Are you sure?", ): lite.destroy() click.echo(f"DataHub Lite at {lite.location()} nuked!") @@ -235,7 +240,7 @@ def ls(path: Optional[str]) -> None: if auto_complete: click.echo( f"Path not found at {auto_complete[0].success_path}/".replace("//", "/") - + click.style(f"{auto_complete[0].failed_token}", fg="red") + + click.style(f"{auto_complete[0].failed_token}", fg="red"), ) click.echo("Did you mean") for completable in auto_complete: @@ -267,7 +272,8 @@ def ls(path: Optional[str]) -> None: "--flavor", required=False, type=click.Choice( - choices=[x.lower() for x in SearchFlavor._member_names_], case_sensitive=False + choices=[x.lower() for x in SearchFlavor._member_names_], + case_sensitive=False, ), default=SearchFlavor.FREE_TEXT.name.lower(), help="the flavor of the query, defaults to free text. Set to exact if you want to pass in a specific query to the database engine.", @@ -295,14 +301,16 @@ def search( search_flavor = SearchFlavor[flavor.upper()] except KeyError: raise click.UsageError( - f"Failed to find a matching query flavor for {flavor}. Valid values are {[x.lower() for x in SearchFlavor._member_names_]}" + f"Failed to find a matching query flavor for {flavor}. Valid values are {[x.lower() for x in SearchFlavor._member_names_]}", ) catalog = _get_datahub_lite(read_only=True) # sanitize query result_ids = set() try: for searchable in catalog.search( - query=query, flavor=search_flavor, aspects=aspect + query=query, + flavor=search_flavor, + aspects=aspect, ): result_str = searchable.id if details: @@ -338,7 +346,7 @@ def init(ctx: click.Context, type: Optional[str], file: Optional[str]) -> None: new_lite_config = LiteLocalConfig.parse_obj(new_lite_config_dict) if lite_config != new_lite_config: if click.confirm( - f"Will replace datahub lite config {lite_config} with {new_lite_config}" + f"Will replace datahub lite config {lite_config} with {new_lite_config}", ): write_lite_config(new_lite_config) @@ -354,7 +362,7 @@ def import_cmd(ctx: click.Context, file: Optional[str]) -> None: if file is None: if not ctx.args: raise click.UsageError( - "Nothing for me to import. Maybe provide a metadata file?" + "Nothing for me to import. Maybe provide a metadata file?", ) file = ctx.args[0] @@ -377,20 +385,22 @@ def export(ctx: click.Context, file: str) -> None: current_time = int(time.time() * 1000.0) pipeline_ctx: PipelineContext = PipelineContext( - run_id=f"datahub-lite_{current_time}" + run_id=f"datahub-lite_{current_time}", ) file = os.path.expanduser(file) base_dir = os.path.dirname(file) if base_dir: os.makedirs(base_dir, exist_ok=True) file_sink: FileSink = FileSink( - ctx=pipeline_ctx, config=FileSinkConfig(filename=file) + ctx=pipeline_ctx, + config=FileSinkConfig(filename=file), ) datahub_lite = _get_datahub_lite(read_only=True) num_events = 0 for mcp in datahub_lite.get_all_aspects(): file_sink.write_record_async( - RecordEnvelope(record=mcp, metadata={}), write_callback=NoopWriteCallback() + RecordEnvelope(record=mcp, metadata={}), + write_callback=NoopWriteCallback(), ) num_events += 1 diff --git a/metadata-ingestion/src/datahub/cli/migrate.py b/metadata-ingestion/src/datahub/cli/migrate.py index 3bd1b6fc4dc124..325b480185f37d 100644 --- a/metadata-ingestion/src/datahub/cli/migrate.py +++ b/metadata-ingestion/src/datahub/cli/migrate.py @@ -141,7 +141,7 @@ def dataplatform2instance_func( keep: bool, ) -> None: click.echo( - f"Starting migration: platform:{platform}, instance={instance}, force={force}, dry-run={dry_run}" + f"Starting migration: platform:{platform}, instance={instance}, force={force}, dry-run={dry_run}", ) run_id: str = f"migrate-{uuid.uuid4()}" migration_report = MigrationReport(run_id, dry_run, keep) @@ -163,7 +163,8 @@ def dataplatform2instance_func( ) if "dataPlatformInstance" in response: assert isinstance( - response["dataPlatformInstance"], DataPlatformInstanceClass + response["dataPlatformInstance"], + DataPlatformInstanceClass, ) data_platform_instance: DataPlatformInstanceClass = response[ "dataPlatformInstance" @@ -173,14 +174,15 @@ def dataplatform2instance_func( continue else: log.debug( - f"{src_entity_urn} is not an instance specific urn. {response}" + f"{src_entity_urn} is not an instance specific urn. {response}", ) urns_to_migrate.append(src_entity_urn) if not force and not dry_run: # get a confirmation from the operator before proceeding if this is not a dry run sampled_urns_to_migrate = random.sample( - urns_to_migrate, k=min(10, len(urns_to_migrate)) + urns_to_migrate, + k=min(10, len(urns_to_migrate)), ) sampled_new_urns: List[str] = [ make_dataset_urn_with_platform_instance( @@ -193,13 +195,14 @@ def dataplatform2instance_func( if key ] click.echo( - f"Will migrate {len(urns_to_migrate)} urns such as {random.sample(urns_to_migrate, k=min(10, len(urns_to_migrate)))}" + f"Will migrate {len(urns_to_migrate)} urns such as {random.sample(urns_to_migrate, k=min(10, len(urns_to_migrate)))}", ) click.echo(f"New urns will look like {sampled_new_urns}") click.confirm("Ok to proceed?", abort=True) for src_entity_urn in progressbar.progressbar( - urns_to_migrate, redirect_stdout=True + urns_to_migrate, + redirect_stdout=True, ): key = dataset_urn_to_key(src_entity_urn) assert key @@ -232,7 +235,7 @@ def dataplatform2instance_func( instance=make_dataplatform_instance_urn(platform, instance), ), systemMetadata=system_metadata, - ) + ), ) migration_report.on_entity_create(new_urn, "dataPlatformInstance") @@ -241,16 +244,24 @@ def dataplatform2instance_func( entity_type = _get_type_from_urn(target_urn) relationshipType = relationship.relationship_type aspect_name = migration_utils.get_aspect_name_from_relationship( - relationshipType, entity_type + relationshipType, + entity_type, ) aspect_map = cli_utils.get_aspects_for_entity( - graph._session, graph.config.server, target_urn, aspects=[aspect_name] + graph._session, + graph.config.server, + target_urn, + aspects=[aspect_name], ) if aspect_name in aspect_map: aspect = aspect_map[aspect_name] assert isinstance(aspect, DictWrapper) aspect = migration_utils.modify_urn_list_for_aspect( - aspect_name, aspect, relationshipType, src_entity_urn, new_urn + aspect_name, + aspect, + relationshipType, + src_entity_urn, + new_urn, ) # use mcpw mcp = MetadataChangeProposalWrapper( @@ -266,7 +277,10 @@ def dataplatform2instance_func( if not dry_run and not keep: log.info(f"will {'hard' if hard else 'soft'} delete {src_entity_urn}") delete_cli._delete_one_urn( - graph, src_entity_urn, soft=not hard, run_id=run_id + graph, + src_entity_urn, + soft=not hard, + run_id=run_id, ) migration_report.on_entity_migrated(src_entity_urn, "status") # type: ignore @@ -308,7 +322,7 @@ def migrate_containers( platform is not None and customProperties["platform"] != platform ): log.debug( - f"{container['urn']} does not match filter criteria, skipping.. {customProperties} {env} {platform}" + f"{container['urn']} does not match filter criteria, skipping.. {customProperties} {env} {platform}", ) continue @@ -332,7 +346,7 @@ def migrate_containers( newKey.instance = instance log.debug( - f"Container key migration: {container['urn']} -> urn:li:container:{newKey.guid()}" + f"Container key migration: {container['urn']} -> urn:li:container:{newKey.guid()}", ) src_urn = container["urn"] @@ -354,7 +368,8 @@ def migrate_containers( assert isinstance(mcp.aspect, ContainerPropertiesClass) containerProperties: ContainerPropertiesClass = mcp.aspect containerProperties.customProperties = newKey.dict( - by_alias=True, exclude_none=True + by_alias=True, + exclude_none=True, ) mcp.aspect = containerProperties elif mcp.aspectName == "containerKey": @@ -378,7 +393,10 @@ def migrate_containers( if not dry_run and not keep: log.info(f"will {'hard' if hard else 'soft'} delete {src_urn}") delete_cli._delete_one_urn( - rest_emitter, src_urn, soft=not hard, run_id=run_id + rest_emitter, + src_urn, + soft=not hard, + run_id=run_id, ) migration_report.on_entity_migrated(src_urn, "status") # type: ignore @@ -388,14 +406,15 @@ def migrate_containers( def get_containers_for_migration(env: str) -> List[Any]: client = get_default_graph() containers_to_migrate = list( - client.get_urns_by_filter(entity_types=["container"], env=env) + client.get_urns_by_filter(entity_types=["container"], env=env), ) containers = [] increment = 20 for i in range(0, len(containers_to_migrate), increment): for container in batch_get_ids( - client, containers_to_migrate[i : i + increment] + client, + containers_to_migrate[i : i + increment], ): log.debug(container) containers.append(container) @@ -443,7 +462,7 @@ def process_container_relationships( rest_emitter: DatahubRestEmitter, ) -> None: relationships: Iterable[RelatedEntity] = migration_utils.get_incoming_relationships( - urn=src_urn + urn=src_urn, ) client = get_default_graph() for relationship in relationships: @@ -457,7 +476,8 @@ def process_container_relationships( entity_type = _get_type_from_urn(target_urn) relationshipType = relationship.relationship_type aspect_name = migration_utils.get_aspect_name_from_relationship( - relationshipType, entity_type + relationshipType, + entity_type, ) aspect_map = cli_utils.get_aspects_for_entity( client._session, @@ -470,7 +490,11 @@ def process_container_relationships( aspect = aspect_map[aspect_name] assert isinstance(aspect, DictWrapper) aspect = migration_utils.modify_urn_list_for_aspect( - aspect_name, aspect, relationshipType, src_urn, dst_urn + aspect_name, + aspect, + relationshipType, + src_urn, + dst_urn, ) # use mcpw mcp = MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/src/datahub/cli/migration_utils.py b/metadata-ingestion/src/datahub/cli/migration_utils.py index a3dfcfe2ac4034..956cb4723b7bd4 100644 --- a/metadata-ingestion/src/datahub/cli/migration_utils.py +++ b/metadata-ingestion/src/datahub/cli/migration_utils.py @@ -81,7 +81,7 @@ def get_aspect_name_from_relationship(relationship_type: str, entity_type: str) return aspect_map[relationship_type][entity_type.lower()] raise Exception( - f"Unable to map aspect name from relationship_type {relationship_type} and entity_type {entity_type}" + f"Unable to map aspect name from relationship_type {relationship_type} and entity_type {entity_type}", ) @@ -109,7 +109,7 @@ def dataJobInputOutput_modifier( return dataJobInputOutput raise Exception( - f"Unable to map aspect_name: dataJobInputOutput, relationship_type {relationship_type}" + f"Unable to map aspect_name: dataJobInputOutput, relationship_type {relationship_type}", ) @staticmethod @@ -232,7 +232,7 @@ def modify_urn_list_for_aspect( new_urn=new_urn, ) raise Exception( - f"Unable to map aspect_name: {aspect_name}, relationship_type {relationship_type}" + f"Unable to map aspect_name: {aspect_name}, relationship_type {relationship_type}", ) diff --git a/metadata-ingestion/src/datahub/cli/put_cli.py b/metadata-ingestion/src/datahub/cli/put_cli.py index d3a6fb5caaf197..5721aaa680def4 100644 --- a/metadata-ingestion/src/datahub/cli/put_cli.py +++ b/metadata-ingestion/src/datahub/cli/put_cli.py @@ -50,7 +50,10 @@ def aspect(urn: str, aspect: str, aspect_data: str, run_id: Optional[str]) -> No entity_type = guess_entity_type(urn) aspect_obj = load_config_file( - aspect_data, allow_stdin=True, resolve_env_vars=False, process_directives=False + aspect_data, + allow_stdin=True, + resolve_env_vars=False, + process_directives=False, ) client = get_default_graph() @@ -95,10 +98,16 @@ def aspect(urn: str, aspect: str, aspect_data: str, run_id: Optional[str]) -> No required=True, ) @click.option( - "--run-id", type=str, help="Run ID into which we should log the platform." + "--run-id", + type=str, + help="Run ID into which we should log the platform.", ) def platform( - ctx: click.Context, name: str, display_name: Optional[str], logo: str, run_id: str + ctx: click.Context, + name: str, + display_name: Optional[str], + logo: str, + run_id: str, ) -> None: """ Create or update a dataplatform entity in DataHub @@ -126,5 +135,5 @@ def platform( ) datahub_graph.emit(mcp) click.echo( - f"✅ Successfully wrote data platform metadata for {platform_urn} to DataHub ({datahub_graph})" + f"✅ Successfully wrote data platform metadata for {platform_urn} to DataHub ({datahub_graph})", ) diff --git a/metadata-ingestion/src/datahub/cli/quickstart_versioning.py b/metadata-ingestion/src/datahub/cli/quickstart_versioning.py index 9739af5127f4d1..53b29e3cee182b 100644 --- a/metadata-ingestion/src/datahub/cli/quickstart_versioning.py +++ b/metadata-ingestion/src/datahub/cli/quickstart_versioning.py @@ -43,7 +43,7 @@ def _fetch_latest_version(cls) -> str: :return: The latest version. """ response = requests.get( - "https://api.github.com/repos/datahub-project/datahub/releases/latest" + "https://api.github.com/repos/datahub-project/datahub/releases/latest", ) response.raise_for_status() return json.loads(response.text)["tag_name"] @@ -52,7 +52,7 @@ def _fetch_latest_version(cls) -> str: def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": if LOCAL_QUICKSTART_MAPPING_FILE: logger.info( - "LOCAL_QUICKSTART_MAPPING_FILE is set, will try to read from local file." + "LOCAL_QUICKSTART_MAPPING_FILE is set, will try to read from local file.", ) path = os.path.expanduser(LOCAL_QUICKSTART_MAPPING_FILE) with open(path) as f: @@ -66,7 +66,7 @@ def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": config_raw = yaml.safe_load(response.text) except Exception as e: logger.debug( - f"Couldn't connect to github: {e}, will try to read from local file." + f"Couldn't connect to github: {e}, will try to read from local file.", ) try: path = os.path.expanduser(DEFAULT_LOCAL_CONFIG_PATH) @@ -77,14 +77,16 @@ def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": if config_raw is None: logger.info( - "Unable to connect to GitHub, using default quickstart version mapping config." + "Unable to connect to GitHub, using default quickstart version mapping config.", ) return QuickstartVersionMappingConfig( quickstart_version_map={ "default": QuickstartExecutionPlan( - composefile_git_ref="master", docker_tag="head", mysql_tag="8.2" + composefile_git_ref="master", + docker_tag="head", + mysql_tag="8.2", ), - } + }, ) config = cls.parse_obj(config_raw) @@ -94,11 +96,13 @@ def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": try: release = cls._fetch_latest_version() config.quickstart_version_map["stable"] = QuickstartExecutionPlan( - composefile_git_ref=release, docker_tag=release, mysql_tag="8.2" + composefile_git_ref=release, + docker_tag=release, + mysql_tag="8.2", ) except Exception: click.echo( - "Couldn't connect to github. --version stable will not work." + "Couldn't connect to github. --version stable will not work.", ) save_quickstart_config(config) return config @@ -140,7 +144,8 @@ def get_quickstart_execution_plan( def save_quickstart_config( - config: QuickstartVersionMappingConfig, path: str = DEFAULT_LOCAL_CONFIG_PATH + config: QuickstartVersionMappingConfig, + path: str = DEFAULT_LOCAL_CONFIG_PATH, ) -> None: # create directory if it doesn't exist path = os.path.expanduser(path) diff --git a/metadata-ingestion/src/datahub/cli/specific/assertions_cli.py b/metadata-ingestion/src/datahub/cli/specific/assertions_cli.py index c0d93af90ada00..d3ed27e1d566e5 100644 --- a/metadata-ingestion/src/datahub/cli/specific/assertions_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/assertions_cli.py @@ -73,7 +73,10 @@ def upsert(file: str) -> None: @upgrade.check_upgrade @telemetry.with_telemetry() def compile( - file: str, platform: str, output_to: Optional[str], extras: List[str] + file: str, + platform: str, + output_to: Optional[str], + extras: List[str], ) -> None: """Compile a set of assertions for input assertion platform. Note that this does not run any code or execute any queries on assertion platform @@ -99,7 +102,8 @@ def compile( try: compiler = ASSERTION_PLATFORMS[platform].create( - output_dir=output_to, extras=extras_list_to_dict(extras) + output_dir=output_to, + extras=extras_list_to_dict(extras), ) result = compiler.compile(assertions_spec) @@ -127,7 +131,7 @@ def write_report_file(output_to: str, result: AssertionCompilationResult) -> Non path=report_path, type=CompileResultArtifactType.COMPILE_REPORT, description="Detailed report about compile status", - ) + ), ) f.write(result.report.as_json()) diff --git a/metadata-ingestion/src/datahub/cli/specific/datacontract_cli.py b/metadata-ingestion/src/datahub/cli/specific/datacontract_cli.py index 3745943c8c96ad..141f8f3205954c 100644 --- a/metadata-ingestion/src/datahub/cli/specific/datacontract_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/datacontract_cli.py @@ -31,7 +31,7 @@ def upsert(file: str) -> None: with get_default_graph() as graph: if not graph.exists(data_contract.entity): raise ValueError( - f"Cannot define a data contract for non-existent entity {data_contract.entity}" + f"Cannot define a data contract for non-existent entity {data_contract.entity}", ) try: @@ -48,7 +48,10 @@ def upsert(file: str) -> None: @datacontract.command() @click.option( - "--urn", required=False, type=str, help="The urn for the data contract to delete" + "--urn", + required=False, + type=str, + help="The urn for the data contract to delete", ) @click.option( "-f", @@ -66,7 +69,7 @@ def delete(urn: Optional[str], file: Optional[str], hard: bool) -> None: if not urn: if not file: raise click.UsageError( - "Must provide either an urn or a file to delete a data contract" + "Must provide either an urn or a file to delete a data contract", ) data_contract = DataContract.from_yaml(file) diff --git a/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py b/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py index 857a6fbb4e18e5..4ccc98a08df5ea 100644 --- a/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/dataproduct_cli.py @@ -36,7 +36,7 @@ def _get_owner_urn(maybe_urn: str) -> str: elif maybe_urn.startswith("urn:li:"): # this looks like an urn, but not a type we recognize raise Exception( - f"Owner urn {maybe_urn} not recognized as one of the supported types (corpuser, corpGroup)" + f"Owner urn {maybe_urn} not recognized as one of the supported types (corpuser, corpGroup)", ) else: # mint a user urn as the default @@ -66,7 +66,7 @@ def _print_diff(orig_file, new_file): new_lines = fp.readlines() sys.stdout.writelines( - difflib.unified_diff(orig_lines, new_lines, orig_file, new_file) + difflib.unified_diff(orig_lines, new_lines, orig_file, new_file), ) @@ -125,7 +125,10 @@ def mutate(file: Path, validate_assets: bool, external_url: str, upsert: bool) - ) @click.option("-f", "--file", required=True, type=click.Path(exists=True)) @click.option( - "--validate-assets/--no-validate-assets", required=False, is_flag=True, default=True + "--validate-assets/--no-validate-assets", + required=False, + is_flag=True, + default=True, ) @click.option("--external-url", required=False, type=str) @upgrade.check_upgrade @@ -141,7 +144,10 @@ def update(file: Path, validate_assets: bool, external_url: str) -> None: ) @click.option("-f", "--file", required=True, type=click.Path(exists=True)) @click.option( - "--validate-assets/--no-validate-assets", required=False, is_flag=True, default=True + "--validate-assets/--no-validate-assets", + required=False, + is_flag=True, + default=True, ) @click.option("--external-url", required=False, type=str) @upgrade.check_upgrade @@ -168,7 +174,8 @@ def diff(file: Path, update: bool) -> None: data_product_local: DataProduct = DataProduct.from_yaml(file, emitter) id = data_product_local.id data_product_remote = DataProduct.from_datahub( - emitter, data_product_local.urn + emitter, + data_product_local.urn, ) with NamedTemporaryFile(suffix=".yaml") as temp_fp: update_needed = data_product_remote.patch_yaml( @@ -194,7 +201,10 @@ def diff(file: Path, update: bool) -> None: name="delete", ) @click.option( - "--urn", required=False, type=str, help="The urn for the data product to delete" + "--urn", + required=False, + type=str, + help="The urn for the data product to delete", ) @click.option( "-f", @@ -211,7 +221,8 @@ def delete(urn: str, file: Path, hard: bool) -> None: if not urn and not file: click.secho( - "Must provide either an urn or a file to delete a data product", fg="red" + "Must provide either an urn or a file to delete a data product", + fg="red", ) raise click.Abort() @@ -252,7 +263,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): dataproduct: DataProduct = DataProduct.from_datahub(graph=graph, id=urn) click.secho( - f"{json.dumps(dataproduct.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(dataproduct.dict(exclude_unset=True, exclude_none=True), indent=2)}", ) if to_file: dataproduct.to_yaml(Path(to_file)) @@ -339,8 +350,10 @@ def add_owner(urn: str, owner: str, owner_type: str) -> None: owner_type, owner_type_urn = validate_ownership_type(owner_type) dataproduct_patcher.add_owner( owner=OwnerClass( - owner=_get_owner_urn(owner), type=owner_type, typeUrn=owner_type_urn - ) + owner=_get_owner_urn(owner), + type=owner_type, + typeUrn=owner_type_urn, + ), ) with get_default_graph() as graph: _abort_if_non_existent_urn(graph, urn, "add owners") @@ -371,7 +384,10 @@ def remove_owner(urn: str, owner_urn: str) -> None: @click.option("--urn", required=True, type=str) @click.option("--asset", required=True, type=str) @click.option( - "--validate-assets/--no-validate-assets", required=False, is_flag=True, default=True + "--validate-assets/--no-validate-assets", + required=False, + is_flag=True, + default=True, ) @upgrade.check_upgrade @telemetry.with_telemetry() @@ -398,7 +414,10 @@ def add_asset(urn: str, asset: str, validate_assets: bool) -> None: @click.option("--urn", required=True, type=str) @click.option("--asset", required=True, type=str) @click.option( - "--validate-assets/--no-validate-assets", required=False, is_flag=True, default=True + "--validate-assets/--no-validate-assets", + required=False, + is_flag=True, + default=True, ) @upgrade.check_upgrade @telemetry.with_telemetry() diff --git a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py index 5601d7e716c797..763187f527b960 100644 --- a/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/dataset_cli.py @@ -61,7 +61,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): dataset: Dataset = Dataset.from_datahub(graph=graph, urn=urn) click.secho( - f"{json.dumps(dataset.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(dataset.dict(exclude_unset=True, exclude_none=True), indent=2)}", ) if to_file: dataset.to_yaml(Path(to_file)) @@ -91,7 +91,10 @@ def add_sibling(urn: str, sibling_urns: Tuple[str]) -> None: def _emit_sibling( - graph: DataHubGraph, primary_urn: str, urn: str, all_urns: Set[str] + graph: DataHubGraph, + primary_urn: str, + urn: str, + all_urns: Set[str], ) -> None: siblings = _get_existing_siblings(graph, urn) for sibling_urn in all_urns: @@ -101,7 +104,7 @@ def _emit_sibling( MetadataChangeProposalWrapper( entityUrn=urn, aspect=Siblings(primary=primary_urn == urn, siblings=sorted(siblings)), - ) + ), ) diff --git a/metadata-ingestion/src/datahub/cli/specific/forms_cli.py b/metadata-ingestion/src/datahub/cli/specific/forms_cli.py index a494396909b32d..69a952e01aaafb 100644 --- a/metadata-ingestion/src/datahub/cli/specific/forms_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/forms_cli.py @@ -44,7 +44,7 @@ def get(urn: str, to_file: str) -> None: if graph.exists(urn): form: Forms = Forms.from_datahub(graph=graph, urn=urn) click.secho( - f"{json.dumps(form.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(form.dict(exclude_unset=True, exclude_none=True), indent=2)}", ) if to_file: form.to_yaml(Path(to_file)) diff --git a/metadata-ingestion/src/datahub/cli/specific/group_cli.py b/metadata-ingestion/src/datahub/cli/specific/group_cli.py index e313fce33d4d57..9be2d2e6b1492d 100644 --- a/metadata-ingestion/src/datahub/cli/specific/group_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/group_cli.py @@ -46,12 +46,14 @@ def upsert(file: Path, override_editable: bool) -> None: datahub_group = CorpGroup.parse_obj(group_config) for mcp in datahub_group.generate_mcp( generation_config=CorpGroupGenerationConfig( - override_editable=override_editable, datahub_graph=emitter - ) + override_editable=override_editable, + datahub_graph=emitter, + ), ): emitter.emit(mcp) click.secho( - f"Update succeeded for group {datahub_group.urn}.", fg="green" + f"Update succeeded for group {datahub_group.urn}.", + fg="green", ) except Exception as e: click.secho( diff --git a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py index 5cd28516a076d9..0e6fe8ef705350 100644 --- a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py @@ -54,12 +54,13 @@ def get(urn: str, to_file: str) -> None: StructuredProperties.from_datahub(graph=graph, urn=urn) ) click.secho( - f"{json.dumps(structuredproperties.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(structuredproperties.dict(exclude_unset=True, exclude_none=True), indent=2)}", ) if to_file: structuredproperties.to_yaml(Path(to_file)) click.secho( - f"Structured property yaml written to {to_file}", fg="green" + f"Structured property yaml written to {to_file}", + fg="green", ) else: click.secho(f"Structured property {urn} does not exist") @@ -120,7 +121,7 @@ def to_yaml_list( with get_default_graph() as graph: if details: logger.info( - "Listing structured properties with details. Use --no-details for urns only" + "Listing structured properties with details. Use --no-details for urns only", ) structuredproperties = StructuredProperties.list(graph) if to_file: @@ -128,11 +129,11 @@ def to_yaml_list( else: for structuredproperty in structuredproperties: click.secho( - f"{json.dumps(structuredproperty.dict(exclude_unset=True, exclude_none=True), indent=2)}" + f"{json.dumps(structuredproperty.dict(exclude_unset=True, exclude_none=True), indent=2)}", ) else: logger.info( - "Listing structured property urns only, use --details for more information" + "Listing structured property urns only, use --details for more information", ) structured_property_urns = StructuredProperties.list_urns(graph) if to_file: @@ -140,7 +141,8 @@ def to_yaml_list( for urn in structured_property_urns: f.write(f"{urn}\n") click.secho( - f"Structured property urns written to {to_file}", fg="green" + f"Structured property urns written to {to_file}", + fg="green", ) else: for urn in structured_property_urns: diff --git a/metadata-ingestion/src/datahub/cli/specific/user_cli.py b/metadata-ingestion/src/datahub/cli/specific/user_cli.py index 740e870d0f49b2..7a74d816927ead 100644 --- a/metadata-ingestion/src/datahub/cli/specific/user_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/user_cli.py @@ -46,9 +46,9 @@ def upsert(file: Path, override_editable: bool) -> None: emitter.emit_all( datahub_user.generate_mcp( generation_config=CorpUserGenerationConfig( - override_editable=override_editable - ) - ) + override_editable=override_editable, + ), + ), ) click.secho(f"Update succeeded for urn {datahub_user.urn}.", fg="green") except Exception as e: diff --git a/metadata-ingestion/src/datahub/cli/timeline_cli.py b/metadata-ingestion/src/datahub/cli/timeline_cli.py index 174ce63e84ef4c..17d8d8fc3a01e7 100644 --- a/metadata-ingestion/src/datahub/cli/timeline_cli.py +++ b/metadata-ingestion/src/datahub/cli/timeline_cli.py @@ -37,7 +37,7 @@ def pretty_id(id: Optional[str]) -> str: assert id is not None if id.startswith("urn:li:datasetField:") or id.startswith("urn:li:schemaField:"): schema_field_key = schema_field_urn_to_key( - id.replace("urn:li:datasetField", "urn:li:schemaField") + id.replace("urn:li:datasetField", "urn:li:schemaField"), ) if schema_field_key: assert schema_field_key is not None @@ -73,7 +73,7 @@ def get_timeline( encoded_urn = Urn.url_encode(urn) else: raise Exception( - f"urn {urn} does not seem to be a valid raw (starts with urn:) or encoded urn (starts with urn%3A)" + f"urn {urn} does not seem to be a valid raw (starts with urn:) or encoded urn (starts with urn%3A)", ) categories: str = ",".join([c.upper() for c in category]) start_time_param: str = f"&startTime={start_time}" if start_time else "" @@ -124,7 +124,11 @@ def get_timeline( help="The end time for the timeline query in milliseconds. Shorthand form like 7days is also supported. e.g. --end 7daysago implies a timestamp 7 days ago", ) @click.option( - "--verbose", "-v", type=bool, is_flag=True, help="Show the underlying http response" + "--verbose", + "-v", + type=bool, + is_flag=True, + help="Show the underlying http response", ) @click.option("--raw", type=bool, is_flag=True, help="Show the raw diff") @click.pass_context @@ -151,7 +155,7 @@ def timeline( for c in category: if c.upper() not in all_categories: raise click.UsageError( - f"category: {c.upper()} is not one of {all_categories}" + f"category: {c.upper()} is not one of {all_categories}", ) if urn is None: @@ -187,7 +191,7 @@ def timeline( if isinstance(timeline, list) and not verbose: for change_txn in timeline: change_instant = str( - datetime.fromtimestamp(change_txn["timestamp"] // 1000) + datetime.fromtimestamp(change_txn["timestamp"] // 1000), ) change_color = ( "green" @@ -196,7 +200,7 @@ def timeline( ) click.echo( - f"{click.style(change_instant, fg='cyan')} - {click.style(change_txn['semVer'], fg=change_color)}" + f"{click.style(change_instant, fg='cyan')} - {click.style(change_txn['semVer'], fg=change_color)}", ) if change_txn["changeEvents"] is not None: for change_event in change_txn["changeEvents"]: @@ -213,10 +217,10 @@ def timeline( target_string = pretty_id( change_event.get("target") or change_event.get("entityUrn") - or "" + or "", ) click.echo( - f"\t{click.style(change_event.get('changeType') or change_event.get('operation'), fg=event_change_color)} {change_event.get('category')} {target_string} {element_string}: {change_event['description']}" + f"\t{click.style(change_event.get('changeType') or change_event.get('operation'), fg=event_change_color)} {change_event.get('category')} {target_string} {element_string}: {change_event['description']}", ) else: click.echo( @@ -224,5 +228,5 @@ def timeline( timeline, sort_keys=True, indent=2, - ) + ), ) diff --git a/metadata-ingestion/src/datahub/configuration/_config_enum.py b/metadata-ingestion/src/datahub/configuration/_config_enum.py index 190a006b077d9f..37f5dc0a964cad 100644 --- a/metadata-ingestion/src/datahub/configuration/_config_enum.py +++ b/metadata-ingestion/src/datahub/configuration/_config_enum.py @@ -11,7 +11,10 @@ class ConfigEnum(Enum): # Ideally we would use @staticmethod here, but some versions of Python don't support it. # See https://github.com/python/mypy/issues/7591. def _generate_next_value_( # type: ignore - name: str, start, count, last_values + name: str, + start, + count, + last_values, ) -> str: # This makes the enum value match the enum option name. # From https://stackoverflow.com/a/44785241/5004662. @@ -26,7 +29,8 @@ def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore from pydantic_core import core_schema return core_schema.no_info_before_validator_function( - cls.validate, handler(source_type) + cls.validate, + handler(source_type), ) else: diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index 8052de1b0669c4..39d52386ed689e 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -284,7 +284,7 @@ def get_allowed_list(self) -> List[str]: """Return the list of allowed strings as a list, after taking into account deny patterns, if possible""" if not self.is_fully_specified_allow_list(): raise ValueError( - "allow list must be fully specified to get list of allowed strings" + "allow list must be fully specified to get list of allowed strings", ) return [a for a in self.allow if not self.denied(a)] @@ -316,7 +316,7 @@ def value(self, string: str) -> List[str]: return self.rules[matching_keys[0]] else: return deduplicate_list( - [v for key in matching_keys for v in self.rules[key]] + [v for key in matching_keys for v in self.rules[key]], ) diff --git a/metadata-ingestion/src/datahub/configuration/config_loader.py b/metadata-ingestion/src/datahub/configuration/config_loader.py index 16105f69d584de..2f90aa01d593cf 100644 --- a/metadata-ingestion/src/datahub/configuration/config_loader.py +++ b/metadata-ingestion/src/datahub/configuration/config_loader.py @@ -154,7 +154,7 @@ def load_config_file( config_mech = TomlConfigurationMechanism() else: raise ConfigurationError( - f"Only .toml, .yml, and .json are supported. Cannot process file type {config_file_path.suffix}" + f"Only .toml, .yml, and .json are supported. Cannot process file type {config_file_path.suffix}", ) url_parsed = parse.urlparse(str(config_file)) @@ -168,12 +168,12 @@ def load_config_file( raw_config_file = response.text except Exception as e: raise ConfigurationError( - f"Cannot read remote file {config_file_path}: {e}" + f"Cannot read remote file {config_file_path}: {e}", ) from e else: if not config_file_path.is_file(): raise ConfigurationError( - f"Cannot open config file {config_file_path.resolve()}" + f"Cannot open config file {config_file_path.resolve()}", ) raw_config_file = config_file_path.read_text() diff --git a/metadata-ingestion/src/datahub/configuration/connection_resolver.py b/metadata-ingestion/src/datahub/configuration/connection_resolver.py index a82698cd38cd71..fa9dfde037bdec 100644 --- a/metadata-ingestion/src/datahub/configuration/connection_resolver.py +++ b/metadata-ingestion/src/datahub/configuration/connection_resolver.py @@ -15,13 +15,13 @@ def _resolve_connection(cls: Type, values: dict) -> dict: graph = get_graph_context() if not graph: raise ValueError( - "Fetching connection details from the backend requires a DataHub graph client." + "Fetching connection details from the backend requires a DataHub graph client.", ) conn = graph.get_connection_json(connection_urn) if conn is None: raise ValueError( - f"Connection {connection_urn} not found using {graph}." + f"Connection {connection_urn} not found using {graph}.", ) # TODO: Should this do some additional validation against the config model? diff --git a/metadata-ingestion/src/datahub/configuration/datetimes.py b/metadata-ingestion/src/datahub/configuration/datetimes.py index 1520462fa9bf8c..d47bad590a70b7 100644 --- a/metadata-ingestion/src/datahub/configuration/datetimes.py +++ b/metadata-ingestion/src/datahub/configuration/datetimes.py @@ -89,7 +89,10 @@ class ClickDatetime(click.ParamType): name = "datetime" def convert( - self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + self, + value: Any, + param: Optional[click.Parameter], + ctx: Optional[click.Context], ) -> datetime: if isinstance(value, datetime): return value diff --git a/metadata-ingestion/src/datahub/configuration/git.py b/metadata-ingestion/src/datahub/configuration/git.py index 7e68e9f80da4ff..3463e595c63fbb 100644 --- a/metadata-ingestion/src/datahub/configuration/git.py +++ b/metadata-ingestion/src/datahub/configuration/git.py @@ -18,7 +18,7 @@ class GitReference(ConfigModel): """Reference to a hosted Git repository. Used to generate "view source" links.""" repo: str = Field( - description="Name of your Git repo e.g. https://github.com/datahub-project/datahub or https://gitlab.com/gitlab-org/gitlab. If organization/repo is provided, we assume it is a GitHub repo." + description="Name of your Git repo e.g. https://github.com/datahub-project/datahub or https://gitlab.com/gitlab-org/gitlab. If organization/repo is provided, we assume it is a GitHub repo.", ) branch: str = Field( "main", @@ -67,7 +67,7 @@ def infer_url_template(cls, url_template: Optional[str], values: dict) -> str: return _GITLAB_URL_TEMPLATE else: raise ValueError( - "Unable to infer URL template from repo. Please set url_template manually." + "Unable to infer URL template from repo. Please set url_template manually.", ) def get_url_for_file_path(self, file_path: str) -> str: @@ -75,7 +75,9 @@ def get_url_for_file_path(self, file_path: str) -> str: if self.url_subdir: file_path = f"{self.url_subdir}/{file_path}" return self.url_template.format( - repo_url=self.repo, branch=self.branch, file_path=file_path + repo_url=self.repo, + branch=self.branch, + file_path=file_path, ) @@ -102,7 +104,9 @@ class GitInfo(GitReference): @validator("deploy_key", pre=True, always=True) def deploy_key_filled_from_deploy_key_file( - cls, v: Optional[SecretStr], values: Dict[str, Any] + cls, + v: Optional[SecretStr], + values: Dict[str, Any], ) -> Optional[SecretStr]: if v is None: deploy_key_file = values.get("deploy_key_file") @@ -114,7 +118,9 @@ def deploy_key_filled_from_deploy_key_file( @validator("repo_ssh_locator", always=True) def infer_repo_ssh_locator( - cls, repo_ssh_locator: Optional[str], values: dict + cls, + repo_ssh_locator: Optional[str], + values: dict, ) -> str: if repo_ssh_locator is not None: return repo_ssh_locator @@ -126,7 +132,7 @@ def infer_repo_ssh_locator( return f"git@gitlab.com:{repo[len(_GITLAB_PREFIX) :]}.git" else: raise ValueError( - "Unable to infer repo_ssh_locator from repo. Please set repo_ssh_locator manually." + "Unable to infer repo_ssh_locator from repo. Please set repo_ssh_locator manually.", ) @property diff --git a/metadata-ingestion/src/datahub/configuration/kafka_consumer_config.py b/metadata-ingestion/src/datahub/configuration/kafka_consumer_config.py index f08c78cadc0b2b..6f8c3bd51941ec 100644 --- a/metadata-ingestion/src/datahub/configuration/kafka_consumer_config.py +++ b/metadata-ingestion/src/datahub/configuration/kafka_consumer_config.py @@ -55,7 +55,7 @@ def _validate_call_back_fn_signature(self, call_back_fn: Any) -> None: inspect.Parameter.POSITIONAL_OR_KEYWORD, ) and param.default == inspect.Parameter.empty - ] + ], ) has_variadic_args = any( diff --git a/metadata-ingestion/src/datahub/configuration/time_window_config.py b/metadata-ingestion/src/datahub/configuration/time_window_config.py index 5fabcf904d3219..88767e0492588f 100644 --- a/metadata-ingestion/src/datahub/configuration/time_window_config.py +++ b/metadata-ingestion/src/datahub/configuration/time_window_config.py @@ -54,7 +54,10 @@ class BaseTimeWindowConfig(ConfigModel): @pydantic.validator("start_time", pre=True, always=True) def default_start_time( - cls, v: Any, values: Dict[str, Any], **kwargs: Any + cls, + v: Any, + values: Dict[str, Any], + **kwargs: Any, ) -> datetime: if v is None: return get_time_bucket( @@ -70,7 +73,7 @@ def default_start_time( "Relative start time should start with minus sign (-) e.g. '-2 days'." ) assert abs(delta) >= get_bucket_duration_delta( - values["bucket_duration"] + values["bucket_duration"], ), ( "Relative start time should be in terms of configured bucket duration. e.g '-2 days' or '-2 hours'." ) @@ -81,7 +84,8 @@ def default_start_time( values["end_time"] = datetime.now(tz=timezone.utc) return get_time_bucket( - values["end_time"] + delta, values["bucket_duration"] + values["end_time"] + delta, + values["bucket_duration"], ) except humanfriendly.InvalidTimespan: # We do not floor start_time to the bucket start time if absolute start time is specified. @@ -94,7 +98,7 @@ def default_start_time( def ensure_timestamps_in_utc(cls, v: datetime) -> datetime: if v.tzinfo is None: raise ValueError( - "Timestamps must be in UTC. Try adding a 'Z' to the value e.g. '2021-07-20T00:00:00Z'" + "Timestamps must be in UTC. Try adding a 'Z' to the value e.g. '2021-07-20T00:00:00Z'", ) # If the timestamp is timezone-aware but not in UTC, convert it to UTC. diff --git a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py index de2a16e9bf247d..9486a2b61cd94f 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py +++ b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py @@ -23,7 +23,7 @@ def _validate_field_rename(cls: Type, values: dict) -> dict: if old_name in values: if new_name in values: raise ValueError( - f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}." + f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}.", ) else: if print_warning: @@ -50,5 +50,5 @@ def _validate_field_rename(cls: Type, values: dict) -> dict: # Given that a renamed field doesn't show up in the fields list, we can't use # the field-level validator, even with a different field name. return pydantic.root_validator(pre=True, skip_on_failure=True, allow_reuse=True)( - _validate_field_rename + _validate_field_rename, ) diff --git a/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py b/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py index 0baaf4f0264b99..6f0dafc603ea10 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py +++ b/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py @@ -15,7 +15,8 @@ def pydantic_multiline_string(field: str) -> classmethod: """ def _validate_field( - cls: Type, v: Union[None, str, pydantic.SecretStr] + cls: Type, + v: Union[None, str, pydantic.SecretStr], ) -> Optional[str]: if v is not None: if isinstance(v, pydantic.SecretStr): diff --git a/metadata-ingestion/src/datahub/emitter/kafka_emitter.py b/metadata-ingestion/src/datahub/emitter/kafka_emitter.py index 781930011b78fb..ca10e09f8e523f 100644 --- a/metadata-ingestion/src/datahub/emitter/kafka_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/kafka_emitter.py @@ -33,7 +33,7 @@ class KafkaEmitterConfig(ConfigModel): connection: KafkaProducerConnectionConfig = pydantic.Field( - default_factory=KafkaProducerConnectionConfig + default_factory=KafkaProducerConnectionConfig, ) topic_routes: Dict[str, str] = { MCE_KEY: DEFAULT_MCE_KAFKA_TOPIC, @@ -66,7 +66,8 @@ def __init__(self, config: KafkaEmitterConfig): schema_registry_client = SchemaRegistryClient(schema_registry_conf) def convert_mce_to_dict( - mce: MetadataChangeEvent, ctx: SerializationContext + mce: MetadataChangeEvent, + ctx: SerializationContext, ) -> dict: return mce.to_obj(tuples=True) diff --git a/metadata-ingestion/src/datahub/emitter/mce_builder.py b/metadata-ingestion/src/datahub/emitter/mce_builder.py index f5da90a86c9ef6..7c13df3b6647bb 100644 --- a/metadata-ingestion/src/datahub/emitter/mce_builder.py +++ b/metadata-ingestion/src/datahub/emitter/mce_builder.py @@ -124,7 +124,10 @@ def make_data_platform_urn(platform: str) -> str: def make_dataset_urn(platform: str, name: str, env: str = DEFAULT_ENV) -> str: return make_dataset_urn_with_platform_instance( - platform=platform, name=name, platform_instance=None, env=env + platform=platform, + name=name, + platform_instance=None, + env=env, ) @@ -136,7 +139,10 @@ def make_dataplatform_instance_urn(platform: str, instance: str) -> str: def make_dataset_urn_with_platform_instance( - platform: str, name: str, platform_instance: Optional[str], env: str = DEFAULT_ENV + platform: str, + name: str, + platform_instance: Optional[str], + env: str = DEFAULT_ENV, ) -> str: if DATASET_URN_TO_LOWER: name = name.lower() @@ -146,7 +152,7 @@ def make_dataset_urn_with_platform_instance( table_name=name, env=env, platform_instance=platform_instance, - ) + ), ) @@ -296,7 +302,7 @@ def make_data_flow_urn( flow_id=flow_id, env=cluster, platform_instance=platform_instance, - ) + ), ) @@ -316,12 +322,15 @@ def make_data_job_urn( platform_instance: Optional[str] = None, ) -> str: return make_data_job_urn_with_flow( - make_data_flow_urn(orchestrator, flow_id, cluster, platform_instance), job_id + make_data_flow_urn(orchestrator, flow_id, cluster, platform_instance), + job_id, ) def make_dashboard_urn( - platform: str, name: str, platform_instance: Optional[str] = None + platform: str, + name: str, + platform_instance: Optional[str] = None, ) -> str: # FIXME: dashboards don't currently include data platform urn prefixes. if platform_instance: @@ -339,7 +348,9 @@ def dashboard_urn_to_key(dashboard_urn: str) -> Optional[DashboardKeyClass]: def make_chart_urn( - platform: str, name: str, platform_instance: Optional[str] = None + platform: str, + name: str, + platform_instance: Optional[str] = None, ) -> str: # FIXME: charts don't currently include data platform urn prefixes. if platform_instance: @@ -420,16 +431,17 @@ def make_lineage_mce( type=lineage_type, ) for upstream_urn in upstream_urns - ] - ) + ], + ), ], - ) + ), ) return mce def can_add_aspect_to_snapshot( - SnapshotType: Type[DictWrapper], AspectType: Type[Aspect] + SnapshotType: Type[DictWrapper], + AspectType: Type[Aspect], ) -> bool: constructor_annotations = get_type_hints(SnapshotType.__init__) aspect_list_union = typing_inspect.get_args(constructor_annotations["aspects"])[0] @@ -446,16 +458,18 @@ def can_add_aspect(mce: MetadataChangeEventClass, AspectType: Type[Aspect]) -> b def assert_can_add_aspect( - mce: MetadataChangeEventClass, AspectType: Type[Aspect] + mce: MetadataChangeEventClass, + AspectType: Type[Aspect], ) -> None: if not can_add_aspect(mce, AspectType): raise AssertionError( - f"Cannot add aspect {AspectType} to {type(mce.proposedSnapshot)}" + f"Cannot add aspect {AspectType} to {type(mce.proposedSnapshot)}", ) def get_aspect_if_available( - mce: MetadataChangeEventClass, AspectType: Type[Aspect] + mce: MetadataChangeEventClass, + AspectType: Type[Aspect], ) -> Optional[Aspect]: assert_can_add_aspect(mce, AspectType) @@ -466,7 +480,7 @@ def get_aspect_if_available( if len(aspects) > 1: raise ValueError( - f"MCE contains multiple aspects of type {AspectType}: {aspects}" + f"MCE contains multiple aspects of type {AspectType}: {aspects}", ) if aspects: return aspects[0] @@ -474,7 +488,8 @@ def get_aspect_if_available( def remove_aspect_if_available( - mce: MetadataChangeEventClass, aspect_type: Type[Aspect] + mce: MetadataChangeEventClass, + aspect_type: Type[Aspect], ) -> bool: assert_can_add_aspect(mce, aspect_type) # loose type annotations since we checked before @@ -498,7 +513,7 @@ def get_or_add_aspect(mce: MetadataChangeEventClass, default: Aspect) -> Aspect: def make_global_tag_aspect_with_tag_list(tags: List[str]) -> GlobalTagsClass: return GlobalTagsClass( - tags=[TagAssociationClass(make_tag_urn(tag)) for tag in tags] + tags=[TagAssociationClass(make_tag_urn(tag)) for tag in tags], ) @@ -509,7 +524,7 @@ def make_ownership_aspect_from_urn_list( ) -> OwnershipClass: for owner_urn in owner_urns: assert owner_urn.startswith("urn:li:corpuser:") or owner_urn.startswith( - "urn:li:corpGroup:" + "urn:li:corpGroup:", ) ownership_source_type: Union[None, OwnershipSourceClass] = None if source_type: @@ -542,7 +557,9 @@ def make_glossary_terms_aspect_from_urn_list(term_urns: List[str]) -> GlossaryTe def set_aspect( - mce: MetadataChangeEventClass, aspect: Optional[Aspect], aspect_type: Type[Aspect] + mce: MetadataChangeEventClass, + aspect: Optional[Aspect], + aspect_type: Type[Aspect], ) -> None: """Sets the aspect to the provided aspect, overwriting any previous aspect value that might have existed before. If passed in aspect is None, then the existing aspect value will be removed""" diff --git a/metadata-ingestion/src/datahub/emitter/mcp.py b/metadata-ingestion/src/datahub/emitter/mcp.py index c6fcfad2e0abaa..56869d37a9bfc1 100644 --- a/metadata-ingestion/src/datahub/emitter/mcp.py +++ b/metadata-ingestion/src/datahub/emitter/mcp.py @@ -95,12 +95,14 @@ def __post_init__(self) -> None: and self.aspectName != self.aspect.get_aspect_name() ): raise ValueError( - f"aspectName {self.aspectName} does not match aspect type {type(self.aspect)} with name {self.aspect.get_aspect_name()}" + f"aspectName {self.aspectName} does not match aspect type {type(self.aspect)} with name {self.aspect.get_aspect_name()}", ) @classmethod def construct_many( - cls, entityUrn: str, aspects: Sequence[Optional[_Aspect]] + cls, + entityUrn: str, + aspects: Sequence[Optional[_Aspect]], ) -> List["MetadataChangeProposalWrapper"]: return [cls(entityUrn=entityUrn, aspect=aspect) for aspect in aspects if aspect] @@ -161,7 +163,9 @@ def to_obj(self, tuples: bool = False, simplified_structure: bool = False) -> di @classmethod def from_obj( - cls, obj: dict, tuples: bool = False + cls, + obj: dict, + tuples: bool = False, ) -> Union["MetadataChangeProposalWrapper", MetadataChangeProposalClass]: """ Attempt to deserialize into an MCPW, but fall back @@ -188,7 +192,8 @@ def from_obj( @classmethod def try_from_mcpc( - cls, mcpc: MetadataChangeProposalClass + cls, + mcpc: MetadataChangeProposalClass, ) -> Optional["MetadataChangeProposalWrapper"]: """Attempts to create a MetadataChangeProposalWrapper from a MetadataChangeProposalClass. Neatly handles unsupported, expected cases, such as unknown aspect types or non-json content type. @@ -217,7 +222,8 @@ def try_from_mcpc( @classmethod def try_from_mcl( - cls, mcl: MetadataChangeLogClass + cls, + mcl: MetadataChangeLogClass, ) -> Union["MetadataChangeProposalWrapper", MetadataChangeProposalClass]: mcpc = MetadataChangeProposalClass( entityUrn=mcl.entityUrn, @@ -233,14 +239,19 @@ def try_from_mcl( @classmethod def from_obj_require_wrapper( - cls, obj: dict, tuples: bool = False + cls, + obj: dict, + tuples: bool = False, ) -> "MetadataChangeProposalWrapper": mcp = cls.from_obj(obj, tuples=tuples) assert isinstance(mcp, cls) return mcp def as_workunit( - self, *, treat_errors_as_warnings: bool = False, is_primary_source: bool = True + self, + *, + treat_errors_as_warnings: bool = False, + is_primary_source: bool = True, ) -> "MetadataWorkUnit": from datahub.ingestion.api.workunit import MetadataWorkUnit diff --git a/metadata-ingestion/src/datahub/emitter/mcp_builder.py b/metadata-ingestion/src/datahub/emitter/mcp_builder.py index 581f903d0eef0d..1b9b4f7efab74b 100644 --- a/metadata-ingestion/src/datahub/emitter/mcp_builder.py +++ b/metadata-ingestion/src/datahub/emitter/mcp_builder.py @@ -45,7 +45,8 @@ # TODO: Once the model change has been deployed for a while, we can remove this. # Probably can do it at the beginning of 2025. _INCLUDE_ENV_IN_CONTAINER_PROPERTIES = get_boolean_env_variable( - "DATAHUB_INCLUDE_ENV_IN_CONTAINER_PROPERTIES", default=True + "DATAHUB_INCLUDE_ENV_IN_CONTAINER_PROPERTIES", + default=True, ) @@ -146,7 +147,9 @@ class NotebookKey(DatahubKey): def as_urn(self) -> str: return make_dataset_urn_with_platform_instance( - platform=self.platform, platform_instance=self.instance, name=self.guid() + platform=self.platform, + platform_instance=self.instance, + name=self.guid(), ) @@ -154,7 +157,8 @@ def as_urn(self) -> str: def add_domain_to_entity_wu( - entity_urn: str, domain_urn: str + entity_urn: str, + domain_urn: str, ) -> Iterable[MetadataWorkUnit]: yield MetadataChangeProposalWrapper( entityUrn=f"{entity_urn}", @@ -163,7 +167,9 @@ def add_domain_to_entity_wu( def add_owner_to_entity_wu( - entity_type: str, entity_urn: str, owner_urn: str + entity_type: str, + entity_urn: str, + owner_urn: str, ) -> Iterable[MetadataWorkUnit]: yield MetadataChangeProposalWrapper( entityUrn=f"{entity_urn}", @@ -172,26 +178,29 @@ def add_owner_to_entity_wu( OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ), ).as_workunit() def add_tags_to_entity_wu( - entity_type: str, entity_urn: str, tags: List[str] + entity_type: str, + entity_urn: str, + tags: List[str], ) -> Iterable[MetadataWorkUnit]: yield MetadataChangeProposalWrapper( entityType=entity_type, entityUrn=f"{entity_urn}", aspect=GlobalTagsClass( - tags=[TagAssociationClass(f"urn:li:tag:{tag}") for tag in tags] + tags=[TagAssociationClass(f"urn:li:tag:{tag}") for tag in tags], ), ).as_workunit() def add_structured_properties_to_entity_wu( - entity_urn: str, structured_properties: Dict[StructuredPropertyUrn, str] + entity_urn: str, + structured_properties: Dict[StructuredPropertyUrn, str], ) -> Iterable[MetadataWorkUnit]: aspect = StructuredPropertiesClass( properties=[ @@ -200,7 +209,7 @@ def add_structured_properties_to_entity_wu( values=[value], ) for urn, value in structured_properties.items() - ] + ], ) yield MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -306,12 +315,14 @@ def gen_containers( if structured_properties: yield from add_structured_properties_to_entity_wu( - entity_urn=container_urn, structured_properties=structured_properties + entity_urn=container_urn, + structured_properties=structured_properties, ) def add_dataset_to_container( - container_key: KeyType, dataset_urn: str + container_key: KeyType, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: container_urn = make_container_urn( guid=container_key.guid(), @@ -324,7 +335,9 @@ def add_dataset_to_container( def add_entity_to_container( - container_key: KeyType, entity_type: str, entity_urn: str + container_key: KeyType, + entity_type: str, + entity_urn: str, ) -> Iterable[MetadataWorkUnit]: container_urn = make_container_urn( guid=container_key.guid(), diff --git a/metadata-ingestion/src/datahub/emitter/mcp_patch_builder.py b/metadata-ingestion/src/datahub/emitter/mcp_patch_builder.py index e51c37d96e90f0..edc2a3b34b5368 100644 --- a/metadata-ingestion/src/datahub/emitter/mcp_patch_builder.py +++ b/metadata-ingestion/src/datahub/emitter/mcp_patch_builder.py @@ -107,7 +107,7 @@ def build(self) -> List[MetadataChangeProposalClass]: aspectName=aspect_name, aspect=GenericAspectClass( value=json.dumps( - pre_json_transform(_recursive_to_obj(patches)) + pre_json_transform(_recursive_to_obj(patches)), ).encode(), contentType=JSON_PATCH_CONTENT_TYPE, ), @@ -136,7 +136,10 @@ def _mint_auditstamp(cls, message: Optional[str] = None) -> AuditStampClass: @classmethod def _ensure_urn_type( - cls, entity_type: str, edges: List[EdgeClass], context: str + cls, + entity_type: str, + edges: List[EdgeClass], + context: str, ) -> None: """ Ensures that the destination URNs in the given edges have the specified entity type. @@ -153,5 +156,5 @@ def _ensure_urn_type( urn = Urn.from_string(e.destinationUrn) if not urn.entity_type == entity_type: raise ValueError( - f"{context}: {e.destinationUrn} is not of type {entity_type}" + f"{context}: {e.destinationUrn} is not of type {entity_type}", ) diff --git a/metadata-ingestion/src/datahub/emitter/request_helper.py b/metadata-ingestion/src/datahub/emitter/request_helper.py index 4e1ec026648b8d..a95b50f3233e94 100644 --- a/metadata-ingestion/src/datahub/emitter/request_helper.py +++ b/metadata-ingestion/src/datahub/emitter/request_helper.py @@ -12,7 +12,10 @@ def _format_header(name: str, value: Union[str, bytes]) -> str: def make_curl_command( - session: requests.Session, method: str, url: str, payload: str + session: requests.Session, + method: str, + url: str, + payload: str, ) -> str: fragments: List[str] = [ "curl", @@ -21,7 +24,7 @@ def make_curl_command( ("-X", method), *[("-H", _format_header(k, v)) for (k, v) in session.headers.items()], ("--data", payload), - ] + ], ), url, ] diff --git a/metadata-ingestion/src/datahub/emitter/rest_emitter.py b/metadata-ingestion/src/datahub/emitter/rest_emitter.py index 7271f784bf881e..099abd1cf82da2 100644 --- a/metadata-ingestion/src/datahub/emitter/rest_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/rest_emitter.py @@ -58,7 +58,7 @@ ] _DEFAULT_RETRY_METHODS = ["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"] _DEFAULT_RETRY_MAX_TIMES = int( - os.getenv("DATAHUB_REST_EMITTER_DEFAULT_RETRY_MAX_TIMES", "4") + os.getenv("DATAHUB_REST_EMITTER_DEFAULT_RETRY_MAX_TIMES", "4"), ) _DATAHUB_EMITTER_TRACE = get_boolean_env_variable("DATAHUB_EMITTER_TRACE", False) @@ -73,7 +73,7 @@ # too much to the backend and hitting a timeout, we try to limit # the number of MCPs we send in a batch. BATCH_INGEST_MAX_PAYLOAD_LENGTH = int( - os.getenv("DATAHUB_REST_EMITTER_BATCH_MAX_PAYLOAD_LENGTH", 200) + os.getenv("DATAHUB_REST_EMITTER_BATCH_MAX_PAYLOAD_LENGTH", 200), ) @@ -127,7 +127,9 @@ def build_session(self) -> requests.Session: ) adapter = HTTPAdapter( - pool_connections=100, pool_maxsize=100, max_retries=retry_strategy + pool_connections=100, + pool_maxsize=100, + max_retries=retry_strategy, ) session.mount("http://", adapter) session.mount("https://", adapter) @@ -205,19 +207,20 @@ def __init__( or timeout[1] < _TIMEOUT_LOWER_BOUND_SEC ): logger.warning( - f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds" + f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds", ) else: timeout = get_or_else(timeout_sec, _DEFAULT_TIMEOUT_SEC) if timeout < _TIMEOUT_LOWER_BOUND_SEC: logger.warning( - f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds" + f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds", ) self._session_config = RequestsSessionConfig( timeout=timeout, retry_status_codes=get_or_else( - retry_status_codes, _DEFAULT_RETRY_STATUS_CODES + retry_status_codes, + _DEFAULT_RETRY_STATUS_CODES, ), retry_methods=get_or_else(retry_methods, _DEFAULT_RETRY_METHODS), retry_max_times=get_or_else(retry_max_times, _DEFAULT_RETRY_MAX_TIMES), @@ -242,11 +245,11 @@ def test_connection(self) -> None: raise ConfigurationError( "You seem to have connected to the frontend service instead of the GMS endpoint. " "The rest emitter should connect to DataHub GMS (usually :8080) or Frontend GMS API (usually :9002/api/gms). " - "For Acryl users, the endpoint should be https://.acryl.io/gms" + "For Acryl users, the endpoint should be https://.acryl.io/gms", ) else: logger.debug( - f"Unable to connect to {url} with status_code: {response.status_code}. Response: {response.text}" + f"Unable to connect to {url} with status_code: {response.status_code}. Response: {response.text}", ) if response.status_code == 401: message = f"Unable to connect to {url} - got an authentication error: {response.text}." @@ -279,7 +282,8 @@ def emit( if isinstance(item, UsageAggregation): self.emit_usage(item) elif isinstance( - item, (MetadataChangeProposal, MetadataChangeProposalWrapper) + item, + (MetadataChangeProposal, MetadataChangeProposalWrapper), ): self.emit_mcp(item, async_flag=async_flag) else: @@ -351,7 +355,7 @@ def emit_mcps( mcp_obj_size = len(json.dumps(mcp_obj)) if _DATAHUB_EMITTER_TRACE: logger.debug( - f"Iterating through object with size {mcp_obj_size} (type: {mcp_obj.get('aspectName')}" + f"Iterating through object with size {mcp_obj_size} (type: {mcp_obj.get('aspectName')}", ) if ( @@ -366,7 +370,7 @@ def emit_mcps( current_chunk_size += mcp_obj_size if len(mcp_obj_chunks) > 0: logger.debug( - f"Decided to send {len(mcps)} MCP batch in {len(mcp_obj_chunks)} chunks" + f"Decided to send {len(mcps)} MCP batch in {len(mcp_obj_chunks)} chunks", ) for mcp_obj_chunk in mcp_obj_chunks: @@ -398,7 +402,7 @@ def _emit_generic(self, url: str, payload: str) -> None: if payload_size > INGEST_MAX_PAYLOAD_BYTES: # since we know total payload size here, we could simply avoid sending such payload at all and report a warning, with current approach we are going to cause whole ingestion to fail logger.warning( - f"Apparent payload size exceeded {INGEST_MAX_PAYLOAD_BYTES}, might fail with an exception due to the size" + f"Apparent payload size exceeded {INGEST_MAX_PAYLOAD_BYTES}, might fail with an exception due to the size", ) logger.debug( "Attempting to emit aspect (size: %s) to DataHub GMS; using curl equivalent to:\n%s", @@ -414,7 +418,8 @@ def _emit_generic(self, url: str, payload: str) -> None: if info.get("stackTrace"): logger.debug( - "Full stack trace from DataHub:\n%s", info.get("stackTrace") + "Full stack trace from DataHub:\n%s", + info.get("stackTrace"), ) info.pop("stackTrace", None) @@ -431,11 +436,13 @@ def _emit_generic(self, url: str, payload: str) -> None: except JSONDecodeError: # If we can't parse the JSON, just raise the original error. raise OperationalError( - "Unable to emit metadata to DataHub GMS", {"message": str(e)} + "Unable to emit metadata to DataHub GMS", + {"message": str(e)}, ) from e except RequestException as e: raise OperationalError( - "Unable to emit metadata to DataHub GMS", {"message": str(e)} + "Unable to emit metadata to DataHub GMS", + {"message": str(e)}, ) from e def __repr__(self) -> str: diff --git a/metadata-ingestion/src/datahub/emitter/serialization_helper.py b/metadata-ingestion/src/datahub/emitter/serialization_helper.py index ab9402ec891887..c114f2c3ae8f9e 100644 --- a/metadata-ingestion/src/datahub/emitter/serialization_helper.py +++ b/metadata-ingestion/src/datahub/emitter/serialization_helper.py @@ -17,7 +17,7 @@ def _pre_handle_union_with_aliases( # On the way out, we need to remove the field discriminator. field = obj["fieldDiscriminator"] return True, { - field: _json_transform(obj[field], from_pattern, to_pattern, pre=True) + field: _json_transform(obj[field], from_pattern, to_pattern, pre=True), } return False, None @@ -59,19 +59,23 @@ def _json_transform(obj: Any, from_pattern: str, to_pattern: str, pre: bool) -> if key.startswith(from_pattern): new_key = key.replace(from_pattern, to_pattern, 1) return { - new_key: _json_transform(value, from_pattern, to_pattern, pre=pre) + new_key: _json_transform(value, from_pattern, to_pattern, pre=pre), } if pre: handled, new_obj = _pre_handle_union_with_aliases( - obj, from_pattern, to_pattern + obj, + from_pattern, + to_pattern, ) if handled: return new_obj if not pre: handled, new_obj = _post_handle_unions_with_aliases( - obj, from_pattern, to_pattern + obj, + from_pattern, + to_pattern, ) if handled: return new_obj diff --git a/metadata-ingestion/src/datahub/emitter/sql_parsing_builder.py b/metadata-ingestion/src/datahub/emitter/sql_parsing_builder.py index a57886a0ba6999..9c65fe2bc27cbe 100644 --- a/metadata-ingestion/src/datahub/emitter/sql_parsing_builder.py +++ b/metadata-ingestion/src/datahub/emitter/sql_parsing_builder.py @@ -70,7 +70,7 @@ def gen_fine_grained_lineage_aspects(self) -> Iterable[FineGrainedLineageClass]: ), downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ - make_schema_field_urn(self.downstream_urn, downstream_col) + make_schema_field_urn(self.downstream_urn, downstream_col), ], ) @@ -87,7 +87,8 @@ class SqlParsingBuilder: # Maps downstream urn -> upstream urn -> LineageEdge # Builds up a single LineageEdge for each upstream -> downstream pair _lineage_map: FileBackedDict[Dict[DatasetUrn, LineageEdge]] = field( - default_factory=FileBackedDict, init=False + default_factory=FileBackedDict, + init=False, ) # TODO: Replace with FileBackedDict approach like in BigQuery usage @@ -213,12 +214,14 @@ def _gen_lineage_mcps(self) -> Iterable[MetadataChangeProposalWrapper]: or None, ) yield MetadataChangeProposalWrapper( - entityUrn=downstream_urn, aspect=upstream_lineage + entityUrn=downstream_urn, + aspect=upstream_lineage, ) def _gen_usage_statistics_workunits(self) -> Iterable[MetadataWorkUnit]: yield from self._usage_aggregator.generate_workunits( - resource_urn_builder=lambda urn: urn, user_urn_builder=lambda urn: urn + resource_urn_builder=lambda urn: urn, + user_urn_builder=lambda urn: urn, ) @@ -302,5 +305,6 @@ def _gen_operation_workunit( customOperationType=custom_operation_type, ) yield MetadataChangeProposalWrapper( - entityUrn=downstream_urn, aspect=aspect + entityUrn=downstream_urn, + aspect=aspect, ).as_workunit() diff --git a/metadata-ingestion/src/datahub/emitter/synchronized_file_emitter.py b/metadata-ingestion/src/datahub/emitter/synchronized_file_emitter.py index f82882f1a87cc3..0c3bde602fd148 100644 --- a/metadata-ingestion/src/datahub/emitter/synchronized_file_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/synchronized_file_emitter.py @@ -33,7 +33,9 @@ def __init__(self, filename: str) -> None: def emit( self, item: Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ], callback: Optional[Callable[[Exception, str], None]] = None, ) -> None: diff --git a/metadata-ingestion/src/datahub/entrypoints.py b/metadata-ingestion/src/datahub/entrypoints.py index 0162b476a21931..2bd5914998f4e9 100644 --- a/metadata-ingestion/src/datahub/entrypoints.py +++ b/metadata-ingestion/src/datahub/entrypoints.py @@ -135,10 +135,12 @@ def init(use_password: bool = False) -> None: click.confirm(f"{DATAHUB_CONFIG_PATH} already exists. Overwrite?", abort=True) click.echo( - "Configure which datahub instance to connect to (https://your-instance.acryl.io/gms for Acryl hosted users)" + "Configure which datahub instance to connect to (https://your-instance.acryl.io/gms for Acryl hosted users)", ) host = click.prompt( - "Enter your DataHub host", type=str, default="http://localhost:8080" + "Enter your DataHub host", + type=str, + default="http://localhost:8080", ) host = fixup_gms_url(host) if use_password: @@ -148,7 +150,9 @@ def init(use_password: bool = False) -> None: type=str, ) _, token = generate_access_token( - username=username, password=password, gms_url=host + username=username, + password=password, + gms_url=host, ) else: token = click.prompt( @@ -188,7 +192,7 @@ def init(use_password: bool = False) -> None: except ImportError as e: logger.debug(f"Failed to load datahub lite command: {e}") datahub.add_command( - make_shim_command("lite", "run `pip install 'acryl-datahub[datahub-lite]'`") + make_shim_command("lite", "run `pip install 'acryl-datahub[datahub-lite]'`"), ) try: @@ -198,7 +202,7 @@ def init(use_password: bool = False) -> None: except ImportError as e: logger.debug(f"Failed to load datahub actions framework: {e}") datahub.add_command( - make_shim_command("actions", "run `pip install acryl-datahub-actions`") + make_shim_command("actions", "run `pip install acryl-datahub-actions`"), ) @@ -221,10 +225,10 @@ def main(**kwargs): logger.exception(f"Command failed: {exc}") logger.debug( - f"DataHub CLI version: {datahub_package.__version__} at {datahub_package.__file__}" + f"DataHub CLI version: {datahub_package.__version__} at {datahub_package.__file__}", ) logger.debug( - f"Python version: {sys.version} at {sys.executable} on {platform.platform()}" + f"Python version: {sys.version} at {sys.executable} on {platform.platform()}", ) gms_config = get_gms_config() if gms_config: diff --git a/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_dataset_properties_aspect.py b/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_dataset_properties_aspect.py index fc164c84793658..7169cc6ac50b6f 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_dataset_properties_aspect.py +++ b/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_dataset_properties_aspect.py @@ -40,7 +40,7 @@ def try_aspect_from_metadata_change_proposal_class( # Deserializing `lastModified` as the `auto_patch_last_modified` function relies on this property # to decide if a patch aspect for the datasetProperties aspect should be generated return DatasetPropertiesClass( - lastModified=TimeStampClass(time=operation["value"]["time"]) + lastModified=TimeStampClass(time=operation["value"]["time"]), ) return None @@ -67,7 +67,7 @@ def auto_patch_last_modified( continue dataset_properties_aspect = wu.get_aspect_of_type( - DatasetPropertiesClass + DatasetPropertiesClass, ) or try_aspect_from_metadata_change_proposal_class(wu) dataset_operation_aspect = wu.get_aspect_of_type(OperationClass) @@ -120,8 +120,8 @@ def auto_patch_last_modified( dataset_patch_builder.set_last_modified( timestamp=TimeStampClass( - time=timestamp_pair.last_updated_timestamp_dataset_props - ) + time=timestamp_pair.last_updated_timestamp_dataset_props, + ), ) yield from [ diff --git a/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_ensure_aspect_size.py b/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_ensure_aspect_size.py index b63c96b617ff06..159849b0487836 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_ensure_aspect_size.py +++ b/metadata-ingestion/src/datahub/ingestion/api/auto_work_units/auto_ensure_aspect_size.py @@ -19,13 +19,17 @@ class EnsureAspectSizeProcessor: def __init__( - self, report: "SourceReport", payload_constraint: int = INGEST_MAX_PAYLOAD_BYTES + self, + report: "SourceReport", + payload_constraint: int = INGEST_MAX_PAYLOAD_BYTES, ): self.report = report self.payload_constraint = payload_constraint def ensure_dataset_profile_size( - self, dataset_urn: str, profile: DatasetProfileClass + self, + dataset_urn: str, + profile: DatasetProfileClass, ) -> None: """ This is quite arbitrary approach to ensuring dataset profile aspect does not exceed allowed size, might be adjusted @@ -41,7 +45,7 @@ def ensure_dataset_profile_size( if value: values_len += len(value) logger.debug( - f"Field {field.fieldPath} has {len(field.sampleValues)} sample values, taking total bytes {values_len}" + f"Field {field.fieldPath} has {len(field.sampleValues)} sample values, taking total bytes {values_len}", ) if sample_fields_size + values_len > self.payload_constraint: field.sampleValues = [] @@ -56,7 +60,9 @@ def ensure_dataset_profile_size( logger.debug(f"Field {field.fieldPath} has no sample values") def ensure_schema_metadata_size( - self, dataset_urn: str, schema: SchemaMetadataClass + self, + dataset_urn: str, + schema: SchemaMetadataClass, ) -> None: """ This is quite arbitrary approach to ensuring schema metadata aspect does not exceed allowed size, might be adjusted diff --git a/metadata-ingestion/src/datahub/ingestion/api/committable.py b/metadata-ingestion/src/datahub/ingestion/api/committable.py index cc7d74469f2b3a..a321f9de970620 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/committable.py +++ b/metadata-ingestion/src/datahub/ingestion/api/committable.py @@ -28,7 +28,10 @@ class StatefulCommittable( Generic[StateType], ): def __init__( - self, name: str, commit_policy: CommitPolicy, state_to_commit: StateType + self, + name: str, + commit_policy: CommitPolicy, + state_to_commit: StateType, ): super().__init__(name=name, commit_policy=commit_policy) self.committed: bool = False diff --git a/metadata-ingestion/src/datahub/ingestion/api/common.py b/metadata-ingestion/src/datahub/ingestion/api/common.py index 097859939cfea5..006e98e98c9fea 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/common.py +++ b/metadata-ingestion/src/datahub/ingestion/api/common.py @@ -70,7 +70,7 @@ def _set_dataset_urn_to_lower_if_needed(self) -> None: def register_checkpointer(self, committable: Committable) -> None: if committable.name in self.checkpointers: raise IndexError( - f"Checkpointing provider {committable.name} already registered." + f"Checkpointing provider {committable.name} already registered.", ) self.checkpointers[committable.name] = committable @@ -81,6 +81,6 @@ def require_graph(self, operation: Optional[str] = None) -> DataHubGraph: if not self.graph: raise ConfigurationError( f"{operation or 'This operation'} requires a graph, but none was provided. " - "To provide one, either use the datahub-rest sink or set the top-level datahub_api config in the recipe." + "To provide one, either use the datahub-rest sink or set the top-level datahub_api config in the recipe.", ) return self.graph diff --git a/metadata-ingestion/src/datahub/ingestion/api/decorators.py b/metadata-ingestion/src/datahub/ingestion/api/decorators.py index d32c0b85ceef4c..3eab4f1758a2dc 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/decorators.py +++ b/metadata-ingestion/src/datahub/ingestion/api/decorators.py @@ -31,7 +31,9 @@ def wrapper(cls: Type) -> Type: def platform_name( - platform_name: str, id: Optional[str] = None, doc_order: Optional[int] = None + platform_name: str, + id: Optional[str] = None, + doc_order: Optional[int] = None, ) -> Callable[[Type], Type]: """Adds a get_platform_name method to the decorated class""" @@ -44,7 +46,7 @@ def wrapper(cls: Type) -> Type: if id and " " in id: raise Exception( - f'Platform id "{id}" contains white-space, please use a platform id without spaces.' + f'Platform id "{id}" contains white-space, please use a platform id without spaces.', ) return wrapper @@ -89,7 +91,9 @@ class CapabilitySetting: def capability( - capability_name: SourceCapability, description: str, supported: bool = True + capability_name: SourceCapability, + description: str, + supported: bool = True, ) -> Callable[[Type], Type]: """ A decorator to mark a source as having a certain capability @@ -111,7 +115,9 @@ def wrapper(cls: Type) -> Type: cls.__capabilities.update(base_caps) cls.__capabilities[capability_name] = CapabilitySetting( - capability=capability_name, description=description, supported=supported + capability=capability_name, + description=description, + supported=supported, ) return cls diff --git a/metadata-ingestion/src/datahub/ingestion/api/incremental_lineage_helper.py b/metadata-ingestion/src/datahub/ingestion/api/incremental_lineage_helper.py index 92ee158661d3d4..46d636d437dac9 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/incremental_lineage_helper.py +++ b/metadata-ingestion/src/datahub/ingestion/api/incremental_lineage_helper.py @@ -37,14 +37,17 @@ def convert_upstream_lineage_to_patch( def convert_chart_info_to_patch( - urn: str, aspect: ChartInfoClass, system_metadata: Optional[SystemMetadataClass] + urn: str, + aspect: ChartInfoClass, + system_metadata: Optional[SystemMetadataClass], ) -> Optional[MetadataWorkUnit]: patch_builder = ChartPatchBuilder(urn, system_metadata) if aspect.customProperties: for key in aspect.customProperties: patch_builder.add_custom_property( - key, str(aspect.customProperties.get(key)) + key, + str(aspect.customProperties.get(key)), ) if aspect.inputEdges: @@ -52,31 +55,35 @@ def convert_chart_info_to_patch( patch_builder.add_input_edge(inputEdge) patch_builder.set_chart_url(aspect.chartUrl).set_external_url( - aspect.externalUrl + aspect.externalUrl, ).set_type(aspect.type).set_title(aspect.title).set_access( - aspect.access + aspect.access, ).set_last_modified(aspect.lastModified).set_last_refreshed( - aspect.lastRefreshed + aspect.lastRefreshed, ).set_description(aspect.description).add_inputs(aspect.inputs) values = patch_builder.build() if values: mcp = next(iter(values)) return MetadataWorkUnit( - id=MetadataWorkUnit.generate_workunit_id(mcp), mcp_raw=mcp + id=MetadataWorkUnit.generate_workunit_id(mcp), + mcp_raw=mcp, ) return None def convert_dashboard_info_to_patch( - urn: str, aspect: DashboardInfoClass, system_metadata: Optional[SystemMetadataClass] + urn: str, + aspect: DashboardInfoClass, + system_metadata: Optional[SystemMetadataClass], ) -> Optional[MetadataWorkUnit]: patch_builder = DashboardPatchBuilder(urn, system_metadata) if aspect.customProperties: for key in aspect.customProperties: patch_builder.add_custom_property( - key, str(aspect.customProperties.get(key)) + key, + str(aspect.customProperties.get(key)), ) if aspect.datasetEdges: @@ -115,11 +122,12 @@ def convert_dashboard_info_to_patch( if values: logger.debug( - f"Generating patch DashboardInfo MetadataWorkUnit for dashboard {aspect.title}" + f"Generating patch DashboardInfo MetadataWorkUnit for dashboard {aspect.title}", ) mcp = next(iter(values)) return MetadataWorkUnit( - id=MetadataWorkUnit.generate_workunit_id(mcp), mcp_raw=mcp + id=MetadataWorkUnit.generate_workunit_id(mcp), + mcp_raw=mcp, ) return None @@ -130,7 +138,7 @@ def get_fine_grained_lineage_key(fine_upstream: FineGrainedLineageClass) -> str: "upstreams": sorted(fine_upstream.upstreams or []), "downstreams": sorted(fine_upstream.downstreams or []), "transformOperation": fine_upstream.transformOperation, - } + }, ) @@ -153,15 +161,20 @@ def auto_incremental_lineage( if lineage_aspect and lineage_aspect.upstreams: yield convert_upstream_lineage_to_patch( - urn, lineage_aspect, wu.metadata.systemMetadata + urn, + lineage_aspect, + wu.metadata.systemMetadata, ) elif isinstance(wu.metadata, MetadataChangeProposalWrapper) and isinstance( - wu.metadata.aspect, UpstreamLineageClass + wu.metadata.aspect, + UpstreamLineageClass, ): lineage_aspect = wu.metadata.aspect if lineage_aspect.upstreams: yield convert_upstream_lineage_to_patch( - urn, lineage_aspect, wu.metadata.systemMetadata + urn, + lineage_aspect, + wu.metadata.systemMetadata, ) else: yield wu diff --git a/metadata-ingestion/src/datahub/ingestion/api/incremental_properties_helper.py b/metadata-ingestion/src/datahub/ingestion/api/incremental_properties_helper.py index 151b0c72a6c2de..d5082eeb540e48 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/incremental_properties_helper.py +++ b/metadata-ingestion/src/datahub/ingestion/api/incremental_properties_helper.py @@ -46,15 +46,20 @@ def auto_incremental_properties( if properties_aspect: yield convert_dataset_properties_to_patch( - urn, properties_aspect, wu.metadata.systemMetadata + urn, + properties_aspect, + wu.metadata.systemMetadata, ) elif isinstance(wu.metadata, MetadataChangeProposalWrapper) and isinstance( - wu.metadata.aspect, DatasetPropertiesClass + wu.metadata.aspect, + DatasetPropertiesClass, ): properties_aspect = wu.metadata.aspect if properties_aspect: yield convert_dataset_properties_to_patch( - urn, properties_aspect, wu.metadata.systemMetadata + urn, + properties_aspect, + wu.metadata.systemMetadata, ) else: yield wu diff --git a/metadata-ingestion/src/datahub/ingestion/api/registry.py b/metadata-ingestion/src/datahub/ingestion/api/registry.py index 6b50e0b8a8a5d4..bfd40586afe081 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/registry.py +++ b/metadata-ingestion/src/datahub/ingestion/api/registry.py @@ -66,7 +66,8 @@ class PluginRegistry(Generic[T]): _aliases: Dict[str, Tuple[str, Callable[[], None]]] def __init__( - self, extra_cls_check: Optional[Callable[[Type[T]], None]] = None + self, + extra_cls_check: Optional[Callable[[Type[T]], None]] = None, ) -> None: self._entrypoints = [] self._mapping = {} @@ -82,7 +83,7 @@ def _get_registered_type(self) -> Type[T]: def _check_cls(self, cls: Type[T]) -> None: if inspect.isabstract(cls): raise ValueError( - f"cannot register an abstract type in the registry; got {cls}" + f"cannot register an abstract type in the registry; got {cls}", ) super_cls = self._get_registered_type() if not issubclass(cls, super_cls): @@ -91,7 +92,10 @@ def _check_cls(self, cls: Type[T]) -> None: self._extra_cls_check(cls) def _register( - self, key: str, tp: Union[str, Type[T], Exception], override: bool = False + self, + key: str, + tp: Union[str, Type[T], Exception], + override: bool = False, ) -> None: if not override and key in self._mapping: raise KeyError(f"key already in use - {key}") @@ -107,12 +111,18 @@ def register_lazy(self, key: str, import_path: str) -> None: self._register(key, import_path) def register_disabled( - self, key: str, reason: Exception, override: bool = False + self, + key: str, + reason: Exception, + override: bool = False, ) -> None: self._register(key, reason, override=override) def register_alias( - self, alias: str, real_key: str, fn: Callable[[], None] = lambda: None + self, + alias: str, + real_key: str, + fn: Callable[[], None] = lambda: None, ) -> None: self._aliases[alias] = (real_key, fn) @@ -174,11 +184,11 @@ def get(self, key: str) -> Type[T]: tp = self._ensure_not_lazy(key) if isinstance(tp, ModuleNotFoundError): raise ConfigurationError( - f"{key} is disabled; try running: pip install '{__package_name__}[{key}]'" + f"{key} is disabled; try running: pip install '{__package_name__}[{key}]'", ) from tp elif isinstance(tp, Exception): raise ConfigurationError( - f"{key} is disabled due to an error in initialization" + f"{key} is disabled due to an error in initialization", ) from tp else: # If it's not an exception, then it's a registered type. @@ -191,7 +201,10 @@ def get_optional(self, key: str) -> Optional[Type[T]]: return None def summary( - self, verbose: bool = True, col_width: int = 15, verbose_col_width: int = 20 + self, + verbose: bool = True, + col_width: int = 15, + verbose_col_width: int = 20, ) -> str: self._materialize_entrypoints() diff --git a/metadata-ingestion/src/datahub/ingestion/api/report.py b/metadata-ingestion/src/datahub/ingestion/api/report.py index 8cfca5782bee40..e7d7f908fe383e 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/report.py +++ b/metadata-ingestion/src/datahub/ingestion/api/report.py @@ -133,5 +133,5 @@ def field(type: str, severity: LogLevel = "DEBUG") -> "EntityFilterReport": """A helper to create a dataclass field.""" return dataclasses.field( - default_factory=lambda: EntityFilterReport(type=type, severity=severity) + default_factory=lambda: EntityFilterReport(type=type, severity=severity), ) diff --git a/metadata-ingestion/src/datahub/ingestion/api/sink.py b/metadata-ingestion/src/datahub/ingestion/api/sink.py index 655e6bb22fa8d1..00d0991ffda9bd 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/sink.py +++ b/metadata-ingestion/src/datahub/ingestion/api/sink.py @@ -37,18 +37,21 @@ def compute_stats(self) -> None: self.current_time = datetime.datetime.now() if self.start_time: self.total_duration_in_seconds = round( - (self.current_time - self.start_time).total_seconds(), 2 + (self.current_time - self.start_time).total_seconds(), + 2, ) if self.total_duration_in_seconds > 0: self.records_written_per_second = int( - self.total_records_written / self.total_duration_in_seconds + self.total_records_written / self.total_duration_in_seconds, ) class WriteCallback(metaclass=ABCMeta): @abstractmethod def on_success( - self, record_envelope: RecordEnvelope, success_metadata: dict + self, + record_envelope: RecordEnvelope, + success_metadata: dict, ) -> None: pass @@ -66,7 +69,9 @@ class NoopWriteCallback(WriteCallback): """Convenience WriteCallback class to support noop""" def on_success( - self, record_envelope: RecordEnvelope, success_metadata: dict + self, + record_envelope: RecordEnvelope, + success_metadata: dict, ) -> None: pass @@ -124,7 +129,9 @@ def handle_work_unit_end(self, workunit: WorkUnit) -> None: @abstractmethod def write_record_async( - self, record_envelope: RecordEnvelope, write_callback: WriteCallback + self, + record_envelope: RecordEnvelope, + write_callback: WriteCallback, ) -> None: # must call callback when done. pass diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index b04ffdb3258934..155b1e157163ae 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -95,7 +95,7 @@ class StructuredLogs(Report): StructuredLogLevel.ERROR: LossyDict(10), StructuredLogLevel.WARN: LossyDict(10), StructuredLogLevel.INFO: LossyDict(10), - } + }, ) def report_log( @@ -194,10 +194,10 @@ class SourceReport(Report): _urns_seen: Set[str] = field(default_factory=set) entities: Dict[str, list] = field(default_factory=lambda: defaultdict(LossyList)) aspects: Dict[str, Dict[str, int]] = field( - default_factory=lambda: defaultdict(lambda: defaultdict(int)) + default_factory=lambda: defaultdict(lambda: defaultdict(int)), ) aspect_urn_samples: Dict[str, Dict[str, LossyList[str]]] = field( - default_factory=lambda: defaultdict(lambda: defaultdict(LossyList)) + default_factory=lambda: defaultdict(lambda: defaultdict(LossyList)), ) _structured_logs: StructuredLogs = field(default_factory=StructuredLogs) @@ -252,7 +252,12 @@ def report_warning( exc: Optional[BaseException] = None, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.WARN, message, title, context, exc, log=False + StructuredLogLevel.WARN, + message, + title, + context, + exc, + log=False, ) def warning( @@ -263,7 +268,12 @@ def warning( exc: Optional[BaseException] = None, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.WARN, message, title, context, exc, log=True + StructuredLogLevel.WARN, + message, + title, + context, + exc, + log=True, ) def report_failure( @@ -275,7 +285,12 @@ def report_failure( log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.ERROR, message, title, context, exc, log=log + StructuredLogLevel.ERROR, + message, + title, + context, + exc, + log=log, ) def failure( @@ -287,7 +302,12 @@ def failure( log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.ERROR, message, title, context, exc, log=log + StructuredLogLevel.ERROR, + message, + title, + context, + exc, + log=log, ) def info( @@ -299,7 +319,12 @@ def info( log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.INFO, message, title, context, exc, log=log + StructuredLogLevel.INFO, + message, + title, + context, + exc, + log=log, ) @contextlib.contextmanager @@ -317,7 +342,11 @@ def report_exc( yield except Exception as exc: self._structured_logs.report_log( - level, message=message, title=title, context=context, exc=exc + level, + message=message, + title=title, + context=context, + exc=exc, ) def __post_init__(self) -> None: @@ -340,7 +369,7 @@ def compute_stats(self) -> None: workunits_produced = self.events_produced if duration.total_seconds() > 0: self.events_produced_per_sec: int = int( - workunits_produced / duration.total_seconds() + workunits_produced / duration.total_seconds(), ) self.running_time = duration else: @@ -418,7 +447,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: and self.ctx.pipeline_config.flags.generate_browse_path_v2 ): browse_path_processor = self._get_browse_path_processor( - self.ctx.pipeline_config.flags.generate_browse_path_v2_dry_run + self.ctx.pipeline_config.flags.generate_browse_path_v2_dry_run, ) auto_lowercase_dataset_urns: Optional[MetadataWorkUnitProcessor] = None @@ -437,7 +466,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: or ( hasattr(self.ctx.pipeline_config.source.config, "get") and self.ctx.pipeline_config.source.config.get( - "convert_urns_to_lowercase" + "convert_urns_to_lowercase", ) ) ) @@ -449,7 +478,8 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: auto_status_aspect, auto_materialize_referenced_tags_terms, partial( - auto_fix_duplicate_schema_field_paths, platform=self._infer_platform() + auto_fix_duplicate_schema_field_paths, + platform=self._infer_platform(), ), partial(auto_fix_empty_field_paths, platform=self._infer_platform()), browse_path_processor, @@ -470,12 +500,13 @@ def _apply_workunit_processors( def get_workunits(self) -> Iterable[MetadataWorkUnit]: return self._apply_workunit_processors( - self.get_workunit_processors(), self.get_workunits_internal() + self.get_workunit_processors(), + self.get_workunits_internal(), ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: raise NotImplementedError( - "get_workunits_internal must be implemented if get_workunits is not overriden." + "get_workunits_internal must be implemented if get_workunits is not overriden.", ) def get_config(self) -> Optional[ConfigModel]: diff --git a/metadata-ingestion/src/datahub/ingestion/api/source_helpers.py b/metadata-ingestion/src/datahub/ingestion/api/source_helpers.py index 08af39cd24982a..11464ee981da63 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source_helpers.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source_helpers.py @@ -190,7 +190,7 @@ def auto_materialize_referenced_tags_terms( ).as_workunit() except InvalidUrnError: logger.info( - f"Source produced an invalid urn, so no key aspect will be generated: {urn}" + f"Source produced an invalid urn, so no key aspect will be generated: {urn}", ) @@ -279,7 +279,8 @@ def auto_browse_path_v2( urn, [ *paths.setdefault( - parent_urn, [] + parent_urn, + [], ), # Guess parent has no parents BrowsePathEntryClass(id=parent_urn, urn=parent_urn), ], @@ -312,8 +313,10 @@ def auto_browse_path_v2( entityUrn=urn, aspect=BrowsePathsV2Class( path=_prepend_platform_instance( - browse_path_v2, platform, platform_instance - ) + browse_path_v2, + platform, + platform_instance, + ), ), ).as_workunit() else: @@ -328,8 +331,10 @@ def auto_browse_path_v2( entityUrn=urn, aspect=BrowsePathsV2Class( path=_prepend_platform_instance( - path, platform, platform_instance - ) + path, + platform, + platform_instance, + ), ), ).as_workunit() elif urn not in emitted_urns and guess_entity_type(urn) == "container": @@ -339,7 +344,11 @@ def auto_browse_path_v2( yield MetadataChangeProposalWrapper( entityUrn=urn, aspect=BrowsePathsV2Class( - path=_prepend_platform_instance([], platform, platform_instance) + path=_prepend_platform_instance( + [], + platform, + platform_instance, + ), ), ).as_workunit() @@ -381,7 +390,7 @@ def auto_fix_duplicate_schema_field_paths( if dropped_fields: logger.info( - f"Fixing duplicate field paths in schema aspect for {wu.get_urn()} by dropping fields: {dropped_fields}" + f"Fixing duplicate field paths in schema aspect for {wu.get_urn()} by dropping fields: {dropped_fields}", ) schema_metadata.fields = updated_fields schemas_with_duplicates += 1 @@ -397,7 +406,8 @@ def auto_fix_duplicate_schema_field_paths( "duplicated_field_paths": duplicated_field_paths, } telemetry.telemetry_instance.ping( - "ingestion_duplicate_schema_field_paths", properties + "ingestion_duplicate_schema_field_paths", + properties, ) @@ -426,7 +436,7 @@ def auto_fix_empty_field_paths( if empty_field_paths > 0: logger.info( - f"Fixing empty field paths in schema aspect for {wu.get_urn()} by dropping empty fields" + f"Fixing empty field paths in schema aspect for {wu.get_urn()} by dropping empty fields", ) schema_metadata.fields = updated_fields schemas_with_empty_fields += 1 @@ -441,7 +451,8 @@ def auto_fix_empty_field_paths( "empty_field_paths": empty_field_paths, } telemetry.telemetry_instance.ping( - "ingestion_empty_schema_field_paths", properties + "ingestion_empty_schema_field_paths", + properties, ) @@ -478,7 +489,7 @@ def auto_empty_dataset_usage_statistics( if invalid_timestamps: logger.warning( f"Usage statistics with unexpected timestamps, bucket_duration={config.bucket_duration}:\n" - ", ".join(str(parse_ts_millis(ts)) for ts in invalid_timestamps) + ", ".join(str(parse_ts_millis(ts)) for ts in invalid_timestamps), ) for bucket in bucket_timestamps: diff --git a/metadata-ingestion/src/datahub/ingestion/api/transform.py b/metadata-ingestion/src/datahub/ingestion/api/transform.py index 1754eb9f132dce..a0f56a5efc5b7e 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/transform.py +++ b/metadata-ingestion/src/datahub/ingestion/api/transform.py @@ -7,7 +7,8 @@ class Transformer: @abstractmethod def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: """ Transforms a sequence of records. diff --git a/metadata-ingestion/src/datahub/ingestion/api/workunit.py b/metadata-ingestion/src/datahub/ingestion/api/workunit.py index f203624aced563..964f5cda0204db 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/workunit.py +++ b/metadata-ingestion/src/datahub/ingestion/api/workunit.py @@ -19,7 +19,9 @@ @dataclass class MetadataWorkUnit(WorkUnit): metadata: Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ] # A workunit creator can determine if this workunit is allowed to fail. @@ -34,7 +36,11 @@ class MetadataWorkUnit(WorkUnit): @overload def __init__( - self, id: str, mce: MetadataChangeEvent, *, is_primary_source: bool = True + self, + id: str, + mce: MetadataChangeEvent, + *, + is_primary_source: bool = True, ): # TODO: Force `mce` to be a keyword-only argument. ... @@ -100,7 +106,9 @@ def get_urn(self) -> str: def generate_workunit_id( cls, item: Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ], ) -> str: if isinstance(item, MetadataChangeEvent): @@ -160,7 +168,9 @@ def decompose_mce_into_mcps(self) -> Iterable["MetadataWorkUnit"]: def from_metadata( cls, metadata: Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ], id: Optional[str] = None, ) -> "MetadataWorkUnit": diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/json_ref_patch.py b/metadata-ingestion/src/datahub/ingestion/extractor/json_ref_patch.py index 2224a096f53875..a52a8d5935e8aa 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/json_ref_patch.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/json_ref_patch.py @@ -17,7 +17,8 @@ def title_swapping_callback(self: JsonRef) -> dict: except Exception as e: raise self._error(f"{e.__class__.__name__}: {str(e)}", cause=e) from e base_doc = _replace_refs( - base_doc, **{**self._ref_kwargs, "base_uri": uri, "recursing": False} + base_doc, + **{**self._ref_kwargs, "base_uri": uri, "recursing": False}, ) else: base_doc = self.store[uri] diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/json_schema_util.py b/metadata-ingestion/src/datahub/ingestion/extractor/json_schema_util.py index 1c440642e06d8b..fce4b60cc84ea9 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/json_schema_util.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/json_schema_util.py @@ -40,18 +40,18 @@ def as_schema_field_type(self) -> SchemaFieldDataTypeClass: if self.type == UnionTypeClass: return SchemaFieldDataTypeClass( type=UnionTypeClass( - nestedTypes=[self.nested_type] if self.nested_type else None - ) + nestedTypes=[self.nested_type] if self.nested_type else None, + ), ) elif self.type == ArrayTypeClass: return SchemaFieldDataTypeClass( type=ArrayTypeClass( - nestedType=[self.nested_type] if self.nested_type else None - ) + nestedType=[self.nested_type] if self.nested_type else None, + ), ) elif self.type == MapTypeClass: return SchemaFieldDataTypeClass( - type=MapTypeClass(keyType="string", valueType=self.nested_type) + type=MapTypeClass(keyType="string", valueType=self.nested_type), ) raise Exception(f"Unexpected type {self.type}") @@ -176,7 +176,7 @@ def as_string(self) -> str: prefix = [] if self.path: return ".".join( - prefix + [f.as_string(v2_format=v2_format) for f in self.path] + prefix + [f.as_string(v2_format=v2_format) for f in self.path], ) else: # this is a non-field (top-level) schema @@ -223,7 +223,7 @@ def _field_from_primitive( ) -> Iterable[SchemaField]: type_override = field_path._get_type_override() native_type = field_path._get_native_type_override() or str( - schema.get("type") or "" + schema.get("type") or "", ) nullable = JsonSchemaTranslator._is_nullable(schema) if "format" in schema: @@ -244,12 +244,13 @@ def _field_from_primitive( type=type_override or SchemaFieldDataTypeClass(type=datahub_field_type()), description=JsonSchemaTranslator._get_description_from_any_schema( - schema + schema, ), nativeDataType=native_type, nullable=nullable, jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required=required + schema, + required=required, ), isPartOfKey=field_path.is_key_schema, ) @@ -263,7 +264,8 @@ def _field_from_primitive( description=f"One of: {', '.join(schema_enums)}", nullable=nullable, jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required=required + schema, + required=required, ), isPartOfKey=field_path.is_key_schema, ) @@ -295,7 +297,8 @@ def _get_type_from_schema(schema: Dict) -> str: elif schema["type"] != "object": return schema["type"] elif "additionalProperties" in schema and isinstance( - schema["additionalProperties"], dict + schema["additionalProperties"], + dict, ): return "map" @@ -330,7 +333,8 @@ def _get_description_from_any_schema(schema: Dict) -> str: @staticmethod def _get_jsonprops_for_any_schema( - schema: Dict, required: Optional[bool] = None + schema: Dict, + required: Optional[bool] = None, ) -> Optional[str]: json_props = {} if "default" in schema: @@ -367,17 +371,19 @@ def _field_from_complex_type( if recursive_type: yield SchemaField( fieldPath=field_path.expand_type( - recursive_type, schema + recursive_type, + schema, ).as_string(), nativeDataType=native_type_override or recursive_type, type=type_override or SchemaFieldDataTypeClass(type=RecordTypeClass()), nullable=nullable, description=JsonSchemaTranslator._get_description_from_any_schema( - schema + schema, ), jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required + schema, + required, ), isPartOfKey=field_path.is_key_schema, recursive=True, @@ -387,7 +393,8 @@ def _field_from_complex_type( # generate a field for the struct if we have a field path yield SchemaField( fieldPath=field_path.expand_type( - discriminated_type, schema + discriminated_type, + schema, ).as_string(), nativeDataType=native_type_override or JsonSchemaTranslator._get_discriminated_type_from_schema(schema), @@ -395,10 +402,11 @@ def _field_from_complex_type( or SchemaFieldDataTypeClass(type=RecordTypeClass()), nullable=nullable, description=JsonSchemaTranslator._get_description_from_any_schema( - schema + schema, ), jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required + schema, + required, ), isPartOfKey=field_path.is_key_schema, ) @@ -408,7 +416,7 @@ def _field_from_complex_type( for field_name, field_schema in schema.get("properties", {}).items(): required_field: bool = field_name in schema.get("required", []) inner_field_path = field_path.clone_plus( - FieldElement(type=[], name=field_name, schema_types=[]) + FieldElement(type=[], name=field_name, schema_types=[]), ) yield from JsonSchemaTranslator.get_fields( JsonSchemaTranslator._get_type_from_schema(field_schema), @@ -424,11 +432,12 @@ def _field_from_complex_type( nativeDataType=native_type_override or JsonSchemaTranslator._get_discriminated_type_from_schema(schema), description=JsonSchemaTranslator._get_description_from_any_schema( - schema + schema, ), nullable=nullable, jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required=required + schema, + required=required, ), isPartOfKey=field_path.is_key_schema, ) @@ -439,7 +448,7 @@ def _field_from_complex_type( if not field_name: field_name = items_type inner_field_path = field_path.clone_plus( - FieldElement(type=[], name=field_name, schema_types=[]) + FieldElement(type=[], name=field_name, schema_types=[]), ) yield from JsonSchemaTranslator.get_fields( items_type, @@ -452,15 +461,15 @@ def _field_from_complex_type( field_path = field_path.expand_type("map", schema) # When additionalProperties is used alone, without properties, the object essentially functions as a map where T is the type described in the additionalProperties sub-schema. Maybe that helps to answer your original question. value_type = JsonSchemaTranslator._get_discriminated_type_from_schema( - schema["additionalProperties"] + schema["additionalProperties"], ) field_path._set_parent_type_if_not_exists( - DataHubType(type=MapTypeClass, nested_type=value_type) + DataHubType(type=MapTypeClass, nested_type=value_type), ) # FIXME: description not set. This is present in schema["description"]. yield from JsonSchemaTranslator.get_fields( JsonSchemaTranslator._get_type_from_schema( - schema["additionalProperties"] + schema["additionalProperties"], ), schema["additionalProperties"], required=required, @@ -486,7 +495,8 @@ def _field_from_complex_type( union_schema = union_category_schema[0] merged_union_schema = ( JsonSchemaTranslator._retain_parent_schema_props_in_union( - union_schema=union_schema, parent_schema=schema + union_schema=union_schema, + parent_schema=schema, ) ) yield from JsonSchemaTranslator.get_fields( @@ -505,16 +515,17 @@ def _field_from_complex_type( nativeDataType=f"union({union_category})", nullable=nullable, description=JsonSchemaTranslator._get_description_from_any_schema( - schema + schema, ), jsonProps=JsonSchemaTranslator._get_jsonprops_for_any_schema( - schema, required + schema, + required, ), isPartOfKey=field_path.is_key_schema, ) for i, union_schema in enumerate(union_category_schema): union_type = JsonSchemaTranslator._get_discriminated_type_from_schema( - union_schema + union_schema, ) if ( union_type == "object" @@ -522,11 +533,12 @@ def _field_from_complex_type( union_type = f"union_{i}" union_field_path = field_path.expand_type("union", schema) union_field_path._set_parent_type_if_not_exists( - DataHubType(type=UnionTypeClass, nested_type=union_type) + DataHubType(type=UnionTypeClass, nested_type=union_type), ) merged_union_schema = ( JsonSchemaTranslator._retain_parent_schema_props_in_union( - union_schema=union_schema, parent_schema=schema + union_schema=union_schema, + parent_schema=schema, ) ) yield from JsonSchemaTranslator.get_fields( @@ -541,7 +553,8 @@ def _field_from_complex_type( @staticmethod def _retain_parent_schema_props_in_union( - union_schema: Dict, parent_schema: Dict + union_schema: Dict, + parent_schema: Dict, ) -> Dict: """Merge the "properties" and the "required" fields from the parent schema into the child union schema.""" @@ -595,10 +608,14 @@ def get_fields( generator = cls.datahub_type_to_converter_mapping.get(datahub_type) if generator is None: raise Exception( - f"Failed to find a mapping for type {datahub_type}, schema was {schema_dict}" + f"Failed to find a mapping for type {datahub_type}, schema was {schema_dict}", ) yield from generator.__get__(cls)( - datahub_type, base_field_path, schema_dict, required, specific_type + datahub_type, + base_field_path, + schema_dict, + required, + specific_type, ) @classmethod @@ -627,7 +644,8 @@ def get_fields_from_schema( except Exception as e: if swallow_exceptions: logger.error( - "Failed to get fields from schema, continuing...", exc_info=e + "Failed to get fields from schema, continuing...", + exc_info=e, ) return else: @@ -646,7 +664,8 @@ def _get_id_from_any_schema(schema_dict: Dict[Any, Any]) -> Optional[str]: def get_enum_description( - authored_description: Optional[str], enum_symbols: List[str] + authored_description: Optional[str], + enum_symbols: List[str], ) -> str: description = authored_description or "" missed_symbols = [symbol for symbol in enum_symbols if symbol not in description] diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py b/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py index 73e67885dc38fb..3806eb6d33402d 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/mce_extractor.py @@ -28,12 +28,13 @@ class WorkUnitRecordExtractorConfig(ConfigModel): class WorkUnitRecordExtractor( - Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig] + Extractor[MetadataWorkUnit, WorkUnitRecordExtractorConfig], ): """An extractor that simply returns the data inside workunits back as records.""" def get_records( - self, workunit: WorkUnit + self, + workunit: WorkUnit, ) -> Iterable[ RecordEnvelope[ Union[ @@ -45,7 +46,8 @@ def get_records( ]: if isinstance(workunit, MetadataWorkUnit): if self.config.unpack_mces_into_mcps and isinstance( - workunit.metadata, MetadataChangeEvent + workunit.metadata, + MetadataChangeEvent, ): for inner_workunit in workunit.decompose_mce_into_mcps(): yield from self.get_records(inner_workunit) @@ -69,7 +71,7 @@ def get_records( invalid_mce = _try_reformat_with_black(invalid_mce) raise ValueError( - f"source produced an invalid metadata work unit: {invalid_mce}" + f"source produced an invalid metadata work unit: {invalid_mce}", ) yield RecordEnvelope( diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/protobuf_util.py b/metadata-ingestion/src/datahub/ingestion/extractor/protobuf_util.py index e947aff384871d..ea14e53ee0fee5 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/protobuf_util.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/protobuf_util.py @@ -73,7 +73,8 @@ def protobuf_schema_to_mce_fields( :return: The list of MCE compatible SchemaFields. """ descriptor: FileDescriptor = _from_protobuf_schema_to_descriptors( - main_schema, imported_schemas + main_schema, + imported_schemas, ) graph: nx.DiGraph = _populate_graph(descriptor) @@ -199,7 +200,10 @@ def _add_message(graph: nx.DiGraph, message: Descriptor, visited: Set[str]) -> N def _add_oneof( - graph: nx.DiGraph, parent_node: str, oneof: OneofDescriptor, visited: Set[str] + graph: nx.DiGraph, + parent_node: str, + oneof: OneofDescriptor, + visited: Set[str], ) -> None: node_name: str = _get_node_name(cast(DescriptorBase, oneof)) node_type: str = _get_type_ascription(cast(DescriptorBase, oneof)) @@ -233,7 +237,8 @@ def _create_schema_field(path: List[str], field: FieldDescriptor) -> _PathAndFie def _from_protobuf_schema_to_descriptors( - main_schema: ProtobufSchema, imported_schemas: Optional[List[ProtobufSchema]] = None + main_schema: ProtobufSchema, + imported_schemas: Optional[List[ProtobufSchema]] = None, ) -> FileDescriptor: if imported_schemas is None: imported_schemas = [] @@ -351,7 +356,8 @@ def _sanitise_type(name: str) -> str: def _schema_fields_from_dag( - graph: nx.DiGraph, is_key_schema: bool + graph: nx.DiGraph, + is_key_schema: bool, ) -> List[SchemaField]: generations: List = list(nx.algorithms.dag.topological_generations(graph)) fields: Dict = {} @@ -386,7 +392,9 @@ def _schema_fields_from_dag( def _traverse_path( - graph: nx.DiGraph, path: List[Tuple[str, str]], stack: List[str] + graph: nx.DiGraph, + path: List[Tuple[str, str]], + stack: List[str], ) -> Generator[_PathAndField, None, None]: if path: src, dst = path[0] diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py b/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py index dbb851c74e7e34..c07e3c2a2bf5ed 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py @@ -162,7 +162,8 @@ def __init__( @staticmethod def _get_type_name( - avro_schema: SchemaOrField, logical_if_present: bool = False + avro_schema: SchemaOrField, + logical_if_present: bool = False, ) -> str: logical_type_name: Optional[str] = None if logical_if_present: @@ -172,39 +173,45 @@ def _get_type_name( or avro_schema.props.get("logicalType"), ) return logical_type_name or str( - getattr(avro_schema.type, "type", avro_schema.type) + getattr(avro_schema.type, "type", avro_schema.type), ) @staticmethod def _get_column_type( - avro_schema: SchemaOrField, logical_type: Optional[str] + avro_schema: SchemaOrField, + logical_type: Optional[str], ) -> SchemaFieldDataType: type_name: str = AvroToMceSchemaConverter._get_type_name(avro_schema) TypeClass: Optional[Type] = AvroToMceSchemaConverter.field_type_mapping.get( - type_name + type_name, ) if logical_type is not None: TypeClass = AvroToMceSchemaConverter.field_logical_type_mapping.get( - logical_type, TypeClass + logical_type, + TypeClass, ) assert TypeClass is not None dt = SchemaFieldDataType(type=TypeClass()) # Handle Arrays and Maps if isinstance(dt.type, ArrayTypeClass) and isinstance( - avro_schema, avro.schema.ArraySchema + avro_schema, + avro.schema.ArraySchema, ): dt.type.nestedType = [ AvroToMceSchemaConverter._get_type_name( - avro_schema.items, logical_if_present=True - ) + avro_schema.items, + logical_if_present=True, + ), ] elif isinstance(dt.type, MapTypeClass) and isinstance( - avro_schema, avro.schema.MapSchema + avro_schema, + avro.schema.MapSchema, ): # Avro map's key is always a string. See: https://avro.apache.org/docs/current/spec.html#Maps dt.type.keyType = "string" dt.type.valueType = AvroToMceSchemaConverter._get_type_name( - avro_schema.values, logical_if_present=True + avro_schema.values, + logical_if_present=True, ) return dt @@ -262,18 +269,21 @@ def _get_type_annotation(schema: SchemaOrField) -> str: @staticmethod @overload def _get_underlying_type_if_option_as_union( - schema: SchemaOrField, default: SchemaOrField + schema: SchemaOrField, + default: SchemaOrField, ) -> SchemaOrField: ... @staticmethod @overload def _get_underlying_type_if_option_as_union( - schema: SchemaOrField, default: Optional[SchemaOrField] = None + schema: SchemaOrField, + default: Optional[SchemaOrField] = None, ) -> Optional[SchemaOrField]: ... @staticmethod def _get_underlying_type_if_option_as_union( - schema: SchemaOrField, default: Optional[SchemaOrField] = None + schema: SchemaOrField, + default: Optional[SchemaOrField] = None, ) -> Optional[SchemaOrField]: if isinstance(schema, avro.schema.UnionSchema) and len(schema.schemas) == 2: (first, second) = schema.schemas @@ -335,7 +345,8 @@ def emit(self) -> Iterable[SchemaField]: schema = schema.type actual_schema = ( self._converter._get_underlying_type_if_option_as_union( - schema, schema + schema, + schema, ) ) @@ -355,7 +366,8 @@ def emit(self) -> Iterable[SchemaField]: slice(len(type_prefix), len(native_data_type) - 1) ] native_data_type = cast( - str, actual_schema.props.get("native_data_type", native_data_type) + str, + actual_schema.props.get("native_data_type", native_data_type), ) field_path = self._converter._get_cur_field_path() @@ -367,7 +379,7 @@ def emit(self) -> Iterable[SchemaField]: meta_aspects: Dict[str, Any] = {} if self._converter._meta_mapping_processor: meta_aspects = self._converter._meta_mapping_processor.process( - merged_props + merged_props, ) tags: List[str] = [] @@ -447,7 +459,7 @@ def gen_items_from_list_tuple_or_scalar( # Union type elif isinstance(schema, avro.schema.UnionSchema): is_option_as_union_type = self._get_underlying_type_if_option_as_union( - schema + schema, ) if is_option_as_union_type is not None: yield is_option_as_union_type @@ -478,7 +490,8 @@ def _gen_nested_schema_from_field( self._fields_stack.pop() def _gen_from_last_field( - self, schema_to_recurse: Optional[AvroNestedSchemas] = None + self, + schema_to_recurse: Optional[AvroNestedSchemas] = None, ) -> Iterable[SchemaField]: """Emits the field most-recent field, optionally triggering sub-schema generation under the field.""" last_field_schema = self._fields_stack[-1] @@ -499,7 +512,8 @@ def _gen_from_last_field( yield from self._to_mce_fields(sub_schema) def _gen_from_non_field_nested_schemas( - self, schema: SchemaOrField + self, + schema: SchemaOrField, ) -> Iterable[SchemaField]: """Handles generation of MCE SchemaFields for all standard AVRO nested types.""" # Handle recursive record definitions @@ -543,7 +557,8 @@ def _gen_from_non_field_nested_schemas( yield from self._to_mce_fields(sub_schema) def _gen_non_nested_to_mce_fields( - self, schema: SchemaOrField + self, + schema: SchemaOrField, ) -> Iterable[SchemaField]: """Handles generation of MCE SchemaFields for non-nested AVRO types.""" with AvroToMceSchemaConverter.SchemaFieldEmissionContextManager( @@ -623,7 +638,7 @@ def avro_schema_to_mce_fields( meta_mapping_processor, schema_tags_field, tag_prefix, - ) + ), ) except Exception: if swallow_exceptions: diff --git a/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py b/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py index beec42724529e6..a8b275edb8cf58 100644 --- a/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py +++ b/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py @@ -35,7 +35,11 @@ class S3ListIterator(Iterator): MAX_KEYS = 1000 def __init__( - self, s3_client: Any, bucket: str, prefix: str, max_keys: int = MAX_KEYS + self, + s3_client: Any, + bucket: str, + prefix: str, + max_keys: int = MAX_KEYS, ) -> None: self._s3 = s3_client self._bucket = bucket @@ -68,7 +72,7 @@ def fetch(self): [ FileInfo(f"s3://{response['Name']}/{x['Key']}", x["Size"], is_file=True) for x in response.get("Contents", []) - ] + ], ) self._token = response.get("NextContinuationToken") @@ -89,7 +93,9 @@ def file_status(self, path: str) -> FileInfo: s3_path = parse_s3_path(path) try: response = self.s3.get_object_attributes( - Bucket=s3_path.bucket, Key=s3_path.key, ObjectAttributes=["ObjectSize"] + Bucket=s3_path.bucket, + Key=s3_path.key, + ObjectAttributes=["ObjectSize"], ) assert_ok_status(response) return FileInfo(path, response["ObjectSize"], is_file=True) diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py b/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py index 98c43079a3bc15..b489f6ad51ca9b 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py @@ -40,7 +40,7 @@ class ClassificationReportMixin: num_tables_classification_found: int = 0 info_types_detected: LossyDict[str, LossyList[str]] = field( - default_factory=LossyDict + default_factory=LossyDict, ) @@ -53,7 +53,9 @@ class ClassificationSourceConfigMixin(ConfigModel): class ClassificationHandler: def __init__( - self, config: ClassificationSourceConfigMixin, report: ClassificationReportMixin + self, + config: ClassificationSourceConfigMixin, + report: ClassificationReportMixin, ): self.config = config self.report = report @@ -75,7 +77,9 @@ def is_classification_enabled_for_table(self, dataset_name: str) -> bool: ) def is_classification_enabled_for_column( - self, dataset_name: str, column_name: str + self, + dataset_name: str, + column_name: str, ) -> bool: return ( self.config.classification is not None @@ -83,7 +87,7 @@ def is_classification_enabled_for_column( and len(self.config.classification.classifiers) > 0 and self.config.classification.table_pattern.allowed(dataset_name) and self.config.classification.column_pattern.allowed( - f"{dataset_name}.{column_name}" + f"{dataset_name}.{column_name}", ) ) @@ -95,12 +99,12 @@ def get_classifiers(self) -> List[Classifier]: if classifier_class is None: raise ConfigurationError( f"Cannot find classifier class of type={self.config.classification.classifiers[0].type} " - " in the registry! Please check the type of the classifier in your config." + " in the registry! Please check the type of the classifier in your config.", ) classifiers.append( classifier_class.create( config_dict=classifier.config, # type: ignore - ) + ), ) return classifiers @@ -125,7 +129,9 @@ def classify_schema_fields( logger.debug("Error", exc_info=e) column_infos = self.get_columns_to_classify( - dataset_name, schema_metadata, sample_data + dataset_name, + schema_metadata, + sample_data, ) if not column_infos: @@ -142,7 +148,8 @@ def classify_schema_fields( column_infos_with_proposals: Iterable[ColumnInfo] if self.config.classification.max_workers > 1: column_infos_with_proposals = self.async_classify( - classifier, column_infos + classifier, + column_infos, ) else: column_infos_with_proposals = classifier.classify(column_infos) @@ -156,7 +163,7 @@ def classify_schema_fields( finally: time_taken = timer.elapsed_seconds() logger.debug( - f"Finished classification {dataset_name}; took {time_taken:.3f} seconds" + f"Finished classification {dataset_name}; took {time_taken:.3f} seconds", ) if field_terms: @@ -164,20 +171,24 @@ def classify_schema_fields( self.populate_terms_in_schema_metadata(schema_metadata, field_terms) def update_field_terms( - self, field_terms: Dict[str, str], col_info: ColumnInfo + self, + field_terms: Dict[str, str], + col_info: ColumnInfo, ) -> None: term = self.get_terms_for_column(col_info) if term: field_terms[col_info.metadata.name] = term def async_classify( - self, classifier: Classifier, columns: List[ColumnInfo] + self, + classifier: Classifier, + columns: List[ColumnInfo], ) -> Iterable[ColumnInfo]: num_columns = len(columns) BATCH_SIZE = 5 # Number of columns passed to classify api at a time logger.debug( - f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}." + f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}.", ) with concurrent.futures.ProcessPoolExecutor( @@ -196,7 +207,7 @@ def async_classify( return [ column_with_proposal for proposal_future in concurrent.futures.as_completed( - column_info_proposal_futures + column_info_proposal_futures, ) for column_with_proposal in proposal_future.result() ] @@ -211,8 +222,8 @@ def populate_terms_in_schema_metadata( schema_field.glossaryTerms = GlossaryTerms( terms=[ GlossaryTermAssociation( - urn=make_term_urn(field_terms[schema_field.fieldPath]) - ) + urn=make_term_urn(field_terms[schema_field.fieldPath]), + ), ] # Keep existing terms if present + ( @@ -221,7 +232,8 @@ def populate_terms_in_schema_metadata( else [] ), auditStamp=AuditStamp( - time=get_sys_time(), actor=make_user_urn("datahub") + time=get_sys_time(), + actor=make_user_urn("datahub"), ), ) @@ -229,13 +241,16 @@ def get_terms_for_column(self, col_info: ColumnInfo) -> Optional[str]: if not col_info.infotype_proposals: return None infotype_proposal = max( - col_info.infotype_proposals, key=lambda p: p.confidence_level + col_info.infotype_proposals, + key=lambda p: p.confidence_level, ) self.report.info_types_detected.setdefault( - infotype_proposal.infotype, LossyList() + infotype_proposal.infotype, + LossyList(), ).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}") term = self.config.classification.info_type_to_term.get( - infotype_proposal.infotype, infotype_proposal.infotype + infotype_proposal.infotype, + infotype_proposal.infotype, ) return term @@ -250,10 +265,11 @@ def get_columns_to_classify( for schema_field in schema_metadata.fields: if not self.is_classification_enabled_for_column( - dataset_name, schema_field.fieldPath + dataset_name, + schema_field.fieldPath, ): logger.debug( - f"Skipping column {dataset_name}.{schema_field.fieldPath} from classification" + f"Skipping column {dataset_name}.{schema_field.fieldPath} from classification", ) continue @@ -271,14 +287,14 @@ def get_columns_to_classify( "Description": schema_field.description, "DataType": schema_field.nativeDataType, "Dataset_Name": dataset_name, - } + }, ), values=( sample_data[schema_field.fieldPath] if schema_field.fieldPath in sample_data.keys() else [] ), - ) + ), ) return column_infos @@ -326,7 +342,8 @@ def classification_workunit_processor( ), ) yield MetadataChangeProposalWrapper( - aspect=maybe_schema_metadata, entityUrn=wu.get_urn() + aspect=maybe_schema_metadata, + entityUrn=wu.get_urn(), ).as_workunit( is_primary_source=wu.is_primary_source, ) diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/classifier.py b/metadata-ingestion/src/datahub/ingestion/glossary/classifier.py index bdcdcb8990eba7..67a2f7a03185e0 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/classifier.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/classifier.py @@ -33,7 +33,8 @@ class ClassificationConfig(ConfigModel): ) sample_size: int = Field( - default=100, description="Number of sample values used for classification." + default=100, + description="Number of sample values used for classification.", ) max_workers: int = Field( diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py index ba03083854e785..15fa0ceca4a3c9 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py @@ -45,7 +45,8 @@ class ValuesFactorConfig(ConfigModel): description="List of regex patterns the column value follows for the info type", ) library: Optional[List[str]] = Field( - default=None, description="Library used for prediction" + default=None, + description="Library used for prediction", ) @@ -81,7 +82,8 @@ class Config: Name: Optional[NameFactorConfig] = Field(default=None, alias="name") Description: Optional[DescriptionFactorConfig] = Field( - default=None, alias="description" + default=None, + alias="description", ) Datatype: Optional[DataTypeFactorConfig] = Field(default=None, alias="datatype") diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index 48a008536ed1ed..deb1294d44e047 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -154,7 +154,8 @@ def test_connection(self) -> None: return try: client_id: Optional[TelemetryClientIdClass] = self.get_aspect( - "urn:li:telemetry:clientId", TelemetryClientIdClass + "urn:li:telemetry:clientId", + TelemetryClientIdClass, ) self.server_id = client_id.clientId if client_id else _MISSING_SERVER_ID except Exception as e: @@ -197,7 +198,7 @@ def from_emitter(cls, emitter: DatahubRestEmitter) -> "DataHubGraph": disable_ssl_verification=session_config.disable_ssl_verification, ca_certificate_path=session_config.ca_certificate_path, client_certificate_path=session_config.client_certificate_path, - ) + ), ) def _send_restli_request(self, method: str, url: str, **kwargs: Any) -> Dict: @@ -209,12 +210,14 @@ def _send_restli_request(self, method: str, url: str, **kwargs: Any) -> Dict: try: info = response.json() raise OperationalError( - "Unable to get metadata from DataHub", info + "Unable to get metadata from DataHub", + info, ) from e except JSONDecodeError: # If we can't parse the JSON, just raise the original error. raise OperationalError( - "Unable to get metadata from DataHub", {"message": str(e)} + "Unable to get metadata from DataHub", + {"message": str(e)}, ) from e def _get_generic(self, url: str, params: Optional[Dict] = None) -> Dict: @@ -224,7 +227,8 @@ def _post_generic(self, url: str, payload_dict: Dict) -> Dict: return self._send_restli_request("POST", url, json=payload_dict) def _make_rest_sink_config( - self, extra_config: Optional[Dict] = None + self, + extra_config: Optional[Dict] = None, ) -> "DatahubRestSinkConfig": from datahub.ingestion.sink.datahub_rest import DatahubRestSinkConfig @@ -249,10 +253,10 @@ def make_rest_sink( yield sink if sink.report.failures: logger.error( - f"Failed to emit {len(sink.report.failures)} records\n{sink.report.as_string()}" + f"Failed to emit {len(sink.report.failures)} records\n{sink.report.as_string()}", ) raise OperationalError( - f"Failed to emit {len(sink.report.failures)} records" + f"Failed to emit {len(sink.report.failures)} records", ) def emit_all( @@ -294,7 +298,7 @@ def get_aspect( aspect = aspect_type.ASPECT_NAME if aspect in TIMESERIES_ASPECT_MAP: raise TypeError( - 'Cannot get a timeseries aspect using "get_aspect". Use "get_latest_timeseries_value" instead.' + 'Cannot get a timeseries aspect using "get_aspect". Use "get_latest_timeseries_value" instead.', ) url: str = f"{self._gms_server}/aspects/{Urn.url_encode(entity_urn)}?aspect={aspect}&version={version}" @@ -317,7 +321,7 @@ def get_aspect( return aspect_type.from_obj(post_json_obj) else: raise GraphError( - f"Failed to find {aspect_type_name} in response {response_json}" + f"Failed to find {aspect_type_name} in response {response_json}", ) @deprecated(reason="Use get_aspect instead which makes aspect string name optional") @@ -350,10 +354,12 @@ def get_domain_properties(self, entity_urn: str) -> Optional[DomainPropertiesCla return self.get_aspect(entity_urn=entity_urn, aspect_type=DomainPropertiesClass) def get_dataset_properties( - self, entity_urn: str + self, + entity_urn: str, ) -> Optional[DatasetPropertiesClass]: return self.get_aspect( - entity_urn=entity_urn, aspect_type=DatasetPropertiesClass + entity_urn=entity_urn, + aspect_type=DatasetPropertiesClass, ) def get_tags(self, entity_urn: str) -> Optional[GlobalTagsClass]: @@ -371,7 +377,10 @@ def get_browse_path(self, entity_urn: str) -> Optional[BrowsePathsClass]: return self.get_aspect(entity_urn=entity_urn, aspect_type=BrowsePathsClass) def get_usage_aspects_from_urn( - self, entity_urn: str, start_timestamp: int, end_timestamp: int + self, + entity_urn: str, + start_timestamp: int, + end_timestamp: int, ) -> Optional[List[DatasetUsageStatisticsClass]]: payload = { "urn": entity_urn, @@ -385,11 +394,13 @@ def get_usage_aspects_from_urn( try: usage_aspects: List[DatasetUsageStatisticsClass] = [] response = self._session.post( - url, data=json.dumps(payload), headers=headers + url, + data=json.dumps(payload), + headers=headers, ) if response.status_code != 200: logger.debug( - f"Non 200 status found while fetching usage aspects - {response.status_code}" + f"Non 200 status found while fetching usage aspects - {response.status_code}", ) return None json_resp = response.json() @@ -398,8 +409,9 @@ def get_usage_aspects_from_urn( if aspect.get("aspect") and aspect.get("aspect").get("value"): usage_aspects.append( DatasetUsageStatisticsClass.from_obj( - json.loads(aspect.get("aspect").get("value")), tuples=True - ) + json.loads(aspect.get("aspect").get("value")), + tuples=True, + ), ) return usage_aspects except Exception as e: @@ -407,7 +419,10 @@ def get_usage_aspects_from_urn( return None def list_all_entity_urns( - self, entity_type: str, start: int, count: int + self, + entity_type: str, + start: int, + count: int, ) -> Optional[List[str]]: url = f"{self._gms_server}/entities?action=listUrns" payload = {"entity": entity_type, "start": start, "count": count} @@ -417,11 +432,13 @@ def list_all_entity_urns( } try: response = self._session.post( - url, data=json.dumps(payload), headers=headers + url, + data=json.dumps(payload), + headers=headers, ) if response.status_code != 200: logger.debug( - f"Non 200 status found while fetching entity urns - {response.status_code}" + f"Non 200 status found while fetching entity urns - {response.status_code}", ) return None json_resp = response.json() @@ -443,7 +460,10 @@ def get_latest_timeseries_value( filter = {"or": [{"and": filter_criteria}]} values = self.get_timeseries_values( - entity_urn=entity_urn, aspect_type=aspect_type, filter=filter, limit=1 + entity_urn=entity_urn, + aspect_type=aspect_type, + filter=filter, + limit=1, ) if not values: return None @@ -474,16 +494,18 @@ def get_timeseries_values( aspect_json: str = value.get("aspect", {}).get("value") if aspect_json: aspects.append( - aspect_type.from_obj(json.loads(aspect_json), tuples=False) + aspect_type.from_obj(json.loads(aspect_json), tuples=False), ) else: raise GraphError( - f"Failed to find {aspect_type} in response {aspect_json}" + f"Failed to find {aspect_type} in response {aspect_json}", ) return aspects def get_entity_raw( - self, entity_urn: str, aspects: Optional[List[str]] = None + self, + entity_urn: str, + aspects: Optional[List[str]] = None, ) -> Dict: endpoint: str = f"{self.config.server}/entitiesV2/{Urn.url_encode(entity_urn)}" if aspects is not None: @@ -495,7 +517,7 @@ def get_entity_raw( return response.json() @deprecated( - reason="Use get_aspect for a single aspect or get_entity_semityped for a full entity." + reason="Use get_aspect for a single aspect or get_entity_semityped for a full entity.", ) def get_aspects_for_entity( self, @@ -541,7 +563,9 @@ def get_aspects_for_entity( return result def get_entity_as_mcps( - self, entity_urn: str, aspects: Optional[List[str]] = None + self, + entity_urn: str, + aspects: Optional[List[str]] = None, ) -> List[MetadataChangeProposalWrapper]: """Get all non-timeseries aspects for an entity. @@ -587,7 +611,9 @@ def get_entity_as_mcps( return results def get_entity_semityped( - self, entity_urn: str, aspects: Optional[List[str]] = None + self, + entity_urn: str, + aspects: Optional[List[str]] = None, ) -> AspectBag: """Get (all) non-timeseries aspects for an entity. @@ -637,7 +663,7 @@ def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]: "field": "name", "value": domain_name, "condition": "EQUAL", - } + }, ] filters.append({"and": filter_criteria}) @@ -652,7 +678,7 @@ def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]: num_entities = results.get("value", {}).get("numEntities", 0) if num_entities > 1: logger.warning( - f"Got {num_entities} results for domain name {domain_name}. Will return the first match." + f"Got {num_entities} results for domain name {domain_name}. Will return the first match.", ) entities_yielded: int = 0 entities = [] @@ -688,7 +714,7 @@ def get_connection_json(self, urn: str) -> Optional[dict]: connection_type = res["connection"]["details"]["type"] if connection_type != "JSON": logger.error( - f"Expected connection details type to be 'JSON', but got {connection_type}" + f"Expected connection details type to be 'JSON', but got {connection_type}", ) return None @@ -740,7 +766,7 @@ def set_connection_json( assert res["upsertConnection"]["urn"] == urn @deprecated( - reason='Use get_urns_by_filter(entity_types=["container"], ...) instead' + reason='Use get_urns_by_filter(entity_types=["container"], ...) instead', ) def get_container_urns_by_filter( self, @@ -759,7 +785,7 @@ def get_container_urns_by_filter( "field": "customProperties", "value": f"instance={env}", "condition": "EQUAL", - } + }, ) filter_criteria.append( @@ -767,7 +793,7 @@ def get_container_urns_by_filter( "field": "typeNames", "value": container_subtype, "condition": "EQUAL", - } + }, ) container_filters.append({"and": filter_criteria}) search_body = { @@ -808,7 +834,12 @@ def _bulk_fetch_schema_info_by_filter( query = query or "*" orFilters = generate_filter( - platform, platform_instance, env, container, status, extraFilters + platform, + platform_instance, + env, + container, + status, + extraFilters, ) graphql_query = textwrap.dedent( @@ -847,7 +878,7 @@ def _bulk_fetch_schema_info_by_filter( } } } - """ + """, ) variables = { @@ -940,7 +971,7 @@ def get_urns_by_filter( } } } - """ + """, ) variables = { @@ -1045,7 +1076,7 @@ def get_results_by_filter( } } } - """ + """, ) variables = { @@ -1061,7 +1092,9 @@ def get_results_by_filter( yield result def _scroll_across_entities_results( - self, graphql_query: str, variables_orig: dict + self, + graphql_query: str, + variables_orig: dict, ) -> Iterable[dict]: variables = variables_orig.copy() first_iter = True @@ -1081,11 +1114,13 @@ def _scroll_across_entities_results( if scroll_id: logger.debug( - f"Scrolling to next scrollAcrossEntities page: {scroll_id}" + f"Scrolling to next scrollAcrossEntities page: {scroll_id}", ) def _scroll_across_entities( - self, graphql_query: str, variables_orig: dict + self, + graphql_query: str, + variables_orig: dict, ) -> Iterable[dict]: variables = variables_orig.copy() first_iter = True @@ -1105,7 +1140,7 @@ def _scroll_across_entities( if scroll_id: logger.debug( - f"Scrolling to next scrollAcrossEntities page: {scroll_id}" + f"Scrolling to next scrollAcrossEntities page: {scroll_id}", ) def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]: @@ -1113,14 +1148,16 @@ def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]: if entity_types is not None: if not entity_types: raise ValueError( - "entity_types cannot be an empty list; use None for all entities" + "entity_types cannot be an empty list; use None for all entities", ) types = [_graphql_entity_type(entity_type) for entity_type in entity_types] return types def get_latest_pipeline_checkpoint( - self, pipeline_name: str, platform: str + self, + pipeline_name: str, + platform: str, ) -> Optional[Checkpoint["GenericCheckpointState"]]: from datahub.ingestion.source.state.entity_removal_state import ( GenericCheckpointState, @@ -1136,7 +1173,8 @@ def get_latest_pipeline_checkpoint( job_name = StaleEntityRemovalHandler.compute_job_id(platform) raw_checkpoint = checkpoint_provider.get_latest_checkpoint( - pipeline_name, job_name + pipeline_name, + job_name, ) if not raw_checkpoint: return None @@ -1148,7 +1186,10 @@ def get_latest_pipeline_checkpoint( ) def get_search_results( - self, start: int = 0, count: int = 1, entity: str = "dataset" + self, + start: int = 0, + count: int = 1, + entity: str = "dataset", ) -> Dict: search_body = {"input": "*", "entity": entity, "start": start, "count": count} results: Dict = self._post_generic(self._search_endpoint, search_body) @@ -1179,7 +1220,7 @@ def execute_graphql( body["operationName"] = operation_name logger.debug( - f"Executing {operation_name or ''} graphql query: {query} with variables: {json.dumps(variables)}" + f"Executing {operation_name or ''} graphql query: {query} with variables: {json.dumps(variables)}", ) result = self._post_generic(url, body) if result.get("errors"): @@ -1220,7 +1261,7 @@ def get_related_entities( via=related_entity.get("via"), ) done = response.get("count", 0) == 0 or response.get("count", 0) < len( - response.get("entities", []) + response.get("entities", []), ) start = start + response.get("count", 0) @@ -1233,11 +1274,12 @@ def exists(self, entity_urn: str) -> bool: return result is not None else: raise Exception( - f"Failed to find key class for entity type {entity_urn_parsed.get_type()} for urn {entity_urn}" + f"Failed to find key class for entity type {entity_urn_parsed.get_type()} for urn {entity_urn}", ) except Exception as e: logger.debug( - f"Failed to check for existence of urn {entity_urn}", exc_info=e + f"Failed to check for existence of urn {entity_urn}", + exc_info=e, ) raise @@ -1253,7 +1295,10 @@ def soft_delete_entity( urn: The urn of the entity to soft-delete. """ self.set_soft_delete_status( - urn=urn, run_id=run_id, deletion_timestamp=deletion_timestamp, delete=True + urn=urn, + run_id=run_id, + deletion_timestamp=deletion_timestamp, + delete=True, ) def set_soft_delete_status( @@ -1276,9 +1321,10 @@ def set_soft_delete_status( entityUrn=urn, aspect=StatusClass(removed=delete), systemMetadata=SystemMetadataClass( - runId=run_id, lastObserved=deletion_timestamp + runId=run_id, + lastObserved=deletion_timestamp, ), - ) + ), ) def hard_delete_entity( @@ -1298,7 +1344,8 @@ def hard_delete_entity( payload_obj: Dict = {"urn": urn} summary = self._post_generic( - f"{self._gms_server}/entities?action=delete", payload_obj + f"{self._gms_server}/entities?action=delete", + payload_obj, ).get("value", {}) rows_affected: int = summary.get("rows", 0) @@ -1316,7 +1363,7 @@ def delete_entity(self, urn: str, hard: bool = False) -> None: if hard: rows_affected, timeseries_rows_affected = self.hard_delete_entity(urn) logger.debug( - f"Hard deleted entity {urn} with {rows_affected} rows affected and {timeseries_rows_affected} timeseries rows affected" + f"Hard deleted entity {urn} with {rows_affected} rows affected and {timeseries_rows_affected} timeseries rows affected", ) else: self.soft_delete_entity(urn) @@ -1356,14 +1403,17 @@ def hard_delete_timeseries_aspect( payload_obj["endTimeMillis"] = int(end_time.timestamp() * 1000) summary = self._post_generic( - f"{self._gms_server}/entities?action=delete", payload_obj + f"{self._gms_server}/entities?action=delete", + payload_obj, ).get("value", {}) timeseries_rows_affected: int = summary.get("timeseriesRows", 0) return timeseries_rows_affected def delete_references_to_urn( - self, urn: str, dry_run: bool = False + self, + urn: str, + dry_run: bool = False, ) -> Tuple[int, List[Dict]]: """Delete references to a given entity. @@ -1387,7 +1437,8 @@ def delete_references_to_urn( payload_obj = {"urn": urn, "dryRun": dry_run} response = self._post_generic( - f"{self._gms_server}/entities?action=deleteReferences", payload_obj + f"{self._gms_server}/entities?action=deleteReferences", + payload_obj, ).get("value", {}) reference_count = response.get("total", 0) related_aspects = response.get("relatedAspects", []) @@ -1419,7 +1470,10 @@ def initialize_schema_resolver_from_datahub( ) -> "SchemaResolver": logger.info("Initializing schema resolver") schema_resolver = self._make_schema_resolver( - platform, platform_instance, env, include_graph=False + platform, + platform_instance, + env, + include_graph=False, ) logger.info(f"Fetching schemas for platform {platform}, env {env}") @@ -1439,10 +1493,10 @@ def initialize_schema_resolver_from_datahub( if count % 1000 == 0: logger.debug( - f"Loaded {count} schema info in {timer.elapsed_seconds()} seconds" + f"Loaded {count} schema info in {timer.elapsed_seconds()} seconds", ) logger.info( - f"Finished loading total {count} schema info in {timer.elapsed_seconds()} seconds" + f"Finished loading total {count} schema info in {timer.elapsed_seconds()} seconds", ) logger.info("Finished initializing schema resolver") @@ -1463,7 +1517,9 @@ def parse_sql_lineage( # Cache the schema resolver to make bulk parsing faster. schema_resolver = self._make_schema_resolver( - platform=platform, platform_instance=platform_instance, env=env + platform=platform, + platform_instance=platform_instance, + env=env, ) return sqlglot_lineage( @@ -1546,7 +1602,8 @@ def _run_assertion_result_shared(self) -> str: return fragment def _run_assertion_build_params( - self, params: Optional[Dict[str, str]] = {} + self, + params: Optional[Dict[str, str]] = {}, ) -> List[Any]: if params is None: return [] diff --git a/metadata-ingestion/src/datahub/ingestion/graph/filters.py b/metadata-ingestion/src/datahub/ingestion/graph/filters.py index 588090ec567277..036bd4e9288303 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/filters.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/filters.py @@ -160,7 +160,8 @@ def _get_container_filter(container: str) -> SearchFilterRule: def _get_platform_instance_filter( - platform: Optional[str], platform_instance: str + platform: Optional[str], + platform_instance: str, ) -> SearchFilterRule: if platform: # Massage the platform instance into a fully qualified urn, if necessary. diff --git a/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py b/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py index c143a8b49f4b7c..1a0c12f9754fba 100644 --- a/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py @@ -82,7 +82,7 @@ def create( sink: Sink, ) -> PipelineRunListener: reporter_config = DatahubIngestionRunSummaryProviderConfig.parse_obj( - config_dict or {} + config_dict or {}, ) if reporter_config.sink: sink_class = sink_registry.get(reporter_config.sink.type) @@ -99,11 +99,11 @@ def create( sink_registry.get_optional("datahub-kafka"), ] if kls - ] + ], ), ): raise IgnorableError( - f"Datahub ingestion reporter will be disabled because sink type {type(sink)} is not supported" + f"Datahub ingestion reporter will be disabled because sink type {type(sink)} is not supported", ) return cls(sink, reporter_config.report_recipe, ctx) @@ -122,7 +122,8 @@ def __init__(self, sink: Sink, report_recipe: bool, ctx: PipelineContext) -> Non ) logger.debug(f"Ingestion source urn = {self.ingestion_source_urn}") self.execution_request_input_urn: Urn = Urn( - entity_type="dataHubExecutionRequest", entity_id=[ctx.run_id] + entity_type="dataHubExecutionRequest", + entity_id=[ctx.run_id], ) self.start_time_ms: int = self.get_cur_time_in_ms() @@ -131,7 +132,7 @@ def __init__(self, sink: Sink, report_recipe: bool, ctx: PipelineContext) -> Non name=self.entity_name, type=ctx.pipeline_config.source.type, platform=make_data_platform_urn( - getattr(ctx.pipeline_config.source, "platform", "unknown") + getattr(ctx.pipeline_config.source, "platform", "unknown"), ), config=DataHubIngestionSourceConfigClass( recipe=self._get_recipe_to_report(ctx), @@ -191,7 +192,7 @@ def _get_recipe_to_report(self, ctx: PipelineContext) -> str: # with a TypeError: Object of type set is not JSON serializable converted_recipe = ( DatahubIngestionRunSummaryProvider._convert_sets_to_lists( - redacted_recipe + redacted_recipe, ) ) return json.dumps(converted_recipe) diff --git a/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py b/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py index 40a95c01bdfc41..a923fcd500feac 100644 --- a/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py +++ b/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py @@ -20,7 +20,7 @@ class FileReporterConfig(ConfigModel): def only_json_supported(cls, v): if v and v.lower() != "json": raise ValueError( - f"Format {v} is not yet supported. Only json is supported at this time" + f"Format {v} is not yet supported. Only json is supported at this time", ) return v diff --git a/metadata-ingestion/src/datahub/ingestion/reporting/reporting_provider_registry.py b/metadata-ingestion/src/datahub/ingestion/reporting/reporting_provider_registry.py index 02bc1b8e473f24..867fd99c0002b8 100644 --- a/metadata-ingestion/src/datahub/ingestion/reporting/reporting_provider_registry.py +++ b/metadata-ingestion/src/datahub/ingestion/reporting/reporting_provider_registry.py @@ -3,5 +3,5 @@ reporting_provider_registry = PluginRegistry[PipelineRunListener]() reporting_provider_registry.register_from_entrypoint( - "datahub.ingestion.reporting_provider.plugins" + "datahub.ingestion.reporting_provider.plugins", ) diff --git a/metadata-ingestion/src/datahub/ingestion/run/connection.py b/metadata-ingestion/src/datahub/ingestion/run/connection.py index 54b0ab9f22c65e..c739e691d44505 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/connection.py +++ b/metadata-ingestion/src/datahub/ingestion/run/connection.py @@ -23,7 +23,7 @@ def test_source_connection(self, recipe_config_dict: dict) -> TestConnectionRepo ): # validate that the class overrides the base implementation return source_class.test_connection( - recipe_config_dict.get("source", {}).get("config", {}) + recipe_config_dict.get("source", {}).get("config", {}), ) else: return TestConnectionReport( diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index 25cbd340c9674b..7c693348e84ecf 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -63,10 +63,12 @@ def __init__(self, name: str = "") -> None: self.name = name def on_success( - self, record_envelope: RecordEnvelope, success_metadata: dict + self, + record_envelope: RecordEnvelope, + success_metadata: dict, ) -> None: logger.debug( - f"{self.name} sink wrote workunit {record_envelope.metadata['workunit_id']}" + f"{self.name} sink wrote workunit {record_envelope.metadata['workunit_id']}", ) def on_failure( @@ -91,7 +93,9 @@ def __init__(self, ctx: PipelineContext, config: Optional[FileSinkConfig]) -> No logger.info(f"Failure logging enabled. Will log to {config.filename}.") def on_success( - self, record_envelope: RecordEnvelope, success_metadata: dict + self, + record_envelope: RecordEnvelope, + success_metadata: dict, ) -> None: pass @@ -169,7 +173,7 @@ def compute_stats(self) -> None: if self._peak_memory_usage < mem_usage: self._peak_memory_usage = mem_usage self.peak_memory_usage = humanfriendly.format_size( - self._peak_memory_usage + self._peak_memory_usage, ) self.mem_info = humanfriendly.format_size(mem_usage) except Exception as e: @@ -254,7 +258,7 @@ def __init__( if self.config.sink is None: logger.info( - "No sink configured, attempting to use the default datahub-rest sink." + "No sink configured, attempting to use the default datahub-rest sink.", ) with _add_init_error_context("configure the default rest sink"): self.sink_type = "datahub-rest" @@ -262,7 +266,7 @@ def __init__( else: self.sink_type = self.config.sink.type with _add_init_error_context( - f"find a registered sink for type {self.sink_type}" + f"find a registered sink for type {self.sink_type}", ): sink_class = sink_registry.get(self.sink_type) @@ -284,16 +288,17 @@ def __init__( self._configure_reporting(report_to) with _add_init_error_context( - f"find a registered source for type {self.source_type}" + f"find a registered source for type {self.source_type}", ): source_class = source_registry.get(self.source_type) with _add_init_error_context(f"configure the source ({self.source_type})"): self.source = source_class.create( - self.config.source.dict().get("config", {}), self.ctx + self.config.source.dict().get("config", {}), + self.ctx, ) logger.debug( - f"Source type {self.source_type} ({source_class}) configured" + f"Source type {self.source_type} ({source_class}) configured", ) logger.info("Source configured successfully.") @@ -301,7 +306,8 @@ def __init__( with _add_init_error_context(f"configure the extractor ({extractor_type})"): extractor_class = extractor_registry.get(extractor_type) self.extractor = extractor_class( - self.config.source.extractor_config, self.ctx + self.config.source.extractor_config, + self.ctx, ) with _add_init_error_context("configure transformers"): @@ -319,10 +325,10 @@ def _configure_transforms(self) -> None: transformer_class = transform_registry.get(transformer_type) transformer_config = transformer.dict().get("config", {}) self.transformers.append( - transformer_class.create(transformer_config, self.ctx) + transformer_class.create(transformer_config, self.ctx), ) logger.debug( - f"Transformer type:{transformer_type},{transformer_class} configured" + f"Transformer type:{transformer_type},{transformer_class} configured", ) # Add the system metadata transformer at the end of the list. @@ -342,14 +348,14 @@ def _configure_reporting(self, report_to: Optional[str]) -> None: reporter.type for reporter in self.config.reporting ]: self.config.reporting.append( - ReporterConfig.parse_obj({"type": "datahub"}) + ReporterConfig.parse_obj({"type": "datahub"}), ) elif report_to: # we assume this is a file name, and add the file reporter self.config.reporting.append( ReporterConfig.parse_obj( - {"type": "file", "config": {"filename": report_to}} - ) + {"type": "file", "config": {"filename": report_to}}, + ), ) for reporter in self.config.reporting: @@ -362,10 +368,10 @@ def _configure_reporting(self, report_to: Optional[str]) -> None: config_dict=reporter_config_dict, ctx=self.ctx, sink=self.sink, - ) + ), ) logger.debug( - f"Reporter type:{reporter_type},{reporter_class} configured." + f"Reporter type:{reporter_type},{reporter_class} configured.", ) except Exception as e: if reporter.required: @@ -374,7 +380,8 @@ def _configure_reporting(self, report_to: Optional[str]) -> None: logger.debug(f"Reporter type {reporter_type} is disabled: {e}") else: logger.warning( - f"Failed to configure reporter: {reporter_type}", exc_info=e + f"Failed to configure reporter: {reporter_type}", + exc_info=e, ) def _notify_reporters_on_ingestion_start(self) -> None: @@ -446,8 +453,8 @@ def run(self) -> None: # noqa: C901 stack.enter_context( memray.Tracker( - f"{self.config.flags.generate_memory_profiles}/{self.config.run_id}.bin" - ) + f"{self.config.flags.generate_memory_profiles}/{self.config.run_id}.bin", + ), ) stack.enter_context(self.sink) @@ -460,7 +467,8 @@ def run(self) -> None: # noqa: C901 LoggingCallback() if not self.config.failure_log.enabled else DeadLetterQueueCallback( - self.ctx, self.config.failure_log.log_config + self.ctx, + self.config.failure_log.log_config, ) ) for wu in itertools.islice( @@ -483,7 +491,9 @@ def run(self) -> None: # noqa: C901 record_envelopes = list(self.extractor.get_records(wu)) except Exception as e: self.source.get_report().failure( - "Source produced bad metadata", context=wu.id, exc=e + "Source produced bad metadata", + context=wu.id, + exc=e, ) continue try: @@ -491,12 +501,13 @@ def run(self) -> None: # noqa: C901 if not self.dry_run: try: self.sink.write_record_async( - record_envelope, callback + record_envelope, + callback, ) except Exception as e: # In case the sink's error handling is bad, we still want to report the error. self.sink.report.report_failure( - f"Failed to write record: {e}" + f"Failed to write record: {e}", ) except (RuntimeError, SystemExit): @@ -518,11 +529,12 @@ def run(self) -> None: # noqa: C901 RecordEnvelope( record=EndOfStream(), metadata={"workunit_id": "end-of-stream"}, - ) - ] + ), + ], ): if not self.dry_run and not isinstance( - record_envelope.record, EndOfStream + record_envelope.record, + EndOfStream, ): # TODO: propagate EndOfStream and other control events to sinks, to allow them to flush etc. self.sink.write_record_async(record_envelope, callback) @@ -566,7 +578,7 @@ def process_commits(self) -> None: else False ) has_warnings: bool = bool( - self.source.get_report().warnings or self.sink.get_report().warnings + self.source.get_report().warnings or self.sink.get_report().warnings, ) for name, committable in self.ctx.get_committables(): @@ -574,7 +586,7 @@ def process_commits(self) -> None: logger.info( f"Processing commit request for {name}. Commit policy = {commit_policy}," - f" has_errors={has_errors}, has_warnings={has_warnings}" + f" has_errors={has_errors}, has_warnings={has_warnings}", ) if ( @@ -582,7 +594,7 @@ def process_commits(self) -> None: and (has_errors or has_warnings) ) or (commit_policy == CommitPolicy.ON_NO_ERRORS and has_errors): logger.warning( - f"Skipping commit request for {name} since policy requirements are not met." + f"Skipping commit request for {name} since policy requirements are not met.", ) continue @@ -597,18 +609,21 @@ def process_commits(self) -> None: def raise_from_status(self, raise_warnings: bool = False) -> None: if self.source.get_report().failures: raise PipelineExecutionError( - "Source reported errors", self.source.get_report() + "Source reported errors", + self.source.get_report(), ) if self.sink.get_report().failures: raise PipelineExecutionError("Sink reported errors", self.sink.get_report()) if raise_warnings: if self.source.get_report().warnings: raise PipelineExecutionError( - "Source reported warnings", self.source.get_report() + "Source reported warnings", + self.source.get_report(), ) if self.sink.get_report().warnings: raise PipelineExecutionError( - "Sink reported warnings", self.sink.get_report() + "Sink reported warnings", + self.sink.get_report(), ) def log_ingestion_stats(self) -> None: @@ -627,7 +642,7 @@ def log_ingestion_stats(self) -> None: transformer.type for transformer in self.config.transformers or [] ], "records_written": stats.discretize( - self.sink.get_report().total_records_written + self.sink.get_report().total_records_written, ), "source_failures": stats.discretize(source_failures), "source_warnings": stats.discretize(source_warnings), @@ -636,7 +651,7 @@ def log_ingestion_stats(self) -> None: "global_warnings": global_warnings, "failures": stats.discretize(source_failures + sink_failures), "warnings": stats.discretize( - source_warnings + sink_warnings + global_warnings + source_warnings + sink_warnings + global_warnings, ), "has_pipeline_name": bool(self.config.pipeline_name), }, @@ -658,11 +673,13 @@ def _get_text_color(self, running: bool, failures: bool, warnings: bool) -> str: def has_failures(self) -> bool: return bool( - self.source.get_report().failures or self.sink.get_report().failures + self.source.get_report().failures or self.sink.get_report().failures, ) def pretty_print_summary( - self, warnings_as_failure: bool = False, currently_running: bool = False + self, + warnings_as_failure: bool = False, + currently_running: bool = False, ) -> int: workunits_produced = self.sink.get_report().total_records_written @@ -696,15 +713,17 @@ def pretty_print_summary( if self.source.get_report().failures or self.sink.get_report().failures: num_failures_source = self._approx_all_vals( - self.source.get_report().failures + self.source.get_report().failures, ) num_failures_sink = len(self.sink.get_report().failures) click.secho( message_template.format( - status=f"with at least {num_failures_source + num_failures_sink} failures" + status=f"with at least {num_failures_source + num_failures_sink} failures", ), fg=self._get_text_color( - running=currently_running, failures=True, warnings=False + running=currently_running, + failures=True, + warnings=False, ), bold=True, ) @@ -719,10 +738,12 @@ def pretty_print_summary( num_warn_global = len(global_warnings) click.secho( message_template.format( - status=f"with at least {num_warn_source + num_warn_sink + num_warn_global} warnings" + status=f"with at least {num_warn_source + num_warn_sink + num_warn_global} warnings", ), fg=self._get_text_color( - running=currently_running, failures=False, warnings=True + running=currently_running, + failures=False, + warnings=True, ), bold=True, ) @@ -731,7 +752,9 @@ def pretty_print_summary( click.secho( message_template.format(status="successfully"), fg=self._get_text_color( - running=currently_running, failures=False, warnings=False + running=currently_running, + failures=False, + warnings=False, ), bold=True, ) @@ -739,7 +762,8 @@ def pretty_print_summary( def _handle_uncaught_pipeline_exception(self, exc: Exception) -> None: logger.exception( - f"Ingestion pipeline threw an uncaught exception: {exc}", stacklevel=2 + f"Ingestion pipeline threw an uncaught exception: {exc}", + stacklevel=2, ) self.source.get_report().report_failure( title="Pipeline Error", diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py index 53e31aa2ea96e1..8810f2e48eac50 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py @@ -64,7 +64,8 @@ class FlagsConfig(ConfigModel): ) set_system_metadata: bool = Field( - True, description="Set system metadata on entities." + True, + description="Set system metadata on entities.", ) set_system_metadata_pipeline_name: bool = Field( True, @@ -98,7 +99,10 @@ class PipelineConfig(ConfigModel): @validator("run_id", pre=True, always=True) def run_id_should_be_semantic( - cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any + cls, + v: Optional[str], + values: Dict[str, Any], + **kwargs: Any, ) -> str: if v == DEFAULT_RUN_ID: source_type = None @@ -112,7 +116,9 @@ def run_id_should_be_semantic( @classmethod def from_dict( - cls, resolved_dict: dict, raw_dict: Optional[dict] = None + cls, + resolved_dict: dict, + raw_dict: Optional[dict] = None, ) -> "PipelineConfig": config = cls.parse_obj(resolved_dict) config._raw_dict = raw_dict diff --git a/metadata-ingestion/src/datahub/ingestion/sink/blackhole.py b/metadata-ingestion/src/datahub/ingestion/sink/blackhole.py index 38b95373d47bde..63cfe7e17047a2 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/blackhole.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/blackhole.py @@ -9,7 +9,9 @@ class BlackHoleSink(Sink[ConfigModel, SinkReport]): def write_record_async( - self, record_envelope: RecordEnvelope, write_callback: WriteCallback + self, + record_envelope: RecordEnvelope, + write_callback: WriteCallback, ) -> None: if write_callback: self.report.report_record_written(record_envelope) diff --git a/metadata-ingestion/src/datahub/ingestion/sink/console.py b/metadata-ingestion/src/datahub/ingestion/sink/console.py index ee7bae565fac87..a2fa815534f05d 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/console.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/console.py @@ -9,7 +9,9 @@ class ConsoleSink(Sink[ConfigModel, SinkReport]): def write_record_async( - self, record_envelope: RecordEnvelope, write_callback: WriteCallback + self, + record_envelope: RecordEnvelope, + write_callback: WriteCallback, ) -> None: print(f"{record_envelope}") if write_callback: diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_kafka.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_kafka.py index 38ddadaafc862c..f492841cdf4e6b 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_kafka.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_kafka.py @@ -27,7 +27,9 @@ def kafka_callback(self, err: Optional[Exception], msg: str) -> None: if err is not None: self.reporter.report_failure(err) self.write_callback.on_failure( - self.record_envelope, err, {"error": err, "msg": msg} + self.record_envelope, + err, + {"error": err, "msg": msg}, ) else: self.reporter.report_record_written(self.record_envelope) @@ -58,7 +60,9 @@ def write_record_async( write_callback: WriteCallback, ) -> None: callback = _KafkaCallback( - self.report, record_envelope, write_callback + self.report, + record_envelope, + write_callback, ).kafka_callback try: record = record_envelope.record diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 209efbbb90febc..4409dd16230663 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) _DEFAULT_REST_SINK_MAX_THREADS = int( - os.getenv("DATAHUB_REST_SINK_DEFAULT_MAX_THREADS", 15) + os.getenv("DATAHUB_REST_SINK_DEFAULT_MAX_THREADS", 15), ) @@ -60,7 +60,8 @@ class RestSinkMode(ConfigEnum): _DEFAULT_REST_SINK_MODE = pydantic.parse_obj_as( - RestSinkMode, os.getenv("DATAHUB_REST_SINK_DEFAULT_MODE", RestSinkMode.ASYNC_BATCH) + RestSinkMode, + os.getenv("DATAHUB_REST_SINK_DEFAULT_MODE", RestSinkMode.ASYNC_BATCH), ) @@ -78,7 +79,7 @@ class DatahubRestSinkConfig(DatahubClientConfig): def validate_max_per_batch(cls, v): if v > BATCH_INGEST_MAX_PAYLOAD_LENGTH: raise ValueError( - f"max_per_batch must be less than or equal to {BATCH_INGEST_MAX_PAYLOAD_LENGTH}" + f"max_per_batch must be less than or equal to {BATCH_INGEST_MAX_PAYLOAD_LENGTH}", ) return v @@ -131,7 +132,7 @@ def __post_init__(self) -> None: gms_config = self.emitter.get_server_config() except Exception as exc: raise ConfigurationError( - f"💥 Failed to connect to DataHub with {repr(self.emitter)}" + f"💥 Failed to connect to DataHub with {repr(self.emitter)}", ) from exc self.report.gms_version = ( @@ -202,7 +203,9 @@ def _write_done_callback( if future.cancelled(): self.report.report_failure({"error": "future was cancelled"}) write_callback.on_failure( - record_envelope, OperationalError("future was cancelled"), {} + record_envelope, + OperationalError("future was cancelled"), + {}, ) elif future.done(): e = future.exception() @@ -215,7 +218,7 @@ def _write_done_callback( if "stackTrace" in e.info: with contextlib.suppress(Exception): e.info["stackTrace"] = "\n".join( - e.info["stackTrace"].split("\n")[:3] + e.info["stackTrace"].split("\n")[:3], ) e.info["message"] = e.info.get("message", "").split("\n")[0][ :200 @@ -277,7 +280,7 @@ def _emit_batch_wrapper( self.report.async_batches_split += chunks logger.info( f"In async_batch mode, the payload was split into {chunks} batches. " - "If there's many of these issues, consider decreasing `max_per_batch`." + "If there's many of these issues, consider decreasing `max_per_batch`.", ) def write_record_async( @@ -303,7 +306,9 @@ def write_record_async( self._emit_wrapper, record, done_callback=functools.partial( - self._write_done_callback, record_envelope, write_callback + self._write_done_callback, + record_envelope, + write_callback, ), ) self.report.pending_requests += 1 @@ -314,7 +319,9 @@ def write_record_async( partition_key, record, done_callback=functools.partial( - self._write_done_callback, record_envelope, write_callback + self._write_done_callback, + record_envelope, + write_callback, ), ) self.report.pending_requests += 1 @@ -329,7 +336,9 @@ def write_record_async( def emit_async( self, item: Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ], ) -> None: return self.write_record_async( diff --git a/metadata-ingestion/src/datahub/ingestion/sink/file.py b/metadata-ingestion/src/datahub/ingestion/sink/file.py index c4f34d780f7c87..757203f984e494 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/file.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/file.py @@ -65,7 +65,8 @@ def write_record_async( ) -> None: record = record_envelope.record obj = _to_obj_for_file( - record, simplified_structure=not self.config.legacy_nested_json_string + record, + simplified_structure=not self.config.legacy_nested_json_string, ) if self.wrote_something: diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/config.py b/metadata-ingestion/src/datahub/ingestion/source/abs/config.py index c62239527a1200..3c28621838086a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/config.py @@ -26,7 +26,9 @@ class DataLakeSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin, PathSpecsConfigMixin + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, + PathSpecsConfigMixin, ): platform: str = Field( default="", @@ -35,7 +37,8 @@ class DataLakeSourceConfig( ) azure_config: Optional[AzureConnectionConfig] = Field( - default=None, description="Azure configuration" + default=None, + description="Azure configuration", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None @@ -66,11 +69,13 @@ class DataLakeSourceConfig( description="regex patterns for tables to profile ", ) profiling: DataLakeProfilerConfig = Field( - default=DataLakeProfilerConfig(), description="Data profiling configuration" + default=DataLakeProfilerConfig(), + description="Data profiling configuration", ) spark_driver_memory: str = Field( - default="4g", description="Max amount of memory to grant Spark." + default="4g", + description="Max amount of memory to grant Spark.", ) spark_config: Dict[str, Any] = Field( @@ -97,17 +102,21 @@ class DataLakeSourceConfig( ) _rename_path_spec_to_plural = pydantic_renamed_field( - "path_spec", "path_specs", lambda path_spec: [path_spec] + "path_spec", + "path_specs", + lambda path_spec: [path_spec], ) def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @pydantic.validator("path_specs", always=True) def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict + cls, + path_specs: List[PathSpec], + values: Dict, ) -> List[PathSpec]: if len(path_specs) == 0: raise ValueError("path_specs must not be empty") @@ -118,7 +127,7 @@ def check_path_specs_and_infer_platform( ) if len(guessed_platforms) > 1: raise ValueError( - f"Cannot have multiple platforms in path_specs: {guessed_platforms}" + f"Cannot have multiple platforms in path_specs: {guessed_platforms}", ) guessed_platform = guessed_platforms.pop() @@ -129,13 +138,13 @@ def check_path_specs_and_infer_platform( or values.get("use_abs_blob_properties") ): raise ValueError( - "Cannot grab abs blob/container tags when platform is not abs. Remove the flag or use abs." + "Cannot grab abs blob/container tags when platform is not abs. Remove the flag or use abs.", ) # Infer platform if not specified. if values.get("platform") and values["platform"] != guessed_platform: raise ValueError( - f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}" + f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}", ) else: logger.debug(f'Setting config "platform": {guessed_platform}') @@ -146,7 +155,8 @@ def check_path_specs_and_infer_platform( @pydantic.validator("platform", always=True) def platform_not_empty(cls, platform: str, values: dict) -> str: inferred_platform = values.get( - "platform", None + "platform", + None, ) # we may have inferred it above platform = platform or inferred_platform if not platform: @@ -155,7 +165,8 @@ def platform_not_empty(cls, platform: str, values: dict) -> str: @pydantic.root_validator() def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] + cls, + values: Dict[str, Any], ) -> Dict[str, Any]: profiling: Optional[DataLakeProfilerConfig] = values.get("profiling") if profiling is not None and profiling.enabled: diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py index d12ff7415faefc..ab56b6aa6bc583 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py @@ -10,7 +10,8 @@ class DataLakeProfilerConfig(ConfigModel): enabled: bool = Field( - default=False, description="Whether profiling should be done." + default=False, + description="Whether profiling should be done.", ) operation_config: OperationConfig = Field( default_factory=OperationConfig, @@ -61,7 +62,8 @@ class DataLakeProfilerConfig(ConfigModel): description="Whether to profile for the quantiles of numeric columns.", ) include_field_distinct_value_frequencies: bool = Field( - default=True, description="Whether to profile for distinct value frequencies." + default=True, + description="Whether to profile for distinct value frequencies.", ) include_field_histogram: bool = Field( default=True, @@ -74,7 +76,8 @@ class DataLakeProfilerConfig(ConfigModel): @pydantic.root_validator() def ensure_field_level_settings_are_normalized( - cls: "DataLakeProfilerConfig", values: Dict[str, Any] + cls: "DataLakeProfilerConfig", + values: Dict[str, Any], ) -> Dict[str, Any]: max_num_fields_to_profile_key = "max_number_of_fields_to_profile" max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/profiling.py b/metadata-ingestion/src/datahub/ingestion/source/abs/profiling.py index c969b229989e84..d8a4f0069fcdfb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/profiling.py @@ -136,12 +136,13 @@ def __init__( ] self.report.report_file_dropped( - f"The max_number_of_fields_to_profile={self.profiling_config.max_number_of_fields_to_profile} reached. Profile of columns {self.file_path}({', '.join(sorted(columns_being_dropped))})" + f"The max_number_of_fields_to_profile={self.profiling_config.max_number_of_fields_to_profile} reached. Profile of columns {self.file_path}({', '.join(sorted(columns_being_dropped))})", ) analysis_result = self.analyzer.run() analysis_metrics = AnalyzerContext.successMetricsAsJson( - self.spark, analysis_result + self.spark, + analysis_result, ) # reshape distinct counts into dictionary @@ -156,7 +157,7 @@ def __init__( when( isnan(c) | col(c).isNull(), c, - ) + ), ).alias(c) for c in self.columns_to_profile if column_types[column] in [DoubleType, FloatType] @@ -168,14 +169,14 @@ def __init__( when( col(c).isNull(), c, - ) + ), ).alias(c) for c in self.columns_to_profile if column_types[column] not in [DoubleType, FloatType] ] null_counts = dataframe.select( - select_numeric_null_counts + select_nonnumeric_null_counts + select_numeric_null_counts + select_nonnumeric_null_counts, ) column_null_counts = null_counts.toPandas().T[0].to_dict() column_null_fractions = { @@ -215,7 +216,7 @@ def __init__( column_profile.nullProportion = column_null_fractions.get(column) if self.profiling_config.include_field_sample_values: column_profile.sampleValues = sorted( - [str(x[column]) for x in rdd_sample] + [str(x[column]) for x in rdd_sample], ) column_spec.type_ = column_types[column] @@ -373,7 +374,7 @@ def extract_table_profiles( # resolve histogram types for grouping column_metrics["kind"] = column_metrics["name"].apply( - lambda x: "Histogram" if x.startswith("Histogram.") else x + lambda x: "Histogram" if x.startswith("Histogram.") else x, ) column_histogram_metrics = column_metrics[column_metrics["kind"] == "Histogram"] @@ -387,12 +388,12 @@ def extract_table_profiles( # we only want the absolute counts for each histogram for now column_histogram_metrics = column_histogram_metrics[ column_histogram_metrics["name"].apply( - lambda x: x.startswith("Histogram.abs.") + lambda x: x.startswith("Histogram.abs."), ) ] # get the histogram bins by chopping off the "Histogram.abs." prefix column_histogram_metrics["bin"] = column_histogram_metrics["name"].apply( - lambda x: x[14:] + lambda x: x[14:], ) # reshape histogram counts for easier access @@ -407,7 +408,7 @@ def extract_table_profiles( if len(column_nonhistogram_metrics) > 0: # reshape other metrics for easier access nonhistogram_metrics = column_nonhistogram_metrics.set_index( - ["instance", "name"] + ["instance", "name"], )["value"] profiled_columns = set(nonhistogram_metrics.index.get_level_values(0)) @@ -428,10 +429,10 @@ def extract_table_profiles( column_profile.max = null_str(deequ_column_profile.get("Maximum")) column_profile.mean = null_str(deequ_column_profile.get("Mean")) column_profile.median = null_str( - deequ_column_profile.get("ApproxQuantiles-0.5") + deequ_column_profile.get("ApproxQuantiles-0.5"), ) column_profile.stdev = null_str( - deequ_column_profile.get("StandardDeviation") + deequ_column_profile.get("StandardDeviation"), ) if all( deequ_column_profile.get(f"ApproxQuantiles-{quantile}") is not None @@ -453,13 +454,15 @@ def extract_table_profiles( if column_spec.histogram_distinct: column_profile.distinctValueFrequencies = [ ValueFrequencyClass( - value=value, frequency=int(column_histogram.loc[value]) + value=value, + frequency=int(column_histogram.loc[value]), ) for value in column_histogram.index ] # sort so output is deterministic column_profile.distinctValueFrequencies = sorted( - column_profile.distinctValueFrequencies, key=lambda x: x.value + column_profile.distinctValueFrequencies, + key=lambda x: x.value, ) else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py index 586e7a3af3bcd1..7d248ac9f1ad30 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py @@ -193,17 +193,18 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List: fields = parquet.ParquetInferrer().infer_schema(file) elif extension == ".csv": fields = csv_tsv.CsvInferrer( - max_rows=self.source_config.max_rows + max_rows=self.source_config.max_rows, ).infer_schema(file) elif extension == ".tsv": fields = csv_tsv.TsvInferrer( - max_rows=self.source_config.max_rows + max_rows=self.source_config.max_rows, ).infer_schema(file) elif extension == ".json": fields = json.JsonInferrer().infer_schema(file) elif extension == ".jsonl": fields = json.JsonInferrer( - max_rows=self.source_config.max_rows, format="jsonl" + max_rows=self.source_config.max_rows, + format="jsonl", ).infer_schema(file) elif extension == ".avro": fields = avro.AvroInferrer().infer_schema(file) @@ -224,13 +225,18 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List: if self.source_config.add_partition_columns_to_schema: self.add_partition_columns_to_schema( - fields=fields, path_spec=path_spec, full_path=table_data.full_path + fields=fields, + path_spec=path_spec, + full_path=table_data.full_path, ) return fields def add_partition_columns_to_schema( - self, path_spec: PathSpec, full_path: str, fields: List[SchemaField] + self, + path_spec: PathSpec, + full_path: str, + fields: List[SchemaField], ) -> None: vars = path_spec.get_named_vars(full_path) if vars is not None and "partition" in vars: @@ -238,7 +244,7 @@ def add_partition_columns_to_schema( partition_arr = partition.split("=") if len(partition_arr) != 2: logger.debug( - f"Could not derive partition key from partition field {partition}" + f"Could not derive partition key from partition field {partition}", ) continue partition_key = partition_arr[0] @@ -250,7 +256,7 @@ def add_partition_columns_to_schema( isPartitioningKey=True, nullable=True, recursive=False, - ) + ), ) def _create_table_operation_aspect(self, table_data: TableData) -> OperationClass: @@ -265,7 +271,9 @@ def _create_table_operation_aspect(self, table_data: TableData) -> OperationClas return operation def ingest_table( - self, table_data: TableData, path_spec: PathSpec + self, + table_data: TableData, + path_spec: PathSpec, ) -> Iterable[MetadataWorkUnit]: aspects: List[Optional[_Aspect]] = [] @@ -289,7 +297,8 @@ def ingest_table( data_platform_instance = DataPlatformInstanceClass( platform=data_platform_urn, instance=make_dataplatform_instance_urn( - self.source_config.platform, self.source_config.platform_instance + self.source_config.platform, + self.source_config.platform_instance, ), ) aspects.append(data_platform_instance) @@ -333,11 +342,11 @@ def ingest_table( aspects.append(schema_metadata) except Exception as e: logger.error( - f"Failed to extract schema from file {table_data.full_path}. The error was:{e}" + f"Failed to extract schema from file {table_data.full_path}. The error was:{e}", ) else: logger.info( - f"Skipping schema extraction for empty file {table_data.full_path}" + f"Skipping schema extraction for empty file {table_data.full_path}", ) if ( @@ -364,7 +373,8 @@ def ingest_table( yield mcp.as_workunit() yield from self.container_WU_creator.create_container_hierarchy( - table_data.table_path, dataset_urn + table_data.table_path, + dataset_urn, ) def get_prefix(self, relative_path: str) -> str: @@ -403,7 +413,9 @@ def extract_table_data( return table_data def resolve_templated_folders( - self, container_name: str, prefix: str + self, + container_name: str, + prefix: str, ) -> Iterable[str]: folder_split: List[str] = prefix.split("*", 1) # If the len of split is 1 it means we don't have * in the prefix @@ -412,11 +424,14 @@ def resolve_templated_folders( return folders: Iterable[str] = list_folders( - container_name, folder_split[0], self.source_config.azure_config + container_name, + folder_split[0], + self.source_config.azure_config, ) for folder in folders: yield from self.resolve_templated_folders( - container_name, f"{folder}{folder_split[1]}" + container_name, + f"{folder}{folder_split[1]}", ) def get_dir_to_process( @@ -451,7 +466,9 @@ def get_dir_to_process( return folder def abs_browser( - self, path_spec: PathSpec, sample_size: int + self, + path_spec: PathSpec, + sample_size: int, ) -> Iterable[Tuple[str, str, datetime, int]]: if self.source_config.azure_config is None: raise ValueError("azure_config not set. Cannot browse Azure Blob Storage") @@ -459,7 +476,7 @@ def abs_browser( self.source_config.azure_config.get_blob_service_client() ) container_client = abs_blob_service_client.get_container_client( - self.source_config.azure_config.container_name + self.source_config.azure_config.container_name, ) container_name = self.source_config.azure_config.container_name @@ -490,7 +507,9 @@ def abs_browser( ): try: for f in list_folders( - container_name, f"{folder}", self.source_config.azure_config + container_name, + f"{folder}", + self.source_config.azure_config, ): logger.info(f"Processing folder: {f}") protocol = ContainerWUCreator.get_protocol(path_spec.include) @@ -520,20 +539,23 @@ def abs_browser( # https://github.com/boto/boto3/issues/1195 if "NoSuchBucket" in repr(e): logger.debug( - f"Got NoSuchBucket exception for {container_name}", e + f"Got NoSuchBucket exception for {container_name}", + e, ) self.get_report().report_warning( - "Missing bucket", f"No bucket found {container_name}" + "Missing bucket", + f"No bucket found {container_name}", ) else: raise e else: logger.debug( - "No template in the pathspec can't do sampling, fallbacking to do full scan" + "No template in the pathspec can't do sampling, fallbacking to do full scan", ) path_spec.sample_files = False for obj in container_client.list_blobs( - prefix=f"{prefix}", results_per_page=PAGE_SIZE + prefix=f"{prefix}", + results_per_page=PAGE_SIZE, ): abs_path = self.create_abs_path(obj.name) logger.debug(f"Path: {abs_path}") @@ -551,7 +573,8 @@ def create_abs_path(self, key: str) -> str: return "" def local_browser( - self, path_spec: PathSpec + self, + path_spec: PathSpec, ) -> Iterable[Tuple[str, str, datetime, int]]: prefix = self.get_prefix(path_spec.include) if os.path.isfile(prefix): @@ -571,7 +594,7 @@ def local_browser( for file in sorted(files): # We need to make sure the path is in posix style which is not true on windows full_path = PurePath( - os.path.normpath(os.path.join(root, file)) + os.path.normpath(os.path.join(root, file)), ).as_posix() yield ( full_path, @@ -591,7 +614,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for path_spec in self.source_config.path_specs: file_browser = ( self.abs_browser( - path_spec, self.source_config.number_of_files_to_sample + path_spec, + self.source_config.number_of_files_to_sample, ) if self.is_abs_platform() else self.local_browser(path_spec) @@ -601,7 +625,11 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if not path_spec.allowed(file): continue table_data = self.extract_table_data( - path_spec, file, name, timestamp, size + path_spec, + file, + name, + timestamp, + size, ) if table_data.table_path not in table_dict: table_dict[table_data.table_path] = table_data @@ -631,7 +659,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index b76eb95def1ede..c0c954e350f067 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -114,7 +114,7 @@ def detect_aws_environment() -> AwsEnvironment: # Check ECS if os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv( - "ECS_CONTAINER_METADATA_URI" + "ECS_CONTAINER_METADATA_URI", ): return AwsEnvironment.ECS @@ -160,7 +160,7 @@ def get_lambda_role_arn() -> Optional[str]: lambda_client = boto3.client("lambda") function_config = lambda_client.get_function_configuration( - FunctionName=function_name + FunctionName=function_name, ) return function_config.get("Role") except Exception as e: @@ -194,7 +194,7 @@ def get_current_identity() -> Tuple[Optional[str], Optional[str]]: elif env == AwsEnvironment.ECS: try: metadata_uri = os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv( - "ECS_CONTAINER_METADATA_URI" + "ECS_CONTAINER_METADATA_URI", ) if metadata_uri: response = requests.get(f"{metadata_uri}/task", timeout=1) @@ -339,7 +339,8 @@ def get_session(self) -> Session: elif self.aws_profile: # Named profile is second priority session = Session( - region_name=self.aws_region, profile_name=self.aws_profile + region_name=self.aws_region, + profile_name=self.aws_profile, ) else: # Use boto3's credential autodetection @@ -419,7 +420,8 @@ def _aws_config(self) -> Config: ) def get_s3_client( - self, verify_ssl: Optional[Union[bool, str]] = None + self, + verify_ssl: Optional[Union[bool, str]] = None, ) -> "S3Client": return self.get_session().client( "s3", @@ -429,7 +431,8 @@ def get_s3_client( ) def get_s3_resource( - self, verify_ssl: Optional[Union[bool, str]] = None + self, + verify_ssl: Optional[Union[bool, str]] = None, ) -> "S3ServiceResource": resource = self.get_session().resource( "s3", diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py index 2509927854d4a0..c86c4c1bf705f0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py @@ -121,7 +121,9 @@ class GlueSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin, AwsSourceConfig + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, + AwsSourceConfig, ): platform: str = Field( default=DEFAULT_PLATFORM, @@ -134,14 +136,16 @@ class GlueSourceConfig( description="When enabled, extracts ownership from Glue table property and overwrites existing owners (DATAOWNER). When disabled, ownership is left empty for datasets. Expects a corpGroup urn, a corpuser urn or only the identifier part for the latter. Not used in the normal course of AWS Glue operations.", ) extract_transforms: Optional[bool] = Field( - default=True, description="Whether to extract Glue transform jobs." + default=True, + description="Whether to extract Glue transform jobs.", ) ignore_unsupported_connectors: Optional[bool] = Field( default=True, description="Whether to ignore unsupported connectors. If disabled, an error will be raised.", ) emit_s3_lineage: bool = Field( - default=False, description="Whether to emit S3-to-Glue lineage." + default=False, + description="Whether to emit S3-to-Glue lineage.", ) glue_s3_lineage_direction: str = Field( default="upstream", @@ -173,7 +177,8 @@ class GlueSourceConfig( ) # Custom Stateful Ingestion settings stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="" + default=None, + description="", ) extract_delta_schema_from_parameters: Optional[bool] = Field( default=False, @@ -187,7 +192,7 @@ class GlueSourceConfig( def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @property @@ -202,7 +207,7 @@ def s3_client(self): def check_direction(cls, v: str) -> str: if v.lower() not in ["upstream", "downstream"]: raise ValueError( - "glue_s3_lineage_direction must be either upstream or downstream" + "glue_s3_lineage_direction must be either upstream or downstream", ) return v.lower() @@ -212,7 +217,7 @@ def platform_validator(cls, v: str) -> str: return v else: raise ValueError( - f"'platform' can only take following values: {VALID_PLATFORMS}" + f"'platform' can only take following values: {VALID_PLATFORMS}", ) @@ -249,7 +254,8 @@ def report_table_dropped(self, table: str) -> None: ) @capability(SourceCapability.LINEAGE_COARSE, "Enabled by default") @capability( - SourceCapability.LINEAGE_FINE, "Support via the `emit_s3_lineage` config field" + SourceCapability.LINEAGE_FINE, + "Support via the `emit_s3_lineage` config field", ) class GlueSource(StatefulIngestionSourceBase): """ @@ -321,7 +327,10 @@ def __init__(self, config: GlueSourceConfig, ctx: PipelineContext): self.env = config.env def get_glue_arn( - self, account_id: str, database: str, table: Optional[str] = None + self, + account_id: str, + database: str, + table: Optional[str] = None, ) -> str: prefix = f"arn:aws:glue:{self.source_config.aws_region}:{account_id}" if table: @@ -352,7 +361,9 @@ def get_all_jobs(self): return jobs def get_dataflow_graph( - self, script_path: str, flow_urn: str + self, + script_path: str, + flow_urn: str, ) -> Optional[Dict[str, Any]]: """ Get the DAG of transforms and data sources/sinks for a job. @@ -420,7 +431,8 @@ def get_s3_uri(self, node_args): return s3_uri def get_dataflow_s3_names( - self, dataflow_graph: Dict[str, Any] + self, + dataflow_graph: Dict[str, Any], ) -> Iterator[Tuple[str, Optional[str]]]: # iterate through each node to populate processed nodes for node in dataflow_graph["DagNodes"]: @@ -500,11 +512,11 @@ def process_dataflow_node( DatasetPropertiesClass( customProperties={k: str(v) for k, v in node_args.items()}, tags=[], - ) + ), ) new_dataset_mces.append( - MetadataChangeEvent(proposedSnapshot=dataset_snapshot) + MetadataChangeEvent(proposedSnapshot=dataset_snapshot), ) new_dataset_ids.append(f"{node['NodeType']}-{node['Id']}") @@ -521,7 +533,8 @@ def process_dataflow_node( # otherwise, a node represents a transformation else: node_urn = mce_builder.make_data_job_urn_with_flow( - flow_urn, job_id=f"{node['NodeType']}-{node['Id']}" + flow_urn, + job_id=f"{node['NodeType']}-{node['Id']}", ) return { @@ -559,7 +572,11 @@ def process_dataflow_graph( # iterate through each node to populate processed nodes for node in dataflow_graph["DagNodes"]: processed_node = self.process_dataflow_node( - node, flow_urn, new_dataset_ids, new_dataset_mces, s3_formats + node, + flow_urn, + new_dataset_ids, + new_dataset_mces, + s3_formats, ) if processed_node is not None: @@ -636,7 +653,7 @@ def get_dataflow_wu(self, flow_urn: str, job: Dict[str, Any]) -> MetadataWorkUni customProperties=custom_props, ), ], - ) + ), ) return MetadataWorkUnit(id=job["Name"], mce=mce) @@ -676,7 +693,7 @@ def get_datajob_wu(self, node: Dict[str, Any], job_name: str) -> MetadataWorkUni inputDatajobs=node["inputDatajobs"], ), ], - ) + ), ) return MetadataWorkUnit(id=f"{job_name}-{node['Id']}", mce=mce) @@ -688,7 +705,7 @@ def get_all_databases(self) -> Iterable[Mapping[str, Any]]: if self.source_config.catalog_id: paginator_response = paginator.paginate( - CatalogId=self.source_config.catalog_id + CatalogId=self.source_config.catalog_id, ) else: paginator_response = paginator.paginate() @@ -717,7 +734,8 @@ def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict if self.source_config.catalog_id: paginator_response = paginator.paginate( - DatabaseName=database_name, CatalogId=self.source_config.catalog_id + DatabaseName=database_name, + CatalogId=self.source_config.catalog_id, ) else: paginator_response = paginator.paginate(DatabaseName=database_name) @@ -746,7 +764,8 @@ def get_all_databases_and_tables( return all_databases, all_tables def get_lineage_if_enabled( - self, mce: MetadataChangeEventClass + self, + mce: MetadataChangeEventClass, ) -> Optional[MetadataWorkUnit]: if self.source_config.emit_s3_lineage: # extract dataset properties aspect @@ -762,7 +781,8 @@ def get_lineage_if_enabled( location = dataset_properties.customProperties["Location"] if is_s3_uri(location): s3_dataset_urn = make_s3_urn_for_lineage( - location, self.source_config.env + location, + self.source_config.env, ) assert self.ctx.graph schema_metadata_for_s3: Optional[SchemaMetadataClass] = ( @@ -787,7 +807,7 @@ def get_lineage_if_enabled( UpstreamClass( dataset=s3_dataset_urn, type=DatasetLineageTypeClass.COPY, - ) + ), ], fineGrainedLineages=fine_grained_lineages or None, ) @@ -802,8 +822,8 @@ def get_lineage_if_enabled( UpstreamClass( dataset=mce.proposedSnapshot.urn, type=DatasetLineageTypeClass.COPY, - ) - ] + ), + ], ) return MetadataChangeProposalWrapper( entityUrn=s3_dataset_urn, @@ -839,17 +859,18 @@ def simplify_field_path(field_path): downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ mce_builder.make_schema_field_urn( - dataset_urn, field_path_v1 - ) + dataset_urn, + field_path_v1, + ), ], upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, upstreams=[ mce_builder.make_schema_field_urn( s3_dataset_urn, simplify_field_path(matching_s3_field.fieldPath), - ) + ), ], - ) + ), ) return fine_grained_lineages return None @@ -869,11 +890,11 @@ def _create_profile_mcp( # Inject table level stats if self.source_config.profiling.row_count in table_stats: dataset_profile.rowCount = int( - float(table_stats[self.source_config.profiling.row_count]) + float(table_stats[self.source_config.profiling.row_count]), ) if self.source_config.profiling.column_count in table_stats: dataset_profile.columnCount = int( - float(table_stats[self.source_config.profiling.column_count]) + float(table_stats[self.source_config.profiling.column_count]), ) # inject column level stats @@ -894,19 +915,19 @@ def _create_profile_mcp( if not self.source_config.profiling.profile_table_level_only: if self.source_config.profiling.unique_count in column_params: column_profile.uniqueCount = int( - float(column_params[self.source_config.profiling.unique_count]) + float(column_params[self.source_config.profiling.unique_count]), ) if self.source_config.profiling.unique_proportion in column_params: column_profile.uniqueProportion = float( - column_params[self.source_config.profiling.unique_proportion] + column_params[self.source_config.profiling.unique_proportion], ) if self.source_config.profiling.null_count in column_params: column_profile.nullCount = int( - float(column_params[self.source_config.profiling.null_count]) + float(column_params[self.source_config.profiling.null_count]), ) if self.source_config.profiling.null_proportion in column_params: column_profile.nullProportion = float( - column_params[self.source_config.profiling.null_proportion] + column_params[self.source_config.profiling.null_proportion], ) if self.source_config.profiling.min in column_params: column_profile.min = column_params[self.source_config.profiling.min] @@ -941,7 +962,10 @@ def _create_profile_mcp( return mcp def get_profile_if_enabled( - self, mce: MetadataChangeEventClass, database_name: str, table_name: str + self, + mce: MetadataChangeEventClass, + database_name: str, + table_name: str, ) -> Iterable[MetadataWorkUnit]: if self.source_config.is_profiling_enabled(): # for cross-account ingestion @@ -951,7 +975,7 @@ def get_profile_if_enabled( CatalogId=self.source_config.catalog_id, ) response = self.glue_client.get_table( - **{k: v for k, v in kwargs.items() if v} + **{k: v for k, v in kwargs.items() if v}, ) # check if this table is partitioned @@ -965,7 +989,7 @@ def get_profile_if_enabled( CatalogId=self.source_config.catalog_id, ) response = self.glue_client.get_partitions( - **{k: v for k, v in kwargs.items() if v} + **{k: v for k, v in kwargs.items() if v}, ) partitions = response["Partitions"] @@ -979,10 +1003,13 @@ def get_profile_if_enabled( partition_spec = str({partition_keys[0]: p["Values"][0]}) if self.source_config.profiling.partition_patterns.allowed( - partition_spec + partition_spec, ): yield self._create_profile_mcp( - mce, table_stats, column_stats, partition_spec + mce, + table_stats, + column_stats, + partition_spec, ).as_workunit() else: continue @@ -991,7 +1018,9 @@ def get_profile_if_enabled( table_stats = response["Table"]["Parameters"] column_stats = response["Table"]["StorageDescriptor"]["Columns"] yield self._create_profile_mcp( - mce, table_stats, column_stats + mce, + table_stats, + column_stats, ).as_workunit() def gen_database_key(self, database: str) -> DatabaseKey: @@ -1004,7 +1033,8 @@ def gen_database_key(self, database: str) -> DatabaseKey: ) def gen_database_containers( - self, database: Mapping[str, Any] + self, + database: Mapping[str, Any], ) -> Iterable[MetadataWorkUnit]: domain_urn = self._gen_domain_urn(database["Name"]) database_container_key = self.gen_database_key(database["Name"]) @@ -1021,13 +1051,16 @@ def gen_database_containers( domain_urn=domain_urn, description=database.get("Description"), qualified_name=self.get_glue_arn( - account_id=database["CatalogId"], database=database["Name"] + account_id=database["CatalogId"], + database=database["Name"], ), extra_properties=parameters, ) def add_table_to_database_container( - self, dataset_urn: str, db_name: str + self, + dataset_urn: str, + db_name: str, ) -> Iterable[MetadataWorkUnit]: database_container_key = self.gen_database_key(db_name) yield from add_dataset_to_container( @@ -1043,7 +1076,9 @@ def _gen_domain_urn(self, dataset_name: str) -> Optional[str]: return None def _get_domain_wu( - self, dataset_name: str, entity_urn: str + self, + dataset_name: str, + entity_urn: str, ) -> Iterable[MetadataWorkUnit]: domain_urn = self._gen_domain_urn(dataset_name) if domain_urn: @@ -1056,7 +1091,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] @@ -1085,7 +1122,7 @@ def _gen_table_wu(self, table: Dict) -> Iterable[MetadataWorkUnit]: full_table_name = f"{database_name}.{table_name}" self.report.report_table_scanned() if not self.source_config.database_pattern.allowed( - database_name + database_name, ) or not self.source_config.table_pattern.allowed(full_table_name): self.report.report_table_dropped(full_table_name) return @@ -1112,7 +1149,8 @@ def _gen_table_wu(self, table: Dict) -> Iterable[MetadataWorkUnit]: entity_urn=dataset_urn, ) yield from self.add_table_to_database_container( - dataset_urn=dataset_urn, db_name=database_name + dataset_urn=dataset_urn, + db_name=database_name, ) wu = self.get_lineage_if_enabled(mce) @@ -1133,7 +1171,9 @@ def _transform_extraction(self) -> Iterable[MetadataWorkUnit]: flow_names: Dict[str, str] = {} for job in self.get_all_jobs(): flow_urn = mce_builder.make_data_flow_urn( - self.platform, job["Name"], self.env + self.platform, + job["Name"], + self.env, ) yield self.get_dataflow_wu(flow_urn, job) @@ -1164,7 +1204,9 @@ def _transform_extraction(self) -> Iterable[MetadataWorkUnit]: continue nodes, new_dataset_ids, new_dataset_mces = self.process_dataflow_graph( - dag, flow_urn, s3_formats + dag, + flow_urn, + s3_formats, ) if not nodes: @@ -1184,10 +1226,13 @@ def _transform_extraction(self) -> Iterable[MetadataWorkUnit]: # flake8: noqa: C901 def _extract_record( - self, dataset_urn: str, table: Dict, table_name: str + self, + dataset_urn: str, + table: Dict, + table_name: str, ) -> MetadataChangeEvent: logger.debug( - f"extract record from table={table_name} for dataset={dataset_urn}" + f"extract record from table={table_name} for dataset={dataset_urn}", ) def get_dataset_properties() -> DatasetPropertiesClass: @@ -1217,7 +1262,7 @@ def get_s3_tags() -> Optional[GlobalTagsClass]: if table.get("StorageDescriptor", {}).get("Location") is None: return None bucket_name = s3_util.get_bucket_name( - table["StorageDescriptor"]["Location"] + table["StorageDescriptor"]["Location"], ) tags_to_add = [] if self.source_config.use_s3_bucket_tags: @@ -1227,16 +1272,17 @@ def get_s3_tags() -> Optional[GlobalTagsClass]: [ make_tag_urn(f"""{tag["Key"]}:{tag["Value"]}""") for tag in bucket_tags["TagSet"] - ] + ], ) except self.s3_client.exceptions.ClientError: logger.warning(f"No tags found for bucket={bucket_name}") if self.source_config.use_s3_object_tags: key_prefix = s3_util.get_key_prefix( - table["StorageDescriptor"]["Location"] + table["StorageDescriptor"]["Location"], ) object_tagging = self.s3_client.get_object_tagging( - Bucket=bucket_name, Key=key_prefix + Bucket=bucket_name, + Key=key_prefix, ) tag_set = object_tagging["TagSet"] if tag_set: @@ -1244,19 +1290,19 @@ def get_s3_tags() -> Optional[GlobalTagsClass]: [ make_tag_urn(f"""{tag["Key"]}:{tag["Value"]}""") for tag in tag_set - ] + ], ) else: # Unlike bucket tags, if an object does not have tags, it will just return an empty array # as opposed to an exception. logger.warning( - f"No tags found for bucket={bucket_name} key={key_prefix}" + f"No tags found for bucket={bucket_name} key={key_prefix}", ) if len(tags_to_add) == 0: return None if self.ctx.graph is not None: logger.debug( - "Connected to DatahubApi, grabbing current tags to maintain." + "Connected to DatahubApi, grabbing current tags to maintain.", ) current_tags: Optional[GlobalTagsClass] = self.ctx.graph.get_aspect( entity_urn=dataset_urn, @@ -1264,21 +1310,23 @@ def get_s3_tags() -> Optional[GlobalTagsClass]: ) if current_tags: tags_to_add.extend( - [current_tag.tag for current_tag in current_tags.tags] + [current_tag.tag for current_tag in current_tags.tags], ) else: logger.warning( - "Could not connect to DatahubApi. No current tags to maintain" + "Could not connect to DatahubApi. No current tags to maintain", ) # Remove duplicate tags tags_to_add = sorted(list(set(tags_to_add))) new_tags = GlobalTagsClass( - tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add] + tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add], ) return new_tags def _is_delta_schema( - provider: str, num_parts: int, columns: Optional[List[Mapping[str, Any]]] + provider: str, + num_parts: int, + columns: Optional[List[Mapping[str, Any]]], ) -> bool: return ( (self.source_config.extract_delta_schema_from_parameters is True) @@ -1297,8 +1345,9 @@ def get_schema_metadata() -> Optional[SchemaMetadata]: provider = table.get("Parameters", {}).get("spark.sql.sources.provider", "") num_parts = int( table.get("Parameters", {}).get( - "spark.sql.sources.schema.numParts", "0" - ) + "spark.sql.sources.schema.numParts", + "0", + ), ) columns = table.get("StorageDescriptor", {}).get("Columns", [{}]) @@ -1356,7 +1405,7 @@ def _get_delta_schema_metadata() -> Optional[SchemaMetadata]: [ table["Parameters"][f"spark.sql.sources.schema.part.{i}"] for i in range(numParts) - ] + ], ) schema_json = json.loads(schema_str) fields: List[SchemaField] = [] @@ -1394,7 +1443,8 @@ def get_data_platform_instance() -> DataPlatformInstanceClass: platform=make_data_platform_urn(self.platform), instance=( make_dataplatform_instance_urn( - self.platform, self.source_config.platform_instance + self.platform, + self.source_config.platform_instance, ) if self.source_config.platform_instance else None @@ -1408,7 +1458,7 @@ def _get_ownership(owner: str) -> Optional[OwnershipClass]: OwnerClass( owner=mce_builder.make_user_urn(owner), type=OwnershipTypeClass.DATAOWNER, - ) + ), ] return OwnershipClass( owners=owners, diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/s3_boto_utils.py b/metadata-ingestion/src/datahub/ingestion/source/aws/s3_boto_utils.py index 87a6f8a5baf2e3..e9448e63cd640c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/s3_boto_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/s3_boto_utils.py @@ -37,7 +37,7 @@ def get_s3_tags( [ make_tag_urn(f"""{tag["Key"]}:{tag["Value"]}""") for tag in bucket.Tagging().tag_set - ] + ], ) except s3.meta.client.exceptions.ClientError: logger.warning(f"No tags found for bucket={bucket_name}") @@ -48,7 +48,7 @@ def get_s3_tags( tag_set = object_tagging["TagSet"] if tag_set: tags_to_add.extend( - [make_tag_urn(f"""{tag["Key"]}:{tag["Value"]}""") for tag in tag_set] + [make_tag_urn(f"""{tag["Key"]}:{tag["Value"]}""") for tag in tag_set], ) else: # Unlike bucket tags, if an object does not have tags, it will just return an empty array @@ -69,13 +69,14 @@ def get_s3_tags( # Remove duplicate tags tags_to_add = sorted(list(set(tags_to_add))) new_tags = GlobalTagsClass( - tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add] + tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add], ) return new_tags def list_folders_path( - s3_uri: str, aws_config: Optional[AwsConnectionConfig] + s3_uri: str, + aws_config: Optional[AwsConnectionConfig], ) -> Iterable[str]: if not is_s3_uri(s3_uri): raise ValueError("Not a s3 URI: " + s3_uri) @@ -87,7 +88,9 @@ def list_folders_path( def list_folders( - bucket_name: str, prefix: str, aws_config: Optional[AwsConnectionConfig] + bucket_name: str, + prefix: str, + aws_config: Optional[AwsConnectionConfig], ) -> Iterable[str]: if aws_config is None: raise ValueError("aws_config not set. Cannot browse s3") diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/s3_util.py b/metadata-ingestion/src/datahub/ingestion/source/aws/s3_util.py index 360f18aa448f27..e339c592d4e5bb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/s3_util.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/s3_util.py @@ -29,7 +29,7 @@ def strip_s3_prefix(s3_uri: str) -> str: s3_prefix = get_s3_prefix(s3_uri) if not s3_prefix: raise ValueError( - f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}" + f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}", ) return s3_uri[len(s3_prefix) :] @@ -62,7 +62,7 @@ def make_s3_urn_for_lineage(s3_uri: str, env: str) -> str: def get_bucket_name(s3_uri: str) -> str: if not is_s3_uri(s3_uri): raise ValueError( - f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}" + f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}", ) return strip_s3_prefix(s3_uri).split("/")[0] @@ -70,7 +70,7 @@ def get_bucket_name(s3_uri: str) -> str: def get_key_prefix(s3_uri: str) -> str: if not is_s3_uri(s3_uri): raise ValueError( - f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}" + f"Not an S3 URI. Must start with one of the following prefixes: {str(S3_PREFIXES)}", ) return strip_s3_prefix(s3_uri).split("/", maxsplit=1)[1] diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py index 55b8f4d889072d..a985e78e868a6b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py @@ -73,7 +73,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] @@ -81,7 +83,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: logger.info("Starting SageMaker ingestion...") # get common lineage graph lineage_processor = LineageProcessor( - sagemaker_client=self.sagemaker_client, env=self.env, report=self.report + sagemaker_client=self.sagemaker_client, + env=self.env, + report=self.report, ) lineage = lineage_processor.get_lineage() @@ -89,12 +93,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.source_config.extract_feature_groups: logger.info("Extracting feature groups...") feature_group_processor = FeatureGroupProcessor( - sagemaker_client=self.sagemaker_client, env=self.env, report=self.report + sagemaker_client=self.sagemaker_client, + env=self.env, + report=self.report, ) yield from feature_group_processor.get_workunits() model_image_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = defaultdict( - dict + dict, ) model_name_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = defaultdict(dict) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py index 73d8d33dd11be7..0f88be6025e467 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py @@ -16,13 +16,16 @@ class SagemakerSourceConfig( StatefulIngestionConfigBase, ): extract_feature_groups: Optional[bool] = Field( - default=True, description="Whether to extract feature groups." + default=True, + description="Whether to extract feature groups.", ) extract_models: Optional[bool] = Field( - default=True, description="Whether to extract models." + default=True, + description="Whether to extract models.", ) extract_jobs: Optional[Union[Dict[str, str], bool]] = Field( - default=True, description="Whether to extract AutoML jobs." + default=True, + description="Whether to extract AutoML jobs.", ) # Custom Stateful Ingestion settings stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py index d46d1c099383fe..49ab7139490b14 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py @@ -54,7 +54,8 @@ def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]: return feature_groups def get_feature_group_details( - self, feature_group_name: str + self, + feature_group_name: str, ) -> "DescribeFeatureGroupResponseTypeDef": """ Get details of a feature group (including list of component features). @@ -62,7 +63,7 @@ def get_feature_group_details( # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group feature_group = self.sagemaker_client.describe_feature_group( - FeatureGroupName=feature_group_name + FeatureGroupName=feature_group_name, ) # use falsy fallback since AWS stubs require this to be a string in tests @@ -71,7 +72,8 @@ def get_feature_group_details( # paginate over feature group features while next_token: next_features = self.sagemaker_client.describe_feature_group( - FeatureGroupName=feature_group_name, NextToken=next_token + FeatureGroupName=feature_group_name, + NextToken=next_token, ) feature_group["FeatureDefinitions"] += next_features["FeatureDefinitions"] next_token = feature_group.get("NextToken", "") @@ -79,7 +81,8 @@ def get_feature_group_details( return feature_group def get_feature_group_wu( - self, feature_group_details: "DescribeFeatureGroupResponseTypeDef" + self, + feature_group_details: "DescribeFeatureGroupResponseTypeDef", ) -> MetadataWorkUnit: """ Generate an MLFeatureTable workunit for a SageMaker feature group. @@ -116,7 +119,7 @@ def get_feature_group_wu( builder.make_ml_primary_key_urn( feature_group_name, feature_group_details["RecordIdentifierFeatureName"], - ) + ), ], # additional metadata customProperties={ @@ -124,7 +127,7 @@ def get_feature_group_wu( "creation_time": str(feature_group_details["CreationTime"]), "status": feature_group_details["FeatureGroupStatus"], }, - ) + ), ) # make the MCE and workunit @@ -142,7 +145,8 @@ def get_feature_type(self, aws_type: str, feature_name: str) -> str: if mapped_type is None: self.report.report_warning( - feature_name, f"unable to map type {aws_type} to metadata schema" + feature_name, + f"unable to map type {aws_type} to metadata schema", ) mapped_type = MLFeatureDataType.UNKNOWN @@ -187,7 +191,7 @@ def get_feature_wu( "s3", s3_name, self.env, - ) + ), ) if "DataCatalogConfig" in feature_group_details["OfflineStoreConfig"]: @@ -205,12 +209,12 @@ def get_feature_wu( textwrap.dedent( f"""Note: table {full_table_name} is an AWS Glue object. This source does not ingest all metadata for Glue tables. To view full table metadata, run Glue ingestion - (see https://datahubproject.io/docs/generated/ingestion/sources/glue)""" - ) + (see https://datahubproject.io/docs/generated/ingestion/sources/glue)""", + ), ) feature_sources.append( - f"urn:li:dataset:(urn:li:dataPlatform:glue,{full_table_name},{self.env})" + f"urn:li:dataset:(urn:li:dataPlatform:glue,{full_table_name},{self.env})", ) # note that there's also an OnlineStoreConfig field, but this @@ -227,7 +231,8 @@ def get_feature_wu( aspects=[ MLPrimaryKeyPropertiesClass( dataType=self.get_feature_type( - feature["FeatureType"], feature["FeatureName"] + feature["FeatureType"], + feature["FeatureName"], ), sources=feature_sources, ), @@ -246,10 +251,11 @@ def get_feature_wu( aspects=[ MLFeaturePropertiesClass( dataType=self.get_feature_type( - feature["FeatureType"], feature["FeatureName"] + feature["FeatureType"], + feature["FeatureName"], ), sources=feature_sources, - ) + ), ], ) @@ -266,7 +272,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: for feature_group in feature_groups: feature_group_details = self.get_feature_group_details( - feature_group["FeatureGroupName"] + feature_group["FeatureGroupName"], ) for feature in feature_group_details["FeatureDefinitions"]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py index be0a99c6d32346..7a11965125a0bd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py @@ -89,7 +89,9 @@ class JobType(Enum): def make_sagemaker_flow_urn(job_type: str, job_name: str, env: str) -> str: return mce_builder.make_data_flow_urn( - orchestrator="sagemaker", flow_id=f"{job_type}:{job_name}", cluster=env + orchestrator="sagemaker", + flow_id=f"{job_type}:{job_name}", + cluster=env, ) @@ -164,12 +166,12 @@ class JobProcessor: # map from model image file path to jobs referencing the model model_image_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = field( - default_factory=lambda: defaultdict(dict) + default_factory=lambda: defaultdict(dict), ) # map from model name to jobs referencing the model model_name_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = field( - default_factory=lambda: defaultdict(dict) + default_factory=lambda: defaultdict(dict), ) def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]: @@ -273,7 +275,7 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]: describe_name_key = job_type_to_info[job_type].describe_name_key return getattr(self.sagemaker_client(), describe_command)( - **{describe_name_key: job_name} + **{describe_name_key: job_name}, ) def get_workunits(self) -> Iterable[MetadataWorkUnit]: @@ -300,7 +302,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: # - move output jobs to inputs # - aggregate i/o datasets logger.info( - "second pass: move output jobs to inputs and aggregate i/o datasets" + "second pass: move output jobs to inputs and aggregate i/o datasets", ) for job_urn in sorted(processed_jobs): processed_job = processed_jobs[job_urn] @@ -322,7 +324,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: DatasetPropertiesClass( customProperties={k: str(v) for k, v in dataset.items()}, tags=[], - ) + ), ) dataset_mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) yield MetadataWorkUnit( @@ -338,7 +340,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: job_snapshot = processed_job.job_snapshot flow_urn = make_sagemaker_flow_urn( - processed_job.job_type.value, processed_job.job_name, self.env + processed_job.job_type.value, + processed_job.job_name, + self.env, ) # create flow for each job @@ -350,7 +354,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: name=processed_job.job_name, ), ], - ) + ), ) yield MetadataWorkUnit( id=flow_urn, @@ -362,7 +366,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: inputDatasets=sorted(list(processed_job.input_datasets.keys())), outputDatasets=sorted(list(processed_job.output_datasets.keys())), inputDatajobs=sorted(list(processed_job.input_jobs)), - ) + ), ) job_mce = MetadataChangeEvent(proposedSnapshot=job_snapshot) @@ -558,7 +562,7 @@ def process_hyper_parameter_tuning_job( training_job["DefinitionName"], self.name_to_arn[full_job_name], self.env, - ) + ), ) else: self.report.report_warning( @@ -612,7 +616,7 @@ def process_labeling_job(self, job: Dict[str, Any]) -> SageMakerJob: output_datasets = {} output_s3_uri: Optional[str] = job.get("LabelingJobOutput", {}).get( - "OutputDatasetS3Uri" + "OutputDatasetS3Uri", ) if output_s3_uri is not None: output_datasets[make_s3_urn(output_s3_uri, self.env)] = { @@ -620,7 +624,7 @@ def process_labeling_job(self, job: Dict[str, Any]) -> SageMakerJob: "uri": output_s3_uri, } output_config_s3_uri: Optional[str] = job.get("OutputConfig", {}).get( - "S3OutputPath" + "S3OutputPath", ) if output_config_s3_uri is not None: output_datasets[make_s3_urn(output_config_s3_uri, self.env)] = { @@ -663,19 +667,26 @@ def process_processing_job(self, job: Dict[str, Any]) -> SageMakerJob: if auto_ml_type is not None and auto_ml_name is not None: input_jobs.add( make_sagemaker_job_urn( - auto_ml_type, auto_ml_name, auto_ml_arn, self.env - ) + auto_ml_type, + auto_ml_name, + auto_ml_arn, + self.env, + ), ) if training_arn is not None: training_type, training_name = self.arn_to_name.get( - training_arn, (None, None) + training_arn, + (None, None), ) if training_type is not None and training_name is not None: input_jobs.add( make_sagemaker_job_urn( - training_type, training_name, training_arn, self.env - ) + training_type, + training_name, + training_arn, + self.env, + ), ) input_datasets = {} @@ -712,7 +723,8 @@ def process_processing_job(self, job: Dict[str, Any]) -> SageMakerJob: # ) outputs: List[Dict[str, Any]] = job.get("ProcessingOutputConfig", {}).get( - "Outputs", [] + "Outputs", + [], ) output_datasets = {} @@ -729,12 +741,13 @@ def process_processing_job(self, job: Dict[str, Any]) -> SageMakerJob: } output_feature_group = output.get("FeatureStoreOutput", {}).get( - "FeatureGroupName" + "FeatureGroupName", ) if output_feature_group is not None: output_datasets[ mce_builder.make_ml_feature_table_urn( - "sagemaker", output_feature_group + "sagemaker", + output_feature_group, ) ] = { "dataset_type": "sagemaker_feature_group", @@ -788,7 +801,7 @@ def process_training_job(self, job: Dict[str, Any]) -> SageMakerJob: checkpoint_s3_uri = job.get("CheckpointConfig", {}).get("S3Uri") debug_s3_path = job.get("DebugHookConfig", {}).get("S3OutputPath") tensorboard_output_path = job.get("TensorBoardOutputConfig", {}).get( - "S3OutputPath" + "S3OutputPath", ) profiler_output_path = job.get("ProfilerConfig", {}).get("S3OutputPath") @@ -830,7 +843,9 @@ def process_training_job(self, job: Dict[str, Any]) -> SageMakerJob: job_metrics = job.get("FinalMetricDataList", []) # sort first by metric name, then from latest -> earliest sorted_metrics = sorted( - job_metrics, key=lambda x: (x["MetricName"], x["Timestamp"]), reverse=True + job_metrics, + key=lambda x: (x["MetricName"], x["Timestamp"]), + reverse=True, ) # extract the last recorded metric values latest_metrics = [] @@ -844,7 +859,7 @@ def process_training_job(self, job: Dict[str, Any]) -> SageMakerJob: zip( [metric["MetricName"] for metric in latest_metrics], [metric["Value"] for metric in latest_metrics], - ) + ), ) if model_data_url is not None: @@ -908,14 +923,18 @@ def process_transform_job(self, job: Dict[str, Any]) -> SageMakerJob: if labeling_arn is not None: labeling_type, labeling_name = self.arn_to_name.get( - labeling_arn, (None, None) + labeling_arn, + (None, None), ) if labeling_type is not None and labeling_name is not None: input_jobs.add( make_sagemaker_job_urn( - labeling_type, labeling_name, labeling_arn, self.env - ) + labeling_type, + labeling_name, + labeling_arn, + self.env, + ), ) if auto_ml_arn is not None: @@ -924,8 +943,11 @@ def process_transform_job(self, job: Dict[str, Any]) -> SageMakerJob: if auto_ml_type is not None and auto_ml_name is not None: input_jobs.add( make_sagemaker_job_urn( - auto_ml_type, auto_ml_name, auto_ml_arn, self.env - ) + auto_ml_type, + auto_ml_name, + auto_ml_arn, + self.env, + ), ) job_snapshot, job_name, job_arn = self.create_common_job_snapshot( diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py index 24e5497269c738..85f9ed680181c0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py @@ -27,20 +27,20 @@ class LineageInfo: # map from model URIs to deployed endpoints model_uri_endpoints: DefaultDict[str, Set[str]] = field( - default_factory=lambda: defaultdict(set) + default_factory=lambda: defaultdict(set), ) # map from model images to deployed endpoints model_image_endpoints: DefaultDict[str, Set[str]] = field( - default_factory=lambda: defaultdict(set) + default_factory=lambda: defaultdict(set), ) # map from group ARNs to model URIs model_uri_to_groups: DefaultDict[str, Set[str]] = field( - default_factory=lambda: defaultdict(set) + default_factory=lambda: defaultdict(set), ) # map from group ARNs to model images model_image_to_groups: DefaultDict[str, Set[str]] = field( - default_factory=lambda: defaultdict(set) + default_factory=lambda: defaultdict(set), ) @@ -173,7 +173,9 @@ def get_model_deployment_lineage(self, deployment_node_arn: str) -> None: self.lineage_info.model_image_endpoints[model_image] |= model_endpoints def get_model_group_lineage( - self, model_group_node_arn: str, node: Dict[str, Any] + self, + model_group_node_arn: str, + node: Dict[str, Any], ) -> None: """ Get the lineage of a model group (models part of the group). @@ -193,7 +195,7 @@ def get_model_group_lineage( # if edge is a model package, then look for models in its source edges if edge["SourceType"] == "Model": model_package_incoming_edges = self.get_incoming_edges( - edge["SourceArn"] + edge["SourceArn"], ) # check incoming edges for models under the model package @@ -211,7 +213,7 @@ def get_model_group_lineage( and source_uri is not None ): self.lineage_info.model_uri_to_groups[source_uri].add( - model_group_arn + model_group_arn, ) # add model_group_arn -> model_image mapping @@ -220,7 +222,7 @@ def get_model_group_lineage( and source_uri is not None ): self.lineage_info.model_image_to_groups[source_uri].add( - model_group_arn + model_group_arn, ) def get_lineage(self) -> LineageInfo: diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py index f1374117af775f..4a5b3b71c6ebe7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py @@ -79,12 +79,12 @@ class ModelProcessor: # map from model image file path to jobs referencing the model model_image_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = field( - default_factory=lambda: defaultdict(dict) + default_factory=lambda: defaultdict(dict), ) # map from model name to jobs referencing the model model_name_to_jobs: DefaultDict[str, Dict[JobKey, ModelJob]] = field( - default_factory=lambda: defaultdict(dict) + default_factory=lambda: defaultdict(dict), ) # map from model uri to model name @@ -130,7 +130,8 @@ def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]: return groups def get_group_details( - self, group_name: str + self, + group_name: str, ) -> "DescribeModelPackageGroupOutputTypeDef": """ Get details of a model group. @@ -138,7 +139,7 @@ def get_group_details( # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model_package_group return self.sagemaker_client.describe_model_package_group( - ModelPackageGroupName=group_name + ModelPackageGroupName=group_name, ) def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]: @@ -153,13 +154,17 @@ def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]: return endpoints def get_endpoint_details( - self, endpoint_name: str + self, + endpoint_name: str, ) -> "DescribeEndpointOutputTypeDef": # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) def get_endpoint_status( - self, endpoint_name: str, endpoint_arn: str, sagemaker_status: str + self, + endpoint_name: str, + endpoint_arn: str, + sagemaker_status: str, ) -> str: endpoint_status = ENDPOINT_STATUS_MAP.get(sagemaker_status) @@ -174,7 +179,8 @@ def get_endpoint_status( return endpoint_status def get_endpoint_wu( - self, endpoint_details: "DescribeEndpointOutputTypeDef" + self, + endpoint_details: "DescribeEndpointOutputTypeDef", ) -> MetadataWorkUnit: """a Get a workunit for an endpoint. @@ -185,13 +191,15 @@ def get_endpoint_wu( endpoint_snapshot = MLModelDeploymentSnapshot( urn=builder.make_ml_model_deployment_urn( - "sagemaker", endpoint_details["EndpointName"], self.env + "sagemaker", + endpoint_details["EndpointName"], + self.env, ), aspects=[ MLModelDeploymentPropertiesClass( createdAt=int( endpoint_details.get("CreationTime", datetime.now()).timestamp() - * 1000 + * 1000, ), status=self.get_endpoint_status( endpoint_details["EndpointArn"], @@ -204,7 +212,7 @@ def get_endpoint_wu( for key, value in endpoint_details.items() if key not in redundant_fields }, - ) + ), ], ) @@ -241,13 +249,14 @@ def get_model_endpoints( # sort endpoints and groups for consistency model_endpoints_sorted = sorted( - [x for x in model_endpoints if x in endpoint_arn_to_name] + [x for x in model_endpoints if x in endpoint_arn_to_name], ) return model_endpoints_sorted def get_group_wu( - self, group_details: "DescribeModelPackageGroupOutputTypeDef" + self, + group_details: "DescribeModelPackageGroupOutputTypeDef", ) -> MetadataWorkUnit: """ Get a workunit for a model group. @@ -268,7 +277,7 @@ def get_group_wu( OwnerClass( owner=f"urn:li:corpuser:{group_details['CreatedBy']['UserProfileName']}", type=OwnershipTypeClass.DATAOWNER, - ) + ), ) group_snapshot = MLModelGroupSnapshot( @@ -277,7 +286,7 @@ def get_group_wu( MLModelGroupPropertiesClass( createdAt=int( group_details.get("CreationTime", datetime.now()).timestamp() - * 1000 + * 1000, ), description=group_details.get("ModelPackageGroupDescription"), customProperties={ @@ -297,7 +306,8 @@ def get_group_wu( return MetadataWorkUnit(id=group_name, mce=mce) def match_model_jobs( - self, model_details: "DescribeModelOutputTypeDef" + self, + model_details: "DescribeModelOutputTypeDef", ) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]: model_training_jobs: Set[str] = set() model_downstream_jobs: Set[str] = set() @@ -325,7 +335,7 @@ def match_model_jobs( job_urn for job_urn, job_direction in data_url_matched_jobs.keys() if job_direction == JobDirection.TRAINING - } + }, ) # extend set of downstream jobs model_downstream_jobs = model_downstream_jobs.union( @@ -333,7 +343,7 @@ def match_model_jobs( job_urn for job_urn, job_direction in data_url_matched_jobs.keys() if job_direction == JobDirection.DOWNSTREAM - } + }, ) for job_key, job_info in data_url_matched_jobs.items(): @@ -370,7 +380,7 @@ def strip_quotes(string: str) -> str: job_urn for job_urn, job_direction in name_matched_jobs.keys() if job_direction == JobDirection.TRAINING - } + }, ) # extend set of downstream jobs model_downstream_jobs = model_downstream_jobs.union( @@ -378,7 +388,7 @@ def strip_quotes(string: str) -> str: job_urn for job_urn, job_direction in name_matched_jobs.keys() if job_direction == JobDirection.DOWNSTREAM - } + }, ) return ( @@ -404,7 +414,7 @@ def get_group_name_from_arn(arn: str) -> str: 'my-model-group' """ logger.debug( - f"Extracting group name from ARN: {arn} because group was not seen before" + f"Extracting group name from ARN: {arn} because group was not seen before", ) return arn.split("/")[-1] @@ -424,7 +434,10 @@ def get_model_wu( model_uri = model_details.get("PrimaryContainer", {}).get("ModelDataUrl") model_endpoints_sorted = self.get_model_endpoints( - model_details, endpoint_arn_to_name, model_image, model_uri + model_details, + endpoint_arn_to_name, + model_image, + model_uri, ) ( @@ -442,7 +455,8 @@ def get_model_wu( model_image_groups: Set[str] = set() if model_image is not None: model_image_groups = self.lineage.model_image_to_groups.get( - model_image, set() + model_image, + set(), ) model_group_arns = model_uri_groups | model_image_groups @@ -453,7 +467,7 @@ def get_model_wu( if x in self.group_arn_to_name else self.get_group_name_from_arn(x) for x in model_group_arns - ] + ], ) model_group_urns = [ @@ -469,17 +483,21 @@ def get_model_wu( model_snapshot = MLModelSnapshot( urn=builder.make_ml_model_urn( - "sagemaker", model_details["ModelName"], self.env + "sagemaker", + model_details["ModelName"], + self.env, ), aspects=[ MLModelPropertiesClass( date=int( model_details.get("CreationTime", datetime.now()).timestamp() - * 1000 + * 1000, ), deployments=[ builder.make_ml_model_deployment_urn( - "sagemaker", endpoint_name, self.env + "sagemaker", + endpoint_name, + self.env, ) for endpoint_name in model_endpoints_sorted ], diff --git a/metadata-ingestion/src/datahub/ingestion/source/azure/abs_folder_utils.py b/metadata-ingestion/src/datahub/ingestion/source/azure/abs_folder_utils.py index ce166f2942dac5..9bf15789b48cf3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/azure/abs_folder_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/azure/abs_folder_utils.py @@ -25,12 +25,12 @@ def get_abs_properties( ) -> Dict[str, str]: if azure_config is None: raise ValueError( - "Azure configuration is not provided. Cannot retrieve container client." + "Azure configuration is not provided. Cannot retrieve container client.", ) blob_service_client = azure_config.get_blob_service_client() container_client = blob_service_client.get_container_client( - container=container_name + container=container_name, ) custom_properties = {"schema_inferred_from": full_path} @@ -39,7 +39,7 @@ def get_abs_properties( { "number_of_files": str(number_of_files), "size_in_bytes": str(size_in_bytes), - } + }, ) if use_abs_blob_properties and blob_name is not None: @@ -61,7 +61,7 @@ def get_abs_properties( ) else: logger.warning( - f"No blob properties found for container={container_name}, blob={blob_name}." + f"No blob properties found for container={container_name}, blob={blob_name}.", ) if use_abs_container_properties: @@ -76,14 +76,17 @@ def get_abs_properties( ) else: logger.warning( - f"No container properties found for container={container_name}." + f"No container properties found for container={container_name}.", ) return custom_properties def add_property( - key: str, value: str, custom_properties: Dict[str, str], resource_name: str + key: str, + value: str, + custom_properties: Dict[str, str], + resource_name: str, ) -> Dict[str, str]: if key in custom_properties: key = f"{key}_{resource_name}" @@ -124,7 +127,7 @@ def create_properties( ) except Exception as exception: logger.debug( - f"Could not create property {key} value {value}, from resource {resource_name}: {exception}." + f"Could not create property {key} value {value}, from resource {resource_name}: {exception}.", ) @@ -139,7 +142,7 @@ def get_abs_tags( # Todo add the service_client, when building out this get_abs_tags if azure_config is None: raise ValueError( - "Azure configuration is not provided. Cannot retrieve container client." + "Azure configuration is not provided. Cannot retrieve container client.", ) tags_to_add: List[str] = [] @@ -176,17 +179,19 @@ def get_abs_tags( tags_to_add = sorted(list(set(tags_to_add))) # Remove duplicate tags new_tags = GlobalTagsClass( - tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add] + tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add], ) return new_tags def list_folders( - container_name: str, prefix: str, azure_config: Optional[AzureConnectionConfig] + container_name: str, + prefix: str, + azure_config: Optional[AzureConnectionConfig], ) -> Iterable[str]: if azure_config is None: raise ValueError( - "Azure configuration is not provided. Cannot retrieve container client." + "Azure configuration is not provided. Cannot retrieve container client.", ) abs_blob_service_client = azure_config.get_blob_service_client() diff --git a/metadata-ingestion/src/datahub/ingestion/source/azure/abs_utils.py b/metadata-ingestion/src/datahub/ingestion/source/azure/abs_utils.py index 042e1b4ef921fb..9b80377b9daa20 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/azure/abs_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/azure/abs_utils.py @@ -5,7 +5,7 @@ # This file should not import any abs spectific modules as we import it in path_spec.py in datat_lake_common.py ABS_PREFIXES_REGEX = re.compile( - r"(http[s]?://[a-z0-9]{3,24}\.blob\.core\.windows\.net/)" + r"(http[s]?://[a-z0-9]{3,24}\.blob\.core\.windows\.net/)", ) @@ -25,7 +25,7 @@ def strip_abs_prefix(abs_uri: str) -> str: abs_prefix = get_abs_prefix(abs_uri) if not abs_prefix: raise ValueError( - f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}" + f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}", ) length_abs_prefix = len(abs_prefix) return abs_uri[length_abs_prefix:] @@ -49,7 +49,7 @@ def make_abs_urn(abs_uri: str, env: str) -> str: def get_container_name(abs_uri: str) -> str: if not is_abs_uri(abs_uri): raise ValueError( - f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}" + f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}", ) return strip_abs_prefix(abs_uri).split("/")[0] @@ -57,7 +57,7 @@ def get_container_name(abs_uri: str) -> str: def get_key_prefix(abs_uri: str) -> str: if not is_abs_uri(abs_uri): raise ValueError( - f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}" + f"Not an Azure Blob Storage URI. Must match the following regular expression: {str(ABS_PREFIXES_REGEX)}", ) return strip_abs_prefix(abs_uri).split("/", maxsplit=1)[1] diff --git a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py index 46de4e09d7ee5b..f7d3ec17e19c5b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py @@ -55,7 +55,7 @@ def get_abfss_url(self, folder_path: str = "") -> str: # TODO DEX-1010 def get_filesystem_client(self) -> FileSystemClient: return self.get_data_lake_service_client().get_file_system_client( - file_system=self.container_name + file_system=self.container_name, ) def get_blob_service_client(self): @@ -94,5 +94,5 @@ def _check_credential_values(cls, values: Dict) -> Dict: ): return values raise ConfigurationError( - "credentials missing, requires one combination of account_key or sas_token or (client_id and client_secret and tenant_id)" + "credentials missing, requires one combination of account_key or sas_token or (client_id and client_secret and tenant_id)", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index ceb010a7f0675f..d0d3b7623399f4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -109,7 +109,8 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): self.domain_registry: Optional[DomainRegistry] = None if self.config.domain: self.domain_registry = DomainRegistry( - cached_domains=[k for k in self.config.domain], graph=self.ctx.graph + cached_domains=[k for k in self.config.domain], + graph=self.ctx.graph, ) BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = ( @@ -178,7 +179,9 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): run_id=self.ctx.run_id, ) self.profiler = BigqueryProfiler( - config, self.report, self.profiling_state_handler + config, + self.report, + self.profiling_state_handler, ) self.bq_schema_extractor = BigQuerySchemaGenerator( @@ -234,10 +237,13 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), functools.partial( - auto_incremental_lineage, self.config.incremental_lineage + auto_incremental_lineage, + self.config.incremental_lineage, ), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -296,7 +302,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: else: if self.config.include_usage_statistics: yield from self.usage_extractor.get_usage_workunits( - [p.id for p in projects], self.bq_schema_extractor.table_refs + [p.id for p in projects], + self.bq_schema_extractor.table_refs, ) if self.config.include_table_lineage: @@ -312,7 +319,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ): for dataset_urn, table in self.bq_schema_extractor.external_tables.items(): yield from self.lineage_extractor.gen_lineage_workunits_for_external_table( - dataset_urn, table.ddl, graph=self.ctx.graph + dataset_urn, + table.ddl, + graph=self.ctx.graph, ) def get_report(self) -> BigQueryV2Report: diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py index d35c5265878c03..dae7ac4a374712 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit.py @@ -102,13 +102,13 @@ def get_table_display_name(self) -> str: if matches: shortened_table_name = matches.group(1) logger.debug( - f"Found table snapshot. Using {shortened_table_name} as the table name." + f"Found table snapshot. Using {shortened_table_name} as the table name.", ) if "$" in shortened_table_name: shortened_table_name = shortened_table_name.split("$", maxsplit=1)[0] logger.debug( - f"Found partitioned table. Using {shortened_table_name} as the table name." + f"Found partitioned table. Using {shortened_table_name} as the table name.", ) table_name, _ = self.get_table_and_shard(shortened_table_name) @@ -151,7 +151,7 @@ class BigQueryTableRef: # Handle table time travel. See https://cloud.google.com/bigquery/docs/time-travel # See https://cloud.google.com/bigquery/docs/table-decorators#time_decorators SNAPSHOT_TABLE_REGEX: ClassVar[Pattern[str]] = re.compile( - "^(.+)@(-?\\d{1,13})(-(-?\\d{1,13})?)?$" + "^(.+)@(-?\\d{1,13})(-(-?\\d{1,13})?)?$", ) table_identifier: BigqueryTableIdentifier @@ -159,7 +159,7 @@ class BigQueryTableRef: @classmethod def from_bigquery_table(cls, table: BigqueryTableIdentifier) -> "BigQueryTableRef": return cls( - BigqueryTableIdentifier(table.project_id, table.dataset, table.table) + BigqueryTableIdentifier(table.project_id, table.dataset, table.table), ) @classmethod @@ -171,8 +171,10 @@ def from_spec_obj(cls, spec: dict) -> "BigQueryTableRef": return cls( # spec dict always has to have projectId, datasetId, tableId otherwise it is an invalid spec BigqueryTableIdentifier( - spec["projectId"], spec["datasetId"], spec["tableId"] - ) + spec["projectId"], + spec["datasetId"], + spec["tableId"], + ), ) @classmethod @@ -209,7 +211,7 @@ def get_sanitized_table_ref(self) -> "BigQueryTableRef": sanitized_table = self.table_identifier.get_table_name() # Handle partitioned and sharded tables. return BigQueryTableRef( - BigqueryTableIdentifier.from_string_name(sanitized_table) + BigqueryTableIdentifier.from_string_name(sanitized_table), ) def __str__(self) -> str: @@ -246,13 +248,15 @@ class QueryEvent: @staticmethod def get_missing_key_entry(entry: AuditLogEntry) -> Optional[str]: return get_first_missing_key( - inp_dict=entry.payload, keys=["serviceData", "jobCompletedEvent", "job"] + inp_dict=entry.payload, + keys=["serviceData", "jobCompletedEvent", "job"], ) @staticmethod def get_missing_key_entry_v2(entry: AuditLogEntry) -> Optional[str]: return get_first_missing_key( - inp_dict=entry.payload, keys=["metadata", "jobChange", "job"] + inp_dict=entry.payload, + keys=["metadata", "jobChange", "job"], ) @staticmethod @@ -272,7 +276,9 @@ def _get_project_id_from_job_name(job_name: str) -> str: @classmethod def from_entry( - cls, entry: AuditLogEntry, debug_include_full_payloads: bool = False + cls, + entry: AuditLogEntry, + debug_include_full_payloads: bool = False, ) -> "QueryEvent": job: Dict = entry.payload["serviceData"]["jobCompletedEvent"]["job"] job_query_conf: Dict = job["jobConfiguration"]["query"] @@ -313,7 +319,7 @@ def from_entry( raw_dest_table = job_query_conf.get("destinationTable") if raw_dest_table: query_event.destinationTable = BigQueryTableRef.from_spec_obj( - raw_dest_table + raw_dest_table, ).get_sanitized_table_ref() # statementType # referencedTables @@ -341,7 +347,7 @@ def from_entry( if not query_event.job_name: logger.debug( "jobName from query events is absent. " - "Auditlog entry - {logEntry}".format(logEntry=entry) + "Auditlog entry - {logEntry}".format(logEntry=entry), ) return query_event @@ -351,17 +357,21 @@ def get_missing_key_exported_bigquery_audit_metadata( row: BigQueryAuditMetadata, ) -> Optional[str]: missing_key = get_first_missing_key_any( - row._xxx_field_to_index, ["timestamp", "protoPayload", "metadata"] + row._xxx_field_to_index, + ["timestamp", "protoPayload", "metadata"], ) if not missing_key: missing_key = get_first_missing_key_any( - json.loads(row["metadata"]), ["jobChange"] + json.loads(row["metadata"]), + ["jobChange"], ) return missing_key @classmethod def from_exported_bigquery_audit_metadata( - cls, row: BigQueryAuditMetadata, debug_include_full_payloads: bool = False + cls, + row: BigQueryAuditMetadata, + debug_include_full_payloads: bool = False, ) -> "QueryEvent": payload: Dict = row["protoPayload"] metadata: Dict = json.loads(row["metadata"]) @@ -404,7 +414,7 @@ def from_exported_bigquery_audit_metadata( raw_dest_table = query_config.get("destinationTable") if raw_dest_table: query_event.destinationTable = BigQueryTableRef.from_string_name( - raw_dest_table + raw_dest_table, ).get_sanitized_table_ref() # referencedTables raw_ref_tables = query_stats.get("referencedTables") @@ -428,7 +438,7 @@ def from_exported_bigquery_audit_metadata( if not query_event.job_name: logger.debug( "jobName from query events is absent. " - "BigQueryAuditMetadata entry - {logEntry}".format(logEntry=row) + "BigQueryAuditMetadata entry - {logEntry}".format(logEntry=row), ) if query_stats.get("totalBilledBytes"): @@ -438,7 +448,9 @@ def from_exported_bigquery_audit_metadata( @classmethod def from_entry_v2( - cls, row: BigQueryAuditMetadata, debug_include_full_payloads: bool = False + cls, + row: BigQueryAuditMetadata, + debug_include_full_payloads: bool = False, ) -> "QueryEvent": payload: Dict = row.payload metadata: Dict = payload["metadata"] @@ -480,7 +492,7 @@ def from_entry_v2( raw_dest_table = query_config.get("destinationTable") if raw_dest_table: query_event.destinationTable = BigQueryTableRef.from_string_name( - raw_dest_table + raw_dest_table, ).get_sanitized_table_ref() # statementType # referencedTables @@ -505,7 +517,7 @@ def from_entry_v2( if not query_event.job_name: logger.debug( "jobName from query events is absent. " - "BigQueryAuditMetadata entry - {logEntry}".format(logEntry=row) + "BigQueryAuditMetadata entry - {logEntry}".format(logEntry=row), ) if query_stats.get("totalBilledBytes"): @@ -541,10 +553,12 @@ class ReadEvent: def get_missing_key_entry(cls, entry: AuditLogEntry) -> Optional[str]: return ( get_first_missing_key( - inp_dict=entry.payload, keys=["metadata", "tableDataRead"] + inp_dict=entry.payload, + keys=["metadata", "tableDataRead"], ) or get_first_missing_key( - inp_dict=entry.payload, keys=["authenticationInfo", "principalEmail"] + inp_dict=entry.payload, + keys=["authenticationInfo", "principalEmail"], ) or get_first_missing_key(inp_dict=entry.payload, keys=["resourceName"]) ) @@ -562,7 +576,9 @@ def get_missing_key_exported_bigquery_audit_metadata( @classmethod def from_entry( - cls, entry: AuditLogEntry, debug_include_full_payloads: bool = False + cls, + entry: AuditLogEntry, + debug_include_full_payloads: bool = False, ) -> "ReadEvent": user = entry.payload["authenticationInfo"]["principalEmail"] resourceName = entry.payload["resourceName"] @@ -577,7 +593,7 @@ def from_entry( jobName = readInfo.get("jobName") resource = BigQueryTableRef.from_string_name( - resourceName + resourceName, ).get_sanitized_table_ref() readEvent = ReadEvent( @@ -592,7 +608,7 @@ def from_entry( if readReason == "JOB" and not jobName: logger.debug( "jobName from read events is absent when readReason is JOB. " - "Auditlog entry - {logEntry}".format(logEntry=entry) + "Auditlog entry - {logEntry}".format(logEntry=entry), ) return readEvent @@ -616,7 +632,9 @@ def from_query_event( @classmethod def from_exported_bigquery_audit_metadata( - cls, row: BigQueryAuditMetadata, debug_include_full_payloads: bool = False + cls, + row: BigQueryAuditMetadata, + debug_include_full_payloads: bool = False, ) -> "ReadEvent": payload = row["protoPayload"] user = payload["authenticationInfo"]["principalEmail"] @@ -633,7 +651,7 @@ def from_exported_bigquery_audit_metadata( jobName = readInfo.get("jobName") resource = BigQueryTableRef.from_string_name( - resourceName + resourceName, ).get_sanitized_table_ref() readEvent = ReadEvent( @@ -648,7 +666,7 @@ def from_exported_bigquery_audit_metadata( if readReason == "JOB" and not jobName: logger.debug( "jobName from read events is absent when readReason is JOB. " - "Auditlog entry - {logEntry}".format(logEntry=row) + "Auditlog entry - {logEntry}".format(logEntry=row), ) return readEvent diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit_log_api.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit_log_api.py index 7d2f8ee0e1fd8d..e8d1037acb304a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit_log_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_audit_log_api.py @@ -69,7 +69,7 @@ def get_exported_bigquery_audit_metadata( self.report.num_get_exported_log_entries_api_requests += 1 for dataset in bigquery_audit_metadata_datasets: logger.info( - f"Start loading log entries from BigQueryAuditMetadata in {dataset}" + f"Start loading log entries from BigQueryAuditMetadata in {dataset}", ) query = bigquery_audit_metadata_query_template( @@ -85,7 +85,7 @@ def get_exported_bigquery_audit_metadata( query_job = bigquery_client.query(query) logger.info( - f"Finished loading log entries from BigQueryAuditMetadata in {dataset}" + f"Finished loading log entries from BigQueryAuditMetadata in {dataset}", ) for entry in query_job: @@ -126,7 +126,7 @@ def get_bigquery_log_entries_via_gcp_logging( for i, entry in enumerate(list_entries): if i > 0 and i % 1000 == 0: logger.info( - f"Loaded {i} log entries from GCP Log for {client.project}" + f"Loaded {i} log entries from GCP Log for {client.project}", ) with current_timer.pause(): @@ -137,5 +137,5 @@ def get_bigquery_log_entries_via_gcp_logging( yield entry logger.info( - f"Finished loading log entries from GCP Log for {client.project}" + f"Finished loading log entries from GCP Log for {client.project}", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py index 57bfa2e3090d31..2854616a01be42 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) DEFAULT_BQ_SCHEMA_PARALLELISM = int( - os.getenv("DATAHUB_BIGQUERY_SCHEMA_PARALLELISM", 20) + os.getenv("DATAHUB_BIGQUERY_SCHEMA_PARALLELISM", 20), ) # Regexp for sharded tables. @@ -47,7 +47,8 @@ class BigQueryBaseConfig(ConfigModel): rate_limit: bool = Field( - default=False, description="Should we rate limit requests made to API." + default=False, + description="Should we rate limit requests made to API.", ) requests_per_min: int = Field( default=60, @@ -71,7 +72,7 @@ def sharded_table_pattern_is_a_valid_regexp(cls, v): re.compile(v) except Exception as e: raise ValueError( - "sharded_table_pattern configuration pattern is invalid." + "sharded_table_pattern configuration pattern is invalid.", ) from e return v @@ -84,7 +85,7 @@ def project_id_backward_compatibility_configs_set(cls, values: Dict) -> Dict: values["project_ids"] = [project_id] elif project_ids and project_id: logging.warning( - "Please use `project_ids` config. Config `project_id` will be ignored." + "Please use `project_ids` config. Config `project_id` will be ignored.", ) return values @@ -111,7 +112,7 @@ class BigQueryCredential(ConfigModel): project_id: str = Field(description="Project id to set the credentials") private_key_id: str = Field(description="Private key id") private_key: str = Field( - description="Private key in a form of '-----BEGIN PRIVATE KEY-----\\nprivate-key\\n-----END PRIVATE KEY-----\\n'" + description="Private key in a form of '-----BEGIN PRIVATE KEY-----\\nprivate-key\\n-----END PRIVATE KEY-----\\n'", ) client_email: str = Field(description="Client email") client_id: str = Field(description="Client Id") @@ -120,7 +121,8 @@ class BigQueryCredential(ConfigModel): description="Authentication uri", ) token_uri: str = Field( - default="https://oauth2.googleapis.com/token", description="Token uri" + default="https://oauth2.googleapis.com/token", + description="Token uri", ) auth_provider_x509_cert_url: str = Field( default="https://www.googleapis.com/oauth2/v1/certs", @@ -151,7 +153,8 @@ def create_credential_temp_file(self) -> str: class BigQueryConnectionConfig(ConfigModel): credential: Optional[BigQueryCredential] = Field( - default=None, description="BigQuery credential informations" + default=None, + description="BigQuery credential informations", ) _credentials_path: Optional[str] = PrivateAttr(None) @@ -172,7 +175,7 @@ def __init__(self, **data: Any): if self.credential: self._credentials_path = self.credential.create_credential_temp_file() logger.debug( - f"Creating temporary credential file at {self._credentials_path}" + f"Creating temporary credential file at {self._credentials_path}", ) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path @@ -187,7 +190,8 @@ def get_policy_tag_manager_client(self) -> datacatalog_v1.PolicyTagManagerClient return datacatalog_v1.PolicyTagManagerClient() def make_gcp_logging_client( - self, project_id: Optional[str] = None + self, + project_id: Optional[str] = None, ) -> GCPLoggingClient: # See https://github.com/googleapis/google-cloud-python/issues/2674 for # why we disable gRPC here. @@ -299,7 +303,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: ): logging.warning( "dataset_pattern is not set but schema_pattern is set, using schema_pattern as dataset_pattern. " - "schema_pattern will be deprecated, please use dataset_pattern instead." + "schema_pattern will be deprecated, please use dataset_pattern instead.", ) values["dataset_pattern"] = schema_pattern dataset_pattern = schema_pattern @@ -309,7 +313,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: ): logging.warning( "schema_pattern will be ignored in favour of dataset_pattern. schema_pattern will be deprecated," - " please use dataset_pattern only." + " please use dataset_pattern only.", ) match_fully_qualified_names = values.get("match_fully_qualified_names") @@ -324,7 +328,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: "Please update `dataset_pattern` to match against fully qualified schema name " "`.` and set config `match_fully_qualified_names : True`." "The config option `match_fully_qualified_names` is deprecated and will be " - "removed in a future release." + "removed in a future release.", ) elif match_fully_qualified_names and dataset_pattern is not None: adjusted = False @@ -339,14 +343,16 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: if adjusted: logger.warning( "`dataset_pattern` was adjusted to match against fully qualified schema names," - " of the form `.`." + " of the form `.`.", ) return values class BigQueryIdentifierConfig( - PlatformInstanceConfigMixin, EnvConfigMixin, LowerCaseDatasetUrnConfigMixin + PlatformInstanceConfigMixin, + EnvConfigMixin, + LowerCaseDatasetUrnConfigMixin, ): include_data_platform_instance: bool = Field( default=False, @@ -380,7 +386,8 @@ class BigQueryV2Config( ) usage: BigQueryUsageConfig = Field( - default=BigQueryUsageConfig(), description="Usage related configs" + default=BigQueryUsageConfig(), + description="Usage related configs", ) include_usage_statistics: bool = Field( @@ -414,7 +421,8 @@ class BigQueryV2Config( ) include_table_snapshots: Optional[bool] = Field( - default=True, description="Whether table snapshots should be ingested." + default=True, + description="Whether table snapshots should be ingested.", ) debug_include_full_payloads: bool = Field( @@ -593,7 +601,7 @@ def set_include_schema_metadata(cls, values: Dict) -> Dict: values["include_table_snapshots"] = False logger.info( "include_tables and include_views are both set to False." - " Disabling schema metadata ingestion for tables, views, and snapshots." + " Disabling schema metadata ingestion for tables, views, and snapshots.", ) return values @@ -608,7 +616,9 @@ def profile_default_settings(cls, values: Dict) -> Dict: @validator("bigquery_audit_metadata_datasets") def validate_bigquery_audit_metadata_datasets( - cls, v: Optional[List[str]], values: Dict + cls, + v: Optional[List[str]], + values: Dict, ) -> Optional[List[str]]: if values.get("use_exported_bigquery_audit_metadata"): assert v and len(v) > 0, ( @@ -621,5 +631,5 @@ def get_table_pattern(self, pattern: List[str]) -> str: return "|".join(pattern) if pattern else "" platform_instance_not_supported_for_bigquery = pydantic_removed_field( - "platform_instance" + "platform_instance", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py index cc9f6d47656588..cfc9424c5290c9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py @@ -49,7 +49,7 @@ def get_sample_data_for_table( # additional filter clause (e.g. where condition on partition) is available. logger.debug( - f"Collecting sample values for table {project}.{dataset}.{table_name}" + f"Collecting sample values for table {project}.{dataset}.{table_name}", ) with PerfTimer() as timer: sample_pc = sample_size_percent * 100 @@ -63,7 +63,7 @@ def get_sample_data_for_table( time_taken = timer.elapsed_seconds() logger.debug( f"Finished collecting sample values for table {project}.{dataset}.{table_name};" - f"{df.shape[0]} rows; took {time_taken:.3f} seconds" + f"{df.shape[0]} rows; took {time_taken:.3f} seconds", ) return df.to_dict(orient="list") diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_helper.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_helper.py index 507e1d917d2066..7049c1a9437866 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_helper.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_helper.py @@ -25,7 +25,8 @@ def unquote_and_decode_unicode_escape_seq( # Replace the Unicode escape sequence with the decoded character try: string = string.replace( - unicode_seq, unicode_seq.encode("utf-8").decode("unicode-escape") + unicode_seq, + unicode_seq.encode("utf-8").decode("unicode-escape"), ) except UnicodeDecodeError: # Skip decoding if is not possible to decode the Unicode escape sequence diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_platform_resource_helper.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_platform_resource_helper.py index 7dc0e4195d5dc9..93285d3a814b48 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_platform_resource_helper.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_platform_resource_helper.py @@ -75,7 +75,8 @@ def __init__( platform_resource_cache: cachetools.LRUCache = cachetools.LRUCache(maxsize=500) def get_platform_resource( - self, platform_resource_key: PlatformResourceKey + self, + platform_resource_key: PlatformResourceKey, ) -> Optional[PlatformResource]: # if graph is not available we always create a new PlatformResource if not self.graph: @@ -84,7 +85,8 @@ def get_platform_resource( return self.platform_resource_cache.get(platform_resource_key.primary_key) platform_resource = PlatformResource.from_datahub( - key=platform_resource_key, graph_client=self.graph + key=platform_resource_key, + graph_client=self.graph, ) if platform_resource: self.platform_resource_cache[platform_resource_key.primary_key] = ( @@ -107,7 +109,7 @@ def generate_label_platform_resource( ) platform_resource = self.get_platform_resource( - new_platform_resource.platform_resource_key() + new_platform_resource.platform_resource_key(), ) if platform_resource: if ( @@ -117,12 +119,12 @@ def generate_label_platform_resource( try: existing_info: Optional[BigQueryLabelInfo] = ( platform_resource.resource_info.value.as_pydantic_object( # type: ignore - BigQueryLabelInfo + BigQueryLabelInfo, ) ) except ValidationError as e: logger.error( - f"Error converting existing value to BigQueryLabelInfo: {e}. Creating new one. Maybe this is because of a non backward compatible schema change." + f"Error converting existing value to BigQueryLabelInfo: {e}. Creating new one. Maybe this is because of a non backward compatible schema change.", ) existing_info = None @@ -134,15 +136,15 @@ def generate_label_platform_resource( return platform_resource else: raise ValueError( - f"Datahub URN mismatch for platform resources. Old (existing) platform resource: {platform_resource} and new platform resource: {new_platform_resource}" + f"Datahub URN mismatch for platform resources. Old (existing) platform resource: {platform_resource} and new platform resource: {new_platform_resource}", ) logger.info(f"Created platform resource {new_platform_resource}") self.platform_resource_cache.update( { - new_platform_resource.platform_resource_key().primary_key: new_platform_resource.platform_resource() - } + new_platform_resource.platform_resource_key().primary_key: new_platform_resource.platform_resource(), + }, ) return new_platform_resource.platform_resource() diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py index 47f21c9f32353a..5b3dec717212d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py @@ -36,15 +36,17 @@ class BigQueryQueriesSourceReport(SourceReport): window: Optional[BaseTimeWindowConfig] = None queries_extractor: Optional[BigQueryQueriesExtractorReport] = None schema_api_perf: BigQuerySchemaApiPerfReport = field( - default_factory=BigQuerySchemaApiPerfReport + default_factory=BigQuerySchemaApiPerfReport, ) class BigQueryQueriesSourceConfig( - BigQueryQueriesExtractorConfig, BigQueryFilterConfig, BigQueryIdentifierConfig + BigQueryQueriesExtractorConfig, + BigQueryFilterConfig, + BigQueryIdentifierConfig, ): connection: BigQueryConnectionConfig = Field( - default_factory=BigQueryConnectionConfig + default_factory=BigQueryConnectionConfig, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py index 8e55d81aac5fe3..f8a147e7fcd031 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py @@ -84,25 +84,25 @@ class BigQueryV2Report( ): num_total_lineage_entries: TopKDict[str, int] = field(default_factory=TopKDict) num_skipped_lineage_entries_missing_data: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_skipped_lineage_entries_not_allowed: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_lineage_entries_sql_parser_failure: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_skipped_lineage_entries_other: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_lineage_total_log_entries: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_lineage_parsed_log_entries: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_lineage_log_parse_failures: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) bigquery_audit_metadata_datasets_missing: Optional[bool] = None lineage_failed_extraction: LossyList[str] = field(default_factory=LossyList) @@ -111,10 +111,10 @@ class BigQueryV2Report( lineage_extraction_sec: Dict[str, float] = field(default_factory=TopKDict) usage_extraction_sec: Dict[str, float] = field(default_factory=TopKDict) num_usage_total_log_entries: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_usage_parsed_log_entries: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) num_usage_resources_dropped: int = 0 @@ -136,13 +136,13 @@ class BigQueryV2Report( profile_table_selection_criteria: Dict[str, str] = field(default_factory=TopKDict) selected_profile_tables: Dict[str, List[str]] = field(default_factory=TopKDict) profiling_skipped_invalid_partition_ids: Dict[str, str] = field( - default_factory=TopKDict + default_factory=TopKDict, ) profiling_skipped_invalid_partition_type: Dict[str, str] = field( - default_factory=TopKDict + default_factory=TopKDict, ) profiling_skipped_partition_profiling_disabled: List[str] = field( - default_factory=LossyList + default_factory=LossyList, ) allow_pattern: Optional[str] = None deny_pattern: Optional[str] = None @@ -171,13 +171,13 @@ class BigQueryV2Report( init_schema_resolver_timer: PerfTimer = field(default_factory=PerfTimer) schema_api_perf: BigQuerySchemaApiPerfReport = field( - default_factory=BigQuerySchemaApiPerfReport + default_factory=BigQuerySchemaApiPerfReport, ) audit_log_api_perf: BigQueryAuditLogApiPerfReport = field( - default_factory=BigQueryAuditLogApiPerfReport + default_factory=BigQueryAuditLogApiPerfReport, ) processing_perf: BigQueryProcessingPerfReport = field( - default_factory=BigQueryProcessingPerfReport + default_factory=BigQueryProcessingPerfReport, ) lineage_start_time: Optional[datetime] = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema.py index cbe1f6eb978247..feea94ca03b950 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema.py @@ -72,7 +72,8 @@ class PartitionInfo: # TimePartitioning field doesn't provide data_type so we have to add it afterwards @classmethod def from_time_partitioning( - cls, time_partitioning: TimePartitioning + cls, + time_partitioning: TimePartitioning, ) -> "PartitionInfo": return cls( field=time_partitioning.field or "_PARTITIONTIME", @@ -83,7 +84,8 @@ def from_time_partitioning( @classmethod def from_range_partitioning( - cls, range_partitioning: Dict[str, Any] + cls, + range_partitioning: Dict[str, Any], ) -> Optional["PartitionInfo"]: field: Optional[str] = range_partitioning.get("field") if not field: @@ -102,7 +104,7 @@ def from_table_info(cls, table_info: TableListItem) -> Optional["PartitionInfo"] return PartitionInfo.from_time_partitioning(table_info.time_partitioning) elif RANGE_PARTITIONING_KEY in table_info._properties: return PartitionInfo.from_range_partitioning( - table_info._properties[RANGE_PARTITIONING_KEY] + table_info._properties[RANGE_PARTITIONING_KEY], ) else: return None @@ -158,7 +160,7 @@ class BigqueryDataset: # Omni Locations - https://cloud.google.com/bigquery/docs/omni-introduction#locations def is_biglake_dataset(self) -> bool: return self.location is not None and self.location.lower().startswith( - ("aws-", "azure-") + ("aws-", "azure-"), ) def supports_table_constraints(self) -> bool: @@ -212,7 +214,7 @@ def _should_retry(exc: BaseException) -> bool: def get_projects(self, max_results_per_page: int = 100) -> List[BigqueryProject]: def _should_retry(exc: BaseException) -> bool: logger.debug( - f"Exception occurred for project.list api. Reason: {exc}. Retrying api request..." + f"Exception occurred for project.list api. Reason: {exc}. Retrying api request...", ) self.report.num_list_projects_retry_request += 1 return True @@ -265,20 +267,24 @@ def get_projects_with_labels(self, labels: FrozenSet[str]) -> List[BigqueryProje for project in self.projects_client.search_projects(query=labels_query): projects.append( BigqueryProject( - id=project.project_id, name=project.display_name - ) + id=project.project_id, + name=project.display_name, + ), ) return projects except Exception as e: logger.error( - f"Error getting projects with labels: {labels}. {e}", exc_info=True + f"Error getting projects with labels: {labels}. {e}", + exc_info=True, ) return [] def get_datasets_for_project_id( - self, project_id: str, maxResults: Optional[int] = None + self, + project_id: str, + maxResults: Optional[int] = None, ) -> List[BigqueryDataset]: with self.report.list_datasets_timer: self.report.num_list_datasets_api_requests += 1 @@ -298,7 +304,8 @@ def get_datasets_for_project_id( # This is not used anywhere def get_datasets_for_project_id_with_information_schema( - self, project_id: str + self, + project_id: str, ) -> List[BigqueryDataset]: """ This method is not used as of now, due to below limitation. @@ -321,7 +328,9 @@ def get_datasets_for_project_id_with_information_schema( ] def list_tables( - self, dataset_name: str, project_id: str + self, + dataset_name: str, + project_id: str, ) -> Iterator[TableListItem]: with PerfTimer() as current_timer: for table in self.bq_client.list_tables(f"{project_id}.{dataset_name}"): @@ -364,7 +373,8 @@ def get_tables_for_dataset( try: with current_timer.pause(): yield BigQuerySchemaApi._make_bigquery_table( - table, tables.get(table.table_name) + table, + tables.get(table.table_name), ) except Exception as e: table_name = f"{project_id}.{dataset_name}.{table.table_name}" @@ -379,7 +389,8 @@ def get_tables_for_dataset( @staticmethod def _make_bigquery_table( - table: bigquery.Row, table_basic: Optional[TableListItem] + table: bigquery.Row, + table_basic: Optional[TableListItem], ) -> BigqueryTable: # Some properties we want to capture are only available from the TableListItem # we get from an earlier query of the list of tables. @@ -425,13 +436,15 @@ def get_views_for_dataset( # If profiling is enabled cur = self.get_query_result( BigqueryQuery.views_for_dataset.format( - project_id=project_id, dataset_name=dataset_name + project_id=project_id, + dataset_name=dataset_name, ), ) else: cur = self.get_query_result( BigqueryQuery.views_for_dataset_without_data_read.format( - project_id=project_id, dataset_name=dataset_name + project_id=project_id, + dataset_name=dataset_name, ), ) @@ -492,11 +505,11 @@ def get_policy_tags_for_column( if rate_limiter: with rate_limiter: policy_tag = self.datacatalog_client.get_policy_tag( - name=policy_tag_name + name=policy_tag_name, ) else: policy_tag = self.datacatalog_client.get_policy_tag( - name=policy_tag_name + name=policy_tag_name, ) yield policy_tag.display_name except Exception as e: @@ -525,8 +538,9 @@ def get_table_constraints_for_dataset( try: cur = self.get_query_result( BigqueryQuery.constraints_for_table.format( - project_id=project_id, dataset_name=dataset_name - ) + project_id=project_id, + dataset_name=dataset_name, + ), ) except Exception as e: report.warning( @@ -566,7 +580,7 @@ def get_table_constraints_for_dataset( if constraint.constraint_type == "FOREIGN KEY" else None ), - ) + ), ) self.report.num_get_table_constraints_for_dataset_api_requests += 1 self.report.get_table_constraints_for_dataset_sec += timer.elapsed_seconds() @@ -589,7 +603,8 @@ def get_columns_for_dataset( cur = self.get_query_result( ( BigqueryQuery.columns_for_dataset.format( - project_id=project_id, dataset_name=dataset_name + project_id=project_id, + dataset_name=dataset_name, ) if not run_optimized_column_query else BigqueryQuery.optimized_columns_for_dataset.format( @@ -618,7 +633,7 @@ def get_columns_for_dataset( ): if last_seen_table != column.table_name: logger.warning( - f"{project_id}.{dataset_name}.{column.table_name} contains more than {column_limit} columns, only processing {column_limit} columns" + f"{project_id}.{dataset_name}.{column.table_name} contains more than {column_limit} columns, only processing {column_limit} columns", ) last_seen_table = column.table_name else: @@ -642,12 +657,12 @@ def get_columns_for_dataset( column.column_name, report, rate_limiter, - ) + ), ) if extract_policy_tags_from_catalog else [] ), - ) + ), ) self.report.num_get_columns_for_dataset_api_requests += 1 self.report.get_columns_for_dataset_sec += timer.elapsed_seconds() @@ -666,13 +681,15 @@ def get_snapshots_for_dataset( # If profiling is enabled cur = self.get_query_result( BigqueryQuery.snapshots_for_dataset.format( - project_id=project_id, dataset_name=dataset_name + project_id=project_id, + dataset_name=dataset_name, ), ) else: cur = self.get_query_result( BigqueryQuery.snapshots_for_dataset_without_data_read.format( - project_id=project_id, dataset_name=dataset_name + project_id=project_id, + dataset_name=dataset_name, ), ) @@ -738,7 +755,7 @@ def query_project_list( yield project else: logger.debug( - f"Ignoring project {project.id} as it's not allowed by project_id_pattern" + f"Ignoring project {project.id} as it's not allowed by project_id_pattern", ) @@ -765,7 +782,7 @@ def query_project_list_from_labels( filters: BigQueryFilter, ) -> Iterable[BigqueryProject]: projects = schema_api.get_projects_with_labels( - frozenset(filters.filter_config.project_labels) + frozenset(filters.filter_config.project_labels), ) if not projects: # Report failure on exception and if empty list is returned @@ -781,5 +798,5 @@ def query_project_list_from_labels( yield project else: logger.debug( - f"Ignoring project {project.id} as it's not allowed by project_id_pattern" + f"Ignoring project {project.id} as it's not allowed by project_id_pattern", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py index ebfbbf0639c38c..b1008aeadc8dc9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py @@ -190,7 +190,7 @@ def __init__( self.data_reader: Optional[BigQueryDataReader] = None if self.classification_handler.is_classification_enabled(): self.data_reader = BigQueryDataReader.create( - self.config.get_bigquery_client() + self.config.get_bigquery_client(), ) # Global store of table identifiers for lineage filtering @@ -246,21 +246,25 @@ def modified_base32decode(self, text_to_decode: str) -> str: return text def get_project_workunits( - self, project: BigqueryProject + self, + project: BigqueryProject, ) -> Iterable[MetadataWorkUnit]: with self.report.new_stage(f"{project.id}: {METADATA_EXTRACTION}"): logger.info(f"Processing project: {project.id}") yield from self._process_project(project) def get_dataplatform_instance_aspect( - self, dataset_urn: str, project_id: str + self, + dataset_urn: str, + project_id: str, ) -> MetadataWorkUnit: aspect = DataPlatformInstanceClass( platform=self.identifiers.make_data_platform_urn(), instance=self.identifiers.make_dataplatform_instance_urn(project_id), ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=aspect + entityUrn=dataset_urn, + aspect=aspect, ).as_workunit() def gen_dataset_key(self, db_name: str, schema: str) -> ContainerKey: @@ -309,11 +313,13 @@ def gen_dataset_containers( label = BigQueryLabel(key=k, value=v) try: platform_resource: PlatformResource = self.platform_resource_helper.generate_label_platform_resource( - label, tag_urn, managed_by_datahub=False + label, + tag_urn, + managed_by_datahub=False, ) label_info: BigQueryLabelInfo = ( platform_resource.resource_info.value.as_pydantic_object( # type: ignore - BigQueryLabelInfo + BigQueryLabelInfo, ) ) tag_urn = TagUrn.from_string(label_info.datahub_urn) @@ -322,7 +328,7 @@ def gen_dataset_containers( yield mcpw.as_workunit() except ValueError as e: logger.warning( - f"Failed to generate platform resource for label {k}:{v}: {e}" + f"Failed to generate platform resource for label {k}:{v}: {e}", ) tags_joined.append(tag_urn.name) @@ -338,7 +344,8 @@ def gen_dataset_containers( database_container_key=database_container_key, external_url=( BQ_EXTERNAL_DATASET_URL_TEMPLATE.format( - project=project_id, dataset=dataset + project=project_id, + dataset=dataset, ) if self.config.include_external_url else None @@ -348,14 +355,15 @@ def gen_dataset_containers( ) def _process_project( - self, bigquery_project: BigqueryProject + self, + bigquery_project: BigqueryProject, ) -> Iterable[MetadataWorkUnit]: db_tables: Dict[str, List[BigqueryTable]] = {} project_id = bigquery_project.id try: bigquery_project.datasets = self.schema_api.get_datasets_for_project_id( - project_id + project_id, ) except Exception as e: if self.config.project_ids and "not enabled BigQuery." in str(e): @@ -385,7 +393,7 @@ def _process_project( if self.config.exclude_empty_projects: self.report.report_dropped(project_id) logger.info( - f"Excluded project '{project_id}' since no datasets were found. {action_message}" + f"Excluded project '{project_id}' since no datasets were found. {action_message}", ) else: if self.config.include_schema_metadata: @@ -401,7 +409,7 @@ def _process_project( yield from self.gen_project_id_containers(project_id) self.report.num_project_datasets_to_scan[project_id] = len( - bigquery_project.datasets + bigquery_project.datasets, ) yield from self._process_project_datasets(bigquery_project, db_tables) @@ -436,7 +444,11 @@ def _process_schema_worker( try: # db_tables, db_views, and db_snapshots are populated in the this method for wu in self._process_schema( - project_id, bigquery_dataset, db_tables, db_views, db_snapshots + project_id, + bigquery_dataset, + db_tables, + db_views, + db_snapshots, ): yield wu except Exception as e: @@ -487,7 +499,8 @@ def _process_schema( rate_limiter: Optional[RateLimiter] = None if self.config.rate_limit: rate_limiter = RateLimiter( - max_calls=self.config.requests_per_min, period=60 + max_calls=self.config.requests_per_min, + period=60, ) if self.config.include_schema_metadata: @@ -505,7 +518,9 @@ def _process_schema( and bigquery_dataset.supports_table_constraints() ): constraints = self.schema_api.get_table_constraints_for_dataset( - project_id=project_id, dataset_name=dataset_name, report=self.report + project_id=project_id, + dataset_name=dataset_name, + report=self.report, ) elif self.store_table_refs: # Need table_refs to calculate lineage and usage @@ -520,17 +535,17 @@ def _process_schema( continue try: self.table_refs.add( - str(BigQueryTableRef(identifier).get_sanitized_table_ref()) + str(BigQueryTableRef(identifier).get_sanitized_table_ref()), ) except Exception as e: logger.warning( - f"Could not create table ref for {table_item.path}: {e}" + f"Could not create table ref for {table_item.path}: {e}", ) return if self.config.include_tables: db_tables[dataset_name] = list( - self.get_tables_for_dataset(project_id, bigquery_dataset) + self.get_tables_for_dataset(project_id, bigquery_dataset), ) for table in db_tables[dataset_name]: @@ -558,7 +573,7 @@ def _process_schema( / table.rows_count if table.rows_count else None - ) + ), ), ) @@ -569,7 +584,7 @@ def _process_schema( dataset_name, self.config.is_profiling_enabled(), self.report, - ) + ), ) for view in db_views[dataset_name]: @@ -588,7 +603,7 @@ def _process_schema( dataset_name, self.config.is_profiling_enabled(), self.report, - ) + ), ) for snapshot in db_snapshots[dataset_name]: @@ -617,13 +632,13 @@ def _process_table( if self.store_table_refs: self.table_refs.add( - str(BigQueryTableRef(table_identifier).get_sanitized_table_ref()) + str(BigQueryTableRef(table_identifier).get_sanitized_table_ref()), ) table.column_count = len(columns) if not table.column_count: logger.warning( - f"Table doesn't have any column or unable to get columns for table: {table_identifier}" + f"Table doesn't have any column or unable to get columns for table: {table_identifier}", ) # If table has time partitioning, set the data type of the partitioning field @@ -637,7 +652,10 @@ def _process_table( None, ) yield from self.gen_table_dataset_workunits( - table, columns, project_id, dataset_name + table, + columns, + project_id, + dataset_name, ) def _process_view( @@ -664,7 +682,7 @@ def _process_view( view.column_count = len(columns) if not view.column_count: logger.warning( - f"View doesn't have any column or unable to get columns for view: {table_identifier}" + f"View doesn't have any column or unable to get columns for view: {table_identifier}", ) yield from self.gen_view_dataset_workunits( @@ -682,13 +700,15 @@ def _process_snapshot( dataset_name: str, ) -> Iterable[MetadataWorkUnit]: table_identifier = BigqueryTableIdentifier( - project_id, dataset_name, snapshot.name + project_id, + dataset_name, + snapshot.name, ) self.report.snapshots_scanned += 1 if not self.config.table_snapshot_pattern.allowed( - table_identifier.raw_table_name() + table_identifier.raw_table_name(), ): self.report.report_dropped(table_identifier.raw_table_name()) return @@ -697,7 +717,7 @@ def _process_snapshot( snapshot.column_count = len(columns) if not snapshot.column_count: logger.warning( - f"Snapshot doesn't have any column or unable to get columns for snapshot: {table_identifier}" + f"Snapshot doesn't have any column or unable to get columns for snapshot: {table_identifier}", ) table_ref = str(BigQueryTableRef(table_identifier).get_sanitized_table_ref()) @@ -728,7 +748,7 @@ def gen_foreign_keys( ) -> Iterable[ForeignKeyConstraint]: table_id = f"{project_id}.{dataset_name}.{table.name}" foreign_keys: List[BigqueryTableConstraint] = list( - filter(lambda x: x.type == "FOREIGN KEY", table.constraints) + filter(lambda x: x.type == "FOREIGN KEY", table.constraints), ) for key, group in groupby( foreign_keys, @@ -752,11 +772,13 @@ def gen_foreign_keys( for item in group: source_field = make_schema_field_urn( - parent_urn=dataset_urn, field_path=item.field_path + parent_urn=dataset_urn, + field_path=item.field_path, ) assert item.referenced_column_name referenced_field = make_schema_field_urn( - parent_urn=foreign_dataset, field_path=item.referenced_column_name + parent_urn=foreign_dataset, + field_path=item.referenced_column_name, ) source_fields.append(source_field) @@ -790,12 +812,12 @@ def gen_table_dataset_workunits( if table.active_billable_bytes: custom_properties["billable_bytes_active"] = str( - table.active_billable_bytes + table.active_billable_bytes, ) if table.long_term_billable_bytes: custom_properties["billable_bytes_long_term"] = str( - table.long_term_billable_bytes + table.long_term_billable_bytes, ) if table.max_partition_id: @@ -820,11 +842,13 @@ def gen_table_dataset_workunits( try: label = BigQueryLabel(key=k, value=v) platform_resource: PlatformResource = self.platform_resource_helper.generate_label_platform_resource( - label, tag_urn, managed_by_datahub=False + label, + tag_urn, + managed_by_datahub=False, ) label_info: BigQueryLabelInfo = ( platform_resource.resource_info.value.as_pydantic_object( # type: ignore - BigQueryLabelInfo + BigQueryLabelInfo, ) ) tag_urn = TagUrn.from_string(label_info.datahub_urn) @@ -833,7 +857,7 @@ def gen_table_dataset_workunits( yield mcpw.as_workunit() except ValueError as e: logger.warning( - f"Failed to generate platform resource for label {k}:{v}: {e}" + f"Failed to generate platform resource for label {k}:{v}: {e}", ) tags_to_add.append(tag_urn.urn()) @@ -862,11 +886,13 @@ def gen_view_dataset_workunits( try: label = BigQueryLabel(key=k, value=v) platform_resource: PlatformResource = self.platform_resource_helper.generate_label_platform_resource( - label, tag_urn, managed_by_datahub=False + label, + tag_urn, + managed_by_datahub=False, ) label_info: BigQueryLabelInfo = ( platform_resource.resource_info.value.as_pydantic_object( # type: ignore - BigQueryLabelInfo + BigQueryLabelInfo, ) ) tag_urn = TagUrn.from_string(label_info.datahub_urn) @@ -875,7 +901,7 @@ def gen_view_dataset_workunits( yield mcpw.as_workunit() except ValueError as e: logger.warning( - f"Failed to generate platform resource for label {k}:{v}: {e}" + f"Failed to generate platform resource for label {k}:{v}: {e}", ) tags_to_add.append(tag_urn.urn()) @@ -897,7 +923,9 @@ def gen_view_dataset_workunits( ) yield MetadataChangeProposalWrapper( entityUrn=self.identifiers.gen_dataset_urn( - project_id, dataset_name, table.name + project_id, + dataset_name, + table.name, ), aspect=view_properties_aspect, ).as_workunit() @@ -938,7 +966,9 @@ def gen_dataset_workunits( custom_properties: Optional[Dict[str, str]] = None, ) -> Iterable[MetadataWorkUnit]: dataset_urn = self.identifiers.gen_dataset_urn( - project_id, dataset_name, table.name + project_id, + dataset_name, + table.name, ) # Added for bigquery to gcs lineage extraction @@ -952,15 +982,21 @@ def gen_dataset_workunits( status = Status(removed=False) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=status + entityUrn=dataset_urn, + aspect=status, ).as_workunit() datahub_dataset_name = BigqueryTableIdentifier( - project_id, dataset_name, table.name + project_id, + dataset_name, + table.name, ) yield self.gen_schema_metadata( - dataset_urn, table, columns, datahub_dataset_name + dataset_urn, + table, + columns, + datahub_dataset_name, ) dataset_properties = DatasetProperties( @@ -983,7 +1019,9 @@ def gen_dataset_workunits( ), externalUrl=( BQ_EXTERNAL_TABLE_URL_TEMPLATE.format( - project=project_id, dataset=dataset_name, table=table.name + project=project_id, + dataset=dataset_name, + table=table.name, ) if self.config.include_external_url else None @@ -993,7 +1031,8 @@ def gen_dataset_workunits( dataset_properties.customProperties.update(custom_properties) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() if tags_to_add: @@ -1004,12 +1043,14 @@ def gen_dataset_workunits( parent_container_key=self.gen_dataset_key(project_id, dataset_name), ) yield self.get_dataplatform_instance_aspect( - dataset_urn=dataset_urn, project_id=project_id + dataset_urn=dataset_urn, + project_id=project_id, ) subTypes = SubTypes(typeNames=sub_types) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=subTypes + entityUrn=dataset_urn, + aspect=subTypes, ).as_workunit() if self.domain_registry: @@ -1021,17 +1062,22 @@ def gen_dataset_workunits( ) def gen_tags_aspect_workunit( - self, dataset_urn: str, tags_to_add: List[str] + self, + dataset_urn: str, + tags_to_add: List[str], ) -> MetadataWorkUnit: tags = GlobalTagsClass( - tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add] + tags=[TagAssociationClass(tag_to_add) for tag_to_add in tags_to_add], ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=tags + entityUrn=dataset_urn, + aspect=tags, ).as_workunit() def is_primary_key( - self, field_path: str, constraints: List[BigqueryTableConstraint] + self, + field_path: str, + constraints: List[BigqueryTableConstraint], ) -> bool: for constraint in constraints: if constraint.field_path == field_path and constraint.type == "PRIMARY KEY": @@ -1039,7 +1085,9 @@ def is_primary_key( return False def gen_schema_fields( - self, columns: List[BigqueryColumn], constraints: List[BigqueryTableConstraint] + self, + columns: List[BigqueryColumn], + constraints: List[BigqueryTableConstraint], ) -> List[SchemaField]: schema_fields: List[SchemaField] = [] @@ -1060,8 +1108,10 @@ def gen_schema_fields( if last_id != col.ordinal_position: schema_fields.extend( get_schema_fields_for_hive_column( - col.name, col.data_type.lower(), description=col.comment - ) + col.name, + col.data_type.lower(), + description=col.comment, + ), ) # We have to add complex type comments to the correct level @@ -1087,9 +1137,9 @@ def gen_schema_fields( tags.append( TagAssociationClass( make_tag_urn( - f"{CLUSTERING_COLUMN_TAG}_{col.cluster_column_position}" - ) - ) + f"{CLUSTERING_COLUMN_TAG}_{col.cluster_column_position}", + ), + ), ) if col.policy_tags: @@ -1098,7 +1148,10 @@ def gen_schema_fields( field = SchemaField( fieldPath=col.name, type=SchemaFieldDataType( - self.BIGQUERY_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() + self.BIGQUERY_FIELD_TYPE_MAPPINGS.get( + col.data_type, + NullType, + )(), ), isPartitioningKey=col.is_partition_column, isPartOfKey=self.is_primary_key(col.field_path, constraints), @@ -1126,8 +1179,10 @@ def gen_schema_metadata( if isinstance(table, BigqueryTable): foreign_keys = list( self.gen_foreign_keys( - table, dataset_name.dataset, dataset_name.project_id - ) + table, + dataset_name.dataset, + dataset_name.project_id, + ), ) schema_metadata = SchemaMetadata( @@ -1150,11 +1205,13 @@ def gen_schema_metadata( if self.config.lineage_use_sql_parser: self.sql_parser_schema_resolver.add_schema_metadata( - dataset_urn, schema_metadata + dataset_urn, + schema_metadata, ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=schema_metadata + entityUrn=dataset_urn, + aspect=schema_metadata, ).as_workunit() def get_tables_for_dataset( @@ -1184,7 +1241,9 @@ def get_tables_for_dataset( # We get the list of tables in the dataset to get core table properties and to be able to process the tables in batches # We collect only the latest shards from sharded tables (tables with _YYYYMMDD suffix) and ignore temporary tables table_items = self.get_core_table_details( - dataset.name, project_id, self.config.temp_table_dataset_prefix + dataset.name, + project_id, + self.config.temp_table_dataset_prefix, ) items_to_get: Dict[str, TableListItem] = {} @@ -1214,7 +1273,10 @@ def get_tables_for_dataset( ) def get_core_table_details( - self, dataset_name: str, project_id: str, temp_table_dataset_prefix: str + self, + dataset_name: str, + project_id: str, + temp_table_dataset_prefix: str, ) -> Dict[str, TableListItem]: table_items: Dict[str, TableListItem] = {} # Dict to store sharded table and the last seen max shard id @@ -1231,20 +1293,20 @@ def get_core_table_details( if ( not self.config.include_views or not self.config.view_pattern.allowed( - table_identifier.raw_table_name() + table_identifier.raw_table_name(), ) ): self.report.report_dropped(table_identifier.raw_table_name()) continue else: if not self.config.table_pattern.allowed( - table_identifier.raw_table_name() + table_identifier.raw_table_name(), ): self.report.report_dropped(table_identifier.raw_table_name()) continue _, shard = BigqueryTableIdentifier.get_table_and_shard( - table_identifier.table + table_identifier.table, ) table_name = table_identifier.get_table_name().split(".")[-1] @@ -1266,7 +1328,7 @@ def get_core_table_details( table=sharded_tables[table_name].table_id, ) _, stored_shard = BigqueryTableIdentifier.get_table_and_shard( - stored_table_identifier.table + stored_table_identifier.table, ) # When table is none, we use dataset_name as table_name assert stored_shard diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_test_connection.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_test_connection.py index fe64eeeb841399..dad1f4dcdbaac4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_test_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_test_connection.py @@ -32,7 +32,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: assert client test_report.basic_connectivity = BigQueryTestConnection.connectivity_test( - client + client, ) connection_conf.start_time = datetime.now() @@ -48,7 +48,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: metadata_read_capability = ( BigQueryTestConnection.metadata_read_capability_test( - project_ids, connection_conf + project_ids, + connection_conf, ) ) if SourceCapability.SCHEMA_METADATA not in _report: @@ -56,14 +57,18 @@ def test_connection(config_dict: dict) -> TestConnectionReport: if connection_conf.include_table_lineage: lineage_capability = BigQueryTestConnection.lineage_capability_test( - connection_conf, project_ids, report + connection_conf, + project_ids, + report, ) if SourceCapability.LINEAGE_COARSE not in _report: _report[SourceCapability.LINEAGE_COARSE] = lineage_capability if connection_conf.include_usage_statistics: usage_capability = BigQueryTestConnection.usage_capability_test( - connection_conf, project_ids, report + connection_conf, + project_ids, + report, ) if SourceCapability.USAGE_STATS not in _report: _report[SourceCapability.USAGE_STATS] = usage_capability @@ -73,7 +78,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=f"{e}" + capable=False, + failure_reason=f"{e}", ) return test_report @@ -82,14 +88,16 @@ def connectivity_test(client: bigquery.Client) -> CapabilityReport: ret = client.query("select 1") if ret.error_result: return CapabilityReport( - capable=False, failure_reason=f"{ret.error_result['message']}" + capable=False, + failure_reason=f"{ret.error_result['message']}", ) else: return CapabilityReport(capable=True) @staticmethod def metadata_read_capability_test( - project_ids: List[str], config: BigQueryV2Config + project_ids: List[str], + config: BigQueryV2Config, ) -> CapabilityReport: for project_id in project_ids: try: @@ -102,7 +110,8 @@ def metadata_read_capability_test( client=client, ) result = bigquery_data_dictionary.get_datasets_for_project_id( - project_id, 10 + project_id, + 10, ) if len(result) == 0: return CapabilityReport( diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/common.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/common.py index 83484e3a6a3d90..d65d5d389765f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/common.py @@ -41,7 +41,11 @@ def __init__( self.structured_reporter = structured_reporter def gen_dataset_urn( - self, project_id: str, dataset_name: str, table: str, use_raw_name: bool = False + self, + project_id: str, + dataset_name: str, + table: str, + use_raw_name: bool = False, ) -> str: datahub_dataset_name = BigqueryTableIdentifier(project_id, dataset_name, table) return make_dataset_urn( @@ -85,14 +89,16 @@ def standardize_identifier_case(self, table_ref_str: str) -> str: class BigQueryFilter: def __init__( - self, filter_config: BigQueryFilterConfig, structured_reporter: SourceReport + self, + filter_config: BigQueryFilterConfig, + structured_reporter: SourceReport, ) -> None: self.filter_config = filter_config self.structured_reporter = structured_reporter def is_allowed(self, table_id: BigqueryTableIdentifier) -> bool: return AllowDenyPattern(deny=BQ_SYSTEM_TABLES_PATTERN).allowed( - str(table_id) + str(table_id), ) and ( self.is_project_allowed(table_id.project_id) and is_schema_allowed( diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py index da82c6a06f0395..fa7370766a962d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py @@ -100,7 +100,8 @@ class LineageEdge: def _merge_lineage_edge_columns( - a: Optional[LineageEdge], b: LineageEdge + a: Optional[LineageEdge], + b: LineageEdge, ) -> LineageEdge: if a is None: return b @@ -164,7 +165,9 @@ def _follow_column_lineage( def make_lineage_edges_from_parsing_result( - sql_lineage: SqlParsingResult, audit_stamp: datetime, lineage_type: str + sql_lineage: SqlParsingResult, + audit_stamp: datetime, + lineage_type: str, ) -> List[LineageEdge]: # Note: This ignores the out_tables section of the sql parsing result. audit_stamp = datetime.now(timezone.utc) @@ -189,9 +192,9 @@ def make_lineage_edges_from_parsing_result( table_name = str( BigQueryTableRef.from_bigquery_table( BigqueryTableIdentifier.from_string_name( - DatasetUrn.from_string(table_urn).name - ) - ) + DatasetUrn.from_string(table_urn).name, + ), + ), ) except IndexError as e: logger.debug(f"Unable to parse table urn {table_urn}: {e}") @@ -260,7 +263,8 @@ def __init__( def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: return self.redundant_run_skip_handler.suggest_run_time_window( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) else: return self.config.start_time, self.config.end_time @@ -295,7 +299,7 @@ def get_lineage_workunits_for_views_and_snapshots( self.datasets_skip_audit_log_lineage.add(view) self.aggregator.add_view_definition( view_urn=self.identifiers.gen_dataset_urn_from_raw_ref( - BigQueryTableRef.from_string_name(view) + BigQueryTableRef.from_string_name(view), ), view_definition=view_definitions[view], default_db=project, @@ -307,13 +311,14 @@ def get_lineage_workunits_for_views_and_snapshots( continue self.datasets_skip_audit_log_lineage.add(snapshot_ref) snapshot_urn = self.identifiers.gen_dataset_urn_from_raw_ref( - BigQueryTableRef.from_string_name(snapshot_ref) + BigQueryTableRef.from_string_name(snapshot_ref), ) base_table_urn = self.identifiers.gen_dataset_urn_from_raw_ref( - BigQueryTableRef(snapshot.base_table_identifier) + BigQueryTableRef(snapshot.base_table_identifier), ) self.aggregator.add_known_lineage_mapping( - upstream_urn=base_table_urn, downstream_urn=snapshot_urn + upstream_urn=base_table_urn, + downstream_urn=snapshot_urn, ) yield from auto_workunit(self.aggregator.gen_metadata()) @@ -339,7 +344,8 @@ def get_lineage_workunits( if self.redundant_run_skip_handler: # Update the checkpoint state for this run. self.redundant_run_skip_handler.update_state( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) def generate_lineage( @@ -369,10 +375,10 @@ def generate_lineage( logger.info(f"Built lineage map containing {len(lineage)} entries.") logger.debug(f"lineage metadata is {lineage}") self.report.lineage_extraction_sec[project_id] = timer.elapsed_seconds( - digits=2 + digits=2, ) self.report.lineage_mem_size[project_id] = humanfriendly.format_size( - memory_footprint.total_size(lineage) + memory_footprint.total_size(lineage), ) for lineage_key in lineage.keys(): @@ -385,11 +391,14 @@ def generate_lineage( continue yield from self.gen_lineage_workunits_for_table( - lineage, BigQueryTableRef.from_string_name(lineage_key) + lineage, + BigQueryTableRef.from_string_name(lineage_key), ) def gen_lineage_workunits_for_table( - self, lineage: Dict[str, Set[LineageEdge]], table_ref: BigQueryTableRef + self, + lineage: Dict[str, Set[LineageEdge]], + table_ref: BigQueryTableRef, ) -> Iterable[MetadataWorkUnit]: dataset_urn = self.identifiers.gen_dataset_urn_from_raw_ref(table_ref) @@ -416,12 +425,14 @@ def gen_lineage( # Incremental lineage is handled by the auto_incremental_lineage helper. yield from [ MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=upstream_lineage - ).as_workunit() + entityUrn=dataset_urn, + aspect=upstream_lineage, + ).as_workunit(), ] def lineage_via_catalog_lineage_api( - self, project_id: str + self, + project_id: str, ) -> Dict[str, Set[LineageEdge]]: """ Uses Data Catalog API to request lineage metadata. Please take a look at the API documentation for more details. @@ -458,7 +469,7 @@ def lineage_via_catalog_lineage_api( table for table in data_dictionary.list_tables(dataset.name, project_id) if table.table_type in ["TABLE", "VIEW", "MATERIALIZED_VIEW"] - ] + ], ) lineage_map: Dict[str, Set[LineageEdge]] = {} @@ -493,17 +504,18 @@ def lineage_via_catalog_lineage_api( upstreams.update( [ str(lineage.source.fully_qualified_name).replace( - "bigquery:", "" + "bigquery:", + "", ) for lineage in response - ] + ], ) # Downstream table identifier destination_table_str = str( BigQueryTableRef( - table_identifier=BigqueryTableIdentifier(*table.split(".")) - ) + table_identifier=BigqueryTableIdentifier(*table.split(".")), + ), ) # Only builds lineage map when the table has upstreams @@ -514,9 +526,9 @@ def lineage_via_catalog_lineage_api( table=str( BigQueryTableRef( table_identifier=BigqueryTableIdentifier.from_string_name( - source_table - ) - ) + source_table, + ), + ), ), column_mapping=frozenset(), auditStamp=curr_date, @@ -538,12 +550,15 @@ def _get_parsed_audit_log_events(self, project_id: str) -> Iterable[QueryEvent]: parse_fn: Callable[[Any], Optional[Union[ReadEvent, QueryEvent]]] if self.config.use_exported_bigquery_audit_metadata: entries = self.get_exported_log_entries( - corrected_start_time, corrected_end_time + corrected_start_time, + corrected_end_time, ) parse_fn = self._parse_exported_bigquery_audit_metadata else: entries = self.get_log_entries_via_gcp_logging( - project_id, corrected_start_time, corrected_end_time + project_id, + corrected_start_time, + corrected_end_time, ) parse_fn = self._parse_bigquery_log_entries @@ -559,7 +574,10 @@ def _get_parsed_audit_log_events(self, project_id: str) -> Iterable[QueryEvent]: self.report.num_lineage_log_parse_failures[project_id] += 1 def get_exported_log_entries( - self, corrected_start_time, corrected_end_time, limit=None + self, + corrected_start_time, + corrected_end_time, + limit=None, ): logger.info("Populating lineage info via exported GCP audit logs") bq_client = self.config.get_bigquery_client() @@ -575,14 +593,17 @@ def get_exported_log_entries( return entries def get_log_entries_via_gcp_logging( - self, project_id, corrected_start_time, corrected_end_time + self, + project_id, + corrected_start_time, + corrected_end_time, ): logger.info("Populating lineage info via exported GCP audit logs") logging_client = self.config.make_gcp_logging_client(project_id) logger.info( f"Start loading log entries from BigQuery for {project_id} " - f"with start_time={corrected_start_time} and end_time={corrected_end_time}" + f"with start_time={corrected_start_time} and end_time={corrected_end_time}", ) entries = self.audit_log_api.get_bigquery_log_entries_via_gcp_logging( logging_client, @@ -612,7 +633,8 @@ def _parse_bigquery_log_entries( missing_entry_v2 = QueryEvent.get_missing_key_entry_v2(entry=entry) if event is None and missing_entry_v2 is None: event = QueryEvent.from_entry_v2( - entry, self.config.debug_include_full_payloads + entry, + self.config.debug_include_full_payloads, ) if event is None: @@ -624,7 +646,8 @@ def _parse_bigquery_log_entries( return event def _parse_exported_bigquery_audit_metadata( - self, audit_metadata: BigQueryAuditMetadata + self, + audit_metadata: BigQueryAuditMetadata, ) -> Optional[QueryEvent]: event: Optional[QueryEvent] = None @@ -634,7 +657,8 @@ def _parse_exported_bigquery_audit_metadata( if missing_exported_audit is None: event = QueryEvent.from_exported_bigquery_audit_metadata( - audit_metadata, self.config.debug_include_full_payloads + audit_metadata, + self.config.debug_include_full_payloads, ) if event is None: @@ -671,7 +695,7 @@ def _create_lineage_map( destination_table.table_identifier.project_id, self.config.match_fully_qualified_names, ) or not self.config.table_pattern.allowed( - destination_table.table_identifier.get_table_name() + destination_table.table_identifier.get_table_name(), ): self.report.num_skipped_lineage_entries_not_allowed[e.project_id] += 1 continue @@ -687,7 +711,7 @@ def _create_lineage_map( # Try the sql parser first. if self.config.lineage_use_sql_parser: logger.debug( - f"Using sql parser for lineage extraction for destination table: {destination_table.table_identifier.get_table_name()}, queryType: {e.statementType}, query: {e.query}" + f"Using sql parser for lineage extraction for destination table: {destination_table.table_identifier.get_table_name()}, queryType: {e.statementType}, query: {e.query}", ) if e.statementType == "SELECT": # We wrap select statements in a CTE to make them parseable as insert statement. @@ -704,7 +728,7 @@ def _create_lineage_map( except Exception: logger.debug( f"Failed to parse select-based lineage query {e.query} for table {destination_table}." - "Sql parsing will likely fail for this query, which will result in a fallback to audit log." + "Sql parsing will likely fail for this query, which will result in a fallback to audit log.", ) query = e.query else: @@ -715,13 +739,13 @@ def _create_lineage_map( default_db=e.project_id, ) logger.debug( - f"Input tables: {raw_lineage.in_tables}, Output tables: {raw_lineage.out_tables}" + f"Input tables: {raw_lineage.in_tables}, Output tables: {raw_lineage.out_tables}", ) if raw_lineage.debug_info.table_error: logger.debug( f"Sql Parser failed on query: {e.query}. It won't cause any major issues, but " f"queries referencing views may contain extra lineage for the tables underlying those views. " - f"The error was {raw_lineage.debug_info.table_error}." + f"The error was {raw_lineage.debug_info.table_error}.", ) self.report.num_lineage_entries_sql_parser_failure[ e.project_id @@ -732,7 +756,7 @@ def _create_lineage_map( raw_lineage, audit_stamp=ts, lineage_type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ) if not lineage_from_event: @@ -751,7 +775,7 @@ def _create_lineage_map( auditStamp=ts, column_mapping=frozenset(), column_confidence=0.1, - ) + ), ) if not lineage_from_event: @@ -781,12 +805,12 @@ def get_upstream_tables( continue if upstream_table.is_temporary_table( - [self.config.temp_table_dataset_prefix] + [self.config.temp_table_dataset_prefix], ): # making sure we don't process a table twice and not get into a recursive loop if upstream_lineage in edges_seen: logger.debug( - f"Skipping table {upstream_lineage} because it was seen already" + f"Skipping table {upstream_lineage} because it was seen already", ) continue edges_seen.add(upstream_lineage) @@ -806,7 +830,8 @@ def get_upstream_tables( # Replace `bq_table -> upstream_table -> temp_table_upstream` # with `bq_table -> temp_table_upstream`, merging the column lineage. collapsed_lineage = _follow_column_lineage( - upstream_lineage, temp_table_upstream + upstream_lineage, + temp_table_upstream, ) upstreams[ref_temp_table_upstream] = ( @@ -836,7 +861,7 @@ def get_lineage_for_table( for upstream in sorted(self.get_upstream_tables(bq_table, lineage_metadata)): upstream_table = BigQueryTableRef.from_string_name(upstream.table) upstream_table_urn = self.identifiers.gen_dataset_urn_from_raw_ref( - upstream_table + upstream_table, ) # Generate table-level lineage. @@ -850,7 +875,8 @@ def get_lineage_for_table( ) if self.config.upstream_lineage_in_report: current_lineage_map: Set = self.report.upstream_lineage.get( - str(bq_table), set() + str(bq_table), + set(), ) current_lineage_map.add(str(upstream_table)) self.report.upstream_lineage[str(bq_table)] = current_lineage_map @@ -863,13 +889,15 @@ def get_lineage_for_table( downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ mce_builder.make_schema_field_urn( - bq_table_urn, col_lineage_edge.out_column - ) + bq_table_urn, + col_lineage_edge.out_column, + ), ], upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, upstreams=[ mce_builder.make_schema_field_urn( - upstream_table_urn, upstream_col + upstream_table_urn, + upstream_col, ) for upstream_col in col_lineage_edge.in_columns ], @@ -894,11 +922,11 @@ def test_capability(self, project_id: str) -> None: limit=1, ): logger.debug( - f"Connection test got one exported_bigquery_audit_metadata {entry}" + f"Connection test got one exported_bigquery_audit_metadata {entry}", ) else: gcp_logging_client: GCPLoggingClient = self.config.make_gcp_logging_client( - project_id + project_id, ) for entry in self.audit_log_api.get_bigquery_log_entries_via_gcp_logging( gcp_logging_client, @@ -937,7 +965,7 @@ def gen_lineage_workunits_for_external_table( except json.JSONDecodeError as e: self.report.num_skipped_external_table_lineage += 1 logger.warning( - f"Json load failed on loading source uri with error: {e}. The field value was: {uris_str}" + f"Json load failed on loading source uri with error: {e}. The field value was: {uris_str}", ) return @@ -949,7 +977,8 @@ def gen_lineage_workunits_for_external_table( if lineage_info: yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=lineage_info + entityUrn=dataset_urn, + aspect=lineage_info, ).as_workunit() def get_lineage_for_external_table( @@ -992,7 +1021,7 @@ def get_lineage_for_external_table( type=DatasetLineageTypeClass.COPY, ) for source_dataset_urn in gcs_urns - ] + ], ) if not upstreams_list: @@ -1001,7 +1030,7 @@ def get_lineage_for_external_table( if self.config.include_column_lineage_with_gcs: assert graph schema_metadata: Optional[SchemaMetadataClass] = graph.get_schema_metadata( - dataset_urn + dataset_urn, ) for gcs_dataset_urn in gcs_urns: schema_metadata_for_gcs: Optional[SchemaMetadataClass] = ( @@ -1017,14 +1046,15 @@ def get_lineage_for_external_table( if not fine_grained_lineage: logger.warning( f"Failed to retrieve fine-grained lineage for dataset {dataset_urn} and GCS {gcs_dataset_urn}. " - f"Check schema metadata: {schema_metadata} and GCS metadata: {schema_metadata_for_gcs}." + f"Check schema metadata: {schema_metadata} and GCS metadata: {schema_metadata_for_gcs}.", ) continue fine_grained_lineages.extend(fine_grained_lineage) upstream_lineage = UpstreamLineageClass( - upstreams=upstreams_list, fineGrainedLineages=fine_grained_lineages or None + upstreams=upstreams_list, + fineGrainedLineages=fine_grained_lineages or None, ) return upstream_lineage @@ -1033,7 +1063,7 @@ def _get_gcs_path(self, path: str) -> Optional[str]: for path_spec in self.config.gcs_lineage_config.path_specs: if not path_spec.allowed(path): logger.debug( - f"Skipping gcs path {path} as it does not match any path spec." + f"Skipping gcs path {path} as it does not match any path spec.", ) self.report.num_lineage_dropped_gcs_path += 1 continue @@ -1047,7 +1077,7 @@ def _get_gcs_path(self, path: str) -> Optional[str]: ): self.report.num_lineage_dropped_gcs_path += 1 logger.debug( - f"Skipping gcs path {path} as it does not match any path spec." + f"Skipping gcs path {path} as it does not match any path spec.", ) return None @@ -1085,17 +1115,18 @@ def simplify_field_path(field_path): downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ mce_builder.make_schema_field_urn( - dataset_urn, field_path_v1 - ) + dataset_urn, + field_path_v1, + ), ], upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, upstreams=[ mce_builder.make_schema_field_urn( gcs_dataset_urn, simplify_field_path(matching_gcs_field.fieldPath), - ) + ), ], - ) + ), ) return fine_grained_lineages return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py index 182ae2265cb162..8620d73eb8a77e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py @@ -38,7 +38,8 @@ def __init__( @staticmethod def get_partition_range_from_partition_id( - partition_id: str, partition_datetime: Optional[datetime] + partition_id: str, + partition_datetime: Optional[datetime], ) -> Tuple[datetime, datetime]: partition_range_map: Dict[int, Tuple[relativedelta, str]] = { 4: (relativedelta(years=1), "%Y"), @@ -55,12 +56,13 @@ def get_partition_range_from_partition_id( partition_datetime = datetime.strptime(partition_id, format) else: partition_datetime = datetime.strptime( - partition_datetime.strftime(format), format + partition_datetime.strftime(format), + format, ) else: raise ValueError( - f"check your partition_id {partition_id}. It must be yearly/monthly/daily/hourly." + f"check your partition_id {partition_id}. It must be yearly/monthly/daily/hourly.", ) upper_bound_partition_datetime = partition_datetime + duration return partition_datetime, upper_bound_partition_datetime @@ -78,7 +80,7 @@ def generate_partition_profiler_query( See more about partitioned tables at https://cloud.google.com/bigquery/docs/partitioned-tables """ logger.debug( - f"generate partition profiler query for project: {project} schema: {schema} and table {table.name}, partition_datetime: {partition_datetime}" + f"generate partition profiler query for project: {project} schema: {schema} and table {table.name}, partition_datetime: {partition_datetime}", ) partition = table.max_partition_id if table.partition_info and partition: @@ -91,7 +93,7 @@ def generate_partition_profiler_query( ) else: logger.warning( - f"Partitioned table {table.name} without partition column" + f"Partitioned table {table.name} without partition column", ) self.report.profiling_skipped_invalid_partition_ids[ f"{project}.{schema}.{table.name}" @@ -99,18 +101,19 @@ def generate_partition_profiler_query( return None, None else: logger.debug( - f"{table.name} is partitioned and partition column is {partition}" + f"{table.name} is partitioned and partition column is {partition}", ) try: ( partition_datetime, upper_bound_partition_datetime, ) = self.get_partition_range_from_partition_id( - partition, partition_datetime + partition, + partition_datetime, ) except ValueError as e: logger.error( - f"Unable to get partition range for partition id: {partition} it failed with exception {e}" + f"Unable to get partition range for partition id: {partition} it failed with exception {e}", ) self.report.profiling_skipped_invalid_partition_ids[ f"{project}.{schema}.{table.name}" @@ -129,7 +132,7 @@ def generate_partition_profiler_query( partition_where_clause = f"`{partition_column_name}` BETWEEN {partition_data_type}('{partition_datetime}') AND {partition_data_type}('{upper_bound_partition_datetime}')" else: logger.warning( - f"Not supported partition type {table.partition_info.type}" + f"Not supported partition type {table.partition_info.type}", ) self.report.profiling_skipped_invalid_partition_type[ f"{project}.{schema}.{table.name}" @@ -157,26 +160,30 @@ def generate_partition_profiler_query( return None, None def get_workunits( - self, project_id: str, tables: Dict[str, List[BigqueryTable]] + self, + project_id: str, + tables: Dict[str, List[BigqueryTable]], ) -> Iterable[MetadataWorkUnit]: profile_requests: List[TableProfilerRequest] = [] for dataset in tables: for table in tables[dataset]: normalized_table_name = BigqueryTableIdentifier( - project_id=project_id, dataset=dataset, table=table.name + project_id=project_id, + dataset=dataset, + table=table.name, ).get_table_name() if table.external and not self.config.profiling.profile_external_tables: self.report.profiling_skipped_other[f"{project_id}.{dataset}"] += 1 logger.info( - f"Skipping profiling of external table {project_id}.{dataset}.{table.name}" + f"Skipping profiling of external table {project_id}.{dataset}.{table.name}", ) continue # Emit the profile work unit logger.debug( - f"Creating profile request for table {normalized_table_name}" + f"Creating profile request for table {normalized_table_name}", ) profile_request = self.get_profile_request(table, dataset, project_id) if profile_request is not None: @@ -184,7 +191,7 @@ def get_workunits( profile_requests.append(profile_request) else: logger.debug( - f"Table {normalized_table_name} was not eliagible for profiling." + f"Table {normalized_table_name} was not eliagible for profiling.", ) if len(profile_requests) == 0: @@ -198,11 +205,16 @@ def get_workunits( def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> str: return BigqueryTableIdentifier( - project_id=db_name, dataset=schema_name, table=table_name + project_id=db_name, + dataset=schema_name, + table=table_name, ).get_table_name() def get_batch_kwargs( - self, table: BaseTable, schema_name: str, db_name: str + self, + table: BaseTable, + schema_name: str, + db_name: str, ) -> dict: return dict( schema=db_name, # @@ -210,7 +222,10 @@ def get_batch_kwargs( ) def get_profile_request( - self, table: BaseTable, schema_name: str, db_name: str + self, + table: BaseTable, + schema_name: str, + db_name: str, ) -> Optional[TableProfilerRequest]: profile_request = super().get_profile_request(table, schema_name, db_name) @@ -223,7 +238,10 @@ def get_profile_request( bq_table = cast(BigqueryTable, table) (partition, custom_sql) = self.generate_partition_profiler_query( - db_name, schema_name, bq_table, self.config.profiling.partition_datetime + db_name, + schema_name, + bq_table, + self.config.profiling.partition_datetime, ) if partition is None and bq_table.partition_info: @@ -238,10 +256,10 @@ def get_profile_request( and not self.config.profiling.partition_profiling_enabled ): logger.debug( - f"{profile_request.pretty_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled" + f"{profile_request.pretty_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled", ) self.report.profiling_skipped_partition_profiling_disabled.append( - profile_request.pretty_name + profile_request.pretty_name, ) return None @@ -251,7 +269,7 @@ def get_profile_request( dict( custom_sql=custom_sql, partition=partition, - ) + ), ) return profile_request diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries.py index 8a558d7736a389..c117976a717909 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries.py @@ -391,7 +391,9 @@ class BigqueryQuery: def bigquery_audit_metadata_query_template_lineage( - dataset: str, use_date_sharded_tables: bool, limit: Optional[int] = None + dataset: str, + use_date_sharded_tables: bool, + limit: Optional[int] = None, ) -> str: """ Receives a dataset (with project specified) and returns a query template that is used to query exported diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py index 0f9471219c6590..67b257a87ab5c8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py @@ -100,7 +100,8 @@ class BigQueryQueriesExtractorConfig(BigQueryBaseConfig): ) top_n_queries: PositiveInt = Field( - default=10, description="Number of top queries to save to each table." + default=10, + description="Number of top queries to save to each table.", ) include_lineage: bool = True @@ -152,7 +153,7 @@ def __init__( map( self.identifiers.standardize_identifier_case, discovered_tables, - ) + ), ) if discovered_tables else None @@ -259,7 +260,9 @@ def get_workunits_internal( with self.report.query_log_fetch_timer: for project in get_projects( - self.schema_api, self.structured_report, self.filters + self.schema_api, + self.structured_report, + self.filters, ): for entry in self.fetch_query_log(project): self.report.num_queries_by_project[project.id] += 1 @@ -285,7 +288,7 @@ def get_workunits_internal( for query in query_instances.values(): if log_timer.should_report(): logger.info( - f"Added {i} deduplicated query log entries to SQL aggregator" + f"Added {i} deduplicated query log entries to SQL aggregator", ) if report_timer.should_report() and self.report.sql_aggregator: @@ -301,7 +304,8 @@ def get_workunits_internal( audit_log_file.unlink(missing_ok=True) def deduplicate_queries( - self, queries: FileBackedList[ObservedQuery] + self, + queries: FileBackedList[ObservedQuery], ) -> FileBackedDict[Dict[int, ObservedQuery]]: # This fingerprint based deduplication is done here to reduce performance hit due to # repetitive sql parsing while adding observed query to aggregator that would otherwise @@ -320,12 +324,17 @@ def deduplicate_queries( time_bucket = 0 if query.timestamp: time_bucket = datetime_to_ts_millis( - get_time_bucket(query.timestamp, self.config.window.bucket_duration) + get_time_bucket( + query.timestamp, + self.config.window.bucket_duration, + ), ) # Not using original BQ query hash as it's not always present query.query_hash = get_query_fingerprint( - query.query, self.identifiers.platform, fast=True + query.query, + self.identifiers.platform, + fast=True, ) query_instances = queries_deduped.setdefault(query.query_hash, {}) @@ -345,12 +354,14 @@ def fetch_query_log(self, project: BigqueryProject) -> Iterable[ObservedQuery]: for region in regions: with self.structured_report.report_exc( - f"Error fetching query log from BQ Project {project.id} for {region}" + f"Error fetching query log from BQ Project {project.id} for {region}", ): yield from self.fetch_region_query_log(project, region) def fetch_region_query_log( - self, project: BigqueryProject, region: str + self, + project: BigqueryProject, + region: str, ) -> Iterable[ObservedQuery]: # Each region needs to be a different query query_log_query = _build_enriched_query_log_query( @@ -396,7 +407,7 @@ def _parse_audit_log_row(self, row: BigQueryJob) -> ObservedQuery: timestamp=row["creation_time"], user=( CorpUserUrn.from_string( - self.identifiers.gen_user_urn(row["user_email"]) + self.identifiers.gen_user_urn(row["user_email"]), ) if row["user_email"] else None @@ -472,7 +483,7 @@ def _build_enriched_query_log_query( ] unsupported_statement_types = ",".join( - [f"'{statement_type}'" for statement_type in UNSUPPORTED_STATEMENT_TYPES] + [f"'{statement_type}'" for statement_type in UNSUPPORTED_STATEMENT_TYPES], ) # NOTE the use of partition column creation_time as timestamp here. diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py index c2b849e58fc6dc..e72aab2f4ed5f3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py @@ -122,7 +122,8 @@ def __init__(self, config: BigQueryV2Config): "resource": lambda e: str(e.resource), "name": lambda e: e.jobName, "timestamp": lambda e: get_time_bucket( - e.timestamp, config.bucket_duration + e.timestamp, + config.bucket_duration, ), "user": lambda e: e.actor_email, "from_query": lambda e: int(e.from_query), @@ -260,7 +261,8 @@ def usage_statistics(self, top_n: int) -> Iterator[UsageStatistic]: query = self.usage_statistics_query(top_n) rows = self.read_events.sql_query_iterator( - query, refs=[self.query_events, self.column_accesses] + query, + refs=[self.query_events, self.column_accesses], ) for row in rows: yield self.UsageStatistic( @@ -291,9 +293,9 @@ def report_disk_usage(self, report: BigQueryV2Report) -> None: { "main": humanfriendly.format_size(os.path.getsize(self.conn.filename)), "queries": humanfriendly.format_size( - os.path.getsize(self.queries._conn.filename) + os.path.getsize(self.queries._conn.filename), ), - } + }, ) @@ -334,7 +336,8 @@ def __init__( def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: return self.redundant_run_skip_handler.suggest_run_time_window( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) else: return self.config.start_time, self.config.end_time @@ -369,7 +372,9 @@ def _should_ingest_usage(self) -> bool: return True def get_usage_workunits( - self, projects: Iterable[str], table_refs: Collection[str] + self, + projects: Iterable[str], + table_refs: Collection[str], ) -> Iterable[MetadataWorkUnit]: if not self._should_ingest_usage(): return @@ -385,7 +390,9 @@ def get_usage_workunits( ) def _get_workunits_internal( - self, events: Iterable[AuditEvent], table_refs: Collection[str] + self, + events: Iterable[AuditEvent], + table_refs: Collection[str], ) -> Iterable[MetadataWorkUnit]: try: with BigQueryUsageState(self.config) as usage_state: @@ -395,7 +402,8 @@ def _get_workunits_internal( if self.config.usage.include_operational_stats: yield from self._generate_operational_workunits( - usage_state, table_refs + usage_state, + table_refs, ) yield from auto_empty_dataset_usage_statistics( @@ -407,7 +415,7 @@ def _get_workunits_internal( ), dataset_urns={ self.identifiers.gen_dataset_urn_from_raw_ref( - BigQueryTableRef.from_string_name(ref) + BigQueryTableRef.from_string_name(ref), ) for ref in table_refs }, @@ -418,7 +426,8 @@ def _get_workunits_internal( self.report_status("usage-ingestion", False) def generate_read_events_from_query( - self, query_event_on_view: QueryEvent + self, + query_event_on_view: QueryEvent, ) -> Iterable[AuditEvent]: try: tables = self.get_tables_from_query( @@ -429,12 +438,12 @@ def generate_read_events_from_query( assert len(tables) != 0 for table in tables: yield AuditEvent.create( - ReadEvent.from_query_event(table, query_event_on_view) + ReadEvent.from_query_event(table, query_event_on_view), ) except Exception as ex: logger.debug( f"Generating read events failed for this query on view: {query_event_on_view.query}. " - f"Usage won't be added. The error was {ex}." + f"Usage won't be added. The error was {ex}.", ) self.report.num_view_query_events_failed_sql_parsing += 1 @@ -472,11 +481,15 @@ def _ingest_events( for new_event in self.generate_read_events_from_query(query_event): with self.report.processing_perf.store_usage_event_sec: num_generated += self._store_usage_event( - new_event, usage_state, table_refs + new_event, + usage_state, + table_refs, ) with self.report.processing_perf.store_usage_event_sec: num_aggregated += self._store_usage_event( - audit_event, usage_state, table_refs + audit_event, + usage_state, + table_refs, ) except Exception as e: @@ -493,13 +506,16 @@ def _ingest_events( usage_state.delete_original_read_events_for_view_query_events() def _generate_operational_workunits( - self, usage_state: BigQueryUsageState, table_refs: Collection[str] + self, + usage_state: BigQueryUsageState, + table_refs: Collection[str], ) -> Iterable[MetadataWorkUnit]: with self.report.new_stage(f"*: {USAGE_EXTRACTION_OPERATIONAL_STATS}"): for audit_event in usage_state.standalone_events(): try: operational_wu = self._create_operation_workunit( - audit_event, table_refs + audit_event, + table_refs, ) if operational_wu: yield operational_wu @@ -512,7 +528,8 @@ def _generate_operational_workunits( ) def _generate_usage_workunits( - self, usage_state: BigQueryUsageState + self, + usage_state: BigQueryUsageState, ) -> Iterable[MetadataWorkUnit]: with self.report.new_stage(f"*: {USAGE_EXTRACTION_USAGE_AGGREGATION}"): top_n = ( @@ -525,7 +542,8 @@ def _generate_usage_workunits( query_freq = [ ( self.uuid_to_query.get( - query_hash, usage_state.queries[query_hash] + query_hash, + usage_state.queries[query_hash], ), count, ) @@ -560,7 +578,7 @@ def _get_usage_events(self, projects: Iterable[str]) -> Iterable[AuditEvent]: with PerfTimer() as timer: try: with self.report.new_stage( - f"{project_id}: {USAGE_EXTRACTION_INGESTION}" + f"{project_id}: {USAGE_EXTRACTION_INGESTION}", ): yield from self._get_parsed_bigquery_log_events(project_id) except Exception as e: @@ -573,7 +591,7 @@ def _get_usage_events(self, projects: Iterable[str]) -> Iterable[AuditEvent]: self.report_status(f"usage-extraction-{project_id}", False) self.report.usage_extraction_sec[project_id] = timer.elapsed_seconds( - digits=2 + digits=2, ) def _store_usage_event( @@ -637,7 +655,8 @@ def _get_destination_table(event: AuditEvent) -> Optional[BigQueryTableRef]: return None def _extract_operational_meta( - self, event: AuditEvent + self, + event: AuditEvent, ) -> Optional[OperationalDataMeta]: # If we don't have Query object that means this is a queryless read operation or a read operation which was not executed as JOB # https://cloud.google.com/bigquery/docs/reference/auditlogs/rest/Shared.Types/BigQueryAuditMetadata.TableDataRead.Reason/ @@ -646,7 +665,7 @@ def _extract_operational_meta( statement_type=OperationTypeClass.CUSTOM, custom_type="CUSTOM_READ", last_updated_timestamp=int( - event.read_event.timestamp.timestamp() * 1000 + event.read_event.timestamp.timestamp() * 1000, ), actor_email=event.read_event.actor_email, ) @@ -674,7 +693,7 @@ def _extract_operational_meta( statement_type=statement_type, custom_type=custom_type, last_updated_timestamp=int( - event.query_event.timestamp.timestamp() * 1000 + event.query_event.timestamp.timestamp() * 1000, ), actor_email=event.query_event.actor_email, ) @@ -682,7 +701,9 @@ def _extract_operational_meta( return None def _create_operation_workunit( - self, event: AuditEvent, table_refs: Collection[str] + self, + event: AuditEvent, + table_refs: Collection[str], ) -> Optional[MetadataWorkUnit]: if not event.read_event and not event.query_event: return None @@ -696,7 +717,7 @@ def _create_operation_workunit( or str(destination_table) not in table_refs ): logger.debug( - f"Filtering out operation {event.query_event}: invalid destination {destination_table}." + f"Filtering out operation {event.query_event}: invalid destination {destination_table}.", ) self.report.num_usage_operations_dropped += 1 return None @@ -715,7 +736,7 @@ def _create_operation_workunit( if event.query_event and event.query_event.referencedTables: for table in event.query_event.referencedTables: affected_datasets.append( - self.identifiers.gen_dataset_urn_from_raw_ref(table) + self.identifiers.gen_dataset_urn_from_raw_ref(table), ) operation_aspect = OperationClass( @@ -740,7 +761,8 @@ def _create_operation_workunit( ).as_workunit() def _create_operational_custom_properties( - self, event: AuditEvent + self, + event: AuditEvent, ) -> Dict[str, str]: custom_properties: Dict[str, str] = {} # This only needs for backward compatibility reason. To make sure we generate the same operational metadata than before @@ -749,7 +771,7 @@ def _create_operational_custom_properties( if event.query_event.end_time and event.query_event.start_time: custom_properties["millisecondsTaken"] = str( int(event.query_event.end_time.timestamp() * 1000) - - int(event.query_event.start_time.timestamp() * 1000) + - int(event.query_event.start_time.timestamp() * 1000), ) if event.query_event.job_name: @@ -759,7 +781,7 @@ def _create_operational_custom_properties( if event.query_event.billed_bytes: custom_properties["bytesProcessed"] = str( - event.query_event.billed_bytes + event.query_event.billed_bytes, ) if event.query_event.default_dataset: @@ -772,13 +794,14 @@ def _create_operational_custom_properties( if event.read_event.fieldsRead: custom_properties["fieldsRead"] = ",".join( - event.read_event.fieldsRead + event.read_event.fieldsRead, ) return custom_properties def _parse_bigquery_log_entry( - self, entry: Union[AuditLogEntry, BigQueryAuditMetadata] + self, + entry: Union[AuditLogEntry, BigQueryAuditMetadata], ) -> Optional[AuditEvent]: event: Optional[Union[ReadEvent, QueryEvent]] = None missing_read_entry = ReadEvent.get_missing_key_entry(entry) @@ -801,30 +824,33 @@ def _parse_bigquery_log_entry( if event is None and missing_query_entry_v2 is None: event = QueryEvent.from_entry_v2( - entry, self.config.debug_include_full_payloads + entry, + self.config.debug_include_full_payloads, ) self.report.num_query_events += 1 if event is None: logger.warning( f"Unable to parse {type(entry)} missing read {missing_read_entry}, " - f"missing query {missing_query_entry} missing v2 {missing_query_entry_v2} for {entry}" + f"missing query {missing_query_entry} missing v2 {missing_query_entry_v2} for {entry}", ) return None return AuditEvent.create(event) def _parse_exported_bigquery_audit_metadata( - self, audit_metadata: BigQueryAuditMetadata + self, + audit_metadata: BigQueryAuditMetadata, ) -> Optional[AuditEvent]: event: Optional[Union[ReadEvent, QueryEvent]] = None missing_read_event = ReadEvent.get_missing_key_exported_bigquery_audit_metadata( - audit_metadata + audit_metadata, ) if missing_read_event is None: event = ReadEvent.from_exported_bigquery_audit_metadata( - audit_metadata, self.config.debug_include_full_payloads + audit_metadata, + self.config.debug_include_full_payloads, ) if not self._is_table_allowed(event.resource): self.report.num_filtered_read_events += 1 @@ -838,7 +864,8 @@ def _parse_exported_bigquery_audit_metadata( ) if event is None and missing_query_event is None: event = QueryEvent.from_exported_bigquery_audit_metadata( - audit_metadata, self.config.debug_include_full_payloads + audit_metadata, + self.config.debug_include_full_payloads, ) self.report.num_query_events += 1 @@ -846,14 +873,16 @@ def _parse_exported_bigquery_audit_metadata( logger.warning( f"{audit_metadata['logName']}-{audit_metadata['insertId']} " f"Unable to parse audit metadata missing QueryEvent keys:{str(missing_query_event)} " - f"ReadEvent keys: {str(missing_read_event)} for {audit_metadata}" + f"ReadEvent keys: {str(missing_read_event)} for {audit_metadata}", ) return None return AuditEvent.create(event) def _get_parsed_bigquery_log_events( - self, project_id: str, limit: Optional[int] = None + self, + project_id: str, + limit: Optional[int] = None, ) -> Iterable[AuditEvent]: audit_log_api = BigQueryAuditLogApi( self.report.audit_log_api_perf, @@ -887,7 +916,7 @@ def _get_parsed_bigquery_log_events( logging_client = self.config.make_gcp_logging_client(project_id) logger.info( f"Start loading log entries from BigQuery for {project_id} " - f"with start_time={corrected_start_time} and end_time={corrected_end_time}" + f"with start_time={corrected_start_time} and end_time={corrected_end_time}", ) entries = audit_log_api.get_bigquery_log_entries_via_gcp_logging( logging_client, @@ -918,7 +947,10 @@ def _generate_filter(self, corrected_start_time, corrected_end_time): ) def get_tables_from_query( - self, query: str, default_project: str, default_dataset: Optional[str] = None + self, + query: str, + default_project: str, + default_dataset: Optional[str] = None, ) -> List[BigQueryTableRef]: """ This method attempts to parse bigquery objects read in the query @@ -936,7 +968,7 @@ def get_tables_from_query( ) except Exception: logger.debug( - f"Sql parsing failed on this query on view: {query}. Usage won't be added." + f"Sql parsing failed on this query on view: {query}. Usage won't be added.", ) logger.debug(result.debug_info) return [] diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py index 062c64d45767fc..3ae25271f8e0a1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py @@ -126,7 +126,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -171,10 +173,11 @@ def get_workunits_internal( yield from self.profiler.get_workunits(self.cassandra_data) def _generate_keyspace_container( - self, keyspace: CassandraKeyspace + self, + keyspace: CassandraKeyspace, ) -> Iterable[MetadataWorkUnit]: keyspace_container_key = self._generate_keyspace_container_key( - keyspace.keyspace_name + keyspace.keyspace_name, ) yield from gen_containers( container_key=keyspace_container_key, @@ -197,7 +200,8 @@ def _generate_keyspace_container_key(self, keyspace_name: str) -> ContainerKey: # get all tables for a given keyspace, iterate over them to extract column metadata def _extract_tables_from_keyspace( - self, keyspace_name: str + self, + keyspace_name: str, ) -> Iterable[MetadataWorkUnit]: self.cassandra_data.keyspaces.append(keyspace_name) tables: List[CassandraTable] = self.cassandra_api.get_tables(keyspace_name) @@ -223,7 +227,9 @@ def _extract_tables_from_keyspace( # 1. Extract columns from table, then construct and emit the schemaMetadata aspect. try: yield from self._extract_columns_from_table( - keyspace_name, table_name, dataset_urn + keyspace_name, + table_name, + dataset_urn, ) except Exception as e: self.report.failure( @@ -242,7 +248,7 @@ def _extract_tables_from_keyspace( aspect=SubTypesClass( typeNames=[ DatasetSubTypes.TABLE, - ] + ], ), ).as_workunit() @@ -259,7 +265,7 @@ def _extract_tables_from_keyspace( "compression": json.dumps(table.compression), "crc_check_chance": str(table.crc_check_chance), "dclocal_read_repair_chance": str( - table.dclocal_read_repair_chance + table.dclocal_read_repair_chance, ), "default_time_to_live": str(table.default_time_to_live), "extensions": json.dumps(table.extensions), @@ -267,7 +273,7 @@ def _extract_tables_from_keyspace( "max_index_interval": str(table.max_index_interval), "min_index_interval": str(table.min_index_interval), "memtable_flush_period_in_ms": str( - table.memtable_flush_period_in_ms + table.memtable_flush_period_in_ms, ), "read_repair_chance": str(table.read_repair_chance), "speculative_retry": str(table.speculative_retry), @@ -286,24 +292,30 @@ def _extract_tables_from_keyspace( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ).as_workunit() # get all columns for a given table, iterate over them to extract column metadata def _extract_columns_from_table( - self, keyspace_name: str, table_name: str, dataset_urn: str + self, + keyspace_name: str, + table_name: str, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: column_infos: List[CassandraColumn] = self.cassandra_api.get_columns( - keyspace_name, table_name + keyspace_name, + table_name, ) schema_fields: List[SchemaField] = list( - CassandraToSchemaFieldConverter.get_schema_fields(column_infos) + CassandraToSchemaFieldConverter.get_schema_fields(column_infos), ) if not schema_fields: self.report.report_warning( - message="Table has no columns, skipping", context=table_name + message="Table has no columns, skipping", + context=table_name, ) return @@ -318,7 +330,7 @@ def _extract_columns_from_table( version=0, hash="", platformSchema=OtherSchemaClass( - rawSchema=json.dumps(jsonable_column_infos) + rawSchema=json.dumps(jsonable_column_infos), ), fields=schema_fields, ) @@ -329,7 +341,8 @@ def _extract_columns_from_table( ).as_workunit() def _extract_views_from_keyspace( - self, keyspace_name: str + self, + keyspace_name: str, ) -> Iterable[MetadataWorkUnit]: views: List[CassandraView] = self.cassandra_api.get_views(keyspace_name) for view in views: @@ -353,7 +366,7 @@ def _extract_views_from_keyspace( aspect=SubTypesClass( typeNames=[ DatasetSubTypes.VIEW, - ] + ], ), ).as_workunit() @@ -380,7 +393,7 @@ def _extract_views_from_keyspace( "crc_check_chance": str(view.crc_check_chance), "include_all_columns": str(view.include_all_columns), "dclocal_read_repair_chance": str( - view.dclocal_read_repair_chance + view.dclocal_read_repair_chance, ), "default_time_to_live": str(view.default_time_to_live), "extensions": json.dumps(view.extensions), @@ -388,7 +401,7 @@ def _extract_views_from_keyspace( "max_index_interval": str(view.max_index_interval), "min_index_interval": str(view.min_index_interval), "memtable_flush_period_in_ms": str( - view.memtable_flush_period_in_ms + view.memtable_flush_period_in_ms, ), "read_repair_chance": str(view.read_repair_chance), "speculative_retry": str(view.speculative_retry), @@ -398,7 +411,9 @@ def _extract_views_from_keyspace( try: yield from self._extract_columns_from_table( - keyspace_name, view_name, dataset_urn + keyspace_name, + view_name, + dataset_urn, ) except Exception as e: self.report.failure( @@ -416,7 +431,9 @@ def _extract_views_from_keyspace( platform_instance=self.config.platform_instance, ) fineGrainedLineages = self.get_upstream_fields_of_field_in_datasource( - view_name, dataset_urn, upstream_urn + view_name, + dataset_urn, + upstream_urn, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -425,7 +442,7 @@ def _extract_views_from_keyspace( UpstreamClass( dataset=upstream_urn, type=DatasetLineageTypeClass.VIEW, - ) + ), ], fineGrainedLineages=fineGrainedLineages, ), @@ -442,13 +459,17 @@ def _extract_views_from_keyspace( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ).as_workunit() def get_upstream_fields_of_field_in_datasource( - self, table_name: str, dataset_urn: str, upstream_urn: str + self, + table_name: str, + dataset_urn: str, + upstream_urn: str, ) -> List[FineGrainedLineageClass]: column_infos = self.cassandra_data.columns.get(table_name, []) # Collect column-level lineage @@ -462,7 +483,7 @@ def get_upstream_fields_of_field_in_datasource( downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[make_schema_field_urn(dataset_urn, source_column)], upstreams=[make_schema_field_urn(upstream_urn, source_column)], - ) + ), ) return fine_grained_lineages diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py index 4cf0613762aab8..566e9bca7f71d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py @@ -65,10 +65,10 @@ class CassandraView(CassandraTable): class CassandraEntities: keyspaces: List[str] = field(default_factory=list) tables: Dict[str, List[str]] = field( - default_factory=dict + default_factory=dict, ) # Maps keyspace -> tables columns: Dict[str, List[CassandraColumn]] = field( - default_factory=dict + default_factory=dict, ) # Maps tables -> columns @@ -123,7 +123,8 @@ def authenticate(self) -> bool: return True if self.config.username and self.config.password: auth_provider = PlainTextAuthProvider( - username=self.config.username, password=self.config.password + username=self.config.username, + password=self.config.password, ) cluster = Cluster( [self.config.contact_point], @@ -142,7 +143,9 @@ def authenticate(self) -> bool: return True except OperationTimedOut as e: self.report.failure( - message="Failed to Authenticate", context=f"{str(e.errors)}", exc=e + message="Failed to Authenticate", + context=f"{str(e.errors)}", + exc=e, ) return False except DriverException as e: @@ -174,7 +177,9 @@ def get_keyspaces(self) -> List[CassandraKeyspace]: return keyspace_list except DriverException as e: self.report.warning( - message="Failed to fetch keyspaces", context=f"{str(e)}", exc=e + message="Failed to fetch keyspaces", + context=f"{str(e)}", + exc=e, ) return [] except Exception as e: @@ -227,7 +232,8 @@ def get_columns(self, keyspace_name: str, table_name: str) -> List[CassandraColu """Fetch all columns for a given table.""" try: column_infos = self.get( - CassandraQueries.GET_COLUMNS_QUERY, [keyspace_name, table_name] + CassandraQueries.GET_COLUMNS_QUERY, + [keyspace_name, table_name], ) column_list = [ CassandraColumn( @@ -244,7 +250,9 @@ def get_columns(self, keyspace_name: str, table_name: str) -> List[CassandraColu return column_list except DriverException as e: self.report.warning( - message="Failed to fetch columns for table", context=f"{str(e)}", exc=e + message="Failed to fetch columns for table", + context=f"{str(e)}", + exc=e, ) return [] except Exception as e: @@ -287,7 +295,9 @@ def get_views(self, keyspace_name: str) -> List[CassandraView]: return view_list except DriverException as e: self.report.warning( - message="Failed to fetch views for keyspace", context=f"{str(e)}", exc=e + message="Failed to fetch views for keyspace", + context=f"{str(e)}", + exc=e, ) return [] except Exception as e: @@ -309,7 +319,9 @@ def execute(self, query: str, limit: Optional[int] = None) -> List: return result_set except DriverException as e: self.report.warning( - message="Failed to fetch stats for keyspace", context=str(e), exc=e + message="Failed to fetch stats for keyspace", + context=str(e), + exc=e, ) return [] except Exception: diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_config.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_config.py index 340bdb68aa4585..0680b44a8abff5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_config.py @@ -44,12 +44,15 @@ class CassandraCloudConfig(ConfigModel): ) request_timeout: int = Field( - default=600, description="Timeout in seconds for individual Cassandra requests." + default=600, + description="Timeout in seconds for individual Cassandra requests.", ) class CassandraSourceConfig( - PlatformInstanceConfigMixin, StatefulIngestionConfigBase, EnvConfigMixin + PlatformInstanceConfigMixin, + StatefulIngestionConfigBase, + EnvConfigMixin, ): """ Configuration for connecting to a Cassandra or DataStax Astra DB source. @@ -61,7 +64,8 @@ class CassandraSourceConfig( ) port: int = Field( - default=9042, description="Port number to connect to the Cassandra instance." + default=9042, + description="Port number to connect to the Cassandra instance.", ) username: Optional[str] = Field( @@ -107,5 +111,5 @@ class CassandraSourceConfig( def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py index 7bf1d66f618a4b..e60adebb77e4c1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py @@ -66,13 +66,14 @@ def __init__( self.report = report def get_workunits( - self, cassandra_data: CassandraEntities + self, + cassandra_data: CassandraEntities, ) -> Iterable[MetadataWorkUnit]: for keyspace_name in cassandra_data.keyspaces: tables = cassandra_data.tables.get(keyspace_name, []) with self.report.new_stage(f"{keyspace_name}: {PROFILING}"): with ThreadPoolExecutor( - max_workers=self.config.profiling.max_workers + max_workers=self.config.profiling.max_workers, ) as executor: future_to_dataset = { executor.submit( @@ -120,7 +121,7 @@ def generate_profile( if not self.config.profile_pattern.allowed(f"{keyspace_name}.{table_name}"): self.report.profiling_skipped_table_profile_pattern[keyspace_name] += 1 logger.info( - f"Table {table_name} in {keyspace_name}, not allowed for profiling" + f"Table {table_name} in {keyspace_name}, not allowed for profiling", ) return @@ -139,7 +140,8 @@ def generate_profile( if profile_aspect: self.report.report_entity_profiled(table_name) mcp = MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=profile_aspect + entityUrn=dataset_urn, + aspect=profile_aspect, ) yield mcp.as_workunit() @@ -156,7 +158,9 @@ def populate_profile_aspect(self, profile_data: ProfileData) -> DatasetProfileCl ) def _create_field_profile( - self, field_name: str, field_stats: ColumnMetric + self, + field_name: str, + field_stats: ColumnMetric, ) -> DatasetFieldProfileClass: quantiles = field_stats.quantiles return DatasetFieldProfileClass( @@ -180,12 +184,15 @@ def _create_field_profile( ) def profile_table( - self, keyspace_name: str, table_name: str, columns: List[CassandraColumn] + self, + keyspace_name: str, + table_name: str, + columns: List[CassandraColumn], ) -> ProfileData: profile_data = ProfileData() resp = self.api.execute( - CassandraQueries.ROW_COUNT.format(keyspace_name, table_name) + CassandraQueries.ROW_COUNT.format(keyspace_name, table_name), ) if resp: profile_data.row_count = resp[0].row_count @@ -194,7 +201,7 @@ def profile_table( if not self.config.profiling.profile_table_level_only: resp = self.api.execute( - f'SELECT {", ".join([col.column_name for col in columns])} FROM {keyspace_name}."{table_name}"' + f'SELECT {", ".join([col.column_name for col in columns])} FROM {keyspace_name}."{table_name}"', ) profile_data.column_metrics = self._collect_column_data(resp, columns) @@ -215,7 +222,9 @@ def _parse_profile_results(self, profile_data: ProfileData) -> ProfileData: return profile_data def _collect_column_data( - self, rows: List[Any], columns: List[CassandraColumn] + self, + rows: List[Any], + columns: List[CassandraColumn], ) -> Dict[str, ColumnMetric]: metrics = {column.column_name: ColumnMetric() for column in columns} diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py index b467ca0aca6be4..3aa6c7371cd24d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py @@ -30,7 +30,7 @@ # we always skip over ingesting metadata about these keyspaces SYSTEM_KEYSPACE_LIST = set( - ["system", "system_auth", "system_schema", "system_distributed", "system_traces"] + ["system", "system_auth", "system_schema", "system_distributed", "system_traces"], ) @@ -57,7 +57,7 @@ def report_entity_scanned(self, name: str, ent_type: str = "View") -> None: # TODO Need to create seperate common config for profiling report profiling_skipped_other: TopKDict[str, int] = field(default_factory=int_top_k_dict) profiling_skipped_table_profile_pattern: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) def report_entity_profiled(self, name: str) -> None: @@ -109,19 +109,20 @@ class CassandraToSchemaFieldConverter: def get_column_type(cassandra_column_type: str) -> SchemaFieldDataType: type_class: Optional[Type] = ( CassandraToSchemaFieldConverter._field_type_to_schema_field_type.get( - cassandra_column_type + cassandra_column_type, ) ) if type_class is None: logger.warning( - f"Cannot map {cassandra_column_type!r} to SchemaFieldDataType, using NullTypeClass." + f"Cannot map {cassandra_column_type!r} to SchemaFieldDataType, using NullTypeClass.", ) type_class = NullTypeClass return SchemaFieldDataType(type=type_class()) def _get_schema_fields( - self, cassandra_column_infos: List[CassandraColumn] + self, + cassandra_column_infos: List[CassandraColumn], ) -> Generator[SchemaField, None, None]: # append each schema field (sort so output is consistent) for column_info in cassandra_column_infos: @@ -129,7 +130,7 @@ def _get_schema_fields( cassandra_type: str = column_info.type schema_field_data_type: SchemaFieldDataType = self.get_column_type( - cassandra_type + cassandra_type, ) schema_field: SchemaField = SchemaField( fieldPath=column_name, @@ -143,7 +144,8 @@ def _get_schema_fields( @classmethod def get_schema_fields( - cls, cassandra_column_infos: List[CassandraColumn] + cls, + cassandra_column_infos: List[CassandraColumn], ) -> Generator[SchemaField, None, None]: converter = cls() yield from converter._get_schema_fields(cassandra_column_infos) diff --git a/metadata-ingestion/src/datahub/ingestion/source/common/data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/common/data_reader.py index 8a0a492ca5d333..6b8af3793adcc2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/common/data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/common/data_reader.py @@ -6,7 +6,10 @@ class DataReader(Closeable): def get_sample_data_for_column( - self, table_id: List[str], column_name: str, sample_size: int + self, + table_id: List[str], + column_name: str, + sample_size: int, ) -> list: raise NotImplementedError() diff --git a/metadata-ingestion/src/datahub/ingestion/source/confluent_schema_registry.py b/metadata-ingestion/src/datahub/ingestion/source/confluent_schema_registry.py index 5ba4dd13fb2ac9..05c51cf774e646 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/confluent_schema_registry.py +++ b/metadata-ingestion/src/datahub/ingestion/source/confluent_schema_registry.py @@ -46,7 +46,9 @@ class ConfluentSchemaRegistry(KafkaSchemaRegistryBase): """ def __init__( - self, source_config: KafkaSourceConfig, report: KafkaSourceReport + self, + source_config: KafkaSourceConfig, + report: KafkaSourceReport, ) -> None: self.source_config: KafkaSourceConfig = source_config self.report: KafkaSourceReport = report @@ -54,12 +56,12 @@ def __init__( { "url": source_config.connection.schema_registry_url, **source_config.connection.schema_registry_config, - } + }, ) self.known_schema_registry_subjects: List[str] = [] try: self.known_schema_registry_subjects.extend( - self.schema_registry_client.get_subjects() + self.schema_registry_client.get_subjects(), ) except Exception as e: logger.warning(f"Failed to get subjects from schema registry: {e}") @@ -74,7 +76,9 @@ def __init__( @classmethod def create( - cls, source_config: KafkaSourceConfig, report: KafkaSourceReport + cls, + source_config: KafkaSourceConfig, + report: KafkaSourceReport, ) -> "ConfluentSchemaRegistry": return cls(source_config, report) @@ -117,7 +121,9 @@ def _compact_schema(schema_str: str) -> str: return json.dumps(json.loads(schema_str), separators=(",", ":")) def get_schema_str_replace_confluent_ref_avro( - self, schema: Schema, schema_seen: Optional[set] = None + self, + schema: Schema, + schema_seen: Optional[set] = None, ) -> str: if not schema.references: return self._compact_schema(schema.schema_str) @@ -132,7 +138,7 @@ def get_schema_str_replace_confluent_ref_avro( if ref_subject not in self.known_schema_registry_subjects: logger.warning( - f"{ref_subject} is not present in the list of registered subjects with schema registry!" + f"{ref_subject} is not present in the list of registered subjects with schema registry!", ) reference_schema = self.schema_registry_client.get_latest_version( @@ -140,7 +146,7 @@ def get_schema_str_replace_confluent_ref_avro( ) schema_seen.add(ref_subject) logger.debug( - f"ref for {ref_subject} is {reference_schema.schema.schema_str}" + f"ref for {ref_subject} is {reference_schema.schema.schema_str}", ) # Replace only external type references with the reference schema recursively. # NOTE: The type pattern is dependent on _compact_schema. @@ -153,11 +159,11 @@ def get_schema_str_replace_confluent_ref_avro( pattern_to_replace = f'{avro_type_kwd}:"{ref_subject}"' if pattern_to_replace not in schema_str: logger.warning( - f"Not match for external schema type: {{name:{ref_name}, subject:{ref_subject}}} in schema:{schema_str}" + f"Not match for external schema type: {{name:{ref_name}, subject:{ref_subject}}} in schema:{schema_str}", ) else: logger.debug( - f"External schema matches by subject, {pattern_to_replace}" + f"External schema matches by subject, {pattern_to_replace}", ) else: logger.debug(f"External schema matches by name, {pattern_to_replace}") @@ -168,7 +174,9 @@ def get_schema_str_replace_confluent_ref_avro( return schema_str def get_schemas_from_confluent_ref_protobuf( - self, schema: Schema, schema_seen: Optional[Set[str]] = None + self, + schema: Schema, + schema_seen: Optional[Set[str]] = None, ) -> List[ProtobufSchema]: all_schemas: List[ProtobufSchema] = [] @@ -186,8 +194,9 @@ def get_schemas_from_confluent_ref_protobuf( schema_seen.add(ref_subject) all_schemas.append( ProtobufSchema( - name=schema_ref.name, content=reference_schema.schema.schema_str - ) + name=schema_ref.name, + content=reference_schema.schema.schema_str, + ), ) return all_schemas @@ -210,7 +219,8 @@ def get_schemas_from_confluent_ref_json( continue reference_schema: RegisteredSchema = ( self.schema_registry_client.get_version( - subject_name=ref_subject, version=schema_ref.version + subject_name=ref_subject, + version=schema_ref.version, ) ) schema_seen.add(ref_subject) @@ -220,7 +230,7 @@ def get_schemas_from_confluent_ref_json( name=schema_ref.name, subject=ref_subject, schema_seen=schema_seen, - ) + ), ) all_schemas.append( JsonSchemaWrapper( @@ -228,12 +238,15 @@ def get_schemas_from_confluent_ref_json( subject=subject, content=schema.schema_str, references=schema.references, - ) + ), ) return all_schemas def _get_schema_and_fields( - self, topic: str, is_key_schema: bool, is_subject: bool + self, + topic: str, + is_key_schema: bool, + is_subject: bool, ) -> Tuple[Optional[Schema], List[SchemaField]]: schema: Optional[Schema] = None kafka_entity = "subject" if is_subject else "topic" @@ -244,18 +257,19 @@ def _get_schema_and_fields( if not is_subject: schema_type_str = "key" if is_key_schema else "value" topic_subject = self._get_subject_for_topic( - topic=topic, is_key_schema=is_key_schema + topic=topic, + is_key_schema=is_key_schema, ) else: topic_subject = topic if topic_subject is not None: logger.debug( - f"The {schema_type_str} schema subject:'{topic_subject}' is found for {kafka_entity}: '{topic}'." + f"The {schema_type_str} schema subject:'{topic_subject}' is found for {kafka_entity}: '{topic}'.", ) try: registered_schema = self.schema_registry_client.get_latest_version( - subject_name=topic_subject + subject_name=topic_subject, ) schema = registered_schema.schema except Exception as e: @@ -269,7 +283,7 @@ def _get_schema_and_fields( ) else: logger.debug( - f"For {kafka_entity}: {topic}, the schema registry subject for the {schema_type_str} schema is not found." + f"For {kafka_entity}: {topic}, the schema registry subject for the {schema_type_str} schema is not found.", ) if not is_key_schema: # Value schema is always expected. Report a warning. @@ -291,7 +305,10 @@ def _get_schema_and_fields( return (schema, fields) def _load_json_schema_with_resolved_references( - self, schema: Schema, name: str, subject: str + self, + schema: Schema, + name: str, + subject: str, ) -> dict: imported_json_schemas: List[JsonSchemaWrapper] = ( self.get_schemas_from_confluent_ref_json(schema, name=name, subject=subject) @@ -305,12 +322,16 @@ def _load_json_schema_with_resolved_references( reference_map[imported_schema.name] = reference_schema jsonref_schema = jsonref.loads( - json.dumps(schema_dict), loader=lambda x: reference_map.get(x) + json.dumps(schema_dict), + loader=lambda x: reference_map.get(x), ) return jsonref_schema def _get_schema_fields( - self, topic: str, schema: Schema, is_key_schema: bool + self, + topic: str, + schema: Schema, + is_key_schema: bool, ) -> List[SchemaField]: # Parse the schema and convert it to SchemaFields. fields: List[SchemaField] = [] @@ -360,8 +381,9 @@ def _get_schema_fields( ) fields = list( JsonSchemaTranslator.get_fields_from_schema( - jsonref_schema, is_key_schema=is_key_schema - ) + jsonref_schema, + is_key_schema=is_key_schema, + ), ) elif not self.source_config.ignore_warnings_on_schema_type: self.report.report_warning( @@ -371,7 +393,10 @@ def _get_schema_fields( return fields def _get_schema_metadata( - self, topic: str, platform_urn: str, is_subject: bool + self, + topic: str, + platform_urn: str, + is_subject: bool, ) -> Optional[SchemaMetadata]: # Process the value schema schema, fields = self._get_schema_and_fields( @@ -411,7 +436,10 @@ def _get_schema_metadata( return None def get_schema_metadata( - self, topic: str, platform_urn: str, is_subject: bool + self, + topic: str, + platform_urn: str, + is_subject: bool, ) -> Optional[SchemaMetadata]: logger.debug(f"Inside get_schema_metadata {topic} {platform_urn}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py b/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py index 8ebb7b9ef7fbdf..1e41787799d04b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py +++ b/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py @@ -318,7 +318,7 @@ def get_resource_description_work_unit( if not entityClass: raise ValueError( - f"Entity Type {entityType} cannot be operated on using csv-enricher" + f"Entity Type {entityType} cannot be operated on using csv-enricher", ) current_editable_properties: Optional[ Union[ @@ -475,7 +475,7 @@ def process_sub_resource_row( ] if len(term_associations_filtered) > 0: field_info.glossaryTerms.terms.extend( - term_associations_filtered + term_associations_filtered, ) needs_write = True else: @@ -506,7 +506,7 @@ def process_sub_resource_row( if not field_match: # this field isn't present in the editable schema metadata aspect, add it current_editable_schema_metadata.editableSchemaFieldInfo.append( - field_info_to_set + field_info_to_set, ) needs_write = True return current_editable_schema_metadata, needs_write @@ -541,7 +541,9 @@ def get_sub_resource_work_units(self) -> Iterable[MetadataWorkUnit]: current_editable_schema_metadata, needs_write, ) = self.process_sub_resource_row( - sub_resource_row, current_editable_schema_metadata, needs_write + sub_resource_row, + current_editable_schema_metadata, + needs_write, ) # Write an MCPW if needed. @@ -553,7 +555,8 @@ def get_sub_resource_work_units(self) -> Iterable[MetadataWorkUnit]: ).as_workunit() def maybe_extract_glossary_terms( - self, row: Dict[str, str] + self, + row: Dict[str, str], ) -> List[GlossaryTermAssociationClass]: if not row["glossary_terms"]: return [] @@ -583,7 +586,9 @@ def maybe_extract_tags(self, row: Dict[str, str]) -> List[TagAssociationClass]: return tag_associations def maybe_extract_owners( - self, row: Dict[str, str], is_resource_row: bool + self, + row: Dict[str, str], + is_resource_row: bool, ) -> List[OwnerClass]: if not is_resource_row: return [] @@ -635,12 +640,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: decoded_content = resp.content.decode("utf-8-sig") rows = list( csv.DictReader( - decoded_content.splitlines(), delimiter=self.config.delimiter - ) + decoded_content.splitlines(), + delimiter=self.config.delimiter, + ), ) except Exception as e: raise ConfigurationError( - f"Cannot read remote file {self.config.filename}, error:{e}" + f"Cannot read remote file {self.config.filename}, error:{e}", ) else: with open(pathlib.Path(self.config.filename), encoding="utf-8-sig") as f: @@ -694,7 +700,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: tag_associations=tag_associations, description=description, domain=domain, - ) + ), ) yield from self.get_sub_resource_work_units() diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/config.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/config.py index ede7d3c3c56959..3339de6c33221d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/config.py @@ -8,5 +8,5 @@ class PathSpecsConfigMixin(ConfigModel): path_specs: List[PathSpec] = Field( - description="List of PathSpec. See [below](#path-spec) the details about PathSpec" + description="List of PathSpec. See [below](#path-spec) the details about PathSpec", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/data_lake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/data_lake_utils.py index 35e40ec2d46cb4..8c850c1a3b8c63 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/data_lake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/data_lake_utils.py @@ -99,7 +99,7 @@ def get_protocol(path: str) -> str: return protocol else: raise ValueError( - f"Unable to get protocol or invalid protocol from path: {path}" + f"Unable to get protocol or invalid protocol from path: {path}", ) @staticmethod @@ -129,7 +129,9 @@ def get_base_full_path(self, path: str) -> str: raise ValueError(f"Unable to get base full path from path: {path}") def create_container_hierarchy( - self, path: str, dataset_urn: str + self, + path: str, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: logger.debug(f"Creating containers for {dataset_urn}") base_full_path = path @@ -156,7 +158,7 @@ def create_container_hierarchy( # Dataset is in the root folder if not parent_folder_path and parent_key is None: logger.warning( - f"Failed to associate Dataset ({dataset_urn}) with container" + f"Failed to associate Dataset ({dataset_urn}) with container", ) return diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py index 10dd9e9e7e029a..20fd0e400c6d15 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py @@ -53,7 +53,7 @@ def __str__(self): class SortKey(ConfigModel): key: str = Field( - description="The key to sort on. This can be a compound key based on the path_spec variables." + description="The key to sort on. This can be a compound key based on the path_spec variables.", ) type: SortKeyType = Field( default=SortKeyType.STRING, @@ -87,7 +87,7 @@ class Config: arbitrary_types_allowed = True include: str = Field( - description="Path to table. Name variable `{table}` is used to mark the folder with dataset. In absence of `{table}`, file level dataset will be created. Check below examples for more details." + description="Path to table. Name variable `{table}` is used to mark the folder with dataset. In absence of `{table}`, file level dataset will be created. Check below examples for more details.", ) exclude: Optional[List[str]] = Field( default=[], @@ -166,14 +166,16 @@ def allowed(self, path: str, ignore_ext: bool = False) -> bool: return False if not pathlib.PurePath(path).globmatch( - self.glob_include, flags=pathlib.GLOBSTAR + self.glob_include, + flags=pathlib.GLOBSTAR, ): return False logger.debug(f"{path} matched include ") if self.exclude: for exclude_path in self.exclude: if pathlib.PurePath(path).globmatch( - exclude_path, flags=pathlib.GLOBSTAR + exclude_path, + flags=pathlib.GLOBSTAR, ): return False logger.debug(f"{path} is not excluded") @@ -215,7 +217,8 @@ def dir_allowed(self, path: str) -> bool: if self.exclude: for exclude_path in self.exclude: if pathlib.PurePath(path.rstrip("/")).globmatch( - exclude_path.rstrip("/"), flags=pathlib.GLOBSTAR + exclude_path.rstrip("/"), + flags=pathlib.GLOBSTAR, ): return False return True @@ -243,7 +246,8 @@ def get_named_vars(self, path: str) -> Union[None, parse.Result, parse.Match]: return self.compiled_include.parse(path) def get_folder_named_vars( - self, path: str + self, + path: str, ) -> Union[None, parse.Result, parse.Match]: return self.compiled_folder_include.parse(path) @@ -268,7 +272,7 @@ def validate_file_types(cls, v: Optional[List[str]]) -> List[str]: for file_type in v: if file_type not in SUPPORTED_FILE_TYPES: raise ValueError( - f"file type {file_type} not in supported file types. Please specify one from {SUPPORTED_FILE_TYPES}" + f"file type {file_type} not in supported file types. Please specify one from {SUPPORTED_FILE_TYPES}", ) return v @@ -276,7 +280,7 @@ def validate_file_types(cls, v: Optional[List[str]]) -> List[str]: def validate_default_extension(cls, v): if v is not None and v not in SUPPORTED_FILE_TYPES: raise ValueError( - f"default extension {v} not in supported default file extension. Please specify one from {SUPPORTED_FILE_TYPES}" + f"default extension {v} not in supported default file extension. Please specify one from {SUPPORTED_FILE_TYPES}", ) return v @@ -294,7 +298,7 @@ def turn_off_sampling_for_non_s3(cls, v, values): def no_named_fields_in_exclude(cls, v: str) -> str: if len(parse.compile(v).named_fields) != 0: raise ValueError( - f"path_spec.exclude {v} should not contain any named variables" + f"path_spec.exclude {v} should not contain any named variables", ) return v @@ -317,7 +321,7 @@ def table_name_in_include(cls, v, values): for x in parse.compile(v).named_fields ): raise ValueError( - f"Not all named variables used in path_spec.table_name {v} are specified in path_spec.include {values['include']}" + f"Not all named variables used in path_spec.table_name {v} are specified in path_spec.include {values['include']}", ) return v @@ -344,7 +348,8 @@ def compiled_include(self): @cached_property def compiled_folder_include(self): parsable_folder_include = PathSpec.get_parsable_include(self.include).rsplit( - "/", 1 + "/", + 1, )[0] logger.debug(f"parsable_folder_include: {parsable_folder_include}") compiled_folder_include = parse.compile(parsable_folder_include) @@ -401,13 +406,13 @@ def get_partition_from_path(self, path: str) -> Optional[List[Tuple[str, str]]]: if "partition_value" in named_vars.named else named_vars.named["partition"][key] ), - ) + ), ) return partition_keys else: # TODO: Fix this message logger.debug( - "Partition key or value not found. Fallbacking another mechanism to get partition keys" + "Partition key or value not found. Fallbacking another mechanism to get partition keys", ) partition_vars = self.extract_variable_names @@ -425,11 +430,11 @@ def get_partition_from_path(self, path: str) -> Optional[List[Tuple[str, str]]]: if pkey in named_vars.named: if index and index in named_vars.named[pkey]: partition_keys.append( - (f"{pkey}_{index}", named_vars.named[pkey][index]) + (f"{pkey}_{index}", named_vars.named[pkey][index]), ) else: partition_keys.append( - (partition_key, named_vars.named[partition_key]) + (partition_key, named_vars.named[partition_key]), ) return partition_keys @@ -475,7 +480,7 @@ def validate_path_spec(cls, values: Dict) -> Dict[str, Any]: for f in required_fields: if f not in values: logger.debug( - f"Failed to validate because {f} wasn't populated correctly" + f"Failed to validate because {f} wasn't populated correctly", ) return values @@ -501,7 +506,7 @@ def validate_path_spec(cls, values: Dict) -> Dict[str, Any]: ): raise ValueError( f"file type specified ({include_ext}) in path_spec.include is not in specified file " - f'types. Please select one from {values.get("file_types")} or specify ".*" to allow all types' + f'types. Please select one from {values.get("file_types")} or specify ".*" to allow all types', ) return values @@ -513,7 +518,9 @@ def _extract_table_name(self, named_vars: dict) -> str: # TODO: Add support to sort partition folders by the defined partition key pattern. This is not implemented yet. def extract_datetime_partition( - self, path: str, is_folder: bool = False + self, + path: str, + is_folder: bool = False, ) -> Optional[datetime.datetime]: if self.sort_key is None: return None @@ -542,12 +549,13 @@ def extract_datetime_partition( for key in var: template_key = var_key + f"[{key}]" partition_format = partition_format.replace( - f"{{{template_key}}}", var[key] + f"{{{template_key}}}", + var[key], ) else: partition_format.replace(f"{{{var_key}}}", var) return datetime.datetime.strptime(partition_format, datetime_format).replace( - tzinfo=datetime.timezone.utc + tzinfo=datetime.timezone.utc, ) def extract_table_name_and_path(self, path: str) -> Tuple[str, str]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py index 37ccce8f8b9657..afdfbf2dd41b70 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py @@ -127,13 +127,14 @@ def check_ingesting_data(cls, values): ): raise ValueError( "Your current config will not ingest any data." - " Please specify at least one of `database_connection` or `kafka_connection`, ideally both." + " Please specify at least one of `database_connection` or `kafka_connection`, ideally both.", ) return values @pydantic.validator("database_connection") def validate_mysql_scheme( - cls, v: SQLAlchemyConnectionConfig + cls, + v: SQLAlchemyConnectionConfig, ) -> SQLAlchemyConnectionConfig: if "mysql" in v.scheme: if v.scheme != "mysql+pymysql": diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py index 382a0d548e38db..aab660f51c488d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py @@ -39,7 +39,7 @@ def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]: urns = self.get_urns() tasks: List[futures.Future[Iterable[MetadataChangeProposalWrapper]]] = [] with futures.ThreadPoolExecutor( - max_workers=self.config.max_workers + max_workers=self.config.max_workers, ) as executor: for urn in urns: tasks.append(executor.submit(self._get_aspects_for_urn, urn)) diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py index 51a25829d21dba..b0b3316aa1c5d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py @@ -148,7 +148,9 @@ def query(self) -> str: """ def execute_server_cursor( - self, query: str, params: Dict[str, Any] + self, + query: str, + params: Dict[str, Any], ) -> Iterable[Dict[str, Any]]: with self.engine.connect() as conn: if self.engine.dialect.name in ["postgresql", "mysql", "mariadb"]: @@ -168,7 +170,9 @@ def execute_server_cursor( raise ValueError(f"Unsupported dialect: {self.engine.dialect.name}") def _get_rows( - self, from_createdon: datetime, stop_time: datetime + self, + from_createdon: datetime, + stop_time: datetime, ) -> Iterable[Dict[str, Any]]: params = { "exclude_aspects": list(self.config.exclude_aspects), @@ -177,10 +181,12 @@ def _get_rows( yield from self.execute_server_cursor(self.query, params) def get_aspects( - self, from_createdon: datetime, stop_time: datetime + self, + from_createdon: datetime, + stop_time: datetime, ) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]: orderer = VersionOrderer[Dict[str, Any]]( - enabled=self.config.include_all_versions + enabled=self.config.include_all_versions, ) rows = self._get_rows(from_createdon=from_createdon, stop_time=stop_time) for row in orderer(rows): @@ -208,12 +214,13 @@ def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]: yield dict(zip(columns, row)) def _parse_row( - self, row: Dict[str, Any] + self, + row: Dict[str, Any], ) -> Optional[MetadataChangeProposalWrapper]: try: json_aspect = post_json_transform(json.loads(row["metadata"])) json_metadata = post_json_transform( - json.loads(row["systemmetadata"] or "{}") + json.loads(row["systemmetadata"] or "{}"), ) system_metadata = SystemMetadataClass.from_obj(json_metadata) return MetadataChangeProposalWrapper( @@ -224,10 +231,12 @@ def _parse_row( ) except Exception as e: logger.warning( - f"Failed to parse metadata for {row['urn']}: {e}", exc_info=True + f"Failed to parse metadata for {row['urn']}: {e}", + exc_info=True, ) self.report.num_database_parse_errors += 1 self.report.database_parse_errors.setdefault( - str(e), LossyDict() + str(e), + LossyDict(), ).setdefault(row["aspect"], LossyList()).append(row["urn"]) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py index ba073533eccfb5..6c53511c9d28c3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py @@ -49,16 +49,18 @@ def __enter__(self) -> "DataHubKafkaReader": "enable.auto.commit": False, "value.deserializer": AvroDeserializer( schema_registry_client=SchemaRegistryClient( - {"url": self.connection_config.schema_registry_url} + {"url": self.connection_config.schema_registry_url}, ), return_record_name=True, ), - } + }, ) return self def get_mcls( - self, from_offsets: Dict[int, int], stop_time: datetime + self, + from_offsets: Dict[int, int], + stop_time: datetime, ) -> Iterable[Tuple[MetadataChangeLogClass, PartitionOffset]]: # Based on https://github.com/confluentinc/confluent-kafka-python/issues/145#issuecomment-284843254 def on_assign(consumer: Consumer, partitions: List[TopicPartition]) -> None: @@ -74,7 +76,8 @@ def on_assign(consumer: Consumer, partitions: List[TopicPartition]) -> None: self.consumer.unsubscribe() def _poll_partition( - self, stop_time: datetime + self, + stop_time: datetime, ) -> Iterable[Tuple[MetadataChangeLogClass, PartitionOffset]]: while True: msg = self.consumer.poll(10) @@ -93,7 +96,7 @@ def _poll_partition( if mcl.created and mcl.created.time > stop_time.timestamp() * 1000: logger.info( f"Stopped reading from kafka, reached MCL " - f"with audit stamp {parse_ts_millis(mcl.created.time)}" + f"with audit stamp {parse_ts_millis(mcl.created.time)}", ) break diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py index 472abd0a97ec70..4d7f036a0e14b6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py @@ -80,16 +80,19 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.config.database_connection is not None: database_reader = DataHubDatabaseReader( - self.config, self.config.database_connection, self.report + self.config, + self.config.database_connection, + self.report, ) yield from self._get_database_workunits( - from_createdon=state.database_createdon_datetime, reader=database_reader + from_createdon=state.database_createdon_datetime, + reader=database_reader, ) self._commit_progress() else: logger.info( - "Skipping ingestion of versioned aspects as no database_connection provided" + "Skipping ingestion of versioned aspects as no database_connection provided", ) if self.config.kafka_connection is not None: @@ -97,23 +100,26 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if not self.config.include_soft_deleted_entities: if database_reader is None: raise ValueError( - "Cannot exclude soft deleted entities without a database connection" + "Cannot exclude soft deleted entities without a database connection", ) soft_deleted_urns = [ row["urn"] for row in database_reader.get_soft_deleted_rows() ] yield from self._get_kafka_workunits( - from_offsets=state.kafka_offsets, soft_deleted_urns=soft_deleted_urns + from_offsets=state.kafka_offsets, + soft_deleted_urns=soft_deleted_urns, ) self._commit_progress() else: logger.info( - "Skipping ingestion of timeseries aspects as no kafka_connection provided" + "Skipping ingestion of timeseries aspects as no kafka_connection provided", ) def _get_database_workunits( - self, from_createdon: datetime, reader: DataHubDatabaseReader + self, + from_createdon: datetime, + reader: DataHubDatabaseReader, ) -> Iterable[MetadataWorkUnit]: logger.info(f"Fetching database aspects starting from {from_createdon}") progress = ProgressTimer(report_every=timedelta(seconds=60)) @@ -124,7 +130,7 @@ def _get_database_workunits( if progress.should_report(): logger.info( - f"Ingested {i} database aspects so far, currently at {createdon}" + f"Ingested {i} database aspects so far, currently at {createdon}", ) yield mcp.as_workunit() @@ -135,12 +141,14 @@ def _get_database_workunits( or not self.report.num_database_parse_errors ): self.stateful_ingestion_handler.update_checkpoint( - last_createdon=createdon + last_createdon=createdon, ) self._commit_progress(i) def _get_kafka_workunits( - self, from_offsets: Dict[int, int], soft_deleted_urns: List[str] + self, + from_offsets: Dict[int, int], + soft_deleted_urns: List[str], ) -> Iterable[MetadataWorkUnit]: if self.config.kafka_connection is None: return @@ -153,20 +161,21 @@ def _get_kafka_workunits( self.ctx, ) as reader: mcls = reader.get_mcls( - from_offsets=from_offsets, stop_time=self.report.stop_time + from_offsets=from_offsets, + stop_time=self.report.stop_time, ) for i, (mcl, offset) in enumerate(mcls): mcp = MetadataChangeProposalWrapper.try_from_mcl(mcl) if mcp.entityUrn in soft_deleted_urns: self.report.num_timeseries_soft_deleted_aspects_dropped += 1 logger.debug( - f"Dropping soft-deleted aspect of {mcp.aspectName} on {mcp.entityUrn}" + f"Dropping soft-deleted aspect of {mcp.aspectName} on {mcp.entityUrn}", ) continue if mcp.changeType == ChangeTypeClass.DELETE: self.report.num_timeseries_deletions_dropped += 1 logger.debug( - f"Dropping timeseries deletion of {mcp.aspectName} on {mcp.entityUrn}" + f"Dropping timeseries deletion of {mcp.aspectName} on {mcp.entityUrn}", ) continue @@ -177,7 +186,8 @@ def _get_kafka_workunits( yield mcp.as_workunit() else: yield MetadataWorkUnit( - id=f"{mcp.entityUrn}-{mcp.aspectName}-{i}", mcp_raw=mcp + id=f"{mcp.entityUrn}-{mcp.aspectName}-{i}", + mcp_raw=mcp, ) self.report.num_kafka_aspects_ingested += 1 @@ -186,7 +196,7 @@ def _get_kafka_workunits( or not self.report.num_kafka_parse_errors ): self.stateful_ingestion_handler.update_checkpoint( - last_offset=offset + last_offset=offset, ) self._commit_progress(i) diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py index 721fc879894423..cfba4ff838ac09 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py @@ -15,7 +15,7 @@ class DataHubSourceReport(StatefulIngestionReport): num_database_parse_errors: int = 0 # error -> aspect -> [urn] database_parse_errors: LossyDict[str, LossyDict[str, LossyList[str]]] = field( - default_factory=LossyDict + default_factory=LossyDict, ) num_kafka_aspects_ingested: int = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/state.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/state.py index 4bedd331a9aea2..c32190ddd8390f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/state.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/state.py @@ -24,7 +24,8 @@ class DataHubIngestionState(CheckpointStateBase): @property def database_createdon_datetime(self) -> datetime: return datetime.fromtimestamp( - self.database_createdon_ts / 1000, tz=timezone.utc + self.database_createdon_ts / 1000, + tz=timezone.utc, ) @@ -34,7 +35,7 @@ class PartitionOffset(NamedTuple): class StatefulDataHubIngestionHandler( - StatefulIngestionUsecaseHandlerBase[DataHubIngestionState] + StatefulIngestionUsecaseHandlerBase[DataHubIngestionState], ): def __init__(self, source: "DataHubSource"): self.state_provider = source.state_provider @@ -50,7 +51,8 @@ def is_checkpointing_enabled(self) -> bool: def get_last_run_state(self) -> DataHubIngestionState: if self.is_checkpointing_enabled() and not self.config.ignore_old_state: last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, DataHubIngestionState + self.job_id, + DataHubIngestionState, ) if last_checkpoint and last_checkpoint.state: return last_checkpoint.state @@ -63,7 +65,7 @@ def create_checkpoint(self) -> Optional[Checkpoint[DataHubIngestionState]]: if self.pipeline_name is None: raise ValueError( - "Pipeline name must be set to use stateful datahub ingestion" + "Pipeline name must be set to use stateful datahub ingestion", ) return Checkpoint( diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index 41b59a9c8b892c..d090ec19776a94 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -73,7 +73,7 @@ def set_metadata_endpoint(cls, values: dict) -> dict: metadata_endpoint = infer_metadata_endpoint(values["access_url"]) if metadata_endpoint is None: raise ValueError( - "Unable to infer the metadata endpoint from the access URL. Please provide a metadata endpoint." + "Unable to infer the metadata endpoint from the access URL. Please provide a metadata endpoint.", ) logger.info(f"Inferred metadata endpoint: {metadata_endpoint}") values["metadata_endpoint"] = metadata_endpoint @@ -290,13 +290,17 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @staticmethod def _send_graphql_query( - metadata_endpoint: str, token: str, query: str, variables: Dict + metadata_endpoint: str, + token: str, + query: str, + variables: Dict, ) -> Dict: logger.debug(f"Sending GraphQL query to dbt Cloud: {query}") response = requests.post( @@ -315,7 +319,7 @@ def _send_graphql_query( res = response.json() if "errors" in res: raise ValueError( - f"Unable to fetch metadata from dbt Cloud: {res['errors']}" + f"Unable to fetch metadata from dbt Cloud: {res['errors']}", ) data = res["data"] except JSONDecodeError as e: @@ -406,7 +410,8 @@ def _parse_into_dbt_node(self, node: Dict) -> DBTNode: status = node["status"] if status is None and materialization != "ephemeral": self.report.report_warning( - key, "node is missing a status, schema metadata will be incomplete" + key, + "node is missing a status, schema metadata will be incomplete", ) # The code fields are new in dbt 1.3, and replace the sql ones. @@ -433,7 +438,7 @@ def _parse_into_dbt_node(self, node: Dict) -> DBTNode: sorted( [self._parse_into_dbt_column(column) for column in node["columns"]], key=lambda c: c.index, - ) + ), ) test_info = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py index fa85308b325979..56fa0fe900dcbf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py @@ -133,7 +133,7 @@ class DBTSourceReport(StaleEntityRemovalSourceReport): sql_parser_parse_failures_list: LossyList[str] = field(default_factory=LossyList) sql_parser_detach_ctes_failures_list: LossyList[str] = field( - default_factory=LossyList + default_factory=LossyList, ) nodes_filtered: LossyList[str] = field(default_factory=LossyList) @@ -188,7 +188,7 @@ def process_only_directive(cls, values): only_values = [k for k in values if values.get(k) == EmitDirective.ONLY] if len(only_values) > 1: raise ValueError( - f"Cannot have more than 1 type of entity emission set to ONLY. Found {only_values}" + f"Cannot have more than 1 type of entity emission set to ONLY. Found {only_values}", ) if len(only_values) == 1: @@ -255,7 +255,8 @@ class DBTCommonConfig( description="Use model identifier instead of model name if defined (if not, default to model name).", ) _deprecate_use_identifiers = pydantic_field_deprecated( - "use_identifiers", warn_if_value_is_not=False + "use_identifiers", + warn_if_value_is_not=False, ) entities_enabled: DBTEntitiesEnabled = Field( @@ -275,7 +276,8 @@ class DBTCommonConfig( "This is mainly useful when you have multiple, interdependent dbt projects. ", ) tag_prefix: str = Field( - default=f"{DBT_PLATFORM}:", description="Prefix added to tags during ingestion." + default=f"{DBT_PLATFORM}:", + description="Prefix added to tags during ingestion.", ) node_name_pattern: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), @@ -325,7 +327,8 @@ class DBTCommonConfig( "that are only distinguished by env, then you should set this flag to True.", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="DBT Stateful Ingestion Config." + default=None, + description="DBT Stateful Ingestion Config.", ) convert_column_urns_to_lowercase: bool = Field( default=False, @@ -363,13 +366,14 @@ def validate_target_platform_value(cls, target_platform: str) -> str: if target_platform.lower() == DBT_PLATFORM: raise ValueError( "target_platform cannot be dbt. It should be the platform which dbt is operating on top of. For e.g " - "postgres." + "postgres.", ) return target_platform @root_validator(pre=True) def set_convert_column_urns_to_lowercase_default_for_snowflake( - cls, values: dict + cls, + values: dict, ) -> dict: if values.get("target_platform", "").lower() == "snowflake": values.setdefault("convert_column_urns_to_lowercase", True) @@ -381,22 +385,25 @@ def validate_write_semantics(cls, write_semantics: str) -> str: raise ValueError( "write_semantics cannot be any other value than PATCH or OVERRIDE. Default value is PATCH. " "For PATCH semantics consider using the datahub-rest sink or " - "provide a datahub_api: configuration on your ingestion recipe" + "provide a datahub_api: configuration on your ingestion recipe", ) return write_semantics @validator("meta_mapping") def meta_mapping_validator( - cls, meta_mapping: Dict[str, Any], values: Dict, **kwargs: Any + cls, + meta_mapping: Dict[str, Any], + values: Dict, + **kwargs: Any, ) -> Dict[str, Any]: for k, v in meta_mapping.items(): if "match" not in v: raise ValueError( - f"meta_mapping section {k} doesn't have a match clause." + f"meta_mapping section {k} doesn't have a match clause.", ) if "operation" not in v: raise ValueError( - f"meta_mapping section {k} doesn't have an operation clause." + f"meta_mapping section {k} doesn't have an operation clause.", ) if v["operation"] == "add_owner": owner_category = v["config"].get("owner_category") @@ -406,27 +413,31 @@ def meta_mapping_validator( @validator("include_column_lineage") def validate_include_column_lineage( - cls, include_column_lineage: bool, values: Dict + cls, + include_column_lineage: bool, + values: Dict, ) -> bool: if include_column_lineage and not values.get("infer_dbt_schemas"): raise ValueError( - "`infer_dbt_schemas` must be enabled to use `include_column_lineage`" + "`infer_dbt_schemas` must be enabled to use `include_column_lineage`", ) return include_column_lineage @validator("skip_sources_in_lineage", always=True) def validate_skip_sources_in_lineage( - cls, skip_sources_in_lineage: bool, values: Dict + cls, + skip_sources_in_lineage: bool, + values: Dict, ) -> bool: entities_enabled: Optional[DBTEntitiesEnabled] = values.get("entities_enabled") prefer_sql_parser_lineage: Optional[bool] = values.get( - "prefer_sql_parser_lineage" + "prefer_sql_parser_lineage", ) if prefer_sql_parser_lineage and not skip_sources_in_lineage: raise ValueError( - "`prefer_sql_parser_lineage` requires that `skip_sources_in_lineage` is enabled." + "`prefer_sql_parser_lineage` requires that `skip_sources_in_lineage` is enabled.", ) if ( @@ -438,7 +449,7 @@ def validate_skip_sources_in_lineage( and not prefer_sql_parser_lineage ): raise ValueError( - "When `skip_sources_in_lineage` is enabled, `entities_enabled.sources` must be set to NO." + "When `skip_sources_in_lineage` is enabled, `entities_enabled.sources` must be set to NO.", ) return skip_sources_in_lineage @@ -568,7 +579,7 @@ def get_fake_ephemeral_table_name(self) -> str: # Similar to get_db_fqn. db_fqn = self._join_parts( - [self.database, self.schema, f"__datahub__dbt__ephemeral__{self.name}"] + [self.database, self.schema, f"__datahub__dbt__ephemeral__{self.name}"], ) db_fqn = db_fqn.lower() return db_fqn.replace('"', "") @@ -699,7 +710,7 @@ def get_upstreams( for upstream in sorted(upstreams): if upstream not in all_nodes: logger.debug( - f"Upstream node - {upstream} not found in all manifest entities." + f"Upstream node - {upstream} not found in all manifest entities.", ) continue @@ -713,7 +724,7 @@ def get_upstreams( target_platform_instance=target_platform_instance, env=environment, skip_sources_in_lineage=skip_sources_in_lineage, - ) + ), ) return upstream_urns @@ -729,7 +740,7 @@ def get_upstreams_for_test( for upstream in test_node.upstream_nodes: if upstream not in all_nodes_map: logger.debug( - f"Upstream node of test {upstream} not found in all manifest entities." + f"Upstream node of test {upstream} not found in all manifest entities.", ) continue @@ -765,13 +776,13 @@ def make_mapping_upstream_lineage( FineGrainedLineage( upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=[ - mce_builder.make_schema_field_urn(upstream_urn, field_name) + mce_builder.make_schema_field_urn(upstream_urn, field_name), ], downstreamType=FineGrainedLineageDownstreamType.FIELD, downstreams=[ - mce_builder.make_schema_field_urn(downstream_urn, field_name) + mce_builder.make_schema_field_urn(downstream_urn, field_name), ], - ) + ), ) return UpstreamLineageClass( @@ -780,9 +791,10 @@ def make_mapping_upstream_lineage( dataset=upstream_urn, type=DatasetLineageTypeClass.COPY, auditStamp=AuditStamp( - time=mce_builder.get_sys_time(), actor=_DEFAULT_ACTOR + time=mce_builder.get_sys_time(), + actor=_DEFAULT_ACTOR, ), - ) + ), ], fineGrainedLineages=cll or None, ) @@ -832,11 +844,13 @@ def __init__(self, config: DBTCommonConfig, ctx: PipelineContext, platform: str) self.compiled_owner_extraction_pattern: Optional[Any] = None if self.config.owner_extraction_pattern: self.compiled_owner_extraction_pattern = re.compile( - self.config.owner_extraction_pattern + self.config.owner_extraction_pattern, ) # Create and register the stateful ingestion use-case handler. self.stale_entity_removal_handler = StaleEntityRemovalHandler.create( - self, self.config, ctx + self, + self.config, + ctx, ) def create_test_entity_mcps( @@ -883,8 +897,8 @@ def create_test_entity_mcps( **guid_upstream_part, }.items() if v is not None - } - ) + }, + ), ) custom_props = { @@ -917,7 +931,7 @@ def create_test_entity_mcps( ) else: logger.debug( - f"Skipping test result {node.name} ({test_result.invocation_id}) emission since it is turned off." + f"Skipping test result {node.name} ({test_result.invocation_id}) emission since it is turned off.", ) @abstractmethod @@ -1005,7 +1019,8 @@ def _to_schema_info(schema_fields: List[SchemaField]) -> SchemaInfo: return {column.fieldPath: column.nativeDataType for column in schema_fields} def _determine_cll_required_nodes( - self, all_nodes_map: Dict[str, DBTNode] + self, + all_nodes_map: Dict[str, DBTNode], ) -> Tuple[Set[str], Set[str]]: # Based on the filter patterns, we only need to do schema inference and CLL # for a subset of nodes. @@ -1035,7 +1050,8 @@ def add_node_to_cll_list(dbt_name: str) -> None: return schema_nodes, cll_nodes def _infer_schemas_and_update_cll( # noqa: C901 - self, all_nodes_map: Dict[str, DBTNode] + self, + all_nodes_map: Dict[str, DBTNode], ) -> None: """Annotate the DBTNode objects with schema information and column-level lineage. @@ -1056,7 +1072,7 @@ def _infer_schemas_and_update_cll( # noqa: C901 if not self.config.infer_dbt_schemas: if self.config.include_column_lineage: raise ConfigurationError( - "`infer_dbt_schemas` must be enabled to use `include_column_lineage`" + "`infer_dbt_schemas` must be enabled to use `include_column_lineage`", ) return @@ -1082,13 +1098,13 @@ def _infer_schemas_and_update_cll( # noqa: C901 ), ) schema_required_nodes, cll_required_nodes = self._determine_cll_required_nodes( - all_nodes_map + all_nodes_map, ) for dbt_name in all_node_order: if dbt_name not in schema_required_nodes: logger.debug( - f"Skipping {dbt_name} because it is filtered out by patterns" + f"Skipping {dbt_name} because it is filtered out by patterns", ) continue @@ -1155,7 +1171,8 @@ def _infer_schemas_and_update_cll( # noqa: C901 added_to_schema_resolver = False if target_node_urn and schema_fields: schema_resolver.add_raw_schema_info( - target_node_urn, self._to_schema_info(schema_fields) + target_node_urn, + self._to_schema_info(schema_fields), ) added_to_schema_resolver = True @@ -1168,7 +1185,7 @@ def _infer_schemas_and_update_cll( # noqa: C901 pass elif node.dbt_name not in cll_required_nodes: logger.debug( - f"Not generating CLL for {node.dbt_name} because we don't need it." + f"Not generating CLL for {node.dbt_name} because we don't need it.", ) elif node.compiled_code: # Add CTE stops based on the upstreams list. @@ -1181,7 +1198,8 @@ def _infer_schemas_and_update_cll( # noqa: C901 ] if upstream_node.is_ephemeral_model() for cte_name in _get_dbt_cte_names( - upstream_node.name, schema_resolver.platform + upstream_node.name, + schema_resolver.platform, ) } if cte_mapping: @@ -1230,7 +1248,8 @@ def _infer_schemas_and_update_cll( # noqa: C901 and inferred_schema_fields ): schema_resolver.add_raw_schema_info( - target_node_urn, self._to_schema_info(inferred_schema_fields) + target_node_urn, + self._to_schema_info(inferred_schema_fields), ) # Save the inferred schema fields into the dbt node. @@ -1252,7 +1271,7 @@ def _parse_cll( ) except Exception as e: logger.debug( - f"Failed to parse compiled code. {node.dbt_name} will not have column lineage." + f"Failed to parse compiled code. {node.dbt_name} will not have column lineage.", ) self.report.sql_parser_parse_failures += 1 self.report.sql_parser_parse_failures_list.append(node.dbt_name) @@ -1268,7 +1287,7 @@ def _parse_cll( self.report.sql_parser_detach_ctes_failures += 1 self.report.sql_parser_detach_ctes_failures_list.append(node.dbt_name) logger.debug( - f"Failed to detach CTEs from compiled code. {node.dbt_name} will not have column lineage." + f"Failed to detach CTEs from compiled code. {node.dbt_name} will not have column lineage.", ) return SqlParsingResult.make_from_error(e) @@ -1276,12 +1295,12 @@ def _parse_cll( if sql_result.debug_info.table_error: self.report.sql_parser_table_errors += 1 logger.info( - f"Failed to generate any CLL lineage for {node.dbt_name}: {sql_result.debug_info.error}" + f"Failed to generate any CLL lineage for {node.dbt_name}: {sql_result.debug_info.error}", ) elif sql_result.debug_info.column_error: self.report.sql_parser_column_errors += 1 logger.info( - f"Failed to generate CLL for {node.dbt_name}: {sql_result.debug_info.column_error}" + f"Failed to generate CLL for {node.dbt_name}: {sql_result.debug_info.column_error}", ) else: self.report.sql_parser_successes += 1 @@ -1321,16 +1340,22 @@ def create_dbt_platform_mces( if self.config.enable_query_tag_mapping and node.query_tag: self.extract_query_tag_aspects( - action_processor_tag, meta_aspects, node + action_processor_tag, + meta_aspects, + node, ) # mutates meta_aspects aspects = self._generate_base_dbt_aspects( - node, additional_custom_props_filtered, DBT_PLATFORM, meta_aspects + node, + additional_custom_props_filtered, + DBT_PLATFORM, + meta_aspects, ) # Upstream lineage. upstream_lineage_class = self._create_lineage_aspect_for_dbt_node( - node, all_nodes_map + node, + all_nodes_map, ) if upstream_lineage_class: aspects.append(upstream_lineage_class) @@ -1356,7 +1381,8 @@ def create_dbt_platform_mces( standalone_aspects, snapshot_aspects = more_itertools.partition( ( lambda aspect: mce_builder.can_add_aspect_to_snapshot( - DatasetSnapshot, type(aspect) + DatasetSnapshot, + type(aspect), ) ), aspects, @@ -1369,7 +1395,8 @@ def create_dbt_platform_mces( ).as_workunit() dataset_snapshot = DatasetSnapshot( - urn=node_datahub_urn, aspects=list(snapshot_aspects) + urn=node_datahub_urn, + aspects=list(snapshot_aspects), ) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) if self.config.write_semantics == "PATCH": @@ -1377,13 +1404,16 @@ def create_dbt_platform_mces( yield MetadataWorkUnit(id=dataset_snapshot.urn, mce=mce) else: logger.debug( - f"Skipping emission of node {node_datahub_urn} because node_type {node.node_type} is disabled" + f"Skipping emission of node {node_datahub_urn} because node_type {node.node_type} is disabled", ) # Model performance. if self.config.entities_enabled.can_emit_model_performance: yield from auto_workunit( - self._create_dataprocess_instance_mcps(node, upstream_lineage_class) + self._create_dataprocess_instance_mcps( + node, + upstream_lineage_class, + ), ) def _create_dataprocess_instance_mcps( @@ -1433,13 +1463,13 @@ def _create_dataprocess_instance_mcps( yield from data_process_instance.start_event_mcp( start_timestamp_millis=datetime_to_ts_millis( - model_performance.start_time + model_performance.start_time, ), ) yield from data_process_instance.end_event_mcp( end_timestamp_millis=datetime_to_ts_millis(model_performance.end_time), start_timestamp_millis=datetime_to_ts_millis( - model_performance.start_time + model_performance.start_time, ), result=( InstanceRunResult.SUCCESS @@ -1466,7 +1496,7 @@ def create_target_platform_mces( ) if not self.config.entities_enabled.can_emit_node_type(node.node_type): logger.debug( - f"Skipping emission of node {node_datahub_urn} because node_type {node.node_type} is disabled" + f"Skipping emission of node {node_datahub_urn} because node_type {node.node_type} is disabled", ) continue @@ -1519,7 +1549,9 @@ def extract_query_tag_aspects( meta_aspects["add_tag"] = query_tag_aspects["add_tag"] def get_aspect_from_dataset( - self, dataset_snapshot: DatasetSnapshot, aspect_type: type + self, + dataset_snapshot: DatasetSnapshot, + aspect_type: type, ) -> Any: for aspect in dataset_snapshot.aspects: if isinstance(aspect, aspect_type): @@ -1528,7 +1560,8 @@ def get_aspect_from_dataset( def get_patched_mce(self, mce): owner_aspect = self.get_aspect_from_dataset( - mce.proposedSnapshot, OwnershipClass + mce.proposedSnapshot, + OwnershipClass, ) if owner_aspect: transformed_owner_list = self.get_transformed_owners_by_source_type( @@ -1548,17 +1581,21 @@ def get_patched_mce(self, mce): tag_aspect.tags = transformed_tag_list term_aspect: GlossaryTermsClass = self.get_aspect_from_dataset( - mce.proposedSnapshot, GlossaryTermsClass + mce.proposedSnapshot, + GlossaryTermsClass, ) if term_aspect: transformed_terms = self.get_transformed_terms( - term_aspect.terms, mce.proposedSnapshot.urn + term_aspect.terms, + mce.proposedSnapshot.urn, ) term_aspect.terms = transformed_terms return mce def _create_dataset_properties_aspect( - self, node: DBTNode, additional_custom_props_filtered: Dict[str, str] + self, + node: DBTNode, + additional_custom_props_filtered: Dict[str, str], ) -> DatasetPropertiesClass: description = node.description @@ -1581,7 +1618,8 @@ def get_external_url(self, node: DBTNode) -> Optional[str]: pass def _create_view_properties_aspect( - self, node: DBTNode + self, + node: DBTNode, ) -> Optional[ViewPropertiesClass]: if node.language != "sql" or not node.raw_code: return None @@ -1589,7 +1627,8 @@ def _create_view_properties_aspect( compiled_code = None if self.config.include_compiled_code and node.compiled_code: compiled_code = try_format_query( - node.compiled_code, platform=self.config.target_platform + node.compiled_code, + platform=self.config.target_platform, ) materialized = node.materialization in {"table", "incremental", "snapshot"} @@ -1618,7 +1657,8 @@ def _generate_base_dbt_aspects( # add dataset properties aspect dbt_properties = self._create_dataset_properties_aspect( - node, additional_custom_props_filtered + node, + additional_custom_props_filtered, ) aspects.append(dbt_properties) @@ -1638,7 +1678,7 @@ def _generate_base_dbt_aspects( aggregated_tags = self._aggregate_tags(node, meta_tags_aspect) if aggregated_tags: aspects.append( - mce_builder.make_global_tag_aspect_with_tag_list(aggregated_tags) + mce_builder.make_global_tag_aspect_with_tag_list(aggregated_tags), ) # add meta term aspects @@ -1664,7 +1704,10 @@ def _generate_base_dbt_aspects( return aspects def get_schema_metadata( - self, report: DBTSourceReport, node: DBTNode, platform: str + self, + report: DBTSourceReport, + node: DBTNode, + platform: str, ) -> SchemaMetadata: action_processor = OperationProcessor( self.config.column_meta_mapping, @@ -1696,7 +1739,7 @@ def get_schema_metadata( logger.warning("The add_owner operation is not supported for columns.") meta_tags: Optional[GlobalTagsClass] = meta_aspects.get( - Constants.ADD_TAG_OPERATION + Constants.ADD_TAG_OPERATION, ) globalTags = None if meta_tags or column.tags: @@ -1706,7 +1749,7 @@ def get_schema_metadata( + [ TagAssociationClass(mce_builder.make_tag_urn(tag)) for tag in column.tags - ] + ], ) glossaryTerms = None @@ -1722,7 +1765,10 @@ def get_schema_metadata( nativeDataType=column.data_type, type=column.datahub_data_type or get_column_type( - report, node.dbt_name, column.data_type, node.dbt_adapter + report, + node.dbt_name, + column.data_type, + node.dbt_adapter, ), description=description, nullable=False, # TODO: actually autodetect this @@ -1752,7 +1798,9 @@ def get_schema_metadata( ) def _aggregate_owners( - self, node: DBTNode, meta_owner_aspects: Any + self, + node: DBTNode, + meta_owner_aspects: Any, ) -> List[OwnerClass]: owner_list: List[OwnerClass] = [] if meta_owner_aspects and self.config.enable_meta_mapping: @@ -1762,12 +1810,13 @@ def _aggregate_owners( owner: str = node.owner if self.compiled_owner_extraction_pattern: match: Optional[Any] = re.match( - self.compiled_owner_extraction_pattern, owner + self.compiled_owner_extraction_pattern, + owner, ) if match: owner = match.group("owner") logger.debug( - f"Owner after applying owner extraction pattern:'{self.config.owner_extraction_pattern}' is '{owner}'." + f"Owner after applying owner extraction pattern:'{self.config.owner_extraction_pattern}' is '{owner}'.", ) if isinstance(owner, list): owners = owner @@ -1782,7 +1831,7 @@ def _aggregate_owners( OwnerClass( owner=mce_builder.make_user_urn(owner), type=OwnershipTypeClass.DATAOWNER, - ) + ), ) owner_list = sorted(owner_list, key=lambda x: x.owner) @@ -1800,7 +1849,9 @@ def _aggregate_tags(self, node: DBTNode, meta_tag_aspect: Any) -> List[str]: return sorted(tags_list) def _create_subType_wu( - self, node: DBTNode, node_datahub_urn: str + self, + node: DBTNode, + node_datahub_urn: str, ) -> Optional[MetadataWorkUnit]: if not node.node_type: return None @@ -1883,17 +1934,19 @@ def _translate_dbt_name_to_upstream_urn(dbt_name: str) -> str: downstreamType=FineGrainedLineageDownstreamType.FIELD, upstreams=[ mce_builder.make_schema_field_urn( - upstream.table, upstream.column + upstream.table, + upstream.column, ) for upstream in column_lineage.upstreams ], downstreams=[ mce_builder.make_schema_field_urn( - node_urn, column_lineage.downstream.column - ) + node_urn, + column_lineage.downstream.column, + ), ], confidenceScore=sql_parsing_result.debug_info.confidence, - ) + ), ) else: @@ -1914,14 +1967,14 @@ def _translate_dbt_name_to_upstream_urn(dbt_name: str) -> str: upstreams=[ mce_builder.make_schema_field_urn( _translate_dbt_name_to_upstream_urn( - upstream_column.upstream_dbt_name + upstream_column.upstream_dbt_name, ), upstream_column.upstream_col, ) for upstream_column in upstreams ], downstreams=[ - mce_builder.make_schema_field_urn(node_urn, downstream) + mce_builder.make_schema_field_urn(node_urn, downstream), ], confidenceScore=( node.cll_debug_info.confidence @@ -1930,7 +1983,8 @@ def _translate_dbt_name_to_upstream_urn(dbt_name: str) -> str: ), ) for downstream, upstreams in itertools.groupby( - node.upstream_cll, lambda x: x.downstream_col + node.upstream_cll, + lambda x: x.downstream_col, ) ] @@ -1965,7 +2019,10 @@ def _translate_dbt_name_to_upstream_urn(dbt_name: str) -> str: # From the existing owners it will remove the owners that are of the source_type_filter and # then add all the new owners to that list. def get_transformed_owners_by_source_type( - self, owners: List[OwnerClass], entity_urn: str, source_type_filter: str + self, + owners: List[OwnerClass], + entity_urn: str, + source_type_filter: str, ) -> List[OwnerClass]: transformed_owners: List[OwnerClass] = [] if owners: @@ -2002,7 +2059,7 @@ def get_transformed_tags_by_prefix( if existing_tags_class and existing_tags_class.tags: for existing_tag in existing_tags_class.tags: if tag_prefix and existing_tag.tag.startswith( - mce_builder.make_tag_urn(tag_prefix) + mce_builder.make_tag_urn(tag_prefix), ): continue tag_set.add(existing_tag.tag) @@ -2011,7 +2068,9 @@ def get_transformed_tags_by_prefix( # This method attempts to read-modify and return the glossary terms of a dataset. # This will combine all new and existing terms and return the final deduped list. def get_transformed_terms( - self, new_terms: List[GlossaryTermAssociation], entity_urn: str + self, + new_terms: List[GlossaryTermAssociation], + entity_urn: str, ) -> List[GlossaryTermAssociation]: term_id_set = {term.urn for term in new_terms} if self.ctx.graph: diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index 04de763370c951..6585b08ae7972b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -43,11 +43,11 @@ class DBTCoreConfig(DBTCommonConfig): manifest_path: str = Field( description="Path to dbt manifest JSON. See https://docs.getdbt.com/reference/artifacts/manifest-json Note " - "this can be a local file or a URI." + "this can be a local file or a URI.", ) catalog_path: str = Field( description="Path to dbt catalog JSON. See https://docs.getdbt.com/reference/artifacts/catalog-json Note this " - "can be a local file or a URI." + "can be a local file or a URI.", ) sources_path: Optional[str] = Field( default=None, @@ -70,10 +70,14 @@ class DBTCoreConfig(DBTCommonConfig): # Because we now also collect model performance metadata, the "test_results" field was renamed to "run_results". _convert_test_results_path = pydantic_renamed_field( - "test_results_path", "run_results_paths", transform=lambda x: [x] if x else [] + "test_results_path", + "run_results_paths", + transform=lambda x: [x] if x else [], ) _convert_run_result_path_singular = pydantic_renamed_field( - "run_results_path", "run_results_paths", transform=lambda x: [x] if x else [] + "run_results_path", + "run_results_paths", + transform=lambda x: [x] if x else [], ) aws_connection: Optional[AwsConnectionConfig] = Field( @@ -90,7 +94,10 @@ class DBTCoreConfig(DBTCommonConfig): @validator("aws_connection", always=True) def aws_connection_needed_if_s3_uris_present( - cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any + cls, + aws_connection: Optional[AwsConnectionConfig], + values: Dict, + **kwargs: Any, ) -> Optional[AwsConnectionConfig]: # first check if there are fields that contain s3 uris uris = [ @@ -105,7 +112,7 @@ def aws_connection_needed_if_s3_uris_present( if s3_uris and aws_connection is None: raise ValueError( - f"Please provide aws_connection configuration, since s3 uris have been provided {s3_uris}" + f"Please provide aws_connection configuration, since s3 uris have been provided {s3_uris}", ) return aws_connection @@ -138,7 +145,8 @@ def get_columns( columns = [] for key, catalog_column in catalog_columns.items(): manifest_column = manifest_columns.get( - key, manifest_columns_lower.get(key.lower(), {}) + key, + manifest_columns_lower.get(key.lower(), {}), ) meta = manifest_column.get("meta", {}) @@ -190,7 +198,7 @@ def extract_dbt_entities( comment = "" if key in all_catalog_entities and all_catalog_entities[key]["metadata"].get( - "comment" + "comment", ): comment = all_catalog_entities[key]["metadata"]["comment"] @@ -277,10 +285,12 @@ def extract_dbt_entities( comment=comment, description=manifest_node.get("description", ""), raw_code=manifest_node.get( - "raw_code", manifest_node.get("raw_sql") + "raw_code", + manifest_node.get("raw_sql"), ), # Backward compatibility dbt <=v1.2 language=manifest_node.get( - "language", "sql" + "language", + "sql", ), # Backward compatibility dbt <=v1.2 upstream_nodes=upstream_nodes, materialization=materialization, @@ -291,7 +301,8 @@ def extract_dbt_entities( tags=tags, owner=owner, compiled_code=manifest_node.get( - "compiled_code", manifest_node.get("compiled_sql") + "compiled_code", + manifest_node.get("compiled_sql"), ), # Backward compatibility dbt <=v1.2 test_info=test_info, ) @@ -410,7 +421,7 @@ def load_run_results( if test_results_json.get("args", {}).get("which") == "generate": logger.warning( "The run results file is from a `dbt docs generate` command, " - "instead of a build/run/test command. Skipping this file." + "instead of a build/run/test command. Skipping this file.", ) return all_nodes @@ -470,21 +481,25 @@ def test_connection(config_dict: dict) -> TestConnectionReport: try: source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict) DBTCoreSource.load_file_as_json( - source_config.manifest_path, source_config.aws_connection + source_config.manifest_path, + source_config.aws_connection, ) DBTCoreSource.load_file_as_json( - source_config.catalog_path, source_config.aws_connection + source_config.catalog_path, + source_config.aws_connection, ) test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @staticmethod def load_file_as_json( - uri: str, aws_connection: Optional[AwsConnectionConfig] + uri: str, + aws_connection: Optional[AwsConnectionConfig], ) -> Dict: if re.match("^https?://", uri): return json.loads(requests.get(uri).text) @@ -492,7 +507,8 @@ def load_file_as_json( u = urlparse(uri) assert aws_connection response = aws_connection.get_s3_client().get_object( - Bucket=u.netloc, Key=u.path.lstrip("/") + Bucket=u.netloc, + Key=u.path.lstrip("/"), ) return json.loads(response["Body"].read().decode("utf-8")) else: @@ -510,16 +526,19 @@ def loadManifestAndCatalog( Optional[str], ]: dbt_manifest_json = self.load_file_as_json( - self.config.manifest_path, self.config.aws_connection + self.config.manifest_path, + self.config.aws_connection, ) dbt_catalog_json = self.load_file_as_json( - self.config.catalog_path, self.config.aws_connection + self.config.catalog_path, + self.config.aws_connection, ) if self.config.sources_path is not None: dbt_sources_json = self.load_file_as_json( - self.config.sources_path, self.config.aws_connection + self.config.sources_path, + self.config.aws_connection, ) sources_results = dbt_sources_json["results"] else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py index dc4e5d426fe42f..5142becb92fa11 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py @@ -92,7 +92,7 @@ class AssertionParams: value=AssertionStdParameterClass( value="1.0", type=AssertionStdParameterTypeClass.NUMBER, - ) + ), ), ), "accepted_values": AssertionParams( diff --git a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py index d2b4a576953daf..39eaddb7ed52f3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py @@ -21,12 +21,14 @@ class S3(ConfigModel): aws_config: Optional[AwsConnectionConfig] = Field( - default=None, description="AWS configuration" + default=None, + description="AWS configuration", ) # Whether or not to create in datahub from the s3 bucket use_s3_bucket_tags: Optional[bool] = Field( - False, description="Whether or not to create tags in datahub from the s3 bucket" + False, + description="Whether or not to create tags in datahub from the s3 bucket", ) # Whether or not to create in datahub from the s3 object use_s3_object_tags: Optional[bool] = Field( @@ -38,7 +40,7 @@ class S3(ConfigModel): class DeltaLakeSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): base_path: str = Field( description="Path to table (s3 or local file system). If path is not a delta table path " - "then all subfolders will be scanned to detect and ingest delta tables." + "then all subfolders will be scanned to detect and ingest delta tables.", ) relative_path: Optional[str] = Field( default=None, diff --git a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/delta_lake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/delta_lake_utils.py index 071dacae4a8c54..cf931672f776df 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/delta_lake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/delta_lake_utils.py @@ -17,7 +17,9 @@ def read_delta_table( - path: str, opts: Dict[str, str], delta_lake_config: DeltaLakeSourceConfig + path: str, + opts: Dict[str, str], + delta_lake_config: DeltaLakeSourceConfig, ) -> Optional[DeltaTable]: if not delta_lake_config.is_s3 and not pathlib.Path(path).exists(): # The DeltaTable() constructor will create the path if it doesn't exist. diff --git a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py index 9df3905437b3b2..13682ae2819632 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/source.py @@ -142,16 +142,19 @@ def get_fields(self, delta_table: DeltaTable) -> List[SchemaField]: return fields def _create_operation_aspect_wu( - self, delta_table: DeltaTable, dataset_urn: str + self, + delta_table: DeltaTable, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: for hist in delta_table.history( - limit=self.source_config.version_history_lookback + limit=self.source_config.version_history_lookback, ): # History schema picked up from https://docs.delta.io/latest/delta-utility.html#retrieve-delta-table-history reported_time: int = int(time.time() * 1000) last_updated_timestamp: int = hist["timestamp"] statement_type = OPERATION_STATEMENT_TYPES.get( - hist.get("operation", "UNKNOWN"), OperationTypeClass.CUSTOM + hist.get("operation", "UNKNOWN"), + OperationTypeClass.CUSTOM, ) custom_type = ( hist.get("operation") @@ -184,7 +187,9 @@ def _create_operation_aspect_wu( ).as_workunit() def ingest_table( - self, delta_table: DeltaTable, path: str + self, + delta_table: DeltaTable, + path: str, ) -> Iterable[MetadataWorkUnit]: table_name = ( delta_table.metadata().name @@ -193,7 +198,7 @@ def ingest_table( ) if not self.source_config.table_pattern.allowed(table_name): logger.debug( - f"Skipping table ({table_name}) present at location {path} as table pattern does not match" + f"Skipping table ({table_name}) present at location {path} as table pattern does not match", ) logger.debug(f"Ingesting table {table_name} from location {path}") @@ -270,7 +275,8 @@ def ingest_table( yield MetadataWorkUnit(id=str(delta_table.metadata().id), mce=mce) yield from self.container_WU_creator.create_container_hierarchy( - browse_path, dataset_urn + browse_path, + dataset_urn, ) yield from self._create_operation_aspect_wu(delta_table, dataset_urn) @@ -318,7 +324,9 @@ def get_folders(self, path: str) -> Iterable[str]: def s3_get_folders(self, path: str) -> Iterable[str]: parse_result = urlparse(path) for page in self.s3_client.get_paginator("list_objects_v2").paginate( - Bucket=parse_result.netloc, Prefix=parse_result.path[1:], Delimiter="/" + Bucket=parse_result.netloc, + Prefix=parse_result.path[1:], + Delimiter="/", ): for o in page.get("CommonPrefixes", []): yield f"{parse_result.scheme}://{parse_result.netloc}/{o.get('Prefix')}" @@ -326,7 +334,7 @@ def s3_get_folders(self, path: str) -> Iterable[str]: def local_get_folders(self, path: str) -> Iterable[str]: if not os.path.isdir(path): raise FileNotFoundError( - f"{path} does not exist or is not a directory. Please check base_path configuration." + f"{path} does not exist or is not a directory. Please check base_path configuration.", ) for folder in os.listdir(path): yield os.path.join(path, folder) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py index cf2d9670400ca5..4c295fbfb52b14 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_api.py @@ -47,7 +47,9 @@ class DremioAPIOperations: _timeout: int = 1800 def __init__( - self, connection_args: "DremioSourceConfig", report: "DremioSourceReport" + self, + connection_args: "DremioSourceConfig", + report: "DremioSourceReport", ) -> None: self.dremio_to_datahub_source_mapper = DremioToDataHubSourceTypeMapping() self.allow_schema_pattern: List[str] = connection_args.schema_pattern.allow @@ -84,10 +86,10 @@ def _get_cloud_base_url(self, connection_args: "DremioSourceConfig") -> str: project_id = connection_args.dremio_cloud_project_id else: self.report.failure( - "Project ID must be provided for Dremio Cloud environments." + "Project ID must be provided for Dremio Cloud environments.", ) raise DremioAPIException( - "Project ID must be provided for Dremio Cloud environments." + "Project ID must be provided for Dremio Cloud environments.", ) if connection_args.dremio_cloud_region == "US": @@ -101,10 +103,10 @@ def _get_on_prem_base_url(self, connection_args: "DremioSourceConfig") -> str: protocol = "https" if connection_args.tls else "http" if not host: self.report.failure( - "Hostname must be provided for on-premises Dremio instances." + "Hostname must be provided for on-premises Dremio instances.", ) raise DremioAPIException( - "Hostname must be provided for on-premises Dremio instances." + "Hostname must be provided for on-premises Dremio instances.", ) return f"{protocol}://{host}:{port}/api/v3" @@ -115,10 +117,10 @@ def _get_ui_url(self, connection_args: "DremioSourceConfig") -> str: project_id = connection_args.dremio_cloud_project_id else: self.report.failure( - "Project ID must be provided for Dremio Cloud environments." + "Project ID must be provided for Dremio Cloud environments.", ) raise DremioAPIException( - "Project ID must be provided for Dremio Cloud environments." + "Project ID must be provided for Dremio Cloud environments.", ) cloud_region = connection_args.dremio_cloud_region if cloud_region == "US": @@ -131,10 +133,10 @@ def _get_ui_url(self, connection_args: "DremioSourceConfig") -> str: protocol = "https" if connection_args.tls else "http" if not host: self.report.failure( - "Hostname must be provided for on-premises Dremio instances." + "Hostname must be provided for on-premises Dremio instances.", ) raise DremioAPIException( - "Hostname must be provided for on-premises Dremio instances." + "Hostname must be provided for on-premises Dremio instances.", ) return f"{protocol}://{host}:{port}" @@ -170,13 +172,13 @@ def authenticate(self, connection_args: "DremioSourceConfig") -> None: if self.is_dremio_cloud: if not connection_args.password: self.report.failure( - "Personal Access Token (PAT) is missing for cloud authentication." + "Personal Access Token (PAT) is missing for cloud authentication.", ) raise DremioAPIException( - "Personal Access Token (PAT) is missing for cloud authentication." + "Personal Access Token (PAT) is missing for cloud authentication.", ) self.session.headers.update( - {"Authorization": f"Bearer {connection_args.password}"} + {"Authorization": f"Bearer {connection_args.password}"}, ) return @@ -187,7 +189,7 @@ def authenticate(self, connection_args: "DremioSourceConfig") -> None: self.session.headers.update( { "Authorization": f"Bearer {connection_args.password}", - } + }, ) return else: @@ -204,7 +206,7 @@ def authenticate(self, connection_args: "DremioSourceConfig") -> None: { "userName": connection_args.username, "password": connection_args.password, - } + }, ), verify=self._verify, timeout=self._timeout, @@ -213,7 +215,7 @@ def authenticate(self, connection_args: "DremioSourceConfig") -> None: token = response.json().get("token") if token: self.session.headers.update( - {"Authorization": f"_dremio{token}"} + {"Authorization": f"_dremio{token}"}, ) return @@ -225,10 +227,10 @@ def authenticate(self, connection_args: "DremioSourceConfig") -> None: sleep(1) # Optional: exponential backoff self.report.failure( - "Credentials cannot be refreshed. Please check your username and password." + "Credentials cannot be refreshed. Please check your username and password.", ) raise DremioAPIException( - "Credentials cannot be refreshed. Please check your username and password." + "Credentials cannot be refreshed. Please check your username and password.", ) def get(self, url: str) -> Dict: @@ -257,7 +259,8 @@ def execute_query(self, query: str, timeout: int = 3600) -> List[Dict[str, Any]] if "errorMessage" in response: self.report.failure( - message="SQL Error", context=f"{response['errorMessage']}" + message="SQL Error", + context=f"{response['errorMessage']}", ) raise DremioAPIException(f"SQL Error: {response['errorMessage']}") @@ -270,7 +273,7 @@ def execute_query(self, query: str, timeout: int = 3600) -> List[Dict[str, Any]] except concurrent.futures.TimeoutError: self.cancel_query(job_id) raise DremioAPIException( - f"Query execution timed out after {timeout} seconds" + f"Query execution timed out after {timeout} seconds", ) except RuntimeError as e: raise DremioAPIException(f"{str(e)}") @@ -329,7 +332,10 @@ def get_job_status(self, job_id: str) -> Dict[str, Any]: ) def get_job_result( - self, job_id: str, offset: int = 0, limit: int = 500 + self, + job_id: str, + offset: int = 0, + limit: int = 500, ) -> Dict[str, Any]: """Get job results in batches""" return self.get( @@ -363,7 +369,8 @@ def get_dataset_id(self, schema: str, dataset: str) -> Optional[str]: return dataset_id def community_get_formatted_tables( - self, tables_and_columns: List[Dict[str, Any]] + self, + tables_and_columns: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: schema_list = [] schema_dict_lookup = [] @@ -383,12 +390,13 @@ def community_get_formatted_tables( { "name": record["COLUMN_NAME"], "ordinal_position": record.get( - "ORDINAL_POSITION", ordinal_position + "ORDINAL_POSITION", + ordinal_position, ), "is_nullable": record["IS_NULLABLE"], "data_type": record["DATA_TYPE"], "column_size": record["COLUMN_SIZE"], - } + }, ) ordinal_position += 1 @@ -409,7 +417,7 @@ def community_get_formatted_tables( if key in dictionary ): dictionary for dictionary in tables_and_columns - }.values() + }.values(), ) for schema in schema_list: @@ -421,12 +429,12 @@ def community_get_formatted_tables( { "TABLE_SCHEMA": "[" + ", ".join( - schemas.get("formatted_path") + [table.get("TABLE_NAME")] + schemas.get("formatted_path") + [table.get("TABLE_NAME")], ) + "]", "TABLE_NAME": table.get("TABLE_NAME"), "COLUMNS": column_dictionary.get( - table.get("FULL_TABLE_PATH", "") + table.get("FULL_TABLE_PATH", ""), ), "VIEW_DEFINITION": table.get("VIEW_DEFINITION"), "RESOURCE_ID": self.get_dataset_id( @@ -437,13 +445,16 @@ def community_get_formatted_tables( schema=".".join(schemas.get("formatted_path")), dataset="", ), - } + }, ) return dataset_list def get_pattern_condition( - self, patterns: Union[str, List[str]], field: str, allow: bool = True + self, + patterns: Union[str, List[str]], + field: str, + allow: bool = True, ) -> str: if not patterns: return "" @@ -473,10 +484,13 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: schema_field = "CONCAT(REPLACE(REPLACE(REPLACE(UPPER(TABLE_SCHEMA), ', ', '.'), '[', ''), ']', ''))" schema_condition = self.get_pattern_condition( - self.allow_schema_pattern, schema_field + self.allow_schema_pattern, + schema_field, ) deny_schema_condition = self.get_pattern_condition( - self.deny_schema_pattern, schema_field, allow=False + self.deny_schema_pattern, + schema_field, + allow=False, ) all_tables_and_columns = [] @@ -492,7 +506,7 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: all_tables_and_columns.extend( self.execute_query( query=formatted_query, - ) + ), ) except DremioAPIException as e: self.report.warning( @@ -524,7 +538,7 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: "is_nullable": record["IS_NULLABLE"], "data_type": record["DATA_TYPE"], "column_size": record["COLUMN_SIZE"], - } + }, ) distinct_tables_list = list( @@ -545,7 +559,7 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: if key in dictionary ): dictionary for dictionary in all_tables_and_columns - }.values() + }.values(), ) for table in distinct_tables_list: @@ -561,7 +575,7 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: "OWNER_TYPE": table.get("OWNER_TYPE"), "CREATED": table.get("CREATED"), "FORMAT_TYPE": table.get("FORMAT_TYPE"), - } + }, ) return tables @@ -569,7 +583,7 @@ def get_all_tables_and_columns(self, containers: Deque) -> List[Dict]: def validate_schema_format(self, schema): if "." in schema: schema_path = self.get( - url=f"/catalog/{self.get_dataset_id(schema=schema, dataset='')}" + url=f"/catalog/{self.get_dataset_id(schema=schema, dataset='')}", ).get("path") return {"original_path": schema, "formatted_path": schema_path} return {"original_path": schema, "formatted_path": [schema]} @@ -625,7 +639,7 @@ def get_tags_for_resource(self, resource_id: str) -> Optional[List[str]]: "Resource ID {} has no tags: {}".format( resource_id, exc, - ) + ), ) return None @@ -644,7 +658,7 @@ def get_description_for_resource(self, resource_id: str) -> Optional[str]: "Resource ID {} has no wiki entry: {}".format( resource_id, exc, - ) + ), ) return None @@ -668,7 +682,8 @@ def _check_pattern_match( # Already has start anchor regex_pattern = pattern.replace(".", r"\.") # Escape dots regex_pattern = regex_pattern.replace( - r"\.*", ".*" + r"\.*", + ".*", ) # Convert .* to wildcard else: # Add start anchor and handle dots @@ -771,7 +786,8 @@ def process_source(source): source_config = source_resp.get("config", {}) db = source_config.get( - "database", source_config.get("databaseName", "") + "database", + source_config.get("databaseName", ""), ) if self.should_include_container([], source.get("path")[0]): @@ -809,7 +825,7 @@ def process_source_and_containers(source): # Use ThreadPoolExecutor to parallelize the processing of sources with concurrent.futures.ThreadPoolExecutor( - max_workers=self._max_workers + max_workers=self._max_workers, ) as executor: future_to_source = { executor.submit(process_source_and_containers, source): source @@ -842,7 +858,9 @@ def get_context_for_vds(self, resource_id: str) -> str: return "" def get_containers_for_location( - self, resource_id: str, path: List[str] + self, + resource_id: str, + path: List[str], ) -> List[Dict[str, str]]: containers = [] @@ -866,7 +884,7 @@ def traverse_path(location_id: str, entity_path: List[str]) -> List: "name": folder_name, "path": folder_path, "container_type": DremioEntityContainerType.FOLDER, - } + }, ) # Recursively process child containers @@ -880,8 +898,8 @@ def traverse_path(location_id: str, entity_path: List[str]) -> List: except Exception as exc: logging.info( "Location {} contains no tables or views. Skipping...".format( - location_id - ) + location_id, + ), ) self.report.warning( message="Failed to get tables or views", diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_aspects.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_aspects.py index d9d85edbf4f7a0..5b356ae11fe4b0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_aspects.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_aspects.py @@ -99,7 +99,9 @@ class SchemaFieldTypeMapper: @classmethod def get_field_type( - cls, data_type: str, data_size: Optional[int] = None + cls, + data_type: str, + data_size: Optional[int] = None, ) -> Tuple["SchemaFieldDataTypeClass", str]: """ Maps a Dremio data type and size to a DataHub SchemaFieldDataTypeClass and native data type string. @@ -125,11 +127,11 @@ def get_field_type( schema_field_type = SchemaFieldDataTypeClass(type=type_class()) logger.debug( f"Mapped data_type '{data_type}' with size '{data_size}' to type class " - f"'{type_class.__name__}' and native data type '{native_data_type}'." + f"'{type_class.__name__}' and native data type '{native_data_type}'.", ) except Exception as e: logger.error( - f"Error initializing SchemaFieldDataTypeClass with type '{type_class.__name__}': {e}" + f"Error initializing SchemaFieldDataTypeClass with type '{type_class.__name__}': {e}", ) schema_field_type = SchemaFieldDataTypeClass(type=NullTypeClass()) @@ -154,7 +156,9 @@ def __init__( self.ingest_owner = ingest_owner def get_container_key( - self, name: Optional[str], path: Optional[List[str]] + self, + name: Optional[str], + path: Optional[List[str]], ) -> DremioContainerKey: key = name if path: @@ -168,7 +172,9 @@ def get_container_key( ) def get_container_urn( - self, name: Optional[str] = None, path: Optional[List[str]] = [] + self, + name: Optional[str] = None, + path: Optional[List[str]] = [], ) -> str: container_key = self.get_container_key(name, path) return container_key.as_urn() @@ -181,13 +187,15 @@ def create_domain_aspect(self) -> Optional[_Aspect]: domains=[ make_domain_urn( str(uuid.uuid5(namespace, self.domain)), - ) - ] + ), + ], ) return None def populate_container_mcp( - self, container_urn: str, container: DremioContainer + self, + container_urn: str, + container: DremioContainer, ) -> Iterable[MetadataWorkUnit]: # Container Properties container_properties = self._create_container_properties(container) @@ -242,7 +250,9 @@ def populate_container_mcp( yield mcp.as_workunit() def populate_dataset_mcp( - self, dataset_urn: str, dataset: DremioDataset + self, + dataset_urn: str, + dataset: DremioDataset, ) -> Iterable[MetadataWorkUnit]: # Dataset Properties dataset_properties = self._create_dataset_properties(dataset) @@ -315,10 +325,10 @@ def populate_dataset_mcp( else: logger.warning( - f"Dataset {dataset.path}.{dataset.resource_name} has not been queried in Dremio" + f"Dataset {dataset.path}.{dataset.resource_name} has not been queried in Dremio", ) logger.warning( - f"Dataset {dataset.path}.{dataset.resource_name} will have a null schema" + f"Dataset {dataset.path}.{dataset.resource_name} will have a null schema", ) # Status @@ -330,7 +340,8 @@ def populate_dataset_mcp( yield mcp.as_workunit() def populate_glossary_term_mcp( - self, glossary_term: DremioGlossaryTerm + self, + glossary_term: DremioGlossaryTerm, ) -> Iterable[MetadataWorkUnit]: glossary_term_info = self._create_glossary_term_info(glossary_term) mcp = MetadataChangeProposalWrapper( @@ -352,7 +363,8 @@ def populate_profile_aspect(self, profile_data: Dict) -> DatasetProfileClass: ) def _create_container_properties( - self, container: DremioContainer + self, + container: DremioContainer, ) -> ContainerPropertiesClass: return ContainerPropertiesClass( name=container.container_name, @@ -362,7 +374,8 @@ def _create_container_properties( ) def _create_browse_paths_containers( - self, entity: DremioContainer + self, + entity: DremioContainer, ) -> Optional[BrowsePathsV2Class]: paths = [] @@ -375,7 +388,8 @@ def _create_browse_paths_containers( return None def _create_container_class( - self, entity: Union[DremioContainer, DremioDataset] + self, + entity: Union[DremioContainer, DremioDataset], ) -> Optional[ContainerClass]: if entity.path: return ContainerClass(container=self.get_container_urn(path=entity.path)) @@ -392,7 +406,8 @@ def _create_data_platform_instance(self) -> DataPlatformInstanceClass: ) def _create_dataset_properties( - self, dataset: DremioDataset + self, + dataset: DremioDataset, ) -> DatasetPropertiesClass: return DatasetPropertiesClass( name=dataset.resource_name, @@ -402,9 +417,10 @@ def _create_dataset_properties( created=TimeStampClass( time=round( datetime.strptime( - dataset.created, "%Y-%m-%d %H:%M:%S.%f" + dataset.created, + "%Y-%m-%d %H:%M:%S.%f", ).timestamp() - * 1000 + * 1000, ) if hasattr(dataset, "created") else 0, @@ -439,8 +455,8 @@ def _create_ownership(self, dataset: DremioDataset) -> Optional[OwnershipClass]: OwnerClass( owner=owner_urn, type=OwnershipTypeClass.TECHNICAL_OWNER, - ) - ] + ), + ], ) return ownership @@ -481,7 +497,8 @@ def _create_schema_field(self, column: DremioDatasetColumn) -> SchemaFieldClass: ) def _create_view_properties( - self, dataset: DremioDataset + self, + dataset: DremioDataset, ) -> Optional[ViewPropertiesClass]: if not dataset.sql_definition: return None @@ -492,7 +509,8 @@ def _create_view_properties( ) def _create_glossary_term_info( - self, glossary_term: DremioGlossaryTerm + self, + glossary_term: DremioGlossaryTerm, ) -> GlossaryTermInfoClass: return GlossaryTermInfoClass( definition="", @@ -501,7 +519,9 @@ def _create_glossary_term_info( ) def _create_field_profile( - self, field_name: str, field_stats: Dict + self, + field_name: str, + field_stats: Dict, ) -> DatasetFieldProfileClass: quantiles = field_stats.get("quantiles") return DatasetFieldProfileClass( diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py index b3f2107a1dfaa7..41985be949ecac 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py @@ -97,7 +97,8 @@ def validate_password(cls, value, values): class ProfileConfig(GEProfilingBaseConfig): query_timeout: int = Field( - default=300, description="Time before cancelling Dremio profiling query" + default=300, + description="Time before cancelling Dremio profiling query", ) include_field_median_value: bool = Field( default=False, @@ -161,7 +162,7 @@ class DremioSourceConfig( def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) # Advanced Configs diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_datahub_source_mapping.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_datahub_source_mapping.py index 482647f8d77da1..66fe40dbee1ee4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_datahub_source_mapping.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_datahub_source_mapping.py @@ -71,7 +71,8 @@ def get_datahub_source_type(dremio_source_type: str) -> str: Return the DataHub source type. """ return DremioToDataHubSourceTypeMapping.SOURCE_TYPE_MAPPING.get( - dremio_source_type.upper(), dremio_source_type.lower() + dremio_source_type.upper(), + dremio_source_type.lower(), ) @staticmethod @@ -108,9 +109,9 @@ def add_mapping( if category: if category.lower() == "file_object_storage": DremioToDataHubSourceTypeMapping.FILE_OBJECT_STORAGE_TYPES.add( - dremio_source_type + dremio_source_type, ) else: DremioToDataHubSourceTypeMapping.DATABASE_SOURCE_TYPES.add( - dremio_source_type + dremio_source_type, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py index b80d7b8e0f9123..47a81f898421b0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_entities.py @@ -173,7 +173,7 @@ def _get_query_subtype(self) -> str: def _get_queried_datasets(self, queried_datasets: str) -> List[str]: return list( - {dataset.strip() for dataset in queried_datasets.strip("[]").split(",")} + {dataset.strip() for dataset in queried_datasets.strip("[]").split(",")}, ) def _get_affected_tables(self) -> str: @@ -237,7 +237,7 @@ def __init__( if self.sql_definition: self.dataset_type = DremioDatasetType.VIEW self.default_schema = api_operations.get_context_for_vds( - resource_id=self.resource_id + resource_id=self.resource_id, ) else: self.dataset_type = DremioDatasetType.TABLE @@ -253,16 +253,16 @@ def __init__( self.format_type = dataset_details.get("FORMAT_TYPE") self.description = api_operations.get_description_for_resource( - resource_id=self.resource_id + resource_id=self.resource_id, ) glossary_terms = api_operations.get_tags_for_resource( - resource_id=self.resource_id + resource_id=self.resource_id, ) if glossary_terms is not None: for glossary_term in glossary_terms: self.glossary_terms.append( - DremioGlossaryTerm(glossary_term=glossary_term) + DremioGlossaryTerm(glossary_term=glossary_term), ) if self.sql_definition and api_operations.edition == DremioEdition.ENTERPRISE: @@ -356,7 +356,7 @@ def set_datasets(self) -> None: containers.extend(self.sources) # Add DremioSource elements for dataset_details in self.dremio_api.get_all_tables_and_columns( - containers=containers + containers=containers, ): dremio_dataset = DremioDataset( dataset_details=dataset_details, @@ -388,7 +388,7 @@ def set_containers(self) -> None: dremio_source_type=container.get("source_type"), root_path=container.get("root_path"), database_name=container.get("database_name"), - ) + ), ) elif container_type == DremioEntityContainerType.SPACE: self.spaces.append( @@ -397,7 +397,7 @@ def set_containers(self) -> None: location_id=container.get("id"), path=[], api_operations=self.dremio_api, - ) + ), ) elif container_type == DremioEntityContainerType.FOLDER: self.folders.append( @@ -406,7 +406,7 @@ def set_containers(self) -> None: location_id=container.get("id"), path=container.get("path"), api_operations=self.dremio_api, - ) + ), ) else: self.spaces.append( @@ -415,7 +415,7 @@ def set_containers(self) -> None: location_id=container.get("id"), path=[], api_operations=self.dremio_api, - ) + ), ) logging.info("Containers retrieved from source") @@ -456,7 +456,7 @@ def get_queries(self) -> Deque[DremioQuery]: submitted_ts=query["submitted_ts"], query=query["query"], queried_datasets=query["queried_datasets"], - ) + ), ) self.queries_populated = True return self.queries diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py index 5332597ffce9e2..d7ab3e23dab0c9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_profiling.py @@ -40,7 +40,9 @@ def __init__( ) # 5 minutes timeout for each query def get_workunits( - self, dataset: DremioDataset, dataset_urn: str + self, + dataset: DremioDataset, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: if not dataset.columns: self.report.warning( @@ -70,7 +72,8 @@ def get_workunits( if profile_aspect: self.report.report_entity_profiled(dataset.resource_name) mcp = MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=profile_aspect + entityUrn=dataset_urn, + aspect=profile_aspect, ) yield mcp.as_workunit() @@ -87,7 +90,9 @@ def populate_profile_aspect(self, profile_data: Dict) -> DatasetProfileClass: ) def _create_field_profile( - self, field_name: str, field_stats: Dict + self, + field_name: str, + field_stats: Dict, ) -> DatasetFieldProfileClass: quantiles = field_stats.get("quantiles") return DatasetFieldProfileClass( @@ -137,7 +142,8 @@ def _profile_chunk(self, table_name: str, columns: List[Tuple[str, str]]) -> Dic raise e def _chunk_columns( - self, columns: List[Tuple[str, str]] + self, + columns: List[Tuple[str, str]], ) -> List[List[Tuple[str, str]]]: return [ columns[i : i + self.MAX_COLUMNS_PER_QUERY] @@ -145,7 +151,9 @@ def _chunk_columns( ] def _build_profile_sql( - self, table_name: str, columns: List[Tuple[str, str]] + self, + table_name: str, + columns: List[Tuple[str, str]], ) -> str: metrics = [] @@ -158,7 +166,7 @@ def _build_profile_sql( metrics.extend(self._get_column_metrics(column_name, data_type)) except Exception as e: logger.warning( - f"Error building metrics for column {column_name}: {str(e)}" + f"Error building metrics for column {column_name}: {str(e)}", ) # Skip this column and continue with others @@ -183,12 +191,12 @@ def _get_column_metrics(self, column_name: str, data_type: str) -> List[str]: if self.config.profiling.include_field_distinct_count: metrics.append( - f"COUNT(DISTINCT {quoted_column_name}) AS {safe_column_name}_distinct_count" + f"COUNT(DISTINCT {quoted_column_name}) AS {safe_column_name}_distinct_count", ) if self.config.profiling.include_field_null_count: metrics.append( - f"SUM(CASE WHEN {quoted_column_name} IS NULL THEN 1 ELSE 0 END) AS {safe_column_name}_null_count" + f"SUM(CASE WHEN {quoted_column_name} IS NULL THEN 1 ELSE 0 END) AS {safe_column_name}_null_count", ) if self.config.profiling.include_field_min_value: @@ -207,31 +215,33 @@ def _get_column_metrics(self, column_name: str, data_type: str) -> List[str]: ]: if self.config.profiling.include_field_mean_value: metrics.append( - f"AVG(CAST({quoted_column_name} AS DOUBLE)) AS {safe_column_name}_mean" + f"AVG(CAST({quoted_column_name} AS DOUBLE)) AS {safe_column_name}_mean", ) if self.config.profiling.include_field_stddev_value: metrics.append( - f"STDDEV(CAST({quoted_column_name} AS DOUBLE)) AS {safe_column_name}_stdev" + f"STDDEV(CAST({quoted_column_name} AS DOUBLE)) AS {safe_column_name}_stdev", ) if self.config.profiling.include_field_median_value: metrics.append( - f"MEDIAN({quoted_column_name}) AS {safe_column_name}_median" + f"MEDIAN({quoted_column_name}) AS {safe_column_name}_median", ) if self.config.profiling.include_field_quantiles: metrics.append( - f"PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY {quoted_column_name}) AS {safe_column_name}_25th_percentile" + f"PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY {quoted_column_name}) AS {safe_column_name}_25th_percentile", ) metrics.append( - f"PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY {quoted_column_name}) AS {safe_column_name}_75th_percentile" + f"PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY {quoted_column_name}) AS {safe_column_name}_75th_percentile", ) return metrics def _parse_profile_results( - self, results: List[Dict], columns: List[Tuple[str, str]] + self, + results: List[Dict], + columns: List[Tuple[str, str]], ) -> Dict: profile: Dict[str, Any] = {"column_stats": {}} result = results[0] if results else {} # We expect only one row of results diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_reporting.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_reporting.py index 9712d4ddc67998..64b1c7a6ab8a64 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_reporting.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_reporting.py @@ -10,7 +10,9 @@ @dataclass class DremioSourceReport( - SQLSourceReport, StaleEntityRemovalSourceReport, IngestionStageReport + SQLSourceReport, + StaleEntityRemovalSourceReport, + IngestionStageReport, ): num_containers_failed: int = 0 num_datasets_failed: int = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py index 6d34e86be6282e..2d20c15d5f6f0a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_source.py @@ -205,7 +205,7 @@ def _build_source_map(self) -> Dict[str, Dict]: datahub_source_type = ( DremioToDataHubSourceTypeMapping.get_datahub_source_type( - source_type + source_type, ) ) @@ -233,11 +233,12 @@ def _build_source_map(self) -> Dict[str, Dict]: except Exception as exc: logger.info( f"Source {source_type} is not a standard Dremio source type. " - f"Adding source_type {source_type} to mapping as database. Error: {exc}" + f"Adding source_type {source_type} to mapping as database. Error: {exc}", ) DremioToDataHubSourceTypeMapping.add_mapping( - source_type, source_name + source_type, + source_name, ) dremio_source_type = ( DremioToDataHubSourceTypeMapping.get_category(source_type) @@ -251,10 +252,10 @@ def _build_source_map(self) -> Dict[str, Dict]: else: logger.error( - f'Source "{source.container_name}" is broken. Containers will not be created for source.' + f'Source "{source.container_name}" is broken. Containers will not be created for source.', ) logger.error( - f'No new cross-platform lineage will be emitted for source "{source.container_name}".' + f'No new cross-platform lineage will be emitted for source "{source.container_name}".', ) logger.error("Fix this source in Dremio to fix this issue.") @@ -264,7 +265,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -281,7 +284,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: try: yield from self.process_container(container) logger.info( - f"Dremio container {container.container_name} emitted successfully" + f"Dremio container {container.container_name} emitted successfully", ) except Exception as exc: self.report.num_containers_failed += 1 # Increment failed containers @@ -298,7 +301,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: try: yield from self.process_dataset(dataset_info) logger.info( - f"Dremio dataset {'.'.join(dataset_info.path)}.{dataset_info.resource_name} emitted successfully" + f"Dremio dataset {'.'.join(dataset_info.path)}.{dataset_info.resource_name} emitted successfully", ) except Exception as exc: self.report.num_datasets_failed += 1 # Increment failed datasets @@ -333,7 +336,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # Profiling if self.config.is_profiling_enabled(): with ThreadPoolExecutor( - max_workers=self.config.profiling.max_workers + max_workers=self.config.profiling.max_workers, ) as executor: future_to_dataset = { executor.submit(self.generate_profiles, dataset): dataset @@ -355,21 +358,25 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) def process_container( - self, container_info: DremioContainer + self, + container_info: DremioContainer, ) -> Iterable[MetadataWorkUnit]: """ Process a Dremio container and generate metadata workunits. """ container_urn = self.dremio_aspects.get_container_urn( - path=container_info.path, name=container_info.container_name + path=container_info.path, + name=container_info.container_name, ) yield from self.dremio_aspects.populate_container_mcp( - container_urn, container_info + container_urn, + container_info, ) def process_dataset( - self, dataset_info: DremioDataset + self, + dataset_info: DremioDataset, ) -> Iterable[MetadataWorkUnit]: """ Process a Dremio dataset and generate metadata workunits. @@ -392,12 +399,14 @@ def process_dataset( ) for dremio_mcp in self.dremio_aspects.populate_dataset_mcp( - dataset_urn, dataset_info + dataset_urn, + dataset_info, ): yield dremio_mcp # Check if the emitted aspect is SchemaMetadataClass if isinstance( - dremio_mcp.metadata, MetadataChangeProposalWrapper + dremio_mcp.metadata, + MetadataChangeProposalWrapper, ) and isinstance(dremio_mcp.metadata.aspect, SchemaMetadataClass): self.sql_parsing_aggregator.register_schema( urn=dataset_urn, @@ -438,8 +447,8 @@ def process_dataset( UpstreamClass( dataset=upstream_urn, type=DatasetLineageTypeClass.COPY, - ) - ] + ), + ], ) mcp = MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -453,7 +462,8 @@ def process_dataset( ) def process_glossary_term( - self, glossary_term_info: DremioGlossaryTerm + self, + glossary_term_info: DremioGlossaryTerm, ) -> Iterable[MetadataWorkUnit]: """ Process a Dremio container and generate metadata workunits. @@ -462,7 +472,8 @@ def process_glossary_term( yield from self.dremio_aspects.populate_glossary_term_mcp(glossary_term_info) def generate_profiles( - self, dataset_info: DremioDataset + self, + dataset_info: DremioDataset, ) -> Iterable[MetadataWorkUnit]: schema_str = ".".join(dataset_info.path) dataset_name = f"{schema_str}.{dataset_info.resource_name}".lower() @@ -476,7 +487,9 @@ def generate_profiles( yield from self.profiler.get_workunits(dataset_info, dataset_urn) def generate_view_lineage( - self, dataset_urn: str, parents: List[str] + self, + dataset_urn: str, + parents: List[str], ) -> Iterable[MetadataWorkUnit]: """ Generate lineage information for views. @@ -498,7 +511,7 @@ def generate_view_lineage( type=DatasetLineageTypeClass.VIEW, ) for upstream_urn in upstream_urns - ] + ], ) mcp = MetadataChangeProposalWrapper( entityType="dataset", @@ -580,7 +593,7 @@ def process_query(self, query: DremioQuery) -> None: timestamp=query.submitted_ts, user=CorpUserUrn(username=query.username), default_db=self.default_db, - ) + ), ) def _map_dremio_dataset_to_urn( @@ -601,7 +614,8 @@ def _map_dremio_dataset_to_urn( return None platform_instance = mapping.get( - "platform_instance", self.config.platform_instance + "platform_instance", + self.config.platform_instance, ) env = mapping.get("env", self.config.env) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/data_reader.py index 13804d703ac93c..2a1257b7c1a132 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/data_reader.py @@ -24,7 +24,10 @@ def __init__(self, client: "DynamoDBClient") -> None: self.client = client def get_sample_data_for_table( - self, table_id: List[str], sample_size: int, **kwargs: Any + self, + table_id: List[str], + sample_size: int, + **kwargs: Any, ) -> Dict[str, list]: """ For dynamoDB, table_id should be in formation ( table-name ) or (region, table-name ) @@ -50,7 +53,7 @@ def get_sample_data_for_table( # for complex data types - L (list) or M (map) - we will recursively process the value into json-like form for attribute_name, attribute_value in item.items(): column_values[attribute_name].append( - self._get_value(attribute_value) + self._get_value(attribute_value), ) # Note: Consider including items configured via `include_table_item` in sample data ? diff --git a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py index cb3f0dd9cf29f4..ac1644f2a137b5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py @@ -199,7 +199,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -217,7 +219,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: continue table_wu_generator = self._process_table( - region, dynamodb_client, table_name, dataset_name + region, + dynamodb_client, + table_name, + dataset_name, ) yield from classification_workunit_processor( table_wu_generator, @@ -252,7 +257,9 @@ def _process_table( primary_key_dict = self.extract_primary_key_from_key_schema(table_info) table_schema = self.construct_schema_from_dynamodb( - dynamodb_client, region, table_name + dynamodb_client, + region, + table_name, ) schema_metadata = self.construct_schema_metadata( @@ -361,20 +368,20 @@ def include_table_item_to_schema( assert isinstance(primary_key_list, List) if len(primary_key_list) > MAX_PRIMARY_KEYS_SIZE: logger.info( - f"the provided primary keys list size exceeded the max size for table {dataset_name}, we'll only process the first {MAX_PRIMARY_KEYS_SIZE} items" + f"the provided primary keys list size exceeded the max size for table {dataset_name}, we'll only process the first {MAX_PRIMARY_KEYS_SIZE} items", ) primary_key_list = primary_key_list[0:MAX_PRIMARY_KEYS_SIZE] items: List[Dict[str, "AttributeValueTypeDef"]] = [] response = dynamodb_client.batch_get_item( - RequestItems={table_name: {"Keys": primary_key_list}} + RequestItems={table_name: {"Keys": primary_key_list}}, ).get("Responses") if response is None: logger.error( - f"failed to retrieve item from table {table_name} by the given key {primary_key_list}" + f"failed to retrieve item from table {table_name} by the given key {primary_key_list}", ) return logger.debug( - f"successfully retrieved {len(primary_key_list)} items based on supplied primary key list" + f"successfully retrieved {len(primary_key_list)} items based on supplied primary key list", ) items = response.get(table_name) or [] @@ -411,7 +418,7 @@ def append_schema( # Handle nested maps by recursive calls if data_type == "M": logger.debug( - f"expanding nested fields for map, current_field_path: {current_field_path}" + f"expanding nested fields for map, current_field_path: {current_field_path}", ) self.append_schema(schema, attribute_value, current_field_path) @@ -435,7 +442,7 @@ def append_schema( ) # Mark as nullable if null encountered types = schema[current_field_path]["types"] logger.debug( - f"append schema with field_path: {current_field_path} and type: {types}" + f"append schema with field_path: {current_field_path} and type: {types}", ) def construct_schema_metadata( @@ -510,7 +517,8 @@ def construct_schema_metadata( return schema_metadata def extract_primary_key_from_key_schema( - self, table_info: "TableDescriptionTypeDef" + self, + table_info: "TableDescriptionTypeDef", ) -> Dict[str, str]: key_schema = table_info.get("KeySchema") primary_key_dict: Dict[str, str] = {} @@ -524,7 +532,7 @@ def extract_primary_key_from_key_schema( def get_native_type(self, attribute_type: Union[type, str], table_name: str) -> str: assert isinstance(attribute_type, str) type_string: Optional[str] = _attribute_type_to_native_type_mapping.get( - attribute_type + attribute_type, ) if type_string is None: self.report.report_warning( @@ -536,11 +544,13 @@ def get_native_type(self, attribute_type: Union[type, str], table_name: str) -> return type_string def get_field_type( - self, attribute_type: Union[type, str], table_name: str + self, + attribute_type: Union[type, str], + table_name: str, ) -> SchemaFieldDataType: assert isinstance(attribute_type, str) type_class: Optional[type] = _attribute_type_to_field_type_mapping.get( - attribute_type + attribute_type, ) if type_class is None: @@ -556,13 +566,15 @@ def get_report(self) -> DynamoDBSourceReport: return self.report def _get_domain_wu( - self, dataset_name: str, entity_urn: str + self, + dataset_name: str, + entity_urn: str, ) -> Iterable[MetadataWorkUnit]: domain_urn = None for domain, pattern in self.config.domain.items(): if pattern.allowed(dataset_name): domain_urn = make_domain_urn( - self.domain_registry.get_domain_urn(domain) + self.domain_registry.get_domain_urn(domain), ) break diff --git a/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py b/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py index ce1c60dcafdd46..8f0d7d756ec92b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py +++ b/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py @@ -113,12 +113,12 @@ class ElasticToSchemaFieldConverter: def get_column_type(elastic_column_type: str) -> SchemaFieldDataType: type_class: Optional[Type] = ( ElasticToSchemaFieldConverter._field_type_to_schema_field_type.get( - elastic_column_type + elastic_column_type, ) ) if type_class is None: logger.warning( - f"Cannot map {elastic_column_type!r} to SchemaFieldDataType, using NullTypeClass." + f"Cannot map {elastic_column_type!r} to SchemaFieldDataType, using NullTypeClass.", ) type_class = NullTypeClass @@ -131,7 +131,8 @@ def _get_cur_field_path(self) -> str: return ".".join(self._prefix_name_stack) def _get_schema_fields( - self, elastic_schema_dict: Dict[str, Any] + self, + elastic_schema_dict: Dict[str, Any], ) -> Generator[SchemaField, None, None]: # append each schema field (sort so output is consistent) PROPERTIES: str = "properties" @@ -168,19 +169,20 @@ def _get_schema_fields( # Unexpected! Log a warning. logger.warning( f"Elastic schema does not have either 'type' or 'properties'!" - f" Schema={json.dumps(elastic_schema_dict)}" + f" Schema={json.dumps(elastic_schema_dict)}", ) continue @classmethod def get_schema_fields( - cls, elastic_mappings: Dict[str, Any] + cls, + elastic_mappings: Dict[str, Any], ) -> Generator[SchemaField, None, None]: converter = cls() properties = elastic_mappings.get("properties") if not properties: logger.warning( - f"Missing 'properties' in elastic search mappings={json.dumps(elastic_mappings)}!" + f"Missing 'properties' in elastic search mappings={json.dumps(elastic_mappings)}!", ) return yield from converter._get_schema_fields(properties) @@ -235,19 +237,22 @@ def collapse_urn(urn: str, collapse_urns: CollapseUrns) -> str: platform_id=data_platform_urn.get_entity_id_as_string(), table_name=name, env=urn_obj.get_env(), - ) + ), ) class ElasticsearchSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): host: str = Field( - default="localhost:9200", description="The elastic search host URI." + default="localhost:9200", + description="The elastic search host URI.", ) username: Optional[str] = Field( - default=None, description="The username credential." + default=None, + description="The username credential.", ) password: Optional[str] = Field( - default=None, description="The password credential." + default=None, + description="The password credential.", ) api_key: Optional[Union[Any, str]] = Field( default=None, @@ -255,15 +260,18 @@ class ElasticsearchSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): ) use_ssl: bool = Field( - default=False, description="Whether to use SSL for the connection or not." + default=False, + description="Whether to use SSL for the connection or not.", ) verify_certs: bool = Field( - default=False, description="Whether to verify SSL certificates." + default=False, + description="Whether to verify SSL certificates.", ) ca_certs: Optional[str] = Field( - default=None, description="Path to a certificate authority (CA) certificate." + default=None, + description="Path to a certificate authority (CA) certificate.", ) client_cert: Optional[str] = Field( @@ -277,7 +285,8 @@ class ElasticsearchSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): ) ssl_assert_hostname: bool = Field( - default=False, description="Use hostname verification if not False." + default=False, + description="Use hostname verification if not False.", ) ssl_assert_fingerprint: Optional[str] = Field( @@ -294,7 +303,8 @@ class ElasticsearchSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): description="regex patterns for indexes to filter in ingestion.", ) ingest_index_templates: bool = Field( - default=False, description="Ingests ES index templates if enabled." + default=False, + description="Ingests ES index templates if enabled.", ) index_template_pattern: AllowDenyPattern = Field( default=AllowDenyPattern(allow=[".*"], deny=["^_.*"]), @@ -314,7 +324,7 @@ class ElasticsearchSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @validator("host") @@ -367,7 +377,9 @@ def __init__(self, config: ElasticsearchSourceConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: Dict[str, Any], ctx: PipelineContext + cls, + config_dict: Dict[str, Any], + ctx: PipelineContext, ) -> "ElasticsearchSource": config = ElasticsearchSourceConfig.parse_obj(config_dict) return cls(config, ctx) @@ -403,17 +415,20 @@ def _get_data_stream_index_count_mcps( platform_instance=self.source_config.platform_instance, ) dataset_urn = collapse_urn( - urn=dataset_urn, collapse_urns=self.source_config.collapse_urns + urn=dataset_urn, + collapse_urns=self.source_config.collapse_urns, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=DatasetPropertiesClass( - customProperties={"numPartitions": str(count)} + customProperties={"numPartitions": str(count)}, ), ) def _extract_mcps( - self, index: str, is_index: bool = True + self, + index: str, + is_index: bool = True, ) -> Iterable[MetadataChangeProposalWrapper]: logger.debug(f"index='{index}', is_index={is_index}") @@ -433,7 +448,8 @@ def _extract_mcps( raw_index = self.client.indices.get_template(name=index) raw_index_metadata = raw_index[index] collapsed_index_name = collapse_name( - name=index, collapse_urns=self.source_config.collapse_urns + name=index, + collapse_urns=self.source_config.collapse_urns, ) # 1. Construct and emit the schemaMetadata aspect @@ -442,7 +458,7 @@ def _extract_mcps( index_mappings_json_str: str = json.dumps(index_mappings) md5_hash = md5(index_mappings_json_str.encode()).hexdigest() schema_fields = list( - ElasticToSchemaFieldConverter.get_schema_fields(index_mappings) + ElasticToSchemaFieldConverter.get_schema_fields(index_mappings), ) if not schema_fields: return @@ -465,7 +481,8 @@ def _extract_mcps( env=self.source_config.env, ) dataset_urn = collapse_urn( - urn=dataset_urn, collapse_urns=self.source_config.collapse_urns + urn=dataset_urn, + collapse_urns=self.source_config.collapse_urns, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -491,8 +508,8 @@ def _extract_mcps( if not data_stream else DatasetSubTypes.ELASTIC_DATASTREAM ) - ) - ] + ), + ], ), ) @@ -509,7 +526,8 @@ def _extract_mcps( # 4.3 number_of_shards index_settings: Dict[str, Any] = raw_index_metadata.get("settings", {}).get( - "index", {} + "index", + {}, ) num_shards: str = index_settings.get("number_of_shards", "") if num_shards: @@ -531,7 +549,8 @@ def _extract_mcps( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.source_config.platform_instance + self.platform, + self.source_config.platform_instance, ), ), ) @@ -543,7 +562,7 @@ def _extract_mcps( "format": "json", "bytes": "b", "h": "index,docs.count,store.size", - } + }, ) if self.cat_response is None: return @@ -554,13 +573,14 @@ def _extract_mcps( ) profile_info_current = list( - filter(lambda x: x["index"] == collapsed_index_name, self.cat_response) + filter(lambda x: x["index"] == collapsed_index_name, self.cat_response), ) if len(profile_info_current) > 0: self.cat_response = list( filter( - lambda x: x["index"] != collapsed_index_name, self.cat_response - ) + lambda x: x["index"] != collapsed_index_name, + self.cat_response, + ), ) row_count = 0 size_in_bytes = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/feast.py b/metadata-ingestion/src/datahub/ingestion/source/feast.py index 6330fe0291660d..27f9422e2fb236 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/feast.py +++ b/metadata-ingestion/src/datahub/ingestion/source/feast.py @@ -93,7 +93,8 @@ class FeastRepositorySourceConfig(ConfigModel): description="Path to the `feature_store.yaml` file used to configure the feature store", ) environment: str = Field( - default=DEFAULT_ENV, description="Environment to use when constructing URNs" + default=DEFAULT_ENV, + description="Environment to use when constructing URNs", ) # owner_mappings example: # This must be added to the recipe in order to extract owners, otherwise NO owners will be extracted @@ -102,7 +103,8 @@ class FeastRepositorySourceConfig(ConfigModel): # datahub_owner_urn: "urn:li:corpGroup:" # datahub_ownership_type: "BUSINESS_OWNER" owner_mappings: Optional[List[Dict[str, str]]] = Field( - default=None, description="Mapping of owner names to owner types" + default=None, + description="Mapping of owner names to owner types", ) enable_owner_extraction: bool = Field( default=False, @@ -148,7 +150,9 @@ def __init__(self, config: FeastRepositorySourceConfig, ctx: PipelineContext): ) def _get_field_type( - self, field_type: Union[ValueType, feast.types.FeastType], parent_name: str + self, + field_type: Union[ValueType, feast.types.FeastType], + parent_name: str, ) -> str: """ Maps types encountered in Feast to corresponding schema types. @@ -158,7 +162,8 @@ def _get_field_type( if ml_feature_data_type is None: self.report.report_warning( - parent_name, f"unable to map type {field_type} to metadata schema" + parent_name, + f"unable to map type {field_type} to metadata schema", ) ml_feature_data_type = MLFeatureDataType.UNKNOWN @@ -205,32 +210,34 @@ def _get_data_sources(self, feature_view: FeatureView) -> List[str]: if feature_view.batch_source is not None: batch_source_platform, batch_source_name = self._get_data_source_details( - feature_view.batch_source + feature_view.batch_source, ) sources.append( builder.make_dataset_urn( batch_source_platform, batch_source_name, self.source_config.environment, - ) + ), ) if feature_view.stream_source is not None: stream_source_platform, stream_source_name = self._get_data_source_details( - feature_view.stream_source + feature_view.stream_source, ) sources.append( builder.make_dataset_urn( stream_source_platform, stream_source_name, self.source_config.environment, - ) + ), ) return sources def _get_entity_workunit( - self, feature_view: FeatureView, entity: Entity + self, + feature_view: FeatureView, + entity: Entity, ) -> MetadataWorkUnit: """ Generate an MLPrimaryKey work unit for a Feast entity. @@ -253,7 +260,7 @@ def _get_entity_workunit( description=entity.description, dataType=self._get_field_type(entity.value_type, entity.name), sources=self._get_data_sources(feature_view), - ) + ), ) mce = MetadataChangeEvent(proposedSnapshot=entity_snapshot) @@ -284,7 +291,7 @@ def _get_feature_workunit( if feature_view.source_request_sources is not None: for request_source in feature_view.source_request_sources.values(): source_platform, source_name = self._get_data_source_details( - request_source + request_source, ) feature_sources.append( @@ -292,7 +299,7 @@ def _get_feature_workunit( source_platform, source_name, self.source_config.environment, - ) + ), ) if feature_view.source_feature_view_projections is not None: @@ -300,7 +307,7 @@ def _get_feature_workunit( feature_view_projection ) in feature_view.source_feature_view_projections.values(): feature_view_source = self.feature_store.get_feature_view( - feature_view_projection.name + feature_view_projection.name, ) feature_sources.extend(self._get_data_sources(feature_view_source)) @@ -310,7 +317,7 @@ def _get_feature_workunit( description=field.tags.get("description"), dataType=self._get_field_type(field.dtype, field.name), sources=feature_sources, - ) + ), ) mce = MetadataChangeEvent(proposedSnapshot=feature_snapshot) @@ -350,7 +357,7 @@ def _get_feature_view_workunit(self, feature_view: FeatureView) -> MetadataWorkU builder.make_ml_primary_key_urn(feature_view_name, entity_name) for entity_name in feature_view.entities ], - ) + ), ) mce = MetadataChangeEvent(proposedSnapshot=feature_view_snapshot) @@ -358,7 +365,8 @@ def _get_feature_view_workunit(self, feature_view: FeatureView) -> MetadataWorkU return MetadataWorkUnit(id=feature_view_name, mce=mce) def _get_on_demand_feature_view_workunit( - self, on_demand_feature_view: OnDemandFeatureView + self, + on_demand_feature_view: OnDemandFeatureView, ) -> MetadataWorkUnit: """ Generate an MLFeatureTable work unit for a Feast on-demand feature view. @@ -386,7 +394,7 @@ def _get_on_demand_feature_view_workunit( for feature in on_demand_feature_view.features ], mlPrimaryKeys=[], - ) + ), ) mce = MetadataChangeEvent(proposedSnapshot=on_demand_feature_view_snapshot) @@ -406,7 +414,7 @@ def _get_tags(self, obj: Union[Entity, FeatureView, FeastField]) -> list: if obj.tags.get("name"): tag_name: str = obj.tags["name"] tag_association = TagAssociationClass( - tag=builder.make_tag_urn(tag_name) + tag=builder.make_tag_urn(tag_name), ) global_tags_aspect = GlobalTagsClass(tags=[tag_association]) aspects.append(global_tags_aspect) @@ -441,7 +449,8 @@ def _create_owner_association(self, owner: str) -> Optional[OwnerClass]: for mapping in self.source_config.owner_mappings: if mapping["feast_owner_name"] == owner: ownership_type_class: str = mapping.get( - "datahub_ownership_type", "TECHNICAL_OWNER" + "datahub_ownership_type", + "TECHNICAL_OWNER", ) datahub_owner_urn = mapping.get("datahub_owner_urn") if datahub_owner_urn: diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index 0b1665d579191e..5d249109e6b187 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -69,7 +69,7 @@ class FileSourceConfig(StatefulIngestionConfigBase): description=( "File path to folder or file to ingest, or URL to a remote file. " "If pointed to a folder, all files with extension {file_extension} (default json) within that folder will be processed." - ) + ), ) file_extension: str = Field( ".json", @@ -96,7 +96,9 @@ class FileSourceConfig(StatefulIngestionConfigBase): ) _filename_populates_path_if_present = pydantic_renamed_field( - "filename", "path", print_warning=False + "filename", + "path", + print_warning=False, ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None @@ -163,7 +165,7 @@ def compute_stats(self) -> None: ): current_files_bytes_read = int( (self.current_file_elements_read / self.current_file_num_elements) - * self.current_file_size + * self.current_file_size, ) total_bytes_read += current_files_bytes_read percentage_completion = ( @@ -179,7 +181,7 @@ def compute_stats(self) -> None: * (100 - percentage_completion) / percentage_completion ) - / 60 + / 60, ) self.percentage_completion = f"{percentage_completion:.2f}%" @@ -212,7 +214,7 @@ def get_filenames(self) -> Iterable[FileInfo]: fs = fs_class.create() for file_info in fs.list(path_str): if file_info.is_file and file_info.path.endswith( - self.config.file_extension + self.config.file_extension, ): yield file_info @@ -222,7 +224,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: partial(auto_workunit_reporter, self.report), auto_status_aspect if self.config.stateful_ingestion else None, StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -233,7 +237,8 @@ def get_workunits_internal( for i, obj in self.iterate_generic_file(f): id = f"{f.path}:{i}" if isinstance( - obj, (MetadataChangeProposalWrapper, MetadataChangeProposal) + obj, + (MetadataChangeProposalWrapper, MetadataChangeProposal), ): if ( self.config.aspect is not None @@ -329,7 +334,8 @@ def iterate_mce_file(self, path: str) -> Iterator[MetadataChangeEvent]: yield mce def iterate_generic_file( - self, file_status: FileInfo + self, + file_status: FileInfo, ) -> Iterator[ Tuple[ int, @@ -362,7 +368,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: basic_connectivity=CapabilityReport( capable=False, failure_reason=f"{config.path} doesn't appear to be a valid file or directory.", - ) + ), ) is_dir = os.path.isdir(config.path) failure_message = None @@ -377,12 +383,13 @@ def test_connection(config_dict: dict) -> TestConnectionReport: if failure_message: return TestConnectionReport( basic_connectivity=CapabilityReport( - capable=False, failure_reason=failure_message - ) + capable=False, + failure_reason=failure_message, + ), ) else: return TestConnectionReport( - basic_connectivity=CapabilityReport(capable=True) + basic_connectivity=CapabilityReport(capable=True), ) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py b/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py index 86826ae7bedc09..f9c8de7f3953c6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py @@ -98,7 +98,8 @@ class FivetranLogConfig(ConfigModel): description="If destination platform is 'bigquery', provide bigquery configuration.", ) _rename_destination_config = pydantic_renamed_field( - "destination_config", "snowflake_destination_config" + "destination_config", + "snowflake_destination_config", ) @root_validator(pre=True) @@ -107,16 +108,16 @@ def validate_destination_platfrom_and_config(cls, values: Dict) -> Dict: if destination_platform == "snowflake": if "snowflake_destination_config" not in values: raise ValueError( - "If destination platform is 'snowflake', user must provide snowflake destination configuration in the recipe." + "If destination platform is 'snowflake', user must provide snowflake destination configuration in the recipe.", ) elif destination_platform == "bigquery": if "bigquery_destination_config" not in values: raise ValueError( - "If destination platform is 'bigquery', user must provide bigquery destination configuration in the recipe." + "If destination platform is 'bigquery', user must provide bigquery destination configuration in the recipe.", ) else: raise ValueError( - f"Destination platform '{destination_platform}' is not yet supported." + f"Destination platform '{destination_platform}' is not yet supported.", ) return values @@ -124,13 +125,13 @@ def validate_destination_platfrom_and_config(cls, values: Dict) -> Dict: @dataclasses.dataclass class MetadataExtractionPerfReport(Report): connectors_metadata_extraction_sec: PerfTimer = dataclasses.field( - default_factory=PerfTimer + default_factory=PerfTimer, ) connectors_lineage_extraction_sec: PerfTimer = dataclasses.field( - default_factory=PerfTimer + default_factory=PerfTimer, ) connectors_jobs_extraction_sec: PerfTimer = dataclasses.field( - default_factory=PerfTimer + default_factory=PerfTimer, ) @@ -139,7 +140,7 @@ class FivetranSourceReport(StaleEntityRemovalSourceReport): connectors_scanned: int = 0 filtered_connectors: LossyList[str] = dataclasses.field(default_factory=LossyList) metadata_extraction_perf: MetadataExtractionPerfReport = dataclasses.field( - default_factory=MetadataExtractionPerfReport + default_factory=MetadataExtractionPerfReport, ) def report_connectors_scanned(self, count: int = 1) -> None: @@ -190,7 +191,8 @@ class FivetranSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="Airbyte Stateful Ingestion Config." + default=None, + description="Airbyte Stateful Ingestion Config.", ) # Fivetran connector all sources to platform instance mapping @@ -218,7 +220,8 @@ def compat_sources_to_database(cls, values: Dict) -> Dict: for key, value in mapping.items(): values["sources_to_platform_instance"].setdefault(key, {}) values["sources_to_platform_instance"][key].setdefault( - "database", value + "database", + value, ) return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py index d8ebbe5b63d1ae..b4ae94d6618413 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py +++ b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py @@ -83,7 +83,8 @@ def _extend_lineage(self, connector: Connector, datajob: DataJob) -> Dict[str, s # Get platform details for connector source source_details = self.config.sources_to_platform_instance.get( - connector.connector_id, PlatformDetail() + connector.connector_id, + PlatformDetail(), ) if source_details.platform is None: if connector.connector_type in KNOWN_DATA_PLATFORM_MAPPING: @@ -101,7 +102,8 @@ def _extend_lineage(self, connector: Connector, datajob: DataJob) -> Dict[str, s # Get platform details for destination destination_details = self.config.destination_to_platform_instance.get( - connector.destination_id, PlatformDetail() + connector.destination_id, + PlatformDetail(), ) if destination_details.platform is None: destination_details.platform = ( @@ -149,7 +151,7 @@ def _extend_lineage(self, connector: Connector, datajob: DataJob) -> Dict[str, s builder.make_schema_field_urn( str(input_dataset_urn), column_lineage.source_column, - ) + ), ] if input_dataset_urn else [] @@ -160,12 +162,12 @@ def _extend_lineage(self, connector: Connector, datajob: DataJob) -> Dict[str, s builder.make_schema_field_urn( str(output_dataset_urn), column_lineage.destination_column, - ) + ), ] if output_dataset_urn else [] ), - ) + ), ) datajob.inlets.extend(input_dataset_urn_list) @@ -237,7 +239,9 @@ def _generate_dpi_from_job(self, job: Job, datajob: DataJob) -> DataProcessInsta ) def _get_dpi_workunits( - self, job: Job, dpi: DataProcessInstance + self, + job: Job, + dpi: DataProcessInstance, ) -> Iterable[MetadataWorkUnit]: status_result_map: Dict[str, InstanceRunResult] = { Constant.SUCCESSFUL: InstanceRunResult.SUCCESS, @@ -247,13 +251,14 @@ def _get_dpi_workunits( if job.status not in status_result_map: logger.debug( f"Status should be either SUCCESSFUL, FAILURE_WITH_TASK or CANCELED and it was " - f"{job.status}" + f"{job.status}", ) return result = status_result_map[job.status] start_timestamp_millis = job.start_time * 1000 for mcp in dpi.generate_mcp( - created_ts_millis=start_timestamp_millis, materialize_iolets=False + created_ts_millis=start_timestamp_millis, + materialize_iolets=False, ): yield mcp.as_workunit() for mcp in dpi.start_event_mcp(start_timestamp_millis): @@ -266,7 +271,8 @@ def _get_dpi_workunits( yield mcp.as_workunit() def _get_connector_workunits( - self, connector: Connector + self, + connector: Connector, ) -> Iterable[MetadataWorkUnit]: self.report.report_connectors_scanned() # Create dataflow entity with same name as connector name @@ -295,7 +301,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran_log_api.py b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran_log_api.py index 529002270cdd9c..5e196280f48d29 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran_log_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran_log_api.py @@ -52,7 +52,7 @@ def _initialize_fivetran_variables( engine.execute( fivetran_log_query.use_database( snowflake_destination_config.database, - ) + ), ) fivetran_log_query.set_db( snowflake_destination_config.log_schema, @@ -70,7 +70,7 @@ def _initialize_fivetran_variables( fivetran_log_database = bigquery_destination_config.dataset else: raise ConfigurationError( - f"Destination platform '{destination_platform}' is not yet supported." + f"Destination platform '{destination_platform}' is not yet supported.", ) return ( engine, @@ -82,14 +82,16 @@ def _query(self, query: str) -> List[Dict]: # Automatically transpile snowflake query syntax to the target dialect. if self.fivetran_log_config.destination_platform != "snowflake": query = sqlglot.parse_one(query, dialect="snowflake").sql( - dialect=self.fivetran_log_config.destination_platform, pretty=True + dialect=self.fivetran_log_config.destination_platform, + pretty=True, ) logger.info(f"Executing query: {query}") resp = self.engine.execute(query) return [row for row in resp] def _get_column_lineage_metadata( - self, connector_ids: List[str] + self, + connector_ids: List[str], ) -> Dict[Tuple[str, str], List]: """ Returns dict of column lineage metadata with key as (, ) @@ -97,8 +99,8 @@ def _get_column_lineage_metadata( all_column_lineage = defaultdict(list) column_lineage_result = self._query( self.fivetran_log_query.get_column_lineage_query( - connector_ids=connector_ids - ) + connector_ids=connector_ids, + ), ) for column_lineage in column_lineage_result: key = ( @@ -114,7 +116,9 @@ def _get_table_lineage_metadata(self, connector_ids: List[str]) -> Dict[str, Lis """ connectors_table_lineage_metadata = defaultdict(list) table_lineage_result = self._query( - self.fivetran_log_query.get_table_lineage_query(connector_ids=connector_ids) + self.fivetran_log_query.get_table_lineage_query( + connector_ids=connector_ids, + ), ) for table_lineage in table_lineage_result: connectors_table_lineage_metadata[ @@ -136,7 +140,7 @@ def _extract_connector_lineage( ( table_lineage[Constant.SOURCE_TABLE_ID], table_lineage[Constant.DESTINATION_TABLE_ID], - ) + ), ) column_lineage_list: List[ColumnLineage] = [] if column_lineage_result: @@ -155,13 +159,15 @@ def _extract_connector_lineage( source_table=f"{table_lineage[Constant.SOURCE_SCHEMA_NAME]}.{table_lineage[Constant.SOURCE_TABLE_NAME]}", destination_table=f"{table_lineage[Constant.DESTINATION_SCHEMA_NAME]}.{table_lineage[Constant.DESTINATION_TABLE_NAME]}", column_lineage=column_lineage_list, - ) + ), ) return table_lineage_list def _get_all_connector_sync_logs( - self, syncs_interval: int, connector_ids: List[str] + self, + syncs_interval: int, + connector_ids: List[str], ) -> Dict[str, Dict[str, Dict[str, Tuple[float, Optional[str]]]]]: sync_logs: Dict[str, Dict[str, Dict[str, Tuple[float, Optional[str]]]]] = {} @@ -185,7 +191,8 @@ def _get_all_connector_sync_logs( return sync_logs def _get_jobs_list( - self, connector_sync_log: Optional[Dict[str, Dict]] + self, + connector_sync_log: Optional[Dict[str, Dict]], ) -> List[Job]: jobs: List[Job] = [] if connector_sync_log is None: @@ -211,7 +218,7 @@ def _get_jobs_list( start_time=round(connector_sync_log[sync_id]["sync_start"][0]), end_time=round(connector_sync_log[sync_id]["sync_end"][0]), status=message_data[Constant.STATUS], - ) + ), ) return jobs @@ -238,11 +245,14 @@ def _fill_connectors_lineage(self, connectors: List[Connector]) -> None: ) def _fill_connectors_jobs( - self, connectors: List[Connector], syncs_interval: int + self, + connectors: List[Connector], + syncs_interval: int, ) -> None: connector_ids = [connector.connector_id for connector in connectors] sync_logs = self._get_all_connector_sync_logs( - syncs_interval, connector_ids=connector_ids + syncs_interval, + connector_ids=connector_ids, ) for connector in connectors: connector.jobs = self._get_jobs_list(sync_logs.get(connector.connector_id)) @@ -263,14 +273,14 @@ def get_allowed_connectors_list( connector_name = connector[Constant.CONNECTOR_NAME] if not connector_patterns.allowed(connector_name): report.report_connectors_dropped( - f"{connector_name} (connector_id: {connector_id}, dropped due to filter pattern)" + f"{connector_name} (connector_id: {connector_id}, dropped due to filter pattern)", ) continue if not destination_patterns.allowed( - destination_id := connector[Constant.DESTINATION_ID] + destination_id := connector[Constant.DESTINATION_ID], ): report.report_connectors_dropped( - f"{connector_name} (connector_id: {connector_id}, destination_id: {destination_id})" + f"{connector_name} (connector_id: {connector_id}, destination_id: {destination_id})", ) continue connectors.append( @@ -284,7 +294,7 @@ def get_allowed_connectors_list( user_id=connector[Constant.CONNECTING_USER_ID], lineage=[], # filled later jobs=[], # filled later - ) + ), ) if not connectors: diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py index 5c03b873c5505e..d62bb671dc44b7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py @@ -113,7 +113,10 @@ def __init__(self, ctx: PipelineContext, config: DataHubGcSourceConfig): self.report.event_not_produced_warn = False self.graph = ctx.require_graph("The DataHubGc source") self.dataprocess_cleanup = DataProcessCleanup( - ctx, self.config.dataprocess_cleanup, self.report, self.config.dry_run + ctx, + self.config.dataprocess_cleanup, + self.report, + self.config.dry_run, ) self.soft_deleted_entities_cleanup = SoftDeletedEntitiesCleanup( ctx, @@ -157,7 +160,8 @@ def get_workunits_internal( self.soft_deleted_entities_cleanup.cleanup_soft_deleted_entities() except Exception as e: self.report.failure( - "While trying to cleanup soft deleted entities ", exc=e + "While trying to cleanup soft deleted entities ", + exc=e, ) if self.config.dataprocess_cleanup.enabled: try: @@ -176,28 +180,39 @@ def get_workunits_internal( def truncate_indices(self) -> None: self._truncate_timeseries_helper(aspect_name="operation", entity_type="dataset") self._truncate_timeseries_helper( - aspect_name="datasetusagestatistics", entity_type="dataset" + aspect_name="datasetusagestatistics", + entity_type="dataset", ) self._truncate_timeseries_helper( - aspect_name="chartUsageStatistics", entity_type="chart" + aspect_name="chartUsageStatistics", + entity_type="chart", ) self._truncate_timeseries_helper( - aspect_name="dashboardUsageStatistics", entity_type="dashboard" + aspect_name="dashboardUsageStatistics", + entity_type="dashboard", ) self._truncate_timeseries_helper( - aspect_name="queryusagestatistics", entity_type="query" + aspect_name="queryusagestatistics", + entity_type="query", ) def _truncate_timeseries_helper(self, aspect_name: str, entity_type: str) -> None: self._truncate_timeseries_with_watch_optional( - aspect_name=aspect_name, entity_type=entity_type, watch=False + aspect_name=aspect_name, + entity_type=entity_type, + watch=False, ) self._truncate_timeseries_with_watch_optional( - aspect_name=aspect_name, entity_type=entity_type, watch=True + aspect_name=aspect_name, + entity_type=entity_type, + watch=True, ) def _truncate_timeseries_with_watch_optional( - self, aspect_name: str, entity_type: str, watch: bool + self, + aspect_name: str, + entity_type: str, + watch: bool, ) -> None: graph = self.graph assert graph is not None @@ -238,7 +253,7 @@ def _truncate_timeseries_with_watch_optional( def x_days_ago_millis(self, days: int) -> int: x_days_ago_datetime = datetime.datetime.now( - datetime.timezone.utc + datetime.timezone.utc, ) - datetime.timedelta(days=days) return int(x_days_ago_datetime.timestamp() * 1000) @@ -255,7 +270,7 @@ def truncate_timeseries_util( gms_url = graph._gms_server if not dry_run: logger.info( - f"Going to truncate timeseries for {aspect} for {gms_url} older than {days_ago} days" + f"Going to truncate timeseries for {aspect} for {gms_url} older than {days_ago} days", ) days_ago_millis = self.x_days_ago_millis(days_ago) url = f"{gms_url}/operations?action=truncateTimeseriesAspect" @@ -342,9 +357,9 @@ def _get_expired_tokens(self) -> dict: "field": "expiresAt", "values": [str(int(time.time() * 1000))], "condition": "LESS_THAN", - } + }, ], - } + }, }, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/dataprocess_cleanup.py b/metadata-ingestion/src/datahub/ingestion/source/gc/dataprocess_cleanup.py index 64c1a0ad0bfbad..2767bcc60d0ad4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/dataprocess_cleanup.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/dataprocess_cleanup.py @@ -99,7 +99,8 @@ class DataProcessCleanupConfig(ConfigModel): enabled: bool = Field( - default=True, description="Whether to do data process cleanup." + default=True, + description="Whether to do data process cleanup.", ) retention_days: Optional[int] = Field( 10, @@ -117,11 +118,13 @@ class DataProcessCleanupConfig(ConfigModel): ) delete_empty_data_jobs: bool = Field( - False, description="Whether to delete Data Jobs without runs" + False, + description="Whether to delete Data Jobs without runs", ) delete_empty_data_flows: bool = Field( - False, description="Whether to delete Data Flows without runs" + False, + description="Whether to delete Data Flows without runs", ) hard_delete_entities: bool = Field( @@ -168,7 +171,7 @@ class DataProcessCleanupReport(SourceReport): num_aspects_removed: int = 0 num_aspect_removed_by_type: TopKDict[str, int] = field(default_factory=TopKDict) sample_soft_deleted_aspects_by_type: TopKDict[str, LossyList[str]] = field( - default_factory=TopKDict + default_factory=TopKDict, ) num_data_flows_found: int = 0 num_data_jobs_found: int = 0 @@ -238,13 +241,17 @@ def fetch_dpis(self, job_urn: str, batch_size: int) -> List[dict]: break except Exception as e: self.report.failure( - f"Exception while fetching DPIs for job {job_urn}:", exc=e + f"Exception while fetching DPIs for job {job_urn}:", + exc=e, ) break return dpis def keep_last_n_dpi( - self, dpis: List[Dict], job: DataJobEntity, executor: ThreadPoolExecutor + self, + dpis: List[Dict], + job: DataJobEntity, + executor: ThreadPoolExecutor, ) -> None: if not self.config.keep_last_n: return @@ -254,7 +261,9 @@ def keep_last_n_dpi( futures = {} for dpi in dpis[self.config.keep_last_n :]: future = executor.submit( - self.delete_entity, dpi["urn"], "dataprocessInstance" + self.delete_entity, + dpi["urn"], + "dataprocessInstance", ) futures[future] = dpi @@ -265,7 +274,8 @@ def keep_last_n_dpi( futures[future]["deleted"] = True except Exception as e: self.report.report_failure( - f"Exception while deleting DPI: {e}", exc=e + f"Exception while deleting DPI: {e}", + exc=e, ) if ( deleted_count_last_n % self.config.batch_size == 0 @@ -292,7 +302,7 @@ def delete_entity(self, urn: str, type: str) -> None: if self.dry_run: logger.info( - f"Dry run is on otherwise it would have deleted {urn} with hard deletion is {self.config.hard_delete_entities}" + f"Dry run is on otherwise it would have deleted {urn} with hard deletion is {self.config.hard_delete_entities}", ) return @@ -318,12 +328,18 @@ def delete_dpi_from_datajobs(self, job: DataJobEntity) -> None: job.total_runs = len( list( - filter(lambda dpi: "deleted" not in dpi or not dpi.get("deleted"), dpis) - ) + filter( + lambda dpi: "deleted" not in dpi or not dpi.get("deleted"), + dpis, + ), + ), ) def remove_old_dpis( - self, dpis: List[Dict], job: DataJobEntity, executor: ThreadPoolExecutor + self, + dpis: List[Dict], + job: DataJobEntity, + executor: ThreadPoolExecutor, ) -> None: if self.config.retention_days is None: return @@ -345,7 +361,9 @@ def remove_old_dpis( or dpi["created"]["time"] < retention_time * 1000 ): future = executor.submit( - self.delete_entity, dpi["urn"], "dataprocessInstance" + self.delete_entity, + dpi["urn"], + "dataprocessInstance", ) futures[future] = dpi @@ -362,7 +380,7 @@ def remove_old_dpis( and deleted_count_retention > 0 ): logger.info( - f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention" + f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention", ) if self.config.delay: @@ -371,7 +389,7 @@ def remove_old_dpis( if deleted_count_retention > 0: logger.info( - f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention" + f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention", ) else: logger.debug(f"No DPIs to delete from {job.urn} due to retention") @@ -395,7 +413,8 @@ def get_data_flows(self) -> Iterable[DataFlowEntity]: ) except Exception as e: self.report.failure( - f"While trying to get dataflows with {scroll_id}", exc=e + f"While trying to get dataflows with {scroll_id}", + exc=e, ) break @@ -447,7 +466,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) except Exception as e: self.report.failure( - f"While trying to get data jobs with {scroll_id}", exc=e + f"While trying to get data jobs with {scroll_id}", + exc=e, ) break scrollAcrossEntities = result.get("scrollAcrossEntities") @@ -472,14 +492,15 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.delete_dpi_from_datajobs(datajob_entity) except Exception as e: self.report.failure( - f"While trying to delete {datajob_entity} ", exc=e + f"While trying to delete {datajob_entity} ", + exc=e, ) if ( datajob_entity.total_runs == 0 and self.config.delete_empty_data_jobs ): logger.info( - f"Deleting datajob {datajob_entity.urn} because there are no runs" + f"Deleting datajob {datajob_entity.urn} because there are no runs", ) self.delete_entity(datajob_entity.urn, "dataJob") deleted_jobs += 1 @@ -501,7 +522,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for key in dataFlows.keys(): if not dataJobs.get(key) or len(dataJobs[key]) == 0: logger.info( - f"Deleting dataflow {key} because there are not datajobs" + f"Deleting dataflow {key} because there are not datajobs", ) self.delete_entity(key, "dataFlow") deleted_data_flows += 1 diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py b/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py index c1763b16f3670f..3fc863034dc205 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py @@ -49,7 +49,8 @@ class DatahubExecutionRequestCleanupConfig(ConfigModel): ) limit_entities_delete: Optional[int] = Field( - 10000, description="Max number of execution requests to hard delete." + 10000, + description="Max number of execution requests to hard delete.", ) max_read_errors: int = Field( @@ -130,7 +131,8 @@ def _to_cleanup_record(self, entry: Dict) -> CleanupRecord: ) def _scroll_execution_requests( - self, overrides: Dict[str, Any] = {} + self, + overrides: Dict[str, Any] = {}, ) -> Iterator[CleanupRecord]: headers: Dict[str, Any] = { "Accept": "application/json", @@ -196,7 +198,7 @@ def _scroll_garbage_records(self): # Always delete corrupted records if not key: logger.warning( - f"ergc({self.instance_id}): will delete corrupted entry with missing source key: {entry}" + f"ergc({self.instance_id}): will delete corrupted entry with missing source key: {entry}", ) yield entry continue @@ -238,7 +240,7 @@ def _scroll_garbage_records(self): f"source count: {state[key]['count']}; " f"source cutoff: {state[key]['cutoffTimestamp']}; " f"record timestamp: {entry.requested_at}." - ) + ), ) yield entry @@ -275,7 +277,7 @@ def _reached_delete_limit(self) -> bool: and self.report.ergc_records_deleted >= self.config.limit_entities_delete ): logger.info( - f"ergc({self.instance_id}): max delete limit reached: {self.config.limit_entities_delete}." + f"ergc({self.instance_id}): max delete limit reached: {self.config.limit_entities_delete}.", ) self.report.ergc_delete_limit_reached = True return True @@ -284,7 +286,7 @@ def _reached_delete_limit(self) -> bool: def run(self) -> None: if not self.config.enabled: logger.info( - f"ergc({self.instance_id}): ExecutionRequest cleaner is disabled." + f"ergc({self.instance_id}): ExecutionRequest cleaner is disabled.", ) return self.report.ergc_start_time = datetime.datetime.now() @@ -295,7 +297,7 @@ def run(self) -> None: f"max days: {self.config.keep_history_max_days}, " f"min records: {self.config.keep_history_min_count}, " f"max records: {self.config.keep_history_max_count}." - ) + ), ) for entry in self._scroll_garbage_records(): @@ -305,5 +307,5 @@ def run(self) -> None: self.report.ergc_end_time = datetime.datetime.now() logger.info( - f"ergc({self.instance_id}): Finished cleanup of ExecutionRequest records." + f"ergc({self.instance_id}): Finished cleanup of ExecutionRequest records.", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/soft_deleted_entity_cleanup.py b/metadata-ingestion/src/datahub/ingestion/source/gc/soft_deleted_entity_cleanup.py index ffcd9218a2103c..d08eaf45edbf86 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/soft_deleted_entity_cleanup.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/soft_deleted_entity_cleanup.py @@ -41,7 +41,8 @@ class SoftDeletedEntitiesCleanupConfig(ConfigModel): enabled: bool = Field( - default=True, description="Whether to do soft deletion cleanup." + default=True, + description="Whether to do soft deletion cleanup.", ) retention_days: int = Field( 10, @@ -84,11 +85,13 @@ class SoftDeletedEntitiesCleanupConfig(ConfigModel): ) limit_entities_delete: Optional[int] = Field( - 25000, description="Max number of entities to delete." + 25000, + description="Max number of entities to delete.", ) futures_max_at_time: int = Field( - 1000, description="Max number of futures to have at a time." + 1000, + description="Max number of futures to have at a time.", ) runtime_limit_seconds: int = Field( @@ -107,7 +110,7 @@ class SoftDeletedEntitiesReport(SourceReport): num_hard_deleted: int = 0 num_hard_deleted_by_type: TopKDict[str, int] = field(default_factory=TopKDict) sample_hard_deleted_aspects_by_type: TopKDict[str, LossyList[str]] = field( - default_factory=TopKDict + default_factory=TopKDict, ) runtime_limit_reached: bool = False deletion_limit_reached: bool = False @@ -166,7 +169,7 @@ def delete_entity(self, urn: str) -> None: entity_urn = Urn.from_string(urn) if self.dry_run: logger.info( - f"Dry run is on otherwise it would have deleted {urn} with hard deletion" + f"Dry run is on otherwise it would have deleted {urn} with hard deletion", ) return if self._deletion_limit_reached() or self._times_up(): @@ -224,7 +227,7 @@ def _process_futures(self, futures: Dict[Future, str]) -> Dict[Future, str]: ): if self.config.delay: logger.debug( - f"Sleeping for {self.config.delay} seconds before further processing batch" + f"Sleeping for {self.config.delay} seconds before further processing batch", ) time.sleep(self.config.delay) return futures @@ -263,16 +266,17 @@ def _get_soft_deleted(self, graphql_query: str, entity_type: str) -> Iterable[st "field": "removed", "values": ["true"], "condition": "EQUAL", - } - ] - } + }, + ], + }, ], - } + }, }, ) except Exception as e: self.report.failure( - f"While trying to get {entity_type} with {scroll_id}", exc=e + f"While trying to get {entity_type} with {scroll_id}", + exc=e, ) break scroll_across_entities = result.get("scrollAcrossEntities") @@ -292,7 +296,7 @@ def _get_soft_deleted(self, graphql_query: str, entity_type: str) -> Iterable[st if entity_type not in self.report.num_entities_found: self.report.num_entities_found[entity_type] = 0 self.report.num_entities_found[entity_type] += scroll_across_entities.get( - "count" + "count", ) for query in search_results: yield query["entity"]["urn"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py b/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py index 5196c8ec5b998b..0a1e7e0ca249c6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py @@ -41,7 +41,9 @@ class HMACKey(ConfigModel): class GCSSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin, PathSpecsConfigMixin + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, + PathSpecsConfigMixin, ): credential: HMACKey = Field( description="Google cloud storage [HMAC keys](https://cloud.google.com/storage/docs/authentication/hmackeys)", @@ -61,7 +63,9 @@ class GCSSourceConfig( @validator("path_specs", always=True) def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict + cls, + path_specs: List[PathSpec], + values: Dict, ) -> List[PathSpec]: if len(path_specs) == 0: raise ValueError("path_specs must not be empty") @@ -128,7 +132,7 @@ def create_equivalent_s3_path_specs(self): table_name=path_spec.table_name, enable_compression=path_spec.enable_compression, sample_files=path_spec.sample_files, - ) + ), ) return s3_path_specs @@ -142,7 +146,7 @@ def s3_source_overrides(self, source: S3Source) -> S3Source: source.is_s3_platform = lambda: True # type: ignore source.create_s3_path = lambda bucket_name, key: unquote( # type: ignore - f"s3://{bucket_name}/{key}" + f"s3://{bucket_name}/{key}", ) return source @@ -150,7 +154,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py index bde26f97bf271f..91e70033bec5be 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py @@ -137,7 +137,9 @@ def _inject_connection_into_datasource(conn: Connection) -> Iterator[None]: underlying_datasource_init = SqlAlchemyDatasource.__init__ def sqlalchemy_datasource_init( - self: SqlAlchemyDatasource, *args: Any, **kwargs: Any + self: SqlAlchemyDatasource, + *args: Any, + **kwargs: Any, ) -> None: underlying_datasource_init(self, *args, **kwargs, engine=conn) self.drivername = conn.dialect.name @@ -164,46 +166,50 @@ def get_column_unique_count_dh_patch(self: SqlAlchemyDataset, column: str) -> in # We use coalesce here to force SQL Alchemy to see this # as a column expression. sa.func.coalesce( - sa.text(f'APPROXIMATE count(distinct "{column}")') + sa.text(f'APPROXIMATE count(distinct "{column}")'), ), - ] - ).select_from(self._table) + ], + ).select_from(self._table), ) return convert_to_json_serializable(element_values.fetchone()[0]) elif self.engine.dialect.name.lower() == BIGQUERY: element_values = self.engine.execute( sa.select(sa.func.APPROX_COUNT_DISTINCT(sa.column(column))).select_from( - self._table - ) + self._table, + ), ) return convert_to_json_serializable(element_values.fetchone()[0]) elif self.engine.dialect.name.lower() == SNOWFLAKE: element_values = self.engine.execute( sa.select(sa.func.APPROX_COUNT_DISTINCT(sa.column(column))).select_from( - self._table - ) + self._table, + ), ) return convert_to_json_serializable(element_values.fetchone()[0]) return convert_to_json_serializable( self.engine.execute( sa.select([sa.func.count(sa.func.distinct(sa.column(column)))]).select_from( - self._table - ) - ).scalar() + self._table, + ), + ).scalar(), ) def _get_column_quantiles_bigquery_patch( # type:ignore - self, column: str, quantiles: Iterable + self, + column: str, + quantiles: Iterable, ) -> list: quantile_queries = list() for quantile in quantiles: quantile_queries.append( - sa.text(f"approx_quantiles({column}, 100) OFFSET [{round(quantile * 100)}]") + sa.text( + f"approx_quantiles({column}, 100) OFFSET [{round(quantile * 100)}]", + ), ) quantiles_query = sa.select(quantile_queries).select_from( # type:ignore - self._table + self._table, ) try: quantiles_results = self.engine.execute(quantiles_query).fetchone() @@ -216,12 +222,14 @@ def _get_column_quantiles_bigquery_patch( # type:ignore def _get_column_quantiles_awsathena_patch( # type:ignore - self, column: str, quantiles: Iterable + self, + column: str, + quantiles: Iterable, ) -> list: import ast table_name = ".".join( - [f'"{table_part}"' for table_part in str(self._table).split(".")] + [f'"{table_part}"' for table_part in str(self._table).split(".")], ) quantiles_list = list(quantiles) @@ -246,10 +254,10 @@ def _get_column_median_patch(self, column): or self.sql_engine_dialect.name.lower() == GXSqlDialect.TRINO ): table_name = ".".join( - [f'"{table_part}"' for table_part in str(self._table).split(".")] + [f'"{table_part}"' for table_part in str(self._table).split(".")], ) element_values = self.engine.execute( - f"SELECT approx_percentile({column}, 0.5) FROM {table_name}" + f"SELECT approx_percentile({column}, 0.5) FROM {table_name}", ) return convert_to_json_serializable(element_values.fetchone()[0]) else: @@ -331,7 +339,9 @@ def _run_with_query_combiner( ) -> Callable[Concatenate["_SingleDatasetProfiler", P], None]: @functools.wraps(method) def inner( - self: "_SingleDatasetProfiler", *args: P.args, **kwargs: P.kwargs + self: "_SingleDatasetProfiler", + *args: P.args, + **kwargs: P.kwargs, ) -> None: return self.query_combiner.run(lambda: method(self, *args, **kwargs)) @@ -382,7 +392,7 @@ def _get_columns_to_profile(self) -> List[str]: self.column_types[col] = str(col_dict["type"]) # We expect the allow/deny patterns to specify '.' if not self.config._allow_deny_patterns.allowed( - f"{self.dataset_name}.{col}" + f"{self.dataset_name}.{col}", ): ignored_columns_by_pattern.append(col) # We try to ignore nested columns as well @@ -395,11 +405,11 @@ def _get_columns_to_profile(self) -> List[str]: if ignored_columns_by_pattern: self.report.report_dropped( - f"The profile of columns by pattern {self.dataset_name}({', '.join(sorted(ignored_columns_by_pattern))})" + f"The profile of columns by pattern {self.dataset_name}({', '.join(sorted(ignored_columns_by_pattern))})", ) if ignored_columns_by_type: self.report.report_dropped( - f"The profile of columns by type {self.dataset_name}({', '.join(sorted(ignored_columns_by_type))})" + f"The profile of columns by type {self.dataset_name}({', '.join(sorted(ignored_columns_by_type))})", ) if self.config.max_number_of_fields_to_profile is not None: @@ -412,7 +422,7 @@ def _get_columns_to_profile(self) -> List[str]: ] if self.config.report_dropped_profiles: self.report.report_dropped( - f"The max_number_of_fields_to_profile={self.config.max_number_of_fields_to_profile} reached. Profile of columns {self.dataset_name}({', '.join(sorted(columns_being_dropped))})" + f"The max_number_of_fields_to_profile={self.config.max_number_of_fields_to_profile} reached. Profile of columns {self.dataset_name}({', '.join(sorted(columns_being_dropped))})", ) return columns_to_profile @@ -433,17 +443,19 @@ def _should_ignore_column(self, sqlalchemy_type: sa.types.TypeEngine) -> bool: @_run_with_query_combiner def _get_column_type(self, column_spec: _SingleColumnSpec, column: str) -> None: column_spec.type_ = BasicDatasetProfilerBase._get_column_type( - self.dataset, column + self.dataset, + column, ) if column_spec.type_ == ProfilerDataType.UNKNOWN: try: datahub_field_type = resolve_sql_type( - self.column_types[column], self.dataset.engine.dialect.name.lower() + self.column_types[column], + self.dataset.engine.dialect.name.lower(), ) except Exception as e: logger.debug( - f"Error resolving sql type {self.column_types[column]}: {e}" + f"Error resolving sql type {self.column_types[column]}: {e}", ) datahub_field_type = None if datahub_field_type is None: @@ -453,14 +465,16 @@ def _get_column_type(self, column_spec: _SingleColumnSpec, column: str) -> None: @_run_with_query_combiner def _get_column_cardinality( - self, column_spec: _SingleColumnSpec, column: str + self, + column_spec: _SingleColumnSpec, + column: str, ) -> None: try: nonnull_count = self.dataset.get_column_nonnull_count(column) column_spec.nonnull_count = nonnull_count except Exception as e: logger.debug( - f"Caught exception while attempting to get column cardinality for column {column}. {e}" + f"Caught exception while attempting to get column cardinality for column {column}. {e}", ) self.report.report_warning( @@ -479,7 +493,7 @@ def _get_column_cardinality( pct_unique = float(unique_count) / nonnull_count except Exception: logger.exception( - f"Failed to get unique count for column {self.dataset_name}.{column}" + f"Failed to get unique count for column {self.dataset_name}.{column}", ) column_spec.unique_count = unique_count @@ -494,30 +508,30 @@ def _get_dataset_rows(self, dataset_profile: DatasetProfileClass) -> None: schema_name = self.dataset_name.split(".")[1] table_name = self.dataset_name.split(".")[2] logger.debug( - f"Getting estimated rowcounts for table:{self.dataset_name}, schema:{schema_name}, table:{table_name}" + f"Getting estimated rowcounts for table:{self.dataset_name}, schema:{schema_name}, table:{table_name}", ) get_estimate_script = sa.text( - f"SELECT c.reltuples AS estimate FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = '{table_name}' AND n.nspname = '{schema_name}'" + f"SELECT c.reltuples AS estimate FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = '{table_name}' AND n.nspname = '{schema_name}'", ) elif dialect_name == MYSQL: schema_name = self.dataset_name.split(".")[0] table_name = self.dataset_name.split(".")[1] logger.debug( - f"Getting estimated rowcounts for table:{self.dataset_name}, schema:{schema_name}, table:{table_name}" + f"Getting estimated rowcounts for table:{self.dataset_name}, schema:{schema_name}, table:{table_name}", ) get_estimate_script = sa.text( - f"SELECT table_rows AS estimate FROM information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'" + f"SELECT table_rows AS estimate FROM information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'", ) else: logger.debug( f"Dialect {dialect_name} not supported for feature " - f"profile_table_row_count_estimate_only. Proceeding with full row count." + f"profile_table_row_count_estimate_only. Proceeding with full row count.", ) dataset_profile.rowCount = self.dataset.get_row_count() return dataset_profile.rowCount = int( - self.dataset.engine.execute(get_estimate_script).scalar() + self.dataset.engine.execute(get_estimate_script).scalar(), ) else: # If the configuration is not set to 'estimate only' mode, we directly obtain the row count from the @@ -533,14 +547,16 @@ def _get_dataset_rows(self, dataset_profile: DatasetProfileClass) -> None: # We don't want limit and offset to get applied to the row count # This is kinda hacky way to do it, but every other way would require major refactoring dataset_profile.rowCount = self.dataset.get_row_count( - self.dataset_name.split(".")[-1] + self.dataset_name.split(".")[-1], ) else: dataset_profile.rowCount = self.dataset.get_row_count() @_run_with_query_combiner def _get_dataset_column_min( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_min_value: return @@ -548,7 +564,7 @@ def _get_dataset_column_min( column_profile.min = str(self.dataset.get_column_min(column)) except Exception as e: logger.debug( - f"Caught exception while attempting to get column min for column {column}. {e}" + f"Caught exception while attempting to get column min for column {column}. {e}", ) self.report.report_warning( @@ -560,7 +576,9 @@ def _get_dataset_column_min( @_run_with_query_combiner def _get_dataset_column_max( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_max_value: return @@ -568,7 +586,7 @@ def _get_dataset_column_max( column_profile.max = str(self.dataset.get_column_max(column)) except Exception as e: logger.debug( - f"Caught exception while attempting to get column max for column {column}. {e}" + f"Caught exception while attempting to get column max for column {column}. {e}", ) self.report.report_warning( @@ -580,7 +598,9 @@ def _get_dataset_column_max( @_run_with_query_combiner def _get_dataset_column_mean( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_mean_value: return @@ -588,7 +608,7 @@ def _get_dataset_column_mean( column_profile.mean = str(self.dataset.get_column_mean(column)) except Exception as e: logger.debug( - f"Caught exception while attempting to get column mean for column {column}. {e}" + f"Caught exception while attempting to get column mean for column {column}. {e}", ) self.report.report_warning( @@ -600,7 +620,9 @@ def _get_dataset_column_mean( @_run_with_query_combiner def _get_dataset_column_median( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_median_value: return @@ -609,23 +631,23 @@ def _get_dataset_column_median( column_profile.median = str( self.dataset.engine.execute( sa.select([sa.func.median(sa.column(column))]).select_from( - self.dataset._table - ) - ).scalar() + self.dataset._table, + ), + ).scalar(), ) elif self.dataset.engine.dialect.name.lower() == BIGQUERY: column_profile.median = str( self.dataset.engine.execute( sa.select( - sa.text(f"approx_quantiles(`{column}`, 2) [OFFSET (1)]") - ).select_from(self.dataset._table) - ).scalar() + sa.text(f"approx_quantiles(`{column}`, 2) [OFFSET (1)]"), + ).select_from(self.dataset._table), + ).scalar(), ) else: column_profile.median = str(self.dataset.get_column_median(column)) except Exception as e: logger.debug( - f"Caught exception while attempting to get column median for column {column}. {e}" + f"Caught exception while attempting to get column median for column {column}. {e}", ) self.report.report_warning( @@ -637,7 +659,9 @@ def _get_dataset_column_median( @_run_with_query_combiner def _get_dataset_column_stdev( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_stddev_value: return @@ -645,7 +669,7 @@ def _get_dataset_column_stdev( column_profile.stdev = str(self.dataset.get_column_stdev(column)) except Exception as e: logger.debug( - f"Caught exception while attempting to get column stddev for column {column}. {e}" + f"Caught exception while attempting to get column stddev for column {column}. {e}", ) self.report.report_warning( title="Profiling: Unable to Calculate Standard Deviation", @@ -656,7 +680,9 @@ def _get_dataset_column_stdev( @_run_with_query_combiner def _get_dataset_column_quantiles( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_quantiles: return @@ -687,7 +713,7 @@ def _get_dataset_column_quantiles( ] except Exception as e: logger.debug( - f"Caught exception while attempting to get column quantiles for column {column}. {e}" + f"Caught exception while attempting to get column quantiles for column {column}. {e}", ) self.report.report_warning( @@ -699,7 +725,9 @@ def _get_dataset_column_quantiles( @_run_with_query_combiner def _get_dataset_column_distinct_value_frequencies( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if self.config.include_field_distinct_value_frequencies: column_profile.distinctValueFrequencies = [ @@ -709,7 +737,9 @@ def _get_dataset_column_distinct_value_frequencies( @_run_with_query_combiner def _get_dataset_column_histogram( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_histogram: return @@ -734,7 +764,7 @@ def _get_dataset_column_histogram( ) except Exception as e: logger.debug( - f"Caught exception while attempting to get column histogram for column {column}. {e}" + f"Caught exception while attempting to get column histogram for column {column}. {e}", ) self.report.report_warning( @@ -746,7 +776,9 @@ def _get_dataset_column_histogram( @_run_with_query_combiner def _get_dataset_column_sample_values( - self, column_profile: DatasetFieldProfileClass, column: str + self, + column_profile: DatasetFieldProfileClass, + column: str, ) -> None: if not self.config.include_field_sample_values: return @@ -769,7 +801,7 @@ def _get_dataset_column_sample_values( ] except Exception as e: logger.debug( - f"Caught exception while attempting to get sample values for column {column}. {e}" + f"Caught exception while attempting to get sample values for column {column}. {e}", ) self.report.report_warning( @@ -783,7 +815,8 @@ def generate_dataset_profile( # noqa: C901 (complexity) self, ) -> DatasetProfileClass: self.dataset.set_default_expectation_argument( - "catch_exceptions", self.config.catch_exceptions + "catch_exceptions", + self.config.catch_exceptions, ) profile = self.init_profile() @@ -870,7 +903,8 @@ def generate_dataset_profile( # noqa: C901 (complexity) if non_null_count is not None and non_null_count > 0: # Sometimes this value is bigger than 1 because of the approx queries column_profile.uniqueProportion = min( - 1, unique_count / non_null_count + 1, + unique_count / non_null_count, ) if not profile.rowCount: @@ -967,12 +1001,13 @@ def init_profile(self): profile.partitionSpec = PartitionSpecClass( type=PartitionTypeClass.QUERY, partition=json.dumps( - dict(limit=self.config.limit, offset=self.config.offset) + dict(limit=self.config.limit, offset=self.config.offset), ), ) elif self.custom_sql: profile.partitionSpec = PartitionSpecClass( - type=PartitionTypeClass.QUERY, partition="SAMPLE" + type=PartitionTypeClass.QUERY, + partition="SAMPLE", ) return profile @@ -1024,7 +1059,8 @@ def update_dataset_batch_use_sampling(self, profile: DatasetProfileClass) -> Non and profile.partitionSpec.type == PartitionTypeClass.FULL_TABLE ): profile.partitionSpec = PartitionSpecClass( - type=PartitionTypeClass.QUERY, partition="SAMPLE" + type=PartitionTypeClass.QUERY, + partition="SAMPLE", ) elif ( profile.partitionSpec @@ -1103,7 +1139,7 @@ def _ge_context(self) -> Iterator[GEContext]: "enabled": False, # "data_context_id": , }, - ) + ), ) datasource_name = f"{self._datasource_name_base}-{uuid.uuid4()}" @@ -1137,7 +1173,7 @@ def generate_profiles( ) -> Iterable[Tuple[GEProfilerRequest, Optional[DatasetProfileClass]]]: max_workers = min(max_workers, len(requests)) logger.info( - f"Will profile {len(requests)} table(s) with {max_workers} worker(s) - this may take a while" + f"Will profile {len(requests)} table(s) with {max_workers} worker(s) - this may take a while", ) with PerfTimer() as timer, unittest.mock.patch( @@ -1153,7 +1189,7 @@ def generate_profiles( "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median", _get_column_median_patch, ), concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers + max_workers=max_workers, ) as async_executor, SQLAlchemyQueryCombiner( enabled=self.config.query_combiner_enabled, catch_exceptions=self.config.catch_exceptions, @@ -1181,7 +1217,7 @@ def generate_profiles( total_time_taken = timer.elapsed_seconds() logger.info( - f"Profiling {len(requests)} table(s) finished in {total_time_taken:.3f} seconds" + f"Profiling {len(requests)} table(s) finished in {total_time_taken:.3f} seconds", ) time_percentiles: Dict[str, float] = {} @@ -1189,12 +1225,13 @@ def generate_profiles( if len(self.times_taken) > 0: percentiles = [50, 75, 95, 99] percentile_values = stats.calculate_percentiles( - self.times_taken, percentiles + self.times_taken, + percentiles, ) time_percentiles = { f"table_time_taken_p{percentile}": stats.discretize( - percentile_values[percentile] + percentile_values[percentile], ) for percentile in percentiles } @@ -1252,7 +1289,7 @@ def _generate_single_profile( **kwargs: Any, ) -> Optional[DatasetProfileClass]: logger.debug( - f"Received single profile request for {pretty_name} for {schema}, {table}, {custom_sql}" + f"Received single profile request for {pretty_name} for {schema}, {table}, {custom_sql}", ) ge_config = { @@ -1269,7 +1306,10 @@ def _generate_single_profile( if custom_sql is not None: # Note that limit and offset are not supported for custom SQL. temp_view = create_athena_temp_table( - self, custom_sql, pretty_name, self.base_engine.raw_connection() + self, + custom_sql, + pretty_name, + self.base_engine.raw_connection(), ) ge_config["table"] = temp_view ge_config["schema"] = None @@ -1290,7 +1330,10 @@ def _generate_single_profile( if self.config.offset: bq_sql += f" OFFSET {self.config.offset}" bigquery_temp_table = create_bigquery_temp_table( - self, bq_sql, pretty_name, self.base_engine.raw_connection() + self, + bq_sql, + pretty_name, + self.base_engine.raw_connection(), ) if platform == BIGQUERY: @@ -1334,7 +1377,7 @@ def _generate_single_profile( time_taken = timer.elapsed_seconds() logger.info( - f"Finished profiling {pretty_name}; took {time_taken:.3f} seconds" + f"Finished profiling {pretty_name}; took {time_taken:.3f} seconds", ) self.times_taken.append(time_taken) if profile.rowCount is not None: @@ -1414,7 +1457,7 @@ def _get_ge_dataset( name_parts = pretty_name.split(".") if len(name_parts) != 3: logger.error( - f"Unexpected {pretty_name} while profiling. Should have 3 parts but has {len(name_parts)} parts." + f"Unexpected {pretty_name} while profiling. Should have 3 parts but has {len(name_parts)} parts.", ) # If we only have two parts that means the project_id is missing from the table name and we add it # Temp tables has 3 parts while normal tables only has 2 parts @@ -1450,7 +1493,7 @@ def create_athena_temp_table( if "." in table_pretty_name: schema_part = table_pretty_name.split(".")[-1] schema_part_quoted = ".".join( - [f'"{part}"' for part in str(schema_part).split(".")] + [f'"{part}"' for part in str(schema_part).split(".")], ) temp_view = f"{schema_part_quoted}_{temp_view}" @@ -1491,7 +1534,7 @@ def create_bigquery_temp_table( if not instance.config.catch_exceptions: raise e logger.exception( - f"Encountered exception while profiling {table_pretty_name}" + f"Encountered exception while profiling {table_pretty_name}", ) instance.report.report_warning( table_pretty_name, @@ -1548,7 +1591,10 @@ def create_bigquery_temp_table( def _get_columns_to_ignore_sampling( - dataset_name: str, tags_to_ignore: Optional[List[str]], platform: str, env: str + dataset_name: str, + tags_to_ignore: Optional[List[str]], + platform: str, + env: str, ) -> Tuple[bool, List[str]]: logger.debug("Collecting columns to ignore for sampling") @@ -1559,7 +1605,9 @@ def _get_columns_to_ignore_sampling( return ignore_table, columns_to_ignore dataset_urn = mce_builder.make_dataset_urn( - name=dataset_name, platform=platform, env=env + name=dataset_name, + platform=platform, + env=env, ) datahub_graph = get_default_graph() @@ -1573,7 +1621,8 @@ def _get_columns_to_ignore_sampling( if not ignore_table: metadata = datahub_graph.get_aspect( - entity_urn=dataset_urn, aspect_type=EditableSchemaMetadata + entity_urn=dataset_urn, + aspect_type=EditableSchemaMetadata, ) if metadata: diff --git a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py index 93142a347ca0e6..d46edd76c7af0e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py @@ -21,7 +21,8 @@ class GEProfilingBaseConfig(ConfigModel): enabled: bool = Field( - default=False, description="Whether profiling should be done." + default=False, + description="Whether profiling should be done.", ) operation_config: OperationConfig = Field( default_factory=OperationConfig, @@ -73,7 +74,8 @@ class GEProfilingBaseConfig(ConfigModel): description="Whether to profile for the quantiles of numeric columns.", ) include_field_distinct_value_frequencies: bool = Field( - default=False, description="Whether to profile for distinct value frequencies." + default=False, + description="Whether to profile for distinct value frequencies.", ) include_field_histogram: bool = Field( default=False, @@ -205,7 +207,8 @@ def deprecate_bigquery_temp_table_schema(cls, values): @pydantic.root_validator(pre=True) def ensure_field_level_settings_are_normalized( - cls: "GEProfilingConfig", values: Dict[str, Any] + cls: "GEProfilingConfig", + values: Dict[str, Any], ) -> Dict[str, Any]: max_num_fields_to_profile_key = "max_number_of_fields_to_profile" max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) @@ -216,7 +219,7 @@ def ensure_field_level_settings_are_normalized( if field_level_metric.startswith("include_field_"): if values.get(field_level_metric): raise ValueError( - "Cannot enable field-level metrics if profile_table_level_only is set" + "Cannot enable field-level metrics if profile_table_level_only is set", ) values[field_level_metric] = False diff --git a/metadata-ingestion/src/datahub/ingestion/source/git/git_import.py b/metadata-ingestion/src/datahub/ingestion/source/git/git_import.py index 6440ec8e5d7877..9167f029ef43a3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/git/git_import.py +++ b/metadata-ingestion/src/datahub/ingestion/source/git/git_import.py @@ -19,7 +19,10 @@ def __init__(self, tmp_dir: str, skip_known_host_verification: bool = True): self.last_repo_cloned: Optional[git.Repo] = None def clone( - self, ssh_key: Optional[SecretStr], repo_url: str, branch: Optional[str] = None + self, + ssh_key: Optional[SecretStr], + repo_url: str, + branch: Optional[str] = None, ) -> Path: # Note: this does a shallow clone. @@ -59,7 +62,7 @@ def clone( if branch is None: logger.info( - f"⏳ Cloning repo '{self.sanitize_repo_url(repo_url)}' (default branch), this can take some time..." + f"⏳ Cloning repo '{self.sanitize_repo_url(repo_url)}' (default branch), this can take some time...", ) self.last_repo_cloned = git.Repo.clone_from( repo_url, @@ -72,7 +75,7 @@ def clone( # we can't just use the --branch flag of Git clone. Doing a blobless clone allows # us to quickly checkout the right commit. logger.info( - f"⏳ Cloning repo '{self.sanitize_repo_url(repo_url)}' (branch: {branch}), this can take some time..." + f"⏳ Cloning repo '{self.sanitize_repo_url(repo_url)}' (branch: {branch}), this can take some time...", ) self.last_repo_cloned = git.Repo.clone_from( repo_url, diff --git a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py index 53f71046c25c0d..91a4709c60e779 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py @@ -35,7 +35,7 @@ class GrafanaSourceConfig(StatefulIngestionConfigBase, PlatformInstanceConfigMix description="Grafana URL in the format http://your-grafana-instance with no trailing slash", ) service_account_token: SecretStr = Field( - description="Service account token for Grafana" + description="Service account token for Grafana", ) @@ -67,7 +67,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] @@ -81,7 +83,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: } try: response = requests.get( - f"{self.source_config.url}/api/search", headers=headers + f"{self.source_config.url}/api/search", + headers=headers, ) response.raise_for_status() except requests.exceptions.RequestException as e: @@ -127,5 +130,5 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ), StatusClass(removed=False), ], - ) + ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py index 9a62ee2dab52f4..e438fe16db7412 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py @@ -87,7 +87,7 @@ LOGGER = logging.getLogger(__name__) logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel( - logging.WARNING + logging.WARNING, ) @@ -101,7 +101,9 @@ @capability(SourceCapability.DOMAINS, "Currently not supported.", supported=False) @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration.") @capability( - SourceCapability.PARTITION_SUPPORT, "Currently not supported.", supported=False + SourceCapability.PARTITION_SUPPORT, + "Currently not supported.", + supported=False, ) @capability(SourceCapability.DESCRIPTIONS, "Enabled by default.") @capability( @@ -134,14 +136,16 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]: namespaces = catalog.list_namespaces() LOGGER.debug( - f"Retrieved {len(namespaces)} namespaces, first 10: {namespaces[:10]}" + f"Retrieved {len(namespaces)} namespaces, first 10: {namespaces[:10]}", ) self.report.report_no_listed_namespaces(len(namespaces)) tables_count = 0 @@ -149,7 +153,7 @@ def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]: namespace_repr = ".".join(namespace) if not self.config.namespace_pattern.allowed(namespace_repr): LOGGER.info( - f"Namespace {namespace_repr} is not allowed by config pattern, skipping" + f"Namespace {namespace_repr} is not allowed by config pattern, skipping", ) self.report.report_dropped(f"{namespace_repr}.*") continue @@ -157,10 +161,11 @@ def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]: tables = catalog.list_tables(namespace) tables_count += len(tables) LOGGER.debug( - f"Retrieved {len(tables)} tables for namespace: {namespace}, in total retrieved {tables_count}, first 10: {tables[:10]}" + f"Retrieved {len(tables)} tables for namespace: {namespace}, in total retrieved {tables_count}, first 10: {tables[:10]}", ) self.report.report_listed_tables_for_namespace( - ".".join(namespace), len(tables) + ".".join(namespace), + len(tables), ) yield from tables except NoSuchNamespaceError: @@ -177,7 +182,7 @@ def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]: f"Couldn't list tables for namespace {namespace} due to {e}", ) LOGGER.exception( - f"Unexpected exception while trying to get list of tables for namespace {namespace}, skipping it" + f"Unexpected exception while trying to get list of tables for namespace {namespace}, skipping it", ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: @@ -190,13 +195,13 @@ def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: # Dataset name is rejected by pattern, report as dropped. self.report.report_dropped(dataset_name) LOGGER.debug( - f"Skipping table {dataset_name} due to not being allowed by the config pattern" + f"Skipping table {dataset_name} due to not being allowed by the config pattern", ) return try: if not hasattr(thread_local, "local_catalog"): LOGGER.debug( - f"Didn't find local_catalog in thread_local ({thread_local}), initializing new catalog" + f"Didn't find local_catalog in thread_local ({thread_local}), initializing new catalog", ) thread_local.local_catalog = self.config.get_catalog() @@ -204,7 +209,9 @@ def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: table = thread_local.local_catalog.load_table(dataset_path) time_taken = timer.elapsed_seconds() self.report.report_table_load_time( - time_taken, dataset_name, table.metadata_location + time_taken, + dataset_name, + table.metadata_location, ) LOGGER.debug(f"Loaded table: {table.name()}, time taken: {time_taken}") yield from self._create_iceberg_workunit(dataset_name, table) @@ -238,7 +245,7 @@ def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: f"Encountered FileNotFoundError when trying to read manifest file for {dataset_name}. {e}", ) LOGGER.warning( - f"FileNotFoundError while processing table {dataset_path}, skipping it." + f"FileNotFoundError while processing table {dataset_path}, skipping it.", ) except ServerError as e: self.report.report_warning( @@ -246,7 +253,7 @@ def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: f"Iceberg Rest Catalog returned 500 status due to an unhandled exception for {dataset_name}. Exception: {e}", ) LOGGER.warning( - f"Iceberg Rest Catalog server error (500 status) encountered when processing table {dataset_path}, skipping it." + f"Iceberg Rest Catalog server error (500 status) encountered when processing table {dataset_path}, skipping it.", ) except Exception as e: self.report.report_failure( @@ -271,7 +278,9 @@ def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: yield wu def _create_iceberg_workunit( - self, dataset_name: str, table: Table + self, + dataset_name: str, + table: Table, ) -> Iterable[MetadataWorkUnit]: with PerfTimer() as timer: self.report.report_table_scanned(dataset_name) @@ -294,7 +303,7 @@ def _create_iceberg_workunit( custom_properties["partition-spec"] = str(self._get_partition_aspect(table)) if table.current_snapshot(): custom_properties["snapshot-id"] = str( - table.current_snapshot().snapshot_id + table.current_snapshot().snapshot_id, ) custom_properties["manifest-list"] = ( table.current_snapshot().manifest_list @@ -309,7 +318,7 @@ def _create_iceberg_workunit( dataset_ownership = self._get_ownership_aspect(table) if dataset_ownership: LOGGER.debug( - f"Adding ownership: {dataset_ownership} to the dataset {dataset_name}" + f"Adding ownership: {dataset_ownership} to the dataset {dataset_name}", ) dataset_snapshot.aspects.append(dataset_ownership) @@ -318,7 +327,9 @@ def _create_iceberg_workunit( mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) self.report.report_table_processing_time( - timer.elapsed_seconds(), dataset_name, table.metadata_location + timer.elapsed_seconds(), + dataset_name, + table.metadata_location, ) yield MetadataWorkUnit(id=dataset_name, mce=mce) @@ -355,16 +366,16 @@ def _get_partition_aspect(self, table: Table) -> Optional[str]: "name": partition.name, "transform": str(partition.transform), "source": str( - table.schema().find_column_name(partition.source_id) + table.schema().find_column_name(partition.source_id), ), "source-id": partition.source_id, "source-type": str( - table.schema().find_type(partition.source_id) + table.schema().find_type(partition.source_id), ), "field-id": partition.field_id, } for partition in table.spec().fields - ] + ], ) except Exception as e: self.report.report_warning( @@ -385,7 +396,7 @@ def _get_ownership_aspect(self, table: Table) -> Optional[OwnershipClass]: owner=make_user_urn(user_owner), type=OwnershipTypeClass.TECHNICAL_OWNER, source=None, - ) + ), ) if self.config.group_ownership_property: if self.config.group_ownership_property in table.metadata.properties: @@ -397,12 +408,13 @@ def _get_ownership_aspect(self, table: Table) -> Optional[OwnershipClass]: owner=make_group_urn(group_owner), type=OwnershipTypeClass.TECHNICAL_OWNER, source=None, - ) + ), ) return OwnershipClass(owners=owners) if owners else None def _get_dataplatform_instance_aspect( - self, dataset_urn: str + self, + dataset_urn: str, ) -> Optional[MetadataWorkUnit]: # If we are a platform instance based source, emit the instance aspect if self.config.platform_instance: @@ -411,7 +423,8 @@ def _get_dataplatform_instance_aspect( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ).as_workunit() @@ -419,7 +432,9 @@ def _get_dataplatform_instance_aspect( return None def _create_schema_metadata( - self, dataset_name: str, table: Table + self, + dataset_name: str, + table: Table, ) -> SchemaMetadata: schema_fields = self._get_schema_fields_for_schema(table.schema()) schema_metadata = SchemaMetadata( @@ -438,7 +453,8 @@ def _get_schema_fields_for_schema( ) -> List[SchemaField]: avro_schema = visit(schema, ToAvroSchemaIcebergVisitor()) schema_fields = schema_util.avro_schema_to_mce_fields( - json.dumps(avro_schema), default_nullable=False + json.dumps(avro_schema), + default_nullable=False, ) return schema_fields @@ -457,7 +473,9 @@ def schema(self, schema: Schema, struct_result: Dict[str, Any]) -> Dict[str, Any return struct_result def struct( - self, struct: StructType, field_results: List[Dict[str, Any]] + self, + struct: StructType, + field_results: List[Dict[str, Any]], ) -> Dict[str, Any]: nullable = True return { @@ -477,7 +495,9 @@ def field(self, field: NestedField, field_result: Dict[str, Any]) -> Dict[str, A } def list( - self, list_type: ListType, element_result: Dict[str, Any] + self, + list_type: ListType, + element_result: Dict[str, Any], ) -> Dict[str, Any]: return { "type": "array", @@ -532,7 +552,7 @@ def visit_decimal(self, decimal_type: DecimalType) -> Dict[str, Any]: # "type": "bytes", # when using bytes, avro drops _nullable attribute and others. See unit test. "type": "fixed", # to fix avro bug ^ resolved by using a fixed type "name": self._gen_name( - "__fixed_" + "__fixed_", ), # to fix avro bug ^ resolved by using a fixed type "size": 1, # to fix avro bug ^ resolved by using a fixed type "logicalType": "decimal", diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py index 83fe3d1c079f17..8b10d573e68127 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py @@ -59,7 +59,8 @@ class IcebergProfilingConfig(ConfigModel): class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): # Override the stateful_ingestion config param with the Iceberg custom stateful ingestion config in the IcebergSourceConfig stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="Iceberg Stateful Ingestion Config." + default=None, + description="Iceberg Stateful Ingestion Config.", ) # The catalog configuration is using a dictionary to be open and flexible. All the keys and values are handled by pyiceberg. This will future-proof any configuration change done by pyiceberg. catalog: Dict[str, Dict[str, Any]] = Field( @@ -83,7 +84,8 @@ class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin) ) profiling: IcebergProfilingConfig = IcebergProfilingConfig() processing_threads: int = Field( - default=1, description="How many threads will be processing tables" + default=1, + description="How many threads will be processing tables", ) @validator("catalog", pre=True, always=True) @@ -103,7 +105,7 @@ def handle_deprecated_catalog_format(cls, value): catalog_type = value["type"] catalog_config = value["config"] new_catalog_config = { - catalog_name: {"type": catalog_type, **catalog_config} + catalog_name: {"type": catalog_type, **catalog_config}, } return new_catalog_config # In case the input is already the new format or is invalid @@ -120,14 +122,14 @@ def validate_catalog_size(cls, value): # Check if that dict is not empty if not catalog_config or not isinstance(catalog_config, dict): raise ValueError( - f"The catalog configuration for '{catalog_name}' must not be empty and should be a dictionary with at least one key-value pair." + f"The catalog configuration for '{catalog_name}' must not be empty and should be a dictionary with at least one key-value pair.", ) return value def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) def get_catalog(self) -> Catalog: @@ -142,7 +144,9 @@ def get_catalog(self) -> Catalog: # Retrieve the dict associated with the one catalog entry catalog_name, catalog_config = next(iter(self.catalog.items())) logger.debug( - "Initializing the catalog %s with config: %s", catalog_name, catalog_config + "Initializing the catalog %s with config: %s", + catalog_name, + catalog_config, ) return load_catalog(name=catalog_name, **catalog_config) @@ -190,7 +194,7 @@ def __str__(self) -> str: "max_time": format_timespan(self.times[-1], detailed=True, max_units=3), # total_time does not provide correct information in case we run in more than 1 thread "total_time": format_timespan(total, detailed=True, max_units=3), - } + }, ) @@ -208,11 +212,13 @@ class IcebergSourceReport(StaleEntityRemovalSourceReport): listed_namespaces: int = 0 total_listed_tables: int = 0 tables_listed_per_namespace: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) def report_listed_tables_for_namespace( - self, namespace: str, no_tables: int + self, + namespace: str, + no_tables: int, ) -> None: self.tables_listed_per_namespace[namespace] = no_tables self.total_listed_tables += no_tables @@ -227,25 +233,46 @@ def report_dropped(self, ent_name: str) -> None: self.filtered.append(ent_name) def report_table_load_time( - self, t: float, table_name: str, table_metadata_location: str + self, + t: float, + table_name: str, + table_metadata_location: str, ) -> None: self.load_table_timings.add_timing(t) self.tables_load_timings.add( - {"table": table_name, "timing": t, "metadata_file": table_metadata_location} + { + "table": table_name, + "timing": t, + "metadata_file": table_metadata_location, + }, ) def report_table_processing_time( - self, t: float, table_name: str, table_metadata_location: str + self, + t: float, + table_name: str, + table_metadata_location: str, ) -> None: self.processing_table_timings.add_timing(t) self.tables_process_timings.add( - {"table": table_name, "timing": t, "metadata_file": table_metadata_location} + { + "table": table_name, + "timing": t, + "metadata_file": table_metadata_location, + }, ) def report_table_profiling_time( - self, t: float, table_name: str, table_metadata_location: str + self, + t: float, + table_name: str, + table_metadata_location: str, ) -> None: self.profiling_table_timings.add_timing(t) self.tables_profile_timings.add( - {"table": table_name, "timing": t, "metadata_file": table_metadata_location} + { + "table": table_name, + "timing": t, + "metadata_file": table_metadata_location, + }, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py index 7642cabbd1404c..7cdf10232cb1e6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py @@ -130,7 +130,7 @@ def profile_table( field.field_id for field in table.schema().fields if field.field_type.is_primitive - ] + ], ) dataset_profile = DatasetProfileClass( timestampMillis=get_sys_time(), @@ -149,7 +149,8 @@ def profile_table( data_file = manifest_entry.data_file if self.config.include_field_null_count: null_counts = self._aggregate_counts( - null_counts, data_file.null_value_counts + null_counts, + data_file.null_value_counts, ) if self.config.include_field_min_value: self._aggregate_bounds( @@ -180,16 +181,19 @@ def profile_table( column_profile = DatasetFieldProfileClass(fieldPath=field_path) if self.config.include_field_null_count: column_profile.nullCount = cast( - int, null_counts.get(field_id, 0) + int, + null_counts.get(field_id, 0), ) column_profile.nullProportion = float( - column_profile.nullCount / row_count + column_profile.nullCount / row_count, ) if self.config.include_field_min_value: column_profile.min = ( self._render_value( - dataset_name, field.field_type, min_bounds.get(field_id) + dataset_name, + field.field_type, + min_bounds.get(field_id), ) if field_id in min_bounds else None @@ -197,7 +201,9 @@ def profile_table( if self.config.include_field_max_value: column_profile.max = ( self._render_value( - dataset_name, field.field_type, max_bounds.get(field_id) + dataset_name, + field.field_type, + max_bounds.get(field_id), ) if field_id in max_bounds else None @@ -205,10 +211,12 @@ def profile_table( dataset_profile.fieldProfiles.append(column_profile) time_taken = timer.elapsed_seconds() self.report.report_table_profiling_time( - time_taken, dataset_name, table.metadata_location + time_taken, + dataset_name, + table.metadata_location, ) LOGGER.debug( - f"Finished profiling of dataset: {dataset_name} in {time_taken}" + f"Finished profiling of dataset: {dataset_name} in {time_taken}", ) yield MetadataChangeProposalWrapper( @@ -217,7 +225,10 @@ def profile_table( ).as_workunit() def _render_value( - self, dataset_name: str, value_type: IcebergType, value: Any + self, + dataset_name: str, + value_type: IcebergType, + value: Any, ) -> Union[str, None]: try: if isinstance(value_type, TimestampType): diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py index edb9b7b8bd5264..0fd5e795844f8e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py @@ -60,19 +60,19 @@ class AzureADConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): # Required client_id: str = Field( - description="Application ID. Found in your app registration on Azure AD Portal" + description="Application ID. Found in your app registration on Azure AD Portal", ) tenant_id: str = Field( - description="Directory ID. Found in your app registration on Azure AD Portal" + description="Directory ID. Found in your app registration on Azure AD Portal", ) client_secret: str = Field( - description="Client secret. Found in your app registration on Azure AD Portal" + description="Client secret. Found in your app registration on Azure AD Portal", ) authority: str = Field( - description="The authority (https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-client-application-configuration) is a URL that indicates a directory that MSAL can request tokens from." + description="The authority (https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-client-application-configuration) is a URL that indicates a directory that MSAL can request tokens from.", ) token_url: str = Field( - description="The token URL that acquires a token from Azure AD for authorizing requests. This source will only work with v1.0 endpoint." + description="The token URL that acquires a token from Azure AD for authorizing requests. This source will only work with v1.0 endpoint.", ) # Optional: URLs for redirect and hitting the Graph API redirect: str = Field( @@ -109,10 +109,12 @@ class AzureADConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): # Optional: to ingest users, groups or both ingest_users: bool = Field( - default=True, description="Whether users should be ingested into DataHub." + default=True, + description="Whether users should be ingested into DataHub.", ) ingest_groups: bool = Field( - default=True, description="Whether groups should be ingested into DataHub." + default=True, + description="Whether groups should be ingested into DataHub.", ) ingest_group_membership: bool = Field( default=True, @@ -150,7 +152,8 @@ class AzureADConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="Azure AD Stateful Ingestion Config." + default=None, + description="Azure AD Stateful Ingestion Config.", ) @@ -173,7 +176,8 @@ def report_filtered(self, name: str) -> None: @config_class(AzureADConfig) @support_status(SupportStatus.CERTIFIED) @capability( - SourceCapability.DELETION_DETECTION, "Optionally enabled via stateful_ingestion" + SourceCapability.DELETION_DETECTION, + "Optionally enabled via stateful_ingestion", ) class AzureADSource(StatefulIngestionSourceBase): """ @@ -267,11 +271,13 @@ def __init__(self, config: AzureADConfig, ctx: PipelineContext): super().__init__(config, ctx) self.config = config self.report = AzureADSourceReport( - filtered_tracking=self.config.filtered_tracking + filtered_tracking=self.config.filtered_tracking, ) session = requests.Session() retries = Retry( - total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] + total=5, + backoff_factor=1, + status_forcelist=[429, 500, 502, 503, 504], ) adapter = HTTPAdapter(max_retries=retries) session.mount("http://", adapter) @@ -308,7 +314,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -324,13 +332,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for azure_ad_groups in self._get_azure_ad_groups(): logger.info("Processing another groups batch...") datahub_corp_group_snapshots = self._map_azure_ad_groups( - azure_ad_groups + azure_ad_groups, ) for group_count, datahub_corp_group_snapshot in enumerate( - datahub_corp_group_snapshots + datahub_corp_group_snapshots, ): mce = MetadataChangeEvent( - proposedSnapshot=datahub_corp_group_snapshot + proposedSnapshot=datahub_corp_group_snapshot, ) wu_id = ( f"group-{group_count + 1}" @@ -383,10 +391,11 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # 3) the users # getting infos about the users belonging to the found groups datahub_corp_user_snapshots = self._map_azure_ad_users( - self.azure_ad_groups_users + self.azure_ad_groups_users, ) yield from self.ingest_ad_users( - datahub_corp_user_snapshots, datahub_corp_user_urn_to_group_membership + datahub_corp_user_snapshots, + datahub_corp_user_urn_to_group_membership, ) # Create MetadataWorkUnits for CorpUsers @@ -429,7 +438,7 @@ def _add_group_members_to_group_membership( else: # Unless told otherwise, we only care about users and groups. Silently skip other object types. logger.warning( - f"Unsupported @odata.type '{odata_type}' found in Azure group member. Skipping...." + f"Unsupported @odata.type '{odata_type}' found in Azure group member. Skipping....", ) def _add_user_to_group_membership( @@ -454,7 +463,7 @@ def ingest_ad_users( datahub_corp_user_urn_to_group_membership: dict, ) -> Generator[MetadataWorkUnit, Any, None]: for user_count, datahub_corp_user_snapshot in enumerate( - datahub_corp_user_snapshots + datahub_corp_user_snapshots, ): # TODO: Refactor common code between this and Okta to a common base class or utils # Add group membership aspect @@ -529,11 +538,16 @@ def _map_identity_to_urn(self, func, id_to_extract, mapping_identifier, id_type) result = func(id_to_extract) except Exception as e: error_str = "Failed to extract DataHub {} from Azure AD {} with name {} due to '{}'".format( - id_type, id_type, id_to_extract.get("displayName"), repr(e) + id_type, + id_type, + id_to_extract.get("displayName"), + repr(e), ) if not result: error_str = "Failed to extract DataHub {} from Azure AD {} with name {} due to unknown reason".format( - id_type, id_type, id_to_extract.get("displayName") + id_type, + id_type, + id_to_extract.get("displayName"), ) if error_str is not None: logger.error(error_str) @@ -603,7 +617,10 @@ def _map_azure_ad_group_to_group_name(self, azure_ad_group): def _map_azure_ad_users(self, azure_ad_users): for user in azure_ad_users: corp_user_urn, error_str = self._map_identity_to_urn( - self._map_azure_ad_user_to_urn, user, "azure_ad_user_mapping", "user" + self._map_azure_ad_user_to_urn, + user, + "azure_ad_user_mapping", + "user", ) if error_str is not None: continue @@ -650,7 +667,10 @@ def _map_azure_ad_user_to_corp_user(self, azure_ad_user): ) def _extract_regex_match_from_dict_value( - self, str_dict: Dict[str, str], key: str, pattern: str + self, + str_dict: Dict[str, str], + key: str, + pattern: str, ) -> str: raw_value = str_dict.get(key) if raw_value is None: @@ -658,6 +678,6 @@ def _extract_regex_match_from_dict_value( match = re.search(pattern, raw_value) if match is None: raise ValueError( - f"Unable to extract a name from {raw_value} with the pattern {pattern}" + f"Unable to extract a name from {raw_value} with the pattern {pattern}", ) return match.group() diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py index 5452fbcd3f053b..8179f509c7ca9e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py @@ -67,10 +67,12 @@ class OktaConfig(StatefulIngestionConfigBase, ConfigModel): # Optional: Whether to ingest users, groups, or both. ingest_users: bool = Field( - default=True, description="Whether users should be ingested into DataHub." + default=True, + description="Whether users should be ingested into DataHub.", ) ingest_groups: bool = Field( - default=True, description="Whether groups should be ingested into DataHub." + default=True, + description="Whether groups should be ingested into DataHub.", ) ingest_group_membership: bool = Field( default=True, @@ -147,7 +149,8 @@ class OktaConfig(StatefulIngestionConfigBase, ConfigModel): # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="Okta Stateful Ingestion Config." + default=None, + description="Okta Stateful Ingestion Config.", ) # Optional: Whether to mask sensitive information from workunit ID's. On by default. @@ -158,7 +161,7 @@ class OktaConfig(StatefulIngestionConfigBase, ConfigModel): def okta_users_one_of_filter_or_search(cls, v, values): if v and values["okta_users_filter"]: raise ValueError( - "Only one of okta_users_filter or okta_users_search can be set" + "Only one of okta_users_filter or okta_users_search can be set", ) return v @@ -166,7 +169,7 @@ def okta_users_one_of_filter_or_search(cls, v, values): def okta_groups_one_of_filter_or_search(cls, v, values): if v and values["okta_groups_filter"]: raise ValueError( - "Only one of okta_groups_filter or okta_groups_search can be set" + "Only one of okta_groups_filter or okta_groups_search can be set", ) return v @@ -198,7 +201,8 @@ def report_filtered(self, name: str) -> None: @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.DESCRIPTIONS, "Optionally enabled via configuration") @capability( - SourceCapability.DELETION_DETECTION, "Optionally enabled via stateful_ingestion" + SourceCapability.DELETION_DETECTION, + "Optionally enabled via stateful_ingestion", ) class OktaSource(StatefulIngestionSourceBase): """ @@ -298,7 +302,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -321,7 +327,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: okta_groups = list(self._get_okta_groups(event_loop)) datahub_corp_group_snapshots = self._map_okta_groups(okta_groups) for group_count, datahub_corp_group_snapshot in enumerate( - datahub_corp_group_snapshots + datahub_corp_group_snapshots, ): mce = MetadataChangeEvent(proposedSnapshot=datahub_corp_group_snapshot) wu_id = f"group-snapshot-{group_count + 1 if self.config.mask_group_id else datahub_corp_group_snapshot.urn}" @@ -351,7 +357,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # Fetch membership for each group. for okta_group in okta_groups: datahub_corp_group_urn = self._map_okta_group_profile_to_urn( - okta_group.profile + okta_group.profile, ) if datahub_corp_group_urn is None: error_str = f"Failed to extract DataHub Group Name from Okta Group: Invalid regex pattern provided or missing profile attribute for group named {okta_group.profile.name}. Skipping..." @@ -363,7 +369,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: okta_group_users = self._get_okta_group_users(okta_group, event_loop) for okta_user in okta_group_users: datahub_corp_user_urn = self._map_okta_user_profile_to_urn( - okta_user.profile + okta_user.profile, ) if datahub_corp_user_urn is None: error_str = f"Failed to extract DataHub Username from Okta User: Invalid regex pattern provided or missing profile attribute for User with login {okta_user.profile.login}. Skipping..." @@ -382,7 +388,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: filtered_okta_users = filter(self._filter_okta_user, okta_users) datahub_corp_user_snapshots = self._map_okta_users(filtered_okta_users) for user_count, datahub_corp_user_snapshot in enumerate( - datahub_corp_user_snapshots + datahub_corp_user_snapshots, ): # TODO: Refactor common code between this and Okta to a common base class or utils # Add GroupMembership aspect populated in Step 2. @@ -396,7 +402,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: and len(datahub_group_membership.groups) == 0 ): logger.debug( - f"Filtering {datahub_corp_user_snapshot.urn} due to group filter" + f"Filtering {datahub_corp_user_snapshot.urn} due to group filter", ) self.report.report_filtered(datahub_corp_user_snapshot.urn) continue @@ -440,7 +446,8 @@ def _create_okta_client(self): # Retrieves all Okta Group Objects in batches. def _get_okta_groups( - self, event_loop: asyncio.AbstractEventLoop + self, + event_loop: asyncio.AbstractEventLoop, ) -> Iterable[Group]: logger.debug("Extracting all Okta groups") @@ -453,16 +460,18 @@ def _get_okta_groups( groups = resp = err = None try: groups, resp, err = event_loop.run_until_complete( - self.okta_client.list_groups(query_parameters) + self.okta_client.list_groups(query_parameters), ) except OktaAPIException as api_err: self.report.report_failure( - "okta_groups", f"Failed to fetch Groups from Okta API: {api_err}" + "okta_groups", + f"Failed to fetch Groups from Okta API: {api_err}", ) while True: if err: self.report.report_failure( - "okta_groups", f"Failed to fetch Groups from Okta API: {err}" + "okta_groups", + f"Failed to fetch Groups from Okta API: {err}", ) if groups: yield from groups @@ -480,7 +489,9 @@ def _get_okta_groups( # Retrieves Okta User Objects in a particular Okta Group in batches. def _get_okta_group_users( - self, group: Group, event_loop: asyncio.AbstractEventLoop + self, + group: Group, + event_loop: asyncio.AbstractEventLoop, ) -> Iterable[User]: logger.debug(f"Extracting users from Okta group named {group.profile.name}") @@ -489,7 +500,7 @@ def _get_okta_group_users( users = resp = err = None try: users, resp, err = event_loop.run_until_complete( - self.okta_client.list_group_users(group.id, query_parameters) + self.okta_client.list_group_users(group.id, query_parameters), ) except OktaAPIException as api_err: self.report.report_failure( @@ -528,16 +539,18 @@ def _get_okta_users(self, event_loop: asyncio.AbstractEventLoop) -> Iterable[Use users = resp = err = None try: users, resp, err = event_loop.run_until_complete( - self.okta_client.list_users(query_parameters) + self.okta_client.list_users(query_parameters), ) except OktaAPIException as api_err: self.report.report_failure( - "okta_users", f"Failed to fetch Users from Okta API: {api_err}" + "okta_users", + f"Failed to fetch Users from Okta API: {api_err}", ) while True: if err: self.report.report_failure( - "okta_users", f"Failed to fetch Users from Okta API: {err}" + "okta_users", + f"Failed to fetch Users from Okta API: {err}", ) if users: yield from users @@ -547,7 +560,8 @@ def _get_okta_users(self, event_loop: asyncio.AbstractEventLoop) -> Iterable[Use users, err = event_loop.run_until_complete(resp.next()) except OktaAPIException as api_err: self.report.report_failure( - "okta_users", f"Failed to fetch Users from Okta API: {api_err}" + "okta_users", + f"Failed to fetch Users from Okta API: {api_err}", ) else: break @@ -568,7 +582,8 @@ def _filter_okta_user(self, okta_user: User) -> bool: # Converts Okta Group Objects into DataHub CorpGroupSnapshots. def _map_okta_groups( - self, okta_groups: Iterable[Group] + self, + okta_groups: Iterable[Group], ) -> Iterable[CorpGroupSnapshot]: for okta_group in okta_groups: corp_group_urn = self._map_okta_group_profile_to_urn(okta_group.profile) @@ -587,7 +602,8 @@ def _map_okta_groups( # Creates DataHub CorpGroup Urn from Okta Group Object. def _map_okta_group_profile_to_urn( - self, okta_group_profile: GroupProfile + self, + okta_group_profile: GroupProfile, ) -> Union[str, None]: # Profile is a required field as per https://developer.okta.com/docs/reference/api/groups/#group-attributes group_name = self._map_okta_group_profile_to_group_name(okta_group_profile) @@ -610,7 +626,8 @@ def _map_okta_group_profile(self, profile: GroupProfile) -> CorpGroupInfoClass: # Converts Okta Group Profile Object into a DataHub Group Name. def _map_okta_group_profile_to_group_name( - self, okta_group_profile: GroupProfile + self, + okta_group_profile: GroupProfile, ) -> Union[str, None]: # Profile is a required field as per https://developer.okta.com/docs/reference/api/groups/#group-attributes return self._extract_regex_match_from_dict_value( @@ -638,7 +655,8 @@ def _map_okta_users(self, okta_users: Iterable[User]) -> Iterable[CorpUserSnapsh # Creates DataHub CorpUser Urn from Okta User Profile def _map_okta_user_profile_to_urn( - self, okta_user_profile: UserProfile + self, + okta_user_profile: UserProfile, ) -> Union[str, None]: # Profile is a required field as per https://developer.okta.com/docs/reference/api/users/#user-attributes username = self._map_okta_user_profile_to_username(okta_user_profile) @@ -648,7 +666,8 @@ def _map_okta_user_profile_to_urn( # Converts Okta User Profile Object into a DataHub User name. def _map_okta_user_profile_to_username( - self, okta_user_profile: UserProfile + self, + okta_user_profile: UserProfile, ) -> Union[str, None]: # Profile is a required field as per https://developer.okta.com/docs/reference/api/users/#user-attributes return self._extract_regex_match_from_dict_value( @@ -683,7 +702,10 @@ def _make_corp_user_urn(self, username: str) -> str: return f"urn:li:corpuser:{username}" def _extract_regex_match_from_dict_value( - self, str_dict: Dict[str, str], key: str, pattern: str + self, + str_dict: Dict[str, str], + key: str, + pattern: str, ) -> Union[str, None]: raw_value = str_dict.get(key) if raw_value is None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py index 9f15eda1501f11..bf7a9aa49c8a6d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py @@ -131,7 +131,8 @@ class KafkaSourceConfig( description="Whether or not to strip email id while adding owners using meta mappings.", ) tag_prefix: str = pydantic.Field( - default="", description="Prefix added to tags during ingestion." + default="", + description="Prefix added to tags during ingestion.", ) ignore_warnings_on_schema_type: bool = pydantic.Field( default=False, @@ -155,7 +156,7 @@ def get_kafka_consumer( "group.id": "datahub-kafka-ingestion", "bootstrap.servers": connection.bootstrap, **connection.consumer_config, - } + }, ) if CallableConsumerConfig.is_callable_config(connection.consumer_config): @@ -176,7 +177,7 @@ def get_kafka_admin_client( "group.id": "datahub-kafka-ingestion", "bootstrap.servers": connection.bootstrap, **connection.consumer_config, - } + }, ) if CallableConsumerConfig.is_callable_config(connection.consumer_config): # As per documentation, we need to explicitly call the poll method to make sure OAuth callback gets executed @@ -204,7 +205,7 @@ def __init__(self, config_dict: dict): self.config = KafkaSourceConfig.parse_obj_allow_extras(config_dict) self.report = KafkaSourceReport() self.consumer: confluent_kafka.Consumer = get_kafka_consumer( - self.config.connection + self.config.connection, ) def get_connection_test(self) -> TestConnectionReport: @@ -231,7 +232,7 @@ def schema_registry_connectivity(self) -> CapabilityReport: { "url": self.config.connection.schema_registry_url, **self.config.connection.schema_registry_config, - } + }, ).get_subjects() return CapabilityReport(capable=True) except Exception as e: @@ -264,7 +265,9 @@ class KafkaSource(StatefulIngestionSourceBase, TestableSource): @classmethod def create_schema_registry( - cls, config: KafkaSourceConfig, report: KafkaSourceReport + cls, + config: KafkaSourceConfig, + report: KafkaSourceReport, ) -> KafkaSchemaRegistryBase: try: schema_registry_class: Type = import_path(config.schema_registry_class) @@ -277,7 +280,7 @@ def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext): super().__init__(config, ctx) self.source_config: KafkaSourceConfig = config self.consumer: confluent_kafka.Consumer = get_kafka_consumer( - self.source_config.connection + self.source_config.connection, ) self.init_kafka_admin_client() self.report: KafkaSourceReport = KafkaSourceReport() @@ -322,13 +325,15 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: topics = self.consumer.list_topics( - timeout=self.source_config.connection.client_timeout_seconds + timeout=self.source_config.connection.client_timeout_seconds, ).topics extra_topic_details = self.fetch_extra_topic_details(topics.keys()) @@ -337,12 +342,16 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.source_config.topic_patterns.allowed(topic): try: yield from self._extract_record( - topic, False, topic_detail, extra_topic_details.get(topic) + topic, + False, + topic_detail, + extra_topic_details.get(topic), ) except Exception as e: logger.warning(f"Failed to extract topic {topic}", exc_info=True) self.report.report_warning( - "topic", f"Exception while extracting topic {topic}: {e}" + "topic", + f"Exception while extracting topic {topic}: {e}", ) else: self.report.report_dropped(topic) @@ -352,14 +361,19 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for subject in self.schema_registry_client.get_subjects(): try: yield from self._extract_record( - subject, True, topic_detail=None, extra_topic_config=None + subject, + True, + topic_detail=None, + extra_topic_config=None, ) except Exception as e: logger.warning( - f"Failed to extract subject {subject}", exc_info=True + f"Failed to extract subject {subject}", + exc_info=True, ) self.report.report_warning( - "subject", f"Exception while extracting topic {subject}: {e}" + "subject", + f"Exception while extracting topic {subject}: {e}", ) def _extract_record( @@ -379,7 +393,9 @@ def _extract_record( # 1. Create schemaMetadata aspect (pass control to SchemaRegistry) schema_metadata = self.schema_registry_client.get_schema_metadata( - topic, platform_urn, is_subject + topic, + platform_urn, + is_subject, ) # topic can have no associated subject, but still it can be ingested without schema @@ -417,11 +433,14 @@ def _extract_record( custom_props: Dict[str, str] = {} if not is_subject: custom_props = self.build_custom_properties( - topic, topic_detail, extra_topic_config + topic, + topic_detail, + extra_topic_config, ) schema_name: Optional[str] = ( self.schema_registry_client._get_subject_for_topic( - topic, is_key_schema=False + topic, + is_key_schema=False, ) ) if schema_name is not None: @@ -439,7 +458,7 @@ def _extract_record( # DataHub Dataset "description" field is mapped to documentSchema's "doc" field. avro_schema = avro.schema.parse( - schema_metadata.platformSchema.documentSchema + schema_metadata.platformSchema.documentSchema, ) description = getattr(avro_schema, "doc", None) # set the tags @@ -448,7 +467,8 @@ def _extract_record( schema_tags = cast( Iterable[str], avro_schema.other_props.get( - self.source_config.schema_tags_field, [] + self.source_config.schema_tags_field, + [], ), ) for tag in schema_tags: @@ -477,11 +497,13 @@ def _extract_record( if all_tags: dataset_snapshot.aspects.append( - mce_builder.make_global_tag_aspect_with_tag_list(all_tags) + mce_builder.make_global_tag_aspect_with_tag_list(all_tags), ) dataset_properties = DatasetPropertiesClass( - name=dataset_name, customProperties=custom_props, description=description + name=dataset_name, + customProperties=custom_props, + description=description, ) dataset_snapshot.aspects.append(dataset_properties) @@ -491,9 +513,10 @@ def _extract_record( DataPlatformInstanceClass( platform=platform_urn, instance=make_dataplatform_instance_urn( - self.platform, self.source_config.platform_instance + self.platform, + self.source_config.platform_instance, ), - ) + ), ) # 6. Emit the datasetSnapshot MCE @@ -513,7 +536,7 @@ def _extract_record( for domain, pattern in self.source_config.domain.items(): if pattern.allowed(dataset_name): domain_urn = make_domain_urn( - self.domain_registry.get_domain_urn(domain) + self.domain_registry.get_domain_urn(domain), ) if domain_urn: @@ -531,7 +554,9 @@ def build_custom_properties( custom_props: Dict[str, str] = {} self.update_custom_props_with_topic_details(topic, topic_detail, custom_props) self.update_custom_props_with_topic_config( - topic, extra_topic_config, custom_props + topic, + extra_topic_config, + custom_props, ) return custom_props @@ -543,7 +568,7 @@ def update_custom_props_with_topic_details( ) -> None: if topic_detail is None or topic_detail.partitions is None: logger.info( - f"Partitions and Replication Factor not available for topic {topic}" + f"Partitions and Replication Factor not available for topic {topic}", ) return @@ -589,7 +614,9 @@ def close(self) -> None: super().close() def _get_config_value_if_present( - self, config_dict: Dict[str, ConfigEntry], key: str + self, + config_dict: Dict[str, ConfigEntry], + key: str, ) -> Any: return @@ -598,7 +625,7 @@ def fetch_extra_topic_details(self, topics: List[str]) -> Dict[str, dict]: if not hasattr(self, "admin_client"): logger.debug( - "Kafka Admin Client missing. Not fetching config details for topics." + "Kafka Admin Client missing. Not fetching config details for topics.", ) else: try: @@ -625,7 +652,9 @@ def fetch_topic_configurations(self, topics: List[str]) -> Dict[str, dict]: topic_configurations: Dict[str, dict] = {} for config_resource, config_result_future in configs.items(): self.process_topic_config_result( - config_resource, config_result_future, topic_configurations + config_resource, + config_result_future, + topic_configurations, ) return topic_configurations @@ -639,9 +668,9 @@ def process_topic_config_result( topic_configurations[config_resource.name] = config_result_future.result() except Exception as e: logger.warning( - f"Config details for topic {config_resource.name} not fetched due to error {e}" + f"Config details for topic {config_resource.name} not fetched due to error {e}", ) else: logger.info( - f"Config details for topic {config_resource.name} fetched successfully" + f"Config details for topic {config_resource.name} fetched successfully", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka_schema_registry_base.py b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka_schema_registry_base.py index 59f174a9a50458..f0ee19c3488cc7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka_schema_registry_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka_schema_registry_base.py @@ -7,7 +7,10 @@ class KafkaSchemaRegistryBase(ABC): @abstractmethod def get_schema_metadata( - self, topic: str, platform_urn: str, is_subject: bool + self, + topic: str, + platform_urn: str, + is_subject: bool, ) -> Optional[SchemaMetadata]: pass @@ -17,6 +20,8 @@ def get_subjects(self) -> List[str]: @abstractmethod def _get_subject_for_topic( - self, dataset_subtype: str, is_key_schema: bool + self, + dataset_subtype: str, + is_key_schema: bool, ) -> Optional[str]: pass diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/common.py b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/common.py index 36f6a96c0d4080..3eade1c3f5db7d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/common.py @@ -45,12 +45,14 @@ class KafkaConnectSourceConfig( # See the Connect REST Interface for details # https://docs.confluent.io/platform/current/connect/references/restapi.html# connect_uri: str = Field( - default="http://localhost:8083/", description="URI to connect to." + default="http://localhost:8083/", + description="URI to connect to.", ) username: Optional[str] = Field(default=None, description="Kafka Connect username.") password: Optional[str] = Field(default=None, description="Kafka Connect password.") cluster_name: Optional[str] = Field( - default="connect-cluster", description="Cluster to ingest from." + default="connect-cluster", + description="Cluster to ingest from.", ) # convert lineage dataset's urns to lowercase convert_lineage_urns_to_lowercase: bool = Field( @@ -62,7 +64,8 @@ class KafkaConnectSourceConfig( description="regex patterns for connectors to filter for ingestion.", ) provided_configs: Optional[List[ProvidedConfig]] = Field( - default=None, description="Provided Configurations" + default=None, + description="Provided Configurations", ) connect_to_platform_map: Optional[Dict[str, Dict[str, str]]] = Field( default=None, @@ -125,7 +128,9 @@ def remove_prefix(text: str, prefix: str) -> str: def unquote( - string: str, leading_quote: str = '"', trailing_quote: Optional[str] = None + string: str, + leading_quote: str = '"', + trailing_quote: Optional[str] = None, ) -> str: """ If string starts and ends with a quote, unquote it @@ -149,7 +154,9 @@ def get_dataset_name( def get_platform_instance( - config: KafkaConnectSourceConfig, connector_name: str, platform: str + config: KafkaConnectSourceConfig, + connector_name: str, + platform: str, ) -> Optional[str]: instance_name = None if ( @@ -161,18 +168,19 @@ def get_platform_instance( if config.platform_instance_map and config.platform_instance_map.get(platform): logger.warning( f"Same source platform {platform} configured in both platform_instance_map and connect_to_platform_map." - "Will prefer connector specific platform instance from connect_to_platform_map." + "Will prefer connector specific platform instance from connect_to_platform_map.", ) elif config.platform_instance_map and config.platform_instance_map.get(platform): instance_name = config.platform_instance_map[platform] logger.info( - f"Instance name assigned is: {instance_name} for Connector Name {connector_name} and platform {platform}" + f"Instance name assigned is: {instance_name} for Connector Name {connector_name} and platform {platform}", ) return instance_name def transform_connector_config( - connector_config: Dict, provided_configs: List[ProvidedConfig] + connector_config: Dict, + provided_configs: List[ProvidedConfig], ) -> None: """This method will update provided configs in connector config values, if any""" lookupsByProvider = {} diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/kafka_connect.py b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/kafka_connect.py index 9edfce5855f430..32a3f9a8b308cf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/kafka_connect.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/kafka_connect.py @@ -78,13 +78,13 @@ def __init__(self, config: KafkaConnectSourceConfig, ctx: PipelineContext): { "Accept": "application/json", "Content-Type": "application/json", - } + }, ) # Test the connection if self.config.username is not None and self.config.password is not None: logger.info( - f"Connecting to {self.config.connect_uri} with Authentication..." + f"Connecting to {self.config.connect_uri} with Authentication...", ) self.session.auth = (self.config.username, self.config.password) @@ -108,17 +108,19 @@ def get_connectors_manifest(self) -> Iterable[ConnectorManifest]: for connector_name in payload: connector_url = f"{self.config.connect_uri}/connectors/{connector_name}" connector_manifest = self._get_connector_manifest( - connector_name, connector_url + connector_name, + connector_url, ) if connector_manifest is None or not self.config.connector_patterns.allowed( - connector_manifest.name + connector_manifest.name, ): self.report.report_dropped(connector_name) continue if self.config.provided_configs: transform_connector_config( - connector_manifest.config, self.config.provided_configs + connector_manifest.config, + self.config.provided_configs, ) connector_manifest.url = connector_url connector_manifest.topic_names = self._get_connector_topics(connector_name) @@ -141,7 +143,7 @@ def get_connectors_manifest(self) -> Iterable[ConnectorManifest]: [ connector.connector_name == connector_manifest.name for connector in self.config.generic_connectors - ] + ], ): class_type = ConfigDrivenSourceConnector else: @@ -175,14 +177,18 @@ def get_connectors_manifest(self) -> Iterable[ConnectorManifest]: yield connector_manifest def _get_connector_manifest( - self, connector_name: str, connector_url: str + self, + connector_name: str, + connector_url: str, ) -> Optional[ConnectorManifest]: try: connector_response = self.session.get(connector_url) connector_response.raise_for_status() except Exception as e: self.report.warning( - "Failed to get connector details", connector_name, exc=e + "Failed to get connector details", + connector_name, + exc=e, ) return None manifest = connector_response.json() @@ -197,7 +203,9 @@ def _get_connector_tasks(self, connector_name: str) -> dict: response.raise_for_status() except Exception as e: self.report.warning( - "Error getting connector tasks", context=connector_name, exc=e + "Error getting connector tasks", + context=connector_name, + exc=e, ) return {} @@ -211,7 +219,9 @@ def _get_connector_topics(self, connector_name: str) -> List[str]: response.raise_for_status() except Exception as e: self.report.warning( - "Error getting connector topics", context=connector_name, exc=e + "Error getting connector topics", + context=connector_name, + exc=e, ) return [] @@ -241,7 +251,8 @@ def construct_flow_workunit(self, connector: ConnectorManifest) -> MetadataWorkU ).as_workunit() def construct_job_workunits( - self, connector: ConnectorManifest + self, + connector: ConnectorManifest, ) -> Iterable[MetadataWorkUnit]: connector_name = connector.name flow_urn = builder.make_data_flow_urn( @@ -261,10 +272,14 @@ def construct_job_workunits( job_property_bag = lineage.job_property_bag source_platform_instance = get_platform_instance( - self.config, connector_name, source_platform + self.config, + connector_name, + source_platform, ) target_platform_instance = get_platform_instance( - self.config, connector_name, target_platform + self.config, + connector_name, + target_platform, ) job_id = self.get_job_id(lineage, connector, self.config) @@ -273,16 +288,20 @@ def construct_job_workunits( inlets = ( [ self.make_lineage_dataset_urn( - source_platform, source_dataset, source_platform_instance - ) + source_platform, + source_dataset, + source_platform_instance, + ), ] if source_dataset else [] ) outlets = [ self.make_lineage_dataset_urn( - target_platform, target_dataset, target_platform_instance - ) + target_platform, + target_dataset, + target_platform_instance, + ), ] yield MetadataChangeProposalWrapper( @@ -322,7 +341,7 @@ def get_job_id( and config.connect_to_platform_map and config.connect_to_platform_map.get(connector.name) and config.connect_to_platform_map[connector.name].get( - lineage.source_platform + lineage.source_platform, ) ): return f"{config.connect_to_platform_map[connector.name][lineage.source_platform]}.{lineage.source_dataset}" @@ -337,7 +356,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -351,11 +372,17 @@ def get_report(self) -> KafkaConnectSourceReport: return self.report def make_lineage_dataset_urn( - self, platform: str, name: str, platform_instance: Optional[str] + self, + platform: str, + name: str, + platform_instance: Optional[str], ) -> str: if self.config.convert_lineage_urns_to_lowercase: name = name.lower() return builder.make_dataset_urn_with_platform_instance( - platform, name, platform_instance, self.config.env + platform, + name, + platform_instance, + self.config.env, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/sink_connectors.py b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/sink_connectors.py index 10255ed544b812..16e37dc5a7f0f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/sink_connectors.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/sink_connectors.py @@ -24,7 +24,7 @@ def _get_parser(self, connector_manifest: ConnectorManifest) -> S3SinkParser: bucket = connector_manifest.config.get("s3.bucket.name") if not bucket: raise ValueError( - "Could not find 's3.bucket.name' in connector configuration" + "Could not find 's3.bucket.name' in connector configuration", ) # https://docs.confluent.io/kafka-connectors/s3-sink/current/configuration_options.html#storage @@ -66,7 +66,7 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: source_platform="kafka", target_dataset=target_dataset, target_platform=parser.target_platform, - ) + ), ) return lineages except Exception as e: @@ -113,7 +113,7 @@ def get_parser( provided_topics_to_tables: Dict[str, str] = {} if connector_manifest.config.get("snowflake.topic2table.map"): for each in connector_manifest.config["snowflake.topic2table.map"].split( - "," + ",", ): topic, table = each.split(":") provided_topics_to_tables[topic.strip()] = table.strip() @@ -163,7 +163,7 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: source_platform=KAFKA, target_dataset=target_dataset, target_platform="snowflake", - ) + ), ) return lineages @@ -251,7 +251,9 @@ def sanitize_table_name(self, table_name): return table_name def get_dataset_table_for_topic( - self, topic: str, parser: BQParser + self, + topic: str, + parser: BQParser, ) -> Optional[str]: if parser.version == "v2": dataset = parser.defaultDataset @@ -269,7 +271,7 @@ def get_dataset_table_for_topic( table = topic if parser.topicsToTables: topicregex_table_map: Dict[str, str] = dict( - self.get_list(parser.topicsToTables) # type: ignore + self.get_list(parser.topicsToTables), # type: ignore ) from java.util.regex import Pattern @@ -284,7 +286,9 @@ def get_dataset_table_for_topic( return f"{dataset}.{table}" def apply_transformations( - self, topic: str, transforms: List[Dict[str, str]] + self, + topic: str, + transforms: List[Dict[str, str]], ) -> str: for transform in transforms: if transform["type"] == "org.apache.kafka.connect.transforms.RegexRouter": @@ -331,7 +335,7 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: source_platform=KAFKA, target_dataset=target_dataset, target_platform=target_platform, - ) + ), ) return lineages diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/source_connectors.py b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/source_connectors.py index 5e64d4e161e3ea..46091a28c88f60 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/source_connectors.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect/source_connectors.py @@ -99,7 +99,8 @@ def get_parser( connector_manifest: ConnectorManifest, ) -> JdbcParser: url = remove_prefix( - str(connector_manifest.config.get("connection.url")), "jdbc:" + str(connector_manifest.config.get("connection.url")), + "jdbc:", ) url_instance = make_url(url) source_platform = get_platform_from_sqlalchemy_uri(str(url_instance)) @@ -189,11 +190,12 @@ def get_table_names(self) -> List[Tuple]: [ task["config"].get("tables") for task in self.connector_manifest.tasks - ] + ], ) ).split(",") quote_method = self.connector_manifest.config.get( - "quote.sql.identifiers", "always" + "quote.sql.identifiers", + "always", ) if ( quote_method == "always" @@ -212,13 +214,17 @@ def get_table_names(self) -> List[Tuple]: ( ( unquote( - table_id.split(sep)[-2], leading_quote_char, trailing_quote_char + table_id.split(sep)[-2], + leading_quote_char, + trailing_quote_char, ) if len(table_id.split(sep)) > 1 else "" ), unquote( - table_id.split(sep)[-1], leading_quote_char, trailing_quote_char + table_id.split(sep)[-1], + leading_quote_char, + trailing_quote_char, ), ) for table_id in table_ids @@ -234,7 +240,7 @@ def extract_flow_property_bag(self) -> Dict[str, str]: # Mask/Remove properties that may reveal credentials flow_property_bag["connection.url"] = self.get_parser( - self.connector_manifest + self.connector_manifest, ).db_connection_url return flow_property_bag @@ -249,7 +255,7 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: transforms = parser.transforms logging.debug( - f"Extracting source platform: {source_platform} and database name: {database_name} from connection url " + f"Extracting source platform: {source_platform} and database name: {database_name} from connection url ", ) if not self.connector_manifest.topic_names: @@ -282,13 +288,13 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: not in self.KNOWN_TOPICROUTING_TRANSFORMS + self.KNOWN_NONTOPICROUTING_TRANSFORMS for transform in transforms - ] + ], ) ALL_TRANSFORMS_NON_TOPICROUTING = all( [ transform["type"] in self.KNOWN_NONTOPICROUTING_TRANSFORMS for transform in transforms - ] + ], ) if NO_TRANSFORM or ALL_TRANSFORMS_NON_TOPICROUTING: @@ -342,7 +348,7 @@ def extract_lineages(self) -> List[KafkaConnectLineage]: topic_prefix=topic_prefix, topic_names=topic_names, include_source_dataset=False, - ) + ), ) self.report.warning( "Could not find input dataset for connector topics", @@ -479,12 +485,12 @@ def get_parser( ) elif connector_class == "io.debezium.connector.sqlserver.SqlServerConnector": database_name = connector_manifest.config.get( - "database.names" + "database.names", ) or connector_manifest.config.get("database.dbname") if "," in str(database_name): raise Exception( - f"Only one database is supported for Debezium's SQL Server connector. Found: {database_name}" + f"Only one database is supported for Debezium's SQL Server connector. Found: {database_name}", ) parser = self.DebeziumParser( diff --git a/metadata-ingestion/src/datahub/ingestion/source/ldap.py b/metadata-ingestion/src/datahub/ingestion/source/ldap.py index 236e91a86700c3..09eda4a02ae3f2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ldap.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ldap.py @@ -113,7 +113,8 @@ class LDAPSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): base_dn: str = Field(description="LDAP DN.") filter: str = Field(default="(objectClass=*)", description="LDAP extractor filter.") attrs_list: Optional[List[str]] = Field( - default=None, description="Retrieved attributes list" + default=None, + description="Retrieved attributes list", ) custom_props_list: Optional[List[str]] = Field( @@ -128,7 +129,8 @@ class LDAPSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): ) page_size: int = Field( - default=20, description="Size of each page to fetch when extracting metadata." + default=20, + description="Size of each page to fetch when extracting metadata.", ) manager_filter_enabled: bool = Field( @@ -141,7 +143,8 @@ class LDAPSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): description="[deprecated] Use pagination_enabled ", ) _deprecate_manager_pagination_enabled = pydantic_renamed_field( - "manager_pagination_enabled", "pagination_enabled" + "manager_pagination_enabled", + "pagination_enabled", ) pagination_enabled: bool = Field( default=True, @@ -167,7 +170,9 @@ def report_dropped(self, dn: str) -> None: def guess_person_ldap( - attrs: Dict[str, Any], config: LDAPSourceConfig, report: LDAPSourceReport + attrs: Dict[str, Any], + config: LDAPSourceConfig, + report: LDAPSourceReport, ) -> Optional[str]: """Determine the user's LDAP based on the DN and attributes.""" if config.user_attrs_map["urn"] in attrs: @@ -228,7 +233,8 @@ def __init__(self, ctx: PipelineContext, config: LDAPSourceConfig): try: self.ldap_client.simple_bind_s( - self.config.ldap_user, self.config.ldap_password + self.config.ldap_user, + self.config.ldap_password, ) except ldap.LDAPError as e: raise ConfigurationError("LDAP connection failed") from e @@ -248,7 +254,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -301,7 +309,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: pctrls = get_pctrls(serverctrls) if not pctrls: self.report.report_failure( - "ldap-control", "Server ignores RFC 2696 control." + "ldap-control", + "Server ignores RFC 2696 control.", ) break cookie = set_cookie(self.lc, pctrls) @@ -335,7 +344,8 @@ def handle_user(self, dn: str, attrs: Dict[str, Any]) -> Iterable[MetadataWorkUn manager_ldap = guess_person_ldap(m_attrs, self.config, self.report) m_email = get_attr_or_none( - m_attrs, self.config.user_attrs_map["email"] + m_attrs, + self.config.user_attrs_map["email"], ) make_manager_urn = ( m_email @@ -352,7 +362,9 @@ def handle_user(self, dn: str, attrs: Dict[str, Any]) -> Iterable[MetadataWorkUn self.report.report_dropped(dn) def handle_group( - self, dn: str, attrs: Dict[str, Any] + self, + dn: str, + attrs: Dict[str, Any], ) -> Iterable[MetadataWorkUnit]: """Creates a workunit for LDAP groups.""" @@ -363,7 +375,10 @@ def handle_group( self.report.report_dropped(dn) def build_corp_user_mce( - self, dn: str, attrs: dict, manager_ldap: Optional[str] + self, + dn: str, + attrs: dict, + manager_ldap: Optional[str], ) -> Optional[MetadataChangeEvent]: """ Create the MetadataChangeEvent via DN and attributes. @@ -382,17 +397,22 @@ def build_corp_user_mce( email = get_attr_or_none(attrs, self.config.user_attrs_map["email"]) display_name = get_attr_or_none( - attrs, self.config.user_attrs_map["displayName"], full_name + attrs, + self.config.user_attrs_map["displayName"], + full_name, ) title = get_attr_or_none(attrs, self.config.user_attrs_map["title"]) department_id_str = get_attr_or_none( - attrs, self.config.user_attrs_map["departmentId"] + attrs, + self.config.user_attrs_map["departmentId"], ) department_name = get_attr_or_none( - attrs, self.config.user_attrs_map["departmentName"] + attrs, + self.config.user_attrs_map["departmentName"], ) country_code = get_attr_or_none( - attrs, self.config.user_attrs_map["countryCode"] + attrs, + self.config.user_attrs_map["countryCode"], ) department_id = None with contextlib.suppress(ValueError): @@ -445,10 +465,12 @@ def build_corp_group_mce(self, attrs: dict) -> Optional[MetadataChangeEvent]: email = get_attr_or_none(attrs, self.config.group_attrs_map["email"]) description = get_attr_or_none( - attrs, self.config.group_attrs_map["description"] + attrs, + self.config.group_attrs_map["description"], ) displayName = get_attr_or_none( - attrs, self.config.group_attrs_map["displayName"] + attrs, + self.config.group_attrs_map["displayName"], ) make_group_urn = ( @@ -513,6 +535,8 @@ def parse_ldap_dn(input_clean: bytes) -> str: def get_attr_or_none( - attrs: Dict[str, Any], key: str, default: Optional[str] = None + attrs: Dict[str, Any], + key: str, + default: Optional[str] = None, ) -> str: return attrs[key][0].decode() if attrs.get(key) else default diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py index abe9b5684f8f1f..6ba2e5f95e2d8c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py @@ -251,7 +251,7 @@ def get_urn(self, config: LookerCommonConfig) -> str: def get_browse_path(self, config: LookerCommonConfig) -> str: browse_path = config.view_browse_pattern.replace_variables( - self.get_mapping(config) + self.get_mapping(config), ) return browse_path @@ -463,7 +463,8 @@ def create_view_project_map(view_fields: List[ViewField]) -> Dict[str, str]: def get_view_file_path( - lkml_fields: List[LookmlModelExploreField], view_name: str + lkml_fields: List[LookmlModelExploreField], + view_name: str, ) -> Optional[str]: """ Search for the view file path on field, if found then return the file path @@ -482,7 +483,8 @@ def get_view_file_path( def create_upstream_views_file_path_map( - view_names: Set[str], lkml_fields: List[LookmlModelExploreField] + view_names: Set[str], + lkml_fields: List[LookmlModelExploreField], ) -> Dict[str, Optional[str]]: """ Create a map of view-name v/s view file path, so that later we can fetch view's file path via view-name @@ -492,7 +494,8 @@ def create_upstream_views_file_path_map( for view_name in view_names: file_path: Optional[str] = get_view_file_path( - lkml_fields=lkml_fields, view_name=view_name + lkml_fields=lkml_fields, + view_name=view_name, ) upstream_views_file_path[view_name] = file_path @@ -659,7 +662,8 @@ def _get_schema( if not view_fields: return None fields, primary_keys = LookerUtil._get_fields_and_primary_keys( - view_fields=view_fields, reporter=reporter + view_fields=view_fields, + reporter=reporter, ) schema_metadata = SchemaMetadata( schemaName=schema_name, @@ -706,13 +710,15 @@ def _get_tag_mce_for_urn(tag_urn: str) -> MetadataChangeEvent: assert tag_urn in LookerUtil.tag_definitions return MetadataChangeEvent( proposedSnapshot=TagSnapshotClass( - urn=tag_urn, aspects=[LookerUtil.tag_definitions[tag_urn]] - ) + urn=tag_urn, + aspects=[LookerUtil.tag_definitions[tag_urn]], + ), ) @staticmethod def _get_tags_from_field_type( - field: ViewField, reporter: SourceReport + field: ViewField, + reporter: SourceReport, ) -> Optional[GlobalTagsClass]: schema_field_tags: List[TagAssociationClass] = [ TagAssociationClass(tag=builder.make_tag_urn(tag_name)) @@ -724,7 +730,7 @@ def _get_tags_from_field_type( [ TagAssociationClass(tag=tag_name) for tag_name in LookerUtil.type_to_tag_map[field.field_type] - ] + ], ) else: reporter.report_warning( @@ -779,7 +785,9 @@ def _get_fields_and_primary_keys( fields = [] for field in view_fields: schema_field = LookerUtil.view_field_to_schema_field( - field, reporter, tag_measures_and_dimensions + field, + reporter, + tag_measures_and_dimensions, ) fields.append(schema_field) if field.is_primary_key: @@ -819,7 +827,7 @@ class LookerExplore: None # captures the view name(s) this explore is derived from ) upstream_views_file_path: Dict[str, Optional[str]] = dataclasses_field( - default_factory=dict + default_factory=dict, ) # view_name is key and file_path is value. A single file may contains multiple views joins: Optional[List[str]] = None fields: Optional[List[ViewField]] = None # the fields exposed in this explore @@ -872,8 +880,8 @@ def from_dict( # create the list of extended explores extends = list( itertools.chain.from_iterable( - dict.get("extends", dict.get("extends__all", [])) - ) + dict.get("extends", dict.get("extends__all", [])), + ), ) if extends: for extended_explore in extends: @@ -889,7 +897,7 @@ def from_dict( upstream_views.extend(parsed_explore.upstream_views or []) else: logger.warning( - f"Could not find extended explore {extended_explore} for explore {dict['name']} in model {model_name}" + f"Could not find extended explore {extended_explore} for explore {dict['name']} in model {model_name}", ) else: # we only fallback to the view_names list if this is not an extended explore @@ -903,11 +911,11 @@ def from_dict( ) if not info: logger.warning( - f"Could not resolve view {view_name} for explore {dict['name']} in model {model_name}" + f"Could not resolve view {view_name} for explore {dict['name']} in model {model_name}", ) else: upstream_views.append( - ProjectInclude(project=info[0].project, include=view_name) + ProjectInclude(project=info[0].project, include=view_name), ) return LookerExplore( @@ -1016,10 +1024,10 @@ def from_api( # noqa: C901 else ViewFieldType.DIMENSION ), project_name=LookerUtil.extract_project_name_from_source_file( - dim_field.source_file + dim_field.source_file, ), view_name=LookerUtil.extract_view_name_from_lookml_model_explore_field( - dim_field + dim_field, ), is_primary_key=( dim_field.primary_key @@ -1027,7 +1035,7 @@ def from_api( # noqa: C901 else False ), upstream_fields=[], - ) + ), ) if explore.fields.measures is not None: for measure_field in explore.fields.measures: @@ -1054,10 +1062,10 @@ def from_api( # noqa: C901 ), field_type=ViewFieldType.MEASURE, project_name=LookerUtil.extract_project_name_from_source_file( - measure_field.source_file + measure_field.source_file, ), view_name=LookerUtil.extract_view_name_from_lookml_model_explore_field( - measure_field + measure_field, ), is_primary_key=( measure_field.primary_key @@ -1065,7 +1073,7 @@ def from_api( # noqa: C901 else False ), upstream_fields=[], - ) + ), ) view_project_map: Dict[str, str] = create_view_project_map(view_fields) @@ -1124,11 +1132,11 @@ def from_api( # noqa: C901 except SDKError as e: if "Looker Not Found (404)" in str(e): logger.info( - f"Explore {explore_name} in model {model} is referred to, but missing. Continuing..." + f"Explore {explore_name} in model {model} is referred to, but missing. Continuing...", ) else: logger.warning( - f"Failed to extract explore {explore_name} from model {model}: {e}" + f"Failed to extract explore {explore_name} from model {model}: {e}", ) except DeserializeError as e: reporter.warning( @@ -1159,7 +1167,7 @@ def get_mapping(self, config: LookerCommonConfig) -> NamingPatternMapping: def get_explore_urn(self, config: LookerCommonConfig) -> str: dataset_name = config.explore_naming_pattern.replace_variables( - self.get_mapping(config) + self.get_mapping(config), ) return builder.make_dataset_urn_with_platform_instance( @@ -1171,7 +1179,7 @@ def get_explore_urn(self, config: LookerCommonConfig) -> str: def get_explore_browse_path(self, config: LookerCommonConfig) -> str: browse_path = config.explore_browse_pattern.replace_variables( - self.get_mapping(config) + self.get_mapping(config), ) return browse_path @@ -1254,7 +1262,7 @@ def _to_metadata_events( # noqa: C901 time=int(observed_lineage_ts.timestamp() * 1000), actor=CORPUSER_DATAHUB, ), - ) + ), ) view_name_to_urn_map[view_ref.include] = view_urn @@ -1270,18 +1278,20 @@ def _to_metadata_events( # noqa: C901 builder.make_schema_field_urn( upstream_column_ref.table, upstream_column_ref.column, - ) + ), ], downstreams=[ builder.make_schema_field_urn( - self.get_explore_urn(config), field.name - ) + self.get_explore_urn(config), + field.name, + ), ], - ) + ), ) upstream_lineage = UpstreamLineage( - upstreams=upstreams, fineGrainedLineages=fine_grained_lineages or None + upstreams=upstreams, + fineGrainedLineages=fine_grained_lineages or None, ) dataset_snapshot.aspects.append(upstream_lineage) if self.fields is not None: @@ -1315,7 +1325,8 @@ def _to_metadata_events( # noqa: C901 # If extracting embeds is enabled, produce an MCP for embed URL. if extract_embed_urls: embed_mcp = create_embed_mcp( - dataset_snapshot.urn, self._get_embed_url(base_url) + dataset_snapshot.urn, + self._get_embed_url(base_url), ) proposals.append(embed_mcp) @@ -1323,7 +1334,7 @@ def _to_metadata_events( # noqa: C901 MetadataChangeProposalWrapper( entityUrn=dataset_snapshot.urn, aspect=container, - ) + ), ) return proposals @@ -1406,15 +1417,15 @@ class LookerDashboardSourceReport(StaleEntityRemovalSourceReport): charts_with_activity: LossySet[str] = dataclasses_field(default_factory=LossySet) accessed_dashboards: int = 0 dashboards_with_activity: LossySet[str] = dataclasses_field( - default_factory=LossySet + default_factory=LossySet, ) # Entities that don't seem to exist, so we don't emit usage aspects for them despite having usage data dashboards_skipped_for_usage: LossySet[str] = dataclasses_field( - default_factory=LossySet + default_factory=LossySet, ) charts_skipped_for_usage: LossySet[str] = dataclasses_field( - default_factory=LossySet + default_factory=LossySet, ) stage_latency: List[StageLatency] = dataclasses_field(default_factory=list) @@ -1428,10 +1439,10 @@ class LookerDashboardSourceReport(StaleEntityRemovalSourceReport): _looker_api: Optional[LookerAPI] = None query_latency: Dict[str, datetime.timedelta] = dataclasses_field( - default_factory=dict + default_factory=dict, ) user_resolution_latency: Dict[str, datetime.timedelta] = dataclasses_field( - default_factory=dict + default_factory=dict, ) def report_total_dashboards(self, total_dashboards: int) -> None: @@ -1456,25 +1467,31 @@ def report_charts_scanned_for_usage(self, num_charts: int) -> None: self.charts_scanned_for_usage += num_charts def report_upstream_latency( - self, start_time: datetime.datetime, end_time: datetime.datetime + self, + start_time: datetime.datetime, + end_time: datetime.datetime, ) -> None: # recording total combined latency is not very useful, keeping this method as a placeholder # for future implementation of min / max / percentiles etc. pass def report_query_latency( - self, query_type: str, latency: datetime.timedelta + self, + query_type: str, + latency: datetime.timedelta, ) -> None: self.query_latency[query_type] = latency def report_user_resolution_latency( - self, generator_type: str, latency: datetime.timedelta + self, + generator_type: str, + latency: datetime.timedelta, ) -> None: self.user_resolution_latency[generator_type] = latency def report_stage_start(self, stage_name: str) -> None: self.stage_latency.append( - StageLatency(name=stage_name, start_time=datetime.datetime.now()) + StageLatency(name=stage_name, start_time=datetime.datetime.now()), ) def report_stage_end(self, stage_name: str) -> None: @@ -1492,14 +1509,16 @@ def report_stage(self, stage_name: str) -> Iterator[None]: def compute_stats(self) -> None: if self.total_dashboards: self.dashboard_process_percentage_completion = round( - 100 * self.dashboards_scanned / self.total_dashboards, 2 + 100 * self.dashboards_scanned / self.total_dashboards, + 2, ) if self._looker_explore_registry: self.explore_registry_stats = self._looker_explore_registry.compute_stats() if self.total_explores: self.explores_process_percentage_completion = round( - 100 * self.explores_scanned / self.total_explores, 2 + 100 * self.explores_scanned / self.total_explores, + 2, ) if self._looker_api: @@ -1637,7 +1656,7 @@ def __init__(self, looker_api: LookerAPI, report: LookerDashboardSourceReport): def _initialize_user_cache(self) -> None: raw_users: Sequence[User] = self.looker_api_wrapper.all_users( - user_fields=self.fields + user_fields=self.fields, ) for raw_user in raw_users: @@ -1654,7 +1673,8 @@ def get_by_id(self, id_: str) -> Optional[LookerUser]: return self._user_cache.get(str(id_)) raw_user: Optional[User] = self.looker_api_wrapper.get_user( - str(id_), user_fields=self.fields + str(id_), + user_fields=self.fields, ) if raw_user is None: return None @@ -1663,7 +1683,8 @@ def get_by_id(self, id_: str) -> Optional[LookerUser]: return looker_user def to_platform_resource( - self, platform_instance: Optional[str] + self, + platform_instance: Optional[str], ) -> Iterable[MetadataChangeProposalWrapper]: try: platform_resource_key = PlatformResourceKey( diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py index 3ed3186399588e..0a7e81b4f8c16d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py @@ -61,11 +61,11 @@ def validate_pattern(self, at_least_one: bool) -> bool: for v in variables: if v not in self.ALLOWED_VARS: raise ValueError( - f"Failed to find {v} in allowed_variables {self.ALLOWED_VARS}" + f"Failed to find {v} in allowed_variables {self.ALLOWED_VARS}", ) if at_least_one and len(variables) == 0: raise ValueError( - f"Failed to find any variable assigned to pattern {self.pattern}. Must have at least one. {self.allowed_docstring()}" + f"Failed to find any variable assigned to pattern {self.pattern}. Must have at least one. {self.allowed_docstring()}", ) return True @@ -123,7 +123,7 @@ class LookerCommonConfig(EnvConfigMixin, PlatformInstanceConfigMixin): ) _deprecate_explore_browse_pattern = pydantic_field_deprecated( - "explore_browse_pattern" + "explore_browse_pattern", ) _deprecate_view_browse_pattern = pydantic_field_deprecated("view_browse_pattern") @@ -155,7 +155,8 @@ def _get_bigquery_definition( def _get_generic_definition( - looker_connection: DBConnection, platform: Optional[str] = None + looker_connection: DBConnection, + platform: Optional[str] = None, ) -> Tuple[str, Optional[str], Optional[str]]: if platform is None: # We extract the platform from the dialect name @@ -199,7 +200,8 @@ def lower_everything(cls, v): @classmethod def from_looker_connection( - cls, looker_connection: DBConnection + cls, + looker_connection: DBConnection, ) -> "LookerConnectionDefinition": """Dialect definitions are here: https://docs.looker.com/setup-and-management/database-config""" extractors: Dict[str, Any] = { @@ -209,14 +211,14 @@ def from_looker_connection( if looker_connection.dialect_name is None: raise ConfigurationError( - f"Unable to fetch a fully filled out connection for {looker_connection.name}. Please check your API permissions." + f"Unable to fetch a fully filled out connection for {looker_connection.name}. Please check your API permissions.", ) for extractor_pattern, extracting_function in extractors.items(): if re.match(extractor_pattern, looker_connection.dialect_name): (platform, db, schema) = extracting_function(looker_connection) return cls(platform=platform, default_db=db, default_schema=schema) raise ConfigurationError( - f"Could not find an appropriate platform for looker_connection: {looker_connection.name} with dialect: {looker_connection.dialect_name}" + f"Could not find an appropriate platform for looker_connection: {looker_connection.name} with dialect: {looker_connection.dialect_name}", ) @@ -282,7 +284,8 @@ class LookerDashboardSourceConfig( "enabled inside of Looker to use this feature.", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="" + default=None, + description="", ) extract_independent_looks: bool = Field( False, @@ -315,16 +318,25 @@ class LookerDashboardSourceConfig( @validator("external_base_url", pre=True, always=True) def external_url_defaults_to_api_config_base_url( - cls, v: Optional[str], *, values: Dict[str, Any], **kwargs: Dict[str, Any] + cls, + v: Optional[str], + *, + values: Dict[str, Any], + **kwargs: Dict[str, Any], ) -> Optional[str]: return v or values.get("base_url") @validator("extract_independent_looks", always=True) def stateful_ingestion_should_be_enabled( - cls, v: Optional[bool], *, values: Dict[str, Any], **kwargs: Dict[str, Any] + cls, + v: Optional[bool], + *, + values: Dict[str, Any], + **kwargs: Dict[str, Any], ) -> Optional[bool]: stateful_ingestion: StatefulStaleMetadataRemovalConfig = cast( - StatefulStaleMetadataRemovalConfig, values.get("stateful_ingestion") + StatefulStaleMetadataRemovalConfig, + values.get("stateful_ingestion"), ) if v is True and ( stateful_ingestion is None or stateful_ingestion.enabled is False diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_connection.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_connection.py index 2b7ce6f6da026d..c0caa181732f30 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_connection.py @@ -36,12 +36,12 @@ def get_connection_def_based_on_connection_string( except SDKError: logger.error( f"Failed to retrieve connection {connection} from Looker. This usually happens when the " - f"credentials provided are not admin credentials." + f"credentials provided are not admin credentials.", ) else: try: connection_def = LookerConnectionDefinition.from_looker_connection( - looker_connection + looker_connection, ) # Populate the cache (using the config map) to avoid calling looker again for this connection diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py index d771821a14d88d..9cc061834cc4f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py @@ -137,7 +137,7 @@ def resolve_includes( resolved_project_name = remote_project else: logger.warning( - f"Resolving {inc} failed. Could not find a locally checked out reference for {remote_project}" + f"Resolving {inc} failed. Could not find a locally checked out reference for {remote_project}", ) continue elif inc.startswith("/"): @@ -155,12 +155,12 @@ def resolve_includes( if project_name == BASE_PROJECT_NAME and root_project_name is not None: non_base_project_name = root_project_name if non_base_project_name != BASE_PROJECT_NAME and inc.startswith( - f"/{non_base_project_name}/" + f"/{non_base_project_name}/", ): # This might be a local include. Let's make sure that '/{project_name}' doesn't # exist as normal include in the project. if not pathlib.Path( - f"{resolved_project_folder}/{non_base_project_name}" + f"{resolved_project_folder}/{non_base_project_name}", ).exists(): path_within_project = pathlib.Path(*pathlib.Path(inc).parts[2:]) glob_expr = f"{resolved_project_folder}/{path_within_project}" @@ -175,7 +175,7 @@ def resolve_includes( pathlib.Path(p) for p in sorted( glob.glob(glob_expr, recursive=True) - + glob.glob(f"{glob_expr}.lkml", recursive=True) + + glob.glob(f"{glob_expr}.lkml", recursive=True), ) ] # We don't want to match directories. The '**' glob can be used to @@ -183,7 +183,7 @@ def resolve_includes( if p.is_file() ] logger.debug( - f"traversal_path={traversal_path}, included_files = {included_files}, seen_so_far: {seen_so_far}" + f"traversal_path={traversal_path}, included_files = {included_files}, seen_so_far: {seen_so_far}", ) if "*" not in inc and not included_files: reporter.warning( @@ -207,12 +207,12 @@ def resolve_includes( or included_file.endswith(".dashboard.lkml") ): logger.debug( - f"include '{included_file}' is a dashboard, skipping it" + f"include '{included_file}' is a dashboard, skipping it", ) continue logger.debug( - f"Will be loading {included_file}, traversed here via {traversal_path}" + f"Will be loading {included_file}, traversed here via {traversal_path}", ) try: parsed = load_and_preprocess_file( @@ -232,7 +232,7 @@ def resolve_includes( reporter, seen_so_far, traversal_path=f"{traversal_path} -> {pathlib.Path(included_file).stem}", - ) + ), ) except Exception as e: reporter.report_warning( @@ -247,7 +247,7 @@ def resolve_includes( [ ProjectInclude(project=resolved_project_name, include=f) for f in included_files - ] + ], ) return resolved @@ -289,7 +289,7 @@ def from_looker_dict( seen_so_far=seen_so_far, ) logger.debug( - f"resolved_includes for {absolute_file_path} is {resolved_includes}" + f"resolved_includes for {absolute_file_path} is {resolved_includes}", ) views = looker_view_file_dict.get("views", []) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_file_loader.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_file_loader.py index 9fac0b52fde0dd..153fa2450db4f4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_file_loader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_file_loader.py @@ -38,7 +38,10 @@ def __init__( self.source_config = source_config def _load_viewfile( - self, project_name: str, path: str, reporter: LookMLSourceReport + self, + project_name: str, + path: str, + reporter: LookMLSourceReport, ) -> Optional[LookerViewFile]: # always fully resolve paths to simplify de-dup path = str(pathlib.Path(path).resolve()) @@ -49,7 +52,7 @@ def _load_viewfile( if not matched_any_extension: # not a view file logger.debug( - f"Skipping file {path} because it doesn't appear to be a view file. Matched extensions {allowed_extensions}" + f"Skipping file {path} because it doesn't appear to be a view file. Matched extensions {allowed_extensions}", ) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_lib_wrapper.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_lib_wrapper.py index c3f2a110136c45..4e9b4dc6ba2364 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_lib_wrapper.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_lib_wrapper.py @@ -43,7 +43,7 @@ class LookerAPIConfig(ConfigModel): client_id: str = Field(description="Looker API client id.") client_secret: str = Field(description="Looker API client secret.") base_url: str = Field( - description="Url to your Looker instance: `https://company.looker.com:19999` or `https://looker.company.com`, or similar. Used for making API calls to Looker and constructing clickable dashboard and chart urls." + description="Url to your Looker instance: `https://company.looker.com:19999` or `https://looker.company.com`, or similar. Used for making API calls to Looker and constructing clickable dashboard and chart urls.", ) transport_options: Optional[TransportOptionsConfig] = Field( None, @@ -86,7 +86,8 @@ def __init__(self, config: LookerAPIConfig) -> None: # Somewhat hacky mechanism for enabling retries on the Looker SDK. # Unfortunately, it doesn't expose a cleaner way to do this. if isinstance( - self.client.transport, looker_requests_transport.RequestsTransport + self.client.transport, + looker_requests_transport.RequestsTransport, ): adapter = HTTPAdapter( max_retries=self.config.max_retries, @@ -109,11 +110,11 @@ def __init__(self, config: LookerAPIConfig) -> None: self.transport_options if config.transport_options is not None else None - ) + ), ) except SDKError as e: raise ConfigurationError( - f"Failed to connect/authenticate with looker - check your configuration: {e}" + f"Failed to connect/authenticate with looker - check your configuration: {e}", ) from e self.client_stats = LookerAPIStats() @@ -198,7 +199,9 @@ def all_lookml_models(self) -> Sequence[LookmlModel]: def lookml_model_explore(self, model: str, explore_name: str) -> LookmlModelExplore: self.client_stats.explore_calls += 1 return self.client.lookml_model_explore( - model, explore_name, transport_options=self.transport_options + model, + explore_name, + transport_options=self.transport_options, ) @lru_cache(maxsize=1000) @@ -218,11 +221,11 @@ def folder_ancestors( if "Looker Not Found (404)" in str(e): # Folder ancestors not found logger.info( - f"Could not find ancestors for folder with id {folder_id}: 404 error" + f"Could not find ancestors for folder with id {folder_id}: 404 error", ) else: logger.warning( - f"Could not find ancestors for folder with id {folder_id}" + f"Could not find ancestors for folder with id {folder_id}", ) logger.warning(f"Failure was {e}") # Folder ancestors not found @@ -235,11 +238,14 @@ def all_connections(self): def connection(self, connection_name: str) -> DBConnection: self.client_stats.connection_calls += 1 return self.client.connection( - connection_name, transport_options=self.transport_options + connection_name, + transport_options=self.transport_options, ) def lookml_model( - self, model_name: str, fields: Union[str, List[str]] + self, + model_name: str, + fields: Union[str, List[str]], ) -> LookmlModel: self.client_stats.lookml_model_calls += 1 return self.client.lookml_model( @@ -263,14 +269,16 @@ def all_dashboards(self, fields: Union[str, List[str]]) -> Sequence[DashboardBas ) def all_looks( - self, fields: Union[str, List[str]], soft_deleted: bool + self, + fields: Union[str, List[str]], + soft_deleted: bool, ) -> List[Look]: self.client_stats.all_looks_calls += 1 looks: List[Look] = list( self.client.all_looks( fields=self.__fields_mapper(fields), transport_options=self.transport_options, - ) + ), ) if soft_deleted: @@ -296,7 +304,9 @@ def get_look(self, look_id: str, fields: Union[str, List[str]]) -> LookWithQuery ) def search_dashboards( - self, fields: Union[str, List[str]], deleted: str + self, + fields: Union[str, List[str]], + deleted: str, ) -> Sequence[Dashboard]: self.client_stats.search_dashboards_calls += 1 return self.client.search_dashboards( @@ -306,7 +316,9 @@ def search_dashboards( ) def search_looks( - self, fields: Union[str, List[str]], deleted: Optional[bool] + self, + fields: Union[str, List[str]], + deleted: Optional[bool], ) -> List[Look]: self.client_stats.search_looks_calls += 1 return list( @@ -314,5 +326,5 @@ def search_looks( fields=self.__fields_mapper(fields), deleted=deleted, transport_options=self.transport_options, - ) + ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py index 2f1fcd378d40fb..ca2e1fec85b464 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py @@ -115,7 +115,8 @@ @capability(SourceCapability.DESCRIPTIONS, "Enabled by default") @capability(SourceCapability.PLATFORM_INSTANCE, "Use the `platform_instance` field") @capability( - SourceCapability.OWNERSHIP, "Enabled by default, configured using `extract_owners`" + SourceCapability.OWNERSHIP, + "Enabled by default, configured using `extract_owners`", ) @capability(SourceCapability.LINEAGE_COARSE, "Supported by default") @capability( @@ -147,10 +148,13 @@ def __init__(self, config: LookerDashboardSourceConfig, ctx: PipelineContext): self.reporter: LookerDashboardSourceReport = LookerDashboardSourceReport() self.looker_api: LookerAPI = LookerAPI(self.source_config) self.user_registry: LookerUserRegistry = LookerUserRegistry( - self.looker_api, self.reporter + self.looker_api, + self.reporter, ) self.explore_registry: LookerExploreRegistry = LookerExploreRegistry( - self.looker_api, self.reporter, self.source_config + self.looker_api, + self.reporter, + self.source_config, ) self.reporter._looker_explore_registry = self.explore_registry self.reporter._looker_api = self.looker_api @@ -225,7 +229,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: if test_report.basic_connectivity is None: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=f"{e}" + capable=False, + failure_reason=f"{e}", ) return test_report @@ -273,7 +278,8 @@ def _get_views_from_fields(self, fields: List[str]) -> List[str]: return list(views) def _get_input_fields_from_query( - self, query: Optional[Query] + self, + query: Optional[Query], ) -> List[InputFieldElement]: if query is None: return [] @@ -285,11 +291,11 @@ def _get_input_fields_from_query( # - looker custom dimensions: https://docs.looker.com/exploring-data/adding-fields/custom-measure#creating_a_custom_dimension_using_a_looker_expression try: dynamic_fields = json.loads( - query.dynamic_fields if query.dynamic_fields is not None else "[]" + query.dynamic_fields if query.dynamic_fields is not None else "[]", ) except JSONDecodeError as e: logger.warning( - f"Json load failed on loading dynamic field with error: {e}. The field value was: {query.dynamic_fields}" + f"Json load failed on loading dynamic field with error: {e}. The field value was: {query.dynamic_fields}", ) dynamic_fields = [] @@ -305,7 +311,7 @@ def _get_input_fields_from_query( type="string", description="", ), - ) + ), ) if "measure" in field: # for measure, we can also make sure to index the underlying field that the measure uses @@ -317,7 +323,7 @@ def _get_input_fields_from_query( view_field=None, model=query.model, explore=query.view, - ) + ), ) result.append( InputFieldElement( @@ -329,7 +335,7 @@ def _get_input_fields_from_query( type="string", description="", ), - ) + ), ) if "dimension" in field: result.append( @@ -342,7 +348,7 @@ def _get_input_fields_from_query( type="string", description="", ), - ) + ), ) # A query uses fields defined in explores, find the metadata about that field @@ -355,8 +361,11 @@ def _get_input_fields_from_query( # later to fetch this result.append( InputFieldElement( - name=field, view_field=None, model=query.model, explore=query.view - ) + name=field, + view_field=None, + model=query.model, + explore=query.view, + ), ) # A query uses fields for filtering, and those fields are defined in views, find the views those fields use @@ -371,8 +380,11 @@ def _get_input_fields_from_query( # later to fetch this result.append( InputFieldElement( - name=field, view_field=None, model=query.model, explore=query.view - ) + name=field, + view_field=None, + model=query.model, + explore=query.view, + ), ) return result @@ -384,7 +396,8 @@ def add_reachable_explore(self, model: str, explore: str, via: str) -> None: self.reachable_explores[(model, explore)].append(via) def _get_looker_dashboard_element( # noqa: C901 - self, element: DashboardElement + self, + element: DashboardElement, ) -> Optional[LookerDashboardElement]: # Dashboard elements can use raw usage_queries against explores explores: List[str] @@ -399,7 +412,7 @@ def _get_looker_dashboard_element( # noqa: C901 # Get the explore from the view directly explores = [element.query.view] if element.query.view is not None else [] logger.debug( - f"Element {element.title}: Explores added via query: {explores}" + f"Element {element.title}: Explores added via query: {explores}", ) for exp in explores: self.add_reachable_explore( @@ -486,11 +499,11 @@ def _get_looker_dashboard_element( # noqa: C901 if element.result_maker.query.view is not None: explores.append(element.result_maker.query.view) input_fields = self._get_input_fields_from_query( - element.result_maker.query + element.result_maker.query, ) logger.debug( - f"Element {element.title}: Explores added via result_maker: {explores}" + f"Element {element.title}: Explores added via result_maker: {explores}", ) for exp in explores: @@ -522,7 +535,7 @@ def _get_looker_dashboard_element( # noqa: C901 view_field=None, model=query.model if query is not None else "", explore=query.view if query is not None else "", - ) + ), ) explores = sorted(list(set(explores))) # dedup the list of views @@ -550,7 +563,8 @@ def _get_looker_dashboard_element( # noqa: C901 return None def _get_chart_type( - self, dashboard_element: LookerDashboardElement + self, + dashboard_element: LookerDashboardElement, ) -> Optional[str]: type_mapping = { "looker_column": ChartTypeClass.BAR, @@ -596,7 +610,9 @@ def _get_chart_type( return chart_type def _get_folder_browse_path_v2_entries( - self, folder: LookerFolder, include_current_folder: bool = True + self, + folder: LookerFolder, + include_current_folder: bool = True, ) -> Iterable[BrowsePathEntryClass]: for ancestor in self.looker_api.folder_ancestors(folder_id=folder.id): assert ancestor.id @@ -644,7 +660,7 @@ def _make_chart_metadata_events( ], # dashboard will be None if this is a standalone look ) -> List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]]: chart_urn = self._make_chart_urn( - element_id=dashboard_element.get_urn_element_id() + element_id=dashboard_element.get_urn_element_id(), ) self.chart_urns.add(chart_urn) chart_snapshot = ChartSnapshot( @@ -664,11 +680,13 @@ def _make_chart_metadata_events( customProperties={ "upstream_fields": ( ",".join( - sorted({field.name for field in dashboard_element.input_fields}) + sorted( + {field.name for field in dashboard_element.input_fields}, + ), ) if dashboard_element.input_fields else "" - ) + ), }, ) chart_snapshot.aspects.append(chart_info) @@ -679,7 +697,7 @@ def _make_chart_metadata_events( and dashboard.folder is not None ): browse_path = BrowsePathsClass( - paths=[f"/Folders/{dashboard.folder_path}/{dashboard.title}"] + paths=[f"/Folders/{dashboard.folder_path}/{dashboard.title}"], ) chart_snapshot.aspects.append(browse_path) @@ -697,7 +715,7 @@ def _make_chart_metadata_events( and dashboard_element.folder is not None ): # independent look browse_path = BrowsePathsClass( - paths=[f"/Folders/{dashboard_element.folder_path}"] + paths=[f"/Folders/{dashboard_element.folder_path}"], ) chart_snapshot.aspects.append(browse_path) browse_path_v2 = BrowsePathsV2Class( @@ -740,14 +758,14 @@ def _make_chart_metadata_events( and self.source_config.external_base_url ): maybe_embed_url = dashboard_element.embed_url( - self.source_config.external_base_url + self.source_config.external_base_url, ) if maybe_embed_url: proposals.append( create_embed_mcp( chart_snapshot.urn, maybe_embed_url, - ) + ), ) if dashboard is None and dashboard_element.folder: @@ -755,20 +773,23 @@ def _make_chart_metadata_events( container=self._gen_folder_key(dashboard_element.folder.id).as_urn(), ) proposals.append( - MetadataChangeProposalWrapper(entityUrn=chart_urn, aspect=container) + MetadataChangeProposalWrapper(entityUrn=chart_urn, aspect=container), ) if browse_path_v2: proposals.append( MetadataChangeProposalWrapper( - entityUrn=chart_urn, aspect=browse_path_v2 - ) + entityUrn=chart_urn, + aspect=browse_path_v2, + ), ) return proposals def _make_dashboard_metadata_events( - self, looker_dashboard: LookerDashboard, chart_urns: List[str] + self, + looker_dashboard: LookerDashboard, + chart_urns: List[str], ) -> List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]]: dashboard_urn = self.make_dashboard_urn(looker_dashboard) dashboard_snapshot = DashboardSnapshot( @@ -790,7 +811,7 @@ def _make_dashboard_metadata_events( and looker_dashboard.folder is not None ): browse_path = BrowsePathsClass( - paths=[f"/Folders/{looker_dashboard.folder_path}"] + paths=[f"/Folders/{looker_dashboard.folder_path}"], ) browse_path_v2 = BrowsePathsV2Class( path=[ @@ -809,7 +830,7 @@ def _make_dashboard_metadata_events( dashboard_mce = MetadataChangeEvent(proposedSnapshot=dashboard_snapshot) proposals: List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]] = [ - dashboard_mce + dashboard_mce, ] if looker_dashboard.folder is not None: @@ -817,14 +838,18 @@ def _make_dashboard_metadata_events( container=self._gen_folder_key(looker_dashboard.folder.id).as_urn(), ) proposals.append( - MetadataChangeProposalWrapper(entityUrn=dashboard_urn, aspect=container) + MetadataChangeProposalWrapper( + entityUrn=dashboard_urn, + aspect=container, + ), ) if browse_path_v2: proposals.append( MetadataChangeProposalWrapper( - entityUrn=dashboard_urn, aspect=browse_path_v2 - ) + entityUrn=dashboard_urn, + aspect=browse_path_v2, + ), ) # If extracting embeds is enabled, produce an MCP for embed URL. @@ -836,7 +861,7 @@ def _make_dashboard_metadata_events( create_embed_mcp( dashboard_snapshot.urn, looker_dashboard.embed_url(self.source_config.external_base_url), - ) + ), ) if self.source_config.include_platform_instance_in_urns: @@ -844,7 +869,7 @@ def _make_dashboard_metadata_events( MetadataChangeProposalWrapper( entityUrn=dashboard_urn, aspect=self._create_platform_instance_aspect(), - ) + ), ) return proposals @@ -915,7 +940,7 @@ def _make_explore_metadata_events( yield from events self.reporter.report_upstream_latency(start_time, end_time) logger.debug( - f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}" + f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}", ) def list_all_explores(self) -> Iterable[Tuple[Optional[str], str, str]]: @@ -930,7 +955,9 @@ def list_all_explores(self) -> Iterable[Tuple[Optional[str], str, str]]: yield (model.project_name, model.name, explore.name) def fetch_one_explore( - self, model: str, explore: str + self, + model: str, + explore: str, ) -> Tuple[ List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]], str, @@ -954,7 +981,8 @@ def fetch_one_explore( return events, f"{model}:{explore}", start_time, datetime.datetime.now() def _extract_event_urn( - self, event: Union[MetadataChangeEvent, MetadataChangeProposalWrapper] + self, + event: Union[MetadataChangeEvent, MetadataChangeProposalWrapper], ) -> Optional[str]: if isinstance(event, MetadataChangeEvent): return event.proposedSnapshot.urn @@ -962,7 +990,8 @@ def _extract_event_urn( return event.entityUrn def _emit_folder_as_container( - self, folder: LookerFolder + self, + folder: LookerFolder, ) -> Iterable[MetadataWorkUnit]: if folder.id not in self.processed_folders: yield from gen_containers( @@ -987,7 +1016,8 @@ def _emit_folder_as_container( path=[ BrowsePathEntryClass("Folders"), *self._get_folder_browse_path_v2_entries( - folder, include_current_folder=False + folder, + include_current_folder=False, ), ], ), @@ -1003,14 +1033,15 @@ def _gen_folder_key(self, folder_id: str) -> LookerFolderKey: ) def _make_dashboard_and_chart_mces( - self, looker_dashboard: LookerDashboard + self, + looker_dashboard: LookerDashboard, ) -> Iterable[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]]: # Step 1: Emit metadata for each Chart inside the Dashboard. chart_events = [] for element in looker_dashboard.dashboard_elements: if element.type == "vis": chart_events.extend( - self._make_chart_metadata_events(element, looker_dashboard) + self._make_chart_metadata_events(element, looker_dashboard), ) yield from chart_events @@ -1025,16 +1056,18 @@ def _make_dashboard_and_chart_mces( chart_urns.add(chart_event_urn) dashboard_events = self._make_dashboard_metadata_events( - looker_dashboard, list(chart_urns) + looker_dashboard, + list(chart_urns), ) yield from dashboard_events def get_ownership( - self, looker_dashboard_look: Union[LookerDashboard, LookerDashboardElement] + self, + looker_dashboard_look: Union[LookerDashboard, LookerDashboardElement], ) -> Optional[OwnershipClass]: if looker_dashboard_look.owner is not None: owner_urn = looker_dashboard_look.owner.get_urn( - self.source_config.strip_user_ids_from_email + self.source_config.strip_user_ids_from_email, ) if owner_urn is not None: ownership: OwnershipClass = OwnershipClass( @@ -1042,33 +1075,34 @@ def get_ownership( OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return ownership return None def _get_change_audit_stamps( - self, looker_dashboard: LookerDashboard + self, + looker_dashboard: LookerDashboard, ) -> ChangeAuditStamps: change_audit_stamp: ChangeAuditStamps = ChangeAuditStamps() if looker_dashboard.created_at is not None: change_audit_stamp.created.time = round( - looker_dashboard.created_at.timestamp() * 1000 + looker_dashboard.created_at.timestamp() * 1000, ) if looker_dashboard.owner is not None: owner_urn = looker_dashboard.owner.get_urn( - self.source_config.strip_user_ids_from_email + self.source_config.strip_user_ids_from_email, ) if owner_urn: change_audit_stamp.created.actor = owner_urn if looker_dashboard.last_updated_at is not None: change_audit_stamp.lastModified.time = round( - looker_dashboard.last_updated_at.timestamp() * 1000 + looker_dashboard.last_updated_at.timestamp() * 1000, ) if looker_dashboard.last_updated_by is not None: updated_by_urn = looker_dashboard.last_updated_by.get_urn( - self.source_config.strip_user_ids_from_email + self.source_config.strip_user_ids_from_email, ) if updated_by_urn: change_audit_stamp.lastModified.actor = updated_by_urn @@ -1078,7 +1112,7 @@ def _get_change_audit_stamps( and looker_dashboard.deleted_at is not None ): deleter_urn = looker_dashboard.deleted_by.get_urn( - self.source_config.strip_user_ids_from_email + self.source_config.strip_user_ids_from_email, ) if deleter_urn: change_audit_stamp.deleted = AuditStamp( @@ -1115,7 +1149,7 @@ def _get_looker_dashboard(self, dashboard: Dashboard) -> LookerDashboard: for element in elements: self.reporter.report_charts_scanned() if element.id is not None and not self.source_config.chart_pattern.allowed( - element.id + element.id, ): self.reporter.report_charts_dropped(element.id) continue @@ -1174,7 +1208,8 @@ def _get_looker_user(self, user_id: Optional[str]) -> Optional[LookerUser]: return user def process_metrics_dimensions_and_fields_for_dashboard( - self, dashboard: LookerDashboard + self, + dashboard: LookerDashboard, ) -> List[MetadataWorkUnit]: chart_mcps = [ self._make_metrics_dimensions_chart_mcp(element) @@ -1190,7 +1225,8 @@ def process_metrics_dimensions_and_fields_for_dashboard( return workunits def _input_fields_from_dashboard_element( - self, dashboard_element: LookerDashboardElement + self, + dashboard_element: LookerDashboardElement, ) -> List[InputFieldClass]: input_fields = ( dashboard_element.input_fields @@ -1203,18 +1239,21 @@ def _input_fields_from_dashboard_element( # enrich the input_fields with the fully hydrated ViewField from the now fetched explores for input_field in input_fields: entity_urn = self._make_chart_urn( - element_id=dashboard_element.get_urn_element_id() + element_id=dashboard_element.get_urn_element_id(), ) view_field_for_reference = input_field.view_field if input_field.view_field is None: explore = self.explore_registry.get_explore( - input_field.model, input_field.explore + input_field.model, + input_field.explore, ) if explore is not None: # add this to the list of explores to finally generate metadata for self.add_reachable_explore( - input_field.model, input_field.explore, entity_urn + input_field.model, + input_field.explore, + entity_urn, ) entity_urn = explore.get_explore_urn(self.source_config) explore_fields = ( @@ -1236,26 +1275,28 @@ def _input_fields_from_dashboard_element( fields_for_mcp.append( InputFieldClass( schemaFieldUrn=builder.make_schema_field_urn( - entity_urn, view_field_for_reference.name + entity_urn, + view_field_for_reference.name, ), schemaField=LookerUtil.view_field_to_schema_field( view_field_for_reference, self.reporter, self.source_config.tag_measures_and_dimensions, ), - ) + ), ) return fields_for_mcp def _make_metrics_dimensions_dashboard_mcp( - self, dashboard: LookerDashboard + self, + dashboard: LookerDashboard, ) -> MetadataChangeProposalWrapper: dashboard_urn = self.make_dashboard_urn(dashboard) all_fields = [] for dashboard_element in dashboard.dashboard_elements: all_fields.extend( - self._input_fields_from_dashboard_element(dashboard_element) + self._input_fields_from_dashboard_element(dashboard_element), ) input_fields_aspect = InputFieldsClass(fields=all_fields) @@ -1266,13 +1307,14 @@ def _make_metrics_dimensions_dashboard_mcp( ) def _make_metrics_dimensions_chart_mcp( - self, dashboard_element: LookerDashboardElement + self, + dashboard_element: LookerDashboardElement, ) -> MetadataChangeProposalWrapper: chart_urn = self._make_chart_urn( - element_id=dashboard_element.get_urn_element_id() + element_id=dashboard_element.get_urn_element_id(), ) input_fields_aspect = InputFieldsClass( - fields=self._input_fields_from_dashboard_element(dashboard_element) + fields=self._input_fields_from_dashboard_element(dashboard_element), ) return MetadataChangeProposalWrapper( @@ -1281,7 +1323,9 @@ def _make_metrics_dimensions_chart_mcp( ) def process_dashboard( - self, dashboard_id: str, fields: List[str] + self, + dashboard_id: str, + fields: List[str], ) -> Tuple[ List[MetadataWorkUnit], Optional[looker_usage.LookerDashboardForUsage], @@ -1328,17 +1372,17 @@ def process_dashboard( if ( looker_dashboard.folder_path is not None and not self.source_config.folder_path_pattern.allowed( - looker_dashboard.folder_path + looker_dashboard.folder_path, ) ): logger.debug( - f"Folder path {looker_dashboard.folder_path} is denied in folder_path_pattern" + f"Folder path {looker_dashboard.folder_path} is denied in folder_path_pattern", ) return [], None, dashboard_id, start_time, datetime.datetime.now() if looker_dashboard.folder: workunits += list( - self._get_folder_and_ancestors_workunits(looker_dashboard.folder) + self._get_folder_and_ancestors_workunits(looker_dashboard.folder), ) mces = self._make_dashboard_and_chart_mces(looker_dashboard) @@ -1347,7 +1391,8 @@ def process_dashboard( MetadataWorkUnit(id=f"looker-{mce.proposedSnapshot.urn}", mce=mce) if isinstance(mce, MetadataChangeEvent) else MetadataWorkUnit( - id=f"looker-{mce.aspectName}-{mce.entityUrn}", mcp=mce + id=f"looker-{mce.aspectName}-{mce.entityUrn}", + mcp=mce, ) ) for mce in mces @@ -1355,7 +1400,7 @@ def process_dashboard( # add on metrics, dimensions, fields events metric_dim_workunits = self.process_metrics_dimensions_and_fields_for_dashboard( - looker_dashboard + looker_dashboard, ) workunits.extend(metric_dim_workunits) @@ -1364,7 +1409,7 @@ def process_dashboard( # generate usage tracking object dashboard_usage = looker_usage.LookerDashboardForUsage.from_dashboard( - dashboard_object + dashboard_object, ) return ( @@ -1376,11 +1421,12 @@ def process_dashboard( ) def _get_folder_and_ancestors_workunits( - self, folder: LookerFolder + self, + folder: LookerFolder, ) -> Iterable[MetadataWorkUnit]: for ancestor_folder in self.looker_api.folder_ancestors(folder.id): yield from self._emit_folder_as_container( - self._get_looker_folder(ancestor_folder) + self._get_looker_folder(ancestor_folder), ) yield from self._emit_folder_as_container(folder) @@ -1444,31 +1490,34 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] def emit_independent_looks_mcp( - self, dashboard_element: LookerDashboardElement + self, + dashboard_element: LookerDashboardElement, ) -> Iterable[MetadataWorkUnit]: if dashboard_element.folder: # independent look yield from self._get_folder_and_ancestors_workunits( - dashboard_element.folder + dashboard_element.folder, ) yield from auto_workunit( stream=self._make_chart_metadata_events( dashboard_element=dashboard_element, dashboard=None, - ) + ), ) yield from auto_workunit( [ self._make_metrics_dimensions_chart_mcp( dashboard_element, - ) - ] + ), + ], ) def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: @@ -1500,7 +1549,8 @@ def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: ] all_looks: List[Look] = self.looker_api.all_looks( - fields=look_fields, soft_deleted=self.source_config.include_deleted + fields=look_fields, + soft_deleted=self.source_config.include_deleted, ) for look in all_looks: if look.id in self.reachable_look_registry: @@ -1527,7 +1577,8 @@ def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: if look.id is not None: query: Optional[Query] = self.looker_api.get_look( - look.id, fields=["query"] + look.id, + fields=["query"], ).query # Only include fields that are in the query_fields list query = Query( @@ -1535,7 +1586,7 @@ def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: key: getattr(query, key) for key in query_fields if hasattr(query, key) - } + }, ) dashboard_element: Optional[LookerDashboardElement] = ( @@ -1548,7 +1599,9 @@ def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: look_id=look.id, dashboard_id=None, # As this is an independent look look=LookWithQuery( - query=query, folder=look.folder, user_id=look.user_id + query=query, + folder=look.folder, + user_id=look.user_id, ), ), ) @@ -1557,7 +1610,7 @@ def extract_independent_looks(self) -> Iterable[MetadataWorkUnit]: if dashboard_element is not None: logger.debug(f"Emitting MCPS for look {look.title}({look.id})") yield from self.emit_independent_looks_mcp( - dashboard_element=dashboard_element + dashboard_element=dashboard_element, ) self.reporter.report_stage_end("extract_independent_looks") @@ -1575,7 +1628,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: dashboard_ids = [dashboard_base.id for dashboard_base in dashboards] dashboard_ids.extend( - [deleted_dashboard.id for deleted_dashboard in deleted_dashboards] + [deleted_dashboard.id for deleted_dashboard in deleted_dashboards], ) selected_dashboard_ids: List[Optional[str]] = [] for id in dashboard_ids: @@ -1633,7 +1686,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: end_time, ) = job.result() logger.debug( - f"Running time of process_dashboard for {dashboard_id} = {(end_time - start_time).total_seconds()}" + f"Running time of process_dashboard for {dashboard_id} = {(end_time - start_time).total_seconds()}", ) self.reporter.report_upstream_latency(start_time, end_time) @@ -1660,7 +1713,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for event in self._make_explore_metadata_events(): if isinstance(event, MetadataChangeEvent): yield MetadataWorkUnit( - id=f"looker-{event.proposedSnapshot.urn}", mce=event + id=f"looker-{event.proposedSnapshot.urn}", + mce=event, ) elif isinstance(event, MetadataChangeProposalWrapper): yield event.as_workunit() @@ -1685,7 +1739,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.source_config.extract_usage_history: self.reporter.report_stage_start("usage_extraction") usage_mcps: List[MetadataChangeProposalWrapper] = self.extract_usage_stat( - looker_dashboards_for_usage, self.chart_urns + looker_dashboards_for_usage, + self.chart_urns, ) for usage_mcp in usage_mcps: yield usage_mcp.as_workunit() @@ -1696,8 +1751,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.reporter.report_stage_start("user_resource_extraction") yield from auto_workunit( self.user_registry.to_platform_resource( - self.source_config.platform_instance - ) + self.source_config.platform_instance, + ), ) def get_report(self) -> SourceReport: diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_template_language.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_template_language.py index 2bcae4d46b8d52..e42342e2e8c1d9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_template_language.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_template_language.py @@ -49,7 +49,7 @@ def _create_new_liquid_variables_with_default( for variable in variables: keys = variable.split( - "." + ".", ) # variable is defined as view._is_selected or view.field_name._is_selected current_dict: dict = new_dict @@ -72,7 +72,7 @@ def liquid_variable_with_default(self, text: str) -> dict: [ text[m.start() : m.end()] for m in re.finditer(SpecialVariable.SPECIAL_VARIABLE_PATTERN, text) - ] + ], ) # if set is empty then no special variables are found. @@ -91,7 +91,7 @@ def resolve_liquid_variable(text: str, liquid_variable: Dict[Any, Any]) -> str: # https://cloud.google.com/looker/docs/liquid-variable-reference#usage_of_in_query_is_selected_and_is_filtered # update in liquid_variable with there default values liquid_variable = SpecialVariable(liquid_variable).liquid_variable_with_default( - text + text, ) # Resolve liquid template return create_template(text).render(liquid_variable) @@ -204,7 +204,8 @@ def transform(self, view: dict) -> dict: if SQL_TABLE_NAME in view and self.is_attribute_supported(SQL_TABLE_NAME): # Give precedence to already processed transformed view.sql_table_name to apply more transformation value_to_transform = view.get( - DATAHUB_TRANSFORMED_SQL_TABLE_NAME, view[SQL_TABLE_NAME] + DATAHUB_TRANSFORMED_SQL_TABLE_NAME, + view[SQL_TABLE_NAME], ) if ( @@ -214,7 +215,8 @@ def transform(self, view: dict) -> dict: ): # Give precedence to already processed transformed view.derived.sql to apply more transformation value_to_transform = view[DERIVED_TABLE].get( - DATAHUB_TRANSFORMED_SQL, view[DERIVED_TABLE][SQL] + DATAHUB_TRANSFORMED_SQL, + view[DERIVED_TABLE][SQL], ) if value_to_transform is None: @@ -223,7 +225,8 @@ def transform(self, view: dict) -> dict: logger.debug(f"value to transform = {value_to_transform}") transformed_value: str = self._apply_transformation( - value=value_to_transform, view=view + value=value_to_transform, + view=view, ) logger.debug(f"transformed value = {transformed_value}") @@ -313,17 +316,20 @@ def __init__(self, source_config: LookMLSourceConfig): # This regx will keep whatever after -- if looker_environment -- self.evaluate_to_true_regx = r"-- if {} --".format( - self.source_config.looker_environment + self.source_config.looker_environment, ) # It will remove all other lines starts with -- if ... -- self.remove_if_comment_line_regx = r"-- if {} --.*?(?=\n|-- if|$)".format( - dev if self.source_config.looker_environment.lower() == prod else prod + dev if self.source_config.looker_environment.lower() == prod else prod, ) def _apply_regx(self, value: str) -> str: result: str = re.sub( - self.remove_if_comment_line_regx, "", value, flags=re.IGNORECASE | re.DOTALL + self.remove_if_comment_line_regx, + "", + value, + flags=re.IGNORECASE | re.DOTALL, ) # Remove '-- if prod --' but keep the rest of the line @@ -381,7 +387,8 @@ def view(self) -> dict: logger.debug(f"Applying transformer {transformer.__class__.__name__}") self.transformed_dict = always_merger.merge( - self.transformed_dict, transformer.transform(self.transformed_dict) + self.transformed_dict, + transformer.transform(self.transformed_dict), ) return self.transformed_dict @@ -396,16 +403,16 @@ def process_lookml_template_language( transformers: List[LookMLViewTransformer] = [ LookMlIfCommentTransformer( - source_config=source_config + source_config=source_config, ), # First evaluate the -- if -- comments. Looker does the same LiquidVariableTransformer( - source_config=source_config + source_config=source_config, ), # Now resolve liquid variables DropDerivedViewPatternTransformer( - source_config=source_config + source_config=source_config, ), # Remove any ${} symbol IncompleteSqlTransformer( - source_config=source_config + source_config=source_config, ), # complete any incomplete sql ] @@ -413,7 +420,7 @@ def process_lookml_template_language( for view in view_lkml_file_dict["views"]: transformed_views.append( - TransformedLookMlView(transformers=transformers, view_dict=view).view() + TransformedLookMlView(transformers=transformers, view_dict=view).view(), ) view_lkml_file_dict["views"] = transformed_views diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_usage.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_usage.py index 05806840b5c954..71d926a0c8ff03 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_usage.py @@ -196,7 +196,8 @@ def get_filter(self) -> Dict[ViewField, str]: @abstractmethod def to_entity_absolute_stat_aspect( - self, looker_object: ModelForUsage + self, + looker_object: ModelForUsage, ) -> AspectAbstract: pass @@ -214,7 +215,10 @@ def _get_urn(self, model: ModelForUsage) -> str: @abstractmethod def append_user_stat( - self, entity_stat_aspect: Aspect, user: LookerUser, row: Dict + self, + entity_stat_aspect: Aspect, + user: LookerUser, + row: Dict, ) -> None: pass @@ -231,7 +235,9 @@ def report_skip_set(self) -> LossySet[str]: pass def create_mcp( - self, model: ModelForUsage, aspect: Aspect + self, + model: ModelForUsage, + aspect: Aspect, ) -> MetadataChangeProposalWrapper: return MetadataChangeProposalWrapper( entityUrn=self._get_urn(model=model), @@ -243,14 +249,15 @@ def _round_time(self, date_time: str) -> int: datetime.datetime.strptime(date_time, "%Y-%m-%d") .replace(tzinfo=datetime.timezone.utc) .timestamp() - * 1000 + * 1000, ) def _get_user_identifier(self, row: Dict) -> str: return row[UserViewField.USER_ID] def _process_entity_timeseries_rows( - self, rows: List[Dict] + self, + rows: List[Dict], ) -> Dict[Tuple[str, str], AspectAbstract]: # Convert Looker entity stat i.e. rows to DataHub stat aspect entity_stat_aspect: Dict[Tuple[str, str], AspectAbstract] = {} @@ -274,11 +281,12 @@ def _fill_user_stat_aspect( user_ids = {self._get_user_identifier(row) for row in user_wise_rows} start_time = datetime.datetime.now() with concurrent.futures.ThreadPoolExecutor( - max_workers=self.config.max_threads + max_workers=self.config.max_threads, ) as async_executor: user_futures = { async_executor.submit( - self.config.looker_user_registry.get_by_id, user_id + self.config.looker_user_registry.get_by_id, + user_id, ): 1 for user_id in user_ids } @@ -288,10 +296,11 @@ def _fill_user_stat_aspect( user_resolution_latency = datetime.datetime.now() - start_time logger.debug( - f"Resolved {len(user_ids)} in {user_resolution_latency.total_seconds()} seconds" + f"Resolved {len(user_ids)} in {user_resolution_latency.total_seconds()} seconds", ) self.report.report_user_resolution_latency( - self.get_stats_generator_name(), user_resolution_latency + self.get_stats_generator_name(), + user_resolution_latency, ) for row in user_wise_rows: @@ -300,33 +309,33 @@ def _fill_user_stat_aspect( if looker_object is None: logger.warning( "Looker object with id({}) was not register with stat generator".format( - self.get_id_from_row(row) - ) + self.get_id_from_row(row), + ), ) continue # Confirm we do have user with user id user: Optional[LookerUser] = self.config.looker_user_registry.get_by_id( - self._get_user_identifier(row) + self._get_user_identifier(row), ) if user is None: logger.warning( - f"Unable to resolve user with id {self._get_user_identifier(row)}, skipping" + f"Unable to resolve user with id {self._get_user_identifier(row)}, skipping", ) continue # Confirm for the user row (entity + time) the entity stat is present for the same time entity_stat_aspect: Optional[Aspect] = entity_usage_stat.get( - self.get_entity_stat_key(row) + self.get_entity_stat_key(row), ) if entity_stat_aspect is None: logger.warning( - f"entity stat is not found for the row = {row}, in entity_stat = {entity_usage_stat}" + f"entity stat is not found for the row = {row}, in entity_stat = {entity_usage_stat}", ) logger.warning( "entity stat is not found for the user stat key = {}".format( - self.get_entity_stat_key(row) - ) + self.get_entity_stat_key(row), + ), ) continue self.append_user_stat(entity_stat_aspect, user, row) @@ -341,15 +350,16 @@ def _execute_query(self, query: LookerQuery, query_name: str) -> List[Dict]: try: start_time = datetime.datetime.now() rows = self.config.looker_api_wrapper.execute_query( - write_query=query.to_write_query() + write_query=query.to_write_query(), ) end_time = datetime.datetime.now() logger.debug( - f"{self.get_stats_generator_name()}: Retrieved {len(rows)} rows in {(end_time - start_time).total_seconds()} seconds" + f"{self.get_stats_generator_name()}: Retrieved {len(rows)} rows in {(end_time - start_time).total_seconds()} seconds", ) self.report.report_query_latency( - f"{self.get_stats_generator_name()}:{query_name}", end_time - start_time + f"{self.get_stats_generator_name()}:{query_name}", + end_time - start_time, ) if self.post_filter: logger.debug("post filtering") @@ -362,7 +372,7 @@ def _execute_query(self, query: LookerQuery, query_name: str) -> List[Dict]: def _append_filters(self, query: LookerQuery) -> LookerQuery: query.filters.update( - {HistoryViewField.HISTORY_CREATED_DATE: self.config.interval} + {HistoryViewField.HISTORY_CREATED_DATE: self.config.interval}, ) if not self.post_filter: query.filters.update(self.get_filter()) @@ -380,22 +390,24 @@ def generate_usage_stat_mcps(self) -> Iterable[MetadataChangeProposalWrapper]: # Execute query and process the raw json which contains stat information entity_query_with_filters: LookerQuery = self._append_filters( - self.get_entity_timeseries_query() + self.get_entity_timeseries_query(), ) entity_rows: List[Dict] = self._execute_query( - entity_query_with_filters, "entity_query" + entity_query_with_filters, + "entity_query", ) entity_usage_stat: Dict[Tuple[str, str], Any] = ( self._process_entity_timeseries_rows(entity_rows) ) # Any type to pass mypy unbound Aspect type error user_wise_query_with_filters: LookerQuery = self._append_filters( - self.get_entity_user_timeseries_query() + self.get_entity_user_timeseries_query(), ) user_wise_rows = self._execute_query(user_wise_query_with_filters, "user_query") # yield absolute stat for entity for object_id, aspect in self._fill_user_stat_aspect( - entity_usage_stat, user_wise_rows + entity_usage_stat, + user_wise_rows, ): if object_id in self.id_to_model: yield self.create_mcp(self.id_to_model[object_id], aspect) @@ -433,8 +445,8 @@ def get_filter(self) -> Dict[ViewField, str]: looker_dashboard.id for looker_dashboard in self.looker_models if looker_dashboard.id is not None - ] - ) + ], + ), } def get_id(self, looker_object: ModelForUsage) -> str: @@ -447,7 +459,7 @@ def get_id_from_row(self, row: dict) -> str: def get_entity_stat_key(self, row: Dict) -> Tuple[str, str]: self.report.dashboards_with_activity.add( - row[HistoryViewField.HISTORY_DASHBOARD_ID] + row[HistoryViewField.HISTORY_DASHBOARD_ID], ) return ( row[HistoryViewField.HISTORY_DASHBOARD_ID], @@ -461,10 +473,12 @@ def _get_urn(self, model: ModelForUsage) -> str: return self.urn_builder(looker_common.get_urn_looker_dashboard_id(model.id)) def to_entity_absolute_stat_aspect( - self, looker_object: ModelForUsage + self, + looker_object: ModelForUsage, ) -> DashboardUsageStatisticsClass: looker_dashboard: LookerDashboardForUsage = cast( - LookerDashboardForUsage, looker_object + LookerDashboardForUsage, + looker_object, ) if looker_dashboard.view_count: self.report.dashboards_with_activity.add(str(looker_dashboard.id)) @@ -482,14 +496,15 @@ def get_entity_user_timeseries_query(self) -> LookerQuery: return query_collection[QueryId.DASHBOARD_PER_USER_PER_DAY_USAGE_STAT] def to_entity_timeseries_stat_aspect( - self, row: dict + self, + row: dict, ) -> DashboardUsageStatisticsClass: self.report.dashboards_with_activity.add( - row[HistoryViewField.HISTORY_DASHBOARD_ID] + row[HistoryViewField.HISTORY_DASHBOARD_ID], ) return DashboardUsageStatisticsClass( timestampMillis=self._round_time( - row[HistoryViewField.HISTORY_CREATED_DATE] + row[HistoryViewField.HISTORY_CREATED_DATE], ), eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY), uniqueUserCount=row[HistoryViewField.HISTORY_DASHBOARD_USER], @@ -497,10 +512,14 @@ def to_entity_timeseries_stat_aspect( ) def append_user_stat( - self, entity_stat_aspect: Aspect, user: LookerUser, row: Dict + self, + entity_stat_aspect: Aspect, + user: LookerUser, + row: Dict, ) -> None: dashboard_stat_aspect: DashboardUsageStatisticsClass = cast( - DashboardUsageStatisticsClass, entity_stat_aspect + DashboardUsageStatisticsClass, + entity_stat_aspect, ) if dashboard_stat_aspect.userCounts is None: @@ -518,7 +537,7 @@ def append_user_stat( executionsCount=row[HistoryViewField.HISTORY_DASHBOARD_RUN_COUNT], usageCount=row[HistoryViewField.HISTORY_DASHBOARD_RUN_COUNT], userEmail=user.email, - ) + ), ) @@ -548,8 +567,8 @@ def report_skip_set(self) -> LossySet[str]: def get_filter(self) -> Dict[ViewField, str]: return { LookViewField.LOOK_ID: ",".join( - [look.id for look in self.looker_models if look.id is not None] - ) + [look.id for look in self.looker_models if look.id is not None], + ), } def get_id(self, looker_object: ModelForUsage) -> str: @@ -576,7 +595,8 @@ def _get_urn(self, model: ModelForUsage) -> str: return self.urn_builder(looker_common.get_urn_looker_element_id(str(model.id))) def to_entity_absolute_stat_aspect( - self, looker_object: ModelForUsage + self, + looker_object: ModelForUsage, ) -> ChartUsageStatisticsClass: looker_look: LookerChartForUsage = cast(LookerChartForUsage, looker_object) assert looker_look.id @@ -599,17 +619,21 @@ def to_entity_timeseries_stat_aspect(self, row: dict) -> ChartUsageStatisticsCla return ChartUsageStatisticsClass( timestampMillis=self._round_time( - row[HistoryViewField.HISTORY_CREATED_DATE] + row[HistoryViewField.HISTORY_CREATED_DATE], ), eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY), viewsCount=row[HistoryViewField.HISTORY_COUNT], ) def append_user_stat( - self, entity_stat_aspect: Aspect, user: LookerUser, row: Dict + self, + entity_stat_aspect: Aspect, + user: LookerUser, + row: Dict, ) -> None: chart_stat_aspect: ChartUsageStatisticsClass = cast( - ChartUsageStatisticsClass, entity_stat_aspect + ChartUsageStatisticsClass, + entity_stat_aspect, ) if chart_stat_aspect.userCounts is None: @@ -625,7 +649,7 @@ def append_user_stat( ChartUserUsageCountsClass( user=user_urn, viewsCount=row[HistoryViewField.HISTORY_COUNT], - ) + ), ) @@ -637,8 +661,8 @@ def create_dashboard_stat_generator( ) -> DashboardStatGenerator: logger.debug( "Number of dashboard received for stat processing = {}".format( - len(looker_dashboards) - ) + len(looker_dashboards), + ), ) return DashboardStatGenerator( config=config, @@ -655,8 +679,11 @@ def create_chart_stat_generator( looker_looks: Sequence[LookerChartForUsage], ) -> LookStatGenerator: logger.debug( - "Number of looks received for stat processing = {}".format(len(looker_looks)) + "Number of looks received for stat processing = {}".format(len(looker_looks)), ) return LookStatGenerator( - config=config, looker_looks=looker_looks, report=report, urn_builder=urn_builder + config=config, + looker_looks=looker_looks, + report=report, + urn_builder=urn_builder, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_view_id_cache.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_view_id_cache.py index 562c7863b31343..156dad05ed4ea3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_view_id_cache.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_view_id_cache.py @@ -18,7 +18,7 @@ def determine_view_file_path(base_folder_path: str, absolute_file_path: str) -> splits: List[str] = absolute_file_path.split(base_folder_path, 1) if len(splits) != 2: logger.debug( - f"base_folder_path({base_folder_path}) and absolute_file_path({absolute_file_path}) not matching" + f"base_folder_path({base_folder_path}) and absolute_file_path({absolute_file_path}) not matching", ) return ViewFieldValue.NOT_AVAILABLE.value @@ -26,7 +26,7 @@ def determine_view_file_path(base_folder_path: str, absolute_file_path: str) -> logger.debug(f"file_path={file_path}") return file_path.strip( - "/" + "/", ) # strip / from path to make it equivalent to source_file attribute of LookerModelExplore API @@ -57,7 +57,8 @@ class LookerViewIdCache: model_name: str reporter: LookMLSourceReport looker_view_id_cache: Dict[ - str, LookerViewId + str, + LookerViewId, ] # Map of view-name as key, and LookerViewId instance as value def __init__( @@ -98,7 +99,8 @@ def get_looker_view_id( for view in included_looker_viewfile.views: if view[NAME] == view_name: file_path = determine_view_file_path( - base_folder_path, included_looker_viewfile.absolute_file_path + base_folder_path, + included_looker_viewfile.absolute_file_path, ) current_project_name: str = ( diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_concept_context.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_concept_context.py index 4e38165bb56286..04ce9a668664a3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_concept_context.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_concept_context.py @@ -29,7 +29,8 @@ def merge_parent_and_child_fields( - child_fields: List[dict], parent_fields: List[dict] + child_fields: List[dict], + parent_fields: List[dict], ) -> List[Dict]: # Fetch the fields from the parent view, i.e., the view name mentioned in view.extends, and include those # fields in child_fields. This inclusion will resolve the fields according to the precedence rules mentioned @@ -281,7 +282,7 @@ def resolve_extends_view_name( return self.looker_refinement_resolver.apply_view_refinement(view[1]) else: logger.warning( - f"failed to resolve view {target_view_name} included from {self.view_file.absolute_file_path}" + f"failed to resolve view {target_view_name} included from {self.view_file.absolute_file_path}", ) return None @@ -294,8 +295,8 @@ def _get_parent_attribute( """ extends = list( itertools.chain.from_iterable( - self.raw_view.get("extends", self.raw_view.get("extends__all", [])) - ) + self.raw_view.get("extends", self.raw_view.get("extends__all", [])), + ), ) # Following Looker's precedence rules. @@ -308,7 +309,7 @@ def _get_parent_attribute( if not extend_view: raise NameError( f"failed to resolve extends view {extend} in view {self.raw_view[NAME]} of" - f" file {self.view_file.absolute_file_path}" + f" file {self.view_file.absolute_file_path}", ) if attribute_name in extend_view: return extend_view[attribute_name] @@ -364,7 +365,7 @@ def sql_table_name(self) -> str: def datahub_transformed_sql_table_name(self) -> str: # This field might be present in parent view of current view table_name: Optional[str] = self.get_including_extends( - field="datahub_transformed_sql_table_name" + field="datahub_transformed_sql_table_name", ) if not table_name: @@ -419,12 +420,13 @@ def name(self) -> str: def view_file_name(self) -> str: splits: List[str] = self.view_file.absolute_file_path.split( - self.base_folder_path, 1 + self.base_folder_path, + 1, ) if len(splits) != 2: logger.debug( f"base_folder_path({self.base_folder_path}) and absolute_file_path({self.view_file.absolute_file_path})" - f" not matching" + f" not matching", ) return ViewFieldValue.NOT_AVAILABLE.value @@ -432,7 +434,7 @@ def view_file_name(self) -> str: logger.debug(f"file_path={file_name}") return file_name.strip( - "/" + "/", ) # strip / from path to make it equivalent to source_file attribute of LookerModelExplore API def _get_list_dict(self, attribute_name: str) -> List[Dict]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py index 7ffb895349ed29..ac69889e2e58c0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py @@ -79,7 +79,9 @@ def compute_stats(self) -> None: class LookMLSourceConfig( - LookerCommonConfig, StatefulIngestionConfigBase, EnvConfigMixin + LookerCommonConfig, + StatefulIngestionConfigBase, + EnvConfigMixin, ): git_info: Optional[GitInfo] = Field( None, @@ -151,7 +153,8 @@ class LookMLSourceConfig( description="When enabled, sql parsing will be executed in a separate process to prevent memory leaks.", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="" + default=None, + description="", ) process_refinements: bool = Field( False, @@ -189,10 +192,12 @@ def convert_string_to_connection_def(cls, conn_map): else: logger.warning( f"Connection map for {key} provides platform {platform} but does not provide a default " - f"database name. This might result in failed resolution" + f"database name. This might result in failed resolution", ) conn_map[key] = LookerConnectionDefinition( - platform=platform, default_db="", default_schema="" + platform=platform, + default_db="", + default_schema="", ) return conn_map @@ -200,11 +205,12 @@ def convert_string_to_connection_def(cls, conn_map): def check_either_connection_map_or_connection_provided(cls, values): """Validate that we must either have a connection map or an api credential""" if not values.get("connection_to_platform_map", {}) and not values.get( - "api", {} + "api", + {}, ): raise ValueError( "Neither api not connection_to_platform_map config was found. LookML source requires either api " - "credentials for Looker or a map of connection names to platform identifiers to work correctly" + "credentials for Looker or a map of connection names to platform identifiers to work correctly", ) return values @@ -214,13 +220,15 @@ def check_either_project_name_or_api_provided(cls, values): if not values.get("project_name") and not values.get("api"): raise ValueError( "Neither project_name not an API credential was found. LookML source requires either api credentials " - "for Looker or a project_name to accurately name views and models." + "for Looker or a project_name to accurately name views and models.", ) return values @validator("base_folder", always=True) def check_base_folder_if_not_provided( - cls, v: Optional[pydantic.DirectoryPath], values: Dict[str, Any] + cls, + v: Optional[pydantic.DirectoryPath], + values: Dict[str, Any], ) -> Optional[pydantic.DirectoryPath]: if v is None: git_info: Optional[GitInfo] = values.get("git_info") @@ -228,7 +236,7 @@ def check_base_folder_if_not_provided( if not git_info.deploy_key: logger.warning( "git_info is provided, but no SSH key is present. If the repo is not public, we'll fail to " - "clone it." + "clone it.", ) else: raise ValueError("Neither base_folder nor git_info has been provided.") diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_refinement.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_refinement.py index 6933d9d69394bc..630ca872b32a38 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_refinement.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_refinement.py @@ -38,10 +38,12 @@ class LookerRefinementResolver: source_config: LookMLSourceConfig reporter: LookMLSourceReport view_refinement_cache: Dict[ - str, dict + str, + dict, ] # Map of view-name as key, and it is raw view dictionary after applying refinement process explore_refinement_cache: Dict[ - str, dict + str, + dict, ] # Map of explore-name as key, and it is raw view dictionary after applying refinement process def __init__( @@ -66,7 +68,9 @@ def is_refinement(view_name: str) -> bool: @staticmethod def merge_column( - original_dict: dict, refinement_dict: dict, key: str + original_dict: dict, + refinement_dict: dict, + key: str, ) -> List[dict]: """ Merge a dimension/measure/other column with one from a refinement. @@ -95,10 +99,14 @@ def merge_column( @staticmethod def merge_and_set_column( - new_raw_view: dict, refinement_view: dict, key: str + new_raw_view: dict, + refinement_view: dict, + key: str, ) -> None: merged_column = LookerRefinementResolver.merge_column( - new_raw_view, refinement_view, key + new_raw_view, + refinement_view, + key, ) if merged_column: new_raw_view[key] = merged_column @@ -118,15 +126,21 @@ def merge_refinements(raw_view: dict, refinement_views: List[dict]) -> dict: # Merge Dimension LookerRefinementResolver.merge_and_set_column( - new_raw_view, refinement_view, LookerRefinementResolver.DIMENSIONS + new_raw_view, + refinement_view, + LookerRefinementResolver.DIMENSIONS, ) # Merge Measure LookerRefinementResolver.merge_and_set_column( - new_raw_view, refinement_view, LookerRefinementResolver.MEASURES + new_raw_view, + refinement_view, + LookerRefinementResolver.MEASURES, ) # Merge Dimension Group LookerRefinementResolver.merge_and_set_column( - new_raw_view, refinement_view, LookerRefinementResolver.DIMENSION_GROUPS + new_raw_view, + refinement_view, + LookerRefinementResolver.DIMENSION_GROUPS, ) return new_raw_view @@ -160,7 +174,7 @@ def get_refinement_from_model_includes(self, view_name: str) -> List[dict]: continue refined_views.extend( - self.get_refinements(included_looker_viewfile.views, view_name) + self.get_refinements(included_looker_viewfile.views, view_name), ) return refined_views @@ -192,18 +206,20 @@ def apply_view_refinement(self, raw_view: dict) -> dict: logger.debug(f"Processing refinement for view {raw_view_name}") refinement_views: List[dict] = self.get_refinement_from_model_includes( - raw_view_name + raw_view_name, ) self.view_refinement_cache[raw_view_name] = self.merge_refinements( - raw_view, refinement_views + raw_view, + refinement_views, ) return self.view_refinement_cache[raw_view_name] @staticmethod def add_extended_explore( - raw_explore: dict, refinement_explores: List[Dict] + raw_explore: dict, + refinement_explores: List[Dict], ) -> None: extended_explores: Set[str] = set() for view in refinement_explores: @@ -212,8 +228,8 @@ def add_extended_explore( view.get( LookerRefinementResolver.EXTENDS, view.get(LookerRefinementResolver.EXTENDS_ALL, []), - ) - ) + ), + ), ) extended_explores.update(extends) @@ -234,14 +250,15 @@ def apply_explore_refinement(self, raw_view: dict) -> dict: if raw_view_name in self.explore_refinement_cache: logger.debug( - f"Returning applied refined explore {raw_view_name} from cache" + f"Returning applied refined explore {raw_view_name} from cache", ) return self.explore_refinement_cache[raw_view_name] logger.debug(f"Processing refinement for explore {raw_view_name}") refinement_explore: List[dict] = self.get_refinements( - self.looker_model.explores, raw_view_name + self.looker_model.explores, + raw_view_name, ) self.add_extended_explore(raw_view, refinement_explore) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py index a8575c84b510d5..6cb5a08f529667 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py @@ -114,12 +114,14 @@ class LookerView: @classmethod def determine_view_file_path( - cls, base_folder_path: str, absolute_file_path: str + cls, + base_folder_path: str, + absolute_file_path: str, ) -> str: splits: List[str] = absolute_file_path.split(base_folder_path, 1) if len(splits) != 2: logger.debug( - f"base_folder_path({base_folder_path}) and absolute_file_path({absolute_file_path}) not matching" + f"base_folder_path({base_folder_path}) and absolute_file_path({absolute_file_path}) not matching", ) return ViewFieldValue.NOT_AVAILABLE.value @@ -127,7 +129,7 @@ def determine_view_file_path( logger.debug(f"file_path={file_path}") return file_path.strip( - "/" + "/", ) # strip / from path to make it equivalent to source_file attribute of LookerModelExplore API @classmethod @@ -168,7 +170,7 @@ def from_looker_dict( ViewFieldType.DIMENSION: view_context.dimensions(), ViewFieldType.DIMENSION_GROUP: view_context.dimension_groups(), ViewFieldType.MEASURE: view_context.measures(), - } + }, ) # in order to maintain order in golden file view_fields: List[ViewField] = [] @@ -178,7 +180,7 @@ def from_looker_dict( upstream_column_ref: List[ColumnRef] = [] if extract_col_level_lineage: upstream_column_ref = view_upstream.get_upstream_column_ref( - field_context=LookerFieldContext(raw_field=field) + field_context=LookerFieldContext(raw_field=field), ) view_fields.append( @@ -187,7 +189,7 @@ def from_looker_dict( upstream_column_ref=upstream_column_ref, type_cls=field_type, populate_sql_logic_in_descriptions=populate_sql_logic_in_descriptions, - ) + ), ) # special case where view is defined as derived sql, however fields are not defined @@ -221,7 +223,9 @@ def from_looker_dict( materialized = view_context.is_materialized_derived_view() view_details = ViewProperties( - materialized=materialized, viewLogic=view_logic, viewLanguage=view_lang + materialized=materialized, + viewLogic=view_logic, + viewLanguage=view_lang, ) else: view_details = ViewProperties( @@ -306,7 +310,7 @@ def __init__(self, config: LookMLSourceConfig, ctx: PipelineContext): except SDKError as err: raise ValueError( "Failed to retrieve connections from looker client. Please check to ensure that you have " - "manage_models permission enabled on this API key." + "manage_models permission enabled on this API key.", ) from err def _load_model(self, path: str) -> LookerModel: @@ -329,7 +333,8 @@ def _load_model(self, path: str) -> LookerModel: return looker_model def _get_upstream_lineage( - self, looker_view: LookerView + self, + looker_view: LookerView, ) -> Optional[UpstreamLineage]: upstream_dataset_urns = looker_view.upstream_dataset_urns @@ -362,14 +367,15 @@ def _get_upstream_lineage( make_schema_field_urn( looker_view.id.get_urn(self.source_config), field.name, - ) + ), ], - ) + ), ) if upstreams: return UpstreamLineage( - upstreams=upstreams, fineGrainedLineages=fine_grained_lineages or None + upstreams=upstreams, + fineGrainedLineages=fine_grained_lineages or None, ) else: return None @@ -377,18 +383,19 @@ def _get_upstream_lineage( def _get_custom_properties(self, looker_view: LookerView) -> DatasetPropertiesClass: assert self.source_config.base_folder # this is always filled out base_folder = self.base_projects_folder.get( - looker_view.id.project_name, self.source_config.base_folder + looker_view.id.project_name, + self.source_config.base_folder, ) try: file_path = str( pathlib.Path(looker_view.absolute_file_path).relative_to( - base_folder.resolve() - ) + base_folder.resolve(), + ), ) except Exception: file_path = None logger.warning( - f"Failed to resolve relative path for file {looker_view.absolute_file_path} w.r.t. folder {self.source_config.base_folder}" + f"Failed to resolve relative path for file {looker_view.absolute_file_path} w.r.t. folder {self.source_config.base_folder}", ) custom_properties = { @@ -396,7 +403,8 @@ def _get_custom_properties(self, looker_view: LookerView) -> DatasetPropertiesCl "looker.model": looker_view.id.model_name, } dataset_props = DatasetPropertiesClass( - name=looker_view.id.view_name, customProperties=custom_properties + name=looker_view.id.view_name, + customProperties=custom_properties, ) maybe_git_info = self.source_config.project_dependencies.get( @@ -415,7 +423,8 @@ def _get_custom_properties(self, looker_view: LookerView) -> DatasetPropertiesCl return dataset_props def _build_dataset_mcps( - self, looker_view: LookerView + self, + looker_view: LookerView, ) -> List[MetadataChangeProposalWrapper]: view_urn = looker_view.id.get_urn(self.source_config) @@ -435,14 +444,14 @@ def _build_dataset_mcps( container = ContainerClass(container=project_key.as_urn()) events.append( - MetadataChangeProposalWrapper(entityUrn=view_urn, aspect=container) + MetadataChangeProposalWrapper(entityUrn=view_urn, aspect=container), ) events.append( MetadataChangeProposalWrapper( entityUrn=view_urn, aspect=looker_view.id.get_browse_path_v2(self.source_config), - ) + ), ) return events @@ -458,7 +467,7 @@ def _build_dataset_mce(self, looker_view: LookerView) -> MetadataChangeEvent: aspects=[], # we append to this list later on ) browse_paths = BrowsePaths( - paths=[looker_view.id.get_browse_path(self.source_config)] + paths=[looker_view.id.get_browse_path(self.source_config)], ) dataset_snapshot.aspects.append(browse_paths) @@ -494,14 +503,15 @@ def get_project_name(self, model_name: str) -> str: except SDKError: raise ValueError( f"Could not locate a project name for model {model_name}. Consider configuring a static project name " - f"in your config file" + f"in your config file", ) def get_manifest_if_present(self, folder: pathlib.Path) -> Optional[LookerManifest]: manifest_file = folder / "manifest.lkml" if manifest_file.exists(): manifest_dict = load_and_preprocess_file( - path=manifest_file, source_config=self.source_config + path=manifest_file, + source_config=self.source_config, ) manifest = LookerManifest( @@ -511,7 +521,9 @@ def get_manifest_if_present(self, folder: pathlib.Path) -> Optional[LookerManife ], remote_dependencies=[ LookerRemoteDependency( - name=x["name"], url=x["url"], ref=x.get("ref") + name=x["name"], + url=x["url"], + ref=x.get("ref"), ) for x in manifest_dict.get("remote_dependencys", []) ], @@ -524,7 +536,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] @@ -574,7 +588,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.base_projects_folder[project] = p_ref self._recursively_check_manifests( - tmp_dir, BASE_PROJECT_NAME, visited_projects + tmp_dir, + BASE_PROJECT_NAME, + visited_projects, ) yield from self.get_internal_workunits() @@ -587,7 +603,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) def _recursively_check_manifests( - self, tmp_dir: str, project_name: str, project_visited: Set[str] + self, + tmp_dir: str, + project_name: str, + project_visited: Set[str], ) -> None: if project_name in project_visited: return @@ -613,7 +632,7 @@ def _recursively_check_manifests( logger.warning( f"The project name in the manifest file '{manifest.project_name}'" f"does not match the configured project name '{self.source_config.project_name}'. " - "This can lead to failures in LookML include resolution and lineage generation." + "This can lead to failures in LookML include resolution and lineage generation.", ) elif self.source_config.project_name is None: self.source_config.project_name = manifest.project_name @@ -663,7 +682,9 @@ def _recursively_check_manifests( project_visited.add(project_name) else: self._recursively_check_manifests( - tmp_dir, remote_project.name, project_visited + tmp_dir, + remote_project.name, + project_visited, ) for project in manifest.local_dependencies: @@ -694,7 +715,7 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 # The ** means "this directory and all subdirectories", and hence should # include all the files we want. model_files = sorted( - self.source_config.base_folder.glob(f"**/*{MODEL_FILE_EXTENSION}") + self.source_config.base_folder.glob(f"**/*{MODEL_FILE_EXTENSION}"), ) model_suffix_len = len(".model") @@ -754,7 +775,7 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 explore_dict = ( looker_refinement_resolver.apply_explore_refinement( - explore_dict + explore_dict, ) ) explore: LookerExplore = LookerExplore.from_dict( @@ -778,7 +799,8 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 logger.debug("Failed to process explore", exc_info=e) processed_view_files = processed_view_map.setdefault( - model.connection, set() + model.connection, + set(), ) project_name = self.get_project_name(model_name) @@ -817,7 +839,7 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 and raw_view_name not in explore_reachable_views ): logger.debug( - f"view {raw_view_name} is not reachable from an explore, skipping.." + f"view {raw_view_name} is not reachable from an explore, skipping..", ) self.reporter.report_unreachable_view_dropped(raw_view_name) continue @@ -840,7 +862,7 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 self.base_projects_folder.get( current_project_name, self.base_projects_folder[BASE_PROJECT_NAME], - ) + ), ) view_context: LookerViewContext = LookerViewContext( @@ -879,13 +901,13 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 if maybe_looker_view: if self.source_config.view_pattern.allowed( - maybe_looker_view.id.view_name + maybe_looker_view.id.view_name, ): view_urn = maybe_looker_view.id.get_urn( - self.source_config + self.source_config, ) view_connection_mapping = view_connection_map.get( - view_urn + view_urn, ) if not view_connection_mapping: view_connection_map[view_urn] = ( @@ -894,7 +916,7 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 ) # first time we are discovering this view logger.debug( - f"Generating MCP for view {raw_view['name']}" + f"Generating MCP for view {raw_view['name']}", ) if ( @@ -902,15 +924,15 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 not in self.processed_projects ): yield from self.gen_project_workunits( - maybe_looker_view.id.project_name + maybe_looker_view.id.project_name, ) self.processed_projects.append( - maybe_looker_view.id.project_name + maybe_looker_view.id.project_name, ) for mcp in self._build_dataset_mcps( - maybe_looker_view + maybe_looker_view, ): yield mcp.as_workunit() mce = self._build_dataset_mce(maybe_looker_view) @@ -928,15 +950,15 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 # this view has previously been discovered and emitted using a different # connection logger.warning( - f"view {maybe_looker_view.id.view_name} from model {model_name}, connection {model.connection} was previously processed via model {prev_model_name}, connection {prev_model_connection} and will likely lead to incorrect lineage to the underlying tables" + f"view {maybe_looker_view.id.view_name} from model {model_name}, connection {model.connection} was previously processed via model {prev_model_name}, connection {prev_model_connection} and will likely lead to incorrect lineage to the underlying tables", ) if not self.source_config.emit_reachable_views_only: logger.warning( - "Consider enabling the `emit_reachable_views_only` flag to handle this case." + "Consider enabling the `emit_reachable_views_only` flag to handle this case.", ) else: self.reporter.report_views_dropped( - str(maybe_looker_view.id) + str(maybe_looker_view.id), ) if ( @@ -946,7 +968,8 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 # Emit tag MCEs for measures and dimensions: for tag_mce in LookerUtil.get_tag_mces(): yield MetadataWorkUnit( - id=f"tag-{tag_mce.proposedSnapshot.urn}", mce=tag_mce + id=f"tag-{tag_mce.proposedSnapshot.urn}", + mce=tag_mce, ) def gen_project_workunits(self, project_name: str) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/urn_functions.py b/metadata-ingestion/src/datahub/ingestion/source/looker/urn_functions.py index 7286beb1f977a9..b5e0df01b7e229 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/urn_functions.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/urn_functions.py @@ -3,7 +3,7 @@ def get_qualified_table_name(urn: str) -> str: if len(part.split(".")) >= 4: return ".".join( - part.split(".")[-3:] + part.split(".")[-3:], ) # return only db.schema.table skip platform instance as higher code is # failing if encounter platform-instance in qualified table name else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/view_upstream.py b/metadata-ingestion/src/datahub/ingestion/source/looker/view_upstream.py index f77eebb3cdd8cb..a9e65e723bf4e7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/view_upstream.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/view_upstream.py @@ -56,7 +56,9 @@ def get_derived_looker_view_id( # 3) employee_income_source.sql_table_name # In any of the form we need the text coming before ".sql_table_name" and after last "." parts: List[str] = re.split( - DERIVED_VIEW_SUFFIX, qualified_table_name, flags=re.IGNORECASE + DERIVED_VIEW_SUFFIX, + qualified_table_name, + flags=re.IGNORECASE, ) view_name: str = parts[0].split(".")[-1] @@ -85,7 +87,7 @@ def resolve_derived_view_urn_of_col_ref( ) if not new_urns: logger.warning( - f"Not able to resolve to derived view looker id for {col_ref.table}" + f"Not able to resolve to derived view looker id for {col_ref.table}", ) continue @@ -114,7 +116,7 @@ def fix_derived_view_urn( if looker_view_id is None: logger.warning( - f"Not able to resolve to derived view looker id for {urn}" + f"Not able to resolve to derived view looker id for {urn}", ) continue @@ -227,7 +229,8 @@ def __init__( @abstractmethod def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: pass @@ -239,7 +242,9 @@ def create_fields(self) -> List[ViewField]: return [] # it is for the special case def create_upstream_column_refs( - self, upstream_urn: str, downstream_looker_columns: List[str] + self, + upstream_urn: str, + downstream_looker_columns: List[str], ) -> List[ColumnRef]: """ - **`upstream_urn`**: The URN of the upstream dataset. @@ -259,11 +264,12 @@ def create_upstream_column_refs( if schema_info: actual_columns = match_columns_to_schema( - schema_info, downstream_looker_columns + schema_info, + downstream_looker_columns, ) else: logger.info( - f"schema_info not found for dataset {urn} in GMS. Using expected_columns to form ColumnRef" + f"schema_info not found for dataset {urn} in GMS. Using expected_columns to form ColumnRef", ) actual_columns = [column.lower() for column in downstream_looker_columns] @@ -274,7 +280,7 @@ def create_upstream_column_refs( ColumnRef( column=column, table=upstream_urn, - ) + ), ) return upstream_column_refs @@ -297,7 +303,7 @@ def __init__( # These are the function where we need to catch the response once calculated self._get_spr = lru_cache(maxsize=1)(self.__get_spr) self._get_upstream_dataset_urn = lru_cache(maxsize=1)( - self.__get_upstream_dataset_urn + self.__get_upstream_dataset_urn, ) def __get_spr(self) -> Optional[SqlParsingResult]: @@ -324,7 +330,7 @@ def __get_upstream_dataset_urn(self) -> List[Urn]: if sql_parsing_result.debug_info.table_error is not None: logger.debug( - f"view-name={self.view_context.name()}, sql_query={self.get_sql_query()}" + f"view-name={self.view_context.name()}, sql_query={self.get_sql_query()}", ) self.reporter.report_warning( title="Table Level Lineage Missing", @@ -382,13 +388,14 @@ def create_fields(self) -> List[ViewField]: description="", field_type=ViewFieldType.UNKNOWN, upstream_fields=_drop_hive_dot_from_upstream(cll.upstreams), - ) + ), ) return fields def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: sql_parsing_result: Optional[SqlParsingResult] = self._get_spr() @@ -477,10 +484,10 @@ def __init__( super().__init__(view_context, looker_view_id_cache, config, reporter, ctx) self._get_upstream_dataset_urn = lru_cache(maxsize=1)( - self.__get_upstream_dataset_urn + self.__get_upstream_dataset_urn, ) self._get_explore_column_mapping = lru_cache(maxsize=1)( - self.__get_explore_column_mapping + self.__get_explore_column_mapping, ) def __get_upstream_dataset_urn(self) -> List[str]: @@ -499,7 +506,7 @@ def __get_upstream_dataset_urn(self) -> List[str]: LookerExplore( name=self.view_context.explore_source()[NAME], model_name=current_view_id.model_name, - ).get_explore_urn(self.config) + ).get_explore_urn(self.config), ] return upstream_dataset_urns @@ -515,14 +522,15 @@ def __get_explore_column_mapping(self) -> Dict: return explore_column_mapping def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: upstream_column_refs: List[ColumnRef] = [] if not self._get_upstream_dataset_urn(): # No upstream explore dataset found logging.debug( - f"upstream explore not found for field {field_context.name()} of view {self.view_context.name()}" + f"upstream explore not found for field {field_context.name()} of view {self.view_context.name()}", ) return upstream_column_refs @@ -533,11 +541,12 @@ def get_upstream_column_ref( if column in self._get_explore_column_mapping(): explore_column: Dict = self._get_explore_column_mapping()[column] expected_columns.append( - explore_column.get("field", explore_column[NAME]) + explore_column.get("field", explore_column[NAME]), ) return self.create_upstream_column_refs( - upstream_urn=explore_urn, downstream_looker_columns=expected_columns + upstream_urn=explore_urn, + downstream_looker_columns=expected_columns, ) def get_upstream_dataset_urn(self) -> List[Urn]: @@ -563,7 +572,7 @@ def __init__( self.upstream_dataset_urn = None self._get_upstream_dataset_urn = lru_cache(maxsize=1)( - self.__get_upstream_dataset_urn + self.__get_upstream_dataset_urn, ) def __get_upstream_dataset_urn(self) -> Urn: @@ -586,7 +595,8 @@ def __get_upstream_dataset_urn(self) -> Urn: return self.upstream_dataset_urn def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: return self.create_upstream_column_refs( upstream_urn=self._get_upstream_dataset_urn(), @@ -616,7 +626,7 @@ def __init__( self.upstream_dataset_urn = [] self._get_upstream_dataset_urn = lru_cache(maxsize=1)( - self.__get_upstream_dataset_urn + self.__get_upstream_dataset_urn, ) def __get_upstream_dataset_urn(self) -> List[Urn]: @@ -636,13 +646,14 @@ def __get_upstream_dataset_urn(self) -> List[Urn]: self.upstream_dataset_urn = [ looker_view_id.get_urn( config=self.config, - ) + ), ] return self.upstream_dataset_urn def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: upstream_column_ref: List[ColumnRef] = [] @@ -660,7 +671,8 @@ def get_upstream_dataset_urn(self) -> List[Urn]: class EmptyImplementation(AbstractViewUpstream): def get_upstream_column_ref( - self, field_context: LookerFieldContext + self, + field_context: LookerFieldContext, ) -> List[ColumnRef]: return [] @@ -697,7 +709,7 @@ def create_view_upstream( [ view_context.is_sql_based_derived_case(), view_context.is_sql_based_derived_view_without_fields_case(), - ] + ], ): return DerivedQueryUpstreamSource( view_context=view_context, diff --git a/metadata-ingestion/src/datahub/ingestion/source/metabase.py b/metadata-ingestion/src/datahub/ingestion/source/metabase.py index ef16dc0a49a223..c9dd83cd391af6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metabase.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metabase.py @@ -71,7 +71,8 @@ class MetabaseConfig(DatasetLineageProviderConfigBase, StatefulIngestionConfigBa ) username: Optional[str] = Field(default=None, description="Metabase username.") password: Optional[pydantic.SecretStr] = Field( - default=None, description="Metabase password." + default=None, + description="Metabase password.", ) # TODO: Check and remove this if no longer needed. # Config database_alias is removed from sql sources. @@ -200,13 +201,13 @@ def setup_session(self) -> None: "X-Metabase-Session": f"{self.access_token}", "Content-Type": "application/json", "Accept": "*/*", - } + }, ) # Test the connection try: test_response = self.session.get( - f"{self.config.connect_uri}/api/user/current" + f"{self.config.connect_uri}/api/user/current", ) test_response.raise_for_status() except HTTPError as e: @@ -232,14 +233,14 @@ def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: try: collections_response = self.session.get( f"{self.config.connect_uri}/api/collection/" - f"?exclude-other-user-collections={json.dumps(self.config.exclude_other_user_collections)}" + f"?exclude-other-user-collections={json.dumps(self.config.exclude_other_user_collections)}", ) collections_response.raise_for_status() collections = collections_response.json() for collection in collections: collection_dashboards_response = self.session.get( - f"{self.config.connect_uri}/api/collection/{collection['id']}/items?models=dashboard" + f"{self.config.connect_uri}/api/collection/{collection['id']}/items?models=dashboard", ) collection_dashboards_response.raise_for_status() collection_dashboards = collection_dashboards_response.json() @@ -249,7 +250,7 @@ def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: for dashboard_info in collection_dashboards.get("data"): dashboard_snapshot = self.construct_dashboard_from_api_data( - dashboard_info + dashboard_info, ) if dashboard_snapshot is not None: mce = MetadataChangeEvent(proposedSnapshot=dashboard_snapshot) @@ -274,7 +275,8 @@ def get_timestamp_millis_from_ts_string(ts_str: str) -> int: return int(datetime.now(timezone.utc).timestamp() * 1000) def construct_dashboard_from_api_data( - self, dashboard_info: dict + self, + dashboard_info: dict, ) -> Optional[DashboardSnapshot]: dashboard_id = dashboard_info.get("id", "") dashboard_url = f"{self.config.connect_uri}/api/dashboard/{dashboard_id}" @@ -291,7 +293,8 @@ def construct_dashboard_from_api_data( return None dashboard_urn = builder.make_dashboard_urn( - self.platform, dashboard_details.get("id", "") + self.platform, + dashboard_details.get("id", ""), ) dashboard_snapshot = DashboardSnapshot( urn=dashboard_urn, @@ -300,7 +303,7 @@ def construct_dashboard_from_api_data( last_edit_by = dashboard_details.get("last-edit-info") or {} modified_actor = builder.make_user_urn(last_edit_by.get("email", "unknown")) modified_ts = self.get_timestamp_millis_from_ts_string( - f"{last_edit_by.get('timestamp')}" + f"{last_edit_by.get('timestamp')}", ) title = dashboard_details.get("name", "") or "" description = dashboard_details.get("description", "") or "" @@ -368,8 +371,8 @@ def _get_ownership(self, creator_id: int) -> Optional[OwnershipClass]: OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return ownership @@ -446,7 +449,7 @@ def construct_card_from_api_data(self, card_data: dict) -> Optional[ChartSnapsho last_edit_by = card_details.get("last-edit-info") or {} modified_actor = builder.make_user_urn(last_edit_by.get("email", "unknown")) modified_ts = self.get_timestamp_millis_from_ts_string( - f"{last_edit_by.get('timestamp')}" + f"{last_edit_by.get('timestamp')}", ) last_modified = ChangeAuditStamps( created=None, @@ -537,7 +540,8 @@ def construct_card_custom_properties(self, card_details: dict) -> Dict: ) filters = (card_details.get("dataset_query", {}).get("query", {})).get( - "filter", [] + "filter", + [], ) custom_properties = { @@ -549,7 +553,9 @@ def construct_card_custom_properties(self, card_details: dict) -> Dict: return custom_properties def get_datasource_urn( - self, card_details: dict, recursion_depth: int = 0 + self, + card_details: dict, + recursion_depth: int = 0, ) -> Optional[List]: if recursion_depth > DATASOURCE_URN_RECURSION_LIMIT: self.report.report_warning( @@ -588,7 +594,7 @@ def get_datasource_urn( # trying to get source table from source question. Recursion depth is limited return self.get_datasource_urn( card_details=self.get_card_details_by_id( - source_table_id.replace("card__", "") + source_table_id.replace("card__", ""), ), recursion_depth=recursion_depth + 1, ) @@ -603,11 +609,13 @@ def get_datasource_urn( name=".".join([v for v in name_components if v]), platform_instance=platform_instance, env=self.config.env, - ) + ), ] else: raw_query_stripped = self.strip_template_expressions( - card_details.get("dataset_query", {}).get("native", {}).get("query", "") + card_details.get("dataset_query", {}) + .get("native", {}) + .get("query", ""), ) result = create_lineage_sql_parsed_result( @@ -622,7 +630,7 @@ def get_datasource_urn( if result.debug_info.table_error: logger.info( f"Failed to parse lineage from query {raw_query_stripped}: " - f"{result.debug_info.table_error}" + f"{result.debug_info.table_error}", ) self.report.report_warning( title="Failed to Extract Lineage", @@ -653,11 +661,12 @@ def strip_template_expressions(raw_query: str) -> str: @lru_cache(maxsize=None) def get_source_table_from_id( - self, table_id: Union[int, str] + self, + table_id: Union[int, str], ) -> Tuple[Optional[str], Optional[str]]: try: dataset_response = self.session.get( - f"{self.config.connect_uri}/api/table/{table_id}" + f"{self.config.connect_uri}/api/table/{table_id}", ) dataset_response.raise_for_status() dataset_json = dataset_response.json() @@ -676,7 +685,9 @@ def get_source_table_from_id( @lru_cache(maxsize=None) def get_platform_instance( - self, platform: Optional[str] = None, datasource_id: Optional[int] = None + self, + platform: Optional[str] = None, + datasource_id: Optional[int] = None, ) -> Optional[str]: """ Method will attempt to detect `platform_instance` by checking @@ -694,7 +705,7 @@ def get_platform_instance( # For cases when metabase has several platform instances (e.g. several individual ClickHouse clusters) if datasource_id is not None and self.config.database_id_to_instance_map: platform_instance = self.config.database_id_to_instance_map.get( - str(datasource_id) + str(datasource_id), ) # If Metabase datasource ID is not mapped to platform instace, fall back to platform mapping @@ -706,11 +717,12 @@ def get_platform_instance( @lru_cache(maxsize=None) def get_datasource_from_id( - self, datasource_id: Union[int, str] + self, + datasource_id: Union[int, str], ) -> Tuple[str, Optional[str], Optional[str], Optional[str]]: try: dataset_response = self.session.get( - f"{self.config.connect_uri}/api/database/{datasource_id}" + f"{self.config.connect_uri}/api/database/{datasource_id}", ) dataset_response.raise_for_status() dataset_json = dataset_response.json() @@ -751,7 +763,8 @@ def get_datasource_from_id( ) platform_instance = self.get_platform_instance( - platform, dataset_json.get("id", None) + platform, + dataset_json.get("id", None), ) field_for_dbname_mapping = { @@ -793,7 +806,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py index 26a0331e1e5767..6486107c9d8955 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py @@ -34,7 +34,9 @@ logger = logging.getLogger(__name__) GlossaryNodeInterface = TypeVar( - "GlossaryNodeInterface", "GlossaryNodeConfig", "BusinessGlossaryConfig" + "GlossaryNodeInterface", + "GlossaryNodeConfig", + "BusinessGlossaryConfig", ) @@ -98,7 +100,7 @@ class DefaultConfig(ConfigModel): class BusinessGlossarySourceConfig(ConfigModel): file: Union[str, pathlib.Path] = Field( - description="File path or URL to business glossary file to ingest." + description="File path or URL to business glossary file to ingest.", ) enable_auto_id: bool = Field( description="Generate guid urns instead of a plaintext path urn with the node/term's hierarchy.", @@ -133,11 +135,13 @@ def create_id(path: List[str], default_id: Optional[str], enable_auto_id: bool) def make_glossary_node_urn( - path: List[str], default_id: Optional[str], enable_auto_id: bool + path: List[str], + default_id: Optional[str], + enable_auto_id: bool, ) -> str: if default_id is not None and default_id.startswith("urn:li:glossaryNode:"): logger.debug( - f"node's default_id({default_id}) is in urn format for path {path}. Returning same as urn" + f"node's default_id({default_id}) is in urn format for path {path}. Returning same as urn", ) return default_id @@ -145,11 +149,13 @@ def make_glossary_node_urn( def make_glossary_term_urn( - path: List[str], default_id: Optional[str], enable_auto_id: bool + path: List[str], + default_id: Optional[str], + enable_auto_id: bool, ) -> str: if default_id is not None and default_id.startswith("urn:li:glossaryTerm:"): logger.debug( - f"term's default_id({default_id}) is in urn format for path {path}. Returning same as urn" + f"term's default_id({default_id}) is in urn format for path {path}. Returning same as urn", ) return default_id @@ -234,7 +240,8 @@ def get_mce_from_snapshot(snapshot: Any) -> models.MetadataChangeEventClass: def make_institutional_memory_mcp( - urn: str, knowledge_cards: List[KnowledgeCard] + urn: str, + knowledge_cards: List[KnowledgeCard], ) -> Optional[MetadataChangeProposalWrapper]: elements: List[models.InstitutionalMemoryMetadataClass] = [] @@ -249,7 +256,7 @@ def make_institutional_memory_mcp( actor="urn:li:corpuser:datahub", message="ingestion bot", ), - ) + ), ) if elements: @@ -262,7 +269,8 @@ def make_institutional_memory_mcp( def make_domain_mcp( - term_urn: str, domain_aspect: models.DomainsClass + term_urn: str, + domain_aspect: models.DomainsClass, ) -> MetadataChangeProposalWrapper: return MetadataChangeProposalWrapper(entityUrn=term_urn, aspect=domain_aspect) @@ -297,7 +305,8 @@ def get_mces_from_node( if glossaryNode.knowledge_links is not None: mcp: Optional[MetadataChangeProposalWrapper] = make_institutional_memory_mcp( - node_urn, glossaryNode.knowledge_links + node_urn, + glossaryNode.knowledge_links, ) if mcp is not None: yield mcp @@ -328,17 +337,19 @@ def get_mces_from_node( def get_domain_class( - graph: Optional[DataHubGraph], domains: List[str] + graph: Optional[DataHubGraph], + domains: List[str], ) -> models.DomainsClass: # FIXME: In the ideal case, the domain registry would be an instance variable so that it # preserves its cache across calls to this function. However, the current implementation # requires the full list of domains to be passed in at instantiation time, so we can't # actually do that. domain_registry: DomainRegistry = DomainRegistry( - cached_domains=[k for k in domains], graph=graph + cached_domains=[k for k in domains], + graph=graph, ) domain_class = models.DomainsClass( - domains=[domain_registry.get_domain_urn(domain) for domain in domains] + domains=[domain_registry.get_domain_urn(domain) for domain in domains], ) return domain_class @@ -448,7 +459,8 @@ def get_mces_from_term( if glossaryTerm.domain is not None: yield make_domain_mcp( - term_urn, get_domain_class(ctx.graph, [glossaryTerm.domain]) + term_urn, + get_domain_class(ctx.graph, [glossaryTerm.domain]), ) term_snapshot: models.GlossaryTermSnapshotClass = models.GlossaryTermSnapshotClass( @@ -459,28 +471,35 @@ def get_mces_from_term( if glossaryTerm.knowledge_links: mcp: Optional[MetadataChangeProposalWrapper] = make_institutional_memory_mcp( - term_urn, glossaryTerm.knowledge_links + term_urn, + glossaryTerm.knowledge_links, ) if mcp is not None: yield mcp def materialize_all_node_urns( - glossary: BusinessGlossaryConfig, enable_auto_id: bool + glossary: BusinessGlossaryConfig, + enable_auto_id: bool, ) -> None: """After this runs, all nodes will have an id value that is a valid urn.""" def _process_child_terms( - parent_node: GlossaryNodeInterface, path: List[str] + parent_node: GlossaryNodeInterface, + path: List[str], ) -> None: for term in parent_node.terms or []: term._urn = make_glossary_term_urn( - path + [term.name], term.id, enable_auto_id + path + [term.name], + term.id, + enable_auto_id, ) for node in parent_node.nodes or []: node._urn = make_glossary_node_urn( - path + [node.name], node.id, enable_auto_id + path + [node.name], + node.id, + enable_auto_id, ) _process_child_terms(node, path + [node.name]) @@ -493,7 +512,8 @@ def populate_path_vs_id(glossary: BusinessGlossaryConfig) -> Dict[str, str]: path_vs_id: Dict[str, str] = {} def _process_child_terms( - parent_node: GlossaryNodeInterface, path: List[str] + parent_node: GlossaryNodeInterface, + path: List[str], ) -> None: for term in parent_node.terms or []: path_vs_id[".".join(path + [term.name])] = term._urn @@ -526,7 +546,8 @@ def create(cls, config_dict, ctx): @classmethod def load_glossary_config( - cls, file_name: Union[str, pathlib.Path] + cls, + file_name: Union[str, pathlib.Path], ) -> BusinessGlossaryConfig: config = load_config_file(file_name, resolve_env_vars=True) glossary_cfg = BusinessGlossaryConfig.parse_obj(config) @@ -542,8 +563,11 @@ def get_workunits_internal( yield from auto_workunit( get_mces( - glossary_config, path_vs_id, ingestion_config=self.config, ctx=self.ctx - ) + glossary_config, + path_vs_id, + ingestion_config=self.config, + ctx=self.ctx, + ), ) def get_report(self): diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py index 9f96f837eb9b3a..8f36b810abedb4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py @@ -55,7 +55,7 @@ def type_must_be_supported(cls, v: str) -> str: allowed_types = ["dataset"] if v not in allowed_types: raise ValueError( - f"Type must be one of {allowed_types}, {v} is not yet supported." + f"Type must be one of {allowed_types}, {v} is not yet supported.", ) return v @@ -63,7 +63,7 @@ def type_must_be_supported(cls, v: str) -> str: def validate_name(cls, v: str) -> str: if v.startswith("urn:li:"): raise ValueError( - "Name should not start with urn:li: - use a plain name, not an urn" + "Name should not start with urn:li: - use a plain name, not an urn", ) return v @@ -85,7 +85,7 @@ def upstream_type_must_be_supported(cls, v: str) -> str: ] if v not in allowed_types: raise ValueError( - f"Upstream Type must be one of {allowed_types}, {v} is not yet supported." + f"Upstream Type must be one of {allowed_types}, {v} is not yet supported.", ) return v @@ -97,7 +97,7 @@ def downstream_type_must_be_supported(cls, v: str) -> str: ] if v not in allowed_types: raise ValueError( - f"Downstream Type must be one of {allowed_types}, {v} is not yet supported." + f"Downstream Type must be one of {allowed_types}, {v} is not yet supported.", ) return v @@ -145,7 +145,9 @@ class LineageFileSource(Source): @classmethod def create( - cls, config_dict: Dict[str, Any], ctx: PipelineContext + cls, + config_dict: Dict[str, Any], + ctx: PipelineContext, ) -> "LineageFileSource": config = LineageFileSourceConfig.parse_obj(config_dict) return cls(ctx, config) @@ -189,7 +191,8 @@ def _get_entity_urn(entity_config: EntityConfig) -> Optional[str]: def _get_lineage_mcp( - entity_node: EntityNodeConfig, preserve_upstream: bool + entity_node: EntityNodeConfig, + preserve_upstream: bool, ) -> Optional[MetadataChangeProposalWrapper]: new_upstreams: List[models.UpstreamClass] = [] new_fine_grained_lineages: List[models.FineGrainedLineageClass] = [] @@ -204,7 +207,7 @@ def _get_lineage_mcp( if not entity_urn: logger.warning( f"Entity type: {entity.type} is unsupported. " - f"Entity node {entity.name} and its upstream lineages will be skipped" + f"Entity node {entity.name} and its upstream lineages will be skipped", ) return None @@ -234,14 +237,15 @@ def _get_lineage_mcp( dataset=upstream_entity_urn, type=models.DatasetLineageTypeClass.TRANSFORMED, auditStamp=models.AuditStampClass( - time=get_sys_time(), actor="urn:li:corpUser:ingestion" + time=get_sys_time(), + actor="urn:li:corpUser:ingestion", ), - ) + ), ) else: logger.warning( f"Entity type: {upstream_entity.type} is unsupported. " - f"Upstream lineage will be skipped for {upstream_entity.name}->{entity.name}" + f"Upstream lineage will be skipped for {upstream_entity.name}->{entity.name}", ) for fine_grained_lineage in entity_node.fineGrainedLineages or []: new_fine_grained_lineages.append( @@ -252,7 +256,7 @@ def _get_lineage_mcp( downstreamType=fine_grained_lineage.downstreamType, confidenceScore=fine_grained_lineage.confidenceScore, transformOperation=fine_grained_lineage.transformOperation, - ) + ), ) return MetadataChangeProposalWrapper( diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 02125db83d2582..0baa38f5b94391 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -295,7 +295,8 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str: def _get_base_external_url_from_tracking_uri(self) -> Optional[str]: if isinstance( - self.client.tracking_uri, str + self.client.tracking_uri, + str, ) and self.client.tracking_uri.startswith("http"): return self.client.tracking_uri else: @@ -326,7 +327,7 @@ def _get_global_tags_workunit( TagAssociationClass( tag=self._make_stage_tag_urn(model_version.current_stage), ), - ] + ], ) wu = self._create_workunit( urn=self._make_ml_model_urn(model_version), diff --git a/metadata-ingestion/src/datahub/ingestion/source/mode.py b/metadata-ingestion/src/datahub/ingestion/source/mode.py index bf0a33e423446a..8205dcbfc7fa10 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mode.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mode.py @@ -125,10 +125,12 @@ class ModeAPIConfig(ConfigModel): description="Multiplier for exponential backoff when waiting to retry", ) max_retry_interval: Union[int, float] = Field( - default=10, description="Maximum interval to wait when retrying" + default=10, + description="Maximum interval to wait when retrying", ) max_attempts: int = Field( - default=5, description="Maximum number of attempts to retry before failing" + default=5, + description="Maximum number of attempts to retry before failing", ) timeout: int = Field( default=40, @@ -140,20 +142,22 @@ class ModeConfig(StatefulIngestionConfigBase, DatasetLineageProviderConfigBase): # See https://mode.com/developer/api-reference/authentication/ # for authentication connect_uri: str = Field( - default="https://app.mode.com", description="Mode host URL." + default="https://app.mode.com", + description="Mode host URL.", ) token: str = Field( - description="When creating workspace API key this is the 'Key ID'." + description="When creating workspace API key this is the 'Key ID'.", ) password: pydantic.SecretStr = Field( - description="When creating workspace API key this is the 'Secret'." + description="When creating workspace API key this is the 'Secret'.", ) exclude_restricted: bool = Field( - default=False, description="Exclude restricted collections" + default=False, + description="Exclude restricted collections", ) workspace: str = Field( - description="The Mode workspace name. Find it in Settings > Workspace > Details." + description="The Mode workspace name. Find it in Settings > Workspace > Details.", ) default_schema: str = Field( default="public", @@ -168,7 +172,8 @@ class ModeConfig(StatefulIngestionConfigBase, DatasetLineageProviderConfigBase): ) owner_username_instead_of_email: Optional[bool] = Field( - default=True, description="Use username for owner URN instead of Email" + default=True, + description="Use username for owner URN instead of Email", ) api_options: ModeAPIConfig = Field( default=ModeAPIConfig(), @@ -176,13 +181,15 @@ class ModeConfig(StatefulIngestionConfigBase, DatasetLineageProviderConfigBase): ) ingest_embed_url: bool = Field( - default=True, description="Whether to Ingest embed URL for Reports" + default=True, + description="Whether to Ingest embed URL for Reports", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None tag_measures_and_dimensions: Optional[bool] = Field( - default=True, description="Tag measures and dimensions in the schema" + default=True, + description="Tag measures and dimensions in the schema", ) @validator("connect_uri") @@ -326,7 +333,7 @@ def __init__(self, ctx: PipelineContext, config: ModeConfig): { "Content-Type": "application/json", "Accept": "application/hal+json", - } + }, ) # Test the connection @@ -356,7 +363,9 @@ def _browse_path_dashboard(self, space_token: str) -> List[BrowsePathEntryClass] ] def _browse_path_query( - self, space_token: str, report_info: dict + self, + space_token: str, + report_info: dict, ) -> List[BrowsePathEntryClass]: dashboard_urn = self._dashboard_urn(report_info) return [ @@ -365,7 +374,10 @@ def _browse_path_query( ] def _browse_path_chart( - self, space_token: str, report_info: dict, query_info: dict + self, + space_token: str, + report_info: dict, + query_info: dict, ) -> List[BrowsePathEntryClass]: query_urn = self.get_dataset_urn_from_query(query_info) return [ @@ -388,7 +400,9 @@ def _parse_last_run_at(self, report_info: dict) -> Optional[int]: return last_refreshed_ts def construct_dashboard( - self, space_token: str, report_info: dict + self, + space_token: str, + report_info: dict, ) -> Optional[Tuple[DashboardSnapshot, MetadataChangeProposalWrapper]]: report_token = report_info.get("token", "") # logger.debug(f"Processing report {report_info.get('name', '')}: {report_info}") @@ -419,12 +433,12 @@ def construct_dashboard( # Creator + created ts. creator = self._get_creator( - report_info.get("_links", {}).get("creator", {}).get("href", "") + report_info.get("_links", {}).get("creator", {}).get("href", ""), ) if creator: creator_actor = builder.make_user_urn(creator) created_ts = int( - dp.parse(f"{report_info.get('created_at', 'now')}").timestamp() * 1000 + dp.parse(f"{report_info.get('created_at', 'now')}").timestamp() * 1000, ) last_modified.created = AuditStamp(time=created_ts, actor=creator_actor) @@ -437,7 +451,8 @@ def construct_dashboard( if last_modified_ts_str: modified_ts = int(dp.parse(last_modified_ts_str).timestamp() * 1000) last_modified.lastModified = AuditStamp( - time=modified_ts, actor="urn:li:corpuser:unknown" + time=modified_ts, + actor="urn:li:corpuser:unknown", ) # Last refreshed ts. @@ -447,7 +462,7 @@ def construct_dashboard( datasets = [] for imported_dataset_name in report_info.get("imported_datasets", {}): mode_dataset = self._get_request_json( - f"{self.workspace_uri}/reports/{imported_dataset_name.get('token')}" + f"{self.workspace_uri}/reports/{imported_dataset_name.get('token')}", ) dataset_urn = builder.make_dataset_urn_with_platform_instance( self.platform, @@ -475,13 +490,13 @@ def construct_dashboard( paths=[ f"/mode/{self.config.workspace}/" f"{space_name}/" - f"{title if title else report_info.get('id', '')}" - ] + f"{title if title else report_info.get('id', '')}", + ], ) dashboard_snapshot.aspects.append(browse_path) browse_path_v2 = BrowsePathsV2Class( - path=self._browse_path_dashboard(space_token) + path=self._browse_path_dashboard(space_token), ) browse_mcp = MetadataChangeProposalWrapper( entityUrn=dashboard_urn, @@ -491,8 +506,8 @@ def construct_dashboard( # Ownership ownership = self._get_ownership( self._get_creator( - report_info.get("_links", {}).get("creator", {}).get("href", "") - ) + report_info.get("_links", {}).get("creator", {}).get("href", ""), + ), ) if ownership is not None: dashboard_snapshot.aspects.append(ownership) @@ -508,8 +523,8 @@ def _get_ownership(self, user: str) -> Optional[OwnershipClass]: OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return ownership @@ -542,7 +557,8 @@ def _get_chart_urns(self, report_token: str) -> list: for chart in charts: logger.debug(f"Chart: {chart.get('token')}") chart_urn = builder.make_chart_urn( - self.platform, chart.get("token", "") + self.platform, + chart.get("token", ""), ) chart_urns.append(chart_urn) @@ -555,7 +571,7 @@ def _get_space_name_and_tokens(self) -> dict: payload = self._get_request_json(f"{self.workspace_uri}/spaces?filter=all") spaces = payload.get("_embedded", {}).get("spaces", {}) logger.debug( - f"Got {len(spaces)} spaces from workspace {self.workspace_uri}" + f"Got {len(spaces)} spaces from workspace {self.workspace_uri}", ) for s in spaces: logger.debug(f"Space: {s.get('name')}") @@ -567,7 +583,7 @@ def _get_space_name_and_tokens(self) -> dict: s.get("restricted") or s.get("default_access_level") == "restricted" ): logging.debug( - f"Skipping space {space_name} due to exclude restricted" + f"Skipping space {space_name} due to exclude restricted", ) continue if not self.config.space_pattern.allowed(space_name): @@ -627,7 +643,9 @@ def _get_chart_type(self, token: str, display_type: str) -> Optional[str]: return chart_type def construct_chart_custom_properties( - self, chart_detail: dict, chart_type: str + self, + chart_detail: dict, + chart_type: str, ) -> Dict: custom_properties = { "ChartType": chart_type, @@ -643,7 +661,7 @@ def construct_chart_custom_properties( { "Columns": str_columns, "Filters": filters[1:-1] if len(filters) else "", - } + }, ) elif chart_type == "pivotTable": @@ -659,12 +677,12 @@ def construct_chart_custom_properties( "Rows": ", ".join(rows) if len(rows) else "", "Metrics": ", ".join(values) if len(values) else "", "Filters": ", ".join(filters) if len(filters) else "", - } + }, ) # list filters in their own row for filter in filters: custom_properties[f"Filter: {filter}"] = ", ".join( - pivot_table.get("filterValues", {}).get(filter, "") + pivot_table.get("filterValues", {}).get(filter, ""), ) # Chart else: @@ -683,7 +701,7 @@ def construct_chart_custom_properties( "Y2": y2[0].get("formula", "") if len(y2) else "", "Metrics": value[0].get("formula", "") if len(value) else "", "Filters": filters[0].get("formula", "") if len(filters) else "", - } + }, ) return custom_properties @@ -736,7 +754,8 @@ def _get_data_sources(self) -> List[dict]: @lru_cache(maxsize=None) def _get_platform_and_dbname( - self, data_source_id: int + self, + data_source_id: int, ) -> Union[Tuple[str, str], Tuple[None, None]]: data_sources = self._get_data_sources() @@ -751,7 +770,8 @@ def _get_platform_and_dbname( for data_source in data_sources: if data_source.get("id", -1) == data_source_id: platform = self._get_datahub_friendly_platform( - data_source.get("adapter", ""), data_source.get("name", "") + data_source.get("adapter", ""), + data_source.get("name", ""), ) database = data_source.get("database", "") # This is hacky but on bigquery we want to change the database if its default @@ -772,17 +792,19 @@ def _replace_definitions(self, raw_query: str) -> str: definitions = re.findall(r"({{(?:\s+)?@[^}{]+}})", raw_query) for definition_variable in definitions: definition_name, definition_alias = self._parse_definition_name( - definition_variable + definition_variable, ) definition_query = self._get_definition(definition_name) # if unable to retrieve definition, then replace the {{}} so that it doesn't get picked up again in recursive call if definition_query is not None: query = query.replace( - definition_variable, f"({definition_query}) as {definition_alias}" + definition_variable, + f"({definition_query}) as {definition_alias}", ) else: query = query.replace( - definition_variable, f"{definition_name} as {definition_alias}" + definition_variable, + f"{definition_name} as {definition_alias}", ) query = self._replace_definitions(query) query = query.replace("\\n", "\n") @@ -797,7 +819,8 @@ def _parse_definition_name(self, definition_variable: str) -> Tuple[str, str]: if len(name_match): name = name_match[0][1:] alias_match = re.findall( - r"as\s+\S+\w+", definition_variable + r"as\s+\S+\w+", + definition_variable, ) # i.e ['as alias_name'] if len(alias_match): alias_match = alias_match[0].split(" ") @@ -809,7 +832,7 @@ def _parse_definition_name(self, definition_variable: str) -> Tuple[str, str]: def _get_definition(self, definition_name): try: definition_json = self._get_request_json( - f"{self.workspace_uri}/definitions" + f"{self.workspace_uri}/definitions", ) definitions = definition_json.get("_embedded", {}).get("definitions", []) for definition in definitions: @@ -980,9 +1003,9 @@ def construct_query_or_dataset( [ BIAssetSubTypes.MODE_DATASET if is_mode_dataset - else BIAssetSubTypes.MODE_QUERY + else BIAssetSubTypes.MODE_QUERY, ] - ) + ), ) yield ( MetadataChangeProposalWrapper( @@ -996,7 +1019,7 @@ def construct_query_or_dataset( aspect=BrowsePathsV2Class( path=self._browse_path_dashboard(space_token) if is_mode_dataset - else self._browse_path_query(space_token, report_info) + else self._browse_path_query(space_token, report_info), ), ).as_workunit() @@ -1053,13 +1076,13 @@ def construct_query_or_dataset( self.report.num_sql_parser_table_error += 1 self.report.num_sql_parser_failures += 1 logger.info( - f"Failed to parse compiled code for report: {report_token} query: {query_token} {parsed_query_object.debug_info.error} the query was [{query_to_parse}]" + f"Failed to parse compiled code for report: {report_token} query: {query_token} {parsed_query_object.debug_info.error} the query was [{query_to_parse}]", ) elif parsed_query_object.debug_info.column_error: self.report.num_sql_parser_column_error += 1 self.report.num_sql_parser_failures += 1 logger.info( - f"Failed to generate CLL for report: {report_token} query: {query_token}: {parsed_query_object.debug_info.column_error} the query was [{query_to_parse}]" + f"Failed to generate CLL for report: {report_token} query: {query_token}: {parsed_query_object.debug_info.column_error} the query was [{query_to_parse}]", ) else: self.report.num_sql_parser_success += 1 @@ -1085,13 +1108,15 @@ def construct_query_or_dataset( ) yield from self.get_upstream_lineage_for_parsed_sql( - query_urn, query_data, parsed_query_object + query_urn, + query_data, + parsed_query_object, ) operation = OperationClass( operationType=OperationTypeClass.UPDATE, lastUpdatedTimestamp=int( - dp.parse(query_data.get("updated_at", "now")).timestamp() * 1000 + dp.parse(query_data.get("updated_at", "now")).timestamp() * 1000, ), timestampMillis=int(datetime.now(tz=timezone.utc).timestamp() * 1000), ) @@ -1102,17 +1127,17 @@ def construct_query_or_dataset( ).as_workunit() creator = self._get_creator( - query_data.get("_links", {}).get("creator", {}).get("href", "") + query_data.get("_links", {}).get("creator", {}).get("href", ""), ) modified_actor = builder.make_user_urn( - creator if creator is not None else "unknown" + creator if creator is not None else "unknown", ) created_ts = int( - dp.parse(query_data.get("created_at", "now")).timestamp() * 1000 + dp.parse(query_data.get("created_at", "now")).timestamp() * 1000, ) modified_ts = int( - dp.parse(query_data.get("updated_at", "now")).timestamp() * 1000 + dp.parse(query_data.get("updated_at", "now")).timestamp() * 1000, ) query_instance_urn = self.get_query_instance_urn_from_query(query_data) @@ -1134,18 +1159,21 @@ def construct_query_or_dataset( ).as_workunit() def get_upstream_lineage_for_parsed_sql( - self, query_urn: str, query_data: dict, parsed_query_object: SqlParsingResult + self, + query_urn: str, + query_data: dict, + parsed_query_object: SqlParsingResult, ) -> List[MetadataWorkUnit]: wu = [] if parsed_query_object is None: logger.info( - f"Failed to extract column level lineage from datasource {query_urn}" + f"Failed to extract column level lineage from datasource {query_urn}", ) return [] if parsed_query_object.debug_info.error: logger.info( - f"Failed to extract column level lineage from datasource {query_urn}: {parsed_query_object.debug_info.error}" + f"Failed to extract column level lineage from datasource {query_urn}: {parsed_query_object.debug_info.error}", ) return [] @@ -1181,7 +1209,7 @@ def get_upstream_lineage_for_parsed_sql( downstreams=downstream, upstreamType=FineGrainedLineageUpstreamTypeClass.FIELD_SET, upstreams=upstreams, - ) + ), ) upstream_lineage = UpstreamLineageClass( @@ -1200,13 +1228,15 @@ def get_upstream_lineage_for_parsed_sql( MetadataChangeProposalWrapper( entityUrn=query_urn, aspect=upstream_lineage, - ).as_workunit() + ).as_workunit(), ) return wu def get_formula_columns( - self, node: Dict, columns: Optional[Set[str]] = None + self, + node: Dict, + columns: Optional[Set[str]] = None, ) -> Set[str]: columns = columns if columns is not None else set() if isinstance(node, dict): @@ -1273,15 +1303,15 @@ def construct_chart_from_api_data( last_modified = ChangeAuditStamps() creator = self._get_creator( - chart_data.get("_links", {}).get("creator", {}).get("href", "") + chart_data.get("_links", {}).get("creator", {}).get("href", ""), ) if creator is not None: modified_actor = builder.make_user_urn(creator) created_ts = int( - dp.parse(chart_data.get("created_at", "now")).timestamp() * 1000 + dp.parse(chart_data.get("created_at", "now")).timestamp() * 1000, ) modified_ts = int( - dp.parse(chart_data.get("updated_at", "now")).timestamp() * 1000 + dp.parse(chart_data.get("updated_at", "now")).timestamp() * 1000, ) last_modified = ChangeAuditStamps( created=AuditStamp(time=created_ts, actor=modified_actor), @@ -1298,7 +1328,8 @@ def construct_chart_from_api_data( ) mode_chart_type = chart_detail.get("chartType", "") or chart_detail.get( - "selectedChart", "" + "selectedChart", + "", ) chart_type = self._get_chart_type(chart_data.get("token", ""), mode_chart_type) description = ( @@ -1315,7 +1346,8 @@ def construct_chart_from_api_data( # create datasource urn custom_properties = self.construct_chart_custom_properties( - chart_detail, mode_chart_type + chart_detail, + mode_chart_type, ) query_urn = self.get_dataset_urn_from_query(query) @@ -1369,8 +1401,8 @@ def construct_chart_from_api_data( # Ownership ownership = self._get_ownership( self._get_creator( - chart_data.get("_links", {}).get("creator", {}).get("href", "") - ) + chart_data.get("_links", {}).get("creator", {}).get("href", ""), + ), ) if ownership is not None: chart_snapshot.aspects.append(ownership) @@ -1383,7 +1415,7 @@ def _get_reports(self, space_token: str) -> List[dict]: reports = [] try: reports_json = self._get_request_json( - f"{self.workspace_uri}/spaces/{space_token}/reports" + f"{self.workspace_uri}/spaces/{space_token}/reports", ) reports = reports_json.get("_embedded", {}).get("reports", {}) except ModeRequestError as e: @@ -1417,7 +1449,7 @@ def _get_queries(self, report_token: str) -> list: queries = [] try: queries_json = self._get_request_json( - f"{self.workspace_uri}/reports/{report_token}/queries" + f"{self.workspace_uri}/reports/{report_token}/queries", ) queries = queries_json.get("_embedded", {}).get("queries", {}) except ModeRequestError as e: @@ -1430,11 +1462,14 @@ def _get_queries(self, report_token: str) -> list: @lru_cache(maxsize=None) def _get_last_query_run( - self, report_token: str, report_run_id: str, query_run_id: str + self, + report_token: str, + report_run_id: str, + query_run_id: str, ) -> Dict: try: queries_json = self._get_request_json( - f"{self.workspace_uri}/reports/{report_token}/runs/{report_run_id}/query_runs{query_run_id}" + f"{self.workspace_uri}/reports/{report_token}/runs/{report_run_id}/query_runs{query_run_id}", ) queries = queries_json.get("_embedded", {}).get("queries", {}) except ModeRequestError as e: @@ -1452,7 +1487,7 @@ def _get_charts(self, report_token: str, query_token: str) -> list: try: charts_json = self._get_request_json( f"{self.workspace_uri}/reports/{report_token}" - f"/queries/{query_token}/charts" + f"/queries/{query_token}/charts", ) charts = charts_json.get("_embedded", {}).get("charts", {}) except ModeRequestError as e: @@ -1479,7 +1514,8 @@ def _get_request_json(self, url: str) -> Dict: def get_request(): try: response = self.session.get( - url, timeout=self.config.api_options.timeout + url, + timeout=self.config.api_options.timeout, ) if response.status_code == 204: # No content, don't parse json return {} @@ -1499,7 +1535,8 @@ def get_request(): @staticmethod def create_embed_aspect_mcp( - entity_urn: str, embed_url: str + entity_urn: str, + embed_url: str, ) -> MetadataChangeProposalWrapper: return MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -1510,7 +1547,9 @@ def gen_space_key(self, space_token: str) -> SpaceKey: return SpaceKey(platform=self.platform, space_token=space_token) def construct_space_container( - self, space_token: str, space_name: str + self, + space_token: str, + space_name: str, ) -> Iterable[MetadataWorkUnit]: key = self.gen_space_key(space_token) yield from gen_containers( @@ -1535,10 +1574,11 @@ def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: reports = self._get_reports(space_token) for report in reports: logger.debug( - f"Report: name: {report.get('name')} token: {report.get('token')}" + f"Report: name: {report.get('name')} token: {report.get('token')}", ) dashboard_tuple_from_report = self.construct_dashboard( - space_token=space_token, report_info=report + space_token=space_token, + report_info=report, ) if dashboard_tuple_from_report is None: @@ -1549,7 +1589,7 @@ def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: ) = dashboard_tuple_from_report mce = MetadataChangeEvent( - proposedSnapshot=dashboard_snapshot_from_report + proposedSnapshot=dashboard_snapshot_from_report, ) mcpw = MetadataChangeProposalWrapper( @@ -1600,7 +1640,8 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]: chart_fields: Dict[str, SchemaFieldClass] = {} for wu in query_mcps: if isinstance( - wu.metadata, MetadataChangeProposalWrapper + wu.metadata, + MetadataChangeProposalWrapper, ) and isinstance(wu.metadata.aspect, SchemaMetadataClass): schema_metadata = wu.metadata.aspect for field in schema_metadata.fields: @@ -1651,7 +1692,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/mongodb.py b/metadata-ingestion/src/datahub/ingestion/source/mongodb.py index ad8487c1a759ec..ae10fa088a31b4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mongodb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mongodb.py @@ -85,23 +85,29 @@ class HostingEnvironment(Enum): class MongoDBConfig( - PlatformInstanceConfigMixin, EnvConfigMixin, StatefulIngestionConfigBase + PlatformInstanceConfigMixin, + EnvConfigMixin, + StatefulIngestionConfigBase, ): # See the MongoDB authentication docs for details and examples. # https://pymongo.readthedocs.io/en/stable/examples/authentication.html connect_uri: str = Field( - default="mongodb://localhost", description="MongoDB connection URI." + default="mongodb://localhost", + description="MongoDB connection URI.", ) username: Optional[str] = Field(default=None, description="MongoDB username.") password: Optional[str] = Field(default=None, description="MongoDB password.") authMechanism: Optional[str] = Field( - default=None, description="MongoDB authentication mechanism." + default=None, + description="MongoDB authentication mechanism.", ) options: dict = Field( - default={}, description="Additional options to pass to `pymongo.MongoClient()`." + default={}, + description="Additional options to pass to `pymongo.MongoClient()`.", ) enableSchemaInference: bool = Field( - default=True, description="Whether to infer schemas. " + default=True, + description="Whether to infer schemas. ", ) schemaSamplingSize: Optional[PositiveInt] = Field( default=1000, @@ -112,7 +118,8 @@ class MongoDBConfig( description="If documents for schema inference should be randomly selected. If `False`, documents will be selected from start.", ) maxSchemaSize: Optional[PositiveInt] = Field( - default=300, description="Maximum number of fields to include in the schema." + default=300, + description="Maximum number of fields to include in the schema.", ) # mongodb only supports 16MB as max size for documents. However, if we try to retrieve a larger document it # errors out with "16793600" as the maximum size supported. @@ -306,12 +313,16 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] def get_pymongo_type_string( - self, field_type: Union[Type, str], collection_name: str + self, + field_type: Union[Type, str], + collection_name: str, ) -> str: """ Return Mongo type string from a Python type @@ -336,7 +347,9 @@ def get_pymongo_type_string( return type_string def get_field_type( - self, field_type: Union[Type, str], collection_name: str + self, + field_type: Union[Type, str], + collection_name: str, ) -> SchemaFieldDataType: """ Maps types encountered in PyMongo to corresponding schema types. @@ -405,7 +418,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: data_platform_instance = DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ) @@ -460,7 +474,8 @@ def _infer_schema_metadata( max_schema_size = self.config.maxSchemaSize collection_schema_size = len(collection_schema.values()) collection_fields: Union[ - List[SchemaDescription], ValuesView[SchemaDescription] + List[SchemaDescription], + ValuesView[SchemaDescription], ] = collection_schema.values() assert max_schema_size is not None if collection_schema_size > max_schema_size: @@ -488,7 +503,8 @@ def _infer_schema_metadata( field = SchemaField( fieldPath=schema_field["delimited_name"], nativeDataType=self.get_pymongo_type_string( - schema_field["type"], dataset_urn.name + schema_field["type"], + dataset_urn.name, ), type=self.get_field_type(schema_field["type"], dataset_urn.name), description=None, @@ -512,12 +528,12 @@ def is_server_version_gte_4_4(self) -> bool: server_version = self.mongo_client.server_info().get("versionArray") if server_version: logger.info( - f"Mongodb version for current connection - {server_version}" + f"Mongodb version for current connection - {server_version}", ) server_version_str_list = [str(i) for i in server_version] required_version = "4.4" return version.parse( - ".".join(server_version_str_list) + ".".join(server_version_str_list), ) >= version.parse(required_version) except Exception as e: logger.error("Error while getting version of the mongodb server %s", e) diff --git a/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py b/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py index 8cdd4b17733e01..b99e092404aa95 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py @@ -88,7 +88,10 @@ def get_field_type(self, attribute_type: Union[type, str]) -> SchemaFieldDataTyp return SchemaFieldDataType(type=type_class()) def get_schema_field_class( - self, col_name: str, col_type: str, **kwargs: Any + self, + col_name: str, + col_type: str, + **kwargs: Any, ) -> SchemaFieldClass: if kwargs["obj_type"] == self.NODE and col_type == self.RELATIONSHIP: col_type = self.NODE @@ -102,7 +105,8 @@ def get_schema_field_class( if col_type in (self.NODE, self.RELATIONSHIP) else col_type, lastModified=AuditStampClass( - time=round(time.time() * 1000), actor="urn:li:corpuser:ingestion" + time=round(time.time() * 1000), + actor="urn:li:corpuser:ingestion", ), ) @@ -118,13 +122,18 @@ def add_properties( ) return MetadataChangeProposalWrapper( entityUrn=make_dataset_urn( - platform=self.PLATFORM, name=dataset, env=self.config.env + platform=self.PLATFORM, + name=dataset, + env=self.config.env, ), aspect=dataset_properties, ) def generate_neo4j_object( - self, dataset: str, columns: list, obj_type: Optional[str] = None + self, + dataset: str, + columns: list, + obj_type: Optional[str] = None, ) -> MetadataChangeProposalWrapper: try: fields = [ @@ -134,7 +143,9 @@ def generate_neo4j_object( ] mcp = MetadataChangeProposalWrapper( entityUrn=make_dataset_urn( - platform=self.PLATFORM, name=dataset, env=self.config.env + platform=self.PLATFORM, + name=dataset, + env=self.config.env, ), aspect=SchemaMetadataClass( schemaName=dataset, @@ -157,7 +168,8 @@ def generate_neo4j_object( def get_neo4j_metadata(self, query: str) -> pd.DataFrame: driver = GraphDatabase.driver( - self.config.uri, auth=(self.config.username, self.config.password) + self.config.uri, + auth=(self.config.username, self.config.password), ) """ This process retrieves the metadata for Neo4j objects using an APOC query, which returns a dictionary @@ -204,19 +216,20 @@ def process_nodes(self, data: list) -> pd.DataFrame: columns=["key", "value"], ) node_df["obj_type"] = node_df["value"].apply( - lambda record: self.get_obj_type(record) + lambda record: self.get_obj_type(record), ) node_df["relationships"] = node_df["value"].apply( - lambda record: self.get_relationships(record) + lambda record: self.get_relationships(record), ) node_df["properties"] = node_df["value"].apply( - lambda record: self.get_properties(record) + lambda record: self.get_properties(record), ) node_df["property_data_types"] = node_df["properties"].apply( - lambda record: self.get_property_data_types(record) + lambda record: self.get_property_data_types(record), ) node_df["description"] = node_df.apply( - lambda record: self.get_node_description(record, node_df), axis=1 + lambda record: self.get_node_description(record, node_df), + axis=1, ) return node_df @@ -226,16 +239,17 @@ def process_relationships(self, data: list, node_df: pd.DataFrame) -> pd.DataFra ] rel_df = pd.DataFrame(rels, columns=["key", "value"]) rel_df["obj_type"] = rel_df["value"].apply( - lambda record: self.get_obj_type(record) + lambda record: self.get_obj_type(record), ) rel_df["properties"] = rel_df["value"].apply( - lambda record: self.get_properties(record) + lambda record: self.get_properties(record), ) rel_df["property_data_types"] = rel_df["properties"].apply( - lambda record: self.get_property_data_types(record) + lambda record: self.get_property_data_types(record), ) rel_df["description"] = rel_df.apply( - lambda record: self.get_rel_descriptions(record, node_df), axis=1 + lambda record: self.get_rel_descriptions(record, node_df), + axis=1, ) return rel_df @@ -251,7 +265,7 @@ def get_rel_descriptions(self, record: dict, df: pd.DataFrame) -> str: if props["direction"] == "in": for prop in props["labels"]: descriptions.append( - f"({row['key']})-[{record['key']}]->({prop})" + f"({row['key']})-[{record['key']}]->({prop})", ) return "\n".join(descriptions) @@ -264,11 +278,11 @@ def get_node_description(self, record: dict, df: pd.DataFrame) -> str: for node in set(props["labels"]): if direction == "in": descriptions.append( - f"({row['key']})<-[{relationship}]-({node})" + f"({row['key']})<-[{relationship}]-({node})", ) elif direction == "out": descriptions.append( - f"({row['key']})-[{relationship}]->({node})" + f"({row['key']})-[{relationship}]->({node})", ) return "\n".join(descriptions) @@ -284,7 +298,7 @@ def get_relationships(self, record: dict) -> dict: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: df = self.get_neo4j_metadata( - "CALL apoc.meta.schema() YIELD value UNWIND keys(value) AS key RETURN key, value[key] AS value;" + "CALL apoc.meta.schema() YIELD value UNWIND keys(value) AS key RETURN key, value[key] AS value;", ) for _, row in df.iterrows(): try: @@ -309,8 +323,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: typeNames=[ DatasetSubTypes.NEO4J_NODE if row["obj_type"] == self.NODE - else DatasetSubTypes.NEO4J_RELATIONSHIP - ] + else DatasetSubTypes.NEO4J_RELATIONSHIP, + ], ), ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/nifi.py b/metadata-ingestion/src/datahub/ingestion/source/nifi.py index 52b1386e21d85a..65cde7d02ec568 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/nifi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/nifi.py @@ -59,7 +59,9 @@ class SSLAdapter(HTTPAdapter): def __init__(self, certfile, keyfile, password=None): self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.context.load_cert_chain( - certfile=certfile, keyfile=keyfile, password=password + certfile=certfile, + keyfile=keyfile, + password=password, ) super().__init__() @@ -82,7 +84,7 @@ class ProcessGroupKey(ContainerKey): class NifiSourceConfig(EnvConfigMixin): site_url: str = Field( - description="URL for Nifi, ending with /nifi/. e.g. https://mynifi.domain/nifi/" + description="URL for Nifi, ending with /nifi/. e.g. https://mynifi.domain/nifi/", ) auth: NifiAuthType = Field( @@ -111,10 +113,12 @@ class NifiSourceConfig(EnvConfigMixin): # Required to be set if auth is of type SINGLE_USER username: Optional[str] = Field( - default=None, description='Nifi username, must be set for auth = "SINGLE_USER"' + default=None, + description='Nifi username, must be set for auth = "SINGLE_USER"', ) password: Optional[str] = Field( - default=None, description='Nifi password, must be set for auth = "SINGLE_USER"' + default=None, + description='Nifi password, must be set for auth = "SINGLE_USER"', ) # Required to be set if auth is of type CLIENT_CERT @@ -123,10 +127,12 @@ class NifiSourceConfig(EnvConfigMixin): description='Path to PEM file containing the public certificates for the user/client identity, must be set for auth = "CLIENT_CERT"', ) client_key_file: Optional[str] = Field( - default=None, description="Path to PEM file containing the client’s secret key" + default=None, + description="Path to PEM file containing the client’s secret key", ) client_key_password: Optional[str] = Field( - default=None, description="The password to decrypt the client_key_file" + default=None, + description="The password to decrypt the client_key_file", ) # Required to be set if nifi server certificate is not signed by @@ -153,17 +159,17 @@ class NifiSourceConfig(EnvConfigMixin): @root_validator(skip_on_failure=True) def validate_auth_params(cla, values): if values.get("auth") is NifiAuthType.CLIENT_CERT and not values.get( - "client_cert_file" + "client_cert_file", ): raise ValueError( - "Config `client_cert_file` is required for CLIENT_CERT auth" + "Config `client_cert_file` is required for CLIENT_CERT auth", ) elif values.get("auth") in ( NifiAuthType.SINGLE_USER, NifiAuthType.BASIC_AUTH, ) and (not values.get("username") or not values.get("password")): raise ValueError( - f"Config `username` and `password` is required for {values.get('auth').value} auth" + f"Config `username` and `password` is required for {values.get('auth').value} auth", ) return values @@ -215,7 +221,7 @@ def add_connection(self, from_component: str, to_component: str) -> None: if outgoing_duplicated or incoming_duplicated: logger.warning( - f"Somehow we attempted to add a connection between 2 components which already existed! Duplicated incoming: {incoming_duplicated}, duplicated outgoing: {outgoing_duplicated}. Connection from component: {from_component} to component: {to_component}" + f"Somehow we attempted to add a connection between 2 components which already existed! Duplicated incoming: {incoming_duplicated}, duplicated outgoing: {outgoing_duplicated}. Connection from component: {from_component} to component: {to_component}", ) def remove_connection(self, from_component: str, to_component: str) -> None: @@ -233,11 +239,11 @@ def delete_component(self, component: str) -> None: logger.debug(f"Deleting component with id: {component}") incoming = self._incoming[component] logger.debug( - f"Recognized {len(incoming)} incoming connections to the component" + f"Recognized {len(incoming)} incoming connections to the component", ) outgoing = self._outgoing[component] logger.debug( - f"Recognized {len(outgoing)} outgoing connections from the component" + f"Recognized {len(outgoing)} outgoing connections from the component", ) for i in incoming: @@ -252,7 +258,7 @@ def delete_component(self, component: str) -> None: added_connections_cnt = len(incoming) * len(outgoing) deleted_connections_cnt = len(incoming) + len(outgoing) logger.debug( - f"Deleted {deleted_connections_cnt} connections and added {added_connections_cnt}" + f"Deleted {deleted_connections_cnt} connections and added {added_connections_cnt}", ) del self._outgoing[component] @@ -320,7 +326,8 @@ class NifiProcessorProvenanceEventAnalyzer: def __init__(self) -> None: # Map of Nifi processor type to the provenance event analyzer to find lineage self.provenance_event_to_lineage_map: Dict[ - str, Callable[[Dict], ExternalDataset] + str, + Callable[[Dict], ExternalDataset], ] = { NifiProcessorType.ListS3: self.process_s3_provenance_event, NifiProcessorType.FetchS3Object: self.process_s3_provenance_event, @@ -338,7 +345,7 @@ def process_s3_provenance_event(self, event): s3_key = get_attribute_value(attributes, "s3.key") if not s3_key: logger.debug( - "s3.key not present in the list of attributes, trying to use filename attribute instead" + "s3.key not present in the list of attributes, trying to use filename attribute instead", ) s3_key = get_attribute_value(attributes, "filename") @@ -436,7 +443,7 @@ class NifiFlow: components: Dict[str, NifiComponent] = field(default_factory=dict) remotely_accessible_ports: Dict[str, NifiComponent] = field(default_factory=dict) connections: BidirectionalComponentGraph = field( - default_factory=BidirectionalComponentGraph + default_factory=BidirectionalComponentGraph, ) processGroups: Dict[str, NifiProcessGroup] = field(default_factory=dict) remoteProcessGroups: Dict[str, NifiRemoteProcessGroup] = field(default_factory=dict) @@ -492,7 +499,7 @@ def update_flow(self, pg_flow_dto: Dict, recursion_level: int = 0) -> None: # n Update self.nifi_flow with contents of the input process group `pg_flow_dto` """ logger.debug( - f"Updating flow with pg_flow_dto {pg_flow_dto.get('breadcrumb', {}).get('breadcrumb', {}).get('id')}, recursion level: {recursion_level}" + f"Updating flow with pg_flow_dto {pg_flow_dto.get('breadcrumb', {}).get('breadcrumb', {}).get('id')}, recursion level: {recursion_level}", ) breadcrumb_dto = pg_flow_dto.get("breadcrumb", {}).get("breadcrumb", {}) nifi_pg = NifiProcessGroup( @@ -539,7 +546,8 @@ def update_flow(self, pg_flow_dto: Dict, recursion_level: int = 0) -> None: # n # Exclude self - recursive relationships if connection.get("sourceId") != connection.get("destinationId"): self.nifi_flow.connections.add_connection( - connection.get("sourceId"), connection.get("destinationId") + connection.get("sourceId"), + connection.get("destinationId"), ) logger.debug(f"Processing {len(flow_dto.get('inputPorts', []))} inputPorts") @@ -599,7 +607,7 @@ def update_flow(self, pg_flow_dto: Dict, recursion_level: int = 0) -> None: # n logger.debug(f"Adding report port {component.get('id')}") logger.debug( - f"Processing {len(flow_dto.get('remoteProcessGroups', []))} remoteProcessGroups" + f"Processing {len(flow_dto.get('remoteProcessGroups', []))} remoteProcessGroups", ) for rpg in flow_dto.get("remoteProcessGroups", []): rpg_component = rpg.get("component", {}) @@ -647,14 +655,14 @@ def update_flow(self, pg_flow_dto: Dict, recursion_level: int = 0) -> None: # n self.nifi_flow.remoteProcessGroups[nifi_rpg.id] = nifi_rpg logger.debug( - f"Processing {len(flow_dto.get('processGroups', []))} processGroups" + f"Processing {len(flow_dto.get('processGroups', []))} processGroups", ) for pg in flow_dto.get("processGroups", []): logger.debug( - f"Retrieving process group: {pg.get('id')} while updating flow for {pg_flow_dto.get('breadcrumb', {}).get('breadcrumb', {}).get('id')}" + f"Retrieving process group: {pg.get('id')} while updating flow for {pg_flow_dto.get('breadcrumb', {}).get('breadcrumb', {}).get('id')}", ) pg_response = self.session.get( - url=urljoin(self.rest_api_base_url, PG_ENDPOINT) + pg.get("id") + url=urljoin(self.rest_api_base_url, PG_ENDPOINT) + pg.get("id"), ) if not pg_response.ok: @@ -672,17 +680,17 @@ def update_flow_keep_only_ingress_egress(self): components_to_del: List[NifiComponent] = [] components = self.nifi_flow.components.values() logger.debug( - f"Processing {len(components)} components for keep only ingress/egress" + f"Processing {len(components)} components for keep only ingress/egress", ) logger.debug( - f"All the connections recognized: {len(self.nifi_flow.connections)}" + f"All the connections recognized: {len(self.nifi_flow.connections)}", ) for index, component in enumerate(components, start=1): logger.debug( - f"Processing {index}th component for ingress/egress pruning. Component id: {component.id}, name: {component.name}, type: {component.type}" + f"Processing {index}th component for ingress/egress pruning. Component id: {component.id}, name: {component.name}, type: {component.type}", ) logger.debug( - f"Current amount of connections: {len(self.nifi_flow.connections)}" + f"Current amount of connections: {len(self.nifi_flow.connections)}", ) if ( component.nifi_type is NifiType.PROCESSOR @@ -698,7 +706,7 @@ def update_flow_keep_only_ingress_egress(self): for component in components_to_del: if component.nifi_type is NifiType.PROCESSOR and component.name.startswith( - ("Get", "List", "Fetch", "Put") + ("Get", "List", "Fetch", "Put"), ): self.report.warning( f"Dropping NiFi Processor of type {component.type}, id {component.id}, name {component.name} from lineage view. \ @@ -708,7 +716,7 @@ def update_flow_keep_only_ingress_egress(self): ) else: logger.debug( - f"Dropping NiFi Component of type {component.type}, id {component.id}, name {component.name} from lineage view." + f"Dropping NiFi Component of type {component.type}, id {component.id}, name {component.name} from lineage view.", ) del self.nifi_flow.components[component.id] @@ -716,7 +724,7 @@ def update_flow_keep_only_ingress_egress(self): def create_nifi_flow(self): logger.debug(f"Retrieving NIFI info from {ABOUT_ENDPOINT}") about_response = self.session.get( - url=urljoin(self.rest_api_base_url, ABOUT_ENDPOINT) + url=urljoin(self.rest_api_base_url, ABOUT_ENDPOINT), ) nifi_version: Optional[str] = None if about_response.ok: @@ -724,14 +732,14 @@ def create_nifi_flow(self): nifi_version = about_response.json().get("about", {}).get("version") except Exception as e: logger.error( - f"Unable to parse about response from Nifi: {about_response} due to {e}" + f"Unable to parse about response from Nifi: {about_response} due to {e}", ) else: logger.warning("Failed to fetch version for nifi") logger.debug(f"Retrieved nifi version: {nifi_version}") logger.debug(f"Retrieving cluster info from {CLUSTER_ENDPOINT}") cluster_response = self.session.get( - url=urljoin(self.rest_api_base_url, CLUSTER_ENDPOINT) + url=urljoin(self.rest_api_base_url, CLUSTER_ENDPOINT), ) clustered: Optional[bool] = None if cluster_response.ok: @@ -743,7 +751,7 @@ def create_nifi_flow(self): logger.warning("Failed to fetch cluster summary for flow") logger.debug("Retrieving ROOT Process Group") pg_response = self.session.get( - url=urljoin(self.rest_api_base_url, PG_ENDPOINT) + "root" + url=urljoin(self.rest_api_base_url, PG_ENDPOINT) + "root", ) pg_response.raise_for_status() @@ -771,11 +779,14 @@ def fetch_provenance_events( ) -> Iterable[Dict]: logger.debug( f"Fetching {eventType} provenance events for {processor.id}\ - of processor type {processor.type}, Start date: {startDate}, End date: {endDate}" + of processor type {processor.type}, Start date: {startDate}, End date: {endDate}", ) provenance_response = self.submit_provenance_query( - processor, eventType, startDate, endDate + processor, + eventType, + startDate, + endDate, ) if provenance_response.ok: @@ -789,7 +800,7 @@ def fetch_provenance_events( attempts = 5 # wait for at most 5 attempts 5*1= 5 seconds while (not provenance.get("finished", False)) and attempts > 0: logger.warning( - f"Provenance query not completed, attempts left : {attempts}" + f"Provenance query not completed, attempts left : {attempts}", ) # wait until the uri returns percentcomplete 100 time.sleep(1) @@ -824,7 +835,10 @@ def fetch_provenance_events( if total != str(totalCount): logger.debug("Trying to retrieve more events for the same processor") yield from self.fetch_provenance_events( - processor, eventType, startDate, oldest_event_time + processor, + eventType, + startDate, + oldest_event_time, ) else: self.report.warning( @@ -837,7 +851,7 @@ def fetch_provenance_events( def submit_provenance_query(self, processor, eventType, startDate, endDate): older_version: bool = self.nifi_flow.version is not None and version.parse( - self.nifi_flow.version + self.nifi_flow.version, ) < version.parse("1.13.0") if older_version: @@ -864,9 +878,9 @@ def submit_provenance_query(self, processor, eventType, startDate, endDate): if endDate else None ), - } - } - } + }, + }, + }, ) logger.debug(payload) self.session.headers.update({}) @@ -882,7 +896,7 @@ def submit_provenance_query(self, processor, eventType, startDate, endDate): self.session.headers.update( { "Content-Type": "application/x-www-form-urlencoded", - } + }, ) return provenance_response @@ -903,12 +917,15 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 if self.nifi_flow.version is not None: flow_properties["version"] = str(self.nifi_flow.version) yield self.construct_flow_workunits( - flow_urn, flow_name, self.make_external_url(rootpg.id), flow_properties + flow_urn, + flow_name, + self.make_external_url(rootpg.id), + flow_properties, ) for component in self.nifi_flow.components.values(): logger.debug( - f"Beginng construction of workunits for component {component.id} of type {component.type} and name {component.name}" + f"Beginng construction of workunits for component {component.id} of type {component.type} and name {component.name}", ) logger.debug(f"Inlets of the component: {component.inlets.keys()}") logger.debug(f"Outlets of the component: {component.outlets.keys()}") @@ -933,14 +950,14 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 ] } jobProperties["properties"] = json.dumps( - component.config.get("properties") # type: ignore + component.config.get("properties"), # type: ignore ) if component.last_event_time is not None: jobProperties["last_event_time"] = component.last_event_time for dataset in component.inlets.values(): logger.debug( - f"Yielding dataset workunits for {dataset.dataset_urn} (inlet)" + f"Yielding dataset workunits for {dataset.dataset_urn} (inlet)", ) yield from self.construct_dataset_workunits( dataset.platform, @@ -951,7 +968,7 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 for dataset in component.outlets.values(): logger.debug( - f"Yielding dataset workunits for {dataset.dataset_urn} (outlet)" + f"Yielding dataset workunits for {dataset.dataset_urn} (outlet)", ) yield from self.construct_dataset_workunits( dataset.platform, @@ -964,7 +981,9 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 if incoming_from in self.nifi_flow.remotely_accessible_ports.keys(): dataset_name = f"{self.config.site_name}.{self.nifi_flow.remotely_accessible_ports[incoming_from].name}" dataset_urn = builder.make_dataset_urn( - NIFI, dataset_name, self.config.env + NIFI, + dataset_name, + self.config.env, ) component.inlets[dataset_urn] = ExternalDataset( NIFI, @@ -974,14 +993,16 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 ) else: inputJobs.add( - builder.make_data_job_urn_with_flow(flow_urn, incoming_from) + builder.make_data_job_urn_with_flow(flow_urn, incoming_from), ) for outgoing_to in outgoing: if outgoing_to in self.nifi_flow.remotely_accessible_ports.keys(): dataset_name = f"{self.config.site_name}.{self.nifi_flow.remotely_accessible_ports[outgoing_to].name}" dataset_urn = builder.make_dataset_urn( - NIFI, dataset_name, self.config.env + NIFI, + dataset_name, + self.config.env, ) component.outlets[dataset_urn] = ExternalDataset( NIFI, @@ -1005,10 +1026,15 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 site_name = self.config.site_url_to_site_name[site_url] dataset_name = f"{site_name}.{component.name}" dataset_urn = builder.make_dataset_urn( - NIFI, dataset_name, self.config.env + NIFI, + dataset_name, + self.config.env, ) component.outlets[dataset_urn] = ExternalDataset( - NIFI, dataset_name, dict(nifi_uri=site_url), dataset_urn + NIFI, + dataset_name, + dict(nifi_uri=site_url), + dataset_urn, ) break @@ -1027,17 +1053,22 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 dataset_name = f"{site_name}.{component.name}" dataset_urn = builder.make_dataset_urn( - NIFI, dataset_name, self.config.env + NIFI, + dataset_name, + self.config.env, ) component.inlets[dataset_urn] = ExternalDataset( - NIFI, dataset_name, dict(nifi_uri=site_url), dataset_urn + NIFI, + dataset_name, + dict(nifi_uri=site_url), + dataset_urn, ) break if self.config.emit_process_group_as_container: # We emit process groups only for all nifi components qualifying as datajobs yield from self.construct_process_group_workunits( - component.parent_group_id + component.parent_group_id, ) yield from self.construct_job_workunits( @@ -1045,7 +1076,9 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 job_name, component.parent_group_id, external_url=self.make_external_url( - component.parent_group_id, component.id, component.parent_rpg_id + component.parent_group_id, + component.id, + component.parent_rpg_id, ), job_type=NIFI.upper() + "_" + component.nifi_type.value, description=component.comments, @@ -1067,13 +1100,15 @@ def construct_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901 def make_flow_urn(self) -> str: return builder.make_data_flow_urn( - NIFI, self.nifi_flow.root_process_group.id, self.config.env + NIFI, + self.nifi_flow.root_process_group.id, + self.config.env, ) def process_provenance_events(self): logger.debug("Starting processing of provenance events") startDate = datetime.now(timezone.utc) - timedelta( - days=self.config.provenance_days + days=self.config.provenance_days, ) eventAnalyzer = NifiProcessorProvenanceEventAnalyzer() @@ -1082,7 +1117,7 @@ def process_provenance_events(self): logger.debug(f"Processing {len(components)} components") for component in components: logger.debug( - f"Processing provenance events for component id: {component.id} name: {component.name}" + f"Processing provenance events for component id: {component.id} name: {component.name}", ) if component.nifi_type is NifiType.PROCESSOR: eventType = eventAnalyzer.KNOWN_INGRESS_EGRESS_PROCESORS[component.type] @@ -1109,12 +1144,13 @@ def authenticate(self): assert self.config.username is not None assert self.config.password is not None self.session.auth = HTTPBasicAuth( - self.config.username, self.config.password + self.config.username, + self.config.password, ) self.session.headers.update( { "Content-Type": "application/x-www-form-urlencoded", - } + }, ) return @@ -1162,7 +1198,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.create_nifi_flow() except Exception as e: self.report.failure( - "Failed to get root process group flow", self.config.site_url, exc=e + "Failed to get root process group flow", + self.config.site_url, + exc=e, ) return @@ -1254,7 +1292,8 @@ def construct_job_workunits( for patch_mcp in patch_builder.build(): logger.debug(f"Preparing Patch MCP: {patch_mcp}") yield MetadataWorkUnit( - id=f"{job_urn}-{patch_mcp.aspectName}", mcp_raw=patch_mcp + id=f"{job_urn}-{patch_mcp.aspectName}", + mcp_raw=patch_mcp, ) else: yield MetadataChangeProposalWrapper( @@ -1267,7 +1306,9 @@ def construct_job_workunits( ).as_workunit() def gen_browse_path_v2_workunit( - self, entity_urn: str, process_group_id: str + self, + entity_urn: str, + process_group_id: str, ) -> MetadataWorkUnit: flow_urn = self.make_flow_urn() return MetadataChangeProposalWrapper( @@ -1276,12 +1317,13 @@ def gen_browse_path_v2_workunit( path=[ BrowsePathEntryClass(id=flow_urn, urn=flow_urn), *self._get_browse_path_v2_entries(process_group_id), - ] + ], ), ).as_workunit() def _get_browse_path_v2_entries( - self, process_group_id: str + self, + process_group_id: str, ) -> List[BrowsePathEntryClass]: """Browse path entries till current process group""" if self._is_root_process_group(process_group_id): @@ -1292,13 +1334,14 @@ def _get_browse_path_v2_entries( current_process_group.parent_group_id ) # always present for non-root process group parent_browse_path = self._get_browse_path_v2_entries( - current_process_group.parent_group_id + current_process_group.parent_group_id, ) if self.config.emit_process_group_as_container: container_urn = self.gen_process_group_key(process_group_id).as_urn() current_browse_entry = BrowsePathEntryClass( - id=container_urn, urn=container_urn + id=container_urn, + urn=container_urn, ) else: current_browse_entry = BrowsePathEntryClass(id=current_process_group.name) @@ -1308,7 +1351,8 @@ def _is_root_process_group(self, process_group_id: str) -> bool: return self.nifi_flow.root_process_group.id == process_group_id def construct_process_group_workunits( - self, process_group_id: str + self, + process_group_id: str, ) -> Iterable[MetadataWorkUnit]: if ( self._is_root_process_group(process_group_id) @@ -1336,12 +1380,15 @@ def construct_process_group_workunits( if self._is_root_process_group(pg.parent_group_id): yield self.gen_browse_path_v2_workunit( - container_key.as_urn(), pg.parent_group_id + container_key.as_urn(), + pg.parent_group_id, ) def gen_process_group_key(self, process_group_id: str) -> ProcessGroupKey: return ProcessGroupKey( - process_group_id=process_group_id, platform=NIFI, env=self.config.env + process_group_id=process_group_id, + platform=NIFI, + env=self.config.env, ) def construct_dataset_workunits( @@ -1354,19 +1401,22 @@ def construct_dataset_workunits( ) -> Iterable[MetadataWorkUnit]: if not dataset_urn: dataset_urn = builder.make_dataset_urn( - dataset_platform, dataset_name, self.config.env + dataset_platform, + dataset_name, + self.config.env, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=DataPlatformInstanceClass( - platform=builder.make_data_platform_urn(dataset_platform) + platform=builder.make_data_platform_urn(dataset_platform), ), ).as_workunit() yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=DatasetPropertiesClass( - externalUrl=external_url, customProperties=datasetProperties + externalUrl=external_url, + customProperties=datasetProperties, ), ).as_workunit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/openapi.py b/metadata-ingestion/src/datahub/ingestion/source/openapi.py index 2075e999ea1d0e..7e0aa98a2535bc 100755 --- a/metadata-ingestion/src/datahub/ingestion/source/openapi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/openapi.py @@ -50,16 +50,19 @@ class OpenApiConfig(ConfigModel): name: str = Field(description="Name of ingestion.") url: str = Field(description="Endpoint URL. e.g. https://example.com") swagger_file: str = Field( - description="Route for access to the swagger file. e.g. openapi.json" + description="Route for access to the swagger file. e.g. openapi.json", ) ignore_endpoints: list = Field( - default=[], description="List of endpoints to ignore during ingestion." + default=[], + description="List of endpoints to ignore during ingestion.", ) username: str = Field( - default="", description="Username used for basic HTTP authentication." + default="", + description="Username used for basic HTTP authentication.", ) password: str = Field( - default="", description="Password used for basic HTTP authentication." + default="", + description="Password used for basic HTTP authentication.", ) proxies: Optional[dict] = Field( default=None, @@ -73,18 +76,23 @@ class OpenApiConfig(ConfigModel): description="If no example is provided for a route, it is possible to create one using forced_example.", ) token: Optional[str] = Field( - default=None, description="Token for endpoint authentication." + default=None, + description="Token for endpoint authentication.", ) bearer_token: Optional[str] = Field( - default=None, description="Bearer token for endpoint authentication." + default=None, + description="Bearer token for endpoint authentication.", ) get_token: dict = Field( - default={}, description="Retrieving a token from the endpoint." + default={}, + description="Retrieving a token from the endpoint.", ) @validator("bearer_token", always=True) def ensure_only_one_token( - cls, bearer_token: Optional[str], values: Dict + cls, + bearer_token: Optional[str], + values: Dict, ) -> Optional[str]: if bearer_token is not None and values.get("token") is not None: raise ValueError("Unable to use 'token' and 'bearer_token' together.") @@ -112,14 +120,15 @@ def get_swagger(self) -> Dict: "we expect the keyword {password} to be present in the url" ) url4req = self.get_token["url_complement"].replace( - "{username}", self.username + "{username}", + self.username, ) url4req = url4req.replace("{password}", self.password) elif self.get_token["request_type"] == "post": url4req = self.get_token["url_complement"] else: raise KeyError( - "This tool accepts only 'get' and 'post' as method for getting tokens" + "This tool accepts only 'get' and 'post' as method for getting tokens", ) self.token = get_tok( url=self.url, @@ -217,11 +226,13 @@ def report_bad_responses(self, status_code: int, type: str) -> None: ) else: raise Exception( - f"Unable to retrieve endpoint, response code {status_code}, key {type}" + f"Unable to retrieve endpoint, response code {status_code}, key {type}", ) def init_dataset( - self, endpoint_k: str, endpoint_dets: dict + self, + endpoint_k: str, + endpoint_dets: dict, ) -> Tuple[DatasetSnapshot, str]: config = self.config @@ -240,7 +251,8 @@ def init_dataset( # adding description dataset_properties = DatasetPropertiesClass( - description=endpoint_dets["description"], customProperties={} + description=endpoint_dets["description"], + customProperties={}, ) dataset_snapshot.aspects.append(dataset_properties) @@ -254,10 +266,14 @@ def init_dataset( link_url = clean_url(config.url + self.url_basepath + endpoint_k) link_description = "Link to call for the dataset." creation = AuditStampClass( - time=int(time.time()), actor="urn:li:corpuser:etl", impersonator=None + time=int(time.time()), + actor="urn:li:corpuser:etl", + impersonator=None, ) link_metadata = InstitutionalMemoryMetadataClass( - url=link_url, description=link_description, createStamp=creation + url=link_url, + description=link_description, + createStamp=creation, ) inst_memory = InstitutionalMemoryClass([link_metadata]) dataset_snapshot.aspects.append(inst_memory) @@ -265,7 +281,9 @@ def init_dataset( return dataset_snapshot, dataset_name def build_wu( - self, dataset_snapshot: DatasetSnapshot, dataset_name: str + self, + dataset_snapshot: DatasetSnapshot, + dataset_name: str, ) -> ApiWorkUnit: mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) return ApiWorkUnit(id=dataset_name, mce=mce) @@ -295,7 +313,8 @@ def get_workunits_internal(self) -> Iterable[ApiWorkUnit]: # noqa: C901 continue dataset_snapshot, dataset_name = self.init_dataset( - endpoint_k, endpoint_dets + endpoint_k, + endpoint_dets, ) # adding dataset fields @@ -330,7 +349,8 @@ def get_workunits_internal(self) -> Iterable[ApiWorkUnit]: # noqa: C901 ) if response.status_code == 200: fields2add, root_dataset_samples[dataset_name] = extract_fields( - response, dataset_name + response, + dataset_name, ) if not fields2add: self.report.info( @@ -376,7 +396,8 @@ def get_workunits_internal(self) -> Iterable[ApiWorkUnit]: # noqa: C901 self.report_bad_responses(response.status_code, type=endpoint_k) else: composed_url = compose_url_attr( - raw_url=endpoint_k, attr_list=config.forced_examples[endpoint_k] + raw_url=endpoint_k, + attr_list=config.forced_examples[endpoint_k], ) tot_url = clean_url(config.url + self.url_basepath + composed_url) if config.token: diff --git a/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py b/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py index 5bacafaa3f5885..389c031bb99556 100755 --- a/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py +++ b/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py @@ -56,7 +56,9 @@ def request_call( headers = {"accept": "application/json"} if username is not None and password is not None: return requests.get( - url, headers=headers, auth=HTTPBasicAuth(username, password) + url, + headers=headers, + auth=HTTPBasicAuth(username, password), ) elif token is not None: headers["Authorization"] = f"{token}" @@ -75,7 +77,11 @@ def get_swag_json( ) -> Dict: tot_url = url + swagger_file response = request_call( - url=tot_url, token=token, username=username, password=password, proxies=proxies + url=tot_url, + token=token, + username=username, + password=password, + proxies=proxies, ) if response.status_code != 200: @@ -107,7 +113,7 @@ def check_sw_version(sw_dict: dict) -> None: if version[0] == 3 and version[1] > 0: logger.warning( - "This plugin has not been fully tested with Swagger version >3.0" + "This plugin has not been fully tested with Swagger version >3.0", ) @@ -177,7 +183,7 @@ def check_for_api_example_data(base_res: dict, key: str) -> dict: data = res_cont["application/json"][ex_field][0] else: logger.warning( - f"Field in swagger file does not give consistent data --- {key}" + f"Field in swagger file does not give consistent data --- {key}", ) elif "text/csv" in res_cont.keys(): data = res_cont["text/csv"]["schema"] @@ -294,7 +300,8 @@ def clean_url(url: str) -> str: def extract_fields( - response: requests.Response, dataset_name: str + response: requests.Response, + dataset_name: str, ) -> Tuple[List[Any], Dict[Any, Any]]: """ Given a URL, this function will extract the fields contained in the @@ -381,7 +388,9 @@ def get_tok( def set_metadata( - dataset_name: str, fields: List, platform: str = "api" + dataset_name: str, + fields: List, + platform: str = "api", ) -> SchemaMetadata: canonical_schema: List[SchemaField] = [] diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py index 14beab6bc9391e..548aaa05b03fc5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py @@ -148,19 +148,23 @@ class PowerBIPlatformDetail: class SupportedDataPlatform(Enum): POSTGRES_SQL = DataPlatformPair( - powerbi_data_platform_name="PostgreSQL", datahub_data_platform_name="postgres" + powerbi_data_platform_name="PostgreSQL", + datahub_data_platform_name="postgres", ) ORACLE = DataPlatformPair( - powerbi_data_platform_name="Oracle", datahub_data_platform_name="oracle" + powerbi_data_platform_name="Oracle", + datahub_data_platform_name="oracle", ) SNOWFLAKE = DataPlatformPair( - powerbi_data_platform_name="Snowflake", datahub_data_platform_name="snowflake" + powerbi_data_platform_name="Snowflake", + datahub_data_platform_name="snowflake", ) MS_SQL = DataPlatformPair( - powerbi_data_platform_name="Sql", datahub_data_platform_name="mssql" + powerbi_data_platform_name="Sql", + datahub_data_platform_name="mssql", ) GOOGLE_BIGQUERY = DataPlatformPair( @@ -174,7 +178,8 @@ class SupportedDataPlatform(Enum): ) DATABRICKS_SQL = DataPlatformPair( - powerbi_data_platform_name="Databricks", datahub_data_platform_name="databricks" + powerbi_data_platform_name="Databricks", + datahub_data_platform_name="databricks", ) DatabricksMultiCloud_SQL = DataPlatformPair( @@ -187,10 +192,10 @@ class SupportedDataPlatform(Enum): class PowerBiDashboardSourceReport(StaleEntityRemovalSourceReport): all_workspace_count: int = 0 filtered_workspace_names: LossyList[str] = dataclass_field( - default_factory=LossyList + default_factory=LossyList, ) filtered_workspace_types: LossyList[str] = dataclass_field( - default_factory=LossyList + default_factory=LossyList, ) dashboards_scanned: int = 0 @@ -244,7 +249,8 @@ class DataBricksPlatformDetail(PlatformDetail): class OwnershipMapping(ConfigModel): create_corp_user: bool = pydantic.Field( - default=True, description="Whether ingest PowerBI user as Datahub Corpuser" + default=True, + description="Whether ingest PowerBI user as Datahub Corpuser", ) use_powerbi_email: bool = pydantic.Field( # TODO: Deprecate and remove this config, since the non-email format @@ -274,10 +280,12 @@ class PowerBiProfilingConfig(ConfigModel): class PowerBiDashboardSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, ): platform_name: str = pydantic.Field( - default=Constant.PLATFORM_NAME, hidden_from_docs=True + default=Constant.PLATFORM_NAME, + hidden_from_docs=True, ) platform_urn: str = pydantic.Field( @@ -315,7 +323,8 @@ class PowerBiDashboardSourceConfig( ) # PowerBI datasource's server to platform instance mapping server_to_platform_instance: Dict[ - str, Union[PlatformDetail, DataBricksPlatformDetail] + str, + Union[PlatformDetail, DataBricksPlatformDetail], ] = pydantic.Field( default={}, description="A mapping of PowerBI datasource's server i.e host[:port] to Data platform instance." @@ -334,7 +343,8 @@ class PowerBiDashboardSourceConfig( client_secret: str = pydantic.Field(description="Azure app client secret") # timeout for meta-data scanning scan_timeout: int = pydantic.Field( - default=60, description="timeout for PowerBI metadata scanning" + default=60, + description="timeout for PowerBI metadata scanning", ) scan_batch_size: int = pydantic.Field( default=1, @@ -355,7 +365,8 @@ class PowerBiDashboardSourceConfig( ) # Enable/Disable extracting report information extract_reports: bool = pydantic.Field( - default=True, description="Whether reports should be ingested" + default=True, + description="Whether reports should be ingested", ) # Configure ingestion of ownership ownership: OwnershipMapping = pydantic.Field( @@ -393,7 +404,8 @@ class PowerBiDashboardSourceConfig( ) # Enable/Disable extracting workspace information to DataHub containers extract_workspaces_to_containers: bool = pydantic.Field( - default=True, description="Extract workspaces to DataHub containers" + default=True, + description="Extract workspaces to DataHub containers", ) # Enable/Disable grouping PBI dataset tables into Datahub container (PBI Dataset) extract_datasets_to_containers: bool = pydantic.Field( @@ -419,7 +431,8 @@ class PowerBiDashboardSourceConfig( ) # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="PowerBI Stateful Ingestion Config." + default=None, + description="PowerBI Stateful Ingestion Config.", ) # Retrieve PowerBI Metadata using Admin API only admin_apis_only: bool = pydantic.Field( @@ -471,7 +484,11 @@ class PowerBiDashboardSourceConfig( workspace_type_filter: List[ Literal[ - "Workspace", "PersonalGroup", "Personal", "AdminWorkspace", "AdminInsights" + "Workspace", + "PersonalGroup", + "Personal", + "AdminWorkspace", + "AdminInsights", ] ] = pydantic.Field( default=["Workspace"], @@ -550,15 +567,15 @@ def workspace_id_backward_compatibility(cls, values: Dict) -> Dict: if workspace_id_pattern == AllowDenyPattern.allow_all() and workspace_id: logger.warning( "workspace_id_pattern is not set but workspace_id is set, setting workspace_id as " - "workspace_id_pattern. workspace_id will be deprecated, please use workspace_id_pattern instead." + "workspace_id_pattern. workspace_id will be deprecated, please use workspace_id_pattern instead.", ) values["workspace_id_pattern"] = AllowDenyPattern( - allow=[f"^{workspace_id}$"] + allow=[f"^{workspace_id}$"], ) elif workspace_id_pattern != AllowDenyPattern.allow_all() and workspace_id: logger.warning( "workspace_id will be ignored in favour of workspace_id_pattern. workspace_id will be deprecated, " - "please use workspace_id_pattern only." + "please use workspace_id_pattern only.", ) values.pop("workspace_id") return values @@ -570,7 +587,7 @@ def raise_error_for_dataset_type_mapping(cls, values: Dict) -> Dict: and values.get("server_to_platform_instance") is not None ): raise ValueError( - "dataset_type_mapping is deprecated. Use server_to_platform_instance only." + "dataset_type_mapping is deprecated. Use server_to_platform_instance only.", ) return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py index 6d51e853a2fb06..b2ebea5a120178 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py @@ -14,13 +14,15 @@ class AbstractDataPlatformInstanceResolver(ABC): @abstractmethod def get_platform_instance( - self, data_platform_detail: PowerBIPlatformDetail + self, + data_platform_detail: PowerBIPlatformDetail, ) -> PlatformDetail: pass class BaseAbstractDataPlatformInstanceResolver( - AbstractDataPlatformInstanceResolver, ABC + AbstractDataPlatformInstanceResolver, + ABC, ): config: PowerBiDashboardSourceConfig @@ -29,10 +31,11 @@ def __init__(self, config): class ResolvePlatformInstanceFromDatasetTypeMapping( - BaseAbstractDataPlatformInstanceResolver + BaseAbstractDataPlatformInstanceResolver, ): def get_platform_instance( - self, data_platform_detail: PowerBIPlatformDetail + self, + data_platform_detail: PowerBIPlatformDetail, ) -> PlatformDetail: platform: Union[str, PlatformDetail] = self.config.dataset_type_mapping[ data_platform_detail.data_platform_pair.powerbi_data_platform_name @@ -45,10 +48,11 @@ def get_platform_instance( class ResolvePlatformInstanceFromServerToPlatformInstance( - BaseAbstractDataPlatformInstanceResolver + BaseAbstractDataPlatformInstanceResolver, ): def get_platform_instance( - self, data_platform_detail: PowerBIPlatformDetail + self, + data_platform_detail: PowerBIPlatformDetail, ) -> PlatformDetail: return ( self.config.server_to_platform_instance[ @@ -65,11 +69,11 @@ def create_dataplatform_instance_resolver( ) -> AbstractDataPlatformInstanceResolver: if config.server_to_platform_instance: logger.debug( - "Creating resolver to resolve platform instance from server_to_platform_instance" + "Creating resolver to resolve platform instance from server_to_platform_instance", ) return ResolvePlatformInstanceFromServerToPlatformInstance(config) logger.debug( - "Creating resolver to resolve platform instance from dataset_type_mapping" + "Creating resolver to resolve platform instance from dataset_type_mapping", ) return ResolvePlatformInstanceFromDatasetTypeMapping(config) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/parser.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/parser.py index 759fc6d7dadfba..5d65eb8c257df0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/parser.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/parser.py @@ -32,7 +32,8 @@ def get_lark_parser() -> Lark: # Read lexical grammar as text grammar: str = pkg_resource.read_text( - "datahub.ingestion.source.powerbi", "powerbi-lexical-grammar.rule" + "datahub.ingestion.source.powerbi", + "powerbi-lexical-grammar.rule", ) # Create lark parser for the grammar text return Lark(grammar, start="let_expression", regex=True) @@ -74,12 +75,13 @@ def get_upstream_tables( parameters = parameters or {} logger.debug( - f"Processing {table.full_name} m-query expression for lineage extraction. Expression = {table.expression}" + f"Processing {table.full_name} m-query expression for lineage extraction. Expression = {table.expression}", ) try: valid, message = validator.validate_parse_tree( - table.expression, native_query_enabled=config.native_query_parsing + table.expression, + native_query_enabled=config.native_query_parsing, ) if valid is False: assert message is not None @@ -95,7 +97,8 @@ def get_upstream_tables( with reporter.m_query_parse_timer: reporter.m_query_parse_attempts += 1 parse_tree: Tree = _parse_expression( - table.expression, parse_timeout=config.m_query_parse_timeout + table.expression, + parse_timeout=config.m_query_parse_timeout, ) except KeyboardInterrupt: diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/pattern_handler.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/pattern_handler.py index 54b810650f5854..672c38a3bc5400 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/pattern_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/pattern_handler.py @@ -63,7 +63,7 @@ def make_urn( PowerBIPlatformDetail( data_platform_pair=data_platform_pair, data_platform_server=server, - ) + ), ) return builder.make_dataset_urn_with_platform_instance( @@ -71,7 +71,8 @@ def make_urn( platform_instance=platform_detail.platform_instance, env=platform_detail.env, name=urn_to_lowercase( - qualified_table_name, config.convert_lineage_urns_to_lowercase + qualified_table_name, + config.convert_lineage_urns_to_lowercase, ), ) @@ -132,7 +133,8 @@ def __init__( @abstractmethod def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: pass @@ -146,7 +148,7 @@ def get_db_detail_from_argument( ) -> Tuple[Optional[str], Optional[str]]: arguments: List[str] = tree_function.strip_char_from_list( values=tree_function.remove_whitespaces_from_list( - tree_function.token_values(arg_list) + tree_function.token_values(arg_list), ), ) @@ -163,7 +165,7 @@ def create_reference_table( ) -> Optional[ReferencedTable]: arguments: List[str] = tree_function.strip_char_from_list( values=tree_function.remove_whitespaces_from_list( - tree_function.token_values(arg_list) + tree_function.token_values(arg_list), ), ) @@ -197,7 +199,11 @@ def create_reference_table( return None def parse_custom_sql( - self, query: str, server: str, database: Optional[str], schema: Optional[str] + self, + query: str, + server: str, + database: Optional[str], + schema: Optional[str], ) -> Lineage: dataplatform_tables: List[DataPlatformTable] = [] @@ -206,12 +212,12 @@ def parse_custom_sql( PowerBIPlatformDetail( data_platform_pair=self.get_platform_pair(), data_platform_server=server, - ) + ), ) ) query = native_sql_parser.remove_drop_statement( - native_sql_parser.remove_special_characters(query) + native_sql_parser.remove_special_characters(query), ) parsed_result: Optional["SqlParsingResult"] = ( @@ -247,7 +253,7 @@ def parse_custom_sql( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ) logger.debug(f"Native Query parsed result={parsed_result}") @@ -268,20 +274,22 @@ def get_platform_pair(self) -> DataPlatformPair: return SupportedDataPlatform.AMAZON_REDSHIFT.value def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: logger.debug( - f"Processing AmazonRedshift data-access function detail {data_access_func_detail}" + f"Processing AmazonRedshift data-access function detail {data_access_func_detail}", ) server, db_name = self.get_db_detail_from_argument( - data_access_func_detail.arg_list + data_access_func_detail.arg_list, ) if db_name is None or server is None: return Lineage.empty() # Return an empty list schema_name: str = cast( - IdentifierAccessor, data_access_func_detail.identifier_accessor + IdentifierAccessor, + data_access_func_detail.identifier_accessor, ).items["Name"] table_name: str = cast( @@ -304,7 +312,7 @@ def create_lineage( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ], column_lineage=[], ) @@ -330,14 +338,15 @@ def _get_server_and_db_name(value: str) -> Tuple[Optional[str], Optional[str]]: return tree_function.strip_char_from_list([splitter_result[0]])[0], db_name def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: logger.debug( - f"Processing Oracle data-access function detail {data_access_func_detail}" + f"Processing Oracle data-access function detail {data_access_func_detail}", ) arguments: List[str] = tree_function.remove_whitespaces_from_list( - tree_function.token_values(data_access_func_detail.arg_list) + tree_function.token_values(data_access_func_detail.arg_list), ) server, db_name = self._get_server_and_db_name(arguments[0]) @@ -346,7 +355,8 @@ def create_lineage( return Lineage.empty() schema_name: str = cast( - IdentifierAccessor, data_access_func_detail.identifier_accessor + IdentifierAccessor, + data_access_func_detail.identifier_accessor, ).items["Schema"] table_name: str = cast( @@ -369,7 +379,7 @@ def create_lineage( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ], column_lineage=[], ) @@ -386,7 +396,7 @@ def form_qualified_table_name( PowerBIPlatformDetail( data_platform_pair=data_platform_pair, data_platform_server=table_reference.warehouse, - ) + ), ) ) @@ -403,10 +413,11 @@ def form_qualified_table_name( return qualified_table_name def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: logger.debug( - f"Processing Databrick data-access function detail {data_access_func_detail}" + f"Processing Databrick data-access function detail {data_access_func_detail}", ) table_detail: Dict[str, str] = {} temp_accessor: Optional[IdentifierAccessor] = ( @@ -454,7 +465,7 @@ def create_lineage( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ], column_lineage=[], ) @@ -476,30 +487,33 @@ class TwoStepDataAccessPattern(AbstractLineage, ABC): """ def two_level_access_pattern( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: logger.debug( - f"Processing {self.get_platform_pair().powerbi_data_platform_name} data-access function detail {data_access_func_detail}" + f"Processing {self.get_platform_pair().powerbi_data_platform_name} data-access function detail {data_access_func_detail}", ) server, db_name = self.get_db_detail_from_argument( - data_access_func_detail.arg_list + data_access_func_detail.arg_list, ) if server is None or db_name is None: return Lineage.empty() # Return an empty list schema_name: str = cast( - IdentifierAccessor, data_access_func_detail.identifier_accessor + IdentifierAccessor, + data_access_func_detail.identifier_accessor, ).items["Schema"] table_name: str = cast( - IdentifierAccessor, data_access_func_detail.identifier_accessor + IdentifierAccessor, + data_access_func_detail.identifier_accessor, ).items["Item"] qualified_table_name: str = f"{db_name}.{schema_name}.{table_name}" logger.debug( - f"Platform({self.get_platform_pair().datahub_data_platform_name}) qualified_table_name= {qualified_table_name}" + f"Platform({self.get_platform_pair().datahub_data_platform_name}) qualified_table_name= {qualified_table_name}", ) urn = make_urn( @@ -514,7 +528,7 @@ def two_level_access_pattern( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ], column_lineage=[], ) @@ -522,7 +536,8 @@ def two_level_access_pattern( class PostgresLineage(TwoStepDataAccessPattern): def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: return self.two_level_access_pattern(data_access_func_detail) @@ -538,7 +553,10 @@ def get_platform_pair(self) -> DataPlatformPair: return SupportedDataPlatform.MS_SQL.value def create_urn_using_old_parser( - self, query: str, db_name: str, server: str + self, + query: str, + db_name: str, + server: str, ) -> List[DataPlatformTable]: dataplatform_tables: List[DataPlatformTable] = [] @@ -576,7 +594,7 @@ def create_urn_using_old_parser( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ) logger.debug(f"Generated upstream tables = {dataplatform_tables}") @@ -584,16 +602,17 @@ def create_urn_using_old_parser( return dataplatform_tables def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: arguments: List[str] = tree_function.strip_char_from_list( values=tree_function.remove_whitespaces_from_list( - tree_function.token_values(data_access_func_detail.arg_list) + tree_function.token_values(data_access_func_detail.arg_list), ), ) server, database = self.get_db_detail_from_argument( - data_access_func_detail.arg_list + data_access_func_detail.arg_list, ) if server is None or database is None: return Lineage.empty() # Return an empty list @@ -628,19 +647,22 @@ def create_lineage( class ThreeStepDataAccessPattern(AbstractLineage, ABC): def get_datasource_server( - self, arguments: List[str], data_access_func_detail: DataAccessFunctionDetail + self, + arguments: List[str], + data_access_func_detail: DataAccessFunctionDetail, ) -> str: return tree_function.strip_char_from_list([arguments[0]])[0] def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: logger.debug( - f"Processing {self.get_platform_pair().datahub_data_platform_name} function detail {data_access_func_detail}" + f"Processing {self.get_platform_pair().datahub_data_platform_name} function detail {data_access_func_detail}", ) arguments: List[str] = tree_function.remove_whitespaces_from_list( - tree_function.token_values(data_access_func_detail.arg_list) + tree_function.token_values(data_access_func_detail.arg_list), ) # First is database name db_name: str = data_access_func_detail.identifier_accessor.items["Name"] # type: ignore @@ -658,7 +680,7 @@ def create_lineage( qualified_table_name: str = f"{db_name}.{schema_name}.{table_name}" logger.debug( - f"{self.get_platform_pair().datahub_data_platform_name} qualified_table_name {qualified_table_name}" + f"{self.get_platform_pair().datahub_data_platform_name} qualified_table_name {qualified_table_name}", ) server: str = self.get_datasource_server(arguments, data_access_func_detail) @@ -676,7 +698,7 @@ def create_lineage( DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ], column_lineage=[], ) @@ -692,7 +714,9 @@ def get_platform_pair(self) -> DataPlatformPair: return SupportedDataPlatform.GOOGLE_BIGQUERY.value def get_datasource_server( - self, arguments: List[str], data_access_func_detail: DataAccessFunctionDetail + self, + arguments: List[str], + data_access_func_detail: DataAccessFunctionDetail, ) -> str: # In Google BigQuery server is project-name # condition to silent lint, it is not going to be None @@ -729,7 +753,7 @@ def create_urn_using_old_parser(self, query: str, server: str) -> Lineage: for qualified_table_name in tables: if len(qualified_table_name.split(".")) != 3: logger.debug( - f"Skipping table {qualified_table_name} as it is not as per qualified_table_name format" + f"Skipping table {qualified_table_name} as it is not as per qualified_table_name format", ) continue @@ -745,7 +769,7 @@ def create_urn_using_old_parser(self, query: str, server: str) -> Lineage: DataPlatformTable( data_platform_pair=self.get_platform_pair(), urn=urn, - ) + ), ) logger.debug(f"Generated dataplatform_tables {dataplatform_tables}") @@ -771,39 +795,42 @@ def get_db_name(self, data_access_tokens: List[str]) -> Optional[str]: return ( get_next_item( # database name is set in Name argument - data_access_tokens, "Name" + data_access_tokens, + "Name", ) or get_next_item( # If both above arguments are not available, then try Catalog - data_access_tokens, "Catalog" + data_access_tokens, + "Catalog", ) ) def create_lineage( - self, data_access_func_detail: DataAccessFunctionDetail + self, + data_access_func_detail: DataAccessFunctionDetail, ) -> Lineage: t1: Optional[Tree] = tree_function.first_arg_list_func( - data_access_func_detail.arg_list + data_access_func_detail.arg_list, ) assert t1 is not None flat_argument_list: List[Tree] = tree_function.flat_argument_list(t1) if len(flat_argument_list) != 2: logger.debug( - f"Expecting 2 argument, actual argument count is {len(flat_argument_list)}" + f"Expecting 2 argument, actual argument count is {len(flat_argument_list)}", ) logger.debug(f"Flat argument list = {flat_argument_list}") return Lineage.empty() data_access_tokens: List[str] = tree_function.remove_whitespaces_from_list( - tree_function.token_values(flat_argument_list[0]) + tree_function.token_values(flat_argument_list[0]), ) if not self.is_native_parsing_supported(data_access_tokens[0]): logger.debug( - f"Unsupported native-query data-platform = {data_access_tokens[0]}" + f"Unsupported native-query data-platform = {data_access_tokens[0]}", ) logger.debug( - f"NativeQuery is supported only for {self.SUPPORTED_NATIVE_QUERY_DATA_PLATFORM}" + f"NativeQuery is supported only for {self.SUPPORTED_NATIVE_QUERY_DATA_PLATFORM}", ) return Lineage.empty() @@ -811,7 +838,7 @@ def create_lineage( if len(data_access_tokens[0]) < 3: logger.debug( f"Server is not available in argument list for data-platform {data_access_tokens[0]}. Returning empty " - "list" + "list", ) return Lineage.empty() @@ -821,7 +848,7 @@ def create_lineage( # The First argument is the query sql_query: str = tree_function.strip_char_from_list( values=tree_function.remove_whitespaces_from_list( - tree_function.token_values(flat_argument_list[1]) + tree_function.token_values(flat_argument_list[1]), ), )[0] # Remove any whitespaces and double quotes character diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/resolver.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/resolver.py index 42963c08d992d1..47ef3f81279492 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/resolver.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/resolver.py @@ -77,7 +77,7 @@ def get_item_selector_tokens( expression_tree: Tree, ) -> Tuple[Optional[str], Optional[Dict[str, str]]]: item_selector: Optional[Tree] = tree_function.first_item_selector_func( - expression_tree + expression_tree, ) if item_selector is None: logger.debug("Item Selector not found in tree") @@ -85,7 +85,7 @@ def get_item_selector_tokens( return None, None identifier_tree: Optional[Tree] = tree_function.first_identifier_func( - expression_tree + expression_tree, ) if identifier_tree is None: logger.debug("Identifier not found in tree") @@ -95,11 +95,12 @@ def get_item_selector_tokens( # remove whitespaces and quotes from token tokens: List[str] = tree_function.strip_char_from_list( tree_function.remove_whitespaces_from_list( - tree_function.token_values(item_selector, parameters=self.parameters) + tree_function.token_values(item_selector, parameters=self.parameters), ), ) identifier: List[str] = tree_function.token_values( - identifier_tree, parameters={} + identifier_tree, + parameters={}, ) # convert tokens to dict @@ -110,7 +111,7 @@ def get_item_selector_tokens( @staticmethod def get_argument_list(invoke_expression: Tree) -> Optional[Tree]: argument_list: Optional[Tree] = tree_function.first_arg_list_func( - invoke_expression + invoke_expression, ) if argument_list is None: logger.debug("First argument-list rule not found in input tree") @@ -124,7 +125,7 @@ def take_first_argument(self, expression: Tree) -> Optional[Tree]: if first_arg_tree is None: logger.debug( - f"Function invocation without argument in expression = {expression.pretty()}" + f"Function invocation without argument in expression = {expression.pretty()}", ) self.reporter.report_warning( f"{self.table.full_name}-variable-statement", @@ -134,7 +135,8 @@ def take_first_argument(self, expression: Tree) -> Optional[Tree]: return first_arg_tree def _process_invoke_expression( - self, invoke_expression: Tree + self, + invoke_expression: Tree, ) -> Union[DataAccessFunctionDetail, List[str], None]: letter_tree: Tree = invoke_expression.children[0] data_access_func: str = tree_function.make_function_name(letter_tree) @@ -145,7 +147,7 @@ def _process_invoke_expression( if data_access_func in self.data_access_functions: arg_list: Optional[Tree] = MQueryResolver.get_argument_list( - invoke_expression + invoke_expression, ) if arg_list is None: self.reporter.report_warning( @@ -206,7 +208,7 @@ def _process_invoke_expression( first_argument = flat_arg_list[0] # take first argument only expression: Optional[Tree] = tree_function.first_list_expression_func( - first_argument + first_argument, ) if TRACE_POWERBI_MQUERY_PARSER: @@ -217,7 +219,7 @@ def _process_invoke_expression( expression = tree_function.first_type_expression_func(first_argument) if expression is None: logger.debug( - f"Either list_expression or type_expression is not found = {invoke_expression.pretty()}" + f"Either list_expression or type_expression is not found = {invoke_expression.pretty()}", ) self.reporter.report_warning( title="M-Query Resolver Error", @@ -227,14 +229,15 @@ def _process_invoke_expression( return None tokens: List[str] = tree_function.remove_whitespaces_from_list( - tree_function.token_values(expression) + tree_function.token_values(expression), ) logger.debug(f"Tokens in invoke expression are {tokens}") return tokens def _process_item_selector_expression( - self, rh_tree: Tree + self, + rh_tree: Tree, ) -> Tuple[Optional[str], Optional[Dict[str, str]]]: first_expression: Optional[Tree] = tree_function.first_expression_func(rh_tree) assert first_expression is not None @@ -251,17 +254,22 @@ def _create_or_update_identifier_accessor( # It is first identifier_accessor if identifier_accessor is None: return IdentifierAccessor( - identifier=new_identifier, items=key_vs_value, next=None + identifier=new_identifier, + items=key_vs_value, + next=None, ) new_identifier_accessor: IdentifierAccessor = IdentifierAccessor( - identifier=new_identifier, items=key_vs_value, next=identifier_accessor + identifier=new_identifier, + items=key_vs_value, + next=identifier_accessor, ) return new_identifier_accessor def create_data_access_functional_detail( - self, identifier: str + self, + identifier: str, ) -> List[DataAccessFunctionDetail]: table_links: List[DataAccessFunctionDetail] = [] @@ -287,7 +295,8 @@ def internal( # Examples: Source = PostgreSql.Database() # public_order_date = Source{[Schema="public",Item="order_date"]}[Data] v_statement: Optional[Tree] = tree_function.get_variable_statement( - self.parse_tree, current_identifier + self.parse_tree, + current_identifier, ) if v_statement is None: self.reporter.report_warning( @@ -330,14 +339,16 @@ def internal( else: new_identifier, key_vs_value = self._process_item_selector_expression( - rh_tree + rh_tree, ) if new_identifier is None or key_vs_value is None: logger.debug("Required information not found in rh_tree") return None new_identifier_accessor: IdentifierAccessor = ( self._create_or_update_identifier_accessor( - identifier_accessor, new_identifier, key_vs_value + identifier_accessor, + new_identifier, + key_vs_value, ) ) @@ -357,7 +368,7 @@ def resolve_to_lineage( # Find out output variable as we are doing backtracking in M-Query output_variable: Optional[str] = tree_function.get_output_variable( - self.parse_tree + self.parse_tree, ) if output_variable is None: @@ -376,11 +387,11 @@ def resolve_to_lineage( for f_detail in table_links: # Get & Check if we support data-access-function available in M-Query supported_resolver = SupportedPattern.get_pattern_handler( - f_detail.data_access_function_name + f_detail.data_access_function_name, ) if supported_resolver is None: logger.debug( - f"Resolver not found for the data-access-function {f_detail.data_access_function_name}" + f"Resolver not found for the data-access-function {f_detail.data_access_function_name}", ) self.reporter.report_warning( f"{self.table.full_name}-data-access-function", diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/tree_function.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/tree_function.py index d48e251bd00906..72e1054ae93e49 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/tree_function.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/tree_function.py @@ -20,8 +20,9 @@ def get_output_variable(root: Tree) -> Optional[str]: # Remove any spaces return "".join( strip_char_from_list( - remove_whitespaces_from_list(token_values(in_expression_tree)), " " - ) + remove_whitespaces_from_list(token_values(in_expression_tree)), + " ", + ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/validator.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/validator.py index b52977aaa41fbe..afb861291b21d1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/validator.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/validator.py @@ -7,7 +7,8 @@ def validate_parse_tree( - expression: str, native_query_enabled: bool = True + expression: str, + native_query_enabled: bool = True, ) -> Tuple[bool, Optional[str]]: """ :param expression: M-Query expression to check if supported data-function is present in expression diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py index 5e5636f2d50fe3..e09bd070ef3a64 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py @@ -140,7 +140,8 @@ def urn_to_lowercase(value: str, flag: bool) -> str: def lineage_urn_to_lowercase(self, value): return Mapper.urn_to_lowercase( - value, self.__config.convert_lineage_urns_to_lowercase + value, + self.__config.convert_lineage_urns_to_lowercase, ) def assets_urn_to_lowercase(self, value): @@ -162,7 +163,8 @@ def new_mcp( ) def _to_work_unit( - self, mcp: MetadataChangeProposalWrapper + self, + mcp: MetadataChangeProposalWrapper, ) -> EquableMetadataWorkUnit: return Mapper.EquableMetadataWorkUnit( id="{PLATFORM}-{ENTITY_URN}-{ASPECT_NAME}".format( @@ -174,7 +176,9 @@ def _to_work_unit( ) def extract_dataset_schema( - self, table: powerbi_data_classes.Table, ds_urn: str + self, + table: powerbi_data_classes.Table, + ds_urn: str, ) -> List[MetadataChangeProposalWrapper]: schema_metadata = self.to_datahub_schema(table) schema_mcp = self.new_mcp( @@ -222,13 +226,15 @@ def make_fine_grained_lineage_class( downstreams=downstream, upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=upstreams, - ) + ), ) return fine_grained_lineages def extract_lineage( - self, table: powerbi_data_classes.Table, ds_urn: str + self, + table: powerbi_data_classes.Table, + ds_urn: str, ) -> List[MetadataChangeProposalWrapper]: mcps: List[MetadataChangeProposalWrapper] = [] @@ -250,7 +256,7 @@ def extract_lineage( ) logger.debug( - f"PowerBI virtual table {table.full_name} and it's upstream dataplatform tables = {upstream_lineage}" + f"PowerBI virtual table {table.full_name} and it's upstream dataplatform tables = {upstream_lineage}", ) for lineage in upstream_lineage: @@ -276,7 +282,7 @@ def extract_lineage( self.make_fine_grained_lineage_class( lineage=lineage, dataset_urn=ds_urn, - ) + ), ) if len(upstream) > 0: @@ -379,15 +385,15 @@ def to_datahub_dataset( [ self.__config.filter_dataset_endorsements.allowed(tag) for tag in (dataset.tags or [""]) - ] + ], ): logger.debug( - "Returning empty dataset_mcps as no dataset tag matched with filter_dataset_endorsements" + "Returning empty dataset_mcps as no dataset tag matched with filter_dataset_endorsements", ) return dataset_mcps logger.debug( - f"Mapping dataset={dataset.name}(id={dataset.id}) to datahub dataset" + f"Mapping dataset={dataset.name}(id={dataset.id}) to datahub dataset", ) if self.__config.extract_datasets_to_containers: @@ -401,7 +407,7 @@ def to_datahub_dataset( name=table.full_name, platform_instance=self.__config.platform_instance, env=self.__config.env, - ) + ), ) logger.debug(f"dataset_urn={ds_urn}") @@ -444,7 +450,7 @@ def to_datahub_dataset( aspect=SubTypesClass( typeNames=[ BIContainerSubTypes.POWERBI_DATASET_TABLE, - ] + ], ), ) # normally, the person who configures the dataset will be the most accurate person for ownership @@ -498,10 +504,10 @@ def extract_profile( return if not self.__config.profile_pattern.allowed( - f"{workspace.name}.{dataset.name}.{table.name}" + f"{workspace.name}.{dataset.name}.{table.name}", ): logger.info( - f"Table {table.name} in {dataset.name}, not allowed for profiling" + f"Table {table.name} in {dataset.name}, not allowed for profiling", ) return logger.debug(f"Profiling table: {table.name}") @@ -515,7 +521,7 @@ def extract_profile( ] = [*(table.columns or []), *(table.measures or [])] for column in columns: allowed_column = self.__config.profile_pattern.allowed( - f"{workspace.name}.{dataset.name}.{table.name}.{column.name}" + f"{workspace.name}.{dataset.name}.{table.name}.{column.name}", ) if column.isHidden or not allowed_column: logger.info(f"Column {column.name} not allowed for profiling") @@ -545,7 +551,7 @@ def transform_tags(tags: List[str]) -> GlobalTagsClass: tags=[ TagAssociationClass(builder.make_tag_urn(tag_to_add)) for tag_to_add in tags - ] + ], ) def to_datahub_chart_mcp( @@ -568,7 +574,7 @@ def to_datahub_chart_mcp( logger.info(f"{Constant.CHART_URN}={chart_urn}") ds_input: List[str] = self.to_urn_set( - [x for x in ds_mcps if x.entityType == Constant.DATASET] + [x for x in ds_mcps if x.entityType == Constant.DATASET], ) def tile_custom_properties(tile: powerbi_data_classes.Tile) -> dict: @@ -654,7 +660,7 @@ def to_urn_set(self, mcps: List[MetadataChangeProposalWrapper]) -> List[str]: mcp.entityUrn for mcp in mcps if mcp is not None and mcp.entityUrn is not None - ] + ], ) def to_datahub_dashboard_mcp( @@ -735,7 +741,7 @@ def chart_custom_properties(dashboard: powerbi_data_classes.Dashboard) -> dict: # Dashboard browsePaths browse_path = BrowsePathsClass( - paths=[f"/{Constant.PLATFORM_NAME}/{dashboard.workspace_name}"] + paths=[f"/{Constant.PLATFORM_NAME}/{dashboard.workspace_name}"], ) browse_path_mcp = self.new_mcp( entity_urn=dashboard_urn, @@ -773,7 +779,8 @@ def append_container_mcp( dataset: Optional[powerbi_data_classes.PowerBIDataset] = None, ) -> None: if self.__config.extract_datasets_to_containers and isinstance( - dataset, powerbi_data_classes.PowerBIDataset + dataset, + powerbi_data_classes.PowerBIDataset, ): container_key = dataset.get_dataset_key(self.__config.platform_name) elif self.__config.extract_workspaces_to_containers and self.workspace_key: @@ -792,7 +799,8 @@ def append_container_mcp( list_of_mcps.append(mcp) def generate_container_for_workspace( - self, workspace: powerbi_data_classes.Workspace + self, + workspace: powerbi_data_classes.Workspace, ) -> Iterable[MetadataWorkUnit]: self.workspace_key = workspace.get_workspace_key( platform_name=self.__config.platform_name, @@ -812,7 +820,8 @@ def generate_container_for_workspace( return container_work_units def generate_container_for_dataset( - self, dataset: powerbi_data_classes.PowerBIDataset + self, + dataset: powerbi_data_classes.PowerBIDataset, ) -> Iterable[MetadataChangeProposalWrapper]: dataset_key = dataset.get_dataset_key(self.__config.platform_name) container_work_units = gen_containers( @@ -844,7 +853,8 @@ def append_tag_mcp( list_of_mcps.append(tags_mcp) def to_datahub_user( - self, user: powerbi_data_classes.User + self, + user: powerbi_data_classes.User, ) -> List[MetadataChangeProposalWrapper]: """ Map PowerBi user to datahub user @@ -868,7 +878,8 @@ def to_datahub_user( return [user_key_mcp] def to_datahub_users( - self, users: List[powerbi_data_classes.User] + self, + users: List[powerbi_data_classes.User], ) -> List[MetadataChangeProposalWrapper]: user_mcps = [] @@ -884,7 +895,7 @@ def to_datahub_users( user.principalType == "User" and self.__config.ownership.owner_criteria and len( - set(user_rights) & set(self.__config.ownership.owner_criteria) + set(user_rights) & set(self.__config.ownership.owner_criteria), ) > 0 ): @@ -901,7 +912,8 @@ def to_datahub_chart( tiles: List[powerbi_data_classes.Tile], workspace: powerbi_data_classes.Workspace, ) -> Tuple[ - List[MetadataChangeProposalWrapper], List[MetadataChangeProposalWrapper] + List[MetadataChangeProposalWrapper], + List[MetadataChangeProposalWrapper], ]: ds_mcps = [] chart_mcps = [] @@ -935,12 +947,12 @@ def to_datahub_work_units( mcps = [] logger.info( - f"Converting dashboard={dashboard.displayName} to datahub dashboard" + f"Converting dashboard={dashboard.displayName} to datahub dashboard", ) # Convert user to CorpUser user_mcps: List[MetadataChangeProposalWrapper] = self.to_datahub_users( - dashboard.users + dashboard.users, ) # Convert tiles to charts ds_mcps, chart_mcps = self.to_datahub_chart(dashboard.tiles, workspace) @@ -990,7 +1002,7 @@ def to_chart_mcps( logger.debug(f"{Constant.CHART_URN}={chart_urn}") ds_input: List[str] = self.to_urn_set( - [x for x in ds_mcps if x.entityType == Constant.DATASET] + [x for x in ds_mcps if x.entityType == Constant.DATASET], ) # Create chartInfo mcp @@ -1023,7 +1035,7 @@ def to_chart_mcps( # Browse path browse_path = BrowsePathsClass( - paths=[f"/{Constant.PLATFORM_NAME}/{workspace.name}"] + paths=[f"/{Constant.PLATFORM_NAME}/{workspace.name}"], ) browse_path_mcp = self.new_mcp( entity_urn=chart_urn, @@ -1116,7 +1128,7 @@ def report_to_dashboard( # Report browsePaths browse_path = BrowsePathsClass( - paths=[f"/{Constant.PLATFORM_NAME}/{workspace.name}"] + paths=[f"/{Constant.PLATFORM_NAME}/{workspace.name}"], ) browse_path_mcp = self.new_mcp( entity_urn=dashboard_urn, @@ -1224,7 +1236,7 @@ def __init__(self, config: PowerBiDashboardSourceConfig, ctx: PipelineContext): self.source_config = config self.reporter = PowerBiDashboardSourceReport() self.dataplatform_instance_resolver = create_dataplatform_instance_resolver( - self.source_config + self.source_config, ) try: self.powerbi_client = PowerBiAPI( @@ -1234,17 +1246,22 @@ def __init__(self, config: PowerBiDashboardSourceConfig, ctx: PipelineContext): except Exception as e: logger.warning(e) exit( - 1 + 1, ) # Exit pipeline as we are not able to connect to PowerBI API Service. This exit will avoid raising # unwanted stacktrace on console self.mapper = Mapper( - ctx, config, self.reporter, self.dataplatform_instance_resolver + ctx, + config, + self.reporter, + self.dataplatform_instance_resolver, ) # Create and register the stateful ingestion use-case handler. self.stale_entity_removal_handler = StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ) @staticmethod @@ -1258,7 +1275,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -1272,19 +1290,19 @@ def get_allowed_workspaces(self) -> List[powerbi_data_classes.Workspace]: logger.info(f"Number of workspaces = {len(all_workspaces)}") self.reporter.all_workspace_count = len(all_workspaces) logger.debug( - f"All workspaces: {[workspace.format_name_for_logger() for workspace in all_workspaces]}" + f"All workspaces: {[workspace.format_name_for_logger() for workspace in all_workspaces]}", ) allowed_workspaces = [] for workspace in all_workspaces: if not self.source_config.workspace_id_pattern.allowed(workspace.id): self.reporter.filtered_workspace_names.append( - f"{workspace.id} - {workspace.name}" + f"{workspace.id} - {workspace.name}", ) continue elif workspace.type not in self.source_config.workspace_type_filter: self.reporter.filtered_workspace_types.append( - f"{workspace.id} - {workspace.name} (type = {workspace.type})" + f"{workspace.id} - {workspace.name} (type = {workspace.type})", ) continue else: @@ -1292,7 +1310,7 @@ def get_allowed_workspaces(self) -> List[powerbi_data_classes.Workspace]: logger.info(f"Number of allowed workspaces = {len(allowed_workspaces)}") logger.debug( - f"Allowed workspaces: {[workspace.format_name_for_logger() for workspace in allowed_workspaces]}" + f"Allowed workspaces: {[workspace.format_name_for_logger() for workspace in allowed_workspaces]}", ) return allowed_workspaces @@ -1308,11 +1326,12 @@ def validate_dataset_type_mapping(self): raise ValueError(f"PowerBI DataPlatform {key} is not supported") logger.debug( - f"Dataset lineage would get ingested for data-platform = {self.source_config.dataset_type_mapping}" + f"Dataset lineage would get ingested for data-platform = {self.source_config.dataset_type_mapping}", ) def extract_independent_datasets( - self, workspace: powerbi_data_classes.Workspace + self, + workspace: powerbi_data_classes.Workspace, ) -> Iterable[MetadataWorkUnit]: if self.source_config.extract_independent_datasets is False: if workspace.independent_datasets: @@ -1324,7 +1343,7 @@ def extract_independent_datasets( dataset.name for dataset in workspace.independent_datasets if dataset.name - ] + ], ), ) return @@ -1334,11 +1353,12 @@ def extract_independent_datasets( stream=self.mapper.to_datahub_dataset( dataset=dataset, workspace=workspace, - ) + ), ) def emit_app( - self, workspace: powerbi_data_classes.Workspace + self, + workspace: powerbi_data_classes.Workspace, ) -> Iterable[MetadataChangeProposalWrapper]: if workspace.app is None: return @@ -1357,9 +1377,9 @@ def emit_app( platform=self.source_config.platform_name, platform_instance=self.source_config.platform_instance, name=powerbi_data_classes.Dashboard.get_urn_part_by_id( - app_dashboard.original_dashboard_id + app_dashboard.original_dashboard_id, ), - ) + ), ) for app_dashboard in workspace.app.dashboards ] @@ -1371,17 +1391,17 @@ def emit_app( platform=self.source_config.platform_name, platform_instance=self.source_config.platform_instance, name=powerbi_data_classes.Report.get_urn_part_by_id( - app_report.original_report_id + app_report.original_report_id, ), - ) + ), ) for app_report in workspace.app.reports - ] + ], ) if assets_within_app: logger.debug( - f"Emitting metadata-workunits for app {workspace.app.name}({workspace.app.id})" + f"Emitting metadata-workunits for app {workspace.app.name}({workspace.app.id})", ) app_urn: str = builder.make_dashboard_urn( @@ -1401,19 +1421,20 @@ def emit_app( actor="urn:li:corpuser:unknown", time=int( datetime.strptime( - workspace.app.last_update, "%Y-%m-%dT%H:%M:%S.%fZ" - ).timestamp() + workspace.app.last_update, + "%Y-%m-%dT%H:%M:%S.%fZ", + ).timestamp(), ), ) if workspace.app.last_update - else None + else None, ), dashboards=assets_within_app, ) # Browse path browse_path: BrowsePathsClass = BrowsePathsClass( - paths=[f"/powerbi/{workspace.name}"] + paths=[f"/powerbi/{workspace.name}"], ) yield from MetadataChangeProposalWrapper.construct_many( @@ -1427,11 +1448,12 @@ def emit_app( ) def get_workspace_workunit( - self, workspace: powerbi_data_classes.Workspace + self, + workspace: powerbi_data_classes.Workspace, ) -> Iterable[MetadataWorkUnit]: if self.source_config.extract_workspaces_to_containers: workspace_workunits = self.mapper.generate_container_for_workspace( - workspace + workspace, ) for workunit in workspace_workunits: @@ -1461,7 +1483,8 @@ def get_workspace_workunit( for report in workspace.reports: for work_unit in self.mapper.report_to_datahub_work_units( - report, workspace + report, + workspace, ): wu = self._get_dashboard_patch_work_unit(work_unit) if wu is not None: @@ -1470,7 +1493,8 @@ def get_workspace_workunit( yield from self.extract_independent_datasets(workspace) def _get_dashboard_patch_work_unit( - self, work_unit: MetadataWorkUnit + self, + work_unit: MetadataWorkUnit, ) -> Optional[MetadataWorkUnit]: dashboard_info_aspect: Optional[DashboardInfoClass] = ( work_unit.get_aspect_of_type(DashboardInfoClass) @@ -1508,11 +1532,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: allowed_workspaces = self.get_allowed_workspaces() batches = more_itertools.chunked( - allowed_workspaces, self.source_config.scan_batch_size + allowed_workspaces, + self.source_config.scan_batch_size, ) for batch_workspaces in batches: for workspace in self.powerbi_client.fill_workspaces( - batch_workspaces, self.reporter + batch_workspaces, + self.reporter, ): logger.info(f"Processing workspace id: {workspace.id}") @@ -1522,7 +1548,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # Refer to https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py#L390 self.stale_entity_removal_handler.set_job_id(workspace.id) self.state_provider.register_stateful_ingestion_usecase_handler( - self.stale_entity_removal_handler + self.stale_entity_removal_handler, ) yield from self._apply_workunit_processors( diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_classes.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_classes.py index 70e4ee68f53515..d6d9f2deb84800 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_classes.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_classes.py @@ -15,7 +15,11 @@ FIELD_TYPE_MAPPING: Dict[ str, Union[ - BooleanTypeClass, DateTypeClass, NullTypeClass, NumberTypeClass, StringTypeClass + BooleanTypeClass, + DateTypeClass, + NullTypeClass, + NumberTypeClass, + StringTypeClass, ], ] = { "Int64": NumberTypeClass(), @@ -133,7 +137,11 @@ class Column: dataType: str isHidden: bool datahubDataType: Union[ - BooleanTypeClass, DateTypeClass, NullTypeClass, NumberTypeClass, StringTypeClass + BooleanTypeClass, + DateTypeClass, + NullTypeClass, + NumberTypeClass, + StringTypeClass, ] columnType: Optional[str] = None expression: Optional[str] = None @@ -148,7 +156,11 @@ class Measure: isHidden: bool dataType: str = "measure" datahubDataType: Union[ - BooleanTypeClass, DateTypeClass, NullTypeClass, NumberTypeClass, StringTypeClass + BooleanTypeClass, + DateTypeClass, + NullTypeClass, + NumberTypeClass, + StringTypeClass, ] = dataclasses.field(default_factory=NullTypeClass) description: Optional[str] = None measure_profile: Optional[MeasureProfile] = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py index 927840c44bf0b0..9e862ca0191085 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py @@ -111,7 +111,7 @@ def __init__( backoff_factor=1, allowed_methods=None, status_forcelist=[429, 500, 502, 503, 504], - ) + ), ), ) @@ -147,13 +147,17 @@ def profile_dataset( @abstractmethod def get_dataset( - self, workspace: Workspace, dataset_id: str + self, + workspace: Workspace, + dataset_id: str, ) -> Optional[PowerBIDataset]: pass @abstractmethod def get_dataset_parameters( - self, workspace_id: str, dataset_id: str + self, + workspace_id: str, + dataset_id: str, ) -> Dict[str, str]: pass @@ -181,27 +185,27 @@ def get_access_token(self) -> str: logger.info("Generating PowerBi access token") auth_response = self._msal_client.acquire_token_for_client( - scopes=[DataResolverBase.SCOPE] + scopes=[DataResolverBase.SCOPE], ) if not auth_response.get(Constant.ACCESS_TOKEN): logger.warning( - "Failed to generate the PowerBi access token. Please check input configuration" + "Failed to generate the PowerBi access token. Please check input configuration", ) raise ConfigurationError( - "Failed to retrieve access token for PowerBI principal. Please verify your configuration values" + "Failed to retrieve access token for PowerBI principal. Please verify your configuration values", ) logger.info("Generated PowerBi access token") self._access_token = "Bearer {}".format( - auth_response.get(Constant.ACCESS_TOKEN) + auth_response.get(Constant.ACCESS_TOKEN), ) safety_gap = 300 self._access_token_expiry_time = datetime.now() + timedelta( seconds=( max(auth_response.get(Constant.ACCESS_TOKEN_EXPIRY, 0) - safety_gap, 0) - ) + ), ) logger.debug(f"{Constant.PBIAccessToken}={self._access_token}") @@ -273,7 +277,9 @@ def get_groups(self, filter_: Dict) -> List[dict]: return output def get_reports( - self, workspace: Workspace, _filter: Optional[str] = None + self, + workspace: Workspace, + _filter: Optional[str] = None, ) -> List[Report]: reports_endpoint = self.get_reports_endpoint(workspace) # Hit PowerBi @@ -302,7 +308,8 @@ def fetch_reports(): embedUrl=raw_instance.get(Constant.EMBED_URL), description=raw_instance.get(Constant.DESCRIPTION, ""), pages=self._get_pages_by_report( - workspace=workspace, report_id=raw_instance[Constant.ID] + workspace=workspace, + report_id=raw_instance[Constant.ID], ), dataset_id=raw_instance.get(Constant.DATASET_ID), users=[], # It will be fetched using Admin Fetcher based on condition @@ -321,7 +328,8 @@ def fetch_reports(): def get_report(self, workspace: Workspace, report_id: str) -> Optional[Report]: reports: List[Report] = self.get_reports( - workspace, _filter=f"id eq '{report_id}'" + workspace, + _filter=f"id eq '{report_id}'", ) if len(reports) == 0: @@ -373,7 +381,8 @@ def new_dataset_or_report(tile_instance: Any) -> dict: return report_fields tile_list_endpoint: str = self.get_tiles_endpoint( - workspace, dashboard_id=dashboard.id + workspace, + dashboard_id=dashboard.id, ) # Hit PowerBi logger.debug(f"Request to URL={tile_list_endpoint}") @@ -482,7 +491,9 @@ class RegularAPIResolver(DataResolverBase): } def get_dataset( - self, workspace: Workspace, dataset_id: str + self, + workspace: Workspace, + dataset_id: str, ) -> Optional[PowerBIDataset]: """ Fetch the dataset from PowerBi for the given dataset identifier @@ -517,7 +528,9 @@ def get_dataset( return new_powerbi_dataset(workspace, response_dict) def get_dataset_parameters( - self, workspace_id: str, dataset_id: str + self, + workspace_id: str, + dataset_id: str, ) -> Dict[str, str]: dataset_get_endpoint: str = RegularAPIResolver.API_ENDPOINTS[ Constant.DATASET_GET @@ -555,13 +568,15 @@ def get_dashboards_endpoint(self, workspace: Workspace) -> str: ] # Replace place holders return dashboards_endpoint.format( - POWERBI_BASE_URL=DataResolverBase.BASE_URL, WORKSPACE_ID=workspace.id + POWERBI_BASE_URL=DataResolverBase.BASE_URL, + WORKSPACE_ID=workspace.id, ) def get_reports_endpoint(self, workspace: Workspace) -> str: reports_endpoint: str = self.API_ENDPOINTS[Constant.REPORT_LIST] return reports_endpoint.format( - POWERBI_BASE_URL=DataResolverBase.BASE_URL, WORKSPACE_ID=workspace.id + POWERBI_BASE_URL=DataResolverBase.BASE_URL, + WORKSPACE_ID=workspace.id, ) def get_tiles_endpoint(self, workspace: Workspace, dashboard_id: str) -> str: @@ -595,7 +610,8 @@ def _get_pages_by_report(self, workspace: Workspace, report_id: str) -> List[Pag return [ Page( id="{}.{}".format( - report_id, raw_instance[Constant.NAME].replace(" ", "_") + report_id, + raw_instance[Constant.NAME].replace(" ", "_"), ), name=raw_instance[Constant.NAME], displayName=raw_instance.get(Constant.DISPLAY_NAME), @@ -626,7 +642,7 @@ def _execute_profiling_query(self, dataset: PowerBIDataset, query: str) -> dict: "queries": [ { "query": query, - } + }, ], "serializerSettings": { "includeNulls": True, @@ -654,7 +670,7 @@ def _get_row_count(self, dataset: PowerBIDataset, table: Table) -> int: ) except (KeyError, IndexError) as ex: logger.warning( - f"Profiling failed for getting row count for dataset {dataset.id}, with {ex}" + f"Profiling failed for getting row count for dataset {dataset.id}, with {ex}", ) return 0 @@ -670,12 +686,15 @@ def _get_data_sample(self, dataset: PowerBIDataset, table: Table) -> dict: ) except (KeyError, IndexError) as ex: logger.warning( - f"Getting sample with TopN failed for dataset {dataset.id}, with {ex}" + f"Getting sample with TopN failed for dataset {dataset.id}, with {ex}", ) return {} def _get_column_data( - self, dataset: PowerBIDataset, table: Table, column: Union[Column, Measure] + self, + dataset: PowerBIDataset, + table: Table, + column: Union[Column, Measure], ) -> dict: try: logger.debug(f"Column data query for {dataset.name}, {column.name}") @@ -689,7 +708,7 @@ def _get_column_data( ) except (KeyError, IndexError) as ex: logger.warning( - f"Getting column statistics failed for dataset {dataset.name}, {column.name}, with {ex}" + f"Getting column statistics failed for dataset {dataset.name}, {column.name}, with {ex}", ) return {} @@ -706,7 +725,7 @@ def profile_dataset( if not profile_pattern.allowed(f"{workspace_name}.{dataset.name}.{table.name}"): logger.info( - f"Table {table.name} in {dataset.name}, not allowed for profiling" + f"Table {table.name} in {dataset.name}, not allowed for profiling", ) return @@ -729,7 +748,8 @@ def profile_dataset( column_stats = self._get_column_data(dataset, table, column) column.measure_profile = MeasureProfile( - sample_values=column_sample, **column_stats + sample_values=column_sample, + **column_stats, ) column_count += 1 @@ -768,7 +788,7 @@ def create_scan_job(self, workspace_ids: List[str]) -> str: scan_create_endpoint = AdminAPIResolver.API_ENDPOINTS[Constant.SCAN_CREATE] scan_create_endpoint = scan_create_endpoint.format( - POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL + POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, ) logger.debug( @@ -800,7 +820,7 @@ def create_scan_job(self, workspace_ids: List[str]) -> str: def _calculate_max_retry(minimum_sleep: int, timeout: int) -> int: if timeout < minimum_sleep: logger.info( - f"Setting timeout to minimum_sleep time {minimum_sleep} seconds" + f"Setting timeout to minimum_sleep time {minimum_sleep} seconds", ) timeout = minimum_sleep @@ -833,12 +853,12 @@ def _is_scan_result_ready( if retry == max_retry: logger.warning( "Max retry reached when polling for scan job (lineage) result. Scan job is not " - "available! Try increasing your max retry using config option scan_timeout" + "available! Try increasing your max retry using config option scan_timeout", ) break logger.debug( - f"Waiting to check for scan job completion for {minimum_sleep_seconds} seconds." + f"Waiting to check for scan job completion for {minimum_sleep_seconds} seconds.", ) sleep(minimum_sleep_seconds) retry += 1 @@ -851,13 +871,15 @@ def wait_for_scan_to_complete(self, scan_id: str, timeout: int) -> Any: """ minimum_sleep_seconds = 3 max_retry: int = AdminAPIResolver._calculate_max_retry( - minimum_sleep_seconds, timeout + minimum_sleep_seconds, + timeout, ) # logger.info(f"Max trial {max_retry}") scan_get_endpoint = AdminAPIResolver.API_ENDPOINTS[Constant.SCAN_GET] scan_get_endpoint = scan_get_endpoint.format( - POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, SCAN_ID=scan_id + POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, + SCAN_ID=scan_id, ) return self._is_scan_result_ready( @@ -904,7 +926,7 @@ def get_users(self, workspace_id: str, entity: str, entity_id: str) -> List[User datasetUserAccessRight=instance.get(Constant.DATASET_USER_ACCESS_RIGHT), reportUserAccessRight=instance.get(Constant.REPORT_USER_ACCESS_RIGHT), dashboardUserAccessRight=instance.get( - Constant.DASHBOARD_USER_ACCESS_RIGHT + Constant.DASHBOARD_USER_ACCESS_RIGHT, ), groupUserAccessRight=instance.get(Constant.GROUP_USER_ACCESS_RIGHT), ) @@ -920,7 +942,8 @@ def get_scan_result(self, scan_id: str) -> Optional[dict]: Constant.SCAN_RESULT_GET ] scan_result_get_endpoint = scan_result_get_endpoint.format( - POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, SCAN_ID=scan_id + POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, + SCAN_ID=scan_id, ) logger.debug(f"Hitting URL={scan_result_get_endpoint}") @@ -938,7 +961,7 @@ def get_scan_result(self, scan_id: str) -> Optional[dict]: or len(res.json().get("workspaces")) == 0 ): logger.warning( - f"Scan result is not available for scan identifier = {scan_id}" + f"Scan result is not available for scan identifier = {scan_id}", ) return None @@ -971,7 +994,9 @@ def get_tiles_endpoint(self, workspace: Workspace, dashboard_id: str) -> str: ) def get_dataset( - self, workspace: Workspace, dataset_id: str + self, + workspace: Workspace, + dataset_id: str, ) -> Optional[PowerBIDataset]: datasets_endpoint = self.API_ENDPOINTS[Constant.DATASET_LIST].format( POWERBI_ADMIN_BASE_URL=DataResolverBase.ADMIN_BASE_URL, @@ -1029,11 +1054,11 @@ def get_modified_workspaces(self, modified_since: str) -> List[str]: and error_msg_json["error"]["code"] == "InvalidRequest" ): raise ConfigurationError( - "Please check if modified_since is within last 30 days." + "Please check if modified_since is within last 30 days.", ) else: raise ConfigurationError( - f"Please resolve the following error: {res.text}" + f"Please resolve the following error: {res.text}", ) res.raise_for_status() @@ -1043,7 +1068,9 @@ def get_modified_workspaces(self, modified_since: str) -> List[str]: return workspace_ids def get_dataset_parameters( - self, workspace_id: str, dataset_id: str + self, + workspace_id: str, + dataset_id: str, ) -> Dict[str, str]: logger.debug("Get dataset parameter is unsupported in Admin API") return {} diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/powerbi_api.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/powerbi_api.py index b41be19d0de53e..1c4619ed5720e4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/powerbi_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/powerbi_api.py @@ -41,7 +41,8 @@ def form_full_table_name( table_name: str, ) -> str: full_table_name: str = "{}.{}".format( - dataset_name.replace(" ", "_"), table_name.replace(" ", "_") + dataset_name.replace(" ", "_"), + table_name.replace(" ", "_"), ) if config.include_workspace_name_in_dataset_urn: @@ -105,7 +106,8 @@ def log_http_error(self, message: str) -> Any: return e def _get_dashboard_endorsements( - self, scan_result: Optional[dict] + self, + scan_result: Optional[dict], ) -> Dict[str, List[str]]: """ Store saved dashboard endorsements into a dict with dashboard id as key and @@ -119,14 +121,15 @@ def _get_dashboard_endorsements( # Iterate through response and create a list of PowerBiAPI.Dashboard dashboard_id = scanned_dashboard.get("id") tags = self._parse_endorsement( - scanned_dashboard.get(Constant.ENDORSEMENT_DETAIL, None) + scanned_dashboard.get(Constant.ENDORSEMENT_DETAIL, None), ) results[dashboard_id] = tags return results def _get_report_endorsements( - self, scan_result: Optional[dict] + self, + scan_result: Optional[dict], ) -> Dict[str, List[str]]: results: Dict[str, List[str]] = {} @@ -139,11 +142,11 @@ def _get_report_endorsements( report_id = report.get(Constant.ID, None) if report_id is None: logger.warning( - f"Report id is none. Skipping endorsement tag for report instance {report}" + f"Report id is none. Skipping endorsement tag for report instance {report}", ) continue endorsements = self._parse_endorsement( - report.get(Constant.ENDORSEMENT_DETAIL, None) + report.get(Constant.ENDORSEMENT_DETAIL, None), ) results[report_id] = endorsements @@ -155,7 +158,10 @@ def _get_resolver(self): return self.__regular_api_resolver def _get_entity_users( - self, workspace_id: str, entity_name: str, entity_id: str + self, + workspace_id: str, + entity_name: str, + entity_id: str, ) -> List[User]: """ Return list of dashboard users @@ -163,7 +169,7 @@ def _get_entity_users( users: List[User] = [] if self.__config.extract_ownership is False: logger.info( - "Extract ownership capabilities is disabled from configuration and hence returning empty users list" + "Extract ownership capabilities is disabled from configuration and hence returning empty users list", ) return users @@ -175,7 +181,7 @@ def _get_entity_users( ) except Exception: e = self.log_http_error( - message=f"Unable to fetch users for {entity_name}({entity_id})." + message=f"Unable to fetch users for {entity_name}({entity_id}).", ) if data_resolver.is_permission_error(cast(Exception, e)): logger.warning( @@ -187,7 +193,9 @@ def _get_entity_users( def get_dashboard_users(self, dashboard: Dashboard) -> List[User]: return self._get_entity_users( - dashboard.workspace_id, Constant.DASHBOARDS, dashboard.id + dashboard.workspace_id, + Constant.DASHBOARDS, + dashboard.id, ) def get_report_users(self, workspace_id: str, report_id: str) -> List[User]: @@ -212,25 +220,26 @@ def get_reports(self, workspace: Workspace) -> List[Report]: ) except Exception: self.log_http_error( - message=f"Unable to fetch reports for workspace {workspace.name}" + message=f"Unable to fetch reports for workspace {workspace.name}", ) def fill_ownership() -> None: if self.__config.extract_ownership is False: logger.info( - "Skipping user retrieval for report as extract_ownership is set to false" + "Skipping user retrieval for report as extract_ownership is set to false", ) return for report in reports: report.users = self.get_report_users( - workspace_id=workspace.id, report_id=report.id + workspace_id=workspace.id, + report_id=report.id, ) def fill_tags() -> None: if self.__config.extract_endorsements_to_tags is False: logger.info( - "Skipping endorsements tags retrieval for report as extract_endorsements_to_tags is set to false" + "Skipping endorsements tags retrieval for report as extract_endorsements_to_tags is set to false", ) return @@ -290,7 +299,7 @@ def get_modified_workspaces(self) -> List[str]: try: modified_workspace_ids = self.__admin_api_resolver.get_modified_workspaces( - self.__config.modified_since + self.__config.modified_since, ) except Exception: self.log_http_error(message="Unable to fetch list of modified workspaces.") @@ -301,27 +310,28 @@ def _get_scan_result(self, workspace_ids: List[str]) -> Any: scan_id: Optional[str] = None try: scan_id = self.__admin_api_resolver.create_scan_job( - workspace_ids=workspace_ids + workspace_ids=workspace_ids, ) except Exception: e = self.log_http_error(message="Unable to fetch get scan result.") if data_resolver.is_permission_error(cast(Exception, e)): logger.warning( "Dataset lineage can not be ingestion because this user does not have access to the PowerBI Admin " - "API. " + "API. ", ) return None logger.debug("Waiting for scan to complete") if ( self.__admin_api_resolver.wait_for_scan_to_complete( - scan_id=scan_id, timeout=self.__config.scan_timeout + scan_id=scan_id, + timeout=self.__config.scan_timeout, ) is False ): raise ValueError( "Workspace detail is not available. Please increase the scan_timeout configuration value to wait " - "longer for the scan job to complete." + "longer for the scan job to complete.", ) # Scan is complete lets take the result @@ -355,7 +365,7 @@ def _get_workspace_datasets(self, workspace: Workspace) -> dict: datasets: Optional[Any] = scan_result.get(Constant.DATASETS) if datasets is None or len(datasets) == 0: logger.warning( - f"Workspace {scan_result[Constant.NAME]}({scan_result[Constant.ID]}) does not have datasets" + f"Workspace {scan_result[Constant.NAME]}({scan_result[Constant.ID]}) does not have datasets", ) logger.info("Returning empty datasets") @@ -393,7 +403,7 @@ def _get_workspace_datasets(self, workspace: Workspace) -> dict: if self.__config.extract_endorsements_to_tags: dataset_instance.tags = self._parse_endorsement( - dataset_dict.get(Constant.ENDORSEMENT_DETAIL, None) + dataset_dict.get(Constant.ENDORSEMENT_DETAIL, None), ) dataset_map[dataset_instance.id] = dataset_instance @@ -424,7 +434,8 @@ def _get_workspace_datasets(self, workspace: Workspace) -> dict: Column( **column, datahubDataType=FIELD_TYPE_MAPPING.get( - column["dataType"], FIELD_TYPE_MAPPING["Null"] + column["dataType"], + FIELD_TYPE_MAPPING["Null"], ), ) for column in table.get("columns", []) @@ -455,7 +466,9 @@ def get_app( ) def _populate_app_details( - self, workspace: Workspace, workspace_metadata: Dict + self, + workspace: Workspace, + workspace_metadata: Dict, ) -> None: # App_id is not present at the root level of workspace_metadata. # It can be found in the workspace_metadata.dashboards or workspace_metadata.reports lists. @@ -474,7 +487,7 @@ def _populate_app_details( AppReport( id=report[Constant.ID], original_report_id=report[Constant.ORIGINAL_REPORT_OBJECT_ID], - ) + ), ) if app_id is None: # In PowerBI one workspace can have one app app_id = report.get(Constant.APP_ID) @@ -520,7 +533,7 @@ def _populate_app_details( Constant.ID ], original_dashboard_id=dashboard[Constant.ID], - ) + ), ) app.reports = app_reports @@ -576,14 +589,14 @@ def _fill_metadata_from_scan_result( # Fetch endorsement tag if it is enabled from configuration if self.__config.extract_endorsements_to_tags: cur_workspace.dashboard_endorsements = self._get_dashboard_endorsements( - cur_workspace.scan_result + cur_workspace.scan_result, ) cur_workspace.report_endorsements = self._get_report_endorsements( - cur_workspace.scan_result + cur_workspace.scan_result, ) else: logger.info( - "Skipping endorsements tag as extract_endorsements_to_tags is not enabled" + "Skipping endorsements tag as extract_endorsements_to_tags is not enabled", ) self._populate_app_details( @@ -617,7 +630,8 @@ def fill_dashboards() -> None: # set tiles of Dashboard for dashboard in workspace.dashboards: dashboard.tiles = self._get_resolver().get_tiles( - workspace, dashboard=dashboard + workspace, + dashboard=dashboard, ) # set the dataset for tiles for tile in dashboard.tiles: @@ -633,7 +647,7 @@ def fill_dashboards() -> None: def fill_reports() -> None: if self.__config.extract_reports is False: logger.info( - "Skipping report retrieval as extract_reports is set to false" + "Skipping report retrieval as extract_reports is set to false", ) return workspace.reports = self.get_reports(workspace) @@ -641,7 +655,7 @@ def fill_reports() -> None: def fill_dashboard_tags() -> None: if self.__config.extract_endorsements_to_tags is False: logger.info( - "Skipping tag retrieval for dashboard as extract_endorsements_to_tags is set to false" + "Skipping tag retrieval for dashboard as extract_endorsements_to_tags is set to false", ) return for dashboard in workspace.dashboards: @@ -656,10 +670,12 @@ def fill_dashboard_tags() -> None: # flake8: noqa: C901 def fill_workspaces( - self, workspaces: List[Workspace], reporter: PowerBiDashboardSourceReport + self, + workspaces: List[Workspace], + reporter: PowerBiDashboardSourceReport, ) -> Iterable[Workspace]: logger.info( - f"Fetching initial metadata for workspaces: {[workspace.format_name_for_logger() for workspace in workspaces]}" + f"Fetching initial metadata for workspaces: {[workspace.format_name_for_logger() for workspace in workspaces]}", ) workspaces = self._fill_metadata_from_scan_result(workspaces=workspaces) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py index 4764400215e12a..bfd08b14c274d7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py @@ -61,23 +61,27 @@ class PowerBiReportServerAPIConfig(EnvConfigMixin): username: str = pydantic.Field(description="Windows account username") password: str = pydantic.Field(description="Windows account password") workstation_name: str = pydantic.Field( - default="localhost", description="Workstation name" + default="localhost", + description="Workstation name", ) host_port: str = pydantic.Field(description="Power BI Report Server host URL") server_alias: str = pydantic.Field( - default="", description="Alias for Power BI Report Server host URL" + default="", + description="Alias for Power BI Report Server host URL", ) graphql_url: Optional[str] = pydantic.Field( - default=None, description="[deprecated] Not used" + default=None, + description="[deprecated] Not used", ) report_virtual_directory_name: str = pydantic.Field( - description="Report Virtual Directory URL name" + description="Report Virtual Directory URL name", ) report_server_virtual_directory_name: str = pydantic.Field( - description="Report Server Virtual Directory URL name" + description="Report Server Virtual Directory URL name", ) extract_ownership: bool = pydantic.Field( - default=True, description="Whether ownership should be ingested" + default=True, + description="Whether ownership should be ingested", ) ownership_type: str = pydantic.Field( default=OwnershipTypeClass.NONE, @@ -87,19 +91,22 @@ class PowerBiReportServerAPIConfig(EnvConfigMixin): @property def get_base_api_http_url(self): return "http://{}/{}/api/v2.0".format( - self.host_port, self.report_virtual_directory_name + self.host_port, + self.report_virtual_directory_name, ) @property def get_base_api_https_url(self): return "https://{}/{}/api/v2.0".format( - self.host_port, self.report_virtual_directory_name + self.host_port, + self.report_virtual_directory_name, ) @property def get_base_url(self): return "http://{}/{}/".format( - self.host_port, self.report_virtual_directory_name + self.host_port, + self.report_virtual_directory_name, ) @property @@ -250,7 +257,8 @@ def new_mcp( ) def __to_work_unit( - self, mcp: MetadataChangeProposalWrapper + self, + mcp: MetadataChangeProposalWrapper, ) -> EquableMetadataWorkUnit: return Mapper.EquableMetadataWorkUnit( id="{PLATFORM}-{ENTITY_URN}-{ASPECT_NAME}".format( @@ -268,7 +276,7 @@ def to_urn_set(mcps: List[MetadataChangeProposalWrapper]) -> List[str]: mcp.entityUrn for mcp in mcps if mcp is not None and mcp.entityUrn is not None - ] + ], ) def to_ownership_set( @@ -282,7 +290,7 @@ def to_ownership_set( for mcp in mcps: if mcp is not None and mcp.entityUrn is not None: ownership.append( - Owner(owner=mcp.entityUrn, type=self.__config.ownership_type) + Owner(owner=mcp.entityUrn, type=self.__config.ownership_type), ) return deduplicate_list(ownership) @@ -296,12 +304,14 @@ def __to_datahub_dashboard( Map PowerBI Report Server report to Datahub Dashboard """ dashboard_urn = builder.make_dashboard_urn( - self.__config.platform_name, report.get_urn_part() + self.__config.platform_name, + report.get_urn_part(), ) chart_urn_list: List[str] = self.to_urn_set(chart_mcps) user_urn_list: List[Owner] = self.to_ownership_set( - mcps=user_mcps, existing_owners=report.user_info.existing_owners + mcps=user_mcps, + existing_owners=report.user_info.existing_owners, ) def custom_properties( @@ -381,8 +391,8 @@ def custom_properties( self.__config.host, self.__config.env, self.__config.report_virtual_directory_name, - ) - ] + ), + ], ) browse_path_mcp = self.new_mcp( entity_type=Constant.DASHBOARD, @@ -517,7 +527,9 @@ class PowerBiReportServerDashboardSource(Source): accessed_dashboards: int = 0 def __init__( - self, config: PowerBiReportServerDashboardSourceConfig, ctx: PipelineContext + self, + config: PowerBiReportServerDashboardSourceConfig, + ctx: PipelineContext, ): super().__init__(ctx) self.source_config = config @@ -545,7 +557,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: report.user_info = self.get_user_info(report) except pydantic.ValidationError as e: message = "Error ({}) occurred while loading User {}(id={})".format( - e, report.name, report.id + e, + report.name, + report.id, ) LOGGER.exception(message, e) self.report.report_warning(report.id, message) @@ -561,7 +575,8 @@ def get_user_info(self, report: Any) -> OwnershipData: if not self.source_config.extract_ownership: return OwnershipData(existing_owners=[], owner_to_add=None) dashboard_urn = builder.make_dashboard_urn( - self.source_config.platform_name, report.get_urn_part() + self.source_config.platform_name, + report.get_urn_part(), ) user_urn = builder.make_user_urn(report.display_name) @@ -570,10 +585,12 @@ def get_user_info(self, report: Any) -> OwnershipData: if ownership: existing_ownership = ownership.owners if self.ctx.graph.get_aspect_v2( - entity_urn=user_urn, aspect="corpUserInfo", aspect_type=CorpUserInfoClass + entity_urn=user_urn, + aspect="corpUserInfo", + aspect_type=CorpUserInfoClass, ): existing_ownership.append( - OwnerClass(owner=user_urn, type=self.source_config.ownership_type) + OwnerClass(owner=user_urn, type=self.source_config.ownership_type), ) return OwnershipData(existing_owners=existing_ownership) user_data = dict( @@ -584,7 +601,8 @@ def get_user_info(self, report: Any) -> OwnershipData: ) owner_to_add = CorpUser(**user_data) return OwnershipData( - existing_owners=existing_ownership, owner_to_add=owner_to_add + existing_owners=existing_ownership, + owner_to_add=owner_to_add, ) def get_report(self) -> SourceReport: diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py index b65ae5cd2994cc..9b91dc6b0a0122 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py @@ -29,7 +29,8 @@ class CatalogItem(BaseModel): display_name: Optional[str] = Field(None, alias="DisplayName") has_data_sources: bool = Field(default=False, alias="HasDataSources") data_sources: Optional[List["DataSource"]] = Field( - default_factory=list, alias="DataSources" + default_factory=list, + alias="DataSources", ) @validator("display_name", always=True) @@ -45,10 +46,18 @@ def get_web_url(self, base_reports_url: str) -> str: return f"{base_reports_url}powerbi{self.path}" def get_browse_path( - self, base_folder: str, workspace: str, env: str, report_directory: str + self, + base_folder: str, + workspace: str, + env: str, + report_directory: str, ) -> str: return "/{}/{}/{}/{}{}".format( - base_folder, env.lower(), workspace, report_directory, self.path + base_folder, + env.lower(), + workspace, + report_directory, + self.path, ) @@ -118,7 +127,7 @@ class Subscription(BaseModel): extension_settings: ExtensionSettings = Field(alias="ExtensionSettings") delivery_extension: str = Field(alias="DeliveryExtension") localized_delivery_extension_name: str = Field( - alias="LocalizedDeliveryExtensionName" + alias="LocalizedDeliveryExtensionName", ) modified_by: str = Field(alias="ModifiedBy") modified_date: datetime = Field(alias="ModifiedDate") @@ -135,19 +144,22 @@ class DataSource(CatalogItem): is_enabled: bool = Field(alias="IsEnabled") connection_string: str = Field(alias="ConnectionString") data_model_data_source: Optional[DataModelDataSource] = Field( - None, alias="DataModelDataSource" + None, + alias="DataModelDataSource", ) data_source_sub_type: Optional[str] = Field(None, alias="DataSourceSubType") data_source_type: Optional[str] = Field(None, alias="DataSourceType") is_original_connection_string_expression_based: bool = Field( - alias="IsOriginalConnectionStringExpressionBased" + alias="IsOriginalConnectionStringExpressionBased", ) is_connection_string_overridden: bool = Field(alias="IsConnectionStringOverridden") credentials_by_user: Optional[CredentialsByUser] = Field( - None, alias="CredentialsByUser" + None, + alias="CredentialsByUser", ) credentials_in_server: Optional[CredentialsInServer] = Field( - None, alias="CredentialsInServer" + None, + alias="CredentialsInServer", ) is_reference: bool = Field(alias="IsReference") subscriptions: Optional[Subscription] = Field(None, alias="Subscriptions") @@ -330,7 +342,8 @@ class CorpUser(BaseModel): username: str properties: CorpUserProperties editable_properties: Optional[CorpUserEditableProperties] = Field( - None, alias="editableProperties" + None, + alias="editableProperties", ) status: Optional[CorpUserStatus] = None tags: Optional[GlobalTags] = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/preset.py b/metadata-ingestion/src/datahub/ingestion/source/preset.py index 7b0bc89648c529..05cd008115fa3e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/preset.py +++ b/metadata-ingestion/src/datahub/ingestion/source/preset.py @@ -27,7 +27,8 @@ class PresetConfig(SupersetConfig): manager_uri: str = Field( - default="https://api.app.preset.io", description="Preset.io API URL" + default="https://api.app.preset.io", + description="Preset.io API URL", ) connect_uri: str = Field(default="", description="Preset workspace URL.") display_uri: Optional[str] = Field( @@ -39,7 +40,8 @@ class PresetConfig(SupersetConfig): # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="Preset Stateful Ingestion Config." + default=None, + description="Preset Stateful Ingestion Config.", ) options: Dict = Field(default={}, description="") @@ -68,7 +70,8 @@ def default_display_uri_to_connect_uri(cls, values): @config_class(PresetConfig) @support_status(SupportStatus.TESTING) @capability( - SourceCapability.DELETION_DETECTION, "Optionally enabled via stateful_ingestion" + SourceCapability.DELETION_DETECTION, + "Optionally enabled via stateful_ingestion", ) class PresetSource(SupersetSource): """ @@ -106,7 +109,7 @@ def login(self): "Authorization": f"Bearer {self.access_token}", "Content-Type": "application/json", "Accept": "*/*", - } + }, ) # Test the connection test_response = requests_session.get(f"{self.config.connect_uri}/version") diff --git a/metadata-ingestion/src/datahub/ingestion/source/profiling/common.py b/metadata-ingestion/src/datahub/ingestion/source/profiling/common.py index b54f0e02fc1c87..41634ac7ed6133 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/profiling/common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/profiling/common.py @@ -14,7 +14,8 @@ class Cardinality(Enum): def convert_to_cardinality( - unique_count: Optional[int], pct_unique: Optional[float] + unique_count: Optional[int], + pct_unique: Optional[float], ) -> Optional[Cardinality]: """ Resolve the cardinality of a column based on the unique count and the percentage of unique values. diff --git a/metadata-ingestion/src/datahub/ingestion/source/pulsar.py b/metadata-ingestion/src/datahub/ingestion/source/pulsar.py index f71949b9eb27f7..6dc05ef2d56e80 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/pulsar.py +++ b/metadata-ingestion/src/datahub/ingestion/source/pulsar.py @@ -127,7 +127,7 @@ def __init__(self, config: PulsarSourceConfig, ctx: PipelineContext): self.session.headers.update( { "Content-Type": "application/json", - } + }, ) if self._is_oauth_authentication_configured(): @@ -136,20 +136,22 @@ def __init__(self, config: PulsarSourceConfig, ctx: PipelineContext): f"{self.config.issuer_url}/.well-known/openid-configuration" ) oid_config_response = requests.get( - oid_config_url, verify=self.session.verify, allow_redirects=False + oid_config_url, + verify=self.session.verify, + allow_redirects=False, ) if oid_config_response: self.config.oid_config.update(oid_config_response.json()) else: logger.error( - f"Unexpected response while getting discovery document using {oid_config_url} : {oid_config_response}" + f"Unexpected response while getting discovery document using {oid_config_url} : {oid_config_response}", ) if "token_endpoint" not in self.config.oid_config: raise Exception( "The token_endpoint is not set, please verify the configured issuer_url or" - " set oid_config.token_endpoint manually in the configuration file." + " set oid_config.token_endpoint manually in the configuration file.", ) # Authentication configured @@ -159,7 +161,7 @@ def __init__(self, config: PulsarSourceConfig, ctx: PipelineContext): ): # Update session header with Bearer token self.session.headers.update( - {"Authorization": f"Bearer {self.get_access_token()}"} + {"Authorization": f"Bearer {self.get_access_token()}"}, ) def get_access_token(self) -> str: @@ -199,7 +201,7 @@ def get_access_token(self) -> str: # Failed to get an access token, raise ConfigurationError( f"Failed to get the Pulsar access token from token_endpoint {self.config.oid_config.get('token_endpoint')}." - f" Please check your input configuration." + f" Please check your input configuration.", ) def _get_pulsar_metadata(self, url): @@ -229,7 +231,7 @@ def _get_pulsar_metadata(self, url): self.report.report_warning("HTTPError", message) except requests.exceptions.RequestException as e: raise Exception( - f"An ambiguous exception occurred while handling the request: {e}" + f"An ambiguous exception occurred while handling the request: {e}", ) @classmethod @@ -246,7 +248,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -271,8 +275,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # Report the Pulsar broker version we are communicating with self.report.report_pulsar_version( self.session.get( - f"{self.base_url}/brokers/version", timeout=self.config.timeout - ).text + f"{self.base_url}/brokers/version", + timeout=self.config.timeout, + ).text, ) # If no tenants are provided, request all tenants from cluster using /admin/v2/tenants endpoint. @@ -311,7 +316,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # Create a mesh of topics with partitioned values, the # partitioned info is added as a custom properties later topics.update( - {topic: partitioned for topic in pulsar_topics} + {topic: partitioned for topic in pulsar_topics}, ) # For all allowed topics get the metadata @@ -333,7 +338,9 @@ def _is_oauth_authentication_configured(self) -> bool: return self.config.issuer_url is not None def _get_schema_and_fields( - self, pulsar_topic: PulsarTopic, is_key_schema: bool + self, + pulsar_topic: PulsarTopic, + is_key_schema: bool, ) -> Tuple[Optional[PulsarSchema], List[SchemaField]]: pulsar_schema: Optional[PulsarSchema] = None @@ -360,14 +367,18 @@ def _get_schema_and_fields( return pulsar_schema, fields def _get_schema_fields( - self, pulsar_topic: PulsarTopic, schema: PulsarSchema, is_key_schema: bool + self, + pulsar_topic: PulsarTopic, + schema: PulsarSchema, + is_key_schema: bool, ) -> List[SchemaField]: # Parse the schema and convert it to SchemaFields. fields: List[SchemaField] = [] if schema.schema_type in ["AVRO", "JSON"]: # Extract fields from schema and get the FQN for the schema fields = schema_util.avro_schema_to_mce_fields( - schema.schema_str, is_key_schema=is_key_schema + schema.schema_str, + is_key_schema=is_key_schema, ) else: self.report.report_warning( @@ -377,11 +388,14 @@ def _get_schema_fields( return fields def _get_schema_metadata( - self, pulsar_topic: PulsarTopic, platform_urn: str + self, + pulsar_topic: PulsarTopic, + platform_urn: str, ) -> Tuple[Optional[PulsarSchema], Optional[SchemaMetadata]]: # FIXME: Type annotations are not working for this function. schema, fields = self._get_schema_and_fields( - pulsar_topic=pulsar_topic, is_key_schema=False + pulsar_topic=pulsar_topic, + is_key_schema=False, ) # type: Tuple[Optional[PulsarSchema], List[SchemaField]] # Create the schemaMetadata aspect. @@ -404,7 +418,9 @@ def _get_schema_metadata( return None, None def _extract_record( - self, topic: str, partitioned: bool + self, + topic: str, + partitioned: bool, ) -> Iterable[MetadataWorkUnit]: logger.info(f"topic = {topic}") @@ -465,7 +481,7 @@ def _extract_record( yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=BrowsePathsClass( - [f"/{self.config.env.lower()}/{self.platform}/{browse_path_suffix}"] + [f"/{self.config.env.lower()}/{self.platform}/{browse_path_suffix}"], ), ).as_workunit() @@ -476,7 +492,8 @@ def _extract_record( aspect=DataPlatformInstanceClass( platform=platform_urn, instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ).as_workunit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/config.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/config.py index b5547e8a8ae9ec..e8a9d40ba8572a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/config.py @@ -110,7 +110,9 @@ class PlatformDetail(PlatformInstanceConfigMixin, EnvConfigMixin): class QlikSourceConfig( - StatefulIngestionConfigBase, PlatformInstanceConfigMixin, EnvConfigMixin + StatefulIngestionConfigBase, + PlatformInstanceConfigMixin, + EnvConfigMixin, ): tenant_hostname: str = pydantic.Field(description="Qlik Tenant hostname") api_key: str = pydantic.Field(description="Qlik API Key") diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py index a43f5f32493f2d..2606ecd1892eec 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py @@ -90,10 +90,12 @@ class Space(BaseModel): @root_validator(pre=True) def update_values(cls, values: Dict) -> Dict: values[Constant.CREATEDAT] = datetime.strptime( - values[Constant.CREATEDAT], QLIK_DATETIME_FORMAT + values[Constant.CREATEDAT], + QLIK_DATETIME_FORMAT, ) values[Constant.UPDATEDAT] = datetime.strptime( - values[Constant.UPDATEDAT], QLIK_DATETIME_FORMAT + values[Constant.UPDATEDAT], + QLIK_DATETIME_FORMAT, ) return values @@ -132,10 +134,12 @@ class QlikDataset(Item): def update_values(cls, values: Dict) -> Dict: # Update str time to datetime values[Constant.CREATEDAT] = datetime.strptime( - values[Constant.CREATEDTIME], QLIK_DATETIME_FORMAT + values[Constant.CREATEDTIME], + QLIK_DATETIME_FORMAT, ) values[Constant.UPDATEDAT] = datetime.strptime( - values[Constant.LASTMODIFIEDTIME], QLIK_DATETIME_FORMAT + values[Constant.LASTMODIFIEDTIME], + QLIK_DATETIME_FORMAT, ) if not values.get(Constant.SPACEID): # spaceId none indicates dataset present in personal space @@ -182,10 +186,12 @@ class Sheet(BaseModel): @root_validator(pre=True) def update_values(cls, values: Dict) -> Dict: values[Constant.CREATEDAT] = datetime.strptime( - values[Constant.CREATEDDATE], QLIK_DATETIME_FORMAT + values[Constant.CREATEDDATE], + QLIK_DATETIME_FORMAT, ) values[Constant.UPDATEDAT] = datetime.strptime( - values[Constant.MODIFIEDDATE], QLIK_DATETIME_FORMAT + values[Constant.MODIFIEDDATE], + QLIK_DATETIME_FORMAT, ) return values @@ -224,10 +230,12 @@ class App(Item): @root_validator(pre=True) def update_values(cls, values: Dict) -> Dict: values[Constant.CREATEDAT] = datetime.strptime( - values[Constant.CREATEDDATE], QLIK_DATETIME_FORMAT + values[Constant.CREATEDDATE], + QLIK_DATETIME_FORMAT, ) values[Constant.UPDATEDAT] = datetime.strptime( - values[Constant.MODIFIEDDATE], QLIK_DATETIME_FORMAT + values[Constant.MODIFIEDDATE], + QLIK_DATETIME_FORMAT, ) if not values.get(Constant.SPACEID): # spaceId none indicates app present in personal space diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py index 10b062c98c147f..728b025795717c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py @@ -32,7 +32,7 @@ def __init__(self, config: QlikSourceConfig) -> None: { "Authorization": f"Bearer {self.config.api_key}", "Content-Type": "application/json", - } + }, ) self.rest_api_url = f"https://{self.config.tenant_hostname}/api/v1" # Test connection by fetching list of api keys @@ -81,7 +81,7 @@ def _get_dataset(self, dataset_id: str, item_id: str) -> Optional[QlikDataset]: return QlikDataset.parse_obj(response_dict) except Exception as e: self._log_http_error( - message=f"Unable to fetch dataset with id {dataset_id}. Exception: {e}" + message=f"Unable to fetch dataset with id {dataset_id}. Exception: {e}", ) return None @@ -98,7 +98,7 @@ def get_user_name(self, user_id: str) -> Optional[str]: return user_name except Exception as e: self._log_http_error( - message=f"Unable to fetch user with id {user_id}. Exception: {e}" + message=f"Unable to fetch user with id {user_id}. Exception: {e}", ) return None @@ -110,19 +110,20 @@ def _get_chart( ) -> Optional[Chart]: try: websocket_connection.websocket_send_request( - method="GetChild", params={"qId": chart_id} + method="GetChild", + params={"qId": chart_id}, ) response = websocket_connection.websocket_send_request(method="GetLayout") q_layout = response[Constant.QLAYOUT] if Constant.HYPERCUBE not in q_layout: logger.warning( - f"Chart with id {chart_id} of sheet {sheet_id} does not have hypercube. q_layout: {q_layout}" + f"Chart with id {chart_id} of sheet {sheet_id} does not have hypercube. q_layout: {q_layout}", ) return None return Chart.parse_obj(q_layout) except Exception as e: self._log_http_error( - message=f"Unable to fetch chart {chart_id} of sheet {sheet_id}. Exception: {e}" + message=f"Unable to fetch chart {chart_id} of sheet {sheet_id}. Exception: {e}", ) return None @@ -133,7 +134,8 @@ def _get_sheet( ) -> Optional[Sheet]: try: websocket_connection.websocket_send_request( - method="GetObject", params={"qId": sheet_id} + method="GetObject", + params={"qId": sheet_id}, ) response = websocket_connection.websocket_send_request(method="GetLayout") sheet_dict = response[Constant.QLAYOUT] @@ -143,11 +145,11 @@ def _get_sheet( sheet = Sheet.parse_obj(sheet_dict[Constant.QMETA]) if Constant.QCHILDLIST not in sheet_dict: logger.warning( - f"Sheet {sheet.title} with id {sheet_id} does not have any charts. sheet_dict: {sheet_dict}" + f"Sheet {sheet.title} with id {sheet_id} does not have any charts. sheet_dict: {sheet_dict}", ) return sheet for i, chart_dict in enumerate( - sheet_dict[Constant.QCHILDLIST][Constant.QITEMS] + sheet_dict[Constant.QCHILDLIST][Constant.QITEMS], ): chart = self._get_chart( websocket_connection, @@ -162,7 +164,7 @@ def _get_sheet( return sheet except Exception as e: self._log_http_error( - message=f"Unable to fetch sheet with id {sheet_id}. Exception: {e}" + message=f"Unable to fetch sheet with id {sheet_id}. Exception: {e}", ) return None @@ -171,17 +173,17 @@ def _add_qri_of_tables(self, tables: List[QlikTable], app_id: str) -> None: app_qri = quote(f"qri:app:sense://{app_id}", safe="") try: response = self.session.get( - f"{self.rest_api_url}/lineage-graphs/nodes/{app_qri}/actions/expand?node={app_qri}&level=TABLE" + f"{self.rest_api_url}/lineage-graphs/nodes/{app_qri}/actions/expand?node={app_qri}&level=TABLE", ) response.raise_for_status() for table_node_qri in response.json()[Constant.GRAPH][Constant.NODES]: table_node_qri = quote(table_node_qri, safe="") response = self.session.get( - f"{self.rest_api_url}/lineage-graphs/nodes/{app_qri}/actions/expand?node={table_node_qri}&level=FIELD" + f"{self.rest_api_url}/lineage-graphs/nodes/{app_qri}/actions/expand?node={table_node_qri}&level=FIELD", ) response.raise_for_status() field_nodes_qris = list( - response.json()[Constant.GRAPH][Constant.NODES].keys() + response.json()[Constant.GRAPH][Constant.NODES].keys(), ) for field_node_qri in field_nodes_qris: response = self.session.post( @@ -206,11 +208,13 @@ def _add_qri_of_tables(self, tables: List[QlikTable], app_id: str) -> None: table.tableQri = table_qri_dict[table.tableName] except Exception as e: self._log_http_error( - message=f"Unable to add QRI for tables of app {app_id}. Exception: {e}" + message=f"Unable to add QRI for tables of app {app_id}. Exception: {e}", ) def _get_app_used_tables( - self, websocket_connection: WebsocketConnection, app_id: str + self, + websocket_connection: WebsocketConnection, + app_id: str, ) -> List[QlikTable]: tables: List[QlikTable] = [] try: @@ -227,12 +231,14 @@ def _get_app_used_tables( self._add_qri_of_tables(tables, app_id) except Exception as e: self._log_http_error( - message=f"Unable to fetch tables used by app {app_id}. Exception: {e}" + message=f"Unable to fetch tables used by app {app_id}. Exception: {e}", ) return tables def _get_app_sheets( - self, websocket_connection: WebsocketConnection, app_id: str + self, + websocket_connection: WebsocketConnection, + app_id: str, ) -> List[Sheet]: sheets: List[Sheet] = [] try: @@ -241,7 +247,7 @@ def _get_app_sheets( params={ "qOptions": { "qTypes": ["sheet"], - } + }, }, ) for sheet_dict in response[Constant.QLIST]: @@ -254,21 +260,23 @@ def _get_app_sheets( websocket_connection.handle.pop() except Exception as e: self._log_http_error( - message=f"Unable to fetch sheets for app {app_id}. Exception: {e}" + message=f"Unable to fetch sheets for app {app_id}. Exception: {e}", ) return sheets def _get_app(self, app_id: str) -> Optional[App]: try: websocket_connection = WebsocketConnection( - self.config.tenant_hostname, self.config.api_key, app_id + self.config.tenant_hostname, + self.config.api_key, + app_id, ) websocket_connection.websocket_send_request( method="OpenDoc", params={"qDocName": app_id}, ) response = websocket_connection.websocket_send_request( - method="GetAppLayout" + method="GetAppLayout", ) app = App.parse_obj(response[Constant.QLAYOUT]) app.sheets = self._get_app_sheets(websocket_connection, app_id) @@ -277,7 +285,7 @@ def _get_app(self, app_id: str) -> Optional[App]: return app except Exception as e: self._log_http_error( - message=f"Unable to fetch app with id {app_id}. Exception: {e}" + message=f"Unable to fetch app with id {app_id}. Exception: {e}", ) return None @@ -294,7 +302,7 @@ def get_items(self) -> List[Item]: if not item.get(Constant.SPACEID): item[Constant.SPACEID] = Constant.PERSONAL_SPACE_ID if self.config.space_pattern.allowed( - self.spaces[item[Constant.SPACEID]] + self.spaces[item[Constant.SPACEID]], ): resource_type = item[Constant.RESOURCETYPE] if resource_type == Constant.APP: diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py index 8d6eaa4bf10474..e8ff07b0a3ad77 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py @@ -130,7 +130,7 @@ def __init__(self, config: QlikSourceConfig, ctx: PipelineContext): except Exception as e: logger.warning(e) exit( - 1 + 1, ) # Exit pipeline as we are not able to connect to Qlik Client Service. @staticmethod @@ -141,7 +141,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -207,19 +208,22 @@ def _gen_space_workunit(self, space: Space) -> Iterable[MetadataWorkUnit]: def _gen_entity_status_aspect(self, entity_urn: str) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( - entityUrn=entity_urn, aspect=Status(removed=False) + entityUrn=entity_urn, + aspect=Status(removed=False), ).as_workunit() def _gen_entity_owner_aspect( - self, entity_urn: str, user_name: str + self, + entity_urn: str, + user_name: str, ) -> MetadataWorkUnit: aspect = OwnershipClass( owners=[ OwnerClass( owner=builder.make_user_urn(user_name), type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -234,7 +238,9 @@ def _gen_dashboard_urn(self, dashboard_identifier: str) -> str: ) def _gen_dashboard_info_workunit( - self, sheet: Sheet, app_id: str + self, + sheet: Sheet, + app_id: str, ) -> MetadataWorkUnit: dashboard_urn = self._gen_dashboard_urn(sheet.id) custom_properties: Dict[str, str] = {"chartCount": str(len(sheet.charts))} @@ -257,11 +263,15 @@ def _gen_dashboard_info_workunit( dashboardUrl=f"https://{self.config.tenant_hostname}/sense/app/{app_id}/sheet/{sheet.id}/state/analysis", ) return MetadataChangeProposalWrapper( - entityUrn=dashboard_urn, aspect=dashboard_info_cls + entityUrn=dashboard_urn, + aspect=dashboard_info_cls, ).as_workunit() def _gen_charts_workunit( - self, charts: List[Chart], input_tables: List[QlikTable], app_id: str + self, + charts: List[Chart], + input_tables: List[QlikTable], + app_id: str, ) -> Iterable[MetadataWorkUnit]: """ Map Qlik Chart to Datahub Chart @@ -331,7 +341,9 @@ def _gen_sheets_workunit(self, app: App) -> Iterable[MetadataWorkUnit]: yield from self._gen_charts_workunit(sheet.charts, app.tables, app.id) def _gen_app_table_upstream_lineage( - self, dataset_urn: str, table: QlikTable + self, + dataset_urn: str, + table: QlikTable, ) -> Optional[MetadataWorkUnit]: upstream_dataset_urn: Optional[str] = None if table.type == BoxType.BLACKBOX: @@ -350,14 +362,15 @@ def _gen_app_table_upstream_lineage( query=table.selectStatement.strip(), default_db=None, platform=KNOWN_DATA_PLATFORM_MAPPING.get( - table.dataconnectorPlatform, table.dataconnectorPlatform + table.dataconnectorPlatform, + table.dataconnectorPlatform, ), env=upstream_dataset_platform_detail.env, platform_instance=upstream_dataset_platform_detail.platform_instance, ).in_tables[0] elif table.type == BoxType.LOADFILE: upstream_dataset_urn = self._gen_qlik_dataset_urn( - f"{table.spaceId}.{table.databaseName}".lower() + f"{table.spaceId}.{table.databaseName}".lower(), ) if upstream_dataset_urn: @@ -366,11 +379,11 @@ def _gen_app_table_upstream_lineage( FineGrainedLineage( upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=[ - builder.make_schema_field_urn(upstream_dataset_urn, field.name) + builder.make_schema_field_urn(upstream_dataset_urn, field.name), ], downstreamType=FineGrainedLineageDownstreamType.FIELD, downstreams=[ - builder.make_schema_field_urn(dataset_urn, field.name) + builder.make_schema_field_urn(dataset_urn, field.name), ], ) for field in table.datasetSchema @@ -380,8 +393,9 @@ def _gen_app_table_upstream_lineage( aspect=UpstreamLineage( upstreams=[ Upstream( - dataset=upstream_dataset_urn, type=DatasetLineageType.COPY - ) + dataset=upstream_dataset_urn, + type=DatasetLineageType.COPY, + ), ], fineGrainedLineages=fine_grained_lineages, ), @@ -390,7 +404,9 @@ def _gen_app_table_upstream_lineage( return None def _gen_app_table_properties( - self, dataset_urn: str, table: QlikTable + self, + dataset_urn: str, + table: QlikTable, ) -> MetadataWorkUnit: dataset_properties = DatasetProperties( name=table.tableName, @@ -401,10 +417,11 @@ def _gen_app_table_properties( Constant.TYPE: "Qlik Table", Constant.DATACONNECTORID: table.dataconnectorid, Constant.DATACONNECTORNAME: table.dataconnectorName, - } + }, ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() def _get_app_table_identifier(self, table: QlikTable) -> Optional[str]: @@ -413,7 +430,9 @@ def _get_app_table_identifier(self, table: QlikTable) -> Optional[str]: return None def _gen_app_tables_workunit( - self, tables: List[QlikTable], app_id: str + self, + tables: List[QlikTable], + app_id: str, ) -> Iterable[MetadataWorkUnit]: for table in tables: table_identifier = self._get_app_table_identifier(table) @@ -443,7 +462,8 @@ def _gen_app_tables_workunit( ).as_workunit() upstream_lineage_workunit = self._gen_app_table_upstream_lineage( - dataset_urn, table + dataset_urn, + table, ) if upstream_lineage_workunit: yield upstream_lineage_workunit @@ -481,23 +501,27 @@ def _gen_qlik_dataset_urn(self, dataset_identifier: str) -> str: ) def _gen_dataplatform_instance_aspect( - self, entity_urn: str + self, + entity_urn: str, ) -> Optional[MetadataWorkUnit]: if self.config.platform_instance: aspect = DataPlatformInstanceClass( platform=builder.make_data_platform_urn(self.platform), instance=builder.make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ) return MetadataChangeProposalWrapper( - entityUrn=entity_urn, aspect=aspect + entityUrn=entity_urn, + aspect=aspect, ).as_workunit() else: return None def _gen_schema_fields( - self, schema: List[QlikDatasetSchemaField] + self, + schema: List[QlikDatasetSchemaField], ) -> List[SchemaField]: schema_fields: List[SchemaField] = [] for field in schema: @@ -508,7 +532,7 @@ def _gen_schema_fields( FIELD_TYPE_MAPPING.get(field.dataType, NullType)() if field.dataType else NullType() - ) + ), ), nativeDataType=field.dataType if field.dataType else "", nullable=field.nullable, @@ -518,7 +542,9 @@ def _gen_schema_fields( return schema_fields def _gen_schema_metadata( - self, dataset_identifier: str, dataset_schema: List[QlikDatasetSchemaField] + self, + dataset_identifier: str, + dataset_schema: List[QlikDatasetSchemaField], ) -> MetadataWorkUnit: dataset_urn = self._gen_qlik_dataset_urn(dataset_identifier) @@ -532,11 +558,14 @@ def _gen_schema_metadata( ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=schema_metadata + entityUrn=dataset_urn, + aspect=schema_metadata, ).as_workunit() def _gen_dataset_properties( - self, dataset_urn: str, dataset: QlikDataset + self, + dataset_urn: str, + dataset: QlikDataset, ) -> MetadataWorkUnit: dataset_properties = DatasetProperties( name=dataset.name, @@ -554,11 +583,12 @@ def _gen_dataset_properties( Constant.DATASETTYPE: dataset.type, Constant.SIZE: str(dataset.size), Constant.ROWCOUNT: str(dataset.rowCount), - } + }, ) return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() def _get_qlik_dataset_identifier(self, dataset: QlikDataset) -> str: @@ -597,7 +627,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/websocket_connection.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/websocket_connection.py index 01ca9415f886a8..93b7b4b5078871 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/websocket_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/websocket_connection.py @@ -17,7 +17,9 @@ def __init__(self, tenant_hostname: str, api_key: str, app_id: str) -> None: self.handle = [-1] def _build_websocket_request_dict( - self, method: str, params: Union[Dict, List] = {} + self, + method: str, + params: Union[Dict, List] = {}, ) -> Dict: return { "jsonrpc": "2.0", @@ -37,7 +39,9 @@ def _send_request(self, request: Dict) -> Dict: return {} def websocket_send_request( - self, method: str, params: Union[Dict, List] = {} + self, + method: str, + params: Union[Dict, List] = {}, ) -> Dict: """ Method to send request to websocket diff --git a/metadata-ingestion/src/datahub/ingestion/source/redash.py b/metadata-ingestion/src/datahub/ingestion/source/redash.py index 666cc8c63aa9ed..f439e783d35ee5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redash.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redash.py @@ -149,7 +149,7 @@ def get_full_qualified_name(self, database_name: str, table_name: str) -> str: class PostgresQualifiedNameParser(QualifiedNameParser): split_char: str = "." names: List = field( - default_factory=lambda: ["database_name", "schema_name", "table_name"] + default_factory=lambda: ["database_name", "schema_name", "table_name"], ) default_schema: Optional[str] = "public" @@ -158,7 +158,7 @@ class PostgresQualifiedNameParser(QualifiedNameParser): class MssqlQualifiedNameParser(QualifiedNameParser): split_char: str = "." names: List = field( - default_factory=lambda: ["database_name", "schema_name", "table_name"] + default_factory=lambda: ["database_name", "schema_name", "table_name"], ) default_schema: Optional[str] = "dbo" @@ -192,7 +192,7 @@ def get_full_qualified_name(self, database_name: str, table_name: str) -> str: class BigqueryQualifiedNameParser(QualifiedNameParser): split_char: str = "." names: List = field( - default_factory=lambda: ["database_name", "schema_name", "table_name"] + default_factory=lambda: ["database_name", "schema_name", "table_name"], ) def get_full_qualified_name(self, database_name: str, table_name: str) -> str: @@ -208,27 +208,32 @@ def get_full_qualified_name(self, database_name: str, table_name: str) -> str: def get_full_qualified_name(platform: str, database_name: str, table_name: str) -> str: if platform == "athena": return AthenaQualifiedNameParser().get_full_qualified_name( - database_name, table_name + database_name, + table_name, ) elif platform == "bigquery": return BigqueryQualifiedNameParser().get_full_qualified_name( - database_name, table_name + database_name, + table_name, ) elif platform == "mssql": return MssqlQualifiedNameParser().get_full_qualified_name( - database_name, table_name + database_name, + table_name, ) elif platform == "mysql": return MysqlQualifiedNameParser().get_full_qualified_name( - database_name, table_name + database_name, + table_name, ) elif platform == "postgres": return PostgresQualifiedNameParser().get_full_qualified_name( - database_name, table_name + database_name, + table_name, ) else: @@ -239,7 +244,8 @@ class RedashConfig(ConfigModel): # See the Redash API for details # https://redash.io/help/user-guide/integrations-and-api/api connect_uri: str = Field( - default="http://localhost:5000", description="Redash base URL." + default="http://localhost:5000", + description="Redash base URL.", ) api_key: str = Field(default="REDASH_API_KEY", description="Redash user API key.") @@ -253,10 +259,12 @@ class RedashConfig(ConfigModel): description="regex patterns for charts to filter for ingestion.", ) skip_draft: bool = Field( - default=True, description="Only ingest published dashboards and charts." + default=True, + description="Only ingest published dashboards and charts.", ) page_size: int = Field( - default=25, description="Limit on number of items to be queried at once." + default=25, + description="Limit on number of items to be queried at once.", ) api_page_limit: int = Field( default=sys.maxsize, @@ -267,7 +275,8 @@ class RedashConfig(ConfigModel): description="Parallelism to use while processing.", ) parse_table_names_from_sql: bool = Field( - default=False, description="See note below." + default=False, + description="See note below.", ) env: str = Field( @@ -328,7 +337,7 @@ def __init__(self, ctx: PipelineContext, config: RedashConfig): { "Content-Type": "application/json", "Accept": "application/json", - } + }, ) # Handling retry and backoff @@ -351,7 +360,7 @@ def __init__(self, ctx: PipelineContext, config: RedashConfig): self.parse_table_names_from_sql = self.config.parse_table_names_from_sql logger.info( - f"Running Redash ingestion with parse_table_names_from_sql={self.parse_table_names_from_sql}" + f"Running Redash ingestion with parse_table_names_from_sql={self.parse_table_names_from_sql}", ) def error(self, log: logging.Logger, key: str, reason: str) -> None: @@ -379,13 +388,15 @@ def _get_platform_based_on_datasource(self, data_source: Dict) -> str: data_source_type = data_source.get("type") if data_source_type: map = REDASH_DATA_SOURCE_TO_DATAHUB_MAP.get( - data_source_type, {"platform": DEFAULT_DATA_SOURCE_PLATFORM} + data_source_type, + {"platform": DEFAULT_DATA_SOURCE_PLATFORM}, ) return map.get("platform", DEFAULT_DATA_SOURCE_PLATFORM) return DEFAULT_DATA_SOURCE_PLATFORM def _get_database_name_based_on_datasource( - self, data_source: Dict + self, + data_source: Dict, ) -> Optional[str]: data_source_type = data_source.get("type", "external") data_source_name = data_source.get("name") @@ -395,18 +406,22 @@ def _get_database_name_based_on_datasource( database_name = data_source_name else: map = REDASH_DATA_SOURCE_TO_DATAHUB_MAP.get( - data_source_type, {"platform": DEFAULT_DATA_SOURCE_PLATFORM} + data_source_type, + {"platform": DEFAULT_DATA_SOURCE_PLATFORM}, ) database_name_key = map.get("db_name_key", "db") database_name = data_source_options.get( - database_name_key, DEFAULT_DATA_BASE_NAME + database_name_key, + DEFAULT_DATA_BASE_NAME, ) return database_name def _get_datasource_urns( - self, data_source: Dict, sql_query_data: Dict = {} + self, + data_source: Dict, + sql_query_data: Dict = {}, ) -> Optional[List[str]]: platform = self._get_platform_based_on_datasource(data_source) database_name = self._get_database_name_based_on_datasource(data_source) @@ -441,13 +456,14 @@ def _get_datasource_urns( else: return [ - builder.make_dataset_urn(platform, database_name, self.config.env) + builder.make_dataset_urn(platform, database_name, self.config.env), ] return None def _get_dashboard_description_from_widgets( - self, dashboard_widgets: List[Dict] + self, + dashboard_widgets: List[Dict], ) -> str: description = "" @@ -472,7 +488,8 @@ def _get_dashboard_description_from_widgets( return description def _get_dashboard_chart_urns_from_widgets( - self, dashboard_widgets: List[Dict] + self, + dashboard_widgets: List[Dict], ) -> List[str]: chart_urns = [] for widget in dashboard_widgets: @@ -482,7 +499,7 @@ def _get_dashboard_chart_urns_from_widgets( visualization_id = visualization.get("id", None) if visualization_id is not None: chart_urns.append( - f"urn:li:chart:({self.platform},{visualization_id})" + f"urn:li:chart:({self.platform},{visualization_id})", ) return chart_urns @@ -497,7 +514,7 @@ def _get_dashboard_snapshot(self, dashboard_data, redash_version): modified_actor = f"urn:li:corpuser:{dashboard_data.get('changed_by', {}).get('username', 'unknown')}" modified_ts = int( - dp.parse(dashboard_data.get("updated_at", "now")).timestamp() * 1000 + dp.parse(dashboard_data.get("updated_at", "now")).timestamp() * 1000, ) title = dashboard_data.get("name", "") @@ -532,7 +549,8 @@ def _get_dashboard_snapshot(self, dashboard_data, redash_version): return dashboard_snapshot def _process_dashboard_response( - self, current_page: int + self, + current_page: int, ) -> Iterable[MetadataWorkUnit]: logger.info(f"Starting processing dashboard for page {current_page}") if current_page > self.api_page_limit: @@ -540,7 +558,8 @@ def _process_dashboard_response( return with PerfTimer() as timer: dashboards_response = self.client.dashboards( - page=current_page, page_size=self.config.page_size + page=current_page, + page_size=self.config.page_size, ) for dashboard_response in dashboards_response["results"]: dashboard_name = dashboard_response["name"] @@ -557,7 +576,7 @@ def _process_dashboard_response( # Tested the same with a Redash instance dashboard_id = dashboard_response["id"] dashboard_data = self.client._get( - f"api/dashboards/{dashboard_id}" + f"api/dashboards/{dashboard_id}", ).json() except Exception: # This does not work in our testing but keeping for now because @@ -571,13 +590,14 @@ def _process_dashboard_response( logger.debug(dashboard_data) dashboard_snapshot = self._get_dashboard_snapshot( - dashboard_data, redash_version + dashboard_data, + redash_version, ) mce = MetadataChangeEvent(proposedSnapshot=dashboard_snapshot) yield MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce) self.report.timing[f"dashboard-{current_page}"] = int( - timer.elapsed_seconds() + timer.elapsed_seconds(), ) def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: @@ -586,7 +606,7 @@ def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: total_dashboards = dashboards_response["count"] max_page = math.ceil(total_dashboards / self.config.page_size) logger.info( - f"/api/dashboards total count {total_dashboards} and max page {max_page}" + f"/api/dashboards total count {total_dashboards} and max page {max_page}", ) self.report.total_dashboards = total_dashboards self.report.max_page_dashboards = max_page @@ -636,7 +656,7 @@ def _get_chart_snapshot(self, query_data: Dict, viz_data: Dict) -> ChartSnapshot modified_actor = f"urn:li:corpuser:{viz_data.get('changed_by', {}).get('username', 'unknown')}" modified_ts = int( - dp.parse(viz_data.get("updated_at", "now")).timestamp() * 1000 + dp.parse(viz_data.get("updated_at", "now")).timestamp() * 1000, ) title = f"{query_data.get('name')} {viz_data.get('name', '')}" @@ -683,7 +703,8 @@ def _process_query_response(self, current_page: int) -> Iterable[MetadataWorkUni return with PerfTimer() as timer: queries_response = self.client.queries( - page=current_page, page_size=self.config.page_size + page=current_page, + page_size=self.config.page_size, ) for query_response in queries_response["results"]: chart_name = query_response["name"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 932ada0a908b28..90301c110640c1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -122,7 +122,8 @@ class RedshiftConfig( ) include_table_lineage: bool = Field( - default=True, description="Whether table lineage should be ingested." + default=True, + description="Whether table lineage should be ingested.", ) include_copy_lineage: bool = Field( default=True, @@ -203,7 +204,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: logger.warning( "Please update `schema_pattern` to match against fully qualified schema name `.` and set config `match_fully_qualified_names : True`." "Current default `match_fully_qualified_names: False` is only to maintain backward compatibility. " - "The config option `match_fully_qualified_names` will be deprecated in future and the default behavior will assume `match_fully_qualified_names: True`." + "The config option `match_fully_qualified_names` will be deprecated in future and the default behavior will assume `match_fully_qualified_names: True`.", ) return values @@ -215,7 +216,7 @@ def connection_config_compatibility_set(cls, values: Dict) -> Dict: and len(values["extra_client_options"]) > 0 ): raise ValueError( - "Cannot set both `connect_args` and `extra_client_options` in the config. Please use `extra_client_options` only." + "Cannot set both `connect_args` and `extra_client_options` in the config. Please use `extra_client_options` only.", ) if "options" in values and "connect_args" in values["options"]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/exception.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/exception.py index ed0856fc1e2924..2e0dc7f5d05676 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/exception.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/exception.py @@ -35,7 +35,8 @@ def handle_redshift_exceptions_yield( def report_redshift_failure( - report: RedshiftReport, e: redshift_connector.Error + report: RedshiftReport, + e: redshift_connector.Error, ) -> None: error_message = str(e).lower() if "permission denied" in error_message: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py index 4b3d238a13261c..73cfdf69ad4acc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py @@ -116,7 +116,7 @@ def merge_lineage( if c.downstream.column in existing_cll: # Merge using upstream + column name as the merge key. existing_cll[c.downstream.column].upstreams = deduplicate_list( - [*existing_cll[c.downstream.column].upstreams, *c.upstreams] + [*existing_cll[c.downstream.column].upstreams, *c.upstreams], ) else: # New output column, just add it as is. @@ -177,7 +177,9 @@ def __init__( self.temp_tables: Dict[str, TempTableRow] = {} def _init_temp_table_schema( - self, database: str, temp_tables: List[TempTableRow] + self, + database: str, + temp_tables: List[TempTableRow], ) -> None: if self.context.graph is None: # to silent lint return @@ -192,7 +194,7 @@ def _init_temp_table_schema( # prepare dataset_urn vs List of schema fields for table in temp_tables: logger.debug( - f"Processing temp table: {table.create_command} with query text {table.query_text}" + f"Processing temp table: {table.create_command} with query text {table.query_text}", ) result = sqlglot_l.create_lineage_sql_parsed_result( platform=LineageDatasetPlatform.REDSHIFT.value, @@ -232,7 +234,7 @@ def _init_temp_table_schema( if downstream_urn not in dataset_vs_columns: dataset_vs_columns[downstream_urn] = [] dataset_vs_columns[downstream_urn].extend( - sqlglot_l.infer_output_schema(table.parsed_result) or [] + sqlglot_l.infer_output_schema(table.parsed_result) or [], ) # Add datasets, and it's respective fields in schema_resolver, so that later schema_resolver would be able @@ -244,7 +246,7 @@ def _init_temp_table_schema( schema_metadata=SchemaMetadata( schemaName=table_name, platform=builder.make_data_platform_urn( - LineageDatasetPlatform.REDSHIFT.value + LineageDatasetPlatform.REDSHIFT.value, ), version=0, hash="", @@ -257,7 +259,8 @@ def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: self.report.stateful_lineage_ingestion_enabled = True return self.redundant_run_skip_handler.suggest_run_time_window( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) else: return self.config.start_time, self.config.end_time @@ -279,7 +282,7 @@ def _get_s3_path(self, path: str) -> Optional[str]: ): self.report.num_lineage_dropped_s3_path += 1 logger.debug( - f"Skipping s3 path {path} as it does not match any path spec." + f"Skipping s3 path {path} as it does not match any path spec.", ) return None @@ -313,7 +316,7 @@ def _get_sources_from_query( return sources, None elif parsed_result.debug_info.table_error: logger.debug( - f"native query parsing failed for {query} with error: {parsed_result.debug_info.table_error}" + f"native query parsing failed for {query} with error: {parsed_result.debug_info.table_error}", ) return sources, None @@ -339,7 +342,7 @@ def _build_s3_path_from_row(self, filename: str) -> Optional[str]: path = filename.strip() if urlparse(path).scheme != "s3": raise ValueError( - f"Only s3 source supported with copy/unload. The source was: {path}" + f"Only s3 source supported with copy/unload. The source was: {path}", ) s3_path = self._get_s3_path(path) return strip_s3_prefix(s3_path) if s3_path else None @@ -369,7 +372,7 @@ def _get_sources( sources, cll = self._get_sources_from_query(db_name=db_name, query=ddl) except Exception as e: logger.warning( - f"Error parsing query {ddl} for getting lineage. Error was {e}." + f"Error parsing query {ddl} for getting lineage. Error was {e}.", ) self.report.num_lineage_dropped_query_parser += 1 else: @@ -378,7 +381,7 @@ def _get_sources( path = filename.strip() if urlparse(path).scheme != "s3": logger.warning( - "Only s3 source supported with copy. The source was: {path}." + "Only s3 source supported with copy. The source was: {path}.", ) self.report.num_lineage_dropped_not_support_copy_path += 1 return [], None @@ -413,7 +416,7 @@ def _get_sources( LineageDataset( platform=platform, urn=urn, - ) + ), ] return sources, cll @@ -448,7 +451,8 @@ def _populate_lineage_map( alias_db_name = self.config.database for lineage_row in RedshiftDataDictionary.get_lineage_rows( - conn=connection, query=query + conn=connection, + query=query, ): target = self._get_target_lineage( alias_db_name, @@ -460,7 +464,7 @@ def _populate_lineage_map( continue logger.debug( - f"Processing {lineage_type.name} lineage row: {lineage_row}" + f"Processing {lineage_type.name} lineage row: {lineage_row}", ) sources, cll = self._get_sources( @@ -481,20 +485,21 @@ def _populate_lineage_map( alias_db_name=alias_db_name, raw_db_name=raw_db_name, connection=connection, - ) + ), ) target.cll = cll # Merging upstreams if dataset already exists and has upstreams if target.dataset.urn in self._lineage_map: self._lineage_map[target.dataset.urn].merge_lineage( - upstreams=target.upstreams, cll=target.cll + upstreams=target.upstreams, + cll=target.cll, ) else: self._lineage_map[target.dataset.urn] = target logger.debug( - f"Lineage[{target}]:{self._lineage_map[target.dataset.urn]}" + f"Lineage[{target}]:{self._lineage_map[target.dataset.urn]}", ) except Exception as e: self.warn( @@ -505,7 +510,8 @@ def _populate_lineage_map( self.report_status(f"extract-{lineage_type.name}", False) def _update_lineage_map_for_table_renames( - self, table_renames: Dict[str, TableRename] + self, + table_renames: Dict[str, TableRename], ) -> None: if not table_renames: return @@ -517,7 +523,7 @@ def _update_lineage_map_for_table_renames( prev_table_lineage = self._lineage_map.get(entry.original_urn) if prev_table_lineage: logger.debug( - f"including lineage for {entry.original_urn} in {entry.new_urn} due to table rename" + f"including lineage for {entry.original_urn} in {entry.new_urn} due to table rename", ) self._lineage_map[entry.new_urn].merge_lineage( upstreams=prev_table_lineage.upstreams, @@ -539,7 +545,7 @@ def _get_target_lineage( if ( not self.config.schema_pattern.allowed(lineage_row.target_schema) or not self.config.table_pattern.allowed( - f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}" + f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}", ) ) and not ( # We also check the all_tables_set, since this might be a renamed table @@ -649,7 +655,7 @@ def _get_upstream_lineages( logger.debug( f"Number of permanent datasets found for {target_table} = {number_of_permanent_dataset_found} in " - f"temp tables {probable_temp_tables}" + f"temp tables {probable_temp_tables}", ) return target_source @@ -743,11 +749,13 @@ def populate_lineage( self._update_lineage_map_for_table_renames(table_renames=table_renames) self.report.lineage_mem_size[self.config.database] = humanfriendly.format_size( - memory_footprint.total_size(self._lineage_map) + memory_footprint.total_size(self._lineage_map), ) def make_fine_grained_lineage_class( - self, lineage_item: LineageItem, dataset_urn: str + self, + lineage_item: LineageItem, + dataset_urn: str, ) -> List[FineGrainedLineage]: fine_grained_lineages: List[FineGrainedLineage] = [] @@ -781,7 +789,7 @@ def make_fine_grained_lineage_class( downstreams=downstream, upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=upstreams, - ) + ), ) logger.debug(f"Created fine_grained_lineage for {dataset_urn}") @@ -865,7 +873,8 @@ def _process_table_renames( ) for rename_row in RedshiftDataDictionary.get_alter_table_commands( - connection, query + connection, + query, ): # Redshift's system table has some issues where it encodes newlines as \n instead a proper # newline character. This can cause issues in our parser. @@ -895,7 +904,10 @@ def _process_table_renames( ) table_renames[new_urn] = TableRename( - prev_urn, new_urn, query_text, timestamp=rename_row.start_time + prev_urn, + new_urn, + query_text, + timestamp=rename_row.start_time, ) # We want to generate lineage for the previous name too. @@ -905,7 +917,8 @@ def _process_table_renames( return table_renames, all_tables def get_temp_tables( - self, connection: redshift_connector.Connection + self, + connection: redshift_connector.Connection, ) -> Iterable[TempTableRow]: ddl_query: str = self.queries.temp_table_ddl_query( start_time=self.config.start_time, @@ -921,14 +934,16 @@ def get_temp_tables( yield row def find_temp_tables( - self, temp_table_rows: List[TempTableRow], temp_table_names: List[str] + self, + temp_table_rows: List[TempTableRow], + temp_table_names: List[str], ) -> List[TempTableRow]: matched_temp_tables: List[TempTableRow] = [] for table_name in temp_table_names: prefixes = self.queries.get_temp_table_clause(table_name) prefixes.extend( - self.queries.get_temp_table_clause(table_name.split(".")[-1]) + self.queries.get_temp_table_clause(table_name.split(".")[-1]), ) for row in temp_table_rows: @@ -940,7 +955,9 @@ def find_temp_tables( return matched_temp_tables def resolve_column_refs( - self, column_refs: List[sqlglot_l.ColumnRef], depth: int = 0 + self, + column_refs: List[sqlglot_l.ColumnRef], + depth: int = 0, ) -> List[sqlglot_l.ColumnRef]: """ This method resolves the column reference to the original column reference. @@ -955,7 +972,7 @@ def resolve_column_refs( if depth >= max_depth: logger.warning( - f"Max depth reached for resolving temporary columns: {column_refs}" + f"Max depth reached for resolving temporary columns: {column_refs}", ) self.report.num_unresolved_temp_columns += 1 return column_refs @@ -972,19 +989,20 @@ def resolve_column_refs( ): resolved_column_refs.extend( self.resolve_column_refs( - column_lineage.upstreams, depth=depth + 1 - ) + column_lineage.upstreams, + depth=depth + 1, + ), ) resolved = True break # If we reach here, it means that we were not able to resolve the column reference. if resolved is False: logger.warning( - f"Unable to resolve column reference {ref} to a permanent table" + f"Unable to resolve column reference {ref} to a permanent table", ) else: logger.debug( - f"Resolved column reference {ref} is not resolved because referenced table {ref.table} is not a temp table or not found. Adding reference as non-temp table. This is normal." + f"Resolved column reference {ref} is not resolved because referenced table {ref.table} is not a temp table or not found. Adding reference as non-temp table. This is normal.", ) resolved_column_refs.append(ref) return resolved_column_refs @@ -1011,7 +1029,7 @@ def _update_target_dataset_cll( == target_column_ref.column ): resolved_columns = self.resolve_column_refs( - source_column_lineage.upstreams + source_column_lineage.upstreams, ) # Add all upstream of above temporary column into upstream of target column upstreams.extend(resolved_columns) @@ -1036,7 +1054,7 @@ def _add_permanent_datasets_recursively( for temp_table in temp_table_rows: logger.debug( - f"Processing temp table with transaction id: {temp_table.transaction_id} and query text {temp_table.query_text}" + f"Processing temp table with transaction id: {temp_table.transaction_id} and query text {temp_table.query_text}", ) intermediate_l_datasets, cll = self._get_sources_from_query( diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py index 4b84c25965a994..73d1a553bbff4c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py @@ -145,7 +145,7 @@ def build( database=self.database, connection=connection, all_tables=collections.defaultdict( - lambda: collections.defaultdict(set) + lambda: collections.defaultdict(set), ), ) for entry in table_renames.values(): @@ -166,7 +166,7 @@ def build( LineageCollectorType.QUERY_SQL_PARSER, query, self._process_sql_parser_lineage, - ) + ), ) if self.config.table_lineage_mode in { LineageMode.STL_SCAN_BASED, @@ -179,14 +179,18 @@ def build( self.end_time, ) populate_calls.append( - (LineageCollectorType.QUERY_SCAN, query, self._process_stl_scan_lineage) + ( + LineageCollectorType.QUERY_SCAN, + query, + self._process_stl_scan_lineage, + ), ) if self.config.include_views and self.config.include_view_lineage: # Populate lineage for views query = self.queries.view_lineage_query() populate_calls.append( - (LineageCollectorType.VIEW, query, self._process_view_lineage) + (LineageCollectorType.VIEW, query, self._process_view_lineage), ) # Populate lineage for late binding views @@ -196,7 +200,7 @@ def build( LineageCollectorType.VIEW_DDL_SQL_PARSING, query, self._process_view_lineage, - ) + ), ) if self.config.include_copy_lineage: @@ -207,7 +211,7 @@ def build( end_time=self.end_time, ) populate_calls.append( - (LineageCollectorType.COPY, query, self._process_copy_command) + (LineageCollectorType.COPY, query, self._process_copy_command), ) if self.config.include_unload_lineage: @@ -218,7 +222,7 @@ def build( end_time=self.end_time, ) populate_calls.append( - (LineageCollectorType.UNLOAD, query, self._process_unload_command) + (LineageCollectorType.UNLOAD, query, self._process_unload_command), ) for lineage_type, query, processor in populate_calls: @@ -244,11 +248,13 @@ def _populate_lineage_agg( logger.debug(f"Processing {lineage_type.name} lineage query: {query}") timer = self.report.lineage_phases_timer.setdefault( - lineage_type.name, PerfTimer() + lineage_type.name, + PerfTimer(), ) with timer: for lineage_row in RedshiftDataDictionary.get_lineage_rows( - conn=connection, query=query + conn=connection, + query=query, ): processor(lineage_row) except Exception as e: @@ -274,7 +280,7 @@ def _process_sql_parser_lineage(self, lineage_row: LineageRow) -> None: default_schema=self.config.default_schema, timestamp=lineage_row.timestamp, session_id=lineage_row.session_id, - ) + ), ) def _make_filtered_target(self, lineage_row: LineageRow) -> Optional[DatasetUrn]: @@ -286,7 +292,7 @@ def _make_filtered_target(self, lineage_row: LineageRow) -> Optional[DatasetUrn] ) if target.urn() not in self.known_urns: logger.debug( - f"Skipping lineage for {target.urn()} as it is not in known_urns" + f"Skipping lineage for {target.urn()} as it is not in known_urns", ) return None @@ -306,7 +312,7 @@ def _process_stl_scan_lineage(self, lineage_row: LineageRow) -> None: if lineage_row.ddl is None: logger.warning( - f"stl scan entry is missing query text for {lineage_row.source_schema}.{lineage_row.source_table}" + f"stl scan entry is missing query text for {lineage_row.source_schema}.{lineage_row.source_table}", ) return self.aggregator.add_known_query_lineage( @@ -354,7 +360,7 @@ def _process_copy_command(self, lineage_row: LineageRow) -> None: logger.debug(f"Recognized s3 dataset urn: {s3_urn}") if not lineage_row.target_schema or not lineage_row.target_table: logger.debug( - f"Didn't find target schema (found: {lineage_row.target_schema}) or target table (found: {lineage_row.target_table})" + f"Didn't find target schema (found: {lineage_row.target_schema}) or target table (found: {lineage_row.target_table})", ) return target = self._make_filtered_target(lineage_row) @@ -362,7 +368,8 @@ def _process_copy_command(self, lineage_row: LineageRow) -> None: return self.aggregator.add_known_lineage_mapping( - upstream_urn=s3_urn, downstream_urn=target.urn() + upstream_urn=s3_urn, + downstream_urn=target.urn(), ) def _process_unload_command(self, lineage_row: LineageRow) -> None: @@ -386,12 +393,13 @@ def _process_unload_command(self, lineage_row: LineageRow) -> None: ) if source.urn() not in self.known_urns: logger.debug( - f"Skipping unload lineage for {source.urn()} as it is not in known_urns" + f"Skipping unload lineage for {source.urn()} as it is not in known_urns", ) return self.aggregator.add_known_lineage_mapping( - upstream_urn=source.urn(), downstream_urn=output_urn + upstream_urn=source.urn(), + downstream_urn=output_urn, ) def _process_external_tables( diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/profile.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/profile.py index 6f611fa6741879..e2284273b59dcb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/profile.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/profile.py @@ -39,7 +39,8 @@ def get_workunits( # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow if self.config.is_profiling_enabled(): self.config.options.setdefault( - "max_overflow", self.config.profiling.max_workers + "max_overflow", + self.config.profiling.max_workers, ) for db in tables.keys(): @@ -53,7 +54,7 @@ def get_workunits( # Case 1: If user did not tell us to profile external tables, simply log this. self.report.profiling_skipped_other[schema] += 1 logger.info( - f"Skipping profiling of external table {db}.{schema}.{table.name}" + f"Skipping profiling of external table {db}.{schema}.{table.name}", ) # Continue, since we should not profile this table. continue diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py index 71a20890d35e88..f1d37de3ace92b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/query.py @@ -272,7 +272,9 @@ def list_late_view_ddls_query() -> str: @staticmethod def alter_table_rename_query( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: start_time_str: str = start_time.strftime(redshift_datetime_format) end_time_str: str = end_time.strftime(redshift_datetime_format) @@ -293,7 +295,9 @@ def alter_table_rename_query( @staticmethod def list_copy_commands_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """\ SELECT DISTINCT @@ -342,13 +346,17 @@ def operation_aspect_query(start_time: str, end_time: str) -> str: @staticmethod def stl_scan_based_lineage_query( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: raise NotImplementedError @staticmethod def list_unload_commands_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: raise NotImplementedError @@ -358,7 +366,9 @@ def temp_table_ddl_query(start_time: datetime, end_time: datetime) -> str: @staticmethod def list_insert_create_queries_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: raise NotImplementedError @@ -400,7 +410,9 @@ def additional_table_metadata_query() -> str: @staticmethod def stl_scan_based_lineage_query( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """ select @@ -470,7 +482,9 @@ def stl_scan_based_lineage_query( @staticmethod def list_unload_commands_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """ select @@ -502,7 +516,9 @@ def list_unload_commands_sql( @staticmethod def list_insert_create_queries_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """\ with query_txt as ( @@ -794,7 +810,9 @@ def additional_table_metadata_query() -> str: @staticmethod def stl_scan_based_lineage_query( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """ SELECT @@ -861,7 +879,9 @@ def stl_scan_based_lineage_query( @staticmethod def list_unload_commands_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """ SELECT @@ -893,7 +913,9 @@ def list_unload_commands_sql( # * querytxt do not contain newlines (to be confirmed it is not a problem) @staticmethod def list_insert_create_queries_sql( - db_name: str, start_time: datetime, end_time: datetime + db_name: str, + start_time: datetime, + end_time: datetime, ) -> str: return """ SELECT diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index cce282c71056a2..91370d3b59eabb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -315,7 +315,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -334,7 +335,8 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): self.domain_registry = None if self.config.domain: self.domain_registry = DomainRegistry( - cached_domains=list(self.config.domain.keys()), graph=self.ctx.graph + cached_domains=list(self.config.domain.keys()), + graph=self.ctx.graph, ) self.redundant_lineage_run_skip_handler: Optional[ @@ -358,7 +360,7 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): ) self.data_dictionary = RedshiftDataDictionary( - is_serverless=self.config.is_serverless + is_serverless=self.config.is_serverless, ) self.db_tables: Dict[str, Dict[str, List[RedshiftTable]]] = {} @@ -408,10 +410,13 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), functools.partial( - auto_incremental_lineage, self.config.incremental_lineage + auto_incremental_lineage, + self.config.incremental_lineage, ), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -432,11 +437,16 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit # TODO: Ideally, we'd push down exception handling to the place where the connection is used, as opposed to keeping # this fallback. For now, this gets us broad coverage quickly. yield from handle_redshift_exceptions_yield( - self.report, self._extract_metadata, connection, database + self.report, + self._extract_metadata, + connection, + database, ) def _extract_metadata( - self, connection: redshift_connector.Connection, database: str + self, + connection: redshift_connector.Connection, + database: str, ) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: yield from self.gen_database_container( database=database, @@ -445,10 +455,10 @@ def _extract_metadata( self.cache_tables_and_views(connection, database) self.report.tables_in_mem_size[database] = humanfriendly.format_size( - memory_footprint.total_size(self.db_tables) + memory_footprint.total_size(self.db_tables), ) self.report.views_in_mem_size[database] = humanfriendly.format_size( - memory_footprint.total_size(self.db_views) + memory_footprint.total_size(self.db_views), ) if self.config.use_lineage_v2: @@ -460,7 +470,7 @@ def _extract_metadata( redundant_run_skip_handler=self.redundant_lineage_run_skip_handler, ) as lineage_extractor: yield from lineage_extractor.aggregator.register_schemas_from_stream( - self.process_schemas(connection, database) + self.process_schemas(connection, database), ) with self.report.new_stage(LINEAGE_EXTRACTION): @@ -483,13 +493,17 @@ def _extract_metadata( ): with self.report.new_stage(LINEAGE_EXTRACTION): yield from self.extract_lineage( - connection=connection, all_tables=all_tables, database=database + connection=connection, + all_tables=all_tables, + database=database, ) if self.config.include_usage_statistics: with self.report.new_stage(USAGE_EXTRACTION_INGESTION): yield from self.extract_usage( - connection=connection, all_tables=all_tables, database=database + connection=connection, + all_tables=all_tables, + database=database, ) if self.config.is_profiling_enabled(): @@ -503,7 +517,8 @@ def _extract_metadata( def process_schemas(self, connection, database): for schema in self.data_dictionary.get_schemas( - conn=connection, database=database + conn=connection, + database=database, ): if not is_schema_allowed( self.config.schema_pattern, @@ -563,7 +578,8 @@ def process_schema( schema_columns: Dict[str, Dict[str, List[RedshiftColumn]]] = {} schema_columns[schema.name] = self.data_dictionary.get_columns_for_schema( - conn=connection, schema=schema + conn=connection, + schema=schema, ) if self.config.include_tables: @@ -577,7 +593,8 @@ def process_schema( table.columns = schema_columns[schema.name].get(table.name, []) table.column_count = len(table.columns) table_wu_generator = self._process_table( - table, database=database + table, + database=database, ) yield from classification_workunit_processor( table_wu_generator, @@ -587,12 +604,13 @@ def process_schema( ) self.report.table_processed[report_key] = ( self.report.table_processed.get( - f"{database}.{schema.name}", 0 + f"{database}.{schema.name}", + 0, ) + 1 ) logger.debug( - f"Table processed: {schema.database}.{schema.name}.{table.name}" + f"Table processed: {schema.database}.{schema.name}.{table.name}", ) else: self.report.info( @@ -613,17 +631,20 @@ def process_schema( view.columns = schema_columns[schema.name].get(view.name, []) view.column_count = len(view.columns) yield from self._process_view( - table=view, database=database, schema=schema + table=view, + database=database, + schema=schema, ) self.report.view_processed[report_key] = ( self.report.view_processed.get( - f"{database}.{schema.name}", 0 + f"{database}.{schema.name}", + 0, ) + 1 ) logger.debug( - f"Table processed: {schema.database}.{schema.name}.{view.name}" + f"Table processed: {schema.database}.{schema.name}.{view.name}", ) else: self.report.info( @@ -635,7 +656,7 @@ def process_schema( logger.info("View processing disabled, skipping") self.report.metadata_extraction_sec[report_key] = timer.elapsed_seconds( - digits=2 + digits=2, ) def _process_table( @@ -652,11 +673,16 @@ def _process_table( return yield from self.gen_table_dataset_workunits( - table, database=database, dataset_name=datahub_dataset_name + table, + database=database, + dataset_name=datahub_dataset_name, ) def _process_view( - self, table: RedshiftView, database: str, schema: RedshiftSchema + self, + table: RedshiftView, + database: str, + schema: RedshiftSchema, ) -> Iterable[MetadataWorkUnit]: datahub_dataset_name = f"{database}.{schema.name}.{table.name}" @@ -735,7 +761,8 @@ def gen_view_dataset_workunits( viewLogic=view.ddl, ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=view_properties_aspect + entityUrn=dataset_urn, + aspect=view_properties_aspect, ).as_workunit() # TODO: Remove to common? @@ -788,7 +815,8 @@ def gen_schema_metadata( fields=self.gen_schema_fields(table.columns), ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=schema_metadata + entityUrn=dataset_urn, + aspect=schema_metadata, ).as_workunit() def gen_dataset_urn(self, datahub_dataset_name: str) -> str: @@ -812,7 +840,9 @@ def gen_dataset_workunits( dataset_urn = self.gen_dataset_urn(datahub_dataset_name) yield from self.gen_schema_metadata( - dataset_urn, table, str(datahub_dataset_name) + dataset_urn, + table, + str(datahub_dataset_name), ) dataset_properties = DatasetProperties( @@ -835,15 +865,18 @@ def gen_dataset_workunits( # TODO: use auto_incremental_properties workunit processor instead # Deprecate use of patch_custom_properties patch_builder = create_dataset_props_patch_builder( - dataset_urn, dataset_properties + dataset_urn, + dataset_properties, ) for patch_mcp in patch_builder.build(): yield MetadataWorkUnit( - id=f"{dataset_urn}-{patch_mcp.aspectName}", mcp_raw=patch_mcp + id=f"{dataset_urn}-{patch_mcp.aspectName}", + mcp_raw=patch_mcp, ) else: yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() # TODO: Check if needed @@ -872,7 +905,8 @@ def gen_dataset_workunits( subTypes = SubTypes(typeNames=[sub_type]) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=subTypes + entityUrn=dataset_urn, + aspect=subTypes, ).as_workunit() if self.domain_registry: @@ -896,14 +930,14 @@ def cache_tables_and_views(self, connection, database): self.config.match_fully_qualified_names, ): logger.debug( - f"Not caching table for schema {database}.{schema} which is not allowed by schema_pattern" + f"Not caching table for schema {database}.{schema} which is not allowed by schema_pattern", ) continue self.db_tables[database][schema] = [] for table in tables[schema]: if self.config.table_pattern.allowed( - f"{database}.{schema}.{table.name}" + f"{database}.{schema}.{table.name}", ): self.db_tables[database][schema].append(table) self.report.table_cached[f"{database}.{schema}"] = ( @@ -911,7 +945,7 @@ def cache_tables_and_views(self, connection, database): ) else: logger.debug( - f"Table {database}.{schema}.{table.name} is filtered by table_pattern" + f"Table {database}.{schema}.{table.name} is filtered by table_pattern", ) self.report.table_filtered[f"{database}.{schema}"] = ( self.report.table_filtered.get(f"{database}.{schema}", 0) + 1 @@ -925,7 +959,7 @@ def cache_tables_and_views(self, connection, database): self.config.match_fully_qualified_names, ): logger.debug( - f"Not caching views for schema {database}.{schema} which is not allowed by schema_pattern" + f"Not caching views for schema {database}.{schema} which is not allowed by schema_pattern", ) continue @@ -938,7 +972,7 @@ def cache_tables_and_views(self, connection, database): ) else: logger.debug( - f"View {database}.{schema}.{view.name} is filtered by view_pattern" + f"View {database}.{schema}.{view.name} is filtered by view_pattern", ) self.report.view_filtered[f"{database}.{schema}"] = ( self.report.view_filtered.get(f"{database}.{schema}", 0) + 1 @@ -1007,20 +1041,24 @@ def extract_lineage( with PerfTimer() as timer: lineage_extractor.populate_lineage( - database=database, connection=connection, all_tables=all_tables + database=database, + connection=connection, + all_tables=all_tables, ) self.report.lineage_extraction_sec[f"{database}"] = timer.elapsed_seconds( - digits=2 + digits=2, ) yield from self.generate_lineage( - database, lineage_extractor=lineage_extractor + database, + lineage_extractor=lineage_extractor, ) if self.redundant_lineage_run_skip_handler: # Update the checkpoint state for this run. self.redundant_lineage_run_skip_handler.update_state( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) def extract_lineage_v2( @@ -1036,19 +1074,22 @@ def extract_lineage_v2( all_tables = self.get_all_tables() lineage_extractor.build( - connection=connection, all_tables=all_tables, db_schemas=self.db_schemas + connection=connection, + all_tables=all_tables, + db_schemas=self.db_schemas, ) yield from lineage_extractor.generate() self.report.lineage_extraction_sec[f"{database}"] = timer.elapsed_seconds( - digits=2 + digits=2, ) if self.redundant_lineage_run_skip_handler: # Update the checkpoint state for this run. self.redundant_lineage_run_skip_handler.update_state( - lineage_extractor.start_time, lineage_extractor.end_time + lineage_extractor.start_time, + lineage_extractor.end_time, ) def _should_ingest_lineage(self) -> bool: @@ -1069,18 +1110,20 @@ def _should_ingest_lineage(self) -> bool: return True def generate_lineage( - self, database: str, lineage_extractor: RedshiftLineageExtractor + self, + database: str, + lineage_extractor: RedshiftLineageExtractor, ) -> Iterable[MetadataWorkUnit]: logger.info(f"Generate lineage for {database}") for schema in deduplicate_list( - itertools.chain(self.db_tables[database], self.db_views[database]) + itertools.chain(self.db_tables[database], self.db_views[database]), ): if ( database not in self.db_schemas or schema not in self.db_schemas[database] ): logger.warning( - f"Either database {database} or {schema} exists in the lineage but was not discovered earlier. Something went wrong." + f"Either database {database} or {schema} exists in the lineage but was not discovered earlier. Something went wrong.", ) continue diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py index 8af436ae49979f..e6c95774af0a72 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py @@ -19,7 +19,10 @@ def __init__(self, conn: redshift_connector.Connection) -> None: self.conn = conn def get_sample_data_for_table( - self, table_id: List[str], sample_size: int, **kwargs: Any + self, + table_id: List[str], + sample_size: int, + **kwargs: Any, ) -> Dict[str, list]: """ For redshift, table_id should be in form (db_name, schema_name, table_name) @@ -30,7 +33,7 @@ def get_sample_data_for_table( table_name = table_id[2] logger.debug( - f"Collecting sample values for table {db_name}.{schema_name}.{table_name}" + f"Collecting sample values for table {db_name}.{schema_name}.{table_name}", ) with PerfTimer() as timer, self.conn.cursor() as cursor: sql = f"select * from {db_name}.{schema_name}.{table_name} limit {sample_size};" @@ -40,7 +43,7 @@ def get_sample_data_for_table( time_taken = timer.elapsed_seconds() logger.debug( f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};" - f"{df.shape[0]} rows; took {time_taken:.3f} seconds" + f"{df.shape[0]} rows; took {time_taken:.3f} seconds", ) return df.to_dict(orient="list") diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py index 594f88dd521ad5..07649b62ebf27f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_schema.py @@ -122,7 +122,8 @@ def __init__(self, is_serverless): @staticmethod def get_query_result( - conn: redshift_connector.Connection, query: str + conn: redshift_connector.Connection, + query: str, ) -> redshift_connector.Cursor: cursor: redshift_connector.Cursor = conn.cursor() @@ -143,7 +144,8 @@ def get_databases(conn: redshift_connector.Connection) -> List[str]: @staticmethod def get_schemas( - conn: redshift_connector.Connection, database: str + conn: redshift_connector.Connection, + database: str, ) -> List[RedshiftSchema]: cursor = RedshiftDataDictionary.get_query_result( conn, @@ -172,7 +174,8 @@ def enrich_tables( # Warning: This table enrichment will not return anything for # external tables (spectrum) and for tables that have never been queried / written to. cur = RedshiftDataDictionary.get_query_result( - conn, self.queries.additional_table_metadata_query() + conn, + self.queries.additional_table_metadata_query(), ) field_names = [i[0] for i in cur.description] db_table_metadata = cur.fetchall() @@ -236,7 +239,10 @@ def get_tables_and_views( rows_count, size_in_bytes, ) = RedshiftDataDictionary.get_table_stats( - enriched_tables, field_names, schema, table + enriched_tables, + field_names, + schema, + table, ) tables[schema].append( @@ -255,7 +261,7 @@ def get_tables_and_views( output_parameters=table[field_names.index("output_format")], serde_parameters=table[field_names.index("serde_parameters")], comment=table[field_names.index("table_description")], - ) + ), ) else: if schema not in views: @@ -288,12 +294,12 @@ def get_tables_and_views( size_in_bytes=size_in_bytes, rows_count=rows_count, materialized=materialized, - ) + ), ) for schema_key, schema_tables in tables.items(): logger.info( - f"In schema: {schema_key} discovered {len(schema_tables)} tables" + f"In schema: {schema_key} discovered {len(schema_tables)} tables", ) for schema_key, schema_views in views.items(): logger.info(f"In schema: {schema_key} discovered {len(schema_views)} views") @@ -307,7 +313,7 @@ def get_table_stats(enriched_tables, field_names, schema, table): creation_time: Optional[datetime] = None if table[field_names.index("creation_time")]: creation_time = table[field_names.index("creation_time")].replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ) last_altered: Optional[datetime] = None size_in_bytes: Optional[int] = None @@ -358,7 +364,8 @@ def get_schema_fields_for_column( @staticmethod def get_columns_for_schema( - conn: redshift_connector.Connection, schema: RedshiftSchema + conn: redshift_connector.Connection, + schema: RedshiftSchema, ) -> Dict[str, List[RedshiftColumn]]: cursor = RedshiftDataDictionary.get_query_result( conn, @@ -473,7 +480,8 @@ def get_temporary_rows( # See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html # for why we need to replace the \n with a newline. query_text=row[field_names.index("query_text")].replace( - r"\n", "\n" + r"\n", + "\n", ), create_command=row[field_names.index("create_command")], start_time=row[field_names.index("start_time")], @@ -503,7 +511,8 @@ def get_alter_table_commands( # See https://docs.aws.amazon.com/redshift/latest/dg/r_STL_QUERYTEXT.html # for why we need to replace the \n with a newline. query_text=row[field_names.index("query_text")].replace( - r"\n", "\n" + r"\n", + "\n", ), start_time=row[field_names.index("start_time")], ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py index 2748f2a588a930..c686ffe577c5aa 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py @@ -32,7 +32,7 @@ class RedshiftReport( view_cached: TopKDict[str, int] = field(default_factory=TopKDict) metadata_extraction_sec: TopKDict[str, float] = field(default_factory=TopKDict) operational_metadata_extraction_sec: TopKDict[str, float] = field( - default_factory=TopKDict + default_factory=TopKDict, ) lineage_mem_size: Dict[str, str] = field(default_factory=TopKDict) tables_in_mem_size: Dict[str, str] = field(default_factory=TopKDict) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py index a5758bdd825702..cef08ae2ff9cb7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py @@ -124,7 +124,8 @@ def __init__( def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: return self.redundant_run_skip_handler.suggest_run_time_window( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) else: return self.config.start_time, self.config.end_time @@ -147,7 +148,8 @@ def _should_ingest_usage(self): return True def get_usage_workunits( - self, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]] + self, + all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], ) -> Iterable[MetadataWorkUnit]: if not self._should_ingest_usage(): return @@ -175,7 +177,8 @@ def get_usage_workunits( ) def _get_workunits_internal( - self, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]] + self, + all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], ) -> Iterable[MetadataWorkUnit]: self.report.num_usage_workunits_emitted = 0 self.report.num_usage_stat_skipped = 0 @@ -186,7 +189,8 @@ def _get_workunits_internal( with PerfTimer() as timer: # Generate operation aspect workunits yield from self._gen_operation_aspect_workunits( - self.connection, all_tables + self.connection, + all_tables, ) self.report.operational_metadata_extraction_sec[ self.config.database @@ -201,12 +205,14 @@ def _get_workunits_internal( ) access_events_iterable: Iterable[RedshiftAccessEvent] = ( self._gen_access_events_from_history_query( - query, connection=self.connection, all_tables=all_tables + query, + connection=self.connection, + all_tables=all_tables, ) ) aggregated_events: AggregatedAccessEvents = self._aggregate_access_events( - access_events_iterable + access_events_iterable, ) # Generate usage workunits from aggregated events. for time_bucket in aggregated_events.values(): @@ -227,7 +233,9 @@ def _gen_operation_aspect_workunits( ) access_events_iterable: Iterable[RedshiftAccessEvent] = ( self._gen_access_events_from_history_query( - query, connection, all_tables=all_tables + query, + connection, + all_tables=all_tables, ) ) @@ -236,8 +244,9 @@ def _gen_operation_aspect_workunits( mcpw.as_workunit() for mcpw in self._drop_repeated_operations( self._gen_operation_aspect_workunits_from_access_events( - access_events_iterable, all_tables=all_tables - ) + access_events_iterable, + all_tables=all_tables, + ), ) ) @@ -291,7 +300,7 @@ def _gen_access_events_from_history_query( ) except pydantic.error_wrappers.ValidationError as e: logging.warning( - f"Validation error on access event creation from row {row}. The error was: {e} Skipping ...." + f"Validation error on access event creation from row {row}. The error was: {e} Skipping ....", ) self.report.num_usage_stat_skipped += 1 continue @@ -304,7 +313,8 @@ def _gen_access_events_from_history_query( results = cursor.fetchmany() def _drop_repeated_operations( - self, events: Iterable[MetadataChangeProposalWrapper] + self, + events: Iterable[MetadataChangeProposalWrapper], ) -> Iterable[MetadataChangeProposalWrapper]: """Drop repeated operations on the same entity. @@ -329,14 +339,17 @@ def timer(): # dict of entity urn -> (last event's actor, operation type) # TODO: Remove the type ignore and use TTLCache[key_type, value_type] directly once that's supported in Python 3.9. last_events: Dict[str, Tuple[Optional[str], str]] = cachetools.TTLCache( # type: ignore[assignment] - maxsize=OPERATION_CACHE_MAXSIZE, ttl=DROP_WINDOW_SEC * 1000, timer=timer + maxsize=OPERATION_CACHE_MAXSIZE, + ttl=DROP_WINDOW_SEC * 1000, + timer=timer, ) for event in events: assert isinstance(event.aspect, OperationClass) timestamp_low_watermark = min( - timestamp_low_watermark, event.aspect.lastUpdatedTimestamp + timestamp_low_watermark, + event.aspect.lastUpdatedTimestamp, ) urn = event.entityUrn @@ -392,17 +405,20 @@ def _gen_operation_aspect_workunits_from_access_events( resource: str = f"{event.database}.{event.schema_}.{event.table}".lower() yield MetadataChangeProposalWrapper( - entityUrn=self.dataset_urn_builder(resource), aspect=operation_aspect + entityUrn=self.dataset_urn_builder(resource), + aspect=operation_aspect, ) self.report.num_operational_stats_workunits_emitted += 1 def _aggregate_access_events( - self, events_iterable: Iterable[RedshiftAccessEvent] + self, + events_iterable: Iterable[RedshiftAccessEvent], ) -> AggregatedAccessEvents: datasets: AggregatedAccessEvents = collections.defaultdict(dict) for event in events_iterable: floored_ts: datetime = get_time_bucket( - event.starttime, self.config.bucket_duration + event.starttime, + self.config.bucket_duration, ) resource: str = f"{event.database}.{event.schema_}.{event.table}".lower() # Get a reference to the bucket value(or initialize not yet in dict) and update it. diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py index 3069c62e3a240f..773ac21434b048 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py @@ -26,7 +26,9 @@ class DataLakeSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin, PathSpecsConfigMixin + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, + PathSpecsConfigMixin, ): platform: str = Field( default="", @@ -34,13 +36,15 @@ class DataLakeSourceConfig( "If not specified, the platform will be inferred from the path_specs.", ) aws_config: Optional[AwsConnectionConfig] = Field( - default=None, description="AWS configuration" + default=None, + description="AWS configuration", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None # Whether or not to create in datahub from the s3 bucket use_s3_bucket_tags: Optional[bool] = Field( - None, description="Whether or not to create tags in datahub from the s3 bucket" + None, + description="Whether or not to create tags in datahub from the s3 bucket", ) # Whether or not to create in datahub from the s3 object use_s3_object_tags: Optional[bool] = Field( @@ -66,11 +70,13 @@ class DataLakeSourceConfig( description="regex patterns for tables to profile ", ) profiling: DataLakeProfilerConfig = Field( - default=DataLakeProfilerConfig(), description="Data profiling configuration" + default=DataLakeProfilerConfig(), + description="Data profiling configuration", ) spark_driver_memory: str = Field( - default="4g", description="Max amount of memory to grant Spark." + default="4g", + description="Max amount of memory to grant Spark.", ) spark_config: Dict[str, Any] = Field( @@ -97,7 +103,9 @@ class DataLakeSourceConfig( ) _rename_path_spec_to_plural = pydantic_renamed_field( - "path_spec", "path_specs", lambda path_spec: [path_spec] + "path_spec", + "path_specs", + lambda path_spec: [path_spec], ) sort_schema_fields: bool = Field( @@ -112,12 +120,14 @@ class DataLakeSourceConfig( def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @pydantic.validator("path_specs", always=True) def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict + cls, + path_specs: List[PathSpec], + values: Dict, ) -> List[PathSpec]: if len(path_specs) == 0: raise ValueError("path_specs must not be empty") @@ -128,7 +138,7 @@ def check_path_specs_and_infer_platform( } if len(guessed_platforms) > 1: raise ValueError( - f"Cannot have multiple platforms in path_specs: {guessed_platforms}" + f"Cannot have multiple platforms in path_specs: {guessed_platforms}", ) guessed_platform = guessed_platforms.pop() @@ -137,13 +147,13 @@ def check_path_specs_and_infer_platform( values.get("use_s3_object_tags") or values.get("use_s3_bucket_tags") ): raise ValueError( - "Cannot grab s3 object/bucket tags when platform is not s3. Remove the flag or use s3." + "Cannot grab s3 object/bucket tags when platform is not s3. Remove the flag or use s3.", ) # Infer platform if not specified. if values.get("platform") and values["platform"] != guessed_platform: raise ValueError( - f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}" + f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}", ) else: logger.debug(f'Setting config "platform": {guessed_platform}') @@ -154,7 +164,8 @@ def check_path_specs_and_infer_platform( @pydantic.validator("platform", always=True) def platform_valid(cls, platform: str, values: dict) -> str: inferred_platform = values.get( - "platform", None + "platform", + None, ) # we may have inferred it above platform = platform or inferred_platform if not platform: @@ -162,22 +173,23 @@ def platform_valid(cls, platform: str, values: dict) -> str: if platform != "s3" and values.get("use_s3_bucket_tags"): raise ValueError( - "Cannot grab s3 bucket tags when platform is not s3. Remove the flag or ingest from s3." + "Cannot grab s3 bucket tags when platform is not s3. Remove the flag or ingest from s3.", ) if platform != "s3" and values.get("use_s3_object_tags"): raise ValueError( - "Cannot grab s3 object tags when platform is not s3. Remove the flag or ingest from s3." + "Cannot grab s3 object tags when platform is not s3. Remove the flag or ingest from s3.", ) if platform != "s3" and values.get("use_s3_content_type"): raise ValueError( - "Cannot grab s3 object content type when platform is not s3. Remove the flag or ingest from s3." + "Cannot grab s3 object content type when platform is not s3. Remove the flag or ingest from s3.", ) return platform @pydantic.root_validator(skip_on_failure=True) def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] + cls, + values: Dict[str, Any], ) -> Dict[str, Any]: profiling: Optional[DataLakeProfilerConfig] = values.get("profiling") if profiling is not None and profiling.enabled: diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py index 58e930eb6e809c..25e7afaab8d213 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py @@ -10,7 +10,8 @@ class DataLakeProfilerConfig(ConfigModel): enabled: bool = Field( - default=False, description="Whether profiling should be done." + default=False, + description="Whether profiling should be done.", ) operation_config: OperationConfig = Field( default_factory=OperationConfig, @@ -61,7 +62,8 @@ class DataLakeProfilerConfig(ConfigModel): description="Whether to profile for the quantiles of numeric columns.", ) include_field_distinct_value_frequencies: bool = Field( - default=True, description="Whether to profile for distinct value frequencies." + default=True, + description="Whether to profile for distinct value frequencies.", ) include_field_histogram: bool = Field( default=True, @@ -74,7 +76,8 @@ class DataLakeProfilerConfig(ConfigModel): @pydantic.root_validator(skip_on_failure=True) def ensure_field_level_settings_are_normalized( - cls: "DataLakeProfilerConfig", values: Dict[str, Any] + cls: "DataLakeProfilerConfig", + values: Dict[str, Any], ) -> Dict[str, Any]: max_num_fields_to_profile_key = "max_number_of_fields_to_profile" max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/profiling.py b/metadata-ingestion/src/datahub/ingestion/source/s3/profiling.py index c969b229989e84..d8a4f0069fcdfb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/profiling.py @@ -136,12 +136,13 @@ def __init__( ] self.report.report_file_dropped( - f"The max_number_of_fields_to_profile={self.profiling_config.max_number_of_fields_to_profile} reached. Profile of columns {self.file_path}({', '.join(sorted(columns_being_dropped))})" + f"The max_number_of_fields_to_profile={self.profiling_config.max_number_of_fields_to_profile} reached. Profile of columns {self.file_path}({', '.join(sorted(columns_being_dropped))})", ) analysis_result = self.analyzer.run() analysis_metrics = AnalyzerContext.successMetricsAsJson( - self.spark, analysis_result + self.spark, + analysis_result, ) # reshape distinct counts into dictionary @@ -156,7 +157,7 @@ def __init__( when( isnan(c) | col(c).isNull(), c, - ) + ), ).alias(c) for c in self.columns_to_profile if column_types[column] in [DoubleType, FloatType] @@ -168,14 +169,14 @@ def __init__( when( col(c).isNull(), c, - ) + ), ).alias(c) for c in self.columns_to_profile if column_types[column] not in [DoubleType, FloatType] ] null_counts = dataframe.select( - select_numeric_null_counts + select_nonnumeric_null_counts + select_numeric_null_counts + select_nonnumeric_null_counts, ) column_null_counts = null_counts.toPandas().T[0].to_dict() column_null_fractions = { @@ -215,7 +216,7 @@ def __init__( column_profile.nullProportion = column_null_fractions.get(column) if self.profiling_config.include_field_sample_values: column_profile.sampleValues = sorted( - [str(x[column]) for x in rdd_sample] + [str(x[column]) for x in rdd_sample], ) column_spec.type_ = column_types[column] @@ -373,7 +374,7 @@ def extract_table_profiles( # resolve histogram types for grouping column_metrics["kind"] = column_metrics["name"].apply( - lambda x: "Histogram" if x.startswith("Histogram.") else x + lambda x: "Histogram" if x.startswith("Histogram.") else x, ) column_histogram_metrics = column_metrics[column_metrics["kind"] == "Histogram"] @@ -387,12 +388,12 @@ def extract_table_profiles( # we only want the absolute counts for each histogram for now column_histogram_metrics = column_histogram_metrics[ column_histogram_metrics["name"].apply( - lambda x: x.startswith("Histogram.abs.") + lambda x: x.startswith("Histogram.abs."), ) ] # get the histogram bins by chopping off the "Histogram.abs." prefix column_histogram_metrics["bin"] = column_histogram_metrics["name"].apply( - lambda x: x[14:] + lambda x: x[14:], ) # reshape histogram counts for easier access @@ -407,7 +408,7 @@ def extract_table_profiles( if len(column_nonhistogram_metrics) > 0: # reshape other metrics for easier access nonhistogram_metrics = column_nonhistogram_metrics.set_index( - ["instance", "name"] + ["instance", "name"], )["value"] profiled_columns = set(nonhistogram_metrics.index.get_level_values(0)) @@ -428,10 +429,10 @@ def extract_table_profiles( column_profile.max = null_str(deequ_column_profile.get("Maximum")) column_profile.mean = null_str(deequ_column_profile.get("Mean")) column_profile.median = null_str( - deequ_column_profile.get("ApproxQuantiles-0.5") + deequ_column_profile.get("ApproxQuantiles-0.5"), ) column_profile.stdev = null_str( - deequ_column_profile.get("StandardDeviation") + deequ_column_profile.get("StandardDeviation"), ) if all( deequ_column_profile.get(f"ApproxQuantiles-{quantile}") is not None @@ -453,13 +454,15 @@ def extract_table_profiles( if column_spec.histogram_distinct: column_profile.distinctValueFrequencies = [ ValueFrequencyClass( - value=value, frequency=int(column_histogram.loc[value]) + value=value, + frequency=int(column_histogram.loc[value]), ) for value in column_histogram.index ] # sort so output is deterministic column_profile.distinctValueFrequencies = sorted( - column_profile.distinctValueFrequencies, key=lambda x: x.value + column_profile.distinctValueFrequencies, + key=lambda x: x.value, ) else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py index 3173423f86a2ea..c4e1911169adb3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py @@ -189,7 +189,8 @@ class TableData: @capability(SourceCapability.CONTAINERS, "Enabled by default") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") @capability( - SourceCapability.SCHEMA_METADATA, "Can infer schema from supported file types" + SourceCapability.SCHEMA_METADATA, + "Can infer schema from supported file types", ) @capability(SourceCapability.TAGS, "Can extract S3 object/bucket tags if enabled") class S3Source(StatefulIngestionSourceBase): @@ -244,7 +245,7 @@ def init_spark(self): # Spark's avro version needs to be matched with the Spark version f"org.apache.spark:spark-avro_2.12:{spark_version}{'.0' if spark_version.count('.') == 1 else ''}", pydeequ.deequ_maven_coord, - ] + ], ), ) @@ -296,11 +297,13 @@ def init_spark(self): if self.source_config.aws_config.aws_endpoint_url is not None: conf.set( - "fs.s3a.endpoint", self.source_config.aws_config.aws_endpoint_url + "fs.s3a.endpoint", + self.source_config.aws_config.aws_endpoint_url, ) if self.source_config.aws_config.aws_region is not None: conf.set( - "fs.s3a.endpoint.region", self.source_config.aws_config.aws_region + "fs.s3a.endpoint.region", + self.source_config.aws_config.aws_region, ) conf.set("spark.jars.excludes", pydeequ.f2j_maven_coord) @@ -373,11 +376,13 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List: raise ValueError("AWS config is required for S3 file sources") s3_client = self.source_config.aws_config.get_s3_client( - self.source_config.verify_ssl + self.source_config.verify_ssl, ) file = smart_open( - table_data.full_path, "rb", transport_params={"client": s3_client} + table_data.full_path, + "rb", + transport_params={"client": s3_client}, ) else: # We still use smart_open here to take advantage of the compression @@ -418,13 +423,17 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List: if self.source_config.add_partition_columns_to_schema and table_data.partitions: self.add_partition_columns_to_schema( - fields=fields, path_spec=path_spec, full_path=table_data.full_path + fields=fields, + path_spec=path_spec, + full_path=table_data.full_path, ) return fields def _get_inferrer( - self, extension: str, content_type: Optional[str] + self, + extension: str, + content_type: Optional[str], ) -> Optional[SchemaInferenceBase]: if content_type == "application/vnd.apache.parquet": return parquet.ParquetInferrer() @@ -444,7 +453,8 @@ def _get_inferrer( return csv_tsv.TsvInferrer(max_rows=self.source_config.max_rows) elif extension == ".jsonl": return json.JsonInferrer( - max_rows=self.source_config.max_rows, format="jsonl" + max_rows=self.source_config.max_rows, + format="jsonl", ) elif extension == ".json": return json.JsonInferrer() @@ -454,7 +464,10 @@ def _get_inferrer( return None def add_partition_columns_to_schema( - self, path_spec: PathSpec, full_path: str, fields: List[SchemaField] + self, + path_spec: PathSpec, + full_path: str, + fields: List[SchemaField], ) -> None: is_fieldpath_v2 = False for field in fields: @@ -478,11 +491,13 @@ def add_partition_columns_to_schema( isPartitioningKey=True, nullable=True, recursive=False, - ) + ), ) def get_table_profile( - self, table_data: TableData, dataset_urn: str + self, + table_data: TableData, + dataset_urn: str, ) -> Iterable[MetadataWorkUnit]: # Importing here to avoid Deequ dependency for non profiling use cases # Deequ fails if Spark is not available which is not needed for non profiling use cases @@ -495,11 +510,13 @@ def get_table_profile( try: if table_data.partitions: table = self.read_file_spark( - table_data.table_path, os.path.splitext(table_data.full_path)[1] + table_data.table_path, + os.path.splitext(table_data.full_path)[1], ) else: table = self.read_file_spark( - table_data.full_path, os.path.splitext(table_data.full_path)[1] + table_data.full_path, + os.path.splitext(table_data.full_path)[1], ) except Exception as e: logger.error(e) @@ -515,7 +532,7 @@ def get_table_profile( with PerfTimer() as timer: # init PySpark analysis object logger.debug( - f"Profiling {table_data.full_path}: reading file and computing nulls+uniqueness {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}" + f"Profiling {table_data.full_path}: reading file and computing nulls+uniqueness {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}", ) table_profiler = _SingleTableProfiler( table, @@ -526,7 +543,7 @@ def get_table_profile( ) logger.debug( - f"Profiling {table_data.full_path}: preparing profilers to run {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}" + f"Profiling {table_data.full_path}: preparing profilers to run {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}", ) # instead of computing each profile individually, we run them all in a single analyzer.run() call # we use a single call because the analyzer optimizes the number of calls to the underlying profiler @@ -535,22 +552,23 @@ def get_table_profile( # compute the profiles logger.debug( - f"Profiling {table_data.full_path}: computing profiles {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}" + f"Profiling {table_data.full_path}: computing profiles {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}", ) analysis_result = table_profiler.analyzer.run() analysis_metrics = AnalyzerContext.successMetricsAsDataFrame( - self.spark, analysis_result + self.spark, + analysis_result, ) logger.debug( - f"Profiling {table_data.full_path}: extracting profiles {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}" + f"Profiling {table_data.full_path}: extracting profiles {datetime.now().strftime('%d/%m/%Y %H:%M:%S')}", ) table_profiler.extract_table_profiles(analysis_metrics) time_taken = timer.elapsed_seconds() logger.info( - f"Finished profiling {table_data.full_path}; took {time_taken:.3f} seconds" + f"Finished profiling {table_data.full_path}; took {time_taken:.3f} seconds", ) self.profiling_times_taken.append(time_taken) @@ -573,7 +591,8 @@ def _create_table_operation_aspect(self, table_data: TableData) -> OperationClas return operation def __create_partition_summary_aspect( - self, partitions: List[Folder] + self, + partitions: List[Folder], ) -> Optional[PartitionsSummaryClass]: min_partition = min(partitions, key=lambda x: x.creation_time) max_partition = max(partitions, key=lambda x: x.creation_time) @@ -586,7 +605,7 @@ def __create_partition_summary_aspect( partition=max_partition_id, createdTime=int(max_partition.creation_time.timestamp() * 1000), lastModifiedTime=int( - max_partition.modification_time.timestamp() * 1000 + max_partition.modification_time.timestamp() * 1000, ), ) @@ -597,16 +616,19 @@ def __create_partition_summary_aspect( partition=min_partition_id, createdTime=int(min_partition.creation_time.timestamp() * 1000), lastModifiedTime=int( - min_partition.modification_time.timestamp() * 1000 + min_partition.modification_time.timestamp() * 1000, ), ) return PartitionsSummaryClass( - maxPartition=max_partition_summary, minPartition=min_partition_summary + maxPartition=max_partition_summary, + minPartition=min_partition_summary, ) def ingest_table( - self, table_data: TableData, path_spec: PathSpec + self, + table_data: TableData, + path_spec: PathSpec, ) -> Iterable[MetadataWorkUnit]: aspects: List[Optional[_Aspect]] = [] @@ -630,7 +652,8 @@ def ingest_table( data_platform_instance = DataPlatformInstanceClass( platform=data_platform_urn, instance=make_dataplatform_instance_urn( - self.source_config.platform, self.source_config.platform_instance + self.source_config.platform, + self.source_config.platform_instance, ), ) aspects.append(data_platform_instance) @@ -648,16 +671,16 @@ def ingest_table( { "number_of_files": str(table_data.number_of_files), "size_in_bytes": str(table_data.size_in_bytes), - } + }, ) else: if table_data.partitions: customProperties.update( { "number_of_partitions": str( - len(table_data.partitions) if table_data.partitions else 0 + len(table_data.partitions) if table_data.partitions else 0, ), - } + }, ) dataset_properties = DatasetPropertiesClass( @@ -690,11 +713,11 @@ def ingest_table( aspects.append(schema_metadata) except Exception as e: logger.error( - f"Failed to extract schema from file {table_data.full_path}. The error was:{e}" + f"Failed to extract schema from file {table_data.full_path}. The error was:{e}", ) else: logger.info( - f"Skipping schema extraction for empty file {table_data.full_path}" + f"Skipping schema extraction for empty file {table_data.full_path}", ) if ( @@ -725,7 +748,7 @@ def ingest_table( if table_data.partitions and self.source_config.generate_partition_aspects: aspects.append( - self.__create_partition_summary_aspect(table_data.partitions) + self.__create_partition_summary_aspect(table_data.partitions), ) for mcp in MetadataChangeProposalWrapper.construct_many( @@ -735,7 +758,8 @@ def ingest_table( yield mcp.as_workunit() yield from self.container_WU_creator.create_container_hierarchy( - table_data.table_path, dataset_urn + table_data.table_path, + dataset_urn, ) if self.source_config.is_profiling_enabled(): @@ -779,7 +803,7 @@ def extract_table_data( [ partition.size if partition.size else 0 for partition in partitions - ] + ], ) ), content_type=browse_path.content_type, @@ -793,11 +817,14 @@ def resolve_templated_folders(self, bucket_name: str, prefix: str) -> Iterable[s return folders: Iterable[str] = list_folders( - bucket_name, folder_split[0], self.source_config.aws_config + bucket_name, + folder_split[0], + self.source_config.aws_config, ) for folder in folders: yield from self.resolve_templated_folders( - bucket_name, f"{folder}{folder_split[1]}" + bucket_name, + f"{folder}{folder_split[1]}", ) def get_dir_to_process( @@ -887,12 +914,12 @@ def get_folder_info( if modification_time is None: logger.warning( - f"Unable to find any files in the folder {key}. Skipping..." + f"Unable to find any files in the folder {key}. Skipping...", ) continue id = path_spec.get_partition_from_path( - self.create_s3_path(max_file.bucket_name, max_file.key) + self.create_s3_path(max_file.bucket_name, max_file.key), ) # If id is None, it means the folder is not a partition @@ -904,7 +931,7 @@ def get_folder_info( modification_time=modification_time, sample_file=self.create_s3_path(max_file.bucket_name, max_file.key), size=file_size, - ) + ), ) return partitions @@ -913,7 +940,7 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa if self.source_config.aws_config is None: raise ValueError("aws_config not set. Cannot browse s3") s3 = self.source_config.aws_config.get_s3_resource( - self.source_config.verify_ssl + self.source_config.verify_ssl, ) bucket_name = get_bucket_name(path_spec.include) logger.debug(f"Scanning bucket: {bucket_name}") @@ -947,11 +974,14 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa table_index = include.find(max_match) for folder in self.resolve_templated_folders( - bucket_name, get_bucket_relative_path(include[:table_index]) + bucket_name, + get_bucket_relative_path(include[:table_index]), ): try: for f in list_folders( - bucket_name, f"{folder}", self.source_config.aws_config + bucket_name, + f"{folder}", + self.source_config.aws_config, ): dirs_to_process = [] logger.info(f"Processing folder: {f}") @@ -965,7 +995,7 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa == FolderTraversalMethod.MAX ): protocol = ContainerWUCreator.get_protocol( - path_spec.include + path_spec.include, ) dirs_to_process_max = self.get_dir_to_process( bucket_name=bucket_name, @@ -994,15 +1024,17 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa folders.extend( self.get_folder_info( - path_spec, bucket, prefix_to_process - ) + path_spec, + bucket, + prefix_to_process, + ), ) max_folder = None if folders: max_folder = max(folders, key=lambda x: x.modification_time) if not max_folder: logger.warning( - f"Unable to find any files in the folder {dir}. Skipping..." + f"Unable to find any files in the folder {dir}. Skipping...", ) continue @@ -1021,13 +1053,14 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa if "NoSuchBucket" in repr(e): logger.debug(f"Got NoSuchBucket exception for {bucket_name}", e) self.get_report().report_warning( - "Missing bucket", f"No bucket found {bucket_name}" + "Missing bucket", + f"No bucket found {bucket_name}", ) else: raise e else: logger.debug( - "No template in the pathspec can't do sampling, fallbacking to do full scan" + "No template in the pathspec can't do sampling, fallbacking to do full scan", ) path_spec.sample_files = False for obj in bucket.objects.filter(Prefix=prefix).page_size(PAGE_SIZE): @@ -1067,12 +1100,12 @@ def local_browser(self, path_spec: PathSpec) -> Iterable[BrowsePath]: for file in sorted(files): # We need to make sure the path is in posix style which is not true on windows full_path = PurePath( - os.path.normpath(os.path.join(root, file)) + os.path.normpath(os.path.join(root, file)), ).as_posix() yield BrowsePath( file=full_path, timestamp=datetime.utcfromtimestamp( - os.path.getmtime(full_path) + os.path.getmtime(full_path), ), size=os.path.getsize(full_path), partitions=[], @@ -1089,7 +1122,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for path_spec in self.source_config.path_specs: file_browser = ( self.s3_browser( - path_spec, self.source_config.number_of_files_to_sample + path_spec, + self.source_config.number_of_files_to_sample, ) if self.is_s3_platform() else self.local_browser(path_spec) @@ -1133,7 +1167,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: total_time_taken = timer.elapsed_seconds() logger.info( - f"Profiling {len(self.profiling_times_taken)} table(s) finished in {total_time_taken:.3f} seconds" + f"Profiling {len(self.profiling_times_taken)} table(s) finished in {total_time_taken:.3f} seconds", ) time_percentiles: Dict[str, float] = {} @@ -1141,12 +1175,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if len(self.profiling_times_taken) > 0: percentiles = [50, 75, 95, 99] percentile_values = stats.calculate_percentiles( - self.profiling_times_taken, percentiles + self.profiling_times_taken, + percentiles, ) time_percentiles = { f"table_time_taken_p{percentile}": stats.discretize( - percentile_values[percentile] + percentile_values[percentile], ) for percentile in percentiles } @@ -1166,7 +1201,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ).workunit_processor, ] diff --git a/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py b/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py index b75f15c0ce770e..c3abb70a19dda3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py @@ -89,7 +89,8 @@ class ConnectionMappingConfig(EnvConfigMixin): platform: Optional[str] = Field( - default=None, description="The platform that this connection mapping belongs to" + default=None, + description="The platform that this connection mapping belongs to", ) platform_instance: Optional[str] = Field( @@ -104,7 +105,9 @@ class ConnectionMappingConfig(EnvConfigMixin): class SACSourceConfig( - StatefulIngestionConfigBase, DatasetSourceConfigMixin, IncrementalLineageConfigMixin + StatefulIngestionConfigBase, + DatasetSourceConfigMixin, + IncrementalLineageConfigMixin, ): stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( default=None, @@ -113,11 +116,11 @@ class SACSourceConfig( tenant_url: str = Field(description="URL of the SAP Analytics Cloud tenant") token_url: str = Field( - description="URL of the OAuth token endpoint of the SAP Analytics Cloud tenant" + description="URL of the OAuth token endpoint of the SAP Analytics Cloud tenant", ) client_id: str = Field(description="Client ID for the OAuth authentication") client_secret: SecretStr = Field( - description="Client secret for the OAuth authentication" + description="Client secret for the OAuth authentication", ) ingest_stories: bool = Field( @@ -151,7 +154,8 @@ class SACSourceConfig( ) connection_mapping: Dict[str, ConnectionMappingConfig] = Field( - default={}, description="Custom mappings for connections" + default={}, + description="Custom mappings for connections", ) query_name_template: Optional[str] = Field( @@ -229,7 +233,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=f"{e}" + capable=False, + failure_reason=f"{e}", ) return test_report @@ -242,7 +247,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: self.config.incremental_lineage, ), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -277,7 +284,9 @@ def get_report(self) -> SACSourceReport: return self.report def get_resource_workunits( - self, resource: Resource, datasets: List[str] + self, + resource: Resource, + datasets: List[str], ) -> Iterable[MetadataWorkUnit]: dashboard_urn = make_dashboard_urn( platform=self.platform, @@ -315,7 +324,8 @@ def get_resource_workunits( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ) @@ -377,7 +387,9 @@ def get_resource_workunits( yield mcp.as_workunit() def get_model_workunits( - self, dataset_urn: str, model: ResourceModel + self, + dataset_urn: str, + model: ResourceModel, ) -> Iterable[MetadataWorkUnit]: mcp = MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -433,12 +445,12 @@ def get_model_workunits( upstream_dataset_name: Optional[str] = None if model.system_type == "BW" and model.external_id.startswith( - "query:" + "query:", ): # query:[][][query] query = model.external_id[11:-1] upstream_dataset_name = self.get_query_name(query) elif model.system_type == "HANA" and model.external_id.startswith( - "view:" + "view:", ): # view:[schema][schema.namespace][view] schema, namespace_with_schema, view = model.external_id.split("][", 2) schema = schema[6:] @@ -464,7 +476,7 @@ def get_model_workunits( env = DEFAULT_ENV logger.info( - f"No connection mapping found for connection with id {model.connection_id}, connection id will be used as platform instance" + f"No connection mapping found for connection with id {model.connection_id}, connection id will be used as platform instance", ) upstream_dataset_urn = make_dataset_urn_with_platform_instance( @@ -579,7 +591,8 @@ def get_sac_connection( session.mount("https://", adapter) session.register_compliance_hook( - "protected_request", _add_sap_sac_custom_auth_header + "protected_request", + _add_sap_sac_custom_auth_header, ) session.fetch_token() @@ -621,7 +634,7 @@ def get_resources(self) -> Iterable[Resource]: ) if not self.config.resource_id_pattern.allowed( - resource_id + resource_id, ) or not self.config.resource_name_pattern.allowed(name): continue @@ -672,7 +685,7 @@ def get_resources(self) -> Iterable[Resource]: connection_id=nav_entity.connectionId, external_id=nav_entity.externalId, # query:[][][query] or view:[schema][schema.namespace][view] is_import=model_id in import_data_model_ids, - ) + ), ) created_by: Optional[str] = entity.createdBy @@ -704,7 +717,7 @@ def get_resources(self) -> Iterable[Resource]: def get_import_data_model_ids(self) -> Set[str]: response = self.session.get( - url=f"{self.config.tenant_url}/api/v1/dataimport/models" + url=f"{self.config.tenant_url}/api/v1/dataimport/models", ) response.raise_for_status() @@ -714,10 +727,11 @@ def get_import_data_model_ids(self) -> Set[str]: return import_data_model_ids def get_import_data_model_columns( - self, model_id: str + self, + model_id: str, ) -> List[ImportDataModelColumn]: response = self.session.get( - url=f"{self.config.tenant_url}/api/v1/dataimport/models/{model_id}/metadata" + url=f"{self.config.tenant_url}/api/v1/dataimport/models/{model_id}/metadata", ) response.raise_for_status() @@ -739,7 +753,7 @@ def get_import_data_model_columns( precision=column.get("precision"), scale=column.get("scale"), is_key=column["isKey"], - ) + ), ) return columns @@ -760,7 +774,8 @@ def get_view_name(self, schema: str, namespace: Optional[str], view: str) -> str return f"{schema}.{view}" def get_schema_field_data_type( - self, column: ImportDataModelColumn + self, + column: ImportDataModelColumn, ) -> SchemaFieldDataTypeClass: if column.property_type == "DATE": return SchemaFieldDataTypeClass(type=DateTypeClass()) @@ -790,7 +805,9 @@ def get_schema_field_native_data_type(self, column: ImportDataModelColumn) -> st def _add_sap_sac_custom_auth_header( - url: str, headers: Dict[str, str], body: Any + url: str, + headers: Dict[str, str], + body: Any, ) -> Tuple[str, Dict[str, str], Any]: headers["x-sap-sac-custom-auth"] = "true" return url, headers, body diff --git a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py index 66e0e6b741d1ff..a1c95fcf606064 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py +++ b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py @@ -92,23 +92,24 @@ class SalesforceConfig(DatasetSourceConfigMixin): username: Optional[str] = Field(description="Salesforce username") password: Optional[str] = Field(description="Password for Salesforce user") consumer_key: Optional[str] = Field( - description="Consumer key for Salesforce JSON web token access" + description="Consumer key for Salesforce JSON web token access", ) private_key: Optional[str] = Field( - description="Private key as a string for Salesforce JSON web token access" + description="Private key as a string for Salesforce JSON web token access", ) security_token: Optional[str] = Field( - description="Security token for Salesforce username" + description="Security token for Salesforce username", ) # client_id, client_secret not required # Direct - Instance URL, Access Token Auth instance_url: Optional[str] = Field( - description="Salesforce instance url. e.g. https://MyDomainName.my.salesforce.com" + description="Salesforce instance url. e.g. https://MyDomainName.my.salesforce.com", ) # Flag to indicate whether the instance is production or sandbox is_sandbox: bool = Field( - default=False, description="Connect to Sandbox instance of your Salesforce" + default=False, + description="Connect to Sandbox instance of your Salesforce", ) access_token: Optional[str] = Field(description="Access token for instance url") @@ -126,7 +127,7 @@ class SalesforceConfig(DatasetSourceConfigMixin): description='Regex patterns for tables/schemas to describe domain_key domain key (domain_key can be any string like "sales".) There can be multiple domain keys specified.', ) api_version: Optional[str] = Field( - description="If specified, overrides default version used by the Salesforce package. Example value: '59.0'" + description="If specified, overrides default version used by the Salesforce package. Example value: '59.0'", ) profiling: SalesforceProfilingConfig = SalesforceProfilingConfig() @@ -138,7 +139,7 @@ class SalesforceConfig(DatasetSourceConfigMixin): def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @validator("instance_url") @@ -316,13 +317,14 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: self.sf.sf_version = version self.base_url = "https://{instance}/services/data/v{sf_version}/".format( - instance=self.sf.sf_instance, sf_version=self.sf.sf_version + instance=self.sf.sf_instance, + sf_version=self.sf.sf_version, ) logger.debug( "Using Salesforce REST API version: {version}".format( - version=self.sf.sf_version - ) + version=self.sf.sf_version, + ), ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: @@ -333,7 +335,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # https://developer.salesforce.com/docs/atlas.en-us.api_tooling.meta/api_tooling/tooling_api_objects_entitydefinition.htm raise ConfigurationError( "Salesforce EntityDefinition query failed. " - "Please verify if user has 'View Setup and Configuration' permission." + "Please verify if user has 'View Setup and Configuration' permission.", ) from e raise e else: @@ -341,7 +343,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self.get_salesforce_object_workunits(sObject) def get_salesforce_object_workunits( - self, sObject: dict + self, + sObject: dict, ) -> Iterable[MetadataWorkUnit]: sObjectName = sObject["QualifiedApiName"] @@ -349,8 +352,8 @@ def get_salesforce_object_workunits( self.report.report_dropped(sObjectName) logger.debug( "Skipping {sObject}, as it is not allowed by object_pattern".format( - sObject=sObjectName - ) + sObject=sObjectName, + ), ) return @@ -371,7 +374,10 @@ def get_salesforce_object_workunits( yield self.get_properties_workunit(sObject, customObject, datasetUrn) yield from self.get_schema_metadata_workunit( - sObjectName, sObject, customObject, datasetUrn + sObjectName, + sObject, + customObject, + datasetUrn, ) yield self.get_subtypes_workunit(sObjectName, datasetUrn) @@ -383,7 +389,7 @@ def get_salesforce_object_workunits( yield from self.get_domain_workunit(sObjectName, datasetUrn) if self.config.is_profiling_enabled() and self.config.profile_pattern.allowed( - sObjectName + sObjectName, ): yield from self.get_profile_workunit(sObjectName, datasetUrn) @@ -416,13 +422,15 @@ def get_salesforce_objects(self) -> List: entities_response = self.sf._call_salesforce("GET", query_url).json() logger.debug( "Salesforce EntityDefinition query returned {count} sObjects".format( - count=len(entities_response["records"]) - ) + count=len(entities_response["records"]), + ), ) return entities_response["records"] def get_domain_workunit( - self, dataset_name: str, datasetUrn: str + self, + dataset_name: str, + datasetUrn: str, ) -> Iterable[MetadataWorkUnit]: domain_urn: Optional[str] = None @@ -432,7 +440,8 @@ def get_domain_workunit( if domain_urn: yield from add_domain_to_entity_wu( - domain_urn=domain_urn, entity_urn=datasetUrn + domain_urn=domain_urn, + entity_urn=datasetUrn, ) def get_platform_instance_workunit(self, datasetUrn: str) -> MetadataWorkUnit: @@ -445,17 +454,20 @@ def get_platform_instance_workunit(self, datasetUrn: str) -> MetadataWorkUnit: ) return MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=dataPlatformInstance + entityUrn=datasetUrn, + aspect=dataPlatformInstance, ).as_workunit() def get_operation_workunit( - self, customObject: dict, datasetUrn: str + self, + customObject: dict, + datasetUrn: str, ) -> Iterable[MetadataWorkUnit]: reported_time: int = int(time.time() * 1000) if customObject.get("CreatedBy") and customObject.get("CreatedDate"): timestamp = self.get_time_from_salesforce_timestamp( - customObject["CreatedDate"] + customObject["CreatedDate"], ) operation = OperationClass( timestampMillis=reported_time, @@ -465,7 +477,8 @@ def get_operation_workunit( ) yield MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=operation + entityUrn=datasetUrn, + aspect=operation, ).as_workunit() # Note - Object Level LastModified captures changes at table level metadata e.g. table @@ -473,30 +486,34 @@ def get_operation_workunit( # field updated if customObject.get("LastModifiedBy") and customObject.get( - "LastModifiedDate" + "LastModifiedDate", ): timestamp = self.get_time_from_salesforce_timestamp( - customObject["LastModifiedDate"] + customObject["LastModifiedDate"], ) operation = OperationClass( timestampMillis=reported_time, operationType=OperationTypeClass.ALTER, lastUpdatedTimestamp=timestamp, actor=builder.make_user_urn( - customObject["LastModifiedBy"]["Username"] + customObject["LastModifiedBy"]["Username"], ), ) yield MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=operation + entityUrn=datasetUrn, + aspect=operation, ).as_workunit() def get_time_from_salesforce_timestamp(self, date: str) -> int: return round( - datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f%z").timestamp() * 1000 + datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f%z").timestamp() * 1000, ) def get_properties_workunit( - self, sObject: dict, customObject: Dict[str, str], datasetUrn: str + self, + sObject: dict, + customObject: Dict[str, str], + datasetUrn: str, ) -> MetadataWorkUnit: propertyLabels = { # from EntityDefinition @@ -522,7 +539,7 @@ def get_properties_workunit( propertyLabels[k]: str(v) for k, v in customObject.items() if k in propertyLabels and v is not None - } + }, ) datasetProperties = DatasetPropertiesClass( @@ -531,11 +548,14 @@ def get_properties_workunit( customProperties=sObjectProperties, ) return MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=datasetProperties + entityUrn=datasetUrn, + aspect=datasetProperties, ).as_workunit() def get_subtypes_workunit( - self, sObjectName: str, datasetUrn: str + self, + sObjectName: str, + datasetUrn: str, ) -> MetadataWorkUnit: subtypes: List[str] = [] if sObjectName.endswith("__c"): @@ -544,11 +564,14 @@ def get_subtypes_workunit( subtypes.append(DatasetSubTypes.SALESFORCE_STANDARD_OBJECT) return MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=SubTypesClass(typeNames=subtypes) + entityUrn=datasetUrn, + aspect=SubTypesClass(typeNames=subtypes), ).as_workunit() def get_profile_workunit( - self, sObjectName: str, datasetUrn: str + self, + sObjectName: str, + datasetUrn: str, ) -> Iterable[MetadataWorkUnit]: # Here approximate record counts as returned by recordCount API are used as rowCount # In future, count() SOQL query may be used instead, if required, might be more expensive @@ -557,13 +580,14 @@ def get_profile_workunit( ) sObject_record_count_response = self.sf._call_salesforce( - "GET", sObject_records_count_url + "GET", + sObject_records_count_url, ).json() logger.debug( "Received Salesforce {sObject} record count response".format( - sObject=sObjectName - ) + sObject=sObjectName, + ), ) for entry in sObject_record_count_response.get("sObjects", []): @@ -573,7 +597,8 @@ def get_profile_workunit( columnCount=self.fieldCounts[sObjectName], ) yield MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=datasetProfile + entityUrn=datasetUrn, + aspect=datasetProfile, ).as_workunit() # Here field description is created from label, description and inlineHelpText @@ -643,7 +668,8 @@ def _get_schema_field( # Created and LastModified Date and Actor are available for Custom Fields only if customField.get("CreatedDate") and customField.get("CreatedBy"): schemaField.created = self.get_audit_stamp( - customField["CreatedDate"], customField["CreatedBy"]["Username"] + customField["CreatedDate"], + customField["CreatedBy"]["Username"], ) if customField.get("LastModifiedDate") and customField.get("LastModifiedBy"): schemaField.lastModified = self.get_audit_stamp( @@ -676,7 +702,7 @@ def get_field_tags(self, fieldName: str, field: dict) -> List[str]: if field["FieldDefinition"]["ComplianceGroup"] is not None: # CCPA, COPPA, GDPR, HIPAA, PCI, PersonalInfo, PII fieldTags.extend( - iter(field["FieldDefinition"]["ComplianceGroup"].split(";")) + iter(field["FieldDefinition"]["ComplianceGroup"].split(";")), ) return fieldTags @@ -687,7 +713,11 @@ def get_audit_stamp(self, date: str, username: str) -> AuditStampClass: ) def get_schema_metadata_workunit( - self, sObjectName: str, sObject: dict, customObject: dict, datasetUrn: str + self, + sObjectName: str, + sObject: dict, + customObject: dict, + datasetUrn: str, ) -> Iterable[MetadataWorkUnit]: sObject_fields_query_url = ( self.base_url @@ -698,12 +728,13 @@ def get_schema_metadata_workunit( + "IsCompound, IsComponent, ReferenceTo, FieldDefinition.ComplianceGroup," + "RelationshipName, IsNillable, FieldDefinition.Description, InlineHelpText " + "FROM EntityParticle WHERE EntityDefinitionId='{}'".format( - sObject["DurableId"] + sObject["DurableId"], ) ) sObject_fields_response = self.sf._call_salesforce( - "GET", sObject_fields_query_url + "GET", + sObject_fields_query_url, ).json() logger.debug(f"Received Salesforce {sObjectName} fields response") @@ -714,20 +745,21 @@ def get_schema_metadata_workunit( + "DeveloperName,CreatedDate,CreatedBy.Username,InlineHelpText," + "LastModifiedDate,LastModifiedBy.Username " + "FROM CustomField WHERE EntityDefinitionId='{}'".format( - sObject["DurableId"] + sObject["DurableId"], ) ) customFields: Dict[str, Dict] = {} try: sObject_custom_fields_response = self.sf._call_salesforce( - "GET", sObject_custom_fields_query_url + "GET", + sObject_custom_fields_query_url, ).json() logger.debug( "Received Salesforce {sObject} custom fields response".format( - sObject=sObjectName - ) + sObject=sObjectName, + ), ) except Exception as e: @@ -758,7 +790,11 @@ def get_schema_metadata_workunit( continue schemaField: SchemaFieldClass = self._get_schema_field( - sObjectName, fieldName, fieldType, field, customField + sObjectName, + fieldName, + fieldType, + field, + customField, ) fields.append(schemaField) @@ -770,7 +806,9 @@ def get_schema_metadata_workunit( and field["ReferenceTo"]["referenceTo"] is not None ): foreignKeys.extend( - list(self.get_foreign_keys_from_field(fieldName, field, datasetUrn)) + list( + self.get_foreign_keys_from_field(fieldName, field, datasetUrn), + ), ) schemaMetadata = SchemaMetadataClass( @@ -787,16 +825,21 @@ def get_schema_metadata_workunit( # Created Date and Actor are available for Custom Object only if customObject.get("CreatedDate") and customObject.get("CreatedBy"): schemaMetadata.created = self.get_audit_stamp( - customObject["CreatedDate"], customObject["CreatedBy"]["Username"] + customObject["CreatedDate"], + customObject["CreatedBy"]["Username"], ) self.fieldCounts[sObjectName] = len(fields) yield MetadataChangeProposalWrapper( - entityUrn=datasetUrn, aspect=schemaMetadata + entityUrn=datasetUrn, + aspect=schemaMetadata, ).as_workunit() def get_foreign_keys_from_field( - self, fieldName: str, field: dict, datasetUrn: str + self, + fieldName: str, + field: dict, + datasetUrn: str, ) -> Iterable[ForeignKeyConstraintClass]: # https://developer.salesforce.com/docs/atlas.en-us.object_reference.meta/object_reference/field_types.htm#i1435823 foreignDatasets = [ diff --git a/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py b/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py index a50e99393fdc27..62370b16032e8d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py @@ -75,10 +75,10 @@ class URIReplacePattern(ConfigModel): class JsonSchemaSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): path: Union[FilePath, DirectoryPath, AnyHttpUrl] = Field( - description="Set this to a single file-path or a directory-path (for recursive traversal) or a remote url. e.g. https://json.schemastore.org/petstore-v1.0.json" + description="Set this to a single file-path or a directory-path (for recursive traversal) or a remote url. e.g. https://json.schemastore.org/petstore-v1.0.json", ) platform: str = Field( - description="Set this to a platform that you want all schemas to live under. e.g. schemaregistry / schemarepo etc." + description="Set this to a platform that you want all schemas to live under. e.g. schemaregistry / schemarepo etc.", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None use_id_as_base_uri: bool = Field( @@ -99,14 +99,15 @@ def download_http_url_to_temp_file(v): if not JsonSchemaTranslator._get_id_from_any_schema(schema_dict): schema_dict["$id"] = str(v) with tempfile.NamedTemporaryFile( - mode="w", delete=False + mode="w", + delete=False, ) as tmp_file: tmp_file.write(json.dumps(schema_dict)) tmp_file.flush() return tmp_file.name except Exception as e: logger.error( - f"Failed to localize url {v} due to {e}. Run with --debug to get full stacktrace" + f"Failed to localize url {v} due to {e}. Run with --debug to get full stacktrace", ) logger.debug(f"Failed to localize url {v} due to {e}", exc_info=e) raise @@ -248,7 +249,11 @@ def __init__(self, ctx: PipelineContext, config: JsonSchemaSourceConfig): self.report = StaleEntityRemovalSourceReport() def _load_one_file( - self, ref_loader: Any, browse_prefix: str, root_dir: Path, file_name: str + self, + ref_loader: Any, + browse_prefix: str, + root_dir: Path, + file_name: str, ) -> Iterable[MetadataWorkUnit]: with unittest.mock.patch("jsonref.JsonRef.callback", title_swapping_callback): (schema_dict, schema_string) = self._load_json_schema( @@ -285,11 +290,13 @@ def _load_one_file( env=self.config.env, ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=meta + entityUrn=dataset_urn, + aspect=meta, ).as_workunit() yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=models.StatusClass(removed=False) + entityUrn=dataset_urn, + aspect=models.StatusClass(removed=False), ).as_workunit() external_url = JsonSchemaTranslator._get_id_from_any_schema(schema_dict) @@ -302,7 +309,7 @@ def _load_one_file( externalUrl=external_url, name=dataset_simple_name, description=JsonSchemaTranslator._get_description_from_any_schema( - schema_dict + schema_dict, ), ), ).as_workunit() @@ -323,7 +330,7 @@ def _load_one_file( entityUrn=dataset_urn, aspect=models.DataPlatformInstanceClass( platform=str( - DataPlatformUrn.create_from_id(self.config.platform) + DataPlatformUrn.create_from_id(self.config.platform), ), instance=make_dataplatform_instance_urn( platform=self.config.platform, @@ -336,7 +343,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -365,10 +374,12 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) except Exception as e: self.report.report_failure( - f"{root}/{file_name}", f"Failed to process due to {e}" + f"{root}/{file_name}", + f"Failed to process due to {e}", ) logger.error( - f"Failed to process file {root}/{file_name}", exc_info=e + f"Failed to process file {root}/{file_name}", + exc_info=e, ) else: @@ -381,7 +392,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) except Exception as e: self.report.report_failure( - str(self.config.path), f"Failed to process due to {e}" + str(self.config.path), + f"Failed to process due to {e}", ) logger.error(f"Failed to process file {self.config.path}", exc_info=e) diff --git a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/json.py b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/json.py index 1659aaf6fa2020..0a077171fe5bbf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/json.py +++ b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/json.py @@ -45,7 +45,8 @@ def infer_schema(self, file: IO[bytes]) -> List[SchemaField]: datastore = [ obj for obj in itertools.islice( - reader.iter(type=dict, skip_invalid=True), self.max_rows + reader.iter(type=dict, skip_invalid=True), + self.max_rows, ) ] else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/object.py b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/object.py index bbcb114ee40c35..3cc44573004a2e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/object.py +++ b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/object.py @@ -67,7 +67,8 @@ def is_field_nullable(doc: Dict[str, Any], field_path: Tuple[str, ...]) -> bool: def is_nullable_collection( - collection: Sequence[Dict[str, Any]], field_path: Tuple + collection: Sequence[Dict[str, Any]], + field_path: Tuple, ) -> bool: """ Check if a nested field is nullable in a collection. @@ -84,7 +85,8 @@ def is_nullable_collection( def construct_schema( - collection: Sequence[Dict[str, Any]], delimiter: str + collection: Sequence[Dict[str, Any]], + delimiter: str, ) -> Dict[Tuple[str, ...], SchemaDescription]: """ Construct (infer) a schema from a collection of documents. diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/config.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/config.py index 6a47884e1b139a..66bf1fb874f633 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/config.py @@ -72,10 +72,13 @@ class PlatformDetail(PlatformInstanceConfigMixin, EnvConfigMixin): class SigmaSourceConfig( - StatefulIngestionConfigBase, PlatformInstanceConfigMixin, EnvConfigMixin + StatefulIngestionConfigBase, + PlatformInstanceConfigMixin, + EnvConfigMixin, ): api_url: str = pydantic.Field( - default=Constant.DEFAULT_API_URL, description="Sigma API hosted URL." + default=Constant.DEFAULT_API_URL, + description="Sigma API hosted URL.", ) client_id: str = pydantic.Field(description="Sigma Client ID") client_secret: str = pydantic.Field(description="Sigma Client Secret") @@ -107,5 +110,6 @@ class SigmaSourceConfig( description="A mapping of the sigma workspace/workbook/chart folder path to all chart's data sources platform details present inside that folder path.", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="Sigma Stateful Ingestion Config." + default=None, + description="Sigma Stateful Ingestion Config.", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py index 0468792f44aabb..cf52f7340181f8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py @@ -137,7 +137,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -173,7 +174,8 @@ def _get_allowed_workspaces(self) -> List[Workspace]: return allowed_workspaces def _gen_workspace_workunit( - self, workspace: Workspace + self, + workspace: Workspace, ) -> Iterable[MetadataWorkUnit]: """ Map Sigma workspace to Datahub container @@ -202,11 +204,14 @@ def _gen_sigma_dataset_urn(self, dataset_identifier: str) -> str: def _gen_entity_status_aspect(self, entity_urn: str) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( - entityUrn=entity_urn, aspect=Status(removed=False) + entityUrn=entity_urn, + aspect=Status(removed=False), ).as_workunit() def _gen_dataset_properties( - self, dataset_urn: str, dataset: SigmaDataset + self, + dataset_urn: str, + dataset: SigmaDataset, ) -> MetadataWorkUnit: dataset_properties = DatasetProperties( name=dataset.name, @@ -221,35 +226,41 @@ def _gen_dataset_properties( if dataset.path: dataset_properties.customProperties["path"] = dataset.path return MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() def _gen_dataplatform_instance_aspect( - self, entity_urn: str + self, + entity_urn: str, ) -> Optional[MetadataWorkUnit]: if self.config.platform_instance: aspect = DataPlatformInstanceClass( platform=builder.make_data_platform_urn(self.platform), instance=builder.make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ) return MetadataChangeProposalWrapper( - entityUrn=entity_urn, aspect=aspect + entityUrn=entity_urn, + aspect=aspect, ).as_workunit() else: return None def _gen_entity_owner_aspect( - self, entity_urn: str, user_name: str + self, + entity_urn: str, + user_name: str, ) -> MetadataWorkUnit: aspect = OwnershipClass( owners=[ OwnerClass( owner=builder.make_user_urn(user_name), type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -263,7 +274,7 @@ def _gen_entity_browsepath_aspect( paths: List[str], ) -> MetadataWorkUnit: entries = [ - BrowsePathEntryClass(id=parent_entity_urn, urn=parent_entity_urn) + BrowsePathEntryClass(id=parent_entity_urn, urn=parent_entity_urn), ] + [BrowsePathEntryClass(id=path) for path in paths] return MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -271,7 +282,8 @@ def _gen_entity_browsepath_aspect( ).as_workunit() def _gen_dataset_workunit( - self, dataset: SigmaDataset + self, + dataset: SigmaDataset, ) -> Iterable[MetadataWorkUnit]: dataset_urn = self._gen_sigma_dataset_urn(dataset.get_urn_part()) @@ -305,7 +317,7 @@ def _gen_dataset_workunit( yield self._gen_entity_browsepath_aspect( entity_urn=dataset_urn, parent_entity_urn=builder.make_container_urn( - self._gen_workspace_key(dataset.workspaceId) + self._gen_workspace_key(dataset.workspaceId), ), paths=paths, ) @@ -314,7 +326,7 @@ def _gen_dataset_workunit( yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=GlobalTagsClass( - tags=[TagAssociationClass(builder.make_tag_urn(dataset.badge))] + tags=[TagAssociationClass(builder.make_tag_urn(dataset.badge))], ), ).as_workunit() @@ -322,7 +334,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -350,11 +364,13 @@ def _gen_dashboard_info_workunit(self, page: Page) -> MetadataWorkUnit: customProperties={"ElementsCount": str(len(page.elements))}, ) return MetadataChangeProposalWrapper( - entityUrn=dashboard_urn, aspect=dashboard_info_cls + entityUrn=dashboard_urn, + aspect=dashboard_info_cls, ).as_workunit() def _get_element_data_source_platform_details( - self, full_path: str + self, + full_path: str, ) -> Optional[PlatformDetail]: data_source_platform_details: Optional[PlatformDetail] = None while full_path != "": @@ -376,7 +392,9 @@ def _get_element_data_source_platform_details( return data_source_platform_details def _get_element_input_details( - self, element: Element, workbook: Workbook + self, + element: Element, + workbook: Workbook, ) -> Dict[str, List[str]]: """ Returns dict with keys as the all element input dataset urn and values as their all upstream dataset urns @@ -385,7 +403,7 @@ def _get_element_input_details( sql_parser_in_tables: List[str] = [] data_source_platform_details = self._get_element_data_source_platform_details( - f"{workbook.path}/{workbook.name}/{element.name}" + f"{workbook.path}/{workbook.name}/{element.name}", ) if element.query and data_source_platform_details: @@ -447,7 +465,8 @@ def _gen_elements_workunit( yield self._gen_entity_status_aspect(chart_urn) inputs: Dict[str, List[str]] = self._get_element_input_details( - element, workbook + element, + workbook, ) yield MetadataChangeProposalWrapper( @@ -466,7 +485,7 @@ def _gen_elements_workunit( yield self._gen_entity_browsepath_aspect( entity_urn=chart_urn, parent_entity_urn=builder.make_container_urn( - self._gen_workspace_key(workbook.workspaceId) + self._gen_workspace_key(workbook.workspaceId), ), paths=paths + [workbook.name], ) @@ -501,7 +520,9 @@ def _gen_elements_workunit( all_input_fields.extend(element_input_fields) def _gen_pages_workunit( - self, workbook: Workbook, paths: List[str] + self, + workbook: Workbook, + paths: List[str], ) -> Iterable[MetadataWorkUnit]: """ Map Sigma workbook page to Datahub dashboard @@ -523,13 +544,16 @@ def _gen_pages_workunit( yield self._gen_entity_browsepath_aspect( entity_urn=dashboard_urn, parent_entity_urn=builder.make_container_urn( - self._gen_workspace_key(workbook.workspaceId) + self._gen_workspace_key(workbook.workspaceId), ), paths=paths + [workbook.name], ) yield from self._gen_elements_workunit( - page.elements, workbook, all_input_fields, paths + page.elements, + workbook, + all_input_fields, + paths, ) yield MetadataChangeProposalWrapper( @@ -568,7 +592,8 @@ def _gen_workbook_workunit(self, workbook: Workbook) -> Iterable[MetadataWorkUni ], externalUrl=workbook.url, lastModified=ChangeAuditStampsClass( - created=created, lastModified=lastModified + created=created, + lastModified=lastModified, ), customProperties={ "path": workbook.path, @@ -576,7 +601,8 @@ def _gen_workbook_workunit(self, workbook: Workbook) -> Iterable[MetadataWorkUni }, ) yield MetadataChangeProposalWrapper( - entityUrn=dashboard_urn, aspect=dashboard_info_cls + entityUrn=dashboard_urn, + aspect=dashboard_info_cls, ).as_workunit() # Set subtype @@ -612,7 +638,7 @@ def _gen_workbook_workunit(self, workbook: Workbook) -> Iterable[MetadataWorkUni yield self._gen_entity_browsepath_aspect( entity_urn=dashboard_urn, parent_entity_urn=builder.make_container_urn( - self._gen_workspace_key(workbook.workspaceId) + self._gen_workspace_key(workbook.workspaceId), ), paths=paths + [workbook.name], ) @@ -638,7 +664,8 @@ def _gen_sigma_dataset_upstream_lineage_workunit( aspect=UpstreamLineage( upstreams=[ Upstream( - dataset=upstream_dataset_urn, type=DatasetLineageType.COPY + dataset=upstream_dataset_urn, + type=DatasetLineageType.COPY, ) for upstream_dataset_urn in upstream_dataset_urns ], diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py index 6762302ebe57c7..e87d1ea41f6858 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py @@ -49,7 +49,7 @@ def _generate_token(self): { "Authorization": f"Bearer {response_dict[Constant.ACCESS_TOKEN]}", "Content-Type": "application/json", - } + }, ) def _log_http_error(self, message: str) -> Any: @@ -79,11 +79,11 @@ def _refresh_access_token(self): { "Authorization": f"Bearer {response_dict[Constant.ACCESS_TOKEN]}", "Content-Type": "application/json", - } + }, ) except Exception as e: self._log_http_error( - message=f"Unable to refresh access token. Exception: {e}" + message=f"Unable to refresh access token. Exception: {e}", ) def _get_api_call(self, url: str) -> requests.Response: @@ -101,7 +101,7 @@ def get_workspace(self, workspace_id: str) -> Optional[Workspace]: return self.workspaces[workspace_id] else: response = self._get_api_call( - f"{self.config.api_url}/workspaces/{workspace_id}" + f"{self.config.api_url}/workspaces/{workspace_id}", ) if response.status_code == 403: logger.debug(f"Workspace {workspace_id} not accessible.") @@ -113,7 +113,7 @@ def get_workspace(self, workspace_id: str) -> Optional[Workspace]: return workspace except Exception as e: self._log_http_error( - message=f"Unable to fetch workspace '{workspace_id}'. Exception: {e}" + message=f"Unable to fetch workspace '{workspace_id}'. Exception: {e}", ) return None @@ -157,7 +157,7 @@ def _get_users(self) -> Dict[str, str]: return users except Exception as e: self._log_http_error( - message=f"Unable to fetch users details. Exception: {e}" + message=f"Unable to fetch users details. Exception: {e}", ) return {} @@ -166,13 +166,15 @@ def get_user_name(self, user_id: str) -> Optional[str]: @functools.lru_cache() def get_workspace_id_from_file_path( - self, parent_id: str, path: str + self, + parent_id: str, + path: str, ) -> Optional[str]: try: path_list = path.split("/") while len(path_list) != 1: # means current parent id is folder's id response = self._get_api_call( - f"{self.config.api_url}/files/{parent_id}" + f"{self.config.api_url}/files/{parent_id}", ) response.raise_for_status() parent_id = response.json()[Constant.PARENTID] @@ -180,7 +182,7 @@ def get_workspace_id_from_file_path( return parent_id except Exception as e: logger.error( - f"Unable to find workspace id using file path '{path}'. Exception: {e}" + f"Unable to find workspace id using file path '{path}'. Exception: {e}", ) return None @@ -197,7 +199,8 @@ def _get_files_metadata(self, file_type: str) -> Dict[str, File]: for file_dict in response_dict[Constant.ENTRIES]: file = File.parse_obj(file_dict) file.workspaceId = self.get_workspace_id_from_file_path( - file.parentId, file.path + file.parentId, + file.path, ) files_metadata[file_dict[Constant.ID]] = file if response_dict[Constant.NEXTPAGE]: @@ -208,7 +211,7 @@ def _get_files_metadata(self, file_type: str) -> Dict[str, File]: return files_metadata except Exception as e: self._log_http_error( - message=f"Unable to fetch files metadata. Exception: {e}" + message=f"Unable to fetch files metadata. Exception: {e}", ) return {} @@ -237,7 +240,7 @@ def get_sigma_datasets(self) -> List[SigmaDataset]: workspace = self.get_workspace(dataset.workspaceId) if workspace: if self.config.workspace_pattern.allowed( - workspace.name + workspace.name, ): datasets.append(dataset) elif self.config.ingest_shared_entities: @@ -253,12 +256,14 @@ def get_sigma_datasets(self) -> List[SigmaDataset]: return datasets except Exception as e: self._log_http_error( - message=f"Unable to fetch sigma datasets. Exception: {e}" + message=f"Unable to fetch sigma datasets. Exception: {e}", ) return [] def _get_element_upstream_sources( - self, element: Element, workbook: Workbook + self, + element: Element, + workbook: Workbook, ) -> Dict[str, str]: """ Returns upstream dataset sources with keys as id and values as name of that dataset @@ -266,16 +271,16 @@ def _get_element_upstream_sources( try: upstream_sources: Dict[str, str] = {} response = self._get_api_call( - f"{self.config.api_url}/workbooks/{workbook.workbookId}/lineage/elements/{element.elementId}" + f"{self.config.api_url}/workbooks/{workbook.workbookId}/lineage/elements/{element.elementId}", ) if response.status_code == 500: logger.debug( - f"Lineage metadata not present for element {element.name} of workbook '{workbook.name}'" + f"Lineage metadata not present for element {element.name} of workbook '{workbook.name}'", ) return upstream_sources if response.status_code == 403: logger.debug( - f"Lineage metadata not accessible for element {element.name} of workbook '{workbook.name}'" + f"Lineage metadata not accessible for element {element.name} of workbook '{workbook.name}'", ) return upstream_sources @@ -292,20 +297,22 @@ def _get_element_upstream_sources( return upstream_sources except Exception as e: self._log_http_error( - message=f"Unable to fetch lineage for element {element.name} of workbook '{workbook.name}'. Exception: {e}" + message=f"Unable to fetch lineage for element {element.name} of workbook '{workbook.name}'. Exception: {e}", ) return {} def _get_element_sql_query( - self, element: Element, workbook: Workbook + self, + element: Element, + workbook: Workbook, ) -> Optional[str]: try: response = self._get_api_call( - f"{self.config.api_url}/workbooks/{workbook.workbookId}/elements/{element.elementId}/query" + f"{self.config.api_url}/workbooks/{workbook.workbookId}/elements/{element.elementId}/query", ) if response.status_code == 404: logger.debug( - f"Query not present for element {element.name} of workbook '{workbook.name}'" + f"Query not present for element {element.name} of workbook '{workbook.name}'", ) return None response.raise_for_status() @@ -314,7 +321,7 @@ def _get_element_sql_query( return response_dict["sql"] except Exception as e: self._log_http_error( - message=f"Unable to fetch sql query for element {element.name} of workbook '{workbook.name}'. Exception: {e}" + message=f"Unable to fetch sql query for element {element.name} of workbook '{workbook.name}'. Exception: {e}", ) return None @@ -322,7 +329,7 @@ def get_page_elements(self, workbook: Workbook, page: Page) -> List[Element]: try: elements: List[Element] = [] response = self._get_api_call( - f"{self.config.api_url}/workbooks/{workbook.workbookId}/pages/{page.pageId}/elements" + f"{self.config.api_url}/workbooks/{workbook.workbookId}/pages/{page.pageId}/elements", ) response.raise_for_status() for i, element_dict in enumerate(response.json()[Constant.ENTRIES]): @@ -339,14 +346,15 @@ def get_page_elements(self, workbook: Workbook, page: Page) -> List[Element]: and self.config.workbook_lineage_pattern.allowed(workbook.name) ): element.upstream_sources = self._get_element_upstream_sources( - element, workbook + element, + workbook, ) element.query = self._get_element_sql_query(element, workbook) elements.append(element) return elements except Exception as e: self._log_http_error( - message=f"Unable to fetch elements of page '{page.name}', workbook '{workbook.name}'. Exception: {e}" + message=f"Unable to fetch elements of page '{page.name}', workbook '{workbook.name}'. Exception: {e}", ) return [] @@ -354,7 +362,7 @@ def get_workbook_pages(self, workbook: Workbook) -> List[Page]: try: pages: List[Page] = [] response = self._get_api_call( - f"{self.config.api_url}/workbooks/{workbook.workbookId}/pages" + f"{self.config.api_url}/workbooks/{workbook.workbookId}/pages", ) response.raise_for_status() for page_dict in response.json()[Constant.ENTRIES]: @@ -364,7 +372,7 @@ def get_workbook_pages(self, workbook: Workbook) -> List[Page]: return pages except Exception as e: self._log_http_error( - message=f"Unable to fetch pages of workbook '{workbook.name}'. Exception: {e}" + message=f"Unable to fetch pages of workbook '{workbook.name}'. Exception: {e}", ) return [] @@ -394,7 +402,7 @@ def get_sigma_workbooks(self) -> List[Workbook]: workspace = self.get_workspace(workbook.workspaceId) if workspace: if self.config.workspace_pattern.allowed( - workspace.name + workspace.name, ): workbook.pages = self.get_workbook_pages(workbook) workbooks.append(workbook) @@ -412,6 +420,6 @@ def get_sigma_workbooks(self) -> List[Workbook]: return workbooks except Exception as e: self._log_http_error( - message=f"Unable to fetch sigma workbooks. Exception: {e}" + message=f"Unable to fetch sigma workbooks. Exception: {e}", ) return [] diff --git a/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py b/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py index 00b783a89774cc..d9e8ac5712346e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py +++ b/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py @@ -97,7 +97,8 @@ def __init__(self, ctx: PipelineContext, config: SlackSourceConfig): self.report = SlackSourceReport() self.workspace_base_url: Optional[str] = None self.rate_limiter = RateLimiter( - max_calls=self.config.api_requests_per_min, period=60 + max_calls=self.config.api_requests_per_min, + period=60, ) self._use_users_info = False @@ -135,7 +136,8 @@ def get_user_info(self) -> Iterable[MetadataWorkUnit]: logger.info(f"User: {user_obj}") corpuser_editable_info = ( self.ctx.graph.get_aspect( - entity_urn=user_obj.urn, aspect_type=CorpUserEditableInfoClass + entity_urn=user_obj.urn, + aspect_type=CorpUserEditableInfoClass, ) or CorpUserEditableInfoClass() ) @@ -161,7 +163,8 @@ def get_user_info(self) -> Iterable[MetadataWorkUnit]: ) def _get_channel_info( - self, cursor: Optional[str] + self, + cursor: Optional[str], ) -> Tuple[List[MetadataWorkUnit], Optional[str]]: result_channels: List[MetadataWorkUnit] = [] with self.rate_limiter: @@ -173,7 +176,8 @@ def _get_channel_info( assert isinstance(response.data, dict) if not response.data["ok"]: self.report.report_failure( - "public_channel", "Failed to fetch public channels" + "public_channel", + "Failed to fetch public channels", ) return result_channels, None for channel in response.data["channels"]: @@ -182,7 +186,8 @@ def _get_channel_info( continue channel_id = channel["id"] urn_channel = builder.make_dataset_urn( - platform=PLATFORM_NAME, name=channel_id + platform=PLATFORM_NAME, + name=channel_id, ) name = channel["name"] is_archived = channel.get("is_archived", False) @@ -202,7 +207,7 @@ def _get_channel_info( actor="urn:li:corpuser:datahub", ), ), - ) + ), ) topic = channel.get("topic", {}).get("value") @@ -220,7 +225,7 @@ def _get_channel_info( description=f"Topic: {topic}\nPurpose: {purpose}", ), ), - ) + ), ) result_channels.append( MetadataWorkUnit( @@ -231,7 +236,7 @@ def _get_channel_info( typeNames=["Slack Channel"], ), ), - ) + ), ) cursor = str(response.data["response_metadata"]["next_cursor"]) return result_channels, cursor @@ -255,12 +260,12 @@ def populate_user_profile(self, user_obj: CorpUser) -> None: with self.rate_limiter: if self._use_users_info: user_profile_res = self.get_slack_client().users_info( - user=user_obj.slack_id + user=user_obj.slack_id, ) user_profile_res = user_profile_res.get("user", {}) else: user_profile_res = self.get_slack_client().users_profile_get( - user=user_obj.slack_id + user=user_obj.slack_id, ) logger.debug(f"User profile: {user_profile_res}") user_profile = user_profile_res.get("profile", {}) @@ -285,7 +290,7 @@ def populate_slack_id_from_email(self, user_obj: CorpUser) -> None: # https://api.slack.com/methods/users.lookupByEmail with self.rate_limiter: user_info_res = self.get_slack_client().users_lookupByEmail( - email=user_obj.email + email=user_obj.email, ) user_info = user_info_res.get("user", {}) user_obj.slack_id = user_info.get("id") @@ -309,7 +314,7 @@ def get_user_to_be_updated(self) -> Iterable[CorpUser]: } } } - """ + """, ) start = 0 count = 10 @@ -320,7 +325,8 @@ def get_user_to_be_updated(self) -> Iterable[CorpUser]: while start < total: variables = {"input": {"start": start, "count": count}} response = self.ctx.graph.execute_graphql( - query=graphql_query, variables=variables + query=graphql_query, + variables=variables, ) list_users = response.get("listUsers", {}) total = list_users.get("total", 0) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_config.py index 61a06580299db6..00809f6b919149 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_config.py @@ -14,7 +14,7 @@ class OAuthIdentityProvider(Enum): class OAuthConfiguration(ConfigModel): provider: OAuthIdentityProvider = Field( description="Identity provider for oauth." - "Supported providers are microsoft and okta." + "Supported providers are microsoft and okta.", ) authority_url: str = Field(description="Authority url of your identity provider") client_id: str = Field(description="client id of your registered application") @@ -24,11 +24,14 @@ class OAuthConfiguration(ConfigModel): default=False, ) client_secret: Optional[SecretStr] = Field( - None, description="client secret of the application if use_certificate = false" + None, + description="client secret of the application if use_certificate = false", ) encoded_oauth_public_key: Optional[str] = Field( - None, description="base64 encoded certificate content if use_certificate = true" + None, + description="base64 encoded certificate content if use_certificate = true", ) encoded_oauth_private_key: Optional[str] = Field( - None, description="base64 encoded private key content if use_certificate = true" + None, + description="base64 encoded private key content if use_certificate = true", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_generator.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_generator.py index a2dc0118b39782..c1ec22b1196379 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_generator.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/oauth_generator.py @@ -41,7 +41,9 @@ def _get_token( def _get_microsoft_token(self, credentials, scopes, check_cache): app = msal.ConfidentialClientApplication( - self.client_id, authority=self.authority_url, client_credential=credentials + self.client_id, + authority=self.authority_url, + client_credential=credentials, ) _token = None if check_cache: @@ -85,7 +87,7 @@ def get_token_with_certificate( decoded_private_key_content = base64.b64decode(private_key_content) decoded_public_key_content = base64.b64decode(public_key_content) public_cert_thumbprint = self.get_public_certificate_thumbprint( - str(decoded_public_key_content, "UTF-8") + str(decoded_public_key_content, "UTF-8"), ) CLIENT_CREDENTIAL = { @@ -95,6 +97,9 @@ def get_token_with_certificate( return self._get_token(CLIENT_CREDENTIAL, scopes, check_cache) def get_token_with_secret( - self, secret: str, scopes: Optional[List[str]], check_cache: bool = False + self, + secret: str, + scopes: Optional[List[str]], + check_cache: bool = False, ) -> Any: return self._get_token(secret, scopes, check_cache) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py index a7c008d932a713..c5fec333cf22d2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py @@ -54,13 +54,14 @@ def __init__( self._urns_processed: List[str] = [] def get_assertion_workunits( - self, discovered_datasets: List[str] + self, + discovered_datasets: List[str], ) -> Iterable[MetadataWorkUnit]: cur = self.connection.query( SnowflakeQuery.dmf_assertion_results( datetime_to_ts_millis(self.config.start_time), datetime_to_ts_millis(self.config.end_time), - ) + ), ) for db_row in cur: mcp = self._process_result_row(db_row, discovered_datasets) @@ -79,7 +80,8 @@ def _gen_platform_instance_wu(self, urn: str) -> MetadataWorkUnit: platform=make_data_platform_urn(self.identifiers.platform), instance=( make_dataplatform_instance_urn( - self.identifiers.platform, self.config.platform_instance + self.identifiers.platform, + self.config.platform_instance, ) if self.config.platform_instance else None @@ -88,14 +90,18 @@ def _gen_platform_instance_wu(self, urn: str) -> MetadataWorkUnit: ).as_workunit(is_primary_source=False) def _process_result_row( - self, result_row: dict, discovered_datasets: List[str] + self, + result_row: dict, + discovered_datasets: List[str], ) -> Optional[MetadataChangeProposalWrapper]: try: result = DataQualityMonitoringResult.parse_obj(result_row) assertion_guid = result.METRIC_NAME.split("__")[-1].lower() status = bool(result.VALUE) # 1 if PASS, 0 if FAIL assertee = self.identifiers.get_dataset_identifier( - result.TABLE_NAME, result.TABLE_SCHEMA, result.TABLE_DATABASE + result.TABLE_NAME, + result.TABLE_SCHEMA, + result.TABLE_DATABASE, ) if assertee in discovered_datasets: return MetadataChangeProposalWrapper( @@ -111,7 +117,7 @@ def _process_result_row( AssertionResultType.SUCCESS if status else AssertionResultType.FAILURE - ) + ), ), ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 5f732e2621656f..1dab0144fcde4b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -60,7 +60,7 @@ class TagOption(StrEnum): @dataclass(frozen=True) class DatabaseId: database: str = Field( - description="Database created from share in consumer account." + description="Database created from share in consumer account.", ) platform_instance: Optional[str] = Field( default=None, @@ -76,7 +76,7 @@ class SnowflakeShareConfig(ConfigModel): ) consumers: Set[DatabaseId] = Field( - description="List of databases created in consumer accounts." + description="List of databases created in consumer accounts.", ) @property @@ -117,7 +117,7 @@ def validate_legacy_schema_pattern(cls, values: Dict) -> Dict: logger.warning( "Please update `schema_pattern` to match against fully qualified schema name `.` and set config `match_fully_qualified_names : True`." "Current default `match_fully_qualified_names: False` is only to maintain backward compatibility. " - "The config option `match_fully_qualified_names` will be deprecated in future and the default behavior will assume `match_fully_qualified_names: True`." + "The config option `match_fully_qualified_names` will be deprecated in future and the default behavior will assume `match_fully_qualified_names: True`.", ) # Always exclude reporting metadata for INFORMATION_SCHEMA schema @@ -130,7 +130,9 @@ def validate_legacy_schema_pattern(cls, values: Dict) -> Dict: class SnowflakeIdentifierConfig( - PlatformInstanceConfigMixin, EnvConfigMixin, LowerCaseDatasetUrnConfigMixin + PlatformInstanceConfigMixin, + EnvConfigMixin, + LowerCaseDatasetUrnConfigMixin, ): # Changing default value here. convert_urns_to_lowercase: bool = Field( @@ -255,7 +257,7 @@ class SnowflakeV2Config( ) _use_legacy_lineage_method_removed = pydantic_removed_field( - "use_legacy_lineage_method" + "use_legacy_lineage_method", ) validate_upstreams_against_patterns: bool = Field( @@ -285,7 +287,8 @@ class SnowflakeV2Config( ) rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field( - "upstreams_deny_pattern", "temporary_tables_pattern" + "upstreams_deny_pattern", + "temporary_tables_pattern", ) shares: Optional[Dict[str, SnowflakeShareConfig]] = Field( @@ -306,7 +309,7 @@ class SnowflakeV2Config( def validate_convert_urns_to_lowercase(cls, v): if not v: add_global_warning( - "Please use `convert_urns_to_lowercase: True`, otherwise lineage to other sources may not work correctly." + "Please use `convert_urns_to_lowercase: True`, otherwise lineage to other sources may not work correctly.", ) return v @@ -315,7 +318,7 @@ def validate_convert_urns_to_lowercase(cls, v): def validate_include_column_lineage(cls, v, values): if not values.get("include_table_lineage") and v: raise ValueError( - "include_table_lineage must be True for include_column_lineage to be set." + "include_table_lineage must be True for include_column_lineage to be set.", ) return v @@ -340,10 +343,10 @@ def validate_unsupported_configs(cls, values: Dict) -> Dict: # TODO: Allow profiling irrespective of basic schema extraction, # as it seems possible with some refactor if not include_technical_schema and any( - [include_profiles, delete_detection_enabled] + [include_profiles, delete_detection_enabled], ): raise ValueError( - "Cannot perform Deletion Detection or Profiling without extracting snowflake technical schema. Set `include_technical_schema` to True or disable Deletion Detection and Profiling." + "Cannot perform Deletion Detection or Profiling without extracting snowflake technical schema. Set `include_technical_schema` to True or disable Deletion Detection and Profiling.", ) return values @@ -356,12 +359,18 @@ def get_sql_alchemy_url( role: Optional[str] = None, ) -> str: return SnowflakeConnectionConfig.get_sql_alchemy_url( - self, database=database, username=username, password=password, role=role + self, + database=database, + username=username, + password=password, + role=role, ) @validator("shares") def validate_shares( - cls, shares: Optional[Dict[str, SnowflakeShareConfig]], values: Dict + cls, + shares: Optional[Dict[str, SnowflakeShareConfig]], + values: Dict, ) -> Optional[Dict[str, SnowflakeShareConfig]]: current_platform_instance = values.get("platform_instance") @@ -370,7 +379,7 @@ def validate_shares( if current_platform_instance is None: logger.info( "It is advisable to use `platform_instance` when ingesting from multiple snowflake accounts, if they contain databases with same name. " - "Setting `platform_instance` allows distinguishing such databases without conflict and correctly ingest their metadata." + "Setting `platform_instance` allows distinguishing such databases without conflict and correctly ingest their metadata.", ) databases_included_in_share: List[DatabaseId] = [] @@ -378,7 +387,8 @@ def validate_shares( for share_details in shares.values(): shared_db = DatabaseId( - share_details.database, share_details.platform_instance + share_details.database, + share_details.platform_instance, ) if current_platform_instance: assert all( @@ -411,7 +421,7 @@ def outbounds(self) -> Dict[str, Set[DatabaseId]]: for share_name, share_details in self.shares.items(): if share_details.platform_instance == self.platform_instance: logger.debug( - f"database {share_details.database} is included in outbound share(s) {share_name}." + f"database {share_details.database} is included in outbound share(s) {share_name}.", ) outbounds[share_details.database].update(share_details.consumers) return outbounds @@ -427,7 +437,7 @@ def inbounds(self) -> Dict[str, DatabaseId]: for consumer in share_details.consumers: if consumer.platform_instance == self.platform_instance: logger.debug( - f"database {consumer.database} is created from inbound share {share_name}." + f"database {consumer.database} is created from inbound share {share_name}.", ) inbounds[consumer.database] = share_details.source_database if self.platform_instance: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py index 2854a99198d62b..4af20882d4aad3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py @@ -65,10 +65,13 @@ class SnowflakeConnectionConfig(ConfigModel): scheme: str = "snowflake" username: Optional[str] = pydantic.Field( - default=None, description="Snowflake username." + default=None, + description="Snowflake username.", ) password: Optional[pydantic.SecretStr] = pydantic.Field( - default=None, exclude=True, description="Snowflake password." + default=None, + exclude=True, + description="Snowflake password.", ) private_key: Optional[str] = pydantic.Field( default=None, @@ -97,7 +100,8 @@ class SnowflakeConnectionConfig(ConfigModel): description="Snowflake account identifier. e.g. xy12345, xy12345.us-east-2.aws, xy12345.us-central1.gcp, xy12345.central-us.azure, xy12345.us-west-2.privatelink. Refer [Account Identifiers](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#format-2-legacy-account-locator-in-a-region) for more details.", ) warehouse: Optional[str] = pydantic.Field( - default=None, description="Snowflake warehouse." + default=None, + description="Snowflake warehouse.", ) role: Optional[str] = pydantic.Field(default=None, description="Snowflake role.") connect_args: Optional[Dict[str, Any]] = pydantic.Field( @@ -128,7 +132,7 @@ def authenticator_type_is_valid(cls, v, values): if v not in _VALID_AUTH_TYPES.keys(): raise ValueError( f"unsupported authenticator type '{v}' was provided," - f" use one of {list(_VALID_AUTH_TYPES.keys())}" + f" use one of {list(_VALID_AUTH_TYPES.keys())}", ) if ( values.get("private_key") is not None @@ -136,7 +140,7 @@ def authenticator_type_is_valid(cls, v, values): ) and v != "KEY_PAIR_AUTHENTICATOR": raise ValueError( f"Either `private_key` and `private_key_path` is set but `authentication_type` is {v}. " - f"Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication" + f"Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication", ) if v == "KEY_PAIR_AUTHENTICATOR": # If we are using key pair auth, we need the private key path and password to be set @@ -146,7 +150,7 @@ def authenticator_type_is_valid(cls, v, values): ): raise ValueError( f"Both `private_key` and `private_key_path` are none. " - f"At least one should be set when using {v} authentication" + f"At least one should be set when using {v} authentication", ) elif v == "OAUTH_AUTHENTICATOR": cls._check_oauth_config(values.get("oauth_config")) @@ -161,7 +165,7 @@ def validate_token_oauth_config(cls, v, values): raise ValueError("Token required for OAUTH_AUTHENTICATOR_TOKEN.") elif v is not None: raise ValueError( - "Token can only be provided when using OAUTH_AUTHENTICATOR_TOKEN" + "Token can only be provided when using OAUTH_AUTHENTICATOR_TOKEN", ) return v @@ -169,27 +173,27 @@ def validate_token_oauth_config(cls, v, values): def _check_oauth_config(oauth_config: Optional[OAuthConfiguration]) -> None: if oauth_config is None: raise ValueError( - "'oauth_config' is none but should be set when using OAUTH_AUTHENTICATOR authentication" + "'oauth_config' is none but should be set when using OAUTH_AUTHENTICATOR authentication", ) if oauth_config.use_certificate is True: if oauth_config.provider == OAuthIdentityProvider.OKTA: raise ValueError( - "Certificate authentication is not supported for Okta." + "Certificate authentication is not supported for Okta.", ) if oauth_config.encoded_oauth_private_key is None: raise ValueError( "'base64_encoded_oauth_private_key' was none " - "but should be set when using certificate for oauth_config" + "but should be set when using certificate for oauth_config", ) if oauth_config.encoded_oauth_public_key is None: raise ValueError( "'base64_encoded_oauth_public_key' was none" - "but should be set when using use_certificate true for oauth_config" + "but should be set when using use_certificate true for oauth_config", ) elif oauth_config.client_secret is None: raise ValueError( "'oauth_config.client_secret' was none " - "but should be set when using use_certificate false for oauth_config" + "but should be set when using use_certificate false for oauth_config", ) def get_sql_alchemy_url( @@ -311,7 +315,7 @@ def get_oauth_connection(self) -> NativeSnowflakeConnection: except KeyError: raise ValueError( f"access_token not found in response {response}. " - "Please check your OAuth configuration." + "Please check your OAuth configuration.", ) connect_args = self.get_options()["connect_args"] return snowflake.connector.connect( @@ -388,11 +392,11 @@ def get_connection(self) -> "SnowflakeConnection": if "not granted to this user" in str(e): raise SnowflakePermissionError( - f"Permissions error when connecting to snowflake: {e}" + f"Permissions error when connecting to snowflake: {e}", ) from e raise ConfigurationError( - f"Failed to connect to snowflake instance: {e}" + f"Failed to connect to snowflake instance: {e}", ) from e diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py index c9615bb498fe48..4a2ef143a465c2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py @@ -13,19 +13,25 @@ class SnowflakeDataReader(DataReader): @staticmethod def create( - conn: SnowflakeConnection, col_name_preprocessor: Callable[[str], str] + conn: SnowflakeConnection, + col_name_preprocessor: Callable[[str], str], ) -> "SnowflakeDataReader": return SnowflakeDataReader(conn, col_name_preprocessor) def __init__( - self, conn: SnowflakeConnection, col_name_preprocessor: Callable[[str], str] + self, + conn: SnowflakeConnection, + col_name_preprocessor: Callable[[str], str], ) -> None: # The lifecycle of this connection is managed externally self.conn = conn self.col_name_preprocessor = col_name_preprocessor def get_sample_data_for_table( - self, table_id: List[str], sample_size: int, **kwargs: Any + self, + table_id: List[str], + sample_size: int, + **kwargs: Any, ) -> Dict[str, list]: """ For snowflake, table_id should be in form (db_name, schema_name, table_name) @@ -37,7 +43,7 @@ def get_sample_data_for_table( table_name = table_id[2] logger.debug( - f"Collecting sample values for table {db_name}.{schema_name}.{table_name}" + f"Collecting sample values for table {db_name}.{schema_name}.{table_name}", ) with PerfTimer() as timer, self.conn.native_connection().cursor() as cursor: sql = f'select * from "{db_name}"."{schema_name}"."{table_name}" sample ({sample_size} rows);' @@ -49,7 +55,7 @@ def get_sample_data_for_table( time_taken = timer.elapsed_seconds() logger.debug( f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};" - f"{df.shape[0]} rows; took {time_taken:.3f} seconds" + f"{df.shape[0]} rows; took {time_taken:.3f} seconds", ) return df.to_dict(orient="list") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index e93ecf30171f65..cc34da7b3c2cc1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -194,7 +194,7 @@ def populate_table_upstreams(self, discovered_tables: List[str]) -> None: # TODO: use sql_aggregator.add_observed_query to report queries from # snowflake.account_usage.query_history and let Datahub generate lineage, usage and operations logger.info( - "Snowflake Account is Standard Edition. Table to Table and View to Table Lineage Feature is not supported." + "Snowflake Account is Standard Edition. Table to Table and View to Table Lineage Feature is not supported.", ) # See Edition Note above for why else: with PerfTimer() as timer: @@ -216,14 +216,16 @@ def populate_known_query_lineage( ) -> None: for db_row in results: dataset_name = self.identifiers.get_dataset_identifier_from_qualified_name( - db_row.DOWNSTREAM_TABLE_NAME + db_row.DOWNSTREAM_TABLE_NAME, ) if dataset_name not in discovered_assets or not db_row.QUERIES: continue for query in db_row.QUERIES: known_lineage = self.get_known_query_lineage( - query, dataset_name, db_row + query, + dataset_name, + db_row, ) if known_lineage and known_lineage.upstreams: self.report.num_tables_with_known_upstreams += 1 @@ -232,7 +234,10 @@ def populate_known_query_lineage( logger.debug(f"No lineage found for {dataset_name}") def get_known_query_lineage( - self, query: Query, dataset_name: str, db_row: UpstreamLineageEdge + self, + query: Query, + dataset_name: str, + db_row: UpstreamLineageEdge, ) -> Optional[KnownQueryLineageInfo]: if not db_row.UPSTREAM_TABLES: return None @@ -241,12 +246,15 @@ def get_known_query_lineage( known_lineage = KnownQueryLineageInfo( query_id=get_query_fingerprint( - query.query_text, self.identifiers.platform, fast=True + query.query_text, + self.identifiers.platform, + fast=True, ), query_text=query.query_text, downstream=downstream_table_urn, upstreams=self.map_query_result_upstreams( - db_row.UPSTREAM_TABLES, query.query_id + db_row.UPSTREAM_TABLES, + query.query_id, ), column_lineage=( self.map_query_result_fine_upstreams( @@ -277,7 +285,8 @@ def _populate_external_upstreams(self, discovered_tables: List[str]) -> None: # Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv'; # NOTE: Snowflake does not log this information to the access_history table. def _get_copy_history_lineage( - self, discovered_tables: List[str] + self, + discovered_tables: List[str], ) -> Iterable[KnownLineageMapping]: query: str = SnowflakeQuery.copy_lineage_history( start_time_millis=int(self.start_time.timestamp() * 1000), @@ -288,7 +297,9 @@ def _get_copy_history_lineage( try: for db_row in self.connection.query(query): known_lineage_mapping = self._process_external_lineage_result_row( - db_row, discovered_tables, identifiers=self.identifiers + db_row, + discovered_tables, + identifiers=self.identifiers, ) if known_lineage_mapping: self.report.num_external_table_edges_scanned += 1 @@ -313,7 +324,7 @@ def _process_external_lineage_result_row( ) -> Optional[KnownLineageMapping]: # key is the down-stream table name key: str = identifiers.get_dataset_identifier_from_qualified_name( - db_row["DOWNSTREAM_TABLE_NAME"] + db_row["DOWNSTREAM_TABLE_NAME"], ) if discovered_tables is not None and key not in discovered_tables: return None @@ -326,7 +337,8 @@ def _process_external_lineage_result_row( if loc.startswith("s3://"): return KnownLineageMapping( upstream_urn=make_s3_urn_for_lineage( - loc, identifiers.identifier_config.env + loc, + identifiers.identifier_config.env, ), downstream_urn=identifiers.gen_dataset_urn(key), ) @@ -357,7 +369,8 @@ def _fetch_upstream_lineages_for_tables(self) -> Iterable[UpstreamLineageEdge]: self.report_status(TABLE_LINEAGE, False) def _process_upstream_lineage_row( - self, db_row: dict + self, + db_row: dict, ) -> Optional[UpstreamLineageEdge]: try: return UpstreamLineageEdge.parse_obj(db_row) @@ -376,7 +389,9 @@ def _process_upstream_lineage_row( return None def map_query_result_upstreams( - self, upstream_tables: Optional[List[UpstreamTableNode]], query_id: str + self, + upstream_tables: Optional[List[UpstreamTableNode]], + query_id: str, ) -> List[UrnStr]: if not upstream_tables: return [] @@ -386,7 +401,7 @@ def map_query_result_upstreams( try: upstream_name = ( self.identifiers.get_dataset_identifier_from_qualified_name( - upstream_table.upstream_object_name + upstream_table.upstream_object_name, ) ) if upstream_name and ( @@ -397,7 +412,7 @@ def map_query_result_upstreams( ) ): upstreams.append( - self.identifiers.gen_dataset_urn(upstream_name) + self.identifiers.gen_dataset_urn(upstream_name), ) except Exception as e: logger.debug(e, exc_info=e) @@ -416,7 +431,10 @@ def map_query_result_fine_upstreams( if column_with_upstreams: try: self._process_add_single_column_upstream( - dataset_urn, fine_upstreams, column_with_upstreams, query_id + dataset_urn, + fine_upstreams, + column_with_upstreams, + query_id, ) except Exception as e: logger.debug(e, exc_info=e) @@ -462,7 +480,8 @@ def build_finegrained_lineage( return None column_lineage = ColumnLineageInfo( downstream=DownstreamColumnRef( - table=dataset_urn, column=self.identifiers.snowflake_identifier(col) + table=dataset_urn, + column=self.identifiers.snowflake_identifier(col), ), upstreams=sorted(column_upstreams), ) @@ -470,7 +489,8 @@ def build_finegrained_lineage( return column_lineage def build_finegrained_lineage_upstreams( - self, upstream_columms: Set[SnowflakeColumnId] + self, + upstream_columms: Set[SnowflakeColumnId], ) -> List[ColumnRef]: column_upstreams = [] for upstream_col in upstream_columms: @@ -487,16 +507,16 @@ def build_finegrained_lineage_upstreams( ): upstream_dataset_name = ( self.identifiers.get_dataset_identifier_from_qualified_name( - upstream_col.object_name + upstream_col.object_name, ) ) column_upstreams.append( ColumnRef( table=self.identifiers.gen_dataset_urn(upstream_dataset_name), column=self.identifiers.snowflake_identifier( - upstream_col.column_name + upstream_col.column_name, ), - ) + ), ) return column_upstreams @@ -507,7 +527,8 @@ def get_external_upstreams(self, external_lineage: Set[str]) -> List[UpstreamCla if external_lineage_entry.startswith("s3://"): external_upstream_table = UpstreamClass( dataset=make_s3_urn_for_lineage( - external_lineage_entry, self.config.env + external_lineage_entry, + self.config.env, ), type=DatasetLineageTypeClass.COPY, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py index 422bda5284dbc5..c5b612a27be583 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py @@ -40,13 +40,16 @@ def __init__( self.database_default_schema: Dict[str, str] = dict() def get_workunits( - self, database: SnowflakeDatabase, db_tables: Dict[str, List[SnowflakeTable]] + self, + database: SnowflakeDatabase, + db_tables: Dict[str, List[SnowflakeTable]], ) -> Iterable[MetadataWorkUnit]: # Extra default SQLAlchemy option for better connection pooling and threading. # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow if self.config.is_profiling_enabled(): self.config.options.setdefault( - "max_overflow", self.config.profiling.max_workers + "max_overflow", + self.config.profiling.max_workers, ) if PUBLIC_SCHEMA not in db_tables: @@ -61,13 +64,15 @@ def get_workunits( and table.type == "EXTERNAL TABLE" ): logger.info( - f"Skipping profiling of external table {database.name}.{schema.name}.{table.name}" + f"Skipping profiling of external table {database.name}.{schema.name}.{table.name}", ) self.report.profiling_skipped_other[schema.name] += 1 continue profile_request = self.get_profile_request( - table, schema.name, database.name + table, + schema.name, + database.name, ) if profile_request is not None: self.report.report_entity_profiled(profile_request.pretty_name) @@ -88,7 +93,10 @@ def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> s return self.identifiers.get_dataset_identifier(table_name, schema_name, db_name) def get_batch_kwargs( - self, table: BaseTable, schema_name: str, db_name: str + self, + table: BaseTable, + schema_name: str, + db_name: str, ) -> dict: custom_sql = None if ( @@ -131,7 +139,8 @@ def get_batch_kwargs( } def get_profiler_instance( - self, db_name: Optional[str] = None + self, + db_name: Optional[str] = None, ) -> "DatahubGEProfiler": assert db_name diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 69d0b62a8edfdf..b44ae733b6c1a6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -111,7 +111,9 @@ class SnowflakeQueriesExtractorConfig(ConfigModel): class SnowflakeQueriesSourceConfig( - SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig + SnowflakeQueriesExtractorConfig, + SnowflakeIdentifierConfig, + SnowflakeFilterConfig, ): connection: SnowflakeConnectionConfig @@ -183,7 +185,7 @@ def __init__( is_temp_table=self.is_temp_table, is_allowed_table=self.is_allowed_table, format_queries=False, - ) + ), ) self.report.sql_aggregator = self.aggregator.report @@ -227,7 +229,8 @@ def is_allowed_table(self, name: str) -> bool: return False return self.filters.is_dataset_pattern_allowed( - name, SnowflakeObjectDomain.TABLE + name, + SnowflakeObjectDomain.TABLE, ) def get_workunits_internal( @@ -303,7 +306,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: ) with self.structured_reporter.report_exc( - "Error fetching copy history from Snowflake" + "Error fetching copy history from Snowflake", ): logger.info("Fetching copy history from Snowflake") resp = self.connection.query(query) @@ -328,7 +331,8 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: yield result def fetch_query_log( - self, users: UsersMapping + self, + users: UsersMapping, ) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]: query_log_query = _build_enriched_query_log_query( start_time=self.config.window.start_time, @@ -338,7 +342,7 @@ def fetch_query_log( ) with self.structured_reporter.report_exc( - "Error fetching query log from Snowflake" + "Error fetching query log from Snowflake", ): logger.info("Fetching query log from Snowflake") resp = self.connection.query(query_log_query) @@ -361,7 +365,9 @@ def fetch_query_log( yield entry def _parse_audit_log_row( - self, row: Dict[str, Any], users: UsersMapping + self, + row: Dict[str, Any], + users: UsersMapping, ) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]: json_fields = { "DIRECT_OBJECTS_ACCESSED", @@ -383,7 +389,7 @@ def _parse_audit_log_row( if object_modified_by_ddl and not objects_modified: known_ddl_entry: Optional[Union[TableRename, TableSwap]] = None with self.structured_reporter.report_exc( - "Error fetching ddl lineage from Snowflake" + "Error fetching ddl lineage from Snowflake", ): known_ddl_entry = self.parse_ddl_query( res["query_text"], @@ -404,14 +410,16 @@ def _parse_audit_log_row( for obj in direct_objects_accessed: dataset = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - obj["objectName"] - ) + obj["objectName"], + ), ) columns = set() for modified_column in obj["columns"]: columns.add( - self.identifiers.snowflake_identifier(modified_column["columnName"]) + self.identifiers.snowflake_identifier( + modified_column["columnName"], + ), ) upstreams.append(dataset) @@ -429,8 +437,8 @@ def _parse_audit_log_row( downstream = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - obj["objectName"] - ) + obj["objectName"], + ), ) column_lineage = [] for modified_column in obj["columns"]: @@ -439,31 +447,32 @@ def _parse_audit_log_row( downstream=DownstreamColumnRef( dataset=downstream, column=self.identifiers.snowflake_identifier( - modified_column["columnName"] + modified_column["columnName"], ), ), upstreams=[ ColumnRef( table=self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - upstream["objectName"] - ) + upstream["objectName"], + ), ), column=self.identifiers.snowflake_identifier( - upstream["columnName"] + upstream["columnName"], ), ) for upstream in modified_column["directSources"] if upstream["objectDomain"] in SnowflakeQuery.ACCESS_HISTORY_TABLE_VIEW_DOMAINS ], - ) + ), ) user = CorpUserUrn( self.identifiers.get_user_identifier( - res["user_name"], users.get(res["user_name"]) - ) + res["user_name"], + users.get(res["user_name"]), + ), ) timestamp: datetime = res["query_start_time"] @@ -471,7 +480,8 @@ def _parse_audit_log_row( # TODO need to map snowflake query types to ours query_type = SNOWFLAKE_QUERY_TYPE_MAPPING.get( - res["query_type"], QueryType.UNKNOWN + res["query_type"], + QueryType.UNKNOWN, ) entry = PreparsedQuery( @@ -479,7 +489,9 @@ def _parse_audit_log_row( # job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint # here query_id=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + res["query_text"], + self.identifiers.platform, + fast=True, ), query_text=res["query_text"], upstreams=upstreams, @@ -509,14 +521,14 @@ def parse_ddl_query( ] == "ALTER" and object_modified_by_ddl["properties"].get("swapTargetName"): urn1 = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - object_modified_by_ddl["objectName"] - ) + object_modified_by_ddl["objectName"], + ), ) urn2 = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - object_modified_by_ddl["properties"]["swapTargetName"]["value"] - ) + object_modified_by_ddl["properties"]["swapTargetName"]["value"], + ), ) return TableSwap(urn1, urn2, query, session_id, timestamp) @@ -525,14 +537,14 @@ def parse_ddl_query( ] == "RENAME_TABLE" and object_modified_by_ddl["properties"].get("objectName"): original_un = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - object_modified_by_ddl["objectName"] - ) + object_modified_by_ddl["objectName"], + ), ) new_urn = self.identifiers.gen_dataset_urn( self.identifiers.get_dataset_identifier_from_qualified_name( - object_modified_by_ddl["properties"]["objectName"]["value"] - ) + object_modified_by_ddl["properties"]["objectName"]["value"], + ), ) return TableRename(original_un, new_urn, query, session_id, timestamp) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py index 40bcfb514efd23..3c43b296e470bd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -12,7 +12,8 @@ def create_deny_regex_sql_filter( - deny_pattern: List[str], filter_cols: List[str] + deny_pattern: List[str], + filter_cols: List[str], ) -> str: upstream_sql_filter = ( " AND ".join( @@ -20,7 +21,7 @@ def create_deny_regex_sql_filter( (f"NOT RLIKE({col_name},'{regexp}','i')") for col_name in filter_cols for regexp in deny_pattern - ] + ], ) if deny_pattern else "" @@ -39,7 +40,7 @@ class SnowflakeQuery: } ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER = "({})".format( - ",".join(f"'{domain}'" for domain in ACCESS_HISTORY_TABLE_VIEW_DOMAINS) + ",".join(f"'{domain}'" for domain in ACCESS_HISTORY_TABLE_VIEW_DOMAINS), ) ACCESS_HISTORY_TABLE_DOMAINS_FILTER = ( "(" @@ -161,7 +162,9 @@ def tables_for_schema(schema_name: str, db_name: Optional[str]) -> str: @staticmethod def get_all_tags_on_object_with_propagation( - db_name: str, quoted_identifier: str, domain: str + db_name: str, + quoted_identifier: str, + domain: str, ) -> str: # https://docs.snowflake.com/en/sql-reference/functions/tag_references.html return f""" @@ -202,7 +205,8 @@ def get_all_tags_in_database_without_propagation(db_name: str) -> str: @staticmethod def get_tags_on_columns_with_propagation( - db_name: str, quoted_table_identifier: str + db_name: str, + quoted_table_identifier: str, ) -> str: # https://docs.snowflake.com/en/sql-reference/functions/tag_references_all_columns.html return f""" @@ -287,8 +291,10 @@ def columns_for_schema( selects.append( columns_template.format( - db_name=db_name, schema_name=schema_name, extra_clause=extra_clause - ) + db_name=db_name, + schema_name=schema_name, + extra_clause=extra_clause, + ), ) return ( @@ -308,7 +314,8 @@ def show_foreign_keys_for_schema(schema_name: str, db_name: str) -> str: @staticmethod def operational_data_for_time_window( - start_time_millis: int, end_time_millis: int + start_time_millis: int, + end_time_millis: int, ) -> str: return f""" SELECT @@ -650,7 +657,7 @@ def gen_email_filter_query(email_filter: AllowDenyPattern) -> str: else: for allow_pattern in email_filter.allow: allow_filters.append( - f"rlike(user_name, '{allow_pattern}','{'i' if email_filter.ignoreCase else 'c'}')" + f"rlike(user_name, '{allow_pattern}','{'i' if email_filter.ignoreCase else 'c'}')", ) if allow_filters: allow_filter = " OR ".join(allow_filters) @@ -659,7 +666,7 @@ def gen_email_filter_query(email_filter: AllowDenyPattern) -> str: deny_filter = "" for deny_pattern in email_filter.deny: deny_filters.append( - f"rlike(user_name, '{deny_pattern}','{'i' if email_filter.ignoreCase else 'c'}')" + f"rlike(user_name, '{deny_pattern}','{'i' if email_filter.ignoreCase else 'c'}')", ) if deny_filters: deny_filter = " OR ".join(deny_filters) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py index b24471f8666afa..ba6c86feb09e97 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -54,7 +54,7 @@ class SnowflakeUsageReport: stateful_usage_ingestion_enabled: bool = False usage_aggregation: SnowflakeUsageAggregationReport = field( - default_factory=SnowflakeUsageAggregationReport + default_factory=SnowflakeUsageAggregationReport, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index 173024aec0cf38..11865a81fe0488 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -138,7 +138,7 @@ def __init__(self) -> None: # self._schema_tags[][] = list of tags applied to schema self._schema_tags: Dict[str, Dict[str, List[SnowflakeTag]]] = defaultdict( - lambda: defaultdict(list) + lambda: defaultdict(list), ) # self._table_tags[][][] = list of tags applied to table @@ -148,9 +148,10 @@ def __init__(self) -> None: # self._column_tags[][][][] = list of tags applied to column self._column_tags: Dict[ - str, Dict[str, Dict[str, Dict[str, List[SnowflakeTag]]]] + str, + Dict[str, Dict[str, Dict[str, List[SnowflakeTag]]]], ] = defaultdict( - lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))), ) def add_database_tag(self, db_name: str, tag: SnowflakeTag) -> None: @@ -166,12 +167,19 @@ def get_schema_tags(self, schema_name: str, db_name: str) -> List[SnowflakeTag]: return self._schema_tags.get(db_name, {}).get(schema_name, []) def add_table_tag( - self, table_name: str, schema_name: str, db_name: str, tag: SnowflakeTag + self, + table_name: str, + schema_name: str, + db_name: str, + tag: SnowflakeTag, ) -> None: self._table_tags[db_name][schema_name][table_name].append(tag) def get_table_tags( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> List[SnowflakeTag]: return self._table_tags[db_name][schema_name][table_name] @@ -186,7 +194,10 @@ def add_column_tag( self._column_tags[db_name][schema_name][table_name][column_name].append(tag) def get_column_tags_for_table( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> Dict[str, List[SnowflakeTag]]: return ( self._column_tags.get(db_name, {}).get(schema_name, {}).get(table_name, {}) @@ -272,7 +283,7 @@ def get_schemas_for_database(self, db_name: str) -> List[SnowflakeSchema]: @serialized_lru_cache(maxsize=1) def get_secure_view_definitions(self) -> Dict[str, Dict[str, Dict[str, str]]]: secure_view_definitions: Dict[str, Dict[str, Dict[str, str]]] = defaultdict( - lambda: defaultdict(lambda: defaultdict()) + lambda: defaultdict(lambda: defaultdict()), ) cur = self.connection.query(SnowflakeQuery.get_secure_view_definitions()) for view in cur: @@ -287,7 +298,8 @@ def get_secure_view_definitions(self) -> Dict[str, Dict[str, Dict[str, str]]]: @serialized_lru_cache(maxsize=1) def get_tables_for_database( - self, db_name: str + self, + db_name: str, ) -> Optional[Dict[str, List[SnowflakeTable]]]: tables: Dict[str, List[SnowflakeTable]] = {} try: @@ -296,7 +308,8 @@ def get_tables_for_database( ) except Exception as e: logger.debug( - f"Failed to get all tables for database - {db_name}", exc_info=e + f"Failed to get all tables for database - {db_name}", + exc_info=e, ) # Error - Information schema query returned too much data. Please repeat query with more selective predicates. return None @@ -317,12 +330,14 @@ def get_tables_for_database( clustering_key=table["CLUSTERING_KEY"], is_dynamic=table.get("IS_DYNAMIC", "NO").upper() == "YES", is_iceberg=table.get("IS_ICEBERG", "NO").upper() == "YES", - ) + ), ) return tables def get_tables_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> List[SnowflakeTable]: tables: List[SnowflakeTable] = [] @@ -343,7 +358,7 @@ def get_tables_for_schema( clustering_key=table["CLUSTERING_KEY"], is_dynamic=table.get("IS_DYNAMIC", "NO").upper() == "YES", is_iceberg=table.get("IS_ICEBERG", "NO").upper() == "YES", - ) + ), ) return tables @@ -361,7 +376,7 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]] db_name, limit=page_limit, view_pagination_marker=view_pagination_marker, - ) + ), ) first_iteration = False @@ -387,13 +402,13 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]] view.get("is_materialized", "false").lower() == "true" ), is_secure=(view.get("is_secure", "false").lower() == "true"), - ) + ), ) if result_set_size >= page_limit: # If we hit the limit, we need to send another request to get the next page. logger.info( - f"Fetching next page of views for {db_name} - after {view_name}" + f"Fetching next page of views for {db_name} - after {view_name}", ) view_pagination_marker = view_name @@ -415,15 +430,19 @@ def get_columns_for_schema( columns = FileBackedDict() object_batches = build_prefix_batches( - all_objects, max_batch_size=10000, max_groups_in_batch=5 + all_objects, + max_batch_size=10000, + max_groups_in_batch=5, ) for batch_index, object_batch in enumerate(object_batches): if batch_index > 0: logger.info( - f"Still fetching columns for {db_name}.{schema_name} - batch {batch_index + 1} of {len(object_batches)}" + f"Still fetching columns for {db_name}.{schema_name} - batch {batch_index + 1} of {len(object_batches)}", ) query = SnowflakeQuery.columns_for_schema( - schema_name, db_name, object_batch + schema_name, + db_name, + object_batch, ) cur = self.connection.query(query) @@ -441,13 +460,15 @@ def get_columns_for_schema( character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"], numeric_precision=column["NUMERIC_PRECISION"], numeric_scale=column["NUMERIC_SCALE"], - ) + ), ) return columns @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_pk_constraints_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> Dict[str, SnowflakePK]: constraints: Dict[str, SnowflakePK] = {} cur = self.connection.query( @@ -457,14 +478,17 @@ def get_pk_constraints_for_schema( for row in cur: if row["table_name"] not in constraints: constraints[row["table_name"]] = SnowflakePK( - name=row["constraint_name"], column_names=[] + name=row["constraint_name"], + column_names=[], ) constraints[row["table_name"]].column_names.append(row["column_name"]) return constraints @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_fk_constraints_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> Dict[str, List[SnowflakeFK]]: constraints: Dict[str, List[SnowflakeFK]] = {} fk_constraints_map: Dict[str, SnowflakeFK] = {} @@ -488,10 +512,10 @@ def get_fk_constraints_for_schema( constraints[row["fk_table_name"]] = [] fk_constraints_map[row["fk_name"]].column_names.append( - row["fk_column_name"] + row["fk_column_name"], ) fk_constraints_map[row["fk_name"]].referred_column_names.append( - row["pk_column_name"] + row["pk_column_name"], ) constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]]) @@ -502,7 +526,7 @@ def get_tags_for_database_without_propagation( db_name: str, ) -> _SnowflakeTagCache: cur = self.connection.query( - SnowflakeQuery.get_all_tags_in_database_without_propagation(db_name) + SnowflakeQuery.get_all_tags_in_database_without_propagation(db_name), ) tags = _SnowflakeTagCache() @@ -530,7 +554,10 @@ def get_tags_for_database_without_propagation( tags.add_schema_tag(object_name, object_database, snowflake_tag) elif domain == SnowflakeObjectDomain.TABLE: # including views tags.add_table_tag( - object_name, object_schema, object_database, snowflake_tag + object_name, + object_schema, + object_database, + snowflake_tag, ) elif domain == SnowflakeObjectDomain.COLUMN: column_name = tag["COLUMN_NAME"] @@ -558,7 +585,9 @@ def get_tags_for_object_with_propagation( cur = self.connection.query( SnowflakeQuery.get_all_tags_on_object_with_propagation( - db_name, quoted_identifier, domain + db_name, + quoted_identifier, + domain, ), ) @@ -569,17 +598,20 @@ def get_tags_for_object_with_propagation( schema=tag["TAG_SCHEMA"], name=tag["TAG_NAME"], value=tag["TAG_VALUE"], - ) + ), ) return tags def get_tags_on_columns_for_table( - self, quoted_table_name: str, db_name: str + self, + quoted_table_name: str, + db_name: str, ) -> Dict[str, List[SnowflakeTag]]: tags: Dict[str, List[SnowflakeTag]] = defaultdict(list) cur = self.connection.query( SnowflakeQuery.get_tags_on_columns_with_propagation( - db_name, quoted_table_name + db_name, + quoted_table_name, ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py index a2d69d9e552916..7efe19b3a78316 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -184,14 +184,16 @@ def __init__( self.identifiers: SnowflakeIdentifierBuilder = identifiers self.data_dictionary: SnowflakeDataDictionary = SnowflakeDataDictionary( - connection=self.connection + connection=self.connection, ) self.report.data_dictionary_cache = self.data_dictionary self.domain_registry: Optional[DomainRegistry] = domain_registry self.classification_handler = ClassificationHandler(self.config, self.report) self.tag_extractor = SnowflakeTagExtractor( - config, self.data_dictionary, self.report + config, + self.data_dictionary, + self.report, ) self.profiler: Optional[SnowflakeProfiler] = profiler self.snowsight_url_builder: Optional[SnowsightUrlBuilder] = ( @@ -231,14 +233,16 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: try: for snowflake_db in self.databases: with self.report.new_stage( - f"{snowflake_db.name}: {METADATA_EXTRACTION}" + f"{snowflake_db.name}: {METADATA_EXTRACTION}", ): yield from self._process_database(snowflake_db) with self.report.new_stage(f"*: {EXTERNAL_TABLE_DDL_LINEAGE}"): discovered_tables: List[str] = [ self.identifiers.get_dataset_identifier( - table_name, schema.name, db.name + table_name, + schema.name, + db.name, ) for db in self.databases for schema in db.schemas @@ -279,7 +283,8 @@ def get_databases(self) -> Optional[List[SnowflakeDatabase]]: return ischema_databases def get_databases_from_ischema( - self, databases: List[SnowflakeDatabase] + self, + databases: List[SnowflakeDatabase], ) -> List[SnowflakeDatabase]: ischema_databases: List[SnowflakeDatabase] = [] for database in databases: @@ -291,19 +296,20 @@ def get_databases_from_ischema( # This is okay, because `show databases` query lists all databases irrespective of permission, # if role has `MANAGE GRANTS` privilege. (not advisable) logger.debug( - f"Failed to list databases {database.name} information_schema" + f"Failed to list databases {database.name} information_schema", ) # SNOWFLAKE database always shows up even if permissions are missing if database == SNOWFLAKE_DATABASE: continue logger.info( - f"The role {self.report.role} has `MANAGE GRANTS` privilege. This is not advisable and also not required." + f"The role {self.report.role} has `MANAGE GRANTS` privilege. This is not advisable and also not required.", ) return ischema_databases def _process_database( - self, snowflake_db: SnowflakeDatabase + self, + snowflake_db: SnowflakeDatabase, ) -> Iterable[MetadataWorkUnit]: db_name = snowflake_db.name @@ -325,13 +331,16 @@ def _process_database( exc_info=e, ) self.structured_reporter.warning( - "Failed to get schemas for database", db_name, exc=e + "Failed to get schemas for database", + db_name, + exc=e, ) return if self.config.extract_tags != TagOption.skip: snowflake_db.tags = self.tag_extractor.get_tags_on_object( - domain="database", db_name=db_name + domain="database", + db_name=db_name, ) if self.config.include_technical_schema: @@ -360,7 +369,9 @@ def _process_schema_worker( snowflake_schema: SnowflakeSchema, ) -> Iterable[MetadataWorkUnit]: for wu in self._process_schema( - snowflake_schema, snowflake_db.name, db_tables + snowflake_schema, + snowflake_db.name, + db_tables, ): yield wu @@ -374,7 +385,9 @@ def _process_schema_worker( yield wu def fetch_schemas_for_database( - self, snowflake_db: SnowflakeDatabase, db_name: str + self, + snowflake_db: SnowflakeDatabase, + db_name: str, ) -> None: schemas: List[SnowflakeSchema] = [] try: @@ -419,7 +432,9 @@ def _process_schema( if self.config.extract_tags != TagOption.skip: snowflake_schema.tags = self.tag_extractor.get_tags_on_object( - schema_name=schema_name, db_name=db_name, domain="schema" + schema_name=schema_name, + db_name=db_name, + domain="schema", ) if self.config.include_technical_schema: @@ -428,7 +443,9 @@ def _process_schema( # We need to do this first so that we can use it when fetching columns. if self.config.include_tables: tables = self.fetch_tables_for_schema( - snowflake_schema, db_name, schema_name + snowflake_schema, + db_name, + schema_name, ) if self.config.include_views: views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) @@ -440,7 +457,9 @@ def _process_schema( data_reader = self.make_data_reader() for table in tables: table_wu_generator = self._process_table( - table, snowflake_schema, db_name + table, + snowflake_schema, + db_name, ) yield from classification_workunit_processor( @@ -454,11 +473,15 @@ def _process_schema( if self.aggregator: for view in views: view_identifier = self.identifiers.get_dataset_identifier( - view.name, schema_name, db_name + view.name, + schema_name, + db_name, ) if view.is_secure and not view.view_definition: view.view_definition = self.fetch_secure_view_definition( - view.name, schema_name, db_name + view.name, + schema_name, + db_name, ) if view.view_definition: self.aggregator.add_view_definition( @@ -486,7 +509,10 @@ def _process_schema( ) def fetch_secure_view_definition( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> Optional[str]: try: view_definitions = self.data_dictionary.get_secure_view_definitions() @@ -505,13 +531,18 @@ def fetch_secure_view_definition( return None def fetch_views_for_schema( - self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + self, + snowflake_schema: SnowflakeSchema, + db_name: str, + schema_name: str, ) -> List[SnowflakeView]: try: views: List[SnowflakeView] = [] for view in self.get_views_for_schema(schema_name, db_name): view_name = self.identifiers.get_dataset_identifier( - view.name, schema_name, db_name + view.name, + schema_name, + db_name, ) self.report.report_entity_scanned(view_name, "view") @@ -537,17 +568,22 @@ def fetch_views_for_schema( return [] def fetch_tables_for_schema( - self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + self, + snowflake_schema: SnowflakeSchema, + db_name: str, + schema_name: str, ) -> List[SnowflakeTable]: try: tables: List[SnowflakeTable] = [] for table in self.get_tables_for_schema(schema_name, db_name): table_identifier = self.identifiers.get_dataset_identifier( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) self.report.report_entity_scanned(table_identifier) if not self.filters.filter_config.table_pattern.allowed( - table_identifier + table_identifier, ): self.report.report_dropped(table_identifier) else: @@ -570,7 +606,8 @@ def fetch_tables_for_schema( def make_data_reader(self) -> Optional[SnowflakeDataReader]: if self.classification_handler.is_classification_enabled() and self.connection: return SnowflakeDataReader.create( - self.connection, self.snowflake_identifier + self.connection, + self.snowflake_identifier, ) return None @@ -583,21 +620,29 @@ def _process_table( ) -> Iterable[MetadataWorkUnit]: schema_name = snowflake_schema.name table_identifier = self.identifiers.get_dataset_identifier( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) try: table.columns = self.get_columns_for_table( - table.name, snowflake_schema, db_name + table.name, + snowflake_schema, + db_name, ) table.column_count = len(table.columns) if self.config.extract_tags != TagOption.skip: table.column_tags = self.tag_extractor.get_column_tags_for_table( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) except Exception as e: self.structured_reporter.warning( - "Failed to get columns for table", table_identifier, exc=e + "Failed to get columns for table", + table_identifier, + exc=e, ) if self.config.extract_tags != TagOption.skip: @@ -614,7 +659,10 @@ def _process_table( if self.config.include_foreign_keys: self.fetch_foreign_keys_for_table( - table, schema_name, db_name, table_identifier + table, + schema_name, + db_name, + table_identifier, ) yield from self.gen_dataset_workunits(table, schema_name, db_name) @@ -628,11 +676,15 @@ def fetch_foreign_keys_for_table( ) -> None: try: table.foreign_keys = self.get_fk_constraints_for_table( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) except Exception as e: self.structured_reporter.warning( - "Failed to get foreign keys for table", table_identifier, exc=e + "Failed to get foreign keys for table", + table_identifier, + exc=e, ) def fetch_pk_for_table( @@ -644,11 +696,15 @@ def fetch_pk_for_table( ) -> None: try: table.pk = self.get_pk_constraints_for_table( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) except Exception as e: self.structured_reporter.warning( - "Failed to get primary key for table", table_identifier, exc=e + "Failed to get primary key for table", + table_identifier, + exc=e, ) def _process_view( @@ -659,20 +715,28 @@ def _process_view( ) -> Iterable[MetadataWorkUnit]: schema_name = snowflake_schema.name view_name = self.identifiers.get_dataset_identifier( - view.name, schema_name, db_name + view.name, + schema_name, + db_name, ) try: view.columns = self.get_columns_for_table( - view.name, snowflake_schema, db_name + view.name, + snowflake_schema, + db_name, ) if self.config.extract_tags != TagOption.skip: view.column_tags = self.tag_extractor.get_column_tags_for_table( - view.name, schema_name, db_name + view.name, + schema_name, + db_name, ) except Exception as e: self.structured_reporter.warning( - "Failed to get columns for view", view_name, exc=e + "Failed to get columns for view", + view_name, + exc=e, ) if self.config.extract_tags != TagOption.skip: @@ -704,11 +768,12 @@ def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: yield from self.gen_tag_workunits(tag) def _format_tags_as_structured_properties( - self, tags: List[SnowflakeTag] + self, + tags: List[SnowflakeTag], ) -> Dict[StructuredPropertyUrn, str]: return { StructuredPropertyUrn( - self.snowflake_identifier(tag.structured_property_identifier()) + self.snowflake_identifier(tag.structured_property_identifier()), ): tag.value for tag in tags } @@ -727,25 +792,30 @@ def gen_dataset_workunits( yield from self._process_tag(tag) dataset_name = self.identifiers.get_dataset_identifier( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) dataset_urn = self.identifiers.gen_dataset_urn(dataset_name) status = Status(removed=False) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=status + entityUrn=dataset_urn, + aspect=status, ).as_workunit() schema_metadata = self.gen_schema_metadata(table, schema_name, db_name) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=schema_metadata + entityUrn=dataset_urn, + aspect=schema_metadata, ).as_workunit() dataset_properties = self.get_dataset_properties(table, schema_name, db_name) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties + entityUrn=dataset_urn, + aspect=dataset_properties, ).as_workunit() schema_container_key = gen_schema_key( @@ -776,11 +846,12 @@ def gen_dataset_workunits( [DatasetSubTypes.VIEW] if isinstance(table, SnowflakeView) else [DatasetSubTypes.TABLE] - ) + ), ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=subTypes + entityUrn=dataset_urn, + aspect=subTypes, ).as_workunit() if self.domain_registry: @@ -801,14 +872,15 @@ def gen_dataset_workunits( tag_associations = [ TagAssociation( tag=make_tag_urn( - self.snowflake_identifier(tag.tag_identifier()) - ) + self.snowflake_identifier(tag.tag_identifier()), + ), ) for tag in table.tags ] global_tags = GlobalTags(tag_associations) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=global_tags + entityUrn=dataset_urn, + aspect=global_tags, ).as_workunit() if isinstance(table, SnowflakeView) and table.view_definition is not None: @@ -823,7 +895,8 @@ def gen_dataset_workunits( ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=view_properties_aspect + entityUrn=dataset_urn, + aspect=view_properties_aspect, ).as_workunit() def get_dataset_properties( @@ -890,11 +963,13 @@ def gen_tag_workunits(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: ) yield MetadataChangeProposalWrapper( - entityUrn=tag_urn, aspect=tag_properties_aspect + entityUrn=tag_urn, + aspect=tag_properties_aspect, ).as_workunit() def gen_tag_as_structured_property_workunits( - self, tag: SnowflakeTag + self, + tag: SnowflakeTag, ) -> Iterable[MetadataWorkUnit]: identifier = self.snowflake_identifier(tag.structured_property_identifier()) urn = StructuredPropertyUrn(identifier).urn() @@ -908,7 +983,8 @@ def gen_tag_as_structured_property_workunits( EntityTypeUrn(f"datahub.{SchemaFieldUrn.ENTITY_TYPE}").urn(), ], lastModified=AuditStamp( - time=get_sys_time(), actor="urn:li:corpuser:datahub" + time=get_sys_time(), + actor="urn:li:corpuser:datahub", ), ) yield MetadataChangeProposalWrapper( @@ -917,14 +993,16 @@ def gen_tag_as_structured_property_workunits( ).as_workunit() def gen_column_tags_as_structured_properties( - self, dataset_urn: str, table: Union[SnowflakeTable, SnowflakeView] + self, + dataset_urn: str, + table: Union[SnowflakeTable, SnowflakeView], ) -> Iterable[MetadataWorkUnit]: for column_name in table.column_tags: schema_field_urn = SchemaFieldUrn(dataset_urn, column_name).urn() yield from add_structured_properties_to_entity_wu( schema_field_urn, self._format_tags_as_structured_properties( - table.column_tags[column_name] + table.column_tags[column_name], ), ) @@ -935,7 +1013,9 @@ def gen_schema_metadata( db_name: str, ) -> SchemaMetadata: dataset_name = self.identifiers.get_dataset_identifier( - table.name, schema_name, db_name + table.name, + schema_name, + db_name, ) dataset_urn = self.identifiers.gen_dataset_urn(dataset_name) @@ -953,7 +1033,7 @@ def gen_schema_metadata( SchemaField( fieldPath=self.snowflake_identifier(col.name), type=SchemaFieldDataType( - SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() + SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)(), ), # NOTE: nativeDataType will not be in sync with older connector nativeDataType=col.get_precise_native_type(), @@ -969,11 +1049,11 @@ def gen_schema_metadata( [ TagAssociation( make_tag_urn( - self.snowflake_identifier(tag.tag_identifier()) - ) + self.snowflake_identifier(tag.tag_identifier()), + ), ) for tag in table.column_tags[col.name] - ] + ], ) if col.name in table.column_tags and not self.config.extract_tags_as_structured_properties @@ -991,14 +1071,18 @@ def gen_schema_metadata( return schema_metadata def build_foreign_keys( - self, table: SnowflakeTable, dataset_urn: str + self, + table: SnowflakeTable, + dataset_urn: str, ) -> List[ForeignKeyConstraint]: foreign_keys = [] for fk in table.foreign_keys: foreign_dataset = make_dataset_urn_with_platform_instance( platform=self.platform, name=self.identifiers.get_dataset_identifier( - fk.referred_table, fk.referred_schema, fk.referred_database + fk.referred_table, + fk.referred_schema, + fk.referred_database, ), env=self.config.env, platform_instance=self.config.platform_instance, @@ -1021,12 +1105,13 @@ def build_foreign_keys( ) for col in fk.column_names ], - ) + ), ) return foreign_keys def gen_database_containers( - self, database: SnowflakeDatabase + self, + database: SnowflakeDatabase, ) -> Iterable[MetadataWorkUnit]: database_container_key = gen_database_key( self.snowflake_identifier(database.name), @@ -1079,7 +1164,9 @@ def gen_database_containers( ) def gen_schema_containers( - self, schema: SnowflakeSchema, db_name: str + self, + schema: SnowflakeSchema, + db_name: str, ) -> Iterable[MetadataWorkUnit]: schema_name = self.snowflake_identifier(schema.name) database_container_key = gen_database_key( @@ -1109,7 +1196,8 @@ def gen_schema_containers( description=schema.comment, external_url=( self.snowsight_url_builder.get_external_url_for_schema( - schema.name, db_name + schema.name, + db_name, ) if self.snowsight_url_builder else None @@ -1137,7 +1225,9 @@ def gen_schema_containers( ) def get_tables_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> List[SnowflakeTable]: tables = self.data_dictionary.get_tables_for_database(db_name) @@ -1151,7 +1241,9 @@ def get_tables_for_schema( return tables.get(schema_name, []) def get_views_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> List[SnowflakeView]: views = self.data_dictionary.get_views_for_database(db_name) @@ -1159,14 +1251,18 @@ def get_views_for_schema( return views.get(schema_name, []) def get_columns_for_table( - self, table_name: str, snowflake_schema: SnowflakeSchema, db_name: str + self, + table_name: str, + snowflake_schema: SnowflakeSchema, + db_name: str, ) -> List[SnowflakeColumn]: schema_name = snowflake_schema.name columns = self.data_dictionary.get_columns_for_schema( schema_name, db_name, cache_exclude_all_objects=itertools.chain( - snowflake_schema.tables, snowflake_schema.views + snowflake_schema.tables, + snowflake_schema.views, ), ) @@ -1174,20 +1270,28 @@ def get_columns_for_table( return columns.get(table_name, []) def get_pk_constraints_for_table( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> Optional[SnowflakePK]: constraints = self.data_dictionary.get_pk_constraints_for_schema( - schema_name, db_name + schema_name, + db_name, ) # Access to table but none of its constraints - is this possible ? return constraints.get(table_name) def get_fk_constraints_for_table( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> List[SnowflakeFK]: constraints = self.data_dictionary.get_fk_constraints_for_schema( - schema_name, db_name + schema_name, + db_name, ) # Access to table but none of its constraints - is this possible ? @@ -1196,13 +1300,16 @@ def get_fk_constraints_for_table( # Handles the case for explicitly created external tables. # NOTE: Snowflake does not log this information to the access_history table. def _external_tables_ddl_lineage( - self, discovered_tables: List[str] + self, + discovered_tables: List[str], ) -> Iterable[KnownLineageMapping]: external_tables_query: str = SnowflakeQuery.show_external_tables() try: for db_row in self.connection.query(external_tables_query): key = self.identifiers.get_dataset_identifier( - db_row["name"], db_row["schema_name"], db_row["database_name"] + db_row["name"], + db_row["schema_name"], + db_row["database_name"], ) if key not in discovered_tables: @@ -1210,7 +1317,8 @@ def _external_tables_ddl_lineage( if db_row["location"].startswith("s3://"): yield KnownLineageMapping( upstream_urn=make_s3_urn_for_lineage( - db_row["location"], self.config.env + db_row["location"], + self.config.env, ), downstream_urn=self.identifiers.gen_dataset_urn(key), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 606acd53dc3324..e877e3637f3c2d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -31,7 +31,8 @@ def __init__( self.report = report def get_shares_workunits( - self, databases: List[SnowflakeDatabase] + self, + databases: List[SnowflakeDatabase], ) -> Iterable[MetadataWorkUnit]: inbounds = self.config.inbounds() outbounds = self.config.outbounds() @@ -74,11 +75,16 @@ def get_shares_workunits( # hence this lineage code is not written in SnowflakeLineageExtractor # also this is not governed by configs include_table_lineage yield self.get_upstream_lineage_with_primary_sibling( - db.name, schema.name, table_name, sibling_dbs[0] + db.name, + schema.name, + table_name, + sibling_dbs[0], ) self.report_missing_databases( - databases, list(inbounds.keys()), list(outbounds.keys()) + databases, + list(inbounds.keys()), + list(outbounds.keys()), ) def report_missing_databases( @@ -112,7 +118,9 @@ def gen_siblings( if not sibling_databases: return dataset_identifier = self.identifiers.get_dataset_identifier( - table_name, schema_name, database_name + table_name, + schema_name, + database_name, ) urn = self.identifiers.gen_dataset_urn(dataset_identifier) @@ -120,7 +128,9 @@ def gen_siblings( make_dataset_urn_with_platform_instance( self.identifiers.platform, self.identifiers.get_dataset_identifier( - table_name, schema_name, sibling_db.database + table_name, + schema_name, + sibling_db.database, ), sibling_db.platform_instance, ) @@ -140,14 +150,18 @@ def get_upstream_lineage_with_primary_sibling( primary_sibling_db: DatabaseId, ) -> MetadataWorkUnit: dataset_identifier = self.identifiers.get_dataset_identifier( - table_name, schema_name, database_name + table_name, + schema_name, + database_name, ) urn = self.identifiers.gen_dataset_urn(dataset_identifier) upstream_urn = make_dataset_urn_with_platform_instance( self.identifiers.platform, self.identifiers.get_dataset_identifier( - table_name, schema_name, primary_sibling_db.database + table_name, + schema_name, + primary_sibling_db.database, ), primary_sibling_db.platform_instance, ) @@ -155,6 +169,8 @@ def get_upstream_lineage_with_primary_sibling( return MetadataChangeProposalWrapper( entityUrn=urn, aspect=UpstreamLineage( - upstreams=[Upstream(dataset=upstream_urn, type=DatasetLineageType.COPY)] + upstreams=[ + Upstream(dataset=upstream_urn, type=DatasetLineageType.COPY), + ], ), ).as_workunit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py index 72952f6b76e8b4..7ca08188e24dcd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py @@ -44,7 +44,7 @@ class SnowflakeSummaryReport(SourceReport, BaseTimeWindowReport): schema_counters: Dict[str, int] = dataclasses.field(default_factory=dict) object_counters: Dict[str, Dict[str, int]] = dataclasses.field( - default_factory=lambda: defaultdict(lambda: defaultdict(int)) + default_factory=lambda: defaultdict(lambda: defaultdict(int)), ) num_snowflake_queries: Optional[int] = None @@ -101,14 +101,18 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for schema in database.schemas: # Tables/views. tables = schema_generator.fetch_tables_for_schema( - schema, database.name, schema.name + schema, + database.name, + schema.name, ) views = schema_generator.fetch_views_for_schema( - schema, database.name, schema.name + schema, + database.name, + schema.name, ) self.report.object_counters[database.name][schema.name] = len( - tables + tables, ) + len(views) # Queries for usage. @@ -120,7 +124,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: FROM snowflake.account_usage.query_history WHERE query_history.start_time >= to_timestamp_ltz({start_time_millis}, 3) AND query_history.start_time < to_timestamp_ltz({end_time_millis}, 3) -""" +""", ): self.report.num_snowflake_queries = row["CNT"] @@ -134,7 +138,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: AND query_start_time < to_timestamp_ltz({end_time_millis}, 3) AND access_history.objects_modified is not null AND ARRAY_SIZE(access_history.objects_modified) > 0 -""" +""", ): self.report.num_snowflake_mutations = row["CNT"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py index 597e7bee4d4cc0..4d492f650fcbc6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py @@ -53,7 +53,9 @@ def _get_tags_on_object_without_propagation( assert schema_name is not None assert table_name is not None tags = self.tag_cache[db_name].get_table_tags( - table_name, schema_name, db_name + table_name, + schema_name, + db_name, ) else: raise ValueError(f"Unknown domain {domain}") @@ -72,7 +74,8 @@ def _get_tags_on_object_with_propagation( elif domain == SnowflakeObjectDomain.SCHEMA: assert schema_name is not None identifier = self.identifiers.get_quoted_identifier_for_schema( - db_name, schema_name + db_name, + schema_name, ) elif ( domain == SnowflakeObjectDomain.TABLE @@ -80,7 +83,9 @@ def _get_tags_on_object_with_propagation( assert schema_name is not None assert table_name is not None identifier = self.identifiers.get_quoted_identifier_for_table( - db_name, schema_name, table_name + db_name, + schema_name, + table_name, ) else: raise ValueError(f"Unknown domain {domain}") @@ -88,7 +93,9 @@ def _get_tags_on_object_with_propagation( self.report.num_get_tags_for_object_queries += 1 tags = self.data_dictionary.get_tags_for_object_with_propagation( - domain=domain, quoted_identifier=identifier, db_name=db_name + domain=domain, + quoted_identifier=identifier, + db_name=db_name, ) return tags @@ -132,17 +139,21 @@ def get_column_tags_for_table( if db_name not in self.tag_cache: self.tag_cache[db_name] = ( self.data_dictionary.get_tags_for_database_without_propagation( - db_name + db_name, ) ) temp_column_tags = self.tag_cache[db_name].get_column_tags_for_table( - table_name, schema_name, db_name + table_name, + schema_name, + db_name, ) elif self.config.extract_tags == TagOption.with_lineage: self.report.num_get_tags_on_columns_for_table_queries += 1 temp_column_tags = self.data_dictionary.get_tags_on_columns_for_table( quoted_table_name=self.identifiers.get_quoted_identifier_for_table( - db_name, schema_name, table_name + db_name, + schema_name, + table_name, ), db_name=db_name, ) @@ -158,7 +169,8 @@ def get_column_tags_for_table( return column_tags def _filter_tags( - self, tags: Optional[List[SnowflakeTag]] + self, + tags: Optional[List[SnowflakeTag]], ) -> Optional[List[SnowflakeTag]]: if tags is None: return tags diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index edd13ee48326bb..b86b16c8357199 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -135,13 +135,15 @@ def __init__( def get_time_window(self) -> Tuple[datetime, datetime]: if self.redundant_run_skip_handler: return self.redundant_run_skip_handler.suggest_run_time_window( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) else: return self.config.start_time, self.config.end_time def get_usage_workunits( - self, discovered_datasets: List[str] + self, + discovered_datasets: List[str], ) -> Iterable[MetadataWorkUnit]: if not self._should_ingest_usage(): return @@ -149,7 +151,7 @@ def get_usage_workunits( with self.report.new_stage(f"*: {USAGE_EXTRACTION_USAGE_AGGREGATION}"): if self.report.edition == SnowflakeEdition.STANDARD.value: logger.info( - "Snowflake Account is Standard Edition. Usage and Operation History Feature is not supported." + "Snowflake Account is Standard Edition. Usage and Operation History Feature is not supported.", ) return @@ -187,7 +189,8 @@ def get_usage_workunits( access_events = self._get_snowflake_history() for event in access_events: yield from self._get_operation_aspect_work_unit( - event, discovered_datasets + event, + discovered_datasets, ) if self.redundant_run_skip_handler: @@ -199,7 +202,8 @@ def get_usage_workunits( ) def _get_workunits_internal( - self, discovered_datasets: List[str] + self, + discovered_datasets: List[str], ) -> Iterable[MetadataWorkUnit]: with PerfTimer() as timer: logger.info("Getting aggregated usage statistics") @@ -241,36 +245,40 @@ def _get_workunits_internal( row["OBJECT_DOMAIN"], ): logger.debug( - f"Skipping usage for {row['OBJECT_DOMAIN']} {row['OBJECT_NAME']}, as table is not allowed by recipe." + f"Skipping usage for {row['OBJECT_DOMAIN']} {row['OBJECT_NAME']}, as table is not allowed by recipe.", ) continue dataset_identifier = ( self.identifiers.get_dataset_identifier_from_qualified_name( - row["OBJECT_NAME"] + row["OBJECT_NAME"], ) ) if dataset_identifier not in discovered_datasets: logger.debug( - f"Skipping usage for {row['OBJECT_DOMAIN']} {dataset_identifier}, as table is not accessible." + f"Skipping usage for {row['OBJECT_DOMAIN']} {dataset_identifier}, as table is not accessible.", ) continue with skip_timer.pause(), self.report.usage_aggregation.result_map_timer as map_timer: wu = self.build_usage_statistics_for_dataset( - dataset_identifier, row + dataset_identifier, + row, ) if wu: with map_timer.pause(): yield wu def build_usage_statistics_for_dataset( - self, dataset_identifier: str, row: dict + self, + dataset_identifier: str, + row: dict, ) -> Optional[MetadataWorkUnit]: try: stats = DatasetUsageStatistics( timestampMillis=int(row["BUCKET_START_TIME"].timestamp() * 1000), eventGranularity=TimeWindowSize( - unit=self.config.bucket_duration, multiple=1 + unit=self.config.bucket_duration, + multiple=1, ), totalSqlQueries=row["TOTAL_QUERIES"], uniqueUserCount=row["TOTAL_USERS"], @@ -292,7 +300,8 @@ def build_usage_statistics_for_dataset( exc_info=e, ) self.report.warning( - "Failed to parse usage statistics for dataset", dataset_identifier + "Failed to parse usage statistics for dataset", + dataset_identifier, ) return None @@ -301,19 +310,20 @@ def _map_top_sql_queries(self, top_sql_queries_str: str) -> List[str]: with self.report.usage_aggregation.queries_map_timer: top_sql_queries = json.loads(top_sql_queries_str) budget_per_query: int = int( - self.config.queries_character_limit / self.config.top_n_queries + self.config.queries_character_limit / self.config.top_n_queries, ) return sorted( [ ( trim_query( - try_format_query(query, self.platform), budget_per_query + try_format_query(query, self.platform), + budget_per_query, ) if self.config.format_sql_queries else trim_query(query, budget_per_query) ) for query in top_sql_queries - ] + ], ) def _map_user_counts( @@ -331,10 +341,11 @@ def _map_user_counts( and user_count["user_name"] ): user_email = "{}@{}".format( - user_count["user_name"], self.config.email_domain + user_count["user_name"], + self.config.email_domain, ).lower() if not user_email or not self.config.user_email_pattern.allowed( - user_email + user_email, ): continue @@ -344,13 +355,13 @@ def _map_user_counts( self.identifiers.get_user_identifier( user_count["user_name"], user_email, - ) + ), ), count=user_count["total"], # NOTE: Generated emails may be incorrect, as email may be different than # username@email_domain userEmail=user_email, - ) + ), ) return sorted(filtered_user_counts, key=lambda v: v.user) @@ -361,7 +372,7 @@ def _map_field_counts(self, field_counts_str: str) -> List[DatasetFieldUsageCoun [ DatasetFieldUsageCounts( fieldPath=self.identifiers.snowflake_identifier( - field_count["col"] + field_count["col"], ), count=field_count["total"], ) @@ -400,13 +411,14 @@ def _check_usage_date_ranges(self) -> None: try: assert self.connection is not None results = self.connection.query( - SnowflakeQuery.get_access_history_date_range() + SnowflakeQuery.get_access_history_date_range(), ) except Exception as e: if isinstance(e, SnowflakePermissionError): error_msg = "Failed to get usage. Please grant imported privileges on SNOWFLAKE database. " self.warn_if_stateful_else_error( - "usage-permission-error", error_msg + "usage-permission-error", + error_msg, ) else: logger.debug(e, exc_info=e) @@ -428,17 +440,19 @@ def _check_usage_date_ranges(self) -> None: ) break self.report.min_access_history_time = db_row["MIN_TIME"].astimezone( - tz=timezone.utc + tz=timezone.utc, ) self.report.max_access_history_time = db_row["MAX_TIME"].astimezone( - tz=timezone.utc + tz=timezone.utc, ) self.report.access_history_range_query_secs = timer.elapsed_seconds( - digits=2 + digits=2, ) def _get_operation_aspect_work_unit( - self, event: SnowflakeJoinedAccessEvent, discovered_datasets: List[str] + self, + event: SnowflakeJoinedAccessEvent, + discovered_datasets: List[str], ) -> Iterable[MetadataWorkUnit]: if event.query_start_time and event.query_type: start_time = event.query_start_time @@ -446,12 +460,13 @@ def _get_operation_aspect_work_unit( user_email = event.email user_name = event.user_name operation_type = OPERATION_STATEMENT_TYPES.get( - query_type, OperationTypeClass.CUSTOM + query_type, + OperationTypeClass.CUSTOM, ) reported_time: int = int(time.time() * 1000) last_updated_timestamp: int = int(start_time.timestamp() * 1000) user_urn = make_user_urn( - self.identifiers.get_user_identifier(user_name, user_email) + self.identifiers.get_user_identifier(user_name, user_email), ) # NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect @@ -460,13 +475,13 @@ def _get_operation_aspect_work_unit( dataset_identifier = ( self.identifiers.get_dataset_identifier_from_qualified_name( - resource + resource, ) ) if dataset_identifier not in discovered_datasets: logger.debug( - f"Skipping operations for table {dataset_identifier}, as table schema is not accessible" + f"Skipping operations for table {dataset_identifier}, as table schema is not accessible", ) continue @@ -492,7 +507,8 @@ def _get_operation_aspect_work_unit( yield wu def _process_snowflake_history_row( - self, event_dict: dict + self, + event_dict: dict, ) -> Iterable[SnowflakeJoinedAccessEvent]: try: # big hammer try block to ensure we don't fail on parsing events self.report.rows_processed += 1 @@ -503,7 +519,7 @@ def _process_snowflake_history_row( return self.parse_event_objects(event_dict) event = SnowflakeJoinedAccessEvent( - **{k.lower(): v for k, v in event_dict.items()} + **{k.lower(): v for k, v in event_dict.items()}, ) yield event except Exception as e: @@ -539,7 +555,7 @@ def parse_event_objects(self, event_dict: Dict) -> None: self.report.rows_zero_objects_modified += 1 event_dict["QUERY_START_TIME"] = (event_dict["QUERY_START_TIME"]).astimezone( - tz=timezone.utc + tz=timezone.utc, ) if ( @@ -566,9 +582,10 @@ def _is_unsupported_object_accessed(self, obj: Dict[str, Any]) -> bool: def _is_object_valid(self, obj: Dict[str, Any]) -> bool: if self._is_unsupported_object_accessed( - obj + obj, ) or not self.filter.is_dataset_pattern_allowed( - obj.get("objectName"), obj.get("objectDomain") + obj.get("objectName"), + obj.get("objectDomain"), ): return False return True diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 030edfde4ca1da..a541b80189dd53 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -39,7 +39,10 @@ class SnowsightUrlBuilder: def __init__(self, account_locator: str, region: str, privatelink: bool = False): cloud, cloud_region_id = self.get_cloud_region_from_snowflake_region_id(region) self.snowsight_base_url = self.create_snowsight_base_url( - account_locator, cloud_region_id, cloud, privatelink + account_locator, + cloud_region_id, + cloud, + privatelink, ) @staticmethod @@ -94,7 +97,9 @@ def get_external_url_for_table( return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/{domain}/{table_name}/" def get_external_url_for_schema( - self, schema_name: str, db_name: str + self, + schema_name: str, + db_name: str, ) -> Optional[str]: return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/" @@ -104,7 +109,9 @@ def get_external_url_for_database(self, db_name: str) -> Optional[str]: class SnowflakeFilter: def __init__( - self, filter_config: SnowflakeFilterConfig, structured_reporter: SourceReport + self, + filter_config: SnowflakeFilterConfig, + structured_reporter: SourceReport, ) -> None: self.filter_config = filter_config self.structured_reporter = structured_reporter @@ -141,7 +148,7 @@ def is_dataset_pattern_allowed( if ( len(dataset_params) >= 1 and not self.filter_config.database_pattern.allowed( - dataset_params[0].strip('"') + dataset_params[0].strip('"'), ) ) or ( len(dataset_params) >= 2 @@ -155,9 +162,9 @@ def is_dataset_pattern_allowed( return False if dataset_type.lower() in { - SnowflakeObjectDomain.TABLE + SnowflakeObjectDomain.TABLE, } and not self.filter_config.table_pattern.allowed( - _cleanup_qualified_name(dataset_name, self.structured_reporter) + _cleanup_qualified_name(dataset_name, self.structured_reporter), ): return False @@ -165,7 +172,7 @@ def is_dataset_pattern_allowed( SnowflakeObjectDomain.VIEW, SnowflakeObjectDomain.MATERIALIZED_VIEW, } and not self.filter_config.view_pattern.allowed( - _cleanup_qualified_name(dataset_name, self.structured_reporter) + _cleanup_qualified_name(dataset_name, self.structured_reporter), ): return False @@ -173,7 +180,10 @@ def is_dataset_pattern_allowed( def _combine_identifier_parts( - *, table_name: str, schema_name: str, db_name: str + *, + table_name: str, + schema_name: str, + db_name: str, ) -> str: return f"{db_name}.{schema_name}.{table_name}" @@ -229,7 +239,8 @@ def _split_qualified_name(qualified_name: str) -> List[str]: # and also unavailability of utility function to identify whether current table/schema/database # name should be quoted in above method get_dataset_identifier def _cleanup_qualified_name( - qualified_name: str, structured_reporter: SourceReport + qualified_name: str, + structured_reporter: SourceReport, ) -> str: name_parts = _split_qualified_name(qualified_name) if len(name_parts) != 3: @@ -266,12 +277,17 @@ def snowflake_identifier(self, identifier: str) -> str: return identifier def get_dataset_identifier( - self, table_name: str, schema_name: str, db_name: str + self, + table_name: str, + schema_name: str, + db_name: str, ) -> str: return self.snowflake_identifier( _combine_identifier_parts( - table_name=table_name, schema_name=schema_name, db_name=db_name - ) + table_name=table_name, + schema_name=schema_name, + db_name=db_name, + ), ) def gen_dataset_urn(self, dataset_identifier: str) -> str: @@ -284,7 +300,7 @@ def gen_dataset_urn(self, dataset_identifier: str) -> str: def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: return self.snowflake_identifier( - _cleanup_qualified_name(qualified_name, self.structured_reporter) + _cleanup_qualified_name(qualified_name, self.structured_reporter), ) @staticmethod @@ -312,13 +328,13 @@ def get_user_identifier( return self.snowflake_identifier( user_email if self.identifier_config.email_as_user_identifier is True - else user_email.split("@")[0] + else user_email.split("@")[0], ) return self.snowflake_identifier( f"{user_name}@{self.identifier_config.email_domain}" if self.identifier_config.email_as_user_identifier is True and self.identifier_config.email_domain is not None - else user_name + else user_name, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index b4ef2180d71d45..51c8e7bc116663 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -139,23 +139,26 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.report: SnowflakeV2Report = SnowflakeV2Report() self.filters = SnowflakeFilter( - filter_config=self.config, structured_reporter=self.report + filter_config=self.config, + structured_reporter=self.report, ) self.identifiers = SnowflakeIdentifierBuilder( - identifier_config=self.config, structured_reporter=self.report + identifier_config=self.config, + structured_reporter=self.report, ) self.domain_registry: Optional[DomainRegistry] = None if self.config.domain: self.domain_registry = DomainRegistry( - cached_domains=[k for k in self.config.domain], graph=self.ctx.graph + cached_domains=[k for k in self.config.domain], + graph=self.ctx.graph, ) # The exit stack helps ensure that we close all the resources we open. self._exit_stack = contextlib.ExitStack() self.connection: SnowflakeConnection = self._exit_stack.enter_context( - self.config.get_connection() + self.config.get_connection(), ) # For database, schema, tables, views, etc @@ -182,7 +185,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): generate_usage_statistics=False, generate_operations=False, format_queries=self.config.format_sql_queries, - ) + ), ) self.report.sql_aggregator = self.aggregator.report @@ -206,7 +209,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): identifiers=self.identifiers, redundant_run_skip_handler=redundant_lineage_run_skip_handler, sql_aggregator=self.aggregator, - ) + ), ) self.usage_extractor: Optional[SnowflakeUsageExtractor] = None @@ -229,7 +232,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): filter=self.filters, identifiers=self.identifiers, redundant_run_skip_handler=redundant_usage_run_skip_handler, - ) + ), ) self.profiling_state_handler: Optional[ProfilingHandler] = None @@ -245,7 +248,9 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.profiler: Optional[SnowflakeProfiler] = None if config.is_profiling_enabled(): self.profiler = SnowflakeProfiler( - config, self.report, self.profiling_state_handler + config, + self.report, + self.profiling_state_handler, ) self.add_config_to_report() @@ -256,7 +261,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: try: connection_conf = SnowflakeConnectionConfig.parse_obj_allow_extras( - config_dict + config_dict, ) connection: SnowflakeConnection = connection_conf.get_connection() @@ -265,7 +270,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) test_report.capability_report = SnowflakeV2Source.check_capabilities( - connection, connection_conf + connection, + connection_conf, ) connection.close() @@ -273,7 +279,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: logger.error(f"Failed to test connection due to {e}", exc_info=e) if test_report.basic_connectivity is None: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=f"{e}" + capable=False, + failure_reason=f"{e}", ) else: test_report.internal_failure = True @@ -283,7 +290,8 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @staticmethod def check_capabilities( - conn: SnowflakeConnection, connection_conf: SnowflakeConnectionConfig + conn: SnowflakeConnection, + connection_conf: SnowflakeConnectionConfig, ) -> Dict[Union[SourceCapability, str], CapabilityReport]: # Currently only overall capabilities are reported. # Resource level variations in capabilities are not considered. @@ -312,7 +320,7 @@ class SnowflakePrivilege: cur = conn.query("select current_secondary_roles()") secondary_roles_str = json.loads( - [row["CURRENT_SECONDARY_ROLES()"] for row in cur][0] + [row["CURRENT_SECONDARY_ROLES()"] for row in cur][0], )["roles"] secondary_roles = ( [] if secondary_roles_str == "" else secondary_roles_str.split(",") @@ -343,7 +351,7 @@ class SnowflakePrivilege: "SCHEMA", ) and privilege.privilege in ("OWNERSHIP", "USAGE"): _report[SourceCapability.CONTAINERS] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.TAGS] = CapabilityReport(capable=True) elif privilege.object_type in ( @@ -352,19 +360,19 @@ class SnowflakePrivilege: "MATERIALIZED_VIEW", ): _report[SourceCapability.SCHEMA_METADATA] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.DESCRIPTIONS] = CapabilityReport( - capable=True + capable=True, ) # Table level profiling is supported without SELECT access # if privilege.privilege in ("SELECT", "OWNERSHIP"): _report[SourceCapability.DATA_PROFILING] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.CLASSIFICATION] = CapabilityReport( - capable=True + capable=True, ) if privilege.object_name.startswith("SNOWFLAKE.ACCOUNT_USAGE."): @@ -372,15 +380,15 @@ class SnowflakePrivilege: # Finer access control is not yet supported for shares # https://community.snowflake.com/s/article/Error-Granting-individual-privileges-on-imported-database-is-not-allowed-Use-GRANT-IMPORTED-PRIVILEGES-instead _report[SourceCapability.LINEAGE_COARSE] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.LINEAGE_FINE] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.USAGE_STATS] = CapabilityReport( - capable=True + capable=True, ) _report[SourceCapability.TAGS] = CapabilityReport(capable=True) @@ -448,13 +456,17 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), functools.partial( - auto_incremental_lineage, self.config.incremental_lineage + auto_incremental_lineage, + self.config.incremental_lineage, ), functools.partial( - auto_incremental_properties, self.config.incremental_properties + auto_incremental_properties, + self.config.incremental_properties, ), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -490,7 +502,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.config.shares: yield from SnowflakeSharesHandler( - self.config, self.report + self.config, + self.report, ).get_shares_workunits(databases) discovered_tables: List[str] = [ @@ -572,7 +585,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.config.include_assertion_results: yield from SnowflakeAssertionsHandler( - self.config, self.report, self.connection, self.identifiers + self.config, + self.report, + self.connection, + self.identifiers, ).get_assertion_workunits(discovered_datasets) self.connection.close() @@ -717,7 +733,11 @@ def _snowflake_clear_ocsp_cache(self) -> None: plat = platform.system().lower() if plat == "darwin": file_path = os.path.join( - "~", "Library", "Caches", "Snowflake", "ocsp_response_validation_cache" + "~", + "Library", + "Caches", + "Snowflake", + "ocsp_response_validation_cache", ) elif plat == "windows": file_path = os.path.join( @@ -731,7 +751,10 @@ def _snowflake_clear_ocsp_cache(self) -> None: else: # linux is the default fallback for snowflake file_path = os.path.join( - "~", ".cache", "snowflake", "ocsp_response_validation_cache" + "~", + ".cache", + "snowflake", + "ocsp_response_validation_cache", ) file_path = os.path.expanduser(file_path) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index cfc43454b51fad..f131e613967303 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -177,7 +177,8 @@ def _get_column_type(self, type_: Union[str, Dict[str, Any]]) -> TypeEngine: # # `get_avro_schema_for_hive_column` accepts a DDL description as column type and # returns the parsed data types in form of a dictionary schema = get_avro_schema_for_hive_column( - hive_column_name=type_name, hive_column_type=type_ + hive_column_name=type_name, + hive_column_type=type_, ) # the actual type description needs to be extracted @@ -197,12 +198,12 @@ def _get_column_type(self, type_: Union[str, Dict[str, Any]]) -> TypeEngine: # struct_field["name"], ( self._get_column_type( - struct_field["type"]["native_data_type"] + struct_field["type"]["native_data_type"], ) if struct_field["type"]["type"] not in ["record", "array"] else self._get_column_type(struct_field["type"]) ), - ) + ), ) args = struct_args @@ -240,14 +241,15 @@ class AthenaConfig(SQLCommonConfig): description="Username credential. If not specified, detected with boto3 rules. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", ) password: Optional[pydantic.SecretStr] = pydantic.Field( - default=None, description="Same detection scheme as username" + default=None, + description="Same detection scheme as username", ) database: Optional[str] = pydantic.Field( default=None, description="The athena database to ingest from. If not set it will be autodetected", ) aws_region: str = pydantic.Field( - description="Aws region where your Athena database is located" + description="Aws region where your Athena database is located", ) aws_role_arn: Optional[str] = pydantic.Field( default=None, @@ -263,7 +265,7 @@ class AthenaConfig(SQLCommonConfig): description="[deprecated in favor of `query_result_location`] S3 query location", ) work_group: str = pydantic.Field( - description="The name of your Amazon Athena Workgroups" + description="The name of your Amazon Athena Workgroups", ) catalog_name: str = pydantic.Field( default="awsdatacatalog", @@ -272,7 +274,7 @@ class AthenaConfig(SQLCommonConfig): query_result_location: str = pydantic.Field( description="S3 path to the [query result bucket](https://docs.aws.amazon.com/athena/latest/ug/querying.html#query-results-specify-location) which should be used by AWS Athena to store results of the" - "queries executed by DataHub." + "queries executed by DataHub.", ) extract_partitions: bool = pydantic.Field( @@ -362,14 +364,18 @@ def get_db_schema(self, dataset_identifier: str) -> Tuple[Optional[str], str]: return None, schema def get_table_properties( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: if not self.cursor: self.cursor = cast(BaseCursor, inspector.engine.raw_connection().cursor()) assert self.cursor metadata: AthenaTableMetadata = self.cursor.get_table_metadata( - table_name=table, schema_name=schema + table_name=table, + schema_name=schema, ) description = metadata.comment custom_properties: Dict[str, str] = {} @@ -381,7 +387,7 @@ def get_table_properties( "comment": partition.comment if partition.comment else "", } for partition in metadata.partition_keys - ] + ], ) for key, value in metadata.parameters.items(): custom_properties[key] = value if value else "" @@ -402,7 +408,7 @@ def get_table_properties( location = make_s3_urn(location, self.config.env) else: logging.debug( - f"Only s3 url supported for location. Skipping {location}" + f"Only s3 url supported for location. Skipping {location}", ) location = None @@ -423,7 +429,8 @@ def gen_schema_containers( extra_properties: Optional[Dict[str, Any]] = None, ) -> Iterable[MetadataWorkUnit]: database_container_key = self.get_database_container_key( - db_name=database, schema=schema + db_name=database, + schema=schema, ) yield from gen_database_container( @@ -480,7 +487,10 @@ def _casted_partition_key(cls, key: str) -> str: @override def get_partitions( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Optional[List[str]]: if not self.config.extract_partitions: return None @@ -489,7 +499,8 @@ def get_partitions( return None metadata: AthenaTableMetadata = self.cursor.get_table_metadata( - table_name=table, schema_name=schema + table_name=table, + schema_name=schema, ) partitions = [] @@ -559,20 +570,24 @@ def get_schema_fields_for_column( return fields def generate_partition_profiler_query( - self, schema: str, table: str, partition_datetime: Optional[datetime.datetime] + self, + schema: str, + table: str, + partition_datetime: Optional[datetime.datetime], ) -> Tuple[Optional[str], Optional[str]]: if not self.config.profiling.partition_profiling_enabled: return None, None partition: Optional[Partitionitem] = self.table_partition_cache.get( - schema, {} + schema, + {}, ).get(table, None) if partition and partition.max_partition: max_partition_filters = [] for key, value in partition.max_partition.items(): max_partition_filters.append( - f"{self._casted_partition_key(key)} = '{value}'" + f"{self._casted_partition_key(key)} = '{value}'", ) max_partition = str(partition.max_partition) return ( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py index a8208ca807ed02..8ffbf417315379 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py @@ -124,13 +124,16 @@ def __post_init__(self): class ClickHouseConfig( - TwoTierSQLAlchemyConfig, BaseTimeWindowConfig, DatasetLineageProviderConfigBase + TwoTierSQLAlchemyConfig, + BaseTimeWindowConfig, + DatasetLineageProviderConfigBase, ): # defaults host_port: str = Field(default="localhost:8123", description="ClickHouse host URL.") scheme: str = Field(default="clickhouse", description="", hidden_from_docs=True) password: pydantic.SecretStr = Field( - default=pydantic.SecretStr(""), description="password" + default=pydantic.SecretStr(""), + description="password", ) secure: Optional[bool] = Field(default=None, description="") protocol: Optional[str] = Field(default=None, description="") @@ -142,18 +145,19 @@ class ClickHouseConfig( description="The part of the URI and it's used to provide additional configuration options or parameters for the database connection.", ) include_table_lineage: Optional[bool] = Field( - default=True, description="Whether table lineage should be ingested." + default=True, + description="Whether table lineage should be ingested.", ) include_materialized_views: Optional[bool] = Field(default=True, description="") def get_sql_alchemy_url(self, current_db=None): url = make_url( - super().get_sql_alchemy_url(uri_opts=self.uri_opts, current_db=current_db) + super().get_sql_alchemy_url(uri_opts=self.uri_opts, current_db=current_db), ) if url.drivername == "clickhouse+native" and url.query.get("protocol"): logger.debug(f"driver = {url.drivername}, query = {url.query}") raise Exception( - "You cannot use a schema clickhouse+native and clickhouse+http at the same time" + "You cannot use a schema clickhouse+native and clickhouse+http at the same time", ) # We can setup clickhouse ingestion in sqlalchemy_uri form and config form. @@ -175,10 +179,10 @@ def projects_backward_compatibility(cls, values: Dict) -> Dict: logger.warning( "uri_opts is not set but protocol or secure option is set." " secure and protocol options is deprecated, please use " - "uri_opts instead." + "uri_opts instead.", ) logger.info( - "Initializing uri_opts from deprecated secure or protocol options" + "Initializing uri_opts from deprecated secure or protocol options", ) values["uri_opts"] = {} if secure: @@ -188,7 +192,7 @@ def projects_backward_compatibility(cls, values: Dict) -> Dict: logger.debug(f"uri_opts: {uri_opts}") elif (secure or protocol) and uri_opts: raise ValueError( - "secure and protocol options is deprecated. Please use uri_opts only." + "secure and protocol options is deprecated. Please use uri_opts only.", ) return values @@ -206,7 +210,7 @@ def projects_backward_compatibility(cls, values: Dict) -> Dict: def _get_all_table_comments_and_properties(self, connection, **kw): properties_clause = ( "formatRow('JSONEachRow', {properties_columns})".format( - properties_columns=PROPERTIES_COLUMNS + properties_columns=PROPERTIES_COLUMNS, ) if PROPERTIES_COLUMNS else "null" @@ -218,7 +222,7 @@ def _get_all_table_comments_and_properties(self, connection, **kw): , comment , {properties_clause} AS properties FROM system.tables - WHERE name NOT LIKE '.inner%'""".format(properties_clause=properties_clause) + WHERE name NOT LIKE '.inner%'""".format(properties_clause=properties_clause), ) all_table_comments: Dict[Tuple[str, str], Dict[str, Any]] = {} @@ -252,9 +256,9 @@ def _get_all_relation_info(self, connection, **kw): , if(engine LIKE '%View', 'v', 'r') AS relkind , name AS relname FROM system.tables - WHERE name NOT LIKE '.inner%'""" - ) - ) + WHERE name NOT LIKE '.inner%'""", + ), + ), ) relations = {} for rel in result: @@ -299,9 +303,9 @@ def _get_schema_column_info(self, connection, schema=None, **kw): , comment FROM system.columns WHERE {schema_clause} - ORDER BY database, table, position""".format(schema_clause=schema_clause) - ) - ) + ORDER BY database, table, position""".format(schema_clause=schema_clause), + ), + ), ) for col in result: key = (col.database, col.table_name) @@ -312,7 +316,9 @@ def _get_schema_column_info(self, connection, schema=None, **kw): def _get_clickhouse_columns(self, connection, table_name, schema=None, **kw): info_cache = kw.get("info_cache") all_schema_columns = self._get_schema_column_info( - connection, schema, info_cache=info_cache + connection, + schema, + info_cache=info_cache, ) key = (schema, table_name) return all_schema_columns[key] @@ -419,7 +425,7 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit assert dataset_snapshot lineage_mcp, lineage_properties_aspect = self.get_lineage_mcp( - wu.metadata.proposedSnapshot.urn + wu.metadata.proposedSnapshot.urn, ) if lineage_mcp is not None: @@ -461,7 +467,7 @@ def _get_all_tables(self) -> Set[str]: """\ SELECT database, name AS table_name FROM system.tables - WHERE name NOT LIKE '.inner%'""" + WHERE name NOT LIKE '.inner%'""", ) all_tables_set = set() @@ -475,7 +481,9 @@ def _get_all_tables(self) -> Set[str]: return all_tables_set def _populate_lineage_map( - self, query: str, lineage_type: LineageCollectorType + self, + query: str, + lineage_type: LineageCollectorType, ) -> None: """ This method generate table level lineage based with the given query. @@ -501,7 +509,7 @@ def _populate_lineage_map( for db_row in engine.execute(text(query)): dataset_name = f"{db_row['target_schema']}.{db_row['target_table']}" if not self.config.database_pattern.allowed( - db_row["target_schema"] + db_row["target_schema"], ) or not self.config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue @@ -513,7 +521,8 @@ def _populate_lineage_map( ) target = LineageItem( dataset=LineageDataset( - platform=LineageDatasetPlatform.CLICKHOUSE, path=target_path + platform=LineageDatasetPlatform.CLICKHOUSE, + path=target_path, ), upstreams=set(), collector_type=lineage_type, @@ -527,7 +536,7 @@ def _populate_lineage_map( LineageDataset( platform=platform, path=path, - ) + ), ] for source in sources: @@ -554,13 +563,13 @@ def _populate_lineage_map( self._lineage_map[target.dataset.path] = target logger.info( - f"Lineage[{target}]:{self._lineage_map[target.dataset.path]}" + f"Lineage[{target}]:{self._lineage_map[target.dataset.path]}", ) except Exception as e: logger.warning( f"Extracting {lineage_type.name} lineage from ClickHouse failed." - f"Continuing...\nError was {e}." + f"Continuing...\nError was {e}.", ) def _populate_lineage(self) -> None: @@ -581,7 +590,7 @@ def _populate_lineage(self) -> None: FROM system.tables WHERE engine IN ('Dictionary') AND create_table_query LIKE '%SOURCE(CLICKHOUSE(%' - ORDER BY target_schema, target_table, source_schema, source_table""" + ORDER BY target_schema, target_table, source_schema, source_table""", ) view_lineage_query = textwrap.dedent( @@ -598,7 +607,7 @@ def _populate_lineage(self) -> None: ARRAY JOIN arrayIntersect(splitByRegexp('[\\s()'']+', create_table_query), tables) AS source WHERE engine IN ('View') AND NOT (source_schema = target_schema AND source_table = target_table) - ORDER BY target_schema, target_table, source_schema, source_table""" + ORDER BY target_schema, target_table, source_schema, source_table""", ) # get materialized view downstream and upstream @@ -629,7 +638,7 @@ def _populate_lineage(self) -> None: FROM system.tables WHERE engine IN ('MaterializedView') AND extract_to <> '') - ORDER BY target_schema, target_table, source_schema, source_table""" + ORDER BY target_schema, target_table, source_schema, source_table""", ) if not self._lineage_map: @@ -638,13 +647,15 @@ def _populate_lineage(self) -> None: if self.config.include_tables: # Populate table level lineage for dictionaries and distributed tables self._populate_lineage_map( - query=table_lineage_query, lineage_type=LineageCollectorType.TABLE + query=table_lineage_query, + lineage_type=LineageCollectorType.TABLE, ) if self.config.include_views: # Populate table level lineage for views self._populate_lineage_map( - query=view_lineage_query, lineage_type=LineageCollectorType.VIEW + query=view_lineage_query, + lineage_type=LineageCollectorType.VIEW, ) if self.config.include_materialized_views: @@ -655,9 +666,11 @@ def _populate_lineage(self) -> None: ) def get_lineage_mcp( - self, dataset_urn: str + self, + dataset_urn: str, ) -> Tuple[ - Optional[MetadataChangeProposalWrapper], Optional[DatasetPropertiesClass] + Optional[MetadataChangeProposalWrapper], + Optional[DatasetPropertiesClass], ]: dataset_key = mce_builder.dataset_urn_to_key(dataset_urn) if dataset_key is None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py b/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py index 76b72d8e37f74b..078a9ec7d6dfd4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py @@ -16,7 +16,7 @@ class CockroachDBConfig(PostgresConfig): scheme = Field(default="cockroachdb+psycopg2", description="database scheme") schema_pattern = Field( - default=AllowDenyPattern(deny=["information_schema", "crdb_internal"]) + default=AllowDenyPattern(deny=["information_schema", "crdb_internal"]), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py index 6d67ab29b3a3d8..135bc3e729abde 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py @@ -204,7 +204,7 @@ def __init__( ): if hive_storage_lineage_direction.lower() not in ["upstream", "downstream"]: raise ValueError( - "hive_storage_lineage_direction must be either upstream or downstream" + "hive_storage_lineage_direction must be either upstream or downstream", ) self.emit_storage_lineage = emit_storage_lineage @@ -324,14 +324,14 @@ def _get_fine_grained_lineages( make_schema_field_urn( parent_urn=storage_urn, field_path=matching_field.fieldPath, - ) + ), ], downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ make_schema_field_urn( parent_urn=dataset_urn, field_path=dataset_path, - ) + ), ], ) else: @@ -341,14 +341,14 @@ def _get_fine_grained_lineages( make_schema_field_urn( parent_urn=dataset_urn, field_path=dataset_path, - ) + ), ], downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ make_schema_field_urn( parent_urn=storage_urn, field_path=matching_field.fieldPath, - ) + ), ], ) @@ -366,7 +366,7 @@ def _create_lineage_mcp( upstream_lineage = UpstreamLineageClass( upstreams=[ - UpstreamClass(dataset=source_urn, type=DatasetLineageTypeClass.COPY) + UpstreamClass(dataset=source_urn, type=DatasetLineageTypeClass.COPY), ], fineGrainedLineages=lineages_list, ) @@ -374,7 +374,8 @@ def _create_lineage_mcp( yield MetadataWorkUnit( id=f"{source_urn}-{target_urn}-lineage", mcp=MetadataChangeProposalWrapper( - entityUrn=target_urn, aspect=upstream_lineage + entityUrn=target_urn, + aspect=upstream_lineage, ), ) @@ -430,7 +431,8 @@ def get_storage_dataset_mcp( yield MetadataWorkUnit( id=f"storage-{storage_urn}-platform", mcp=MetadataChangeProposalWrapper( - entityUrn=storage_urn, aspect=platform_instance_aspect + entityUrn=storage_urn, + aspect=platform_instance_aspect, ), ) @@ -447,13 +449,14 @@ def get_storage_dataset_mcp( yield MetadataWorkUnit( id=f"storage-{storage_urn}-schema", mcp=MetadataChangeProposalWrapper( - entityUrn=storage_urn, aspect=storage_schema + entityUrn=storage_urn, + aspect=storage_schema, ), ) except Exception as e: logger.error( - f"Failed to create storage dataset MCPs for {storage_location}: {e}" + f"Failed to create storage dataset MCPs for {storage_location}: {e}", ) return @@ -516,7 +519,10 @@ def get_lineage_mcp( None if not (dataset_schema and storage_schema) else self._get_fine_grained_lineages( - dataset_urn, storage_urn, dataset_schema, storage_schema + dataset_urn, + storage_urn, + dataset_schema, + storage_schema, ) ) @@ -598,8 +604,9 @@ def dbapi_get_columns_patched(self, connection, table_name, schema=None, **kw): except KeyError: util.warn( "Did not recognize type '{}' of column '{}'".format( - col_type, col_name - ) + col_type, + col_name, + ), ) coltype = types.NullType # type: ignore result.append( @@ -610,7 +617,7 @@ def dbapi_get_columns_patched(self, connection, table_name, schema=None, **kw): "default": None, "full_type": orig_col_type, # pass it through "comment": _comment, - } + }, ) return result @@ -678,7 +685,7 @@ def _validate_direction(cls, v: str) -> str: """Validate the lineage direction.""" if v.lower() not in ["upstream", "downstream"]: raise ValueError( - "storage_lineage_direction must be either upstream or downstream" + "storage_lineage_direction must be either upstream or downstream", ) return v.lower() @@ -744,8 +751,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if dataset_props and dataset_props.customProperties: table = { "StorageDescriptor": { - "Location": dataset_props.customProperties.get("Location") - } + "Location": dataset_props.customProperties.get("Location"), + }, } if table.get("StorageDescriptor", {}).get("Location"): @@ -780,17 +787,20 @@ def get_schema_fields_for_column( ) if self._COMPLEX_TYPE.match(fields[0].nativeDataType) and isinstance( - fields[0].type.type, NullTypeClass + fields[0].type.type, + NullTypeClass, ): assert len(fields) == 1 field = fields[0] # Get avro schema for subfields along with parent complex field avro_schema = get_avro_schema_for_hive_column( - column["name"], field.nativeDataType + column["name"], + field.nativeDataType, ) new_fields = schema_util.avro_schema_to_mce_fields( - json.dumps(avro_schema), default_nullable=True + json.dumps(avro_schema), + default_nullable=True, ) # First field is the parent complex field @@ -832,7 +842,9 @@ def _process_view( if view_definition: view_properties_aspect = ViewPropertiesClass( - materialized=False, viewLanguage="SQL", viewLogic=view_definition + materialized=False, + viewLanguage="SQL", + viewLogic=view_definition, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py index 60ecbaf38838a6..998bdc4f46dcc8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py @@ -124,7 +124,9 @@ class HiveMetastore(BasicSQLAlchemyConfig): ) include_view_lineage: bool = Field( - default=False, description="", hidden_from_docs=True + default=False, + description="", + hidden_from_docs=True, ) include_catalog_name_in_ids: bool = Field( @@ -143,7 +145,9 @@ class HiveMetastore(BasicSQLAlchemyConfig): ) def get_sql_alchemy_url( - self, uri_opts: Optional[Dict[str, Any]] = None, database: Optional[str] = None + self, + uri_opts: Optional[Dict[str, Any]] = None, + database: Optional[str] = None, ) -> str: if not ((self.host_port and self.scheme) or self.sqlalchemy_uri): raise ValueError("host_port and schema or connect_uri required.") @@ -165,7 +169,9 @@ def get_sql_alchemy_url( @capability(SourceCapability.DATA_PROFILING, "Not Supported", False) @capability(SourceCapability.CLASSIFICATION, "Not Supported", False) @capability( - SourceCapability.LINEAGE_COARSE, "View lineage is not supported", supported=False + SourceCapability.LINEAGE_COARSE, + "View lineage is not supported", + supported=False, ) class HiveMetastoreSource(SQLAlchemySource): """ @@ -383,11 +389,11 @@ def gen_schema_containers( statement: str = ( HiveMetastoreSource._SCHEMAS_POSTGRES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) if "postgresql" in self.config.scheme else HiveMetastoreSource._SCHEMAS_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) ) @@ -430,15 +436,18 @@ def get_default_ingestion_job_id(self) -> JobId: return JobId(self.config.ingestion_job_id) def _get_table_properties( - self, db_name: str, scheme: str, where_clause_suffix: str + self, + db_name: str, + scheme: str, + where_clause_suffix: str, ) -> Dict[str, Dict[str, str]]: statement: str = ( HiveMetastoreSource._HIVE_PROPERTIES_POSTGRES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) if "postgresql" in scheme else HiveMetastoreSource._HIVE_PROPERTIES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) ) iter_res = self._alchemy_client.execute_query(statement) @@ -472,11 +481,11 @@ def loop_tables( where_clause_suffix = f"{sql_config.tables_where_clause_suffix} {self._get_db_filter_where_clause()}" statement: str = ( HiveMetastoreSource._TABLES_POSTGRES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) if "postgresql" in sql_config.scheme else HiveMetastoreSource._TABLES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) ) @@ -498,7 +507,9 @@ def loop_tables( ) dataset_name = self.get_identifier( - schema=schema_name, entity=key.table, inspector=inspector + schema=schema_name, + entity=key.table, + inspector=inspector, ) self.report.report_entity_scanned(dataset_name, ent_type="table") @@ -550,21 +561,23 @@ def loop_tables( properties["create_date"] = str(columns[-1]["create_date"] or "") par_columns: str = ", ".join( - [c["col_name"] for c in columns if c["is_partition_col"]] + [c["col_name"] for c in columns if c["is_partition_col"]], ) if par_columns != "": properties["partitioned_columns"] = par_columns table_description = properties.get("comment") yield from self.add_hive_dataset_to_container( - dataset_urn=dataset_urn, inspector=inspector, schema=key.schema + dataset_urn=dataset_urn, + inspector=inspector, + schema=key.schema, ) if self.config.enable_properties_merge: from datahub.specific.dataset import DatasetPatchBuilder patch_builder: DatasetPatchBuilder = DatasetPatchBuilder( - urn=dataset_snapshot.urn + urn=dataset_snapshot.urn, ) patch_builder.set_display_name(key.table) @@ -616,7 +629,10 @@ def loop_tables( ) def add_hive_dataset_to_container( - self, dataset_urn: str, inspector: Inspector, schema: str + self, + dataset_urn: str, + inspector: Inspector, + schema: str, ) -> Iterable[MetadataWorkUnit]: db_name = self.get_db_name(inspector) schema_container_key = gen_schema_key( @@ -638,11 +654,11 @@ def get_hive_view_columns(self, inspector: Inspector) -> Iterable[ViewDataset]: statement: str = ( HiveMetastoreSource._HIVE_VIEWS_POSTGRES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) if "postgresql" in self.config.scheme else HiveMetastoreSource._HIVE_VIEWS_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) ) @@ -657,7 +673,9 @@ def get_hive_view_columns(self, inspector: Inspector) -> Iterable[ViewDataset]: ) dataset_name = self.get_identifier( - schema=schema_name, entity=key.table, inspector=inspector + schema=schema_name, + entity=key.table, + inspector=inspector, ) if not self.config.database_pattern.allowed(key.schema): @@ -683,11 +701,11 @@ def get_presto_view_columns(self, inspector: Inspector) -> Iterable[ViewDataset] statement: str = ( HiveMetastoreSource._VIEWS_POSTGRES_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) if "postgresql" in self.config.scheme else HiveMetastoreSource._VIEWS_SQL_STATEMENT.format( - where_clause_suffix=where_clause_suffix + where_clause_suffix=where_clause_suffix, ) ) @@ -706,7 +724,7 @@ def get_presto_view_columns(self, inspector: Inspector) -> Iterable[ViewDataset] ) columns, view_definition = self._get_presto_view_column_metadata( - row["view_original_text"] + row["view_original_text"], ) if len(columns) == 0: @@ -795,7 +813,9 @@ def loop_views( dataset_snapshot.aspects.append(view_properties) yield from self.add_hive_dataset_to_container( - dataset_urn=dataset_urn, inspector=inspector, schema=dataset.schema_name + dataset_urn=dataset_urn, + inspector=inspector, + schema=dataset.schema_name, ) # construct mce @@ -853,7 +873,8 @@ def _get_table_key(self, row: Dict[str, Any]) -> TableKey: return TableKey(schema=row["schema_name"], table=row["table_name"]) def _get_presto_view_column_metadata( - self, view_original_text: str + self, + view_original_text: str, ) -> Tuple[List[Dict], str]: """ Get Column Metadata from VIEW_ORIGINAL_TEXT from TBLS table for Presto Views. @@ -863,7 +884,8 @@ def _get_presto_view_column_metadata( """ # remove encoded Presto View data prefix and suffix encoded_view_info = view_original_text.split( - HiveMetastoreSource._PRESTO_VIEW_PREFIX, 1 + HiveMetastoreSource._PRESTO_VIEW_PREFIX, + 1, )[-1].rsplit(HiveMetastoreSource._PRESTO_VIEW_SUFFIX, 1)[0] # view_original_text is b64 encoded: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/job_models.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/job_models.py index d3941e7add0fd0..af2762c380d7a8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/job_models.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/job_models.py @@ -263,7 +263,8 @@ def as_maybe_platform_instance_aspect(self) -> Optional[DataPlatformInstanceClas return DataPlatformInstanceClass( platform=make_data_platform_urn(self.entity.orchestrator), instance=make_dataplatform_instance_urn( - self.entity.orchestrator, self.entity.platform_instance + self.entity.orchestrator, + self.entity.platform_instance, ), ) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py index a2338f14196d77..ef3c643e46458f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py @@ -75,7 +75,8 @@ class SQLServerConfig(BasicSQLAlchemyConfig): description="Include ingest of stored procedures. Requires access to the 'sys' schema.", ) include_stored_procedures_code: bool = Field( - default=True, description="Include information about object code." + default=True, + description="Include information about object code.", ) procedure_pattern: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), @@ -87,7 +88,8 @@ class SQLServerConfig(BasicSQLAlchemyConfig): description="Include ingest of MSSQL Jobs. Requires access to the 'msdb' and 'sys' schema.", ) include_descriptions: bool = Field( - default=True, description="Include table descriptions information." + default=True, + description="Include table descriptions information.", ) use_odbc: bool = Field( default=False, @@ -201,7 +203,7 @@ def handle_sql_variant_as_string(value): conn.connection.add_output_converter(-150, handle_sql_variant_as_string) except AttributeError as e: logger.debug( - f"Failed to mount output converter for MSSQL data type -150 due to {e}" + f"Failed to mount output converter for MSSQL data type -150 due to {e}", ) def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None: @@ -219,7 +221,7 @@ def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None: AND EP.MINOR_ID = 0 AND EP.NAME = 'MS_Description' AND EP.CLASS = 1 - """ + """, ) for row in table_metadata: self.table_descriptions[ @@ -242,7 +244,7 @@ def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None: AND EP.MINOR_ID = C.COLUMN_ID AND EP.NAME = 'MS_Description' AND EP.CLASS = 1 - """ + """, ) for row in column_metadata: self.column_descriptions[ @@ -256,24 +258,37 @@ def create(cls, config_dict: Dict, ctx: PipelineContext) -> "SQLServerSource": # override to get table descriptions def get_table_properties( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: description, properties, location_urn = super().get_table_properties( - inspector, schema, table + inspector, + schema, + table, ) # Update description if available. db_name: str = self.get_db_name(inspector) description = self.table_descriptions.get( - f"{db_name}.{schema}.{table}", description + f"{db_name}.{schema}.{table}", + description, ) return description, properties, location_urn # override to get column descriptions def _get_columns( - self, dataset_name: str, inspector: Inspector, schema: str, table: str + self, + dataset_name: str, + inspector: Inspector, + schema: str, + table: str, ) -> List[Dict]: columns: List[Dict] = super()._get_columns( - dataset_name, inspector, schema, table + dataset_name, + inspector, + schema, + table, ) # Update column description if available. db_name: str = self.get_db_name(inspector) @@ -344,7 +359,7 @@ def _get_jobs(self, conn: Connection, db_name: str) -> Dict[str, Dict[str, Any]] ON job.job_id = steps.job_id where database_name = '{db_name}' - """ + """, ) jobs: Dict[str, Dict[str, Any]] = {} for row in jobs_data: @@ -389,7 +404,9 @@ def loop_jobs( yield from self.loop_job_steps(job, job_steps) def loop_job_steps( - self, job: MSSQLJob, job_steps: Dict[str, Any] + self, + job: MSSQLJob, + job_steps: Dict[str, Any], ) -> Iterable[MetadataWorkUnit]: for _step_id, step_data in job_steps.items(): step = JobStep( @@ -429,7 +446,7 @@ def loop_stored_procedures( # noqa: C901 self.report.report_dropped(procedure_full_name) continue procedures.append( - StoredProcedure(flow=mssql_default_job, **procedure_data) + StoredProcedure(flow=mssql_default_job, **procedure_data), ) if procedures: @@ -438,7 +455,9 @@ def loop_stored_procedures( # noqa: C901 yield from self._process_stored_procedure(conn, procedure) def _process_stored_procedure( - self, conn: Connection, procedure: StoredProcedure + self, + conn: Connection, + procedure: StoredProcedure, ) -> Iterable[MetadataWorkUnit]: upstream = self._get_procedure_upstream(conn, procedure) downstream = self._get_procedure_downstream(conn, procedure) @@ -459,7 +478,8 @@ def _process_stored_procedure( procedure_inputs = self._get_procedure_inputs(conn, procedure) properties = self._get_procedure_properties(conn, procedure) data_job.add_property( - "input parameters", str([param.name for param in procedure_inputs]) + "input parameters", + str([param.name for param in procedure_inputs]), ) for param in procedure_inputs: data_job.add_property(f"parameter {param.name}", str(param.properties)) @@ -476,7 +496,8 @@ def _process_stored_procedure( @staticmethod def _get_procedure_downstream( - conn: Connection, procedure: StoredProcedure + conn: Connection, + procedure: StoredProcedure, ) -> ProcedureLineageStream: downstream_data = conn.execute( f""" @@ -488,7 +509,7 @@ def _get_procedure_downstream( left join sys.objects o1 on sed.referenced_id = o1.object_id WHERE referenced_id = OBJECT_ID(N'{procedure.escape_full_name}') AND o.type_desc in ('TABLE_TYPE', 'VIEW', 'USER_TABLE') - """ + """, ) downstream_dependencies = [] for row in downstream_data: @@ -500,13 +521,14 @@ def _get_procedure_downstream( type=row["type"], env=procedure.flow.env, server=procedure.flow.platform_instance, - ) + ), ) return ProcedureLineageStream(dependencies=downstream_dependencies) @staticmethod def _get_procedure_upstream( - conn: Connection, procedure: StoredProcedure + conn: Connection, + procedure: StoredProcedure, ) -> ProcedureLineageStream: upstream_data = conn.execute( f""" @@ -521,7 +543,7 @@ def _get_procedure_upstream( WHERE referencing_id = OBJECT_ID(N'{procedure.escape_full_name}') AND referenced_schema_name is not null AND o1.type_desc in ('TABLE_TYPE', 'VIEW', 'SQL_STORED_PROCEDURE', 'USER_TABLE') - """ + """, ) upstream_dependencies = [] for row in upstream_data: @@ -533,13 +555,14 @@ def _get_procedure_upstream( type=row["type"], env=procedure.flow.env, server=procedure.flow.platform_instance, - ) + ), ) return ProcedureLineageStream(dependencies=upstream_dependencies) @staticmethod def _get_procedure_inputs( - conn: Connection, procedure: StoredProcedure + conn: Connection, + procedure: StoredProcedure, ) -> List[ProcedureParameter]: inputs_data = conn.execute( f""" @@ -548,7 +571,7 @@ def _get_procedure_inputs( type_name(user_type_id) AS 'type' FROM sys.parameters WHERE object_id = object_id('{procedure.escape_full_name}') - """ + """, ) inputs_list = [] for row in inputs_data: @@ -557,7 +580,8 @@ def _get_procedure_inputs( @staticmethod def _get_procedure_code( - conn: Connection, procedure: StoredProcedure + conn: Connection, + procedure: StoredProcedure, ) -> Tuple[Optional[str], Optional[str]]: query = f"EXEC [{procedure.db}].dbo.sp_helptext '{procedure.escape_full_name}'" try: @@ -588,7 +612,8 @@ def _get_procedure_code( @staticmethod def _get_procedure_properties( - conn: Connection, procedure: StoredProcedure + conn: Connection, + procedure: StoredProcedure, ) -> Dict[str, Any]: properties_data = conn.execute( f""" @@ -597,18 +622,21 @@ def _get_procedure_properties( modify_date as date_modified FROM sys.procedures WHERE object_id = object_id('{procedure.escape_full_name}') - """ + """, ) properties = {} for row in properties_data: properties = dict( - date_created=row["date_created"], date_modified=row["date_modified"] + date_created=row["date_created"], + date_modified=row["date_modified"], ) return properties @staticmethod def _get_stored_procedures( - conn: Connection, db_name: str, schema: str + conn: Connection, + db_name: str, + schema: str, ) -> List[Dict[str, str]]: stored_procedures_data = conn.execute( f""" @@ -620,12 +648,12 @@ def _get_stored_procedures( INNER JOIN [{db_name}].[sys].[schemas] s ON pr.schema_id = s.schema_id where s.name = '{schema}' - """ + """, ) procedures_list = [] for row in stored_procedures_data: procedures_list.append( - dict(db=db_name, schema=row["schema_name"], name=row["procedure_name"]) + dict(db=db_name, schema=row["schema_name"], name=row["procedure_name"]), ) return procedures_list @@ -684,20 +712,26 @@ def get_inspectors(self) -> Iterable[Inspector]: databases = conn.execute( "SELECT name FROM master.sys.databases WHERE name NOT IN \ ('master', 'model', 'msdb', 'tempdb', 'Resource', \ - 'distribution' , 'reportserver', 'reportservertempdb'); " + 'distribution' , 'reportserver', 'reportservertempdb'); ", ) for db in databases: if self.config.database_pattern.allowed(db["name"]): url = self.config.get_sql_alchemy_url(current_db=db["name"]) with create_engine( - url, **self.config.options + url, + **self.config.options, ).connect() as conn: inspector = inspect(conn) self.current_database = db["name"] yield inspector def get_identifier( - self, *, schema: str, entity: str, inspector: Inspector, **kwargs: Any + self, + *, + schema: str, + entity: str, + inspector: Inspector, + **kwargs: Any, ) -> str: regular = f"{schema}.{entity}" qualified_table_name = regular @@ -728,7 +762,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: procedure=procedure, procedure_job_urn=MSSQLDataJob(entity=procedure).urn, is_temp_table=self.is_temp_table, - ) + ), ) def is_temp_table(self, name: str) -> bool: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/stored_procedure_lineage.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/stored_procedure_lineage.py index b979a270a55282..6005fb3f6d2a37 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/stored_procedure_lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/stored_procedure_lineage.py @@ -43,12 +43,12 @@ def parse_procedure_code( default_db=default_db, default_schema=default_schema, query=query, - ) + ), ) if aggregator.report.num_observed_queries_failed and raise_: logger.info(aggregator.report.as_string()) raise ValueError( - f"Failed to parse {aggregator.report.num_observed_queries_failed} queries." + f"Failed to parse {aggregator.report.num_observed_queries_failed} queries.", ) mcps = list(aggregator.gen_metadata()) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py index 12ce271e9f5ef7..0a4c391036a91d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py @@ -91,7 +91,7 @@ def add_profile_metadata(self, inspector: Inspector) -> None: return with inspector.engine.connect() as conn: for row in conn.execute( - "SELECT table_schema, table_name, data_length from information_schema.tables" + "SELECT table_schema, table_name, data_length from information_schema.tables", ): self.profile_metadata_info.dataset_name_to_storage_bytes[ f"{row.TABLE_SCHEMA}.{row.TABLE_NAME}" diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py index ac568c58af6c68..58421fab7b3d33 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py @@ -70,10 +70,12 @@ class OracleConfig(BasicSQLAlchemyConfig): description="Will be set automatically to default value.", ) service_name: Optional[str] = Field( - default=None, description="Oracle service name. If using, omit `database`." + default=None, + description="Oracle service name. If using, omit `database`.", ) database: Optional[str] = Field( - default=None, description="If using, omit `service_name`." + default=None, + description="If using, omit `service_name`.", ) add_database_name_to_urn: Optional[bool] = Field( default=False, @@ -90,7 +92,7 @@ class OracleConfig(BasicSQLAlchemyConfig): def check_service_name(cls, v, values): if values.get("database") and v: raise ValueError( - "specify one of 'database' and 'service_name', but not both" + "specify one of 'database' and 'service_name', but not both", ) return v @@ -132,7 +134,7 @@ def get_db_name(self) -> str: try: # Try to retrieve current DB name by executing query db_name = self._inspector_instance.bind.execute( - sql.text("select sys_context('USERENV','DB_NAME') from dual") + sql.text("select sys_context('USERENV','DB_NAME') from dual"), ).scalar() return str(db_name) except sqlalchemy.exc.DatabaseError as e: @@ -141,7 +143,7 @@ def get_db_name(self) -> str: def get_schema_names(self) -> List[str]: cursor = self._inspector_instance.bind.execute( - sql.text("SELECT username FROM dba_users ORDER BY username") + sql.text("SELECT username FROM dba_users ORDER BY username"), ) return [ @@ -155,7 +157,7 @@ def get_table_names(self, schema: Optional[str] = None) -> List[str]: skip order_by, we are not using order_by """ schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -180,7 +182,7 @@ def get_table_names(self, schema: Optional[str] = None) -> List[str]: def get_view_names(self, schema: Optional[str] = None) -> List[str]: schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -198,15 +200,18 @@ def get_view_names(self, schema: Optional[str] = None) -> List[str]: ] def get_columns( - self, table_name: str, schema: Optional[str] = None, dblink: str = "" + self, + table_name: str, + schema: Optional[str] = None, + dblink: str = "", ) -> List[dict]: denormalized_table_name = self._inspector_instance.dialect.denormalize_name( - table_name + table_name, ) assert denormalized_table_name schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -304,7 +309,7 @@ def get_columns( coltype = ischema_names[coltype]() except KeyError: logger.warning( - f"Did not recognize type {coltype} of column {colname}" + f"Did not recognize type {coltype} of column {colname}", ) coltype = sqltypes.NULLTYPE @@ -316,7 +321,8 @@ def get_columns( if identity_options is not None: identity = self._inspector_instance.dialect._parse_identity_options( # type: ignore - identity_options, default_on_nul + identity_options, + default_on_nul, ) default = None else: @@ -342,12 +348,12 @@ def get_columns( def get_table_comment(self, table_name: str, schema: Optional[str] = None) -> Dict: denormalized_table_name = self._inspector_instance.dialect.denormalize_name( - table_name + table_name, ) assert denormalized_table_name schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -368,7 +374,10 @@ def get_table_comment(self, table_name: str, schema: Optional[str] = None) -> Di return {"text": c.scalar()} def _get_constraint_data( - self, table_name: str, schema: Optional[str] = None, dblink: str = "" + self, + table_name: str, + schema: Optional[str] = None, + dblink: str = "", ) -> List[sqlalchemy.engine.Row]: params = {"table_name": table_name} @@ -410,15 +419,18 @@ def _get_constraint_data( return constraint_data def get_pk_constraint( - self, table_name: str, schema: Optional[str] = None, dblink: str = "" + self, + table_name: str, + schema: Optional[str] = None, + dblink: str = "", ) -> Dict: denormalized_table_name = self._inspector_instance.dialect.denormalize_name( - table_name + table_name, ) assert denormalized_table_name schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -427,7 +439,9 @@ def get_pk_constraint( pkeys = [] constraint_name = None constraint_data = self._get_constraint_data( - denormalized_table_name, schema, dblink + denormalized_table_name, + schema, + dblink, ) for row in constraint_data: @@ -439,27 +453,30 @@ def get_pk_constraint( remote_column, remote_owner, ) = row[0:2] + tuple( - [self._inspector_instance.dialect.normalize_name(x) for x in row[2:6]] + [self._inspector_instance.dialect.normalize_name(x) for x in row[2:6]], ) if cons_type == "P": if constraint_name is None: constraint_name = self._inspector_instance.dialect.normalize_name( - cons_name + cons_name, ) pkeys.append(local_column) return {"constrained_columns": pkeys, "name": constraint_name} def get_foreign_keys( - self, table_name: str, schema: Optional[str] = None, dblink: str = "" + self, + table_name: str, + schema: Optional[str] = None, + dblink: str = "", ) -> List: denormalized_table_name = self._inspector_instance.dialect.denormalize_name( - table_name + table_name, ) assert denormalized_table_name schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -468,7 +485,9 @@ def get_foreign_keys( requested_schema = schema # to check later on constraint_data = self._get_constraint_data( - denormalized_table_name, schema, dblink + denormalized_table_name, + schema, + dblink, ) def fkey_rec(): @@ -492,7 +511,7 @@ def fkey_rec(): remote_column, remote_owner, ) = row[0:2] + tuple( - [self._inspector_instance.dialect.normalize_name(x) for x in row[2:6]] + [self._inspector_instance.dialect.normalize_name(x) for x in row[2:6]], ) cons_name = self._inspector_instance.dialect.normalize_name(cons_name) @@ -502,7 +521,7 @@ def fkey_rec(): logger.warning( "Got 'None' querying 'table_name' from " f"dba_cons_columns{dblink} - does the user have " - "proper rights to the table?" + "proper rights to the table?", ) rec = fkeys[cons_name] @@ -517,7 +536,7 @@ def fkey_rec(): if ( requested_schema is not None or self._inspector_instance.dialect.denormalize_name( - remote_owner + remote_owner, ) != schema ): @@ -532,15 +551,17 @@ def fkey_rec(): return list(fkeys.values()) def get_view_definition( - self, view_name: str, schema: Optional[str] = None + self, + view_name: str, + schema: Optional[str] = None, ) -> Union[str, None]: denormalized_view_name = self._inspector_instance.dialect.denormalize_name( - view_name + view_name, ) assert denormalized_view_name schema = self._inspector_instance.dialect.denormalize_name( - schema or self.default_schema_name + schema or self.default_schema_name, ) if schema is None: @@ -610,10 +631,12 @@ def get_db_name(self, inspector: Inspector) -> str: def get_inspectors(self) -> Iterable[Inspector]: for inspector in super().get_inspectors(): event.listen( - inspector.engine, "before_cursor_execute", before_cursor_execute + inspector.engine, + "before_cursor_execute", + before_cursor_execute, ) logger.info( - f'Data dictionary mode is: "{self.config.data_dictionary_mode}".' + f'Data dictionary mode is: "{self.config.data_dictionary_mode}".', ) # Sqlalchemy inspector uses ALL_* tables as per oracle dialect implementation. # OracleInspectorObjectWrapper provides alternate implementation using DBA_* tables. @@ -658,7 +681,7 @@ def generate_profile_candidates( WHERE t.OWNER = :owner AND (t.NUM_ROWS < :table_row_limit OR t.NUM_ROWS IS NULL) AND COALESCE(t.NUM_ROWS * t.AVG_ROW_LEN, 0) / (1024 * 1024 * 1024) < :table_size_limit - """ + """, ), dict( owner=inspector.dialect.denormalize_name(schema), @@ -673,7 +696,7 @@ def generate_profile_candidates( schema=schema, entity=inspector.dialect.normalize_name(row[TABLE_NAME_COL_LOC]) or _raise_err( - ValueError(f"Invalid table name: {row[TABLE_NAME_COL_LOC]}") + ValueError(f"Invalid table name: {row[TABLE_NAME_COL_LOC]}"), ), inspector=inspector, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py index 12c98ef11a654d..4cb4fabb260ba8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py @@ -100,7 +100,7 @@ class ViewLineageEntry(BaseModel): class BasePostgresConfig(BasicSQLAlchemyConfig): scheme: str = Field(default="postgresql+psycopg2", description="database scheme") schema_pattern: AllowDenyPattern = Field( - default=AllowDenyPattern(deny=["information_schema"]) + default=AllowDenyPattern(deny=["information_schema"]), ) @@ -158,7 +158,7 @@ def create(cls, config_dict, ctx): def get_inspectors(self) -> Iterable[Inspector]: # Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database url = self.config.get_sql_alchemy_url( - database=self.config.database or self.config.initial_database + database=self.config.database or self.config.initial_database, ) logger.debug(f"sql_alchemy_url={url}") engine = create_engine(url, **self.config.options) @@ -170,7 +170,7 @@ def get_inspectors(self) -> Iterable[Inspector]: # pg_database catalog - https://www.postgresql.org/docs/current/catalog-pg-database.html # exclude template databases - https://www.postgresql.org/docs/current/manage-ag-templatedbs.html databases = conn.execute( - "SELECT datname from pg_database where datname not in ('template0', 'template1')" + "SELECT datname from pg_database where datname not in ('template0', 'template1')", ) for db in databases: if not self.config.database_pattern.allowed(db["datname"]): @@ -189,7 +189,8 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit yield from self._get_view_lineage_workunits(inspector) def _get_view_lineage_elements( - self, inspector: Inspector + self, + inspector: Inspector, ) -> Dict[Tuple[str, str], List[str]]: data: List[ViewLineageEntry] = [] with inspector.engine.connect() as conn: @@ -205,13 +206,13 @@ def _get_view_lineage_elements( for lineage in data: if not self.config.view_pattern.allowed(lineage.dependent_view): self.report.report_dropped( - f"{lineage.dependent_schema}.{lineage.dependent_view}" + f"{lineage.dependent_schema}.{lineage.dependent_view}", ) continue if not self.config.schema_pattern.allowed(lineage.dependent_schema): self.report.report_dropped( - f"{lineage.dependent_schema}.{lineage.dependent_view}" + f"{lineage.dependent_schema}.{lineage.dependent_view}", ) continue @@ -227,13 +228,14 @@ def _get_view_lineage_elements( ), platform_instance=self.config.platform_instance, env=self.config.env, - ) + ), ) return lineage_elements def _get_view_lineage_workunits( - self, inspector: Inspector + self, + inspector: Inspector, ) -> Iterable[MetadataWorkUnit]: lineage_elements = self._get_view_lineage_elements(inspector) @@ -247,7 +249,9 @@ def _get_view_lineage_workunits( # Construct a lineage object. view_identifier = self.get_identifier( - schema=dependent_schema, entity=dependent_view, inspector=inspector + schema=dependent_schema, + entity=dependent_view, + inspector=inspector, ) if view_identifier not in self.views_failed_parsing: return @@ -269,7 +273,12 @@ def _get_view_lineage_workunits( yield item.as_workunit() def get_identifier( - self, *, schema: str, entity: str, inspector: Inspector, **kwargs: Any + self, + *, + schema: str, + entity: str, + inspector: Inspector, + **kwargs: Any, ) -> str: regular = f"{schema}.{entity}" if self.config.database: @@ -281,7 +290,7 @@ def add_profile_metadata(self, inspector: Inspector) -> None: try: with inspector.engine.connect() as conn: for row in conn.execute( - """SELECT table_catalog, table_schema, table_name, pg_table_size('"' || table_catalog || '"."' || table_schema || '"."' || table_name || '"') AS table_size FROM information_schema.TABLES""" + """SELECT table_catalog, table_schema, table_name, pg_table_size('"' || table_catalog || '"."' || table_schema || '"."' || table_name || '"') AS table_size FROM information_schema.TABLES""", ): self.profile_metadata_info.dataset_name_to_storage_bytes[ self.get_identifier( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py b/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py index 9333c6edd1fa5d..a8da4c8b32372b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py @@ -35,7 +35,7 @@ def get_view_names(self, connection, schema: str = None, **kw): # type: ignore SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = :schema and "table_type" = 'VIEW' - """ + """, ).strip() res = connection.execute(sql.text(query), schema=schema) return [row.table_name for row in res] @@ -54,14 +54,17 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): query = dedent( f""" SHOW CREATE VIEW "{schema}"."{view_name}" - """ + """, ).strip() res = connection.execute(sql.text(query)) return next(res)[0] def _get_full_table( # type: ignore - self, table_name: str, schema: Optional[str] = None, quote: bool = True + self, + table_name: str, + schema: Optional[str] = None, + quote: bool = True, ) -> str: table_part = ( self.identifier_preparer.quote_identifier(table_name) if quote else table_name diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index a0bd9ce0760bd1..25c5b1cc0bd33a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -197,7 +197,9 @@ def make_sqlalchemy_type(name: str) -> Type[TypeEngine]: def get_column_type( - sql_report: SQLSourceReport, dataset_name: str, column_type: Any + sql_report: SQLSourceReport, + dataset_name: str, + column_type: Any, ) -> SchemaFieldDataType: """ Maps SQLAlchemy types (https://docs.sqlalchemy.org/en/13/core/type_basics.html) to corresponding schema types @@ -334,7 +336,8 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str) self.domain_registry: Optional[DomainRegistry] = None if self.config.domain: self.domain_registry = DomainRegistry( - cached_domains=[k for k in self.config.domain], graph=self.ctx.graph + cached_domains=[k for k in self.config.domain], + graph=self.ctx.graph, ) self.views_failed_parsing: Set[str] = set() @@ -364,7 +367,8 @@ def test_connection(cls, config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) except Exception as e: test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -503,10 +507,13 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), functools.partial( - auto_incremental_lineage, self.config.incremental_lineage + auto_incremental_lineage, + self.config.incremental_lineage, ), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -523,7 +530,8 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow if sql_config.is_profiling_enabled(): sql_config.options.setdefault( - "max_overflow", sql_config.profiling.max_workers + "max_overflow", + sql_config.profiling.max_workers, ) for inspector in self.get_inspectors(): @@ -557,12 +565,14 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit if profiler: profile_requests += list( - self.loop_profiler_requests(inspector, schema, sql_config) + self.loop_profiler_requests(inspector, schema, sql_config), ) if profiler and profile_requests: yield from self.loop_profiler( - profile_requests, profiler, platform=self.platform + profile_requests, + profiler, + platform=self.platform, ) # Generate workunit for aggregated SQL parsing results @@ -570,7 +580,12 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit yield mcp.as_workunit() def get_identifier( - self, *, schema: str, entity: str, inspector: Inspector, **kwargs: Any + self, + *, + schema: str, + entity: str, + inspector: Inspector, + **kwargs: Any, ) -> str: # Many SQLAlchemy dialects have three-level hierarchies. This method, which # subclasses can override, enables them to modify the identifiers as needed. @@ -614,7 +629,10 @@ def get_foreign_key_metadata( ] return ForeignKeyConstraint( - fk_dict["name"], foreign_fields, source_fields, foreign_dataset + fk_dict["name"], + foreign_fields, + source_fields, + foreign_dataset, ) def make_data_reader(self, inspector: Inspector) -> Optional[DataReader]: @@ -643,14 +661,16 @@ def loop_tables( # noqa: C901 try: for table in inspector.get_table_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=table, inspector=inspector + schema=schema, + entity=table, + inspector=inspector, ) if dataset_name not in tables_seen: tables_seen.add(dataset_name) else: logger.debug( - f"{dataset_name} has already been seen, skipping..." + f"{dataset_name} has already been seen, skipping...", ) continue @@ -685,12 +705,18 @@ def add_information_for_schema(self, inspector: Inspector, schema: str) -> None: pass def get_extra_tags( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Optional[Dict[str, List[str]]]: return None def get_partitions( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Optional[List[str]]: return None @@ -716,7 +742,9 @@ def _process_table( ) description, properties, location_urn = self.get_table_properties( - inspector, schema, table + inspector, + schema, + table, ) dataset_properties = DatasetPropertiesClass( @@ -756,7 +784,9 @@ def _process_table( db_name = self.get_db_name(inspector) yield from self.add_table_to_schema_container( - dataset_urn=dataset_urn, db_name=db_name, schema=schema + dataset_urn=dataset_urn, + db_name=db_name, + schema=schema, ) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) yield SqlWorkUnit(id=dataset_name, mce=mce) @@ -798,7 +828,7 @@ def _classify( try: if ( self.classification_handler.is_classification_enabled_for_table( - dataset_name + dataset_name, ) and data_reader and schema_metadata.fields @@ -811,7 +841,7 @@ def _classify( [schema, table], int( self.config.classification.sample_size - * SAMPLE_SIZE_MULTIPLIER + * SAMPLE_SIZE_MULTIPLIER, ), ), ) @@ -826,17 +856,25 @@ def _classify( ) def get_database_properties( - self, inspector: Inspector, database: str + self, + inspector: Inspector, + database: str, ) -> Optional[Dict[str, str]]: return None def get_schema_properties( - self, inspector: Inspector, database: str, schema: str + self, + inspector: Inspector, + database: str, + schema: str, ) -> Optional[Dict[str, str]]: return None def get_table_properties( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: description: Optional[str] = None properties: Dict[str, str] = {} @@ -869,7 +907,8 @@ def get_table_properties( return description, properties, location def get_dataplatform_instance_aspect( - self, dataset_urn: str + self, + dataset_urn: str, ) -> Optional[MetadataWorkUnit]: # If we are a platform instance based source, emit the instance aspect if self.config.platform_instance: @@ -878,7 +917,8 @@ def get_dataplatform_instance_aspect( aspect=DataPlatformInstanceClass( platform=make_data_platform_urn(self.platform), instance=make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ), ), ).as_workunit() @@ -886,7 +926,11 @@ def get_dataplatform_instance_aspect( return None def _get_columns( - self, dataset_name: str, inspector: Inspector, schema: str, table: str + self, + dataset_name: str, + inspector: Inspector, + schema: str, + table: str, ) -> List[dict]: columns = [] try: @@ -903,7 +947,11 @@ def _get_columns( return columns def _get_foreign_keys( - self, dataset_urn: str, inspector: Inspector, schema: str, table: str + self, + dataset_urn: str, + inspector: Inspector, + schema: str, + table: str, ) -> List[ForeignKeyConstraint]: try: foreign_keys = [ @@ -913,7 +961,7 @@ def _get_foreign_keys( except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( - f"{dataset_urn}: failure in foreign key extraction... skipping" + f"{dataset_urn}: failure in foreign key extraction... skipping", ) foreign_keys = [] return foreign_keys @@ -995,7 +1043,9 @@ def loop_views( try: for view in inspector.get_view_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=view, inspector=inspector + schema=schema, + entity=view, + inspector=inspector, ) self.report.report_entity_scanned(dataset_name, ent_type="view") @@ -1124,7 +1174,9 @@ def _process_view( ).as_workunit() view_properties_aspect = ViewPropertiesClass( - materialized=False, viewLanguage="SQL", viewLogic=view_definition + materialized=False, + viewLanguage="SQL", + viewLogic=view_definition, ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -1169,12 +1221,18 @@ def get_profile_args(self) -> Dict: # Override if needed def generate_partition_profiler_query( - self, schema: str, table: str, partition_datetime: Optional[datetime.datetime] + self, + schema: str, + table: str, + partition_datetime: Optional[datetime.datetime], ) -> Tuple[Optional[str], Optional[str]]: return None, None def is_table_partitioned( - self, database: Optional[str], schema: str, table: str + self, + database: Optional[str], + schema: str, + table: str, ) -> Optional[bool]: return None @@ -1227,22 +1285,29 @@ def loop_profiler_requests( threshold_time: Optional[datetime.datetime] = None if sql_config.profiling.profile_if_updated_since_days is not None: threshold_time = datetime.datetime.now( - datetime.timezone.utc + datetime.timezone.utc, ) - datetime.timedelta( - sql_config.profiling.profile_if_updated_since_days + sql_config.profiling.profile_if_updated_since_days, ) profile_candidates = self.generate_profile_candidates( - inspector, threshold_time, schema + inspector, + threshold_time, + schema, ) except NotImplementedError: logger.debug("Source does not support generating profile candidates.") for table in inspector.get_table_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=table, inspector=inspector + schema=schema, + entity=table, + inspector=inspector, ) if not self.is_dataset_eligible_for_profiling( - dataset_name, schema, inspector, profile_candidates + dataset_name, + schema, + inspector, + profile_candidates, ): self.report.num_tables_not_eligible_profiling[schema] += 1 if self.config.profiling.report_dropped_profiles: @@ -1256,11 +1321,15 @@ def loop_profiler_requests( continue (partition, custom_sql) = self.generate_partition_profiler_query( - schema, table, self.config.profiling.partition_datetime + schema, + table, + self.config.profiling.partition_datetime, ) if partition is None and self.is_table_partitioned( - database=None, schema=schema, table=table + database=None, + schema=schema, + table=table, ): self.warn( logger, @@ -1274,13 +1343,13 @@ def loop_profiler_requests( and not self.config.profiling.partition_profiling_enabled ): logger.debug( - f"{dataset_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled" + f"{dataset_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled", ) continue self.report.report_entity_profiled(dataset_name) logger.debug( - f"Preparing profiling request for {schema}, {table}, {partition}" + f"Preparing profiling request for {schema}, {table}, {partition}", ) yield GEProfilerRequest( pretty_name=dataset_name, @@ -1344,7 +1413,10 @@ def prepare_profiler_args( custom_sql: Optional[str] = None, ) -> dict: return dict( - schema=schema, table=table, partition=partition, custom_sql=custom_sql + schema=schema, + table=table, + partition=partition, + custom_sql=custom_sql, ) def get_schema_resolver(self) -> SchemaResolver: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py index 7d82d99412ffe8..6bec67e3682270 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py @@ -54,7 +54,8 @@ class SQLFilterConfig(ConfigModel): @pydantic.root_validator(pre=True) def view_pattern_is_table_pattern_unless_specified( - cls, values: Dict[str, Any] + cls, + values: Dict[str, Any], ) -> Dict[str, Any]: view_pattern = values.get("view_pattern") table_pattern = values.get("table_pattern") @@ -87,10 +88,12 @@ class SQLCommonConfig( ) include_views: bool = Field( - default=True, description="Whether views should be ingested." + default=True, + description="Whether views should be ingested.", ) include_tables: bool = Field( - default=True, description="Whether tables should be ingested." + default=True, + description="Whether tables should be ingested.", ) include_table_location_lineage: bool = Field( @@ -127,12 +130,13 @@ class SQLCommonConfig( ) def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) @pydantic.root_validator(skip_on_failure=True) def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] + cls, + values: Dict[str, Any], ) -> Dict[str, Any]: profiling: Optional[GEProfilingConfig] = values.get("profiling") # Note: isinstance() check is required here as unity-catalog source reuses @@ -153,7 +157,9 @@ def get_sql_alchemy_url(self): class SQLAlchemyConnectionConfig(ConfigModel): username: Optional[str] = Field(default=None, description="username") password: Optional[pydantic.SecretStr] = Field( - default=None, exclude=True, description="password" + default=None, + exclude=True, + description="password", ) host_port: str = Field(description="host URL") database: Optional[str] = Field(default=None, description="database (catalog)") @@ -177,7 +183,9 @@ class SQLAlchemyConnectionConfig(ConfigModel): _database_alias_removed = pydantic_removed_field("database_alias") def get_sql_alchemy_url( - self, uri_opts: Optional[Dict[str, Any]] = None, database: Optional[str] = None + self, + uri_opts: Optional[Dict[str, Any]] = None, + database: Optional[str] = None, ) -> str: if not ((self.host_port and self.scheme) or self.sqlalchemy_uri): raise ValueError("host_port and schema or connect_uri required.") @@ -225,5 +233,5 @@ def make_sqlalchemy_uri( port=port, database=db, query=uri_opts or {}, - ) + ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py index 78b0dcf9b7be82..031d2aa886edc6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py @@ -52,10 +52,10 @@ class BaseView: class SQLAlchemyGenericConfig(SQLCommonConfig): platform: str = Field( - description="Name of platform being ingested, used in constructing URNs." + description="Name of platform being ingested, used in constructing URNs.", ) connect_uri: str = Field( - description="URI of database to connect to. See https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls" + description="URI of database to connect to. See https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls", ) def get_sql_alchemy_url(self): diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic_profiler.py index 664735053f1852..ed5f91877baab6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic_profiler.py @@ -72,7 +72,7 @@ def generate_profile_workunits( and request.table.size_in_bytes is None ): logger.warning( - f"Table {request.pretty_name} has no column count, rows count, or size in bytes. Skipping emitting table level profile." + f"Table {request.pretty_name} has no column count, rows count, or size in bytes. Skipping emitting table level profile.", ) else: table_level_profile = DatasetProfile( @@ -83,7 +83,8 @@ def generate_profile_workunits( ) dataset_urn = self.dataset_urn_builder(request.pretty_name) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=table_level_profile + entityUrn=dataset_urn, + aspect=table_level_profile, ).as_workunit() if not ge_profile_requests: @@ -93,7 +94,10 @@ def generate_profile_workunits( ge_profiler = self.get_profiler_instance(db_name) for ge_profiler_request, profile in ge_profiler.generate_profiles( - ge_profile_requests, max_workers, platform, profiler_args + ge_profile_requests, + max_workers, + platform, + profiler_args, ): if profile is None: continue @@ -116,10 +120,12 @@ def generate_profile_workunits( # We don't add to the profiler state if we only do table level profiling as it always happens if self.state_handler: self.state_handler.add_to_state( - dataset_urn, int(datetime.now().timestamp() * 1000) + dataset_urn, + int(datetime.now().timestamp() * 1000), ) yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=profile + entityUrn=dataset_urn, + aspect=profile, ).as_workunit() def dataset_urn_builder(self, dataset_name: str) -> str: @@ -135,7 +141,10 @@ def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> s pass def get_profile_request( - self, table: BaseTable, schema_name: str, db_name: str + self, + table: BaseTable, + schema_name: str, + db_name: str, ) -> Optional[TableProfilerRequest]: skip_profiling = False profile_table_level_only = self.config.profiling.profile_table_level_only @@ -147,7 +156,7 @@ def get_profile_request( rows_count=table.rows_count, ): logger.debug( - f"Dataset {dataset_name} was not eligible for profiling due to last_altered, size in bytes or count of rows limit" + f"Dataset {dataset_name} was not eligible for profiling due to last_altered, size in bytes or count of rows limit", ) # Profile only table level if dataset is filtered from profiling # due to size limits alone @@ -182,7 +191,10 @@ def get_profile_request( return profile_request def get_batch_kwargs( - self, table: BaseTable, schema_name: str, db_name: str + self, + table: BaseTable, + schema_name: str, + db_name: str, ) -> dict: return dict( schema=schema_name, @@ -201,7 +213,8 @@ def get_inspectors(self) -> Iterable[Inspector]: yield inspector def get_profiler_instance( - self, db_name: Optional[str] = None + self, + db_name: Optional[str] = None, ) -> "DatahubGEProfiler": logger.debug(f"Getting profiler instance from {self.platform}") url = self.config.get_sql_alchemy_url() @@ -237,7 +250,7 @@ def is_dataset_eligible_for_profiling( if not self.config.table_pattern.allowed(dataset_name): logger.debug( - f"Table {dataset_name} is not allowed for profiling due to table pattern" + f"Table {dataset_name} is not allowed for profiling due to table pattern", ) return False @@ -254,16 +267,17 @@ def is_dataset_eligible_for_profiling( and self.config.profiling.profile_if_updated_since_days is not None ): threshold_time = datetime.now(timezone.utc) - timedelta( - self.config.profiling.profile_if_updated_since_days + self.config.profiling.profile_if_updated_since_days, ) schema_name = dataset_name.rsplit(".", 1)[0] if not check_table_with_profile_pattern( - self.config.profile_pattern, dataset_name + self.config.profile_pattern, + dataset_name, ): self.report.profiling_skipped_table_profile_pattern[schema_name] += 1 logger.debug( - f"Table {dataset_name} is not allowed for profiling due to profile pattern" + f"Table {dataset_name} is not allowed for profiling due to profile pattern", ) return False @@ -272,7 +286,7 @@ def is_dataset_eligible_for_profiling( ): self.report.profiling_skipped_not_updated[schema_name] += 1 logger.debug( - f"Table {dataset_name} was skipped because it was not updated recently enough" + f"Table {dataset_name} was skipped because it was not updated recently enough", ) return False @@ -282,7 +296,7 @@ def is_dataset_eligible_for_profiling( ): self.report.profiling_skipped_size_limit[schema_name] += 1 logger.debug( - f"Table {dataset_name} is not allowed for profiling due to size limit" + f"Table {dataset_name} is not allowed for profiling due to size limit", ) return False @@ -292,7 +306,7 @@ def is_dataset_eligible_for_profiling( ): self.report.profiling_skipped_row_limit[schema_name] += 1 logger.debug( - f"Table {dataset_name} is not allowed for profiling due to row limit" + f"Table {dataset_name} is not allowed for profiling due to row limit", ) return False diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_report.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_report.py index 785972b88a49d7..32229b9c846886 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_report.py @@ -14,24 +14,24 @@ @dataclass class DetailedProfilerReportMixin: profiling_skipped_not_updated: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) profiling_skipped_size_limit: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) profiling_skipped_row_limit: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) profiling_skipped_table_profile_pattern: TopKDict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) profiling_skipped_other: TopKDict[str, int] = field(default_factory=int_top_k_dict) num_tables_not_eligible_profiling: Dict[str, int] = field( - default_factory=int_top_k_dict + default_factory=int_top_k_dict, ) @@ -73,6 +73,7 @@ def report_dropped(self, ent_name: str) -> None: self.filtered.append(ent_name) def report_from_query_combiner( - self, query_combiner_report: SQLAlchemyQueryCombinerReport + self, + query_combiner_report: SQLAlchemyQueryCombinerReport, ) -> None: self.query_combiner = query_combiner_report diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_utils.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_utils.py index 1545de0fff796f..b9752b8d12e09a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_utils.py @@ -48,7 +48,10 @@ def gen_schema_key( def gen_database_key( - database: str, platform: str, platform_instance: Optional[str], env: Optional[str] + database: str, + platform: str, + platform_instance: Optional[str], + env: Optional[str], ) -> DatabaseKey: return DatabaseKey( database=database, @@ -142,7 +145,9 @@ def gen_database_container( if domain_registry: assert domain_config domain_urn = gen_domain_urn( - database, domain_config=domain_config, domain_registry=domain_registry + database, + domain_config=domain_config, + domain_registry=domain_registry, ) yield from gen_containers( @@ -187,7 +192,9 @@ def get_domain_wu( def get_dataplatform_instance_aspect( - dataset_urn: str, platform: str, platform_instance: Optional[str] + dataset_urn: str, + platform: str, + platform_instance: Optional[str], ) -> Optional[MetadataWorkUnit]: # If we are a platform instance based source, emit the instance aspect if platform_instance: @@ -211,8 +218,9 @@ def gen_lineage( if upstream_lineage is not None: lineage_workunits = [ MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=upstream_lineage - ).as_workunit() + entityUrn=dataset_urn, + aspect=upstream_lineage, + ).as_workunit(), ] yield from lineage_workunits @@ -221,7 +229,7 @@ def gen_lineage( # downgrade a schema field def downgrade_schema_field_from_v2(field: SchemaField) -> SchemaField: field.fieldPath = DatasetUrn.get_simple_field_path_from_v2_field_path( - field.fieldPath + field.fieldPath, ) return field @@ -246,7 +254,8 @@ def schema_requires_v2(canonical_schema: List[SchemaField]) -> bool: def check_table_with_profile_pattern( - profile_pattern: AllowDenyPattern, table_name: str + profile_pattern: AllowDenyPattern, + table_name: str, ) -> bool: parts = len(table_name.split(".")) allow_list: List[str] = [] @@ -261,6 +270,7 @@ def check_table_with_profile_pattern( allow_list.append(pattern) table_allow_deny_pattern = AllowDenyPattern( - allow=allow_list, deny=profile_pattern.deny + allow=allow_list, + deny=profile_pattern.deny, ) return table_allow_deny_pattern.allowed(table_name) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_data_reader.py index 0765eee57bf80d..7d432f1a87d935 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_data_reader.py @@ -31,7 +31,10 @@ def _table(self, table_id: List[str]) -> sa.Table: ) def get_sample_data_for_table( - self, table_id: List[str], sample_size: int, **kwargs: Any + self, + table_id: List[str], + sample_size: int, + **kwargs: Any, ) -> Dict[str, list]: """ For sqlalchemy, table_id should be in form (schema_name, table_name) @@ -51,8 +54,9 @@ def get_sample_data_for_table( query = str( raw_query.compile( - self.connection, compile_kwargs={"literal_binds": True} - ) + self.connection, + compile_kwargs={"literal_binds": True}, + ), ) query += "\nAND ROWNUM <= %d" % sample_size else: @@ -66,7 +70,7 @@ def get_sample_data_for_table( time_taken = timer.elapsed_seconds() logger.debug( f"Finished collecting sample values for table {'.'.join(table_id)};" - f"took {time_taken:.3f} seconds" + f"took {time_taken:.3f} seconds", ) return column_values diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_uri_mapper.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_uri_mapper.py index b6a463837228db..607de4b683de86 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_uri_mapper.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sqlalchemy_uri_mapper.py @@ -3,7 +3,8 @@ def _platform_alchemy_uri_tester_gen( - platform: str, opt_starts_with: Optional[str] = None + platform: str, + opt_starts_with: Optional[str] = None, ) -> Tuple[str, Callable[[str], bool]]: return platform, lambda x: x.startswith(opt_starts_with or platform) @@ -36,7 +37,7 @@ def _platform_alchemy_uri_tester_gen( _platform_alchemy_uri_tester_gen("sqlite"), _platform_alchemy_uri_tester_gen("trino"), _platform_alchemy_uri_tester_gen("vertica"), - ] + ], ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py index 84b65d6635e9d4..405e0c72dfcc30 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py @@ -101,7 +101,10 @@ class TeradataTable: # lru cache is set to 1 which work only in single threaded environment but it keeps the memory footprint lower @lru_cache(maxsize=1) def get_schema_columns( - self: Any, connection: Connection, dbc_columns: str, schema: str + self: Any, + connection: Connection, + dbc_columns: str, + schema: str, ) -> Dict[str, List[Any]]: columns: Dict[str, List[Any]] = {} columns_query = f"select * from dbc.{dbc_columns} where DatabaseName (NOT CASESPECIFIC) = '{schema}' (NOT CASESPECIFIC) order by TableName, ColumnId" @@ -119,7 +122,9 @@ def get_schema_columns( # lru cache is set to 1 which work only in single threaded environment but it keeps the memory footprint lower @lru_cache(maxsize=1) def get_schema_pk_constraints( - self: Any, connection: Connection, schema: str + self: Any, + connection: Connection, + schema: str, ) -> Dict[str, List[Any]]: dbc_indices = "IndicesV" + "X" if configure.usexviews else "IndicesV" primary_keys: Dict[str, List[Any]] = {} @@ -168,7 +173,7 @@ def optimized_get_pk_constraint( for index_column in res: index_columns.append(self.normalize_name(index_column.ColumnName)) index_name = self.normalize_name( - index_column.IndexName + index_column.IndexName, ) # There should be just one IndexName return {"constrained_columns": index_columns, "name": index_name} @@ -199,7 +204,7 @@ def optimized_get_columns( if td_table is None: logger.warning( - f"Table {table_name} not found in cache for schema {schema}, not getting columns" + f"Table {table_name} not found in cache for schema {schema}, not getting columns", ) return [] @@ -223,7 +228,8 @@ def optimized_get_columns( dbc_columns = "columnsQV" if use_qvci else "columnsV" dbc_columns = dbc_columns + "X" if configure.usexviews else dbc_columns res = self.get_schema_columns(connection, dbc_columns, schema).get( - table_name, [] + table_name, + [], ) final_column_info = [] @@ -245,7 +251,9 @@ def optimized_get_columns( # lru cache is set to 1 which work only in single threaded environment but it keeps the memory footprint lower @lru_cache(maxsize=1) def get_schema_foreign_keys( - self: Any, connection: Connection, schema: str + self: Any, + connection: Connection, + schema: str, ) -> Dict[str, List[Any]]: dbc_child_parent_table = ( "All_RI_ChildrenV" + "X" if configure.usexviews else "All_RI_ChildrenV" @@ -297,10 +305,10 @@ def grouper(fk_row): for constraint_col in constraint_cols: fk_dict["constrained_columns"].append( - self.normalize_name(constraint_col.ChildKeyColumn) + self.normalize_name(constraint_col.ChildKeyColumn), ) fk_dict["referred_columns"].append( - self.normalize_name(constraint_col.ParentKeyColumn) + self.normalize_name(constraint_col.ParentKeyColumn), ) fk_dicts.append(fk_dict) @@ -395,7 +403,7 @@ class TeradataConfig(BaseTeradataConfig, BaseTimeWindowConfig): "tdwm", "val", "dbc", - ] + ], ), description="Regex patterns for databases to filter in ingestion.", ) @@ -592,7 +600,9 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): setattr(self, "loop_tables", self.cached_loop_tables) # noqa: B010 setattr(self, "loop_views", self.cached_loop_views) # noqa: B010 setattr( # noqa: B010 - self, "get_table_properties", self.cached_get_table_properties + self, + "get_table_properties", + self.cached_get_table_properties, ) tables_cache = self._tables_cache @@ -623,7 +633,11 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): table_name, schema=None, **kw: optimized_get_pk_constraint( - self, connection, table_name, schema, **kw + self, + connection, + table_name, + schema, + **kw, ), ) @@ -635,7 +649,11 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): table_name, schema=None, **kw: optimized_get_foreign_keys( - self, connection, table_name, schema, **kw + self, + connection, + table_name, + schema, + **kw, ), ) @@ -643,7 +661,10 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): TeradataDialect, "get_schema_columns", lambda self, connection, dbc_columns, schema: get_schema_columns( - self, connection, dbc_columns, schema + self, + connection, + dbc_columns, + schema, ), ) @@ -660,7 +681,9 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): TeradataDialect, "get_schema_pk_constraints", lambda self, connection, schema: get_schema_pk_constraints( - self, connection, schema + self, + connection, + schema, ), ) @@ -668,12 +691,14 @@ def __init__(self, config: TeradataConfig, ctx: PipelineContext): TeradataDialect, "get_schema_foreign_keys", lambda self, connection, schema: get_schema_foreign_keys( - self, connection, schema + self, + connection, + schema, ), ) else: logger.info( - "Disabling stale entity removal as tables and views are disabled" + "Disabling stale entity removal as tables and views are disabled", ) if self.config.stateful_ingestion: self.config.stateful_ingestion.remove_stale_metadata = False @@ -746,14 +771,18 @@ def cached_loop_tables( # noqa: C901 lambda schema: [ i.name for i in filter( - lambda t: t.object_type != "View", self._tables_cache[schema] + lambda t: t.object_type != "View", + self._tables_cache[schema], ) ], ) yield from super().loop_tables(inspector, schema, sql_config) def cached_get_table_properties( - self, inspector: Inspector, schema: str, table: str + self, + inspector: Inspector, + schema: str, + table: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: description: Optional[str] = None properties: Dict[str, str] = {} @@ -782,7 +811,8 @@ def cached_loop_views( # noqa: C901 lambda schema: [ i.name for i in filter( - lambda t: t.object_type == "View", self._tables_cache[schema] + lambda t: t.object_type == "View", + self._tables_cache[schema], ) ], ) @@ -830,7 +860,7 @@ def _make_lineage_query(self) -> str: "" if not self.config.databases else "and default_database in ({databases})".format( - databases=",".join([f"'{db}'" for db in self.config.databases]) + databases=",".join([f"'{db}'" for db in self.config.databases]), ) ) @@ -861,7 +891,7 @@ def gen_lineage_from_query( ) if result.debug_info.table_error: logger.debug( - f"Error parsing table lineage ({view_urn}):\n{result.debug_info.table_error}" + f"Error parsing table lineage ({view_urn}):\n{result.debug_info.table_error}", ) self.report.num_table_parse_failures += 1 else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py index b6fa51dd70e18d..befeb506d9f769 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py @@ -92,7 +92,7 @@ def gen_catalog_connector_dict(engine: Engine) -> Dict[str, str]: """ SELECT * FROM "system"."metadata"."catalogs" - """ + """, ).strip() res = engine.execute(sql.text(query)) return {row.catalog_name: row.connector_name for row in res} @@ -114,7 +114,7 @@ def get_table_names(self, connection, schema: str = None, **kw): # type: ignore SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = :schema and "table_type" != 'VIEW' - """ + """, ).strip() res = connection.execute(sql.text(query), schema=schema) return [row.table_name for row in res] @@ -165,7 +165,7 @@ def _get_columns(self, connection, table_name, schema: str = None, **kw): # typ WHERE "table_schema" = :schema AND "table_name" = :table ORDER BY "ordinal_position" ASC - """ + """, ).strip() res = connection.execute(sql.text(query), schema=schema, table=table_name) columns = [] @@ -240,7 +240,10 @@ class TrinoSource(SQLAlchemySource): config: TrinoConfig def __init__( - self, config: TrinoConfig, ctx: PipelineContext, platform: str = "trino" + self, + config: TrinoConfig, + ctx: PipelineContext, + platform: str = "trino", ): super().__init__(config, ctx, platform) @@ -262,10 +265,11 @@ def _get_source_dataset_urn( if not connector_name: return None connector_details = self.config.catalog_to_connector_details.get( - catalog_name, ConnectorDetail() + catalog_name, + ConnectorDetail(), ) connector_platform_name = KNOWN_CONNECTOR_PLATFORM_MAPPING.get( - connector_details.connector_platform or connector_name + connector_details.connector_platform or connector_name, ) if not connector_platform_name: logging.debug(f"Platform '{connector_platform_name}' is not yet supported.") @@ -300,14 +304,16 @@ def gen_siblings_workunit( yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=Siblings( - primary=self.config.trino_as_primary, siblings=[source_dataset_urn] + primary=self.config.trino_as_primary, + siblings=[source_dataset_urn], ), ).as_workunit() yield MetadataChangeProposalWrapper( entityUrn=source_dataset_urn, aspect=Siblings( - primary=not self.config.trino_as_primary, siblings=[dataset_urn] + primary=not self.config.trino_as_primary, + siblings=[dataset_urn], ), ).as_workunit() @@ -323,8 +329,8 @@ def gen_lineage_workunit( entityUrn=dataset_urn, aspect=UpstreamLineage( upstreams=[ - Upstream(dataset=source_dataset_urn, type=DatasetLineageType.VIEW) - ] + Upstream(dataset=source_dataset_urn, type=DatasetLineageType.VIEW), + ], ), ).as_workunit() @@ -338,7 +344,12 @@ def _process_table( data_reader: Optional[DataReader], ) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]: yield from super()._process_table( - dataset_name, inspector, schema, table, sql_config, data_reader + dataset_name, + inspector, + schema, + table, + sql_config, + data_reader, ) if self.config.ingest_lineage_to_connectors: dataset_urn = make_dataset_urn_with_platform_instance( @@ -348,7 +359,10 @@ def _process_table( self.config.env, ) source_dataset_urn = self._get_source_dataset_urn( - dataset_name, inspector, schema, table + dataset_name, + inspector, + schema, + table, ) if source_dataset_urn: yield from self.gen_siblings_workunit(dataset_urn, source_dataset_urn) @@ -363,7 +377,11 @@ def _process_view( sql_config: SQLCommonConfig, ) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]: yield from super()._process_view( - dataset_name, inspector, schema, view, sql_config + dataset_name, + inspector, + schema, + view, + sql_config, ) if self.config.ingest_lineage_to_connectors: dataset_urn = make_dataset_urn_with_platform_instance( @@ -373,7 +391,10 @@ def _process_view( self.config.env, ) source_dataset_urn = self._get_source_dataset_urn( - dataset_name, inspector, schema, view + dataset_name, + inspector, + schema, + view, ) if source_dataset_urn: yield from self.gen_siblings_workunit(dataset_urn, source_dataset_urn) @@ -404,11 +425,13 @@ def get_schema_fields_for_column( field = fields[0] # Get avro schema for subfields along with parent complex field avro_schema = self.get_avro_schema_from_data_type( - column["type"], column["name"] + column["type"], + column["name"], ) newfields = schema_util.avro_schema_to_mce_fields( - json.dumps(avro_schema), default_nullable=True + json.dumps(avro_schema), + default_nullable=True, ) # First field is the parent complex field @@ -420,7 +443,9 @@ def get_schema_fields_for_column( return fields def get_avro_schema_from_data_type( - self, column_type: TypeEngine, column_name: str + self, + column_type: TypeEngine, + column_name: str, ) -> Dict[str, Any]: # Below Record structure represents the dataset level # Inner fields represent the complex field (struct/array/map/union) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/two_tier_sql_source.py b/metadata-ingestion/src/datahub/ingestion/source/sql/two_tier_sql_source.py index 98ad2f6027dfdf..36ae226d4242d0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/two_tier_sql_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/two_tier_sql_source.py @@ -36,7 +36,8 @@ class TwoTierSQLAlchemyConfig(BasicSQLAlchemyConfig): ) _schema_pattern_deprecated = pydantic_renamed_field( - "schema_pattern", "database_pattern" + "schema_pattern", + "database_pattern", ) def get_sql_alchemy_url( @@ -100,7 +101,9 @@ def add_table_to_schema_container( ) def get_allowed_schemas( - self, inspector: Inspector, db_name: str + self, + inspector: Inspector, + db_name: str, ) -> typing.Iterable[str]: # This method returns schema names but for 2 tier databases there is no schema layer at all hence passing # dbName itself as an allowed schema diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py index 92487d48b99e63..5fdee0d21d7982 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py @@ -80,10 +80,12 @@ class VerticaConfig(BasicSQLAlchemyConfig): description="Regex patterns for ml models to filter in ingestion. ", ) include_projections: Optional[bool] = pydantic.Field( - default=True, description="Whether projections should be ingested." + default=True, + description="Whether projections should be ingested.", ) include_models: Optional[bool] = pydantic.Field( - default=True, description="Whether Models should be ingested." + default=True, + description="Whether Models should be ingested.", ) include_view_lineage: bool = pydantic.Field( @@ -145,7 +147,8 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit yield from self.gen_database_containers( database=db_name, extra_properties=self.get_database_properties( - inspector=inspector, database=db_name + inspector=inspector, + database=db_name, ), ) @@ -156,7 +159,9 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit schema=schema, database=db_name, extra_properties=self.get_schema_properties( - inspector=inspector, schema=schema, database=db_name + inspector=inspector, + schema=schema, + database=db_name, ), ) @@ -167,16 +172,23 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit if profiler: profile_requests += list( - self.loop_profiler_requests(inspector, schema, sql_config) + self.loop_profiler_requests(inspector, schema, sql_config), ) if profiler and profile_requests: yield from self.loop_profiler( - profile_requests, profiler, platform=self.platform + profile_requests, + profiler, + platform=self.platform, ) def get_identifier( - self, *, schema: str, entity: str, inspector: VerticaInspector, **kwargs: Any + self, + *, + schema: str, + entity: str, + inspector: VerticaInspector, + **kwargs: Any, ) -> str: regular = f"{schema}.{entity}" if self.config.database: @@ -185,7 +197,9 @@ def get_identifier( return f"{current_database}.{regular}" def get_database_properties( - self, inspector: VerticaInspector, database: str + self, + inspector: VerticaInspector, + database: str, ) -> Optional[Dict[str, str]]: try: custom_properties = inspector._get_database_properties(database) @@ -193,19 +207,24 @@ def get_database_properties( except Exception as ex: self.report.report_failure( - f"{database}", f"unable to get extra_properties : {ex}" + f"{database}", + f"unable to get extra_properties : {ex}", ) return None def get_schema_properties( - self, inspector: VerticaInspector, database: str, schema: str + self, + inspector: VerticaInspector, + database: str, + schema: str, ) -> Optional[Dict[str, str]]: try: custom_properties = inspector._get_schema_properties(schema) return custom_properties except Exception as ex: self.report.report_failure( - f"{database}.{schema}", f"unable to get extra_properties : {ex}" + f"{database}.{schema}", + f"unable to get extra_properties : {ex}", ) return None @@ -231,7 +250,12 @@ def _process_table( owner_urn=f"urn:li:corpuser:{table_owner}", ) yield from super()._process_table( - dataset_name, inspector, schema, table, sql_config, data_reader + dataset_name, + inspector, + schema, + table, + sql_config, + data_reader, ) def loop_views( @@ -243,7 +267,9 @@ def loop_views( try: for view in inspector.get_view_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=view, inspector=inspector + schema=schema, + entity=view, + inspector=inspector, ) self.report.report_entity_scanned(dataset_name, ent_type="view") @@ -262,10 +288,11 @@ def loop_views( ) except Exception as e: logger.warning( - f"Unable to ingest view {schema}.{view} due to an exception.\n {traceback.format_exc()}" + f"Unable to ingest view {schema}.{view} due to an exception.\n {traceback.format_exc()}", ) self.report.report_warning( - f"{schema}.{view}", f"Ingestion error: {e}" + f"{schema}.{view}", + f"Ingestion error: {e}", ) if self.config.include_view_lineage: try: @@ -277,7 +304,8 @@ def loop_views( ) dataset_snapshot = DatasetSnapshot( - urn=dataset_urn, aspects=[StatusClass(removed=False)] + urn=dataset_urn, + aspects=[StatusClass(removed=False)], ) lineage_info = self._get_upstream_lineage_info( @@ -296,10 +324,11 @@ def loop_views( except Exception as e: logger.warning( - f"Unable to get lineage of view {schema}.{view} due to an exception.\n {traceback.format_exc()}" + f"Unable to get lineage of view {schema}.{view} due to an exception.\n {traceback.format_exc()}", ) self.report.report_warning( - f"{schema}.{view}", f"Ingestion error: {e}" + f"{schema}.{view}", + f"Ingestion error: {e}", ) except Exception as e: self.report.report_failure(f"{schema}", f"Views error: {e}") @@ -345,7 +374,11 @@ def _process_view( ) yield from super()._process_view( - dataset_name, inspector, schema, view, sql_config + dataset_name, + inspector, + schema, + view, + sql_config, ) def loop_projections( @@ -375,7 +408,9 @@ def loop_projections( try: for projection in inspector.get_projection_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=projection, inspector=inspector + schema=schema, + entity=projection, + inspector=inspector, ) if dataset_name not in projections_seen: projections_seen.add(dataset_name) @@ -388,7 +423,11 @@ def loop_projections( continue try: yield from self._process_projections( - dataset_name, inspector, schema, projection, sql_config + dataset_name, + inspector, + schema, + projection, + sql_config, ) except Exception as ex: logger.warning( @@ -396,7 +435,8 @@ def loop_projections( ex, ) self.report.report_warning( - f"{schema}.{projection}", f"Ingestion error: {ex}" + f"{schema}.{projection}", + f"Ingestion error: {ex}", ) if self.config.include_projection_lineage: try: @@ -408,21 +448,26 @@ def loop_projections( ) dataset_snapshot = DatasetSnapshot( - urn=dataset_urn, aspects=[StatusClass(removed=False)] + urn=dataset_urn, + aspects=[StatusClass(removed=False)], ) lineage_info = self._get_upstream_lineage_info_projection( - dataset_urn, inspector, projection, schema + dataset_urn, + inspector, + projection, + schema, ) if lineage_info is not None: yield MetadataChangeProposalWrapper( - entityUrn=dataset_snapshot.urn, aspect=lineage_info + entityUrn=dataset_snapshot.urn, + aspect=lineage_info, ).as_workunit() except Exception as e: logger.warning( - f"Unable to get lineage of projection {projection} due to an exception.\n {traceback.format_exc()}" + f"Unable to get lineage of projection {projection} due to an exception.\n {traceback.format_exc()}", ) self.report.report_warning(f"{schema}", f"Ingestion error: {e}") except Exception as ex: @@ -448,7 +493,9 @@ def _process_projections( aspects=[StatusClass(removed=False)], ) description, properties, location_urn = self.get_projection_properties( - inspector, schema, projection + inspector, + schema, + projection, ) dataset_properties = DatasetPropertiesClass( @@ -467,7 +514,10 @@ def _process_projections( # extra_tags = self.get_extra_tags(inspector, schema, projection) pk_constraints: dict = inspector.get_pk_constraint(projection, schema) foreign_keys = self._get_foreign_keys( - dataset_urn, inspector, schema, projection + dataset_urn, + inspector, + schema, + projection, ) schema_fields = self.get_schema_fields( dataset_name, @@ -532,11 +582,16 @@ def loop_profiler_requests( for projection in inspector.get_projection_names(schema): dataset_name = self.get_identifier( - schema=schema, entity=projection, inspector=inspector + schema=schema, + entity=projection, + inspector=inspector, ) if not self.is_dataset_eligible_for_profiling( - dataset_name, schema, inspector, profile_candidates + dataset_name, + schema, + inspector, + profile_candidates, ): if self.config.profiling.report_dropped_profiles: self.report.report_dropped(f"profile of {dataset_name}") @@ -548,10 +603,14 @@ def loop_profiler_requests( continue (partition, custom_sql) = self.generate_partition_profiler_query( - schema, projection, self.config.profiling.partition_datetime + schema, + projection, + self.config.profiling.partition_datetime, ) if partition is None and self.is_table_partitioned( - database=None, schema=schema, table=projection + database=None, + schema=schema, + table=projection, ): self.report.report_warning( "profile skipped as partitioned table is empty or partition id was invalid", @@ -563,12 +622,12 @@ def loop_profiler_requests( and not self.config.profiling.partition_profiling_enabled ): logger.debug( - f"{dataset_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled" + f"{dataset_name} and partition {partition} is skipped because profiling.partition_profiling_enabled property is disabled", ) continue self.report.report_entity_profiled(dataset_name) logger.debug( - f"Preparing profiling request for {schema}, {projection}, {partition}" + f"Preparing profiling request for {schema}, {projection}, {partition}", ) yield GEProfilerRequest( @@ -606,7 +665,9 @@ def loop_models( try: for models in inspector.get_models_names(schema): dataset_name = self.get_identifier( - schema="Entities", entity=models, inspector=inspector + schema="Entities", + entity=models, + inspector=inspector, ) if dataset_name not in models_seen: @@ -628,10 +689,11 @@ def loop_models( ) except Exception as error: logger.warning( - f"Unable to ingest {schema}.{models} due to an exception. %s {traceback.format_exc()}" + f"Unable to ingest {schema}.{models} due to an exception. %s {traceback.format_exc()}", ) self.report.report_warning( - f"{schema}.{models}", f"Ingestion error: {error}" + f"{schema}.{models}", + f"Ingestion error: {error}", ) except Exception as error: self.report.report_failure(f"{schema}", f"Model error: {error}") @@ -668,7 +730,9 @@ def _process_models( aspects=[StatusClass(removed=False)], ) description, properties, location = self.get_model_properties( - inspector, schema, table + inspector, + schema, + table, ) dataset_properties = DatasetPropertiesClass( @@ -711,7 +775,10 @@ def _process_models( ) def get_projection_properties( - self, inspector: VerticaInspector, schema: str, projection: str + self, + inspector: VerticaInspector, + schema: str, + projection: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: """ Returns projection related metadata information to show in properties tab @@ -745,7 +812,10 @@ def get_projection_properties( return description, properties, location def get_model_properties( - self, inspector: VerticaInspector, schema: str, model: str + self, + inspector: VerticaInspector, + schema: str, + model: str, ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: """ Returns ml models related metadata information to show in properties tab @@ -818,7 +888,7 @@ def _get_upstream_lineage_info( if upstream_tables: logger.debug( - f" lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" + f" lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}", ) return UpstreamLineage(upstreams=upstream_tables) @@ -863,7 +933,7 @@ def _get_upstream_lineage_info_projection( if upstream_tables: logger.debug( - f" lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" + f" lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}", ) return UpstreamLineage(upstreams=upstream_tables) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py index f3e8e774e4388f..21937bf9c097ab 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py @@ -45,7 +45,7 @@ class SqlQueriesSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin): query_file: str = Field(description="Path to file to ingest") platform: str = Field( - description="The platform for which to generate data, e.g. snowflake" + description="The platform for which to generate data, e.g. snowflake", ) usage: BaseUsageConfig = Field( @@ -120,7 +120,7 @@ class SqlQueriesSource(Source): def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig): if not ctx.graph: raise ValueError( - "SqlQueriesSource needs a datahub_api from which to pull schema metadata" + "SqlQueriesSource needs a datahub_api from which to pull schema metadata", ) self.graph: DataHubGraph = ctx.graph @@ -196,7 +196,7 @@ def _process_query(self, entry: "QueryEntry") -> Iterable[MetadataWorkUnit]: return elif result.debug_info.column_error: logger.debug( - f"Error parsing column lineage, {result.debug_info.column_error}" + f"Error parsing column lineage, {result.debug_info.column_error}", ) self.report.num_column_parse_failures += 1 @@ -221,7 +221,10 @@ class QueryEntry: @classmethod def create( - cls, entry_dict: dict, *, config: SqlQueriesSourceConfig + cls, + entry_dict: dict, + *, + config: SqlQueriesSourceConfig, ) -> "QueryEntry": return cls( query=entry_dict["query"], diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py index 2c7a4a8b6c137d..0043b802ae3dd6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py @@ -37,7 +37,8 @@ class CheckpointStateBase(ConfigModel): def to_bytes( self, compressor: Callable[[bytes], bytes] = functools.partial( - bz2.compress, compresslevel=9 + bz2.compress, + compresslevel=9, ), max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE, ) -> bytes: @@ -54,7 +55,7 @@ def to_bytes( # The original base85 implementation used pickle, which would cause # issues with deserialization if we ever changed the state class definition. raise ValueError( - "Cannot write base85 encoded bytes. Use base85-bz2-json instead." + "Cannot write base85 encoded bytes. Use base85-bz2-json instead.", ) elif self.serde == "base85-bz2-json": encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor) @@ -63,7 +64,7 @@ def to_bytes( if len(encoded_bytes) > max_allowed_state_size: raise ValueError( - f"The state size has exceeded the max_allowed_state_size of {max_allowed_state_size}" + f"The state size has exceeded the max_allowed_state_size of {max_allowed_state_size}", ) return encoded_bytes @@ -74,7 +75,8 @@ def _to_bytes_utf8(model: ConfigModel) -> bytes: @staticmethod def _to_bytes_base85_json( - model: ConfigModel, compressor: Callable[[bytes], bytes] + model: ConfigModel, + compressor: Callable[[bytes], bytes], ) -> bytes: return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model))) @@ -114,7 +116,8 @@ def create_from_checkpoint_aspect( try: if checkpoint_aspect.state.serde == "utf-8": state_obj = Checkpoint._from_utf8_bytes( - checkpoint_aspect, state_class + checkpoint_aspect, + state_class, ) elif checkpoint_aspect.state.serde == "base85": state_obj = Checkpoint._from_base85_bytes( @@ -132,7 +135,7 @@ def create_from_checkpoint_aspect( raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}") except Exception as e: logger.error( - f"Failed to construct checkpoint class from checkpoint aspect: {e}" + f"Failed to construct checkpoint class from checkpoint aspect: {e}", ) raise e else: @@ -145,7 +148,7 @@ def create_from_checkpoint_aspect( ) logger.info( f"Successfully constructed last checkpoint state for job {job_name} " - f"with timestamp {parse_ts_millis(checkpoint_aspect.timestampMillis)}" + f"with timestamp {parse_ts_millis(checkpoint_aspect.timestampMillis)}", ) return checkpoint return None @@ -171,7 +174,7 @@ def _from_base85_bytes( state_class: Type[StateType], ) -> StateType: state: StateType = pickle.loads( - decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore + decompressor(base64.b85decode(checkpoint_aspect.state.payload)), # type: ignore ) with contextlib.suppress(Exception): @@ -195,7 +198,7 @@ def _from_base85_json_bytes( state_uncompressed = decompressor( base64.b85decode(checkpoint_aspect.state.payload) if checkpoint_aspect.state.payload is not None - else b"{}" + else b"{}", ) state_as_dict = json.loads(state_uncompressed.decode("utf-8")) state_as_dict["version"] = checkpoint_aspect.state.formatVersion @@ -203,14 +206,15 @@ def _from_base85_json_bytes( return state_class.parse_obj(state_as_dict) def to_checkpoint_aspect( - self, max_allowed_state_size: int + self, + max_allowed_state_size: int, ) -> Optional[DatahubIngestionCheckpointClass]: try: checkpoint_state = IngestionCheckpointStateClass( formatVersion=self.state.version, serde=self.state.serde, payload=self.state.to_bytes( - max_allowed_state_size=max_allowed_state_size + max_allowed_state_size=max_allowed_state_size, ), ) checkpoint_aspect = DatahubIngestionCheckpointClass( @@ -224,7 +228,8 @@ def to_checkpoint_aspect( return checkpoint_aspect except Exception as e: logger.error( - "Failed to construct the checkpoint aspect from checkpoint object", e + "Failed to construct the checkpoint aspect from checkpoint object", + e, ) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py index 2b10ca1fa57ed8..734be40249db93 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py @@ -80,7 +80,7 @@ class GenericCheckpointState(CheckpointStateBase): # From dbt: "encoded_node_urns": "dataset", # "encoded_assertion_urns": "assertion", # already handled from SQL - } + }, ) def __init__(self, **data: Any): # type: ignore @@ -102,7 +102,9 @@ def add_checkpoint_urn(self, type: str, urn: str) -> None: self._urns_set.add(urn) def get_urns_not_in( - self, type: str, other_checkpoint_state: "GenericCheckpointState" + self, + type: str, + other_checkpoint_state: "GenericCheckpointState", ) -> Iterable[str]: """ Gets the urns present in this checkpoint but not the other_checkpoint for the given type. @@ -124,7 +126,8 @@ def get_urns_not_in( yield from (urn for urn in diff if guess_entity_type(urn) == type) def get_percent_entities_changed( - self, old_checkpoint_state: "GenericCheckpointState" + self, + old_checkpoint_state: "GenericCheckpointState", ) -> float: """ Returns the percentage of entities that have changed relative to `old_checkpoint_state`. @@ -136,7 +139,8 @@ def get_percent_entities_changed( old_urns_filtered = filter_ignored_entity_types(old_checkpoint_state.urns) return compute_percent_entities_changed( - new_entities=self.urns, old_entities=old_urns_filtered + new_entities=self.urns, + old_entities=old_urns_filtered, ) def urn_count(self) -> int: @@ -144,14 +148,16 @@ def urn_count(self) -> int: def compute_percent_entities_changed( - new_entities: List[str], old_entities: List[str] + new_entities: List[str], + old_entities: List[str], ) -> float: ( overlap_count, old_count, _, ) = _get_entity_overlap_and_cardinalities( - new_entities=new_entities, old_entities=old_entities + new_entities=new_entities, + old_entities=old_entities, ) if old_count: @@ -160,7 +166,8 @@ def compute_percent_entities_changed( def _get_entity_overlap_and_cardinalities( - new_entities: List[str], old_entities: List[str] + new_entities: List[str], + old_entities: List[str], ) -> Tuple[int, int, int]: new_set = set(new_entities) old_set = set(old_entities) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py index 6080ddadb65e40..74d808a155ae22 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/profiling_state_handler.py @@ -111,7 +111,8 @@ def get_last_state(self) -> Optional[ProfilingCheckpointState]: if not self.is_checkpointing_enabled() or self._ignore_old_state(): return None last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, ProfilingCheckpointState + self.job_id, + ProfilingCheckpointState, ) if last_checkpoint and last_checkpoint.state: return cast(ProfilingCheckpointState, last_checkpoint.state) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py index e4a2646f6ccd3c..0f0dd95d68a145 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py @@ -76,7 +76,7 @@ def _init_job_id(self) -> JobId: return JobId( f"{platform}_skip_redundant_run{job_name_suffix}" if platform - else f"skip_redundant_run{job_name_suffix}" + else f"skip_redundant_run{job_name_suffix}", ) @abstractmethod @@ -130,12 +130,15 @@ def get_current_checkpoint( return cur_checkpoint def should_skip_this_run( - self, cur_start_time: datetime, cur_end_time: datetime + self, + cur_start_time: datetime, + cur_end_time: datetime, ) -> bool: skip: bool = False last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, BaseTimeWindowCheckpointState + self.job_id, + BaseTimeWindowCheckpointState, ) if last_checkpoint: @@ -145,7 +148,7 @@ def should_skip_this_run( ) logger.debug( - f"{self.job_id} : Last run start, end times:({last_run_time_window})" + f"{self.job_id} : Last run start, end times:({last_run_time_window})", ) # If current run's time window is subset of last run's time window, then skip. @@ -166,10 +169,12 @@ def suggest_run_time_window( # as part of stateful ingestion configuration. It is likely that they may cause # more confusion than help to most users hence not added to start with. last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, BaseTimeWindowCheckpointState + self.job_id, + BaseTimeWindowCheckpointState, ) if (last_checkpoint is None) or self.should_skip_this_run( - cur_start_time, cur_end_time + cur_start_time, + cur_end_time, ): return cur_start_time, cur_end_time @@ -185,33 +190,34 @@ def suggest_run_time_window( if allow_expand: suggested_start_time = last_run.end_time self.log( - f"Expanding time window. Updating start time to {suggested_start_time}." + f"Expanding time window. Updating start time to {suggested_start_time}.", ) else: self.log( - f"Observed gap in last run end time({last_run.end_time}) and current run start time({cur_start_time})." + f"Observed gap in last run end time({last_run.end_time}) and current run start time({cur_start_time}).", ) elif allow_reduce and cur_run.left_intersects(last_run): # scenario of scheduled ingestions with default start, end times suggested_start_time = last_run.end_time self.log( - f"Reducing time window. Updating start time to {suggested_start_time}." + f"Reducing time window. Updating start time to {suggested_start_time}.", ) elif allow_reduce and cur_run.right_intersects(last_run): # a manual backdated run suggested_end_time = last_run.start_time self.log( - f"Reducing time window. Updating end time to {suggested_end_time}." + f"Reducing time window. Updating end time to {suggested_end_time}.", ) # make sure to consider complete time bucket for usage if last_checkpoint.state.bucket_duration: suggested_start_time = get_time_bucket( - suggested_start_time, last_checkpoint.state.bucket_duration + suggested_start_time, + last_checkpoint.state.bucket_duration, ) self.log( - f"Adjusted start, end times: ({suggested_start_time}, {suggested_end_time})" + f"Adjusted start, end times: ({suggested_start_time}, {suggested_end_time})", ) return (suggested_start_time, suggested_end_time) @@ -236,7 +242,10 @@ def get_job_name_suffix(self): return "_usage" def update_state( - self, start_time: datetime, end_time: datetime, bucket_duration: BucketDuration + self, + start_time: datetime, + end_time: datetime, + bucket_duration: BucketDuration, ) -> None: cur_checkpoint = self.get_current_checkpoint() if cur_checkpoint: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py index 017d78bc1abf8d..43748f2cf41bb3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stale_entity_removal_handler.py @@ -89,7 +89,7 @@ def auto_stale_entity_removal( class StaleEntityRemovalHandler( - StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"] + StatefulIngestionUsecaseHandlerBase["GenericCheckpointState"], ): """ The stateful ingestion helper class that handles stale entity removal. @@ -143,7 +143,9 @@ def workunit_processor(self): @classmethod def compute_job_id( - cls, platform: Optional[str], unique_id: Optional[str] = None + cls, + platform: Optional[str], + unique_id: Optional[str] = None, ) -> JobId: # Handle backward-compatibility for existing sources. backward_comp_platform_to_job_name: Dict[str, str] = { @@ -166,7 +168,7 @@ def compute_job_id( return JobId( f"{platform}_{job_name_suffix}{unique_suffix}" if platform - else job_name_suffix + else job_name_suffix, ) def _init_job_id(self, unique_id: Optional[str] = None) -> JobId: @@ -247,7 +249,8 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: return logger.debug("Checking for stale entity removal.") last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, self.state_type_class + self.job_id, + self.state_type_class, ) if not last_checkpoint: return @@ -283,7 +286,7 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: # Check if the entity delta is below the fail-safe threshold. entity_difference_percent = cur_checkpoint_state.get_percent_entities_changed( - last_checkpoint_state + last_checkpoint_state, ) if not copy_previous_state_and_exit and ( entity_difference_percent @@ -312,7 +315,7 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: if copy_previous_state_and_exit: logger.info( f"Copying urns from last state (size {len(last_checkpoint_state.urns)}) to current state (size {len(cur_checkpoint_state.urns)}) " - "to ensure stale entities from previous runs are deleted on the next successful run." + "to ensure stale entities from previous runs are deleted on the next successful run.", ) for urn in last_checkpoint_state.urns: self.add_entity_to_state("", urn) @@ -323,7 +326,8 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: # Everything looks good, emit the soft-deletion workunits for urn in last_checkpoint_state.get_urns_not_in( - type="*", other_checkpoint_state=cur_checkpoint_state + type="*", + other_checkpoint_state=cur_checkpoint_state, ): entity_type = guess_entity_type(urn) if ( @@ -336,7 +340,7 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: if urn in self._urns_to_skip: report.report_last_state_non_deletable_entities(urn) logger.debug( - f"Not soft-deleting entity {urn} since it is in urns_to_skip" + f"Not soft-deleting entity {urn} since it is in urns_to_skip", ) continue yield self._create_soft_delete_workunit(urn) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 4e9e1425a9ae06..db0a7e25853ab4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -81,7 +81,8 @@ def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values.get("enabled"): if values.get("state_provider") is None: values["state_provider"] = DynamicTypedStateProviderConfig( - type="datahub", config={} + type="datahub", + config={}, ) return values @@ -95,7 +96,8 @@ class StatefulIngestionConfigBase(GenericModel, Generic[CustomConfig]): """ stateful_ingestion: Optional[CustomConfig] = Field( - default=None, description="Stateful Ingestion Config" + default=None, + description="Stateful Ingestion Config", ) @@ -108,7 +110,8 @@ class StatefulLineageConfigMixin(ConfigModel): ) _store_last_lineage_extraction_timestamp = pydantic_renamed_field( - "store_last_lineage_extraction_timestamp", "enable_stateful_lineage_ingestion" + "store_last_lineage_extraction_timestamp", + "enable_stateful_lineage_ingestion", ) @root_validator(skip_on_failure=True) @@ -117,7 +120,7 @@ def lineage_stateful_option_validator(cls, values: Dict) -> Dict: if not sti or not sti.enabled: if values.get("enable_stateful_lineage_ingestion"): logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_lineage_ingestion config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_lineage_ingestion config option as well", ) values["enable_stateful_lineage_ingestion"] = False @@ -133,7 +136,8 @@ class StatefulProfilingConfigMixin(ConfigModel): ) _store_last_profiling_timestamps = pydantic_renamed_field( - "store_last_profiling_timestamps", "enable_stateful_profiling" + "store_last_profiling_timestamps", + "enable_stateful_profiling", ) @root_validator(skip_on_failure=True) @@ -142,7 +146,7 @@ def profiling_stateful_option_validator(cls, values: Dict) -> Dict: if not sti or not sti.enabled: if values.get("enable_stateful_profiling"): logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_profiling config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_profiling config option as well", ) values["enable_stateful_profiling"] = False return values @@ -157,7 +161,8 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig): ) _store_last_usage_extraction_timestamp = pydantic_renamed_field( - "store_last_usage_extraction_timestamp", "enable_stateful_usage_ingestion" + "store_last_usage_extraction_timestamp", + "enable_stateful_usage_ingestion", ) @root_validator(skip_on_failure=True) @@ -166,7 +171,7 @@ def last_usage_extraction_stateful_option_validator(cls, values: Dict) -> Dict: if not sti or not sti.enabled: if values.get("enable_stateful_usage_ingestion"): logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_usage_ingestion config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_usage_ingestion config option as well", ) values["enable_stateful_usage_ingestion"] = False return values @@ -239,7 +244,7 @@ def _initialize_checkpointing_state_provider(self) -> None: and self.ctx.pipeline_name ): logger.info( - "Stateful ingestion will be automatically enabled, as datahub-rest sink is used or `datahub_api` is specified" + "Stateful ingestion will be automatically enabled, as datahub-rest sink is used or `datahub_api` is specified", ) self.stateful_ingestion_config = StatefulIngestionConfig( enabled=True, @@ -253,17 +258,17 @@ def _initialize_checkpointing_state_provider(self) -> None: ): if self.ctx.pipeline_name is None: raise ConfigurationError( - "pipeline_name must be provided if stateful ingestion is enabled." + "pipeline_name must be provided if stateful ingestion is enabled.", ) checkpointing_state_provider_class = ( ingestion_checkpoint_provider_registry.get( - self.stateful_ingestion_config.state_provider.type + self.stateful_ingestion_config.state_provider.type, ) ) if checkpointing_state_provider_class is None: raise ConfigurationError( f"Cannot find checkpoint provider class of type={self.stateful_ingestion_config.state_provider.type} " - " in the registry! Please check the type of the checkpointing provider in your config." + " in the registry! Please check the type of the checkpointing provider in your config.", ) self.ingestion_checkpointing_state_provider = ( checkpointing_state_provider_class.create( @@ -274,21 +279,22 @@ def _initialize_checkpointing_state_provider(self) -> None: assert self.ingestion_checkpointing_state_provider if self.stateful_ingestion_config.ignore_old_state: logger.warning( - "The 'ignore_old_state' config is True. The old checkpoint state will not be provided." + "The 'ignore_old_state' config is True. The old checkpoint state will not be provided.", ) if self.stateful_ingestion_config.ignore_new_state: logger.warning( - "The 'ignore_new_state' config is True. The new checkpoint state will not be created." + "The 'ignore_new_state' config is True. The new checkpoint state will not be created.", ) # Add the checkpoint state provide to the platform context. self.ctx.register_checkpointer(self.ingestion_checkpointing_state_provider) logger.debug( - f"Successfully created {self.stateful_ingestion_config.state_provider.type} state provider." + f"Successfully created {self.stateful_ingestion_config.state_provider.type} state provider.", ) def register_stateful_ingestion_usecase_handler( - self, usecase_handler: StatefulIngestionUsecaseHandlerBase + self, + usecase_handler: StatefulIngestionUsecaseHandlerBase, ) -> None: """ Registers a use-case handler with the common-base class. @@ -327,7 +333,9 @@ def is_checkpointing_enabled(self, job_id: JobId) -> bool: return self._usecase_handlers[job_id].is_checkpointing_enabled() def _get_last_checkpoint( - self, job_id: JobId, checkpoint_state_class: Type[StateType] + self, + job_id: JobId, + checkpoint_state_class: Type[StateType], ) -> Optional[Checkpoint]: """ This is a template method implementation for querying the last checkpoint state. @@ -354,7 +362,9 @@ def _get_last_checkpoint( # Base-class implementations for common state management tasks. def get_last_checkpoint( - self, job_id: JobId, checkpoint_state_class: Type[StateType] + self, + job_id: JobId, + checkpoint_state_class: Type[StateType], ) -> Optional[Checkpoint[StateType]]: if not self.is_stateful_ingestion_configured() or ( self.stateful_ingestion_config @@ -364,7 +374,8 @@ def get_last_checkpoint( if job_id not in self.last_checkpoints: self.last_checkpoints[job_id] = self._get_last_checkpoint( - job_id, checkpoint_state_class + job_id, + checkpoint_state_class, ) return self.last_checkpoints[job_id] @@ -391,13 +402,13 @@ def _prepare_checkpoint_states_for_commit(self) -> None: and self.stateful_ingestion_config.ignore_new_state ): logger.info( - "The `ignore_new_state` config is True. Not committing current checkpoint." + "The `ignore_new_state` config is True. Not committing current checkpoint.", ) return None if self.ctx.dry_run_mode or self.ctx.preview_mode: logger.warning( f"Will not be committing checkpoints in dry_run_mode(={self.ctx.dry_run_mode})" - f" or preview_mode(={self.ctx.preview_mode})." + f" or preview_mode(={self.ctx.preview_mode}).", ) return None @@ -409,7 +420,7 @@ def _prepare_checkpoint_states_for_commit(self) -> None: job_checkpoint.prepare_for_commit() try: checkpoint_aspect = job_checkpoint.to_checkpoint_aspect( - self.stateful_ingestion_config.max_checkpoint_state_size + self.stateful_ingestion_config.max_checkpoint_state_size, ) except Exception as e: logger.error( @@ -423,7 +434,7 @@ def _prepare_checkpoint_states_for_commit(self) -> None: # Set the state to commit in the provider. assert self.ingestion_checkpointing_state_provider self.ingestion_checkpointing_state_provider.state_to_commit.update( - job_checkpoint_aspects + job_checkpoint_aspects, ) def prepare_for_commit(self) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py index 1f5a651fc64a79..62c8039f340328 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py @@ -32,12 +32,14 @@ def __init__( if not self._is_server_stateful_ingestion_capable(): raise ConfigurationError( "Datahub server is not capable of supporting stateful ingestion. " - "Please consider upgrading to the latest server version to use this feature." + "Please consider upgrading to the latest server version to use this feature.", ) @classmethod def create( - cls, config_dict: Dict[str, Any], ctx: PipelineContext + cls, + config_dict: Dict[str, Any], + ctx: PipelineContext, ) -> "DatahubIngestionCheckpointingProvider": config = DatahubIngestionStateProviderConfig.parse_obj(config_dict) if config.datahub_api is not None: @@ -47,7 +49,7 @@ def create( return cls(ctx.graph) else: raise ValueError( - "A graph instance is required. Either pass one in the pipeline context, or set it explicitly in the stateful ingestion provider config." + "A graph instance is required. Either pass one in the pipeline context, or set it explicitly in the stateful ingestion provider config.", ) def _is_server_stateful_ingestion_capable(self) -> bool: @@ -63,11 +65,13 @@ def get_latest_checkpoint( ) -> Optional[DatahubIngestionCheckpointClass]: logger.debug( f"Querying for the latest ingestion checkpoint for pipelineName:'{pipeline_name}'," - f" job_name:'{job_name}'" + f" job_name:'{job_name}'", ) data_job_urn = self.get_data_job_urn( - self.orchestrator_name, pipeline_name, job_name + self.orchestrator_name, + pipeline_name, + job_name, ) latest_checkpoint: Optional[DatahubIngestionCheckpointClass] = ( @@ -83,13 +87,13 @@ def get_latest_checkpoint( logger.debug( f"The last committed ingestion checkpoint for pipelineName:'{pipeline_name}'," f" job_name:'{job_name}' found with start_time:" - f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis / 1000)}" + f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis / 1000)}", ) return latest_checkpoint else: logger.debug( f"No committed ingestion checkpoint for pipelineName:'{pipeline_name}'," - f" job_name:'{job_name}' found" + f" job_name:'{job_name}' found", ) return None @@ -103,7 +107,7 @@ def commit(self) -> None: # Emit the ingestion state for each job logger.debug( f"Committing ingestion checkpoint for pipeline:'{checkpoint.pipelineName}', " - f"job:'{job_name}'" + f"job:'{job_name}'", ) self.committed = False @@ -124,12 +128,12 @@ def commit(self) -> None: MetadataChangeProposalWrapper( entityUrn=datajob_urn, aspect=checkpoint, - ) + ), ) self.committed = True logger.debug( f"Committed ingestion checkpoint for pipeline:'{checkpoint.pipelineName}', " - f"job:'{job_name}'" + f"job:'{job_name}'", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py index 55f0903b9c91c7..a80a22afb47bab 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py @@ -30,7 +30,9 @@ def __init__(self, config: FileIngestionStateProviderConfig): @classmethod def create( - cls, config_dict: Dict[str, Any], ctx: PipelineContext + cls, + config_dict: Dict[str, Any], + ctx: PipelineContext, ) -> "FileIngestionCheckpointingProvider": config = FileIngestionStateProviderConfig.parse_obj(config_dict) return cls(config) @@ -42,11 +44,13 @@ def get_latest_checkpoint( ) -> Optional[DatahubIngestionCheckpointClass]: logger.debug( f"Querying for the latest ingestion checkpoint for pipelineName:'{pipeline_name}'," - f" job_name:'{job_name}'" + f" job_name:'{job_name}'", ) data_job_urn = self.get_data_job_urn( - self.orchestrator_name, pipeline_name, job_name + self.orchestrator_name, + pipeline_name, + job_name, ) latest_checkpoint: Optional[DatahubIngestionCheckpointClass] = None try: @@ -67,13 +71,13 @@ def get_latest_checkpoint( logger.debug( f"The last committed ingestion checkpoint for pipelineName:'{pipeline_name}'," f" job_name:'{job_name}' found with start_time:" - f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis / 1000)}" + f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis / 1000)}", ) return latest_checkpoint else: logger.debug( f"No committed ingestion checkpoint for pipelineName:'{pipeline_name}'," - f" job_name:'{job_name}' found" + f" job_name:'{job_name}' found", ) return None @@ -88,7 +92,7 @@ def commit(self) -> None: # Emit the ingestion state for each job logger.debug( f"Committing ingestion checkpoint for pipeline:'{checkpoint.pipelineName}', " - f"job:'{job_name}'" + f"job:'{job_name}'", ) datajob_urn = self.get_data_job_urn( self.orchestrator_name, @@ -99,10 +103,10 @@ def commit(self) -> None: MetadataChangeProposalWrapper( entityUrn=datajob_urn, aspect=checkpoint, - ) + ), ) write_metadata_file(pathlib.Path(self.config.filename), checkpoint_workunits) self.committed = True logger.debug( - f"Committed all ingestion checkpoints for pipeline:'{checkpoint.pipelineName}'" + f"Committed all ingestion checkpoints for pipeline:'{checkpoint.pipelineName}'", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/state_provider_registry.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/state_provider_registry.py index 2817302e97f775..869cbe4707072d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/state_provider_registry.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/state_provider_registry.py @@ -7,5 +7,5 @@ IngestionCheckpointingProviderBase ]() ingestion_checkpoint_provider_registry.register_from_entrypoint( - "datahub.ingestion.checkpointing_provider.plugins" + "datahub.ingestion.checkpointing_provider.plugins", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/superset.py b/metadata-ingestion/src/datahub/ingestion/source/superset.py index a8b328f6e17739..f239f7f928bf8c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/superset.py +++ b/metadata-ingestion/src/datahub/ingestion/source/superset.py @@ -121,12 +121,15 @@ def modified_ts(self) -> Optional[int]: class SupersetConfig( - StatefulIngestionConfigBase, EnvConfigMixin, PlatformInstanceConfigMixin + StatefulIngestionConfigBase, + EnvConfigMixin, + PlatformInstanceConfigMixin, ): # See the Superset /security/login endpoint for details # https://superset.apache.org/docs/rest-api connect_uri: str = Field( - default="http://localhost:8088", description="Superset host URL." + default="http://localhost:8088", + description="Superset host URL.", ) display_uri: Optional[str] = Field( default=None, @@ -140,14 +143,17 @@ class SupersetConfig( password: Optional[str] = Field(default=None, description="Superset password.") # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="Superset Stateful Ingestion Config." + default=None, + description="Superset Stateful Ingestion Config.", ) ingest_dashboards: bool = Field( - default=True, description="Enable to ingest dashboards." + default=True, + description="Enable to ingest dashboards.", ) ingest_charts: bool = Field(default=True, description="Enable to ingest charts.") ingest_datasets: bool = Field( - default=False, description="Enable to ingest datasets." + default=False, + description="Enable to ingest datasets.", ) provider: str = Field(default="db", description="Superset provider.") @@ -203,7 +209,8 @@ def get_filter_name(filter_obj): @config_class(SupersetConfig) @support_status(SupportStatus.CERTIFIED) @capability( - SourceCapability.DELETION_DETECTION, "Optionally enabled via stateful_ingestion" + SourceCapability.DELETION_DETECTION, + "Optionally enabled via stateful_ingestion", ) @capability(SourceCapability.DOMAINS, "Enabled by `domain` config to assign domain_key") @capability(SourceCapability.LINEAGE_COARSE, "Supported by default") @@ -253,12 +260,12 @@ def login(self) -> requests.Session: "Authorization": f"Bearer {self.access_token}", "Content-Type": "application/json", "Accept": "*/*", - } + }, ) # Test the connection test_response = requests_session.get( - f"{self.config.connect_uri}/api/v1/dashboard/" + f"{self.config.connect_uri}/api/v1/dashboard/", ) if test_response.status_code == 200: pass @@ -290,12 +297,13 @@ def paginate_entity_api_results(self, entity_type, page_size=100): @lru_cache(maxsize=None) def get_platform_from_database_id(self, database_id): database_response = self.session.get( - f"{self.config.connect_uri}/api/v1/database/{database_id}" + f"{self.config.connect_uri}/api/v1/database/{database_id}", ).json() sqlalchemy_uri = database_response.get("result", {}).get("sqlalchemy_uri") if sqlalchemy_uri is None: platform_name = database_response.get("result", {}).get( - "backend", "external" + "backend", + "external", ) else: platform_name = get_platform_from_sqlalchemy_uri(sqlalchemy_uri) @@ -318,7 +326,9 @@ def get_dataset_info(self, dataset_id: int) -> dict: return dataset_response.json() def get_datasource_urn_from_id( - self, dataset_response: dict, platform_instance: str + self, + dataset_response: dict, + platform_instance: str, ) -> str: schema_name = dataset_response.get("result", {}).get("schema") table_name = dataset_response.get("result", {}).get("table_name") @@ -351,7 +361,8 @@ def get_datasource_urn_from_id( raise ValueError("Could not construct dataset URN") def construct_dashboard_from_api_data( - self, dashboard_data: dict + self, + dashboard_data: dict, ) -> DashboardSnapshot: dashboard_urn = make_dashboard_urn( platform=self.platform, @@ -365,7 +376,7 @@ def construct_dashboard_from_api_data( modified_actor = f"urn:li:corpuser:{(dashboard_data.get('changed_by') or {}).get('username', 'unknown')}" modified_ts = int( - dp.parse(dashboard_data.get("changed_on_utc", "now")).timestamp() * 1000 + dp.parse(dashboard_data.get("changed_on_utc", "now")).timestamp() * 1000, ) title = dashboard_data.get("dashboard_title", "") # note: the API does not currently supply created_by usernames due to a bug @@ -388,7 +399,7 @@ def construct_dashboard_from_api_data( platform=self.platform, name=value.get("meta", {}).get("chartId", "unknown"), platform_instance=self.config.platform_instance, - ) + ), ) # Build properties @@ -399,17 +410,17 @@ def construct_dashboard_from_api_data( map( lambda owner: owner.get("username", "unknown"), dashboard_data.get("owners", []), - ) + ), ), "IsCertified": str( - True if dashboard_data.get("certified_by") else False + True if dashboard_data.get("certified_by") else False, ).lower(), } if dashboard_data.get("certified_by"): custom_properties["CertifiedBy"] = dashboard_data.get("certified_by", "") custom_properties["CertificationDetails"] = str( - dashboard_data.get("certification_details") + dashboard_data.get("certification_details"), ) # Create DashboardInfo object @@ -428,11 +439,11 @@ def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: for dashboard_data in self.paginate_entity_api_results("dashboard", PAGE_SIZE): try: dashboard_snapshot = self.construct_dashboard_from_api_data( - dashboard_data + dashboard_data, ) except Exception as e: self.report.warning( - f"Failed to construct dashboard snapshot. Dashboard name: {dashboard_data.get('dashboard_title')}. Error: \n{e}" + f"Failed to construct dashboard snapshot. Dashboard name: {dashboard_data.get('dashboard_title')}. Error: \n{e}", ) continue # Emit the dashboard @@ -456,7 +467,7 @@ def construct_chart_from_chart_data(self, chart_data: dict) -> ChartSnapshot: modified_actor = f"urn:li:corpuser:{(chart_data.get('changed_by') or {}).get('username', 'unknown')}" modified_ts = int( - dp.parse(chart_data.get("changed_on_utc", "now")).timestamp() * 1000 + dp.parse(chart_data.get("changed_on_utc", "now")).timestamp() * 1000, ) title = chart_data.get("slice_name", "") @@ -471,7 +482,8 @@ def construct_chart_from_chart_data(self, chart_data: dict) -> ChartSnapshot: datasource_id = chart_data.get("datasource_id") dataset_response = self.get_dataset_info(datasource_id) datasource_urn = self.get_datasource_urn_from_id( - dataset_response, self.platform + dataset_response, + self.platform, ) params = json.loads(chart_data.get("params", "{}")) @@ -531,7 +543,7 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]: mce = MetadataChangeEvent(proposedSnapshot=chart_snapshot) except Exception as e: self.report.warning( - f"Failed to construct chart snapshot. Chart name: {chart_data.get('table_name')}. Error: \n{e}" + f"Failed to construct chart snapshot. Chart name: {chart_data.get('table_name')}. Error: \n{e}", ) continue # Emit the chart @@ -584,12 +596,14 @@ def gen_dataset_urn(self, datahub_dataset_name: str) -> str: ) def construct_dataset_from_dataset_data( - self, dataset_data: dict + self, + dataset_data: dict, ) -> DatasetSnapshot: dataset_response = self.get_dataset_info(dataset_data.get("id")) dataset = SupersetDataset(**dataset_response["result"]) datasource_urn = self.get_datasource_urn_from_id( - dataset_response, self.platform + dataset_response, + self.platform, ) dataset_url = f"{self.config.display_uri}{dataset.explore_url or ''}" @@ -607,7 +621,7 @@ def construct_dataset_from_dataset_data( [ self.gen_schema_metadata(dataset_response), dataset_info, - ] + ], ) dataset_snapshot = DatasetSnapshot( @@ -620,12 +634,12 @@ def emit_dataset_mces(self) -> Iterable[MetadataWorkUnit]: for dataset_data in self.paginate_entity_api_results("dataset", PAGE_SIZE): try: dataset_snapshot = self.construct_dataset_from_dataset_data( - dataset_data + dataset_data, ) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) except Exception as e: self.report.warning( - f"Failed to construct dataset snapshot. Dataset name: {dataset_data.get('table_name')}. Error: \n{e}" + f"Failed to construct dataset snapshot. Dataset name: {dataset_data.get('table_name')}. Error: \n{e}", ) continue # Emit the dataset @@ -647,7 +661,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -659,7 +675,7 @@ def _get_domain_wu(self, title: str, entity_urn: str) -> Iterable[MetadataWorkUn for domain, pattern in self.config.domain.items(): if pattern.allowed(title): domain_urn = make_domain_urn( - self.domain_registry.get_domain_urn(domain) + self.domain_registry.get_domain_urn(domain), ) break diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py index f961bd8ecba604..914e5b6fdbdb63 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py @@ -260,7 +260,8 @@ def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) def get_tableau_auth( - self, site: str + self, + site: str, ) -> Union[TableauAuth, PersonalAccessTokenAuth]: # https://tableau.github.io/server-client-python/docs/api-ref#authentication authentication: Union[TableauAuth, PersonalAccessTokenAuth] @@ -272,11 +273,13 @@ def get_tableau_auth( ) elif self.token_name and self.token_value: authentication = PersonalAccessTokenAuth( - self.token_name, self.token_value, site + self.token_name, + self.token_value, + site, ) else: raise ConfigurationError( - "Tableau Source: Either username/password or token_name/token_value must be set" + "Tableau Source: Either username/password or token_name/token_value must be set", ) return authentication @@ -307,7 +310,7 @@ def make_tableau_client(self, site: str) -> Server: total=self.max_retries, backoff_factor=1, status_forcelist=RETRIABLE_ERROR_CODES, - ) + ), ) server._session.mount("http://", adapter) server._session.mount("https://", adapter) @@ -323,7 +326,7 @@ def make_tableau_client(self, site: str) -> Server: raise ValueError(message) from e except Exception as e: raise ValueError( - f"Unable to login (check your Tableau connection and credentials): {str(e)}" + f"Unable to login (check your Tableau connection and credentials): {str(e)}", ) from e @@ -512,7 +515,8 @@ class TableauConfig( ) default_schema_map: Dict[str, str] = Field( - default={}, description="Default schema to use when schema is not found." + default={}, + description="Default schema to use when schema is not found.", ) ingest_tags: Optional[bool] = Field( default=False, @@ -548,7 +552,8 @@ class TableauConfig( ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="" + default=None, + description="", ) ingest_embed_url: Optional[bool] = Field( @@ -635,19 +640,19 @@ def projects_backward_compatibility(cls, values: Dict) -> Dict: project_path_pattern = values.get("project_path_pattern") if project_pattern is None and project_path_pattern is None and projects: logger.warning( - "projects is deprecated, please use project_path_pattern instead." + "projects is deprecated, please use project_path_pattern instead.", ) logger.info("Initializing project_pattern from projects") values["project_pattern"] = AllowDenyPattern( - allow=[f"^{prj}$" for prj in projects] + allow=[f"^{prj}$" for prj in projects], ) elif (project_pattern or project_path_pattern) and projects: raise ValueError( - "projects is deprecated. Please use project_path_pattern only." + "projects is deprecated. Please use project_path_pattern only.", ) elif project_path_pattern and project_pattern: raise ValueError( - "project_pattern is deprecated. Please use project_path_pattern only." + "project_pattern is deprecated. Please use project_path_pattern only.", ) return values @@ -662,7 +667,7 @@ def validate_config_values(cls, values: Dict) -> Dict: and len(tags_for_hidden_assets) > 0 ): raise ValueError( - "tags_for_hidden_assets is only allowed with ingest_tags enabled. Be aware that this will overwrite tags entered from the UI." + "tags_for_hidden_assets is only allowed with ingest_tags enabled. Be aware that this will overwrite tags entered from the UI.", ) return values @@ -764,29 +769,29 @@ class TableauSourceReport( num_table_field_skipped_no_name: int = 0 # timers extract_usage_stats_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) fetch_groups_timer: Dict[str, float] = dataclass_field(default_factory=TopKDict) populate_database_server_hostname_map_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) populate_projects_registry_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) emit_workbooks_timer: Dict[str, float] = dataclass_field(default_factory=TopKDict) emit_sheets_timer: Dict[str, float] = dataclass_field(default_factory=TopKDict) emit_dashboards_timer: Dict[str, float] = dataclass_field(default_factory=TopKDict) emit_embedded_datasources_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) emit_published_datasources_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) emit_custom_sql_datasources_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) emit_upstream_tables_timer: Dict[str, float] = dataclass_field( - default_factory=TopKDict + default_factory=TopKDict, ) # lineage num_tables_with_upstream_lineage: int = 0 @@ -805,7 +810,7 @@ class TableauSourceReport( num_expected_tableau_metadata_queries: int = 0 num_actual_tableau_metadata_queries: int = 0 tableau_server_error_stats: Dict[str, int] = dataclass_field( - default_factory=(lambda: defaultdict(int)) + default_factory=(lambda: defaultdict(int)), ) # Counters for tracking the number of queries made to get_connection_objects method @@ -816,13 +821,13 @@ class TableauSourceReport( # These counters are useful to understand the impact of changing the page size. num_queries_by_connection_type: Dict[str, int] = dataclass_field( - default_factory=(lambda: defaultdict(int)) + default_factory=(lambda: defaultdict(int)), ) num_filter_queries_by_connection_type: Dict[str, int] = dataclass_field( - default_factory=(lambda: defaultdict(int)) + default_factory=(lambda: defaultdict(int)), ) num_paginated_queries_by_connection_type: Dict[str, int] = dataclass_field( - default_factory=(lambda: defaultdict(int)) + default_factory=(lambda: defaultdict(int)), ) @@ -915,13 +920,14 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report.basic_connectivity = CapabilityReport(capable=True) test_report.capability_report = check_user_role( - logged_in_user=UserInfo.from_server(server=server) + logged_in_user=UserInfo.from_server(server=server), ) except Exception as e: logger.warning(f"{e}", exc_info=e) test_report.basic_connectivity = CapabilityReport( - capable=False, failure_reason=str(e) + capable=False, + failure_reason=str(e), ) return test_report @@ -932,7 +938,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -947,7 +955,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: or not self.config.site_name_pattern.allowed(site.name) ): logger.info( - f"Skip site '{site.name}' as it's excluded in site_name_pattern or inactive." + f"Skip site '{site.name}' as it's excluded in site_name_pattern or inactive.", ) continue self.server.auth.switch_site(site) @@ -1186,7 +1194,7 @@ def _is_allowed_project(self, project: TableauProject) -> bool: ) and self.config.project_path_pattern.allowed(self._get_project_path(project)) if is_allowed is False: logger.info( - f"Project ({project.name}) is not allowed as per project_pattern or project_path_pattern" + f"Project ({project.name}) is not allowed as per project_pattern or project_path_pattern", ) return is_allowed @@ -1219,7 +1227,7 @@ def _is_denied_project(self, project: TableauProject) -> bool: # TableauConfig.projects_backward_compatibility ensures that at least one of these properties is configured. return self.config.project_pattern.denied( - project.name + project.name, ) or self.config.project_path_pattern.denied(self._get_project_path(project)) def _init_tableau_project_registry(self, all_project_map: dict) -> None: @@ -1238,11 +1246,11 @@ def _init_tableau_project_registry(self, all_project_map: dict) -> None: if self.config.extract_project_hierarchy is False: logger.debug( "Skipping project hierarchy processing as configuration extract_project_hierarchy is " - "disabled" + "disabled", ) else: logger.debug( - "Reevaluating projects as extract_project_hierarchy is enabled" + "Reevaluating projects as extract_project_hierarchy is enabled", ) for project in list_of_skip_projects: @@ -1251,14 +1259,14 @@ def _init_tableau_project_registry(self, all_project_map: dict) -> None: and not self._is_denied_project(project) ): logger.debug( - f"Project {project.name} is added in project registry as it's a child project and not explicitly denied in `deny` list" + f"Project {project.name} is added in project registry as it's a child project and not explicitly denied in `deny` list", ) projects_to_ingest[project.id] = project # We rely on automatic browse paths (v2) when creating containers. That's why we need to sort the projects here. # Otherwise, nested projects will not have the correct browse paths if not created in correct order / hierarchy. self.tableau_project_registry = OrderedDict( - sorted(projects_to_ingest.items(), key=lambda item: len(item[1].path)) + sorted(projects_to_ingest.items(), key=lambda item: len(item[1].path)), ) def _init_datasource_registry(self) -> None: @@ -1270,7 +1278,7 @@ def _init_datasource_registry(self) -> None: if ds.project_id not in self.tableau_project_registry: logger.debug( f"project id ({ds.project_id}) of datasource {ds.name} is not present in project " - f"registry" + f"registry", ) continue self.datasource_project_map[ds.id] = ds.project_id @@ -1290,7 +1298,7 @@ def _init_workbook_registry(self) -> None: if wb.project_id not in self.tableau_project_registry: logger.debug( f"project id ({wb.project_id}) of workbook {wb.name} is not present in project " - f"registry" + f"registry", ) continue self.workbook_project_map[wb.id] = wb.project_id @@ -1322,7 +1330,8 @@ def get_data_platform_instance(self) -> DataPlatformInstanceClass: platform=builder.make_data_platform_urn(self.platform), instance=( builder.make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.platform, + self.config.platform_instance, ) if self.config.platform_instance else None @@ -1359,7 +1368,7 @@ def get_connection_object_page( logger.debug( f"Query {connection_type} to get {fetch_size} objects with cursor {current_cursor}" - f" and filter {query_filter}" + f" and filter {query_filter}", ) try: assert self.server is not None @@ -1512,11 +1521,12 @@ def get_connection_object_page( # This is a pretty dumb backoff mechanism, but it's good enough for now. backoff_time = min( - (self.config.max_retries - retries_remaining + 1) ** 2, 60 + (self.config.max_retries - retries_remaining + 1) ** 2, + 60, ) logger.info( f"Query {connection_type} received a 30 second timeout error - will retry in {backoff_time} seconds. " - f"Retries remaining: {retries_remaining}" + f"Retries remaining: {retries_remaining}", ) time.sleep(backoff_time) return self.get_connection_object_page( @@ -1533,7 +1543,8 @@ def get_connection_object_page( connection_object = query_data.get(c.DATA, {}).get(connection_type, {}) has_next_page = connection_object.get(c.PAGE_INFO, {}).get( - c.HAS_NEXT_PAGE, False + c.HAS_NEXT_PAGE, + False, ) next_cursor = connection_object.get(c.PAGE_INFO, {}).get( @@ -1558,7 +1569,7 @@ def get_connection_objects( filter_pages = get_filter_pages(query_filter, page_size) self.report.num_queries_by_connection_type[connection_type] += 1 self.report.num_filter_queries_by_connection_type[connection_type] += len( - filter_pages + filter_pages, ) for filter_page in filter_pages: @@ -1664,7 +1675,7 @@ def _create_upstream_table_lineage( logger.debug( f"Embedded datasource {datasource.get(c.ID)} has upstreamDatasources.\ Setting only upstreamDatasources lineage. The upstreamTables lineage \ - will be set via upstream published datasource." + will be set via upstream published datasource.", ) else: # This adds an edge to upstream DatabaseTables using `upstreamTables` @@ -1685,7 +1696,7 @@ def _create_upstream_table_lineage( table_id_to_urn.update(csql_id_to_urn) logger.debug( - f"A total of {len(upstream_tables)} upstream table edges found for datasource {datasource[c.ID]}" + f"A total of {len(upstream_tables)} upstream table edges found for datasource {datasource[c.ID]}", ) datasource_urn = builder.make_dataset_urn_with_platform_instance( @@ -1712,7 +1723,9 @@ def _create_upstream_table_lineage( and column.get(c.TABLE, {}).get(c.ID) } fine_grained_lineages = self.get_upstream_columns_of_fields_in_datasource( - datasource, datasource_urn, table_id_to_urn + datasource, + datasource_urn, + table_id_to_urn, ) upstream_tables = [ Upstream(dataset=table_urn, type=DatasetLineageType.TRANSFORMED) @@ -1724,7 +1737,8 @@ def _create_upstream_table_lineage( # Find fine grained lineage for datasource column to datasource column edge, # upstream columns may be from same datasource upstream_fields = self.get_upstream_fields_of_field_in_datasource( - datasource, datasource_urn + datasource, + datasource_urn, ) fine_grained_lineages.extend(upstream_fields) @@ -1737,7 +1751,7 @@ def _create_upstream_table_lineage( fine_grained_lineages.extend(upstream_columns) logger.debug( - f"A total of {len(fine_grained_lineages)} upstream column edges found for datasource {datasource[c.ID]}" + f"A total of {len(fine_grained_lineages)} upstream column edges found for datasource {datasource[c.ID]}", ) return upstream_tables, fine_grained_lineages @@ -1762,7 +1776,8 @@ def get_upstream_datasources(self, datasource: dict) -> List[Upstream]: return upstream_tables def get_upstream_csql_tables( - self, fields: List[dict] + self, + fields: List[dict], ) -> Tuple[List[Upstream], Dict[str, str]]: upstream_csql_urns = set() csql_id_to_urn = {} @@ -1812,19 +1827,20 @@ def get_upstream_tables( if not is_custom_sql and not num_tbl_cols: self.report.num_upstream_table_skipped_no_columns += 1 logger.warning( - f"Skipping upstream table with id {table[c.ID]}, no columns: {table}" + f"Skipping upstream table with id {table[c.ID]}, no columns: {table}", ) continue elif table[c.NAME] is None: self.report.num_upstream_table_skipped_no_name += 1 logger.warning( - f"Skipping upstream table {table[c.ID]} from lineage since its name is none: {table}" + f"Skipping upstream table {table[c.ID]} from lineage since its name is none: {table}", ) continue try: ref = TableauUpstreamReference.create( - table, default_schema_map=self.config.default_schema_map + table, + default_schema_map=self.config.default_schema_map, ) except Exception as e: self.report.num_upstream_table_failed_generate_reference += 1 @@ -1864,7 +1880,9 @@ def get_upstream_tables( ) else: self.database_tables[table_urn].update_table( - table[c.ID], num_tbl_cols, table_path + table[c.ID], + num_tbl_cols, + table_path, ) return upstream_tables, table_id_to_urn @@ -1917,7 +1935,7 @@ def get_upstream_columns_of_fields_in_datasource( builder.make_schema_field_urn( parent_urn=parent_dataset_urn, field_path=name, - ) + ), ) if input_columns: @@ -1925,11 +1943,11 @@ def get_upstream_columns_of_fields_in_datasource( FineGrainedLineage( downstreamType=FineGrainedLineageDownstreamType.FIELD, downstreams=sorted( - [builder.make_schema_field_urn(datasource_urn, field_name)] + [builder.make_schema_field_urn(datasource_urn, field_name)], ), upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=sorted(input_columns), - ) + ), ) return fine_grained_lineages @@ -1941,7 +1959,9 @@ def is_snowflake_urn(self, urn: str) -> bool: ) def get_upstream_fields_of_field_in_datasource( - self, datasource: dict, datasource_urn: str + self, + datasource: dict, + datasource_urn: str, ) -> List[FineGrainedLineage]: fine_grained_lineages = [] for field in datasource.get(c.FIELDS) or []: @@ -1971,24 +1991,26 @@ def get_upstream_fields_of_field_in_datasource( self.config.env, ), field_path=name, - ) + ), ) if input_fields: fine_grained_lineages.append( FineGrainedLineage( downstreamType=FineGrainedLineageDownstreamType.FIELD, downstreams=[ - builder.make_schema_field_urn(datasource_urn, field_name) + builder.make_schema_field_urn(datasource_urn, field_name), ], upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=input_fields, transformOperation=self.get_transform_operation(field), - ) + ), ) return fine_grained_lineages def get_upstream_fields_from_custom_sql( - self, datasource: dict, datasource_urn: str + self, + datasource: dict, + datasource_urn: str, ) -> List[FineGrainedLineage]: parsed_result = self.parse_custom_sql( datasource=datasource, @@ -2013,8 +2035,9 @@ def get_upstream_fields_from_custom_sql( downstream = ( [ builder.make_schema_field_urn( - datasource_urn, cll_info.downstream.column - ) + datasource_urn, + cll_info.downstream.column, + ), ] if cll_info.downstream is not None and cll_info.downstream.column is not None @@ -2030,7 +2053,7 @@ def get_upstream_fields_from_custom_sql( downstreams=downstream, upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=upstreams, - ) + ), ) return fine_grained_lineages @@ -2060,7 +2083,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: connection_type=c.CUSTOM_SQL_TABLE_CONNECTION, query_filter=custom_sql_filter, page_size=self.config.effective_custom_sql_table_page_size, - ) + ), ) unique_custom_sql = get_unique_custom_sql(custom_sql_connection) @@ -2086,13 +2109,13 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: if len(csql[c.DATA_SOURCES]) > 0: # CustomSQLTable id owned by exactly one tableau data source logger.debug( - f"Number of datasources referencing CustomSQLTable: {len(csql[c.DATA_SOURCES])}" + f"Number of datasources referencing CustomSQLTable: {len(csql[c.DATA_SOURCES])}", ) datasource = csql[c.DATA_SOURCES][0] datasource_name = datasource.get(c.NAME) if datasource.get( - c.TYPE_NAME + c.TYPE_NAME, ) == c.EMBEDDED_DATA_SOURCE and datasource.get(c.WORKBOOK): workbook = datasource.get(c.WORKBOOK) datasource_name = ( @@ -2101,7 +2124,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: else None ) logger.debug( - f"Adding datasource {datasource_name}({datasource.get('id')}) to workbook container" + f"Adding datasource {datasource_name}({datasource.get('id')}) to workbook container", ) yield from add_entity_to_container( self.gen_workbook_key(workbook[c.ID]), @@ -2112,7 +2135,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: project_luid = self._get_datasource_project_luid(datasource) if project_luid: logger.debug( - f"Adding datasource {datasource_name}({datasource.get('id')}) to project {project_luid} container" + f"Adding datasource {datasource_name}({datasource.get('id')}) to project {project_luid} container", ) # TODO: Technically, we should have another layer of hierarchy with the datasource name here. # Same with the workbook name above. However, in practice most projects/workbooks have a single @@ -2124,7 +2147,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: ) else: logger.debug( - f"Datasource {datasource_name}({datasource.get('id')}) project_luid not found" + f"Datasource {datasource_name}({datasource.get('id')}) project_luid not found", ) project = self._get_project_browse_path_name(datasource) @@ -2143,7 +2166,9 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: if self.config.force_extraction_of_lineage_from_custom_sql_queries: logger.debug("Extracting TLL & CLL from custom sql (forced)") yield from self._create_lineage_from_unsupported_csql( - csql_urn, csql, columns + csql_urn, + csql, + columns, ) else: tables = csql.get(c.TABLES, []) @@ -2151,7 +2176,9 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: if tables: # lineage from custom sql -> datasets/tables # yield from self._create_lineage_to_upstream_tables( - csql_urn, tables, datasource + csql_urn, + tables, + datasource, ) elif ( self.config.extract_lineage_from_unsupported_custom_sql_queries @@ -2160,7 +2187,9 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: # custom sql tables may contain unsupported sql, causing incomplete lineage # we extract the lineage from the raw queries yield from self._create_lineage_from_unsupported_csql( - csql_urn, csql, columns + csql_urn, + csql, + columns, ) # Schema Metadata @@ -2171,7 +2200,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: # Browse path if project and datasource_name: browse_paths = BrowsePathsClass( - paths=[f"{self.dataset_browse_prefix}/{project}/{datasource_name}"] + paths=[f"{self.dataset_browse_prefix}/{project}/{datasource_name}"], ) dataset_snapshot.aspects.append(browse_paths) else: @@ -2200,7 +2229,8 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]: ) def get_schema_metadata_for_custom_sql( - self, columns: List[dict] + self, + columns: List[dict], ) -> Optional[SchemaMetadata]: fields = [] schema_metadata = None @@ -2210,7 +2240,7 @@ def get_schema_metadata_for_custom_sql( if field.get(c.NAME) is None: self.report.num_csql_field_skipped_no_name += 1 logger.warning( - f"Skipping field {field[c.ID]} from schema since its name is none" + f"Skipping field {field[c.ID]} from schema since its name is none", ) continue nativeDataType = field.get(c.REMOTE_TYPE, c.UNKNOWN) @@ -2243,7 +2273,7 @@ def _get_published_datasource_project_luid(self, ds: dict) -> Optional[str]: ): logger.debug( f"published datasource {ds.get(c.NAME)} project_luid not found." - f" Running get datasource query for {ds[c.LUID]}" + f" Running get datasource query for {ds[c.LUID]}", ) # Query and update self.datasource_project_map with luid self._query_published_datasource_for_project_luid(ds[c.LUID]) @@ -2268,7 +2298,7 @@ def _query_published_datasource_for_project_luid(self, ds_luid: str) -> None: if ds_result.project_id not in self.tableau_project_registry: logger.debug( f"project id ({ds_result.project_id}) of datasource {ds_result.name} is not present in project " - f"registry" + f"registry", ) else: self.datasource_project_map[ds_result.id] = ds_result.project_id @@ -2292,7 +2322,7 @@ def _get_workbook_project_luid(self, wb: dict) -> Optional[str]: def _get_embedded_datasource_project_luid(self, ds: dict) -> Optional[str]: if ds.get(c.WORKBOOK): project_luid: Optional[str] = self._get_workbook_project_luid( - ds[c.WORKBOOK] + ds[c.WORKBOOK], ) if project_luid and project_luid in self.tableau_project_registry: @@ -2310,7 +2340,7 @@ def _get_datasource_project_luid(self, ds: dict) -> Optional[str]: c.EMBEDDED_DATA_SOURCE, ): logger.debug( - f"datasource {ds.get(c.NAME)} type {ds.get(c.TYPE_NAME)} is unsupported" + f"datasource {ds.get(c.NAME)} type {ds.get(c.TYPE_NAME)} is unsupported", ) return None @@ -2338,7 +2368,7 @@ def _get_project_browse_path_name(self, ds: dict) -> Optional[str]: project_luid = self._get_datasource_project_luid(ds) if project_luid is None: logger.warning( - f"Could not load project hierarchy for datasource {ds.get(c.NAME)}. Please check permissions." + f"Could not load project hierarchy for datasource {ds.get(c.NAME)}. Please check permissions.", ) logger.debug(f"datasource = {ds}") return None @@ -2346,7 +2376,10 @@ def _get_project_browse_path_name(self, ds: dict) -> Optional[str]: return self._project_luid_to_browse_path_name(project_luid=project_luid) def _create_lineage_to_upstream_tables( - self, csql_urn: str, tables: List[dict], datasource: dict + self, + csql_urn: str, + tables: List[dict], + datasource: dict, ) -> Iterable[MetadataWorkUnit]: # This adds an edge to upstream DatabaseTables using `upstreamTables` upstream_tables, _ = self.get_upstream_tables( @@ -2432,14 +2465,14 @@ def parse_custom_sql( or database_info.get(c.CONNECTION_TYPE) is None ): logger.debug( - f"database information is missing from datasource {datasource_urn}" + f"database information is missing from datasource {datasource_urn}", ) return None query = datasource.get(c.QUERY) if query is None: logger.debug( - f"raw sql query is not available for datasource {datasource_urn}" + f"raw sql query is not available for datasource {datasource_urn}", ) return None query = self._clean_tableau_query_parameters(query) @@ -2461,7 +2494,7 @@ def parse_custom_sql( ) logger.debug( - f"Overridden info upstream_db={upstream_db}, platform_instance={platform_instance}, platform={platform}" + f"Overridden info upstream_db={upstream_db}, platform_instance={platform_instance}, platform={platform}", ) parsed_result = create_lineage_sql_parsed_result( @@ -2478,19 +2511,20 @@ def parse_custom_sql( if parsed_result.debug_info.table_error: logger.warning( - f"Failed to extract table lineage from datasource {datasource_urn}: {parsed_result.debug_info.table_error}" + f"Failed to extract table lineage from datasource {datasource_urn}: {parsed_result.debug_info.table_error}", ) self.report.num_upstream_table_lineage_failed_parse_sql += 1 elif parsed_result.debug_info.column_error: logger.warning( - f"Failed to extract column level lineage from datasource {datasource_urn}: {parsed_result.debug_info.column_error}" + f"Failed to extract column level lineage from datasource {datasource_urn}: {parsed_result.debug_info.column_error}", ) self.report.num_upstream_fine_grained_lineage_failed_parse_sql += 1 return parsed_result def _enrich_database_tables_with_parsed_schemas( - self, parsing_result: SqlParsingResult + self, + parsing_result: SqlParsingResult, ) -> None: in_tables_schemas: Dict[str, Set[str]] = ( transform_parsing_result_to_in_tables_schemas(parsing_result) @@ -2503,15 +2537,20 @@ def _enrich_database_tables_with_parsed_schemas( for table_urn, columns in in_tables_schemas.items(): if table_urn in self.database_tables: self.database_tables[table_urn].update_table( - table_urn, parsed_columns=columns + table_urn, + parsed_columns=columns, ) else: self.database_tables[table_urn] = DatabaseTable( - urn=table_urn, parsed_columns=columns + urn=table_urn, + parsed_columns=columns, ) def _create_lineage_from_unsupported_csql( - self, csql_urn: str, csql: dict, out_columns: List[Dict[Any, Any]] + self, + csql_urn: str, + csql: dict, + out_columns: List[Dict[Any, Any]], ) -> Iterable[MetadataWorkUnit]: parsed_result = self.parse_custom_sql( datasource=csql, @@ -2535,7 +2574,9 @@ def _create_lineage_from_unsupported_csql( if self.config.extract_column_level_lineage: logger.debug("Extracting CLL from custom sql") fine_grained_lineages = make_fine_grained_lineage_class( - parsed_result, csql_urn, out_columns + parsed_result, + csql_urn, + out_columns, ) upstream_lineage = UpstreamLineage( @@ -2552,7 +2593,8 @@ def _create_lineage_from_unsupported_csql( self.report.num_upstream_fine_grained_lineage += len(fine_grained_lineages) def _get_schema_metadata_for_datasource( - self, datasource_fields: List[dict] + self, + datasource_fields: List[dict], ) -> Optional[SchemaMetadata]: fields = [] for field in datasource_fields: @@ -2561,7 +2603,7 @@ def _get_schema_metadata_for_datasource( if field.get(c.NAME) is None: self.report.num_upstream_table_skipped_no_name += 1 logger.warning( - f"Skipping field {field[c.ID]} from schema since its name is none" + f"Skipping field {field[c.ID]} from schema since its name is none", ) continue @@ -2582,7 +2624,8 @@ def _get_schema_metadata_for_datasource( ) def get_metadata_change_event( - self, snap_shot: Union["DatasetSnapshot", "DashboardSnapshot", "ChartSnapshot"] + self, + snap_shot: Union["DatasetSnapshot", "DashboardSnapshot", "ChartSnapshot"], ) -> MetadataWorkUnit: mce = MetadataChangeEvent(proposedSnapshot=snap_shot) return MetadataWorkUnit(id=snap_shot.urn, mce=mce) @@ -2615,7 +2658,10 @@ def emit_datasource( logger.debug(f"datasource {datasource.get(c.NAME)} browse-path {browse_path}") datasource_id = datasource[c.ID] datasource_urn = builder.make_dataset_urn_with_platform_instance( - self.platform, datasource_id, self.config.platform_instance, self.config.env + self.platform, + datasource_id, + self.config.platform_instance, + self.config.env, ) if datasource_id not in self.datasource_ids_being_used: self.datasource_ids_being_used.append(datasource_id) @@ -2629,7 +2675,7 @@ def emit_datasource( if datasource_info and self.config.ingest_tags: tags = self.get_tags(datasource_info) dataset_snapshot.aspects.append( - builder.make_global_tag_aspect_with_tag_list(tags) + builder.make_global_tag_aspect_with_tag_list(tags), ) # Browse path @@ -2640,7 +2686,7 @@ def emit_datasource( if browse_path: browse_paths = BrowsePathsClass( - paths=[f"{self.dataset_browse_prefix}/{browse_path}"] + paths=[f"{self.dataset_browse_prefix}/{browse_path}"], ) dataset_snapshot.aspects.append(browse_paths) @@ -2682,7 +2728,9 @@ def emit_datasource( upstream_tables, fine_grained_lineages, ) = self._create_upstream_table_lineage( - datasource, browse_path, is_embedded_ds=is_embedded_ds + datasource, + browse_path, + is_embedded_ds=is_embedded_ds, ) if upstream_tables: @@ -2702,12 +2750,12 @@ def emit_datasource( self.report.num_tables_with_upstream_lineage += 1 self.report.num_upstream_table_lineage += len(upstream_tables) self.report.num_upstream_fine_grained_lineage += len( - fine_grained_lineages + fine_grained_lineages, ) # Datasource Fields schema_metadata = self._get_schema_metadata_for_datasource( - datasource.get(c.FIELDS, []) + datasource.get(c.FIELDS, []), ) if schema_metadata is not None: dataset_snapshot.aspects.append(schema_metadata) @@ -2721,12 +2769,14 @@ def emit_datasource( ["Embedded Data Source"] if is_embedded_ds else ["Published Data Source"] - ) + ), ), ) container_key = self._get_datasource_container_key( - datasource, workbook, is_embedded_ds + datasource, + workbook, + is_embedded_ds, ) if container_key is not None: yield from add_entity_to_container( @@ -2739,7 +2789,10 @@ def get_custom_props_from_dict(self, obj: dict, keys: List[str]) -> Optional[dic return {key: str(obj[key]) for key in keys if obj.get(key)} or None def _get_datasource_container_key( - self, datasource: dict, workbook: Optional[dict], is_embedded_ds: bool + self, + datasource: dict, + workbook: Optional[dict], + is_embedded_ds: bool, ) -> Optional[ContainerKey]: container_key: Optional[ContainerKey] = None if is_embedded_ds: # It is embedded then parent is container is workbook @@ -2747,18 +2800,18 @@ def _get_datasource_container_key( container_key = self.gen_workbook_key(workbook[c.ID]) else: logger.warning( - f"Parent container not set for embedded datasource {datasource[c.ID]}" + f"Parent container not set for embedded datasource {datasource[c.ID]}", ) else: parent_project_luid = self._get_published_datasource_project_luid( - datasource + datasource, ) # It is published datasource and hence parent container is project if parent_project_luid is not None: container_key = self.gen_project_key(parent_project_luid) else: logger.warning( - f"Parent container not set for published datasource {datasource[c.ID]}" + f"Parent container not set for published datasource {datasource[c.ID]}", ) return container_key @@ -2792,7 +2845,7 @@ def update_datasource_for_field_upstream( # update datasource's field for its upstream for field_dict in datasource.get(c.FIELDS, []): field_upstream_dict: Optional[dict] = field_vs_upstream.get( - field_dict.get(c.ID) + field_dict.get(c.ID), ) if field_upstream_dict: # Add upstream fields to field @@ -2825,7 +2878,7 @@ def emit_upstream_tables(self) -> Iterable[MetadataWorkUnit]: tableau_database_table_id_to_urn_map[tbl.id] = urn tables_filter = { - c.ID_WITH_IN: list(tableau_database_table_id_to_urn_map.keys()) + c.ID_WITH_IN: list(tableau_database_table_id_to_urn_map.keys()), } # Emitting tables that came from Tableau metadata @@ -2842,7 +2895,7 @@ def emit_upstream_tables(self) -> Iterable[MetadataWorkUnit]: is_embedded = tableau_table.get(c.IS_EMBEDDED) or False if not is_embedded and not self.config.ingest_tables_external: logger.debug( - f"Skipping external table {database_table.urn} as ingest_tables_external is set to False" + f"Skipping external table {database_table.urn} as ingest_tables_external is set to False", ) continue @@ -2853,13 +2906,13 @@ def emit_upstream_tables(self) -> Iterable[MetadataWorkUnit]: # Only tables purely parsed from SQL queries don't have ID if database_table.id: logger.debug( - f"Skipping external table {database_table.urn} should have already been ingested from Tableau metadata" + f"Skipping external table {database_table.urn} should have already been ingested from Tableau metadata", ) continue if not self.config.ingest_tables_external: logger.debug( - f"Skipping external table {database_table.urn} as ingest_tables_external is set to False" + f"Skipping external table {database_table.urn} as ingest_tables_external is set to False", ) continue @@ -2871,7 +2924,7 @@ def emit_table( tableau_columns: Optional[List[Dict[str, Any]]], ) -> Iterable[MetadataWorkUnit]: logger.debug( - f"Emitting external table {database_table} tableau_columns {tableau_columns}" + f"Emitting external table {database_table} tableau_columns {tableau_columns}", ) dataset_urn = DatasetUrn.from_string(database_table.urn) dataset_snapshot = DatasetSnapshot( @@ -2884,14 +2937,15 @@ def emit_table( paths=[ f"{self.dataset_browse_prefix}/{path}" for path in sorted(database_table.paths, key=lambda p: (len(p), p)) - ] + ], ) dataset_snapshot.aspects.append(browse_paths) else: logger.debug(f"Browse path not set for table {database_table.urn}") schema_metadata = self.get_schema_metadata_for_table( - tableau_columns, database_table.parsed_columns + tableau_columns, + database_table.parsed_columns, ) if schema_metadata is not None: dataset_snapshot.aspects.append(schema_metadata) @@ -2899,7 +2953,7 @@ def emit_table( if not dataset_snapshot.aspects: # This should only happen with ingest_tables_external enabled. logger.warning( - f"Urn {database_table.urn} has no real aspects, adding a key aspect to ensure materialization" + f"Urn {database_table.urn} has no real aspects, adding a key aspect to ensure materialization", ) dataset_snapshot.aspects.append(dataset_urn.to_key_aspect()) @@ -2919,7 +2973,7 @@ def get_schema_metadata_for_table( if field.get(c.NAME) is None: self.report.num_table_field_skipped_no_name += 1 logger.warning( - f"Skipping field {field[c.ID]} from schema since its name is none" + f"Skipping field {field[c.ID]} from schema since its name is none", ) continue nativeDataType = field.get(c.REMOTE_TYPE, c.UNKNOWN) @@ -2986,7 +3040,9 @@ def _create_datahub_chart_usage_stat( ) def _get_chart_stat_wu( - self, sheet: dict, sheet_urn: str + self, + sheet: dict, + sheet_urn: str, ) -> Optional[MetadataWorkUnit]: luid: Optional[str] = sheet.get(c.LUID) if luid is None: @@ -3006,7 +3062,7 @@ def _get_chart_stat_wu( return None aspect: ChartUsageStatisticsClass = self._create_datahub_chart_usage_stat( - usage_stat + usage_stat, ) logger.debug( "stat: Chart usage stat work unit is created for %s(id:%s)", @@ -3032,14 +3088,18 @@ def emit_sheets(self) -> Iterable[MetadataWorkUnit]: else: self.report.num_hidden_assets_skipped += 1 logger.debug( - f"Skip view {sheet.get(c.ID)} because it's hidden (luid is blank)." + f"Skip view {sheet.get(c.ID)} because it's hidden (luid is blank).", ) def emit_sheets_as_charts( - self, sheet: dict, workbook: Optional[Dict] + self, + sheet: dict, + workbook: Optional[Dict], ) -> Iterable[MetadataWorkUnit]: sheet_urn: str = builder.make_chart_urn( - self.platform, sheet[c.ID], self.config.platform_instance + self.platform, + sheet[c.ID], + self.config.platform_instance, ) chart_snapshot = ChartSnapshot( urn=sheet_urn, @@ -3083,7 +3143,10 @@ def emit_sheets_as_charts( for ds_id in data_sources: ds_urn = builder.make_dataset_urn_with_platform_instance( - self.platform, ds_id, self.config.platform_instance, self.config.env + self.platform, + ds_id, + self.config.platform_instance, + self.config.env, ) datasource_urn.append(ds_urn) if ds_id not in self.datasource_ids_being_used: @@ -3115,7 +3178,7 @@ def emit_sheets_as_charts( chart_snapshot.aspects.append(browse_paths) else: logger.warning( - f"Could not set browse path for workbook {sheet[c.ID]}. Please check permissions." + f"Could not set browse path for workbook {sheet[c.ID]}. Please check permissions.", ) # Ownership @@ -3127,12 +3190,12 @@ def emit_sheets_as_charts( if self.config.ingest_tags: tags = self.get_tags(sheet) if len(self.config.tags_for_hidden_assets) > 0 and self._is_hidden_view( - sheet + sheet, ): tags.extend(self.config.tags_for_hidden_assets) chart_snapshot.aspects.append( - builder.make_global_tag_aspect_with_tag_list(tags) + builder.make_global_tag_aspect_with_tag_list(tags), ) yield self.get_metadata_change_event(chart_snapshot) @@ -3141,18 +3204,20 @@ def emit_sheets_as_charts( self.new_embed_aspect_mcp( entity_urn=sheet_urn, embed_url=sheet_external_url, - ) + ), ) if workbook is not None: yield from add_entity_to_container( - self.gen_workbook_key(workbook[c.ID]), c.CHART, chart_snapshot.urn + self.gen_workbook_key(workbook[c.ID]), + c.CHART, + chart_snapshot.urn, ) if input_fields: yield MetadataChangeProposalWrapper( entityUrn=sheet_urn, aspect=InputFields( - fields=sorted(input_fields, key=lambda x: x.schemaFieldUrn) + fields=sorted(input_fields, key=lambda x: x.schemaFieldUrn), ), ).as_workunit() @@ -3168,7 +3233,9 @@ def _get_project_path(self, project: TableauProject) -> str: return self.config.project_path_separator.join(project.path) def populate_sheet_upstream_fields( - self, sheet: dict, input_fields: List[InputField] + self, + sheet: dict, + input_fields: List[InputField], ) -> None: for field in sheet.get(c.DATA_SOURCE_FIELDS): # type: ignore if not field: @@ -3190,9 +3257,10 @@ def populate_sheet_upstream_fields( field_path=name, ), schemaField=tableau_field_to_schema_field( - field, self.config.ingest_tags + field, + self.config.ingest_tags, ), - ) + ), ) def emit_workbook_as_container(self, workbook: Dict) -> Iterable[MetadataWorkUnit]: @@ -3226,7 +3294,7 @@ def emit_workbook_as_container(self, workbook: Dict) -> Iterable[MetadataWorkUni workbook_id: Optional[str] = workbook.get(c.ID) workbook_name: Optional[str] = workbook.get(c.NAME) logger.warning( - f"Could not load project hierarchy for workbook {workbook_name}({workbook_id}). Please check permissions." + f"Could not load project hierarchy for workbook {workbook_name}({workbook_id}). Please check permissions.", ) custom_props = None @@ -3238,7 +3306,7 @@ def emit_workbook_as_container(self, workbook: Dict) -> Iterable[MetadataWorkUni workbook_instance = self.server.workbooks.get_by_id(workbook.get(c.LUID)) self.server.workbooks.populate_permissions(workbook_instance) custom_props = self._create_workbook_properties( - workbook_instance.permissions + workbook_instance.permissions, ) yield from gen_containers( @@ -3287,7 +3355,9 @@ def _create_datahub_dashboard_usage_stat( ) def _get_dashboard_stat_wu( - self, dashboard: dict, dashboard_urn: str + self, + dashboard: dict, + dashboard_urn: str, ) -> Optional[MetadataWorkUnit]: luid: Optional[str] = dashboard.get(c.LUID) if luid is None: @@ -3322,7 +3392,8 @@ def _get_dashboard_stat_wu( @staticmethod def new_embed_aspect_mcp( - entity_urn: str, embed_url: str + entity_urn: str, + embed_url: str, ) -> MetadataChangeProposalWrapper: return MetadataChangeProposalWrapper( entityUrn=entity_urn, @@ -3353,7 +3424,7 @@ def emit_dashboards(self) -> Iterable[MetadataWorkUnit]: else: self.report.num_hidden_assets_skipped += 1 logger.debug( - f"Skip dashboard {dashboard.get(c.ID)} because it's hidden (luid is blank)." + f"Skip dashboard {dashboard.get(c.ID)} because it's hidden (luid is blank).", ) def get_tags(self, obj: dict) -> List[str]: @@ -3367,10 +3438,14 @@ def get_tags(self, obj: dict) -> List[str]: return [] def emit_dashboard( - self, dashboard: dict, workbook: Optional[Dict] + self, + dashboard: dict, + workbook: Optional[Dict], ) -> Iterable[MetadataWorkUnit]: dashboard_urn: str = builder.make_dashboard_urn( - self.platform, dashboard[c.ID], self.config.platform_instance + self.platform, + dashboard[c.ID], + self.config.platform_instance, ) dashboard_snapshot = DashboardSnapshot( urn=dashboard_urn, @@ -3418,12 +3493,12 @@ def emit_dashboard( if self.config.ingest_tags: tags = self.get_tags(dashboard) if len(self.config.tags_for_hidden_assets) > 0 and self._is_hidden_view( - dashboard + dashboard, ): tags.extend(self.config.tags_for_hidden_assets) dashboard_snapshot.aspects.append( - builder.make_global_tag_aspect_with_tag_list(tags) + builder.make_global_tag_aspect_with_tag_list(tags), ) if self.config.extract_usage_stats: @@ -3437,7 +3512,7 @@ def emit_dashboard( dashboard_snapshot.aspects.append(browse_paths) else: logger.warning( - f"Could not set browse path for dashboard {dashboard[c.ID]}. Please check permissions." + f"Could not set browse path for dashboard {dashboard[c.ID]}. Please check permissions.", ) # Ownership @@ -3452,7 +3527,7 @@ def emit_dashboard( self.new_embed_aspect_mcp( entity_urn=dashboard_urn, embed_url=dashboard_external_url, - ) + ), ) if workbook is not None: @@ -3463,7 +3538,8 @@ def emit_dashboard( ) def get_browse_paths_aspect( - self, workbook: Optional[Dict] + self, + workbook: Optional[Dict], ) -> Optional[BrowsePathsClass]: browse_paths: Optional[BrowsePathsClass] = None if workbook and workbook.get(c.NAME): @@ -3472,8 +3548,8 @@ def get_browse_paths_aspect( browse_paths = BrowsePathsClass( paths=[ f"{self.no_env_browse_prefix}/{self._project_luid_to_browse_path_name(project_luid)}" - f"/{workbook[c.NAME].replace('/', REPLACE_SLASH_CHAR)}" - ] + f"/{workbook[c.NAME].replace('/', REPLACE_SLASH_CHAR)}", + ], ) elif workbook.get(c.PROJECT_NAME): @@ -3481,8 +3557,8 @@ def get_browse_paths_aspect( browse_paths = BrowsePathsClass( paths=[ f"{self.no_env_browse_prefix}/{workbook[c.PROJECT_NAME].replace('/', REPLACE_SLASH_CHAR)}" - f"/{workbook[c.NAME].replace('/', REPLACE_SLASH_CHAR)}" - ] + f"/{workbook[c.NAME].replace('/', REPLACE_SLASH_CHAR)}", + ], ) return browse_paths @@ -3509,7 +3585,10 @@ def emit_embedded_datasources(self) -> Iterable[MetadataWorkUnit]: @lru_cache(maxsize=None) def get_last_modified( - self, creator: Optional[str], created_at: bytes, updated_at: bytes + self, + creator: Optional[str], + created_at: bytes, + updated_at: bytes, ) -> ChangeAuditStamps: last_modified = ChangeAuditStamps() if creator: @@ -3531,8 +3610,8 @@ def _get_ownership(self, user: str) -> Optional[OwnershipClass]: OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return ownership @@ -3606,7 +3685,7 @@ def emit_project_in_topological_order( for project in self.tableau_project_registry.values(): logger.debug( - f"project {project.name} and it's parent {project.parent_name} and parent id {project.parent_id}" + f"project {project.name} and it's parent {project.parent_name} and parent id {project.parent_id}", ) yield from emit_project_in_topological_order(project) @@ -3635,7 +3714,8 @@ def _get_allowed_capabilities(self, capabilities: Dict[str, str]) -> List[str]: return allowed_capabilities def _create_workbook_properties( - self, permissions: List[PermissionsRule] + self, + permissions: List[PermissionsRule], ) -> Optional[Dict[str, str]]: if not self.config.permission_ingestion: return None @@ -3648,10 +3728,10 @@ def _create_workbook_properties( logger.debug(f"Group {rule.grantee.id} not found in group map.") continue if not self.config.permission_ingestion.group_name_pattern.allowed( - group.name + group.name, ): logger.info( - f"Skip permission '{group.name}' as it's excluded in group_name_pattern." + f"Skip permission '{group.name}' as it's excluded in group_name_pattern.", ) continue @@ -3662,7 +3742,7 @@ def _create_workbook_properties( def ingest_tableau_site(self): with self.report.new_stage( - f"Ingesting Tableau Site: {self.site_id} {self.site_content_url}" + f"Ingesting Tableau Site: {self.site_id} {self.site_content_url}", ): # Initialise the dictionary to later look-up for chart and dashboard stat if self.config.extract_usage_stats: diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_common.py b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_common.py index 5d5103330fe302..ab94fec93ec773 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_common.py @@ -536,7 +536,8 @@ def tableau_field_to_schema_field(field, ingest_tags): fieldPath=field["name"], type=SchemaFieldDataType(type=TypeClass()), description=make_description_from_params( - field.get("description", ""), field.get("formula") + field.get("description", ""), + field.get("formula"), ), nativeDataType=nativeDataType, globalTags=( @@ -545,7 +546,7 @@ def tableau_field_to_schema_field(field, ingest_tags): field.get("role", ""), field.get("__typename", ""), field.get("aggregation", ""), - ] + ], ) if ingest_tags else None @@ -620,12 +621,12 @@ def get_fully_qualified_table_name( if platform in ("athena", "hive", "mysql", "clickhouse"): # it two tier database system (athena, hive, mysql), just take final 2 fully_qualified_table_name = ".".join( - fully_qualified_table_name.split(".")[-2:] + fully_qualified_table_name.split(".")[-2:], ) else: # if there are more than 3 tokens, just take the final 3 fully_qualified_table_name = ".".join( - fully_qualified_table_name.split(".")[-3:] + fully_qualified_table_name.split(".")[-3:], ) return fully_qualified_table_name @@ -642,7 +643,9 @@ class TableauUpstreamReference: @classmethod def create( - cls, d: Dict, default_schema_map: Optional[Dict[str, str]] = None + cls, + d: Dict, + default_schema_map: Optional[Dict[str, str]] = None, ) -> "TableauUpstreamReference": if d is None: raise ValueError("TableauUpstreamReference.create: d is None") @@ -676,17 +679,17 @@ def create( if database != t_database: logger.debug( f"Upstream urn generation ({t_id}):" - f" replacing database {t_database} with {database} from full name {t_full_name}" + f" replacing database {t_database} with {database} from full name {t_full_name}", ) if schema != t_schema: logger.debug( f"Upstream urn generation ({t_id}):" - f" replacing schema {t_schema} with {schema} from full name {t_full_name}" + f" replacing schema {t_schema} with {schema} from full name {t_full_name}", ) if table != t_table: logger.debug( f"Upstream urn generation ({t_id}):" - f" replacing table {t_table} with {table} from full name {t_full_name}" + f" replacing table {t_table} with {table} from full name {t_full_name}", ) # TODO: See if we can remove this -- made for redshift @@ -698,7 +701,7 @@ def create( and schema in t_table ): logger.debug( - f"Omitting schema for upstream table {t_id}, schema included in table name" + f"Omitting schema for upstream table {t_id}, schema included in table name", ) schema = "" @@ -756,7 +759,10 @@ def make_dataset_urn( ) return builder.make_dataset_urn_with_platform_instance( - platform, table_name, platform_instance, env + platform, + table_name, + platform_instance, + env, ) @@ -833,7 +839,7 @@ def make_upstream_class( for dataset_urn in parsed_result.in_tables: upstream_tables.append( - UpstreamClass(type=DatasetLineageType.TRANSFORMED, dataset=dataset_urn) + UpstreamClass(type=DatasetLineageType.TRANSFORMED, dataset=dataset_urn), ) return upstream_tables @@ -870,7 +876,7 @@ def make_fine_grained_lineage_class( cll_info.downstream.column.lower(), cll_info.downstream.column, ), - ) + ), ] if cll_info.downstream is not None and cll_info.downstream.column is not None @@ -887,7 +893,7 @@ def make_fine_grained_lineage_class( downstreams=downstream, upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=upstreams, - ) + ), ) return fine_grained_lineages @@ -999,7 +1005,7 @@ def get_filter_pages(query_filter: dict, page_size: int) -> List[dict]: start : ( start + page_size if start + page_size < len(ids) else len(ids) ) - ] + ], } for start in range(0, len(ids), page_size) ] @@ -1018,6 +1024,6 @@ def optimize_query_filter(query_filter: dict) -> dict: optimized_query[c.ID_WITH_IN] = list(OrderedSet(query_filter[c.ID_WITH_IN])) if query_filter.get(c.PROJECT_NAME_WITH_IN): optimized_query[c.PROJECT_NAME_WITH_IN] = list( - OrderedSet(query_filter[c.PROJECT_NAME_WITH_IN]) + OrderedSet(query_filter[c.PROJECT_NAME_WITH_IN]), ) return optimized_query diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_validation.py b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_validation.py index 4ec0e5ef01d3c6..361b3a8156f941 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_validation.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau_validation.py @@ -14,7 +14,7 @@ def check_user_role( capability_dict: Dict[Union[SourceCapability, str], CapabilityReport] = { c.SITE_PERMISSION: CapabilityReport( capable=True, - ) + ), } failure_reason: str = ( diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/analyze_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/unity/analyze_profiler.py index 995690be790c4f..0a3c2d008ce551 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/analyze_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/analyze_profiler.py @@ -30,7 +30,8 @@ class UnityCatalogAnalyzeProfiler: dataset_urn_builder: Callable[[TableReference], str] def get_workunits( - self, table_refs: Collection[TableReference] + self, + table_refs: Collection[TableReference], ) -> Iterable[MetadataWorkUnit]: try: tables = self._filter_tables(table_refs) @@ -46,7 +47,8 @@ def get_workunits( return def _filter_tables( - self, table_refs: Collection[TableReference] + self, + table_refs: Collection[TableReference], ) -> Collection[TableReference]: return [ ref @@ -69,12 +71,15 @@ def process_table(self, ref: TableReference) -> Optional[MetadataWorkUnit]: except Exception as e: self.report.report_warning("profiling", str(e)) logger.warning( - f"Unexpected error during profiling table {ref}: {e}", exc_info=True + f"Unexpected error during profiling table {ref}: {e}", + exc_info=True, ) return None def gen_dataset_profile_workunit( - self, ref: TableReference, table_profile: TableProfile + self, + ref: TableReference, + table_profile: TableProfile, ) -> MetadataWorkUnit: row_count = table_profile.num_rows aspect = DatasetProfileClass( @@ -99,7 +104,8 @@ def gen_dataset_profile_workunit( @staticmethod def _gen_dataset_field_profile( - num_rows: Optional[int], column_profile: ColumnProfile + num_rows: Optional[int], + column_profile: ColumnProfile, ) -> DatasetFieldProfileClass: unique_proportion: Optional[float] = None null_proportion: Optional[float] = None diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py index 6c3f7a51294797..4b6a85aa3c422b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py @@ -47,7 +47,8 @@ class UnityCatalogProfilerConfig(ConfigModel): # TODO: Support cluster compute as well, for ge profiling warehouse_id: Optional[str] = Field( - default=None, description="SQL Warehouse id, for running profiling queries." + default=None, + description="SQL Warehouse id, for running profiling queries.", ) pattern: AllowDenyPattern = Field( @@ -62,7 +63,8 @@ class UnityCatalogProfilerConfig(ConfigModel): class DeltaLakeDetails(ConfigModel): platform_instance_name: Optional[str] = Field( - default=None, description="Delta-lake paltform instance name" + default=None, + description="Delta-lake paltform instance name", ) env: str = Field(default="PROD", description="Delta-lake environment") @@ -72,7 +74,8 @@ class UnityCatalogAnalyzeProfilerConfig(UnityCatalogProfilerConfig): # TODO: Reduce duplicate code with DataLakeProfilerConfig, GEProfilingConfig, SQLAlchemyConfig enabled: bool = Field( - default=False, description="Whether profiling should be done." + default=False, + description="Whether profiling should be done.", ) operation_config: OperationConfig = Field( default_factory=OperationConfig, @@ -126,7 +129,7 @@ class UnityCatalogSourceConfig( ): token: str = pydantic.Field(description="Databricks personal access token") workspace_url: str = pydantic.Field( - description="Databricks workspace url. e.g. https://my-workspace.cloud.databricks.com" + description="Databricks workspace url. e.g. https://my-workspace.cloud.databricks.com", ) warehouse_id: Optional[str] = pydantic.Field( default=None, @@ -163,7 +166,7 @@ class UnityCatalogSourceConfig( ) _only_ingest_assigned_metastore_removed = pydantic_removed_field( - "only_ingest_assigned_metastore" + "only_ingest_assigned_metastore", ) _metastore_id_pattern_removed = pydantic_removed_field("metastore_id_pattern") @@ -229,7 +232,8 @@ class UnityCatalogSourceConfig( ) _rename_table_ownership = pydantic_renamed_field( - "include_table_ownership", "include_ownership" + "include_table_ownership", + "include_ownership", ) include_column_lineage: bool = pydantic.Field( @@ -255,7 +259,8 @@ class UnityCatalogSourceConfig( # TODO: Remove `type:ignore` by refactoring config profiling: Union[ - UnityCatalogGEProfilerConfig, UnityCatalogAnalyzeProfilerConfig + UnityCatalogGEProfilerConfig, + UnityCatalogAnalyzeProfilerConfig, ] = Field( # type: ignore default=UnityCatalogGEProfilerConfig(), description="Data profiling configuration", @@ -289,14 +294,15 @@ def get_sql_alchemy_url(self, database: Optional[str] = None) -> str: def is_profiling_enabled(self) -> bool: return self.profiling.enabled and is_profiling_enabled( - self.profiling.operation_config + self.profiling.operation_config, ) def is_ge_profiling(self) -> bool: return self.profiling.method == "ge" stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="Unity Catalog Stateful Ingestion Config." + default=None, + description="Unity Catalog Stateful Ingestion Config.", ) @pydantic.validator("start_time") @@ -309,7 +315,7 @@ def within_thirty_days(cls, v: datetime) -> datetime: def workspace_url_should_start_with_http_scheme(cls, workspace_url: str) -> str: if not workspace_url.lower().startswith(("http://", "https://")): raise ValueError( - "Workspace URL must start with http scheme. e.g. https://my-workspace.cloud.databricks.com" + "Workspace URL must start with http scheme. e.g. https://my-workspace.cloud.databricks.com", ) return workspace_url @@ -340,12 +346,12 @@ def set_warehouse_id_from_profiling(cls, values: Dict[str, Any]) -> Dict[str, An and values["warehouse_id"] != profiling.warehouse_id ): raise ValueError( - "When `warehouse_id` is set, it must match the `warehouse_id` in `profiling`." + "When `warehouse_id` is set, it must match the `warehouse_id` in `profiling`.", ) if values.get("include_hive_metastore") and not values.get("warehouse_id"): raise ValueError( - "When `include_hive_metastore` is set, `warehouse_id` must be set." + "When `include_hive_metastore` is set, `warehouse_id` must be set.", ) if values.get("warehouse_id") and profiling and not profiling.warehouse_id: @@ -358,7 +364,8 @@ def set_warehouse_id_from_profiling(cls, values: Dict[str, Any]) -> Dict[str, An @pydantic.validator("schema_pattern", always=True) def schema_pattern_should__always_deny_information_schema( - cls, v: AllowDenyPattern + cls, + v: AllowDenyPattern, ) -> AllowDenyPattern: v.deny.append(".*\\.information_schema") return v diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py b/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py index 45d74edd9c8e61..a32685b7e09df5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/connection_test.py @@ -44,7 +44,8 @@ def usage_connectivity(self) -> Optional[CapabilityReport]: return None try: query_history = self.proxy.query_history( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) _ = next(iter(query_history)) return CapabilityReport(capable=True) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/ge_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/unity/ge_profiler.py index e24ca8330777ed..3562df19983bdd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/ge_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/ge_profiler.py @@ -56,7 +56,8 @@ def get_workunits(self, tables: List[Table]) -> Iterable[MetadataWorkUnit]: # Extra default SQLAlchemy option for better connection pooling and threading. # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow self.config.options.setdefault( - "max_overflow", self.profiling_config.max_workers + "max_overflow", + self.profiling_config.max_workers, ) url = self.config.get_sql_alchemy_url() @@ -65,7 +66,7 @@ def get_workunits(self, tables: List[Table]) -> Iterable[MetadataWorkUnit]: profile_requests = [] with ThreadPoolExecutor( - max_workers=self.profiling_config.max_workers + max_workers=self.profiling_config.max_workers, ) as executor: futures = [ executor.submit( @@ -78,7 +79,7 @@ def get_workunits(self, tables: List[Table]) -> Iterable[MetadataWorkUnit]: try: for i, completed in enumerate( - as_completed(futures, timeout=self.profiling_config.max_wait_secs) + as_completed(futures, timeout=self.profiling_config.max_wait_secs), ): profile_request = completed.result() if profile_request is not None: @@ -103,7 +104,9 @@ def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> s return f"{db_name}.{schema_name}.{table_name}" def get_unity_profile_request( - self, table: UnityCatalogSQLGenericTable, conn: Connection + self, + table: UnityCatalogSQLGenericTable, + conn: Connection, ) -> Optional[TableProfilerRequest]: # TODO: Reduce code duplication with get_profile_request skip_profiling = False @@ -154,7 +157,8 @@ def get_unity_profile_request( def _get_dataset_size_in_bytes( - table: UnityCatalogSQLGenericTable, conn: Connection + table: UnityCatalogSQLGenericTable, + conn: Connection, ) -> Optional[int]: name = ".".join( conn.dialect.identifier_preparer.quote(c) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/hive_metastore_proxy.py b/metadata-ingestion/src/datahub/ingestion/source/unity/hive_metastore_proxy.py index eea10d940bd1c8..f9c8cf544ecb6b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/hive_metastore_proxy.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/hive_metastore_proxy.py @@ -68,7 +68,10 @@ class HiveMetastoreProxy(Closeable): """ def __init__( - self, sqlalchemy_url: str, options: dict, report: UnityCatalogReport + self, + sqlalchemy_url: str, + options: dict, + report: UnityCatalogReport, ) -> None: try: self.inspector = HiveMetastoreProxy.get_inspector(sqlalchemy_url, options) @@ -125,10 +128,12 @@ def get_table_names(self, schema_name: str) -> List[str]: return [row.tableName for row in rows] except Exception as e: self.report.report_warning( - "Failed to get tables for schema", f"{HIVE_METASTORE}.{schema_name}" + "Failed to get tables for schema", + f"{HIVE_METASTORE}.{schema_name}", ) logger.warning( - f"Failed to get tables {schema_name} due to {e}", exc_info=True + f"Failed to get tables {schema_name} due to {e}", + exc_info=True, ) return [] @@ -140,7 +145,8 @@ def get_view_names(self, schema_name: str) -> List[str]: except Exception as e: self.report.report_warning("Failed to get views for schema", schema_name) logger.warning( - f"Failed to get views {schema_name} due to {e}", exc_info=True + f"Failed to get views {schema_name} due to {e}", + exc_info=True, ) return [] @@ -156,7 +162,7 @@ def _get_table( comment = detailed_info.pop("Comment", None) storage_location = detailed_info.pop("Location", None) datasource_format = self._get_datasource_format( - detailed_info.pop("Provider", None) + detailed_info.pop("Provider", None), ) created_at = self._get_created_at(detailed_info.pop("Created Time", None)) @@ -184,7 +190,9 @@ def _get_table( ) def get_table_profile( - self, ref: TableReference, include_column_stats: bool = False + self, + ref: TableReference, + include_column_stats: bool = False, ) -> Optional[TableProfile]: columns = self._get_columns( ref.schema, @@ -220,7 +228,9 @@ def get_table_profile( ) def _get_column_profile( - self, column: str, ref: TableReference + self, + column: str, + ref: TableReference, ) -> Optional[ColumnProfile]: try: props = self._column_describe_extended(ref.schema, ref.table, column) @@ -267,7 +277,8 @@ def _get_created_at(self, created_at: Optional[str]) -> Optional[datetime]: ) def _get_datasource_format( - self, provider: Optional[str] + self, + provider: Optional[str], ) -> Optional[DataSourceFormat]: raw_format = provider if raw_format: @@ -281,7 +292,7 @@ def _get_datasource_format( def _get_view_definition(self, schema_name: str, table_name: str) -> Optional[str]: try: rows = self._execute_sql( - f"SHOW CREATE TABLE `{schema_name}`.`{table_name}`" + f"SHOW CREATE TABLE `{schema_name}`.`{table_name}`", ) for row in rows: return row[0] @@ -370,7 +381,7 @@ def _get_columns(self, schema_name: str, table_name: str) -> List[Column]: position=None, nullable=None, comment=row[2], - ) + ), ) except Exception as e: self.report.report_warning( @@ -392,14 +403,17 @@ def _describe_extended(self, schema_name: str, table_name: str) -> List[Row]: return self._execute_sql(f"DESCRIBE EXTENDED `{schema_name}`.`{table_name}`") def _column_describe_extended( - self, schema_name: str, table_name: str, column_name: str + self, + schema_name: str, + table_name: str, + column_name: str, ) -> List[Row]: """ Rows are structured as shown in examples here https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-aux-describe-table.html#examples """ return self._execute_sql( - f"DESCRIBE EXTENDED `{schema_name}`.`{table_name}` {column_name}" + f"DESCRIBE EXTENDED `{schema_name}`.`{table_name}` {column_name}", ) def _execute_sql(self, sql: str) -> List[Row]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py index fd6fa8a50f707b..e841839744689c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy.py @@ -130,10 +130,13 @@ def catalogs(self, metastore: Optional[Metastore]) -> Iterable[Catalog]: yield optional_catalog def catalog( - self, catalog_name: str, metastore: Optional[Metastore] + self, + catalog_name: str, + metastore: Optional[Metastore], ) -> Optional[Catalog]: response = self._workspace_client.catalogs.get( - catalog_name, include_browse=True + catalog_name, + include_browse=True, ) if not response: logger.info(f"Catalog {catalog_name} not found") @@ -152,7 +155,8 @@ def schemas(self, catalog: Catalog) -> Iterable[Schema]: yield from self.hive_metastore_proxy.hive_metastore_schemas(catalog) return response = self._workspace_client.schemas.list( - catalog_name=catalog.name, include_browse=True + catalog_name=catalog.name, + include_browse=True, ) if not response: logger.info(f"Schemas not found for catalog {catalog.id}") @@ -181,7 +185,8 @@ def tables(self, schema: Schema) -> Iterable[Table]: for table in response: try: optional_table = self._create_table( - schema, cast(TableInfoWithGeneration, table) + schema, + cast(TableInfoWithGeneration, table), ) if optional_table: yield optional_table @@ -234,7 +239,7 @@ def query_history( }, "statuses": [QueryStatus.FINISHED], "statement_types": [typ.value for typ in ALLOWED_STATEMENT_TYPES], - } + }, ) for query_info in self._query_history(filter_by=filter_by): try: @@ -266,7 +271,9 @@ def _query_history( } response: dict = self._workspace_client.api_client.do( # type: ignore - method, path, body={**body, "filter_by": filter_by.as_dict()} + method, + path, + body={**body, "filter_by": filter_by.as_dict()}, ) # we use default raw=False(default) in above request, therefore will always get dict while True: @@ -277,11 +284,15 @@ def _query_history( if not response.get("next_page_token"): # last page return response = self._workspace_client.api_client.do( # type: ignore - method, path, body={**body, "page_token": response["next_page_token"]} + method, + path, + body={**body, "page_token": response["next_page_token"]}, ) def list_lineages_by_table( - self, table_name: str, include_entity_lineage: bool + self, + table_name: str, + include_entity_lineage: bool, ) -> dict: """List table lineage by table name.""" return self._workspace_client.api_client.do( # type: ignore @@ -315,13 +326,14 @@ def table_lineage(self, table: Table, include_entity_lineage: bool) -> None: for item in response.get("upstreams") or []: if "tableInfo" in item: table_ref = TableReference.create_from_lineage( - item["tableInfo"], table.schema.catalog.metastore + item["tableInfo"], + table.schema.catalog.metastore, ) if table_ref: table.upstreams[table_ref] = {} elif "fileInfo" in item: external_ref = ExternalTableReference.create_from_lineage( - item["fileInfo"] + item["fileInfo"], ) if external_ref: table.external_upstreams.add(external_ref) @@ -334,7 +346,8 @@ def table_lineage(self, table: Table, include_entity_lineage: bool) -> None: table.downstream_notebooks.add(notebook["notebook_id"]) except Exception as e: logger.warning( - f"Error getting lineage on table {table.ref}: {e}", exc_info=True + f"Error getting lineage on table {table.ref}: {e}", + exc_info=True, ) def get_column_lineage(self, table: Table, column_name: str) -> None: @@ -345,11 +358,13 @@ def get_column_lineage(self, table: Table, column_name: str) -> None: ) for item in response.get("upstream_cols") or []: table_ref = TableReference.create_from_lineage( - item, table.schema.catalog.metastore + item, + table.schema.catalog.metastore, ) if table_ref: table.upstreams.setdefault(table_ref, {}).setdefault( - column_name, [] + column_name, + [], ).append(item["name"]) except Exception as e: logger.warning( @@ -379,7 +394,9 @@ def _create_metastore( ) def _create_catalog( - self, metastore: Optional[Metastore], obj: CatalogInfo + self, + metastore: Optional[Metastore], + obj: CatalogInfo, ) -> Optional[Catalog]: if not obj.name: self.report.num_catalogs_missing_name += 1 @@ -423,7 +440,9 @@ def _create_column(self, table_id: str, obj: ColumnInfo) -> Optional[Column]: ) def _create_table( - self, schema: Schema, obj: TableInfoWithGeneration + self, + schema: Schema, + obj: TableInfoWithGeneration, ) -> Optional[Table]: if not obj.name: self.report.num_tables_missing_name += 1 @@ -454,7 +473,9 @@ def _create_table( ) def _extract_columns( - self, columns: List[ColumnInfo], table_id: str + self, + columns: List[ColumnInfo], + table_id: str, ) -> Iterable[Column]: for column in columns: optional_column = self._create_column(table_id, column) @@ -462,7 +483,8 @@ def _extract_columns( yield optional_column def _create_service_principal( - self, obj: DatabricksServicePrincipal + self, + obj: DatabricksServicePrincipal, ) -> Optional[ServicePrincipal]: if not obj.display_name or not obj.application_id: return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py index 51546a79e05c32..762656153f1df0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_profiling.py @@ -78,7 +78,8 @@ def get_table_stats( if call_analyze: response = self._analyze_table(ref, include_columns=include_columns) success = self._check_analyze_table_statement_status( - response, max_wait_secs=max_wait_secs + response, + max_wait_secs=max_wait_secs, ) if not success: self.report.profile_table_timeouts.append(str(ref)) @@ -90,7 +91,7 @@ def get_table_stats( idx = (str(msg).find("`") + 1) or (str(msg).find("'") + 1) or len(str(msg)) base_msg = msg[:idx] self.report.profile_table_errors.setdefault(base_msg, LossyList()).append( - (str(ref), msg) + (str(ref), msg), ) logger.warning( f"Failure during profiling {ref}, {e.kwargs}: ({e.error_code}) {e}", @@ -112,18 +113,22 @@ def get_table_stats( return None def _should_retry_unsupported_column( - self, ref: TableReference, e: DatabricksError + self, + ref: TableReference, + e: DatabricksError, ) -> bool: if "[UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE]" in str(e): logger.info( - f"Attempting to profile table without columns due to unsupported column type: {ref}" + f"Attempting to profile table without columns due to unsupported column type: {ref}", ) self.report.num_profile_failed_unsupported_column_type += 1 return True return False def _analyze_table( - self, ref: TableReference, include_columns: bool + self, + ref: TableReference, + include_columns: bool, ) -> StatementResponse: statement = f"ANALYZE TABLE {ref.schema}.{ref.table} COMPUTE STATISTICS" if include_columns: @@ -138,7 +143,9 @@ def _analyze_table( return response def _check_analyze_table_statement_status( - self, execute_response: StatementResponse, max_wait_secs: int + self, + execute_response: StatementResponse, + max_wait_secs: int, ) -> bool: if not execute_response.statement_id or not execute_response.status: return False @@ -155,7 +162,7 @@ def _check_analyze_table_statement_status( backoff_sec *= 2 response = self._workspace_client.statement_execution.get_statement( - statement_id + statement_id, ) self._raise_if_error(response, "get-statement") status = response.status # type: ignore @@ -163,7 +170,9 @@ def _check_analyze_table_statement_status( return status.state == StatementState.SUCCEEDED def _get_table_profile( - self, ref: TableReference, include_columns: bool + self, + ref: TableReference, + include_columns: bool, ) -> Optional[TableProfile]: if self.hive_metastore_proxy and ref.catalog == HIVE_METASTORE: return self.hive_metastore_proxy.get_table_profile(ref, include_columns) @@ -171,7 +180,9 @@ def _get_table_profile( return self._create_table_profile(table_info, include_columns=include_columns) def _create_table_profile( - self, table_info: TableInfo, include_columns: bool + self, + table_info: TableInfo, + include_columns: bool, ) -> TableProfile: # Warning: this implementation is brittle -- dependent on properties that can change columns_names = ( @@ -195,23 +206,27 @@ def _create_table_profile( ) def _create_column_profile( - self, column: str, table_info: TableInfo + self, + column: str, + table_info: TableInfo, ) -> ColumnProfile: tblproperties = table_info.properties or {} return ColumnProfile( name=column, null_count=self._get_int( - table_info, f"spark.sql.statistics.colStats.{column}.nullCount" + table_info, + f"spark.sql.statistics.colStats.{column}.nullCount", ), distinct_count=self._get_int( - table_info, f"spark.sql.statistics.colStats.{column}.distinctCount" + table_info, + f"spark.sql.statistics.colStats.{column}.distinctCount", ), min=tblproperties.get(f"spark.sql.statistics.colStats.{column}.min"), max=tblproperties.get(f"spark.sql.statistics.colStats.{column}.max"), avg_len=tblproperties.get(f"spark.sql.statistics.colStats.{column}.avgLen"), max_len=tblproperties.get(f"spark.sql.statistics.colStats.{column}.maxLen"), version=tblproperties.get( - f"spark.sql.statistics.colStats.{column}.version" + f"spark.sql.statistics.colStats.{column}.version", ), ) @@ -223,7 +238,7 @@ def _get_int(self, table_info: TableInfo, field: str) -> Optional[int]: return int(value) except ValueError: logger.warning( - f"Failed to parse int for {table_info.name} - {field}: {value}" + f"Failed to parse int for {table_info.name} - {field}: {value}", ) self.report.num_profile_failed_int_casts += 1 return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py index 9c5752c518df14..908d1c59c97114 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/proxy_types.py @@ -164,7 +164,9 @@ def create(cls, table: "Table") -> "TableReference": @classmethod def create_from_lineage( - cls, d: dict, metastore: Optional[Metastore] + cls, + d: dict, + metastore: Optional[Metastore], ) -> Optional["TableReference"]: try: return cls( @@ -282,7 +284,7 @@ def __bool__(self): self.num_columns is not None, self.total_size is not None, any(self.column_profiles), - ) + ), ) @@ -305,7 +307,7 @@ def __bool__(self): self.distinct_count is not None, self.min is not None, self.max is not None, - ) + ), ) @@ -325,5 +327,5 @@ def add_upstream(cls, upstream: TableReference, notebook: "Notebook") -> "Notebo **{ # type: ignore **dataclasses.asdict(notebook), "upstreams": frozenset([*notebook.upstreams, upstream]), - } + }, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/report.py b/metadata-ingestion/src/datahub/ingestion/source/unity/report.py index f16769341853a1..a3367e7d71fd3a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/report.py @@ -40,7 +40,7 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport): num_queries_duplicate_table: int = 0 num_queries_parsed_by_spark_plan: int = 0 usage_perf_report: UnityCatalogUsagePerfReport = field( - default_factory=UnityCatalogUsagePerfReport + default_factory=UnityCatalogUsagePerfReport, ) # Distinguish from Operations emitted for created / updated timestamps @@ -49,7 +49,7 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport): profile_table_timeouts: LossyList[str] = field(default_factory=LossyList) profile_table_empty: LossyList[str] = field(default_factory=LossyList) profile_table_errors: LossyDict[str, LossyList[Tuple[str, str]]] = field( - default_factory=LossyDict + default_factory=LossyDict, ) num_profile_missing_size_in_bytes: int = 0 num_profile_failed_unsupported_column_type: int = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py index 29562eaf3ce5b1..5c9bde190da9cf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py @@ -141,7 +141,8 @@ @capability(SourceCapability.CONTAINERS, "Enabled by default") @capability(SourceCapability.OWNERSHIP, "Supported via the `include_ownership` config") @capability( - SourceCapability.DATA_PROFILING, "Supported via the `profiling.enabled` config" + SourceCapability.DATA_PROFILING, + "Supported via the `profiling.enabled` config", ) @capability( SourceCapability.DELETION_DETECTION, @@ -195,7 +196,8 @@ def __init__(self, ctx: PipelineContext, config: UnityCatalogSourceConfig): if self.config.domain: self.domain_registry = DomainRegistry( - cached_domains=[k for k in self.config.domain], graph=self.ctx.graph + cached_domains=[k for k in self.config.domain], + graph=self.ctx.graph, ) # Global map of service principal application id -> ServicePrincipal @@ -258,7 +260,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), StaleEntityRemovalHandler.create( - self, self.config, self.ctx + self, + self.config, + self.ctx, ).workunit_processor, ] @@ -308,7 +312,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: user_urn_builder=self.gen_user_urn, ) yield from usage_extractor.get_usage_workunits( - self.table_refs | self.view_refs + self.table_refs | self.view_refs, ) if self.config.is_profiling_enabled(): @@ -350,7 +354,8 @@ def build_service_principal_map(self) -> None: self.service_principals[sp.application_id] = sp except Exception as e: self.report.report_warning( - "service-principals", f"Unable to fetch service principals: {e}" + "service-principals", + f"Unable to fetch service principals: {e}", ) def build_groups_map(self) -> None: @@ -380,7 +385,8 @@ def _gen_notebook_workunits(self, notebook: Notebook) -> Iterable[MetadataWorkUn name=notebook.path.rsplit("/", 1)[-1], customProperties=properties, externalUrl=urljoin( - self.config.workspace_url, f"#notebook/{notebook.id}" + self.config.workspace_url, + f"#notebook/{notebook.id}", ), created=( TimeStampClass(int(notebook.created_at.timestamp() * 1000)) @@ -416,7 +422,7 @@ def _gen_notebook_lineage(self, notebook: Notebook) -> Optional[MetadataWorkUnit type=DatasetLineageTypeClass.COPY, ) for upstream_ref in notebook.upstreams - ] + ], ), ).as_workunit() @@ -433,7 +439,8 @@ def process_metastores(self) -> Iterable[MetadataWorkUnit]: self.report.metastores.processed(metastore.id) def process_catalogs( - self, metastore: Optional[Metastore] + self, + metastore: Optional[Metastore], ) -> Iterable[MetadataWorkUnit]: for catalog in self._get_catalogs(metastore): if not self.config.catalog_pattern.allowed(catalog.id): @@ -449,7 +456,8 @@ def _get_catalogs(self, metastore: Optional[Metastore]) -> Iterable[Catalog]: if self.config.catalogs: for catalog_name in self.config.catalogs: catalog = self.unity_catalog_api_proxy.catalog( - catalog_name, metastore=metastore + catalog_name, + metastore=metastore, ) if catalog: yield catalog @@ -478,7 +486,7 @@ def process_tables(self, schema: Schema) -> Iterable[MetadataWorkUnit]: self.config.is_profiling_enabled() and self.config.is_ge_profiling() and self.config.profiling.pattern.allowed( - table.ref.qualified_table_name + table.ref.qualified_table_name, ) and not table.is_view ): @@ -513,7 +521,8 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn for notebook_id in table.downstream_notebooks: if str(notebook_id) in self.notebooks: self.notebooks[str(notebook_id)] = Notebook.add_upstream( - table.ref, self.notebooks[str(notebook_id)] + table.ref, + self.notebooks[str(notebook_id)], ) # Sql parsing is required only for hive metastore view lineage @@ -522,7 +531,8 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn and table.schema.catalog.type == CustomCatalogType.HIVE_METASTORE_CATALOG ): self.sql_parser_schema_resolver.add_schema_metadata( - dataset_urn, schema_metadata + dataset_urn, + schema_metadata, ) if table.view_definition: self.view_definitions[dataset_urn] = (table.ref, table.view_definition) @@ -552,7 +562,8 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn patch_builder = create_dataset_owners_patch_builder(dataset_urn, ownership) for patch_mcp in patch_builder.build(): yield MetadataWorkUnit( - id=f"{dataset_urn}-{patch_mcp.aspectName}", mcp_raw=patch_mcp + id=f"{dataset_urn}-{patch_mcp.aspectName}", + mcp_raw=patch_mcp, ) if table_props: @@ -561,7 +572,8 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn patch_builder = create_dataset_props_patch_builder(dataset_urn, table_props) for patch_mcp in patch_builder.build(): yield MetadataWorkUnit( - id=f"{dataset_urn}-{patch_mcp.aspectName}", mcp_raw=patch_mcp + id=f"{dataset_urn}-{patch_mcp.aspectName}", + mcp_raw=patch_mcp, ) yield from [ @@ -582,7 +594,8 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn def ingest_lineage(self, table: Table) -> Optional[UpstreamLineageClass]: if self.config.include_table_lineage: self.unity_catalog_api_proxy.table_lineage( - table, include_entity_lineage=self.config.include_notebooks + table, + include_entity_lineage=self.config.include_notebooks, ) if self.config.include_column_lineage and table.upstreams: @@ -590,7 +603,7 @@ def ingest_lineage(self, table: Table) -> Optional[UpstreamLineageClass]: self.report.num_column_lineage_skipped_column_count += 1 with ThreadPoolExecutor( - max_workers=self.config.lineage_max_workers + max_workers=self.config.lineage_max_workers, ) as executor: for column in table.columns[: self.config.column_lineage_column_limit]: executor.submit( @@ -602,12 +615,14 @@ def ingest_lineage(self, table: Table) -> Optional[UpstreamLineageClass]: return self._generate_lineage_aspect(self.gen_dataset_urn(table.ref), table) def _generate_lineage_aspect( - self, dataset_urn: str, table: Table + self, + dataset_urn: str, + table: Table, ) -> Optional[UpstreamLineageClass]: upstreams: List[UpstreamClass] = [] finegrained_lineages: List[FineGrainedLineage] = [] for upstream_ref, downstream_to_upstream_cols in sorted( - table.upstreams.items() + table.upstreams.items(), ): upstream_urn = self.gen_dataset_urn(upstream_ref) @@ -629,7 +644,7 @@ def _generate_lineage_aspect( UpstreamClass( dataset=upstream_urn, type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ) for notebook in table.upstream_notebooks: @@ -637,7 +652,7 @@ def _generate_lineage_aspect( UpstreamClass( dataset=self.gen_notebook_urn(notebook), type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ) if self.config.include_external_lineage: @@ -645,21 +660,22 @@ def _generate_lineage_aspect( if not external_ref.has_permission or not external_ref.path: self.report.num_external_upstreams_lacking_permissions += 1 logger.warning( - f"Lacking permissions for external file upstream on {table.ref}" + f"Lacking permissions for external file upstream on {table.ref}", ) elif external_ref.path.startswith("s3://"): upstreams.append( UpstreamClass( dataset=make_s3_urn_for_lineage( - external_ref.path, self.config.env + external_ref.path, + self.config.env, ), type=DatasetLineageTypeClass.COPY, - ) + ), ) else: self.report.num_external_upstreams_unsupported += 1 logger.warning( - f"Unsupported external file upstream on {table.ref}: {external_ref.path}" + f"Unsupported external file upstream on {table.ref}: {external_ref.path}", ) if upstreams: @@ -722,7 +738,8 @@ def gen_schema_containers(self, schema: Schema) -> Iterable[MetadataWorkUnit]: ) def gen_metastore_containers( - self, metastore: Metastore + self, + metastore: Metastore, ) -> Iterable[MetadataWorkUnit]: domain_urn = self._gen_domain_urn(metastore.name) @@ -808,13 +825,15 @@ def _gen_domain_urn(self, dataset_name: str) -> Optional[str]: for domain, pattern in self.config.domain.items(): if pattern.allowed(dataset_name): domain_urn = make_domain_urn( - self.domain_registry.get_domain_urn(domain) + self.domain_registry.get_domain_urn(domain), ) return domain_urn def add_table_to_dataset_container( - self, dataset_urn: str, schema: Schema + self, + dataset_urn: str, + schema: Schema, ) -> Iterable[MetadataWorkUnit]: schema_container_key = self.gen_schema_key(schema) yield from add_dataset_to_container( @@ -879,8 +898,8 @@ def _create_table_ownership_aspect(self, table: Table) -> Optional[OwnershipClas OwnerClass( owner=owner_urn, type=OwnershipTypeClass.DATAOWNER, - ) - ] + ), + ], ) return None @@ -892,7 +911,8 @@ def _create_data_platform_instance_aspect( platform=make_data_platform_urn(self.platform), instance=( make_dataplatform_instance_urn( - self.platform, self.platform_instance_name + self.platform, + self.platform_instance_name, ) if self.platform_instance_name else None @@ -902,13 +922,17 @@ def _create_data_platform_instance_aspect( def _create_table_sub_type_aspect(self, table: Table) -> SubTypesClass: return SubTypesClass( - typeNames=[DatasetSubTypes.VIEW if table.is_view else DatasetSubTypes.TABLE] + typeNames=[ + DatasetSubTypes.VIEW if table.is_view else DatasetSubTypes.TABLE, + ], ) def _create_view_property_aspect(self, table: Table) -> ViewProperties: assert table.view_definition return ViewProperties( - materialized=False, viewLanguage="SQL", viewLogic=table.view_definition + materialized=False, + viewLanguage="SQL", + viewLogic=table.view_definition, ) def _create_schema_metadata_aspect(self, table: Table) -> SchemaMetadataClass: @@ -932,23 +956,28 @@ def _create_schema_field(column: Column) -> List[SchemaFieldClass]: if _COMPLEX_TYPE.match(column.type_text.lower()): return get_schema_fields_for_hive_column( - column.name, column.type_text.lower(), description=column.comment + column.name, + column.type_text.lower(), + description=column.comment, ) else: return [ SchemaFieldClass( fieldPath=column.name, type=SchemaFieldDataTypeClass( - type=DATA_TYPE_REGISTRY.get(column.type_name, NullTypeClass)() + type=DATA_TYPE_REGISTRY.get(column.type_name, NullTypeClass)(), ), nativeDataType=column.type_text, nullable=column.nullable, description=column.comment, - ) + ), ] def _run_sql_parser( - self, view_ref: TableReference, query: str, schema_resolver: SchemaResolver + self, + view_ref: TableReference, + query: str, + schema_resolver: SchemaResolver, ) -> Optional[SqlParsingResult]: raw_lineage = sqlglot_lineage( query, @@ -961,18 +990,18 @@ def _run_sql_parser( if raw_lineage.debug_info.table_error: logger.debug( f"Failed to parse lineage for view {view_ref}: " - f"{raw_lineage.debug_info.table_error}" + f"{raw_lineage.debug_info.table_error}", ) self.report.num_view_definitions_failed_parsing += 1 self.report.view_definitions_parsing_failures.append( - f"Table-level sql parsing error for view {view_ref}: {raw_lineage.debug_info.table_error}" + f"Table-level sql parsing error for view {view_ref}: {raw_lineage.debug_info.table_error}", ) return None elif raw_lineage.debug_info.column_error: self.report.num_view_definitions_failed_column_parsing += 1 self.report.view_definitions_parsing_failures.append( - f"Column-level sql parsing error for view {view_ref}: {raw_lineage.debug_info.column_error}" + f"Column-level sql parsing error for view {view_ref}: {raw_lineage.debug_info.column_error}", ) else: self.report.num_view_definitions_parsed += 1 @@ -1051,7 +1080,7 @@ def gen_lineage_workunit( entityUrn=dataset_urn, aspect=UpstreamLineage( upstreams=[ - Upstream(dataset=source_dataset_urn, type=DatasetLineageType.VIEW) - ] + Upstream(dataset=source_dataset_urn, type=DatasetLineageType.VIEW), + ], ), ).as_workunit() diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/usage.py b/metadata-ingestion/src/datahub/ingestion/source/unity/usage.py index 2e9f7fc00c8784..3ae96d3453117f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/usage.py @@ -67,7 +67,8 @@ def spark_sql_parser(self): return self._spark_sql_parser def get_usage_workunits( - self, table_refs: Set[TableReference] + self, + table_refs: Set[TableReference], ) -> Iterable[MetadataWorkUnit]: try: yield from self._get_workunits_internal(table_refs) @@ -76,7 +77,8 @@ def get_usage_workunits( self.report.report_warning("usage-extraction", str(e)) def _get_workunits_internal( - self, table_refs: Set[TableReference] + self, + table_refs: Set[TableReference], ) -> Iterable[MetadataWorkUnit]: table_map = defaultdict(list) query_hashes = set() @@ -92,15 +94,18 @@ def _get_workunits_internal( with self.report.usage_perf_report.query_fingerprinting_timer: query_hashes.add( get_query_fingerprint( - query.query_text, "databricks", fast=True - ) + query.query_text, + "databricks", + fast=True, + ), ) self.report.num_unique_queries = len(query_hashes) table_info = self._parse_query(query, table_map) if table_info is not None: if self.config.include_operational_stats: yield from self._generate_operation_workunit( - query, table_info + query, + table_info, ) for source_table in table_info.source_tables: with ( @@ -133,7 +138,9 @@ def _get_workunits_internal( ) def _generate_operation_workunit( - self, query: Query, table_info: QueryTableInfo + self, + query: Query, + table_info: QueryTableInfo, ) -> Iterable[MetadataWorkUnit]: with self.report.usage_perf_report.gen_operation_timer: if ( @@ -167,14 +174,17 @@ def _generate_operation_workunit( def _get_queries(self) -> Iterable[Query]: try: yield from self.proxy.query_history( - self.config.start_time, self.config.end_time + self.config.start_time, + self.config.end_time, ) except Exception as e: logger.warning("Error getting queries", exc_info=True) self.report.report_warning("get-queries", str(e)) def _parse_query( - self, query: Query, table_map: TableMap + self, + query: Query, + table_map: TableMap, ) -> Optional[QueryTableInfo]: with self.report.usage_perf_report.sql_parsing_timer: table_info = self._parse_query_via_sqlglot(query.query_text) @@ -188,10 +198,12 @@ def _parse_query( else: return QueryTableInfo( source_tables=self._resolve_tables( - table_info.source_tables, table_map + table_info.source_tables, + table_map, ), target_tables=self._resolve_tables( - table_info.target_tables, table_map + table_info.target_tables, + table_map, ), ) @@ -237,7 +249,8 @@ def _parse_query_via_spark_sql_plan(self, query: str) -> Optional[StringTableInf tables = [self._parse_plan_item(item) for item in plan] self.report.num_queries_parsed_by_spark_plan += 1 return GenericTableInfo( - source_tables=[t for t in tables if t], target_tables=[] + source_tables=[t for t in tables if t], + target_tables=[], ) except Exception as e: logger.info(f"Could not parse query via spark plan, {query}: {e!r}") @@ -250,7 +263,9 @@ def _parse_plan_item(item: dict) -> Optional[str]: return None def _resolve_tables( - self, tables: List[str], table_map: TableMap + self, + tables: List[str], + table_map: TableMap, ) -> List[TableReference]: """Resolve tables to TableReferences, filtering out unrecognized or unresolvable table names.""" @@ -268,7 +283,7 @@ def _resolve_tables( output.append(refs[0]) else: logger.warning( - f"Could not resolve table ref for {table}: {len(refs)} duplicates." + f"Could not resolve table ref for {table}: {len(refs)} duplicates.", ) duplicate_table = True diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py index 2e1e315c4df956..54de66729e1622 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py @@ -152,13 +152,13 @@ def _get_clickhouse_history(self): event_dict[k] = v.strip() if not self.config.database_pattern.allowed( - event_dict.get("database") + event_dict.get("database"), ) or not ( self.config.table_pattern.allowed(event_dict.get("full_table_name")) or self.config.view_pattern.allowed(event_dict.get("full_table_name")) ): logger.debug( - f"Dropping usage event for {event_dict.get('full_table_name')}" + f"Dropping usage event for {event_dict.get('full_table_name')}", ) continue @@ -192,10 +192,10 @@ def _get_joined_access_event(self, events): joined_access_events = [] for event_dict in events: event_dict["starttime"] = self._convert_str_to_datetime( - event_dict.get("starttime") + event_dict.get("starttime"), ) event_dict["endtime"] = self._convert_str_to_datetime( - event_dict.get("endtime") + event_dict.get("endtime"), ) if not (event_dict.get("database", None) and event_dict.get("table", None)): @@ -211,7 +211,8 @@ def _get_joined_access_event(self, events): return joined_access_events def _aggregate_access_events( - self, events: List[ClickHouseJoinedAccessEvent] + self, + events: List[ClickHouseJoinedAccessEvent], ) -> Dict[datetime, Dict[ClickHouseTableRef, AggregatedDataset]]: datasets: Dict[datetime, Dict[ClickHouseTableRef, AggregatedDataset]] = ( collections.defaultdict(dict) @@ -246,7 +247,9 @@ def _make_usage_stat(self, agg: AggregatedDataset) -> MetadataWorkUnit: return agg.make_usage_workunit( self.config.bucket_duration, lambda resource: builder.make_dataset_urn( - "clickhouse", resource, self.config.env + "clickhouse", + resource, + self.config.env, ), self.config.top_n_queries, self.config.format_sql_queries, diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py index e4138696186416..59f94a6dba7ef6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py @@ -87,13 +87,13 @@ class EnvBasedSourceBaseConfig: class TrinoUsageConfig(TrinoConfig, BaseUsageConfig, EnvBasedSourceBaseConfig): email_domain: str = Field( - description="The email domain which will be appended to the users " + description="The email domain which will be appended to the users ", ) audit_catalog: str = Field( - description="The catalog name where the audit table can be found " + description="The catalog name where the audit table can be found ", ) audit_schema: str = Field( - description="The schema name where the audit table can be found" + description="The schema name where the audit table can be found", ) options: dict = Field(default={}, description="") database: str = Field(description="The name of the catalog from getting the usage") @@ -199,7 +199,7 @@ def _get_joined_access_event(self, events): for event_dict in events: if event_dict.get("create_time"): event_dict["create_time"] = self._convert_str_to_datetime( - event_dict["create_time"] + event_dict["create_time"], ) else: self.report.num_joined_access_events_skipped += 1 @@ -207,7 +207,7 @@ def _get_joined_access_event(self, events): continue event_dict["end_time"] = self._convert_str_to_datetime( - event_dict.get("end_time") + event_dict.get("end_time"), ) if not event_dict["accessed_metadata"]: @@ -216,7 +216,7 @@ def _get_joined_access_event(self, events): continue event_dict["accessed_metadata"] = json.loads( - event_dict["accessed_metadata"] + event_dict["accessed_metadata"], ) if not event_dict.get("usr"): @@ -233,7 +233,8 @@ def _get_joined_access_event(self, events): return joined_access_events def _aggregate_access_events( - self, events: List[TrinoJoinedAccessEvent] + self, + events: List[TrinoJoinedAccessEvent], ) -> Dict[datetime, Dict[TrinoTableRef, AggregatedDataset]]: datasets: Dict[datetime, Dict[TrinoTableRef, AggregatedDataset]] = ( collections.defaultdict(dict) @@ -244,10 +245,10 @@ def _aggregate_access_events( for metadata in event.accessed_metadata: # Skipping queries starting with $system@ if metadata.catalog_name and metadata.catalog_name.startswith( - "$system@" + "$system@", ): logging.debug( - f"Skipping system query for {metadata.catalog_name}..." + f"Skipping system query for {metadata.catalog_name}...", ) continue diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py index 73e7e415e2b9eb..1f87b8f35ff3d9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py @@ -90,7 +90,7 @@ def make_usage_workunit( if query_freq is not None: if top_n_queries < len(query_freq): logger.warning( - f"Top N query limit exceeded on {str(resource)}. Max number of queries {top_n_queries} < {len(query_freq)}. Truncating top queries to {top_n_queries}." + f"Top N query limit exceeded on {str(resource)}. Max number of queries {top_n_queries} < {len(query_freq)}. Truncating top queries to {top_n_queries}.", ) query_freq = query_freq[0:top_n_queries] @@ -212,14 +212,16 @@ class BaseUsageConfig(BaseTimeWindowConfig): ) top_n_queries: pydantic.PositiveInt = Field( - default=10, description="Number of top queries to save to each table." + default=10, + description="Number of top queries to save to each table.", ) user_email_pattern: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), description="regex patterns for user emails to filter in usage.", ) include_operational_stats: bool = Field( - default=True, description="Whether to display operational stats." + default=True, + description="Whether to display operational stats.", ) include_read_operational_stats: bool = Field( @@ -228,10 +230,12 @@ class BaseUsageConfig(BaseTimeWindowConfig): ) format_sql_queries: bool = Field( - default=False, description="Whether to format sql queries" + default=False, + description="Whether to format sql queries", ) include_top_n_queries: bool = Field( - default=True, description="Whether to ingest the top_n_queries." + default=True, + description="Whether to ingest the top_n_queries.", ) @pydantic.validator("top_n_queries") @@ -241,7 +245,7 @@ def ensure_top_n_queries_is_not_too_big(cls, v: int, values: dict) -> int: max_queries = int(values["queries_character_limit"] / minimum_query_size) if v > max_queries: raise ValueError( - f"top_n_queries is set to {v} but it can be maximum {max_queries}" + f"top_n_queries is set to {v} but it can be maximum {max_queries}", ) return v @@ -252,7 +256,8 @@ class UsageAggregator(Generic[ResourceType]): def __init__(self, config: BaseUsageConfig): self.config = config self.aggregation: Dict[ - datetime, Dict[ResourceType, GenericAggregatedDataset[ResourceType]] + datetime, + Dict[ResourceType, GenericAggregatedDataset[ResourceType]], ] = defaultdict(dict) def aggregate_event( @@ -306,7 +311,7 @@ def convert_usage_aggregation_class( aspect = DatasetUsageStatistics( timestampMillis=obj.bucket, eventGranularity=TimeWindowSizeClass( - unit=convert_window_to_interval(obj.duration) + unit=convert_window_to_interval(obj.duration), ), uniqueUserCount=obj.metrics.uniqueUserCount, totalSqlQueries=obj.metrics.totalSqlQueries, @@ -314,7 +319,9 @@ def convert_usage_aggregation_class( userCounts=( [ DatasetUserUsageCountsClass( - user=u.user, count=u.count, userEmail=u.userEmail + user=u.user, + count=u.count, + userEmail=u.userEmail, ) for u in obj.metrics.users if u.user is not None @@ -334,7 +341,7 @@ def convert_usage_aggregation_class( return MetadataChangeProposalWrapper(entityUrn=obj.resource, aspect=aspect) else: raise Exception( - f"Skipping unsupported usage aggregation - invalid entity type: {obj}" + f"Skipping unsupported usage aggregation - invalid entity type: {obj}", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py b/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py index f0f0ab95ca8119..c1f2c15ff962ea 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py @@ -7,14 +7,15 @@ class CSVEnricherConfig(ConfigModel): filename: str = pydantic.Field( - description="File path or URL of CSV file to ingest." + description="File path or URL of CSV file to ingest.", ) write_semantics: str = pydantic.Field( default="PATCH", description='Whether the new tags, terms and owners to be added will override the existing ones added only by this source or not. Value for this config can be "PATCH" or "OVERRIDE". NOTE: this will apply to all metadata for the entity, not just a single aspect.', ) delimiter: str = pydantic.Field( - default=",", description="Delimiter to use when parsing CSV" + default=",", + description="Delimiter to use when parsing CSV", ) array_delimiter: str = pydantic.Field( default="|", @@ -27,7 +28,7 @@ def validate_write_semantics(cls, write_semantics: str) -> str: raise ValueError( "write_semantics cannot be any other value than PATCH or OVERRIDE. Default value is PATCH. " "For PATCH semantics consider using the datahub-rest sink or " - "provide a datahub_api: configuration on your ingestion recipe" + "provide a datahub_api: configuration on your ingestion recipe", ) return write_semantics @@ -35,6 +36,6 @@ def validate_write_semantics(cls, write_semantics: str) -> str: def validator_diff(cls, array_delimiter: str, values: Dict[str, Any]) -> str: if array_delimiter == values["delimiter"]: raise ValueError( - "array_delimiter and delimiter are the same. Please choose different delimiters." + "array_delimiter and delimiter are the same. Please choose different delimiters.", ) return array_delimiter diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py b/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py index a670173aa47519..cace042f2f56ce 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py @@ -46,7 +46,7 @@ def validate_profile_day_of_week(cls, v: Optional[int]) -> Optional[int]: return None if profile_day_of_week < 0 or profile_day_of_week > 6: raise ValueError( - f"Invalid value {profile_day_of_week} for profile_day_of_week. Must be between 0 to 6 (both inclusive)." + f"Invalid value {profile_day_of_week} for profile_day_of_week. Must be between 0 to 6 (both inclusive).", ) return profile_day_of_week @@ -57,7 +57,7 @@ def validate_profile_date_of_month(cls, v: Optional[int]) -> Optional[int]: return None if profile_date_of_month < 1 or profile_date_of_month > 31: raise ValueError( - f"Invalid value {profile_date_of_month} for profile_date_of_month. Must be between 1 to 31 (both inclusive)." + f"Invalid value {profile_date_of_month} for profile_date_of_month. Must be between 1 to 31 (both inclusive).", ) return profile_date_of_month diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py b/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py index 6b458dac60ea31..baa8d580883a26 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py @@ -33,10 +33,13 @@ def _is_valid_hostname(hostname: str) -> bool: class PulsarSourceConfig( - StatefulIngestionConfigBase, PlatformInstanceConfigMixin, EnvConfigMixin + StatefulIngestionConfigBase, + PlatformInstanceConfigMixin, + EnvConfigMixin, ): web_service_url: str = Field( - default="http://localhost:8080", description="The web URL for the cluster." + default="http://localhost:8080", + description="The web URL for the cluster.", ) timeout: int = Field( default=5, @@ -48,10 +51,12 @@ class PulsarSourceConfig( description="The complete URL for a Custom Authorization Server. Mandatory for OAuth based authentication.", ) client_id: Optional[str] = Field( - default=None, description="The application's client ID" + default=None, + description="The application's client ID", ) client_secret: Optional[str] = Field( - default=None, description="The application's client secret" + default=None, + description="The application's client secret", ) # Mandatory for token authentication token: Optional[str] = Field( @@ -86,36 +91,43 @@ class PulsarSourceConfig( ) domain: Dict[str, AllowDenyPattern] = Field( - default_factory=dict, description="Domain patterns" + default_factory=dict, + description="Domain patterns", ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = Field( - default=None, description="see Stateful Ingestion" + default=None, + description="see Stateful Ingestion", ) oid_config: dict = Field( - default_factory=dict, description="Placeholder for OpenId discovery document" + default_factory=dict, + description="Placeholder for OpenId discovery document", ) @validator("token") def ensure_only_issuer_or_token( - cls, token: Optional[str], values: Dict[str, Optional[str]] + cls, + token: Optional[str], + values: Dict[str, Optional[str]], ) -> Optional[str]: if token is not None and values.get("issuer_url") is not None: raise ValueError( - "Expected only one authentication method, either issuer_url or token." + "Expected only one authentication method, either issuer_url or token.", ) return token @validator("client_secret", always=True) def ensure_client_id_and_secret_for_issuer_url( - cls, client_secret: Optional[str], values: Dict[str, Optional[str]] + cls, + client_secret: Optional[str], + values: Dict[str, Optional[str]], ) -> Optional[str]: if values.get("issuer_url") is not None and ( client_secret is None or values.get("client_id") is None ): raise ValueError( - "Missing configuration: client_id and client_secret are mandatory when issuer_url is set." + "Missing configuration: client_id and client_secret are mandatory when issuer_url is set.", ) return client_secret @@ -129,7 +141,7 @@ def web_service_url_scheme_host_port(cls, val: str) -> str: if not _is_valid_hostname(url.hostname.__str__()): raise ValueError( - f"Not a valid hostname, hostname contains invalid characters, found {url.hostname}" + f"Not a valid hostname, hostname contains invalid characters, found {url.hostname}", ) return config_clean.remove_trailing_slashes(val) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py index 55989cf17f2691..0fe52ea775f3c2 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py @@ -30,14 +30,18 @@ def __init__(self, config: AddDatasetBrowsePathConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "AddDatasetBrowsePathTransformer": config = AddDatasetBrowsePathConfig.parse_obj(config_dict) return cls(config, ctx) @staticmethod def _merge_with_server_browse_paths( - graph: DataHubGraph, urn: str, mce_browse_paths: Optional[BrowsePathsClass] + graph: DataHubGraph, + urn: str, + mce_browse_paths: Optional[BrowsePathsClass], ) -> Optional[BrowsePathsClass]: if not mce_browse_paths or not mce_browse_paths.paths: # nothing to add, no need to consult server @@ -59,7 +63,10 @@ def _merge_with_server_browse_paths( return mce_browse_paths def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: platform_part, dataset_fqdn, env = ( entity_urn.replace("urn:li:dataset:(", "").replace(")", "").split(",") @@ -85,7 +92,9 @@ def transform_aspect( return cast( Optional[Aspect], AddDatasetBrowsePathTransformer._merge_with_server_browse_paths( - self.ctx.graph, entity_urn, browse_paths + self.ctx.graph, + entity_urn, + browse_paths, ), ) else: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py index b4dc8835f9fba9..b52147ddca56f6 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py @@ -43,7 +43,10 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetDataProdu return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: return None @@ -59,7 +62,7 @@ def handle_end_of_stream( if data_product_urn: if data_product_urn not in data_products: data_products[data_product_urn] = DataProductPatchBuilder( - data_product_urn + data_product_urn, ).add_asset(entity_urn) else: data_products[data_product_urn] = data_products[ @@ -69,20 +72,21 @@ def handle_end_of_stream( if is_container: assert self.ctx.graph container_aspect = self.ctx.graph.get_aspect( - entity_urn, aspect_type=ContainerClass + entity_urn, + aspect_type=ContainerClass, ) if not container_aspect: continue container_urn = container_aspect.container if data_product_urn not in data_products_container: container_product = DataProductPatchBuilder( - data_product_urn + data_product_urn, ).add_asset(container_urn) data_products_container[data_product_urn] = container_product else: data_products_container[data_product_urn] = ( data_products_container[data_product_urn].add_asset( - container_urn + container_urn, ) ) @@ -107,14 +111,16 @@ class SimpleAddDatasetDataProduct(AddDatasetDataProduct): def __init__(self, config: SimpleDatasetDataProductConfig, ctx: PipelineContext): generic_config = AddDatasetDataProductConfig( get_data_product_to_add=lambda dataset_urn: config.dataset_to_data_product_urns.get( - dataset_urn + dataset_urn, ), ) super().__init__(generic_config, ctx) @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SimpleAddDatasetDataProduct": config = SimpleDatasetDataProductConfig.parse_obj(config_dict) return cls(config, ctx) @@ -130,7 +136,7 @@ def validate_pattern_value(cls, values: Dict) -> Dict: for key, value in rules.items(): if isinstance(value, list) and len(value) > 1: raise ValueError( - "Same dataset cannot be an asset of two different data product." + "Same dataset cannot be an asset of two different data product.", ) elif isinstance(value, str): rules[key] = [rules[key]] @@ -154,7 +160,9 @@ def __init__(self, config: PatternDatasetDataProductConfig, ctx: PipelineContext @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetDataProduct": config = PatternDatasetDataProductConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py index b107a62c905b4a..b889f77b39cbbc 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py @@ -50,7 +50,7 @@ def __init__(self, config: AddDatasetOwnershipConfig, ctx: PipelineContext): and self.ctx.graph is None ): raise ConfigurationError( - "With PATCH TransformerSemantics, AddDatasetOwnership requires a datahub_api to connect to. Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe" + "With PATCH TransformerSemantics, AddDatasetOwnership requires a datahub_api to connect to. Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe", ) @classmethod @@ -60,7 +60,9 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetOwnership @staticmethod def _merge_with_server_ownership( - graph: DataHubGraph, urn: str, mce_ownership: Optional[OwnershipClass] + graph: DataHubGraph, + urn: str, + mce_ownership: Optional[OwnershipClass], ) -> Optional[OwnershipClass]: if not mce_ownership or not mce_ownership.owners: # If there are no owners to add, we don't need to patch anything. @@ -100,7 +102,7 @@ def handle_end_of_stream( container_urn = path.urn if not container_urn or not container_urn.startswith( - "urn:li:container:" + "urn:li:container:", ): continue @@ -108,7 +110,7 @@ def handle_end_of_stream( ownership_container_mapping[container_urn] = data_ownerships else: ownership_container_mapping[container_urn] = list( - ownership_container_mapping[container_urn] + data_ownerships + ownership_container_mapping[container_urn] + data_ownerships, ) mcps: List[ @@ -124,7 +126,10 @@ def handle_end_of_stream( return mcps def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_ownership_aspect: Optional[OwnershipClass] = cast(OwnershipClass, aspect) out_ownership_aspect: OwnershipClass = OwnershipClass( @@ -149,7 +154,9 @@ def transform_aspect( return cast( Optional[Aspect], self._merge_with_server_ownership( - self.ctx.graph, entity_urn, out_ownership_aspect + self.ctx.graph, + entity_urn, + out_ownership_aspect, ), ) else: @@ -170,7 +177,7 @@ class SimpleAddDatasetOwnership(AddDatasetOwnership): def __init__(self, config: SimpleDatasetOwnershipConfig, ctx: PipelineContext): ownership_type, ownership_type_urn = builder.validate_ownership_type( - config.ownership_type + config.ownership_type, ) owners = [ OwnerClass( @@ -191,7 +198,9 @@ def __init__(self, config: SimpleDatasetOwnershipConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SimpleAddDatasetOwnership": config = SimpleDatasetOwnershipConfig.parse_obj(config_dict) return cls(config, ctx) @@ -209,7 +218,7 @@ class PatternAddDatasetOwnership(AddDatasetOwnership): def __init__(self, config: PatternDatasetOwnershipConfig, ctx: PipelineContext): owner_pattern = config.owner_pattern ownership_type, ownership_type_urn = builder.validate_ownership_type( - config.ownership_type + config.ownership_type, ) generic_config = AddDatasetOwnershipConfig( get_owners_to_add=lambda urn: [ @@ -229,7 +238,9 @@ def __init__(self, config: PatternDatasetOwnershipConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetOwnership": config = PatternDatasetOwnershipConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py index 4b9b4c9e6f5da6..af7ae6717cda0c 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py @@ -77,30 +77,34 @@ def _merge_with_server_properties( custom_properties_to_add.update(dataset_properties_aspect.customProperties) patch_dataset_properties: DatasetPropertiesClass = copy.deepcopy( - dataset_properties_aspect + dataset_properties_aspect, ) patch_dataset_properties.customProperties = custom_properties_to_add return patch_dataset_properties def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_dataset_properties_aspect: DatasetPropertiesClass = cast( - DatasetPropertiesClass, aspect + DatasetPropertiesClass, + aspect, ) if not in_dataset_properties_aspect: in_dataset_properties_aspect = DatasetPropertiesClass() out_dataset_properties_aspect: DatasetPropertiesClass = copy.deepcopy( - in_dataset_properties_aspect + in_dataset_properties_aspect, ) if self.config.replace_existing is True: # clean the existing properties out_dataset_properties_aspect.customProperties = {} properties_to_add = self.config.add_properties_resolver_class( # type: ignore - **self.resolver_args + **self.resolver_args, ).get_properties_to_add(entity_urn) out_dataset_properties_aspect.customProperties.update(properties_to_add) @@ -108,7 +112,9 @@ def transform_aspect( assert self.ctx.graph patch_dataset_properties_aspect = ( AddDatasetProperties._merge_with_server_properties( - self.ctx.graph, entity_urn, out_dataset_properties_aspect + self.ctx.graph, + entity_urn, + out_dataset_properties_aspect, ) ) return cast(Optional[Aspect], patch_dataset_properties_aspect) @@ -142,7 +148,9 @@ def __init__(self, config: SimpleAddDatasetPropertiesConfig, ctx: PipelineContex @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SimpleAddDatasetProperties": config = SimpleAddDatasetPropertiesConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py index d2687ebc5e76f6..df86a341125b70 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py @@ -42,7 +42,9 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetSchemaTag return cls(config, ctx) def extend_field( - self, schema_field: SchemaFieldClass, server_field: Optional[SchemaFieldClass] + self, + schema_field: SchemaFieldClass, + server_field: Optional[SchemaFieldClass], ) -> SchemaFieldClass: all_tags = self.config.get_tags_to_add(schema_field.fieldPath) if len(all_tags) == 0: @@ -79,11 +81,15 @@ def extend_field( return schema_field def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[builder.Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[builder.Aspect], ) -> Optional[builder.Aspect]: schema_metadata_aspect: SchemaMetadataClass = cast(SchemaMetadataClass, aspect) assert schema_metadata_aspect is None or isinstance( - schema_metadata_aspect, SchemaMetadataClass + schema_metadata_aspect, + SchemaMetadataClass, ) server_field_map: dict = {} @@ -140,7 +146,9 @@ def __init__(self, config: PatternDatasetTagsConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetSchemaTags": config = PatternDatasetTagsConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py index d17a39bee6cfbf..14d25949bb0a67 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py @@ -43,7 +43,9 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetSchemaTer return cls(config, ctx) def extend_field( - self, schema_field: SchemaFieldClass, server_field: Optional[SchemaFieldClass] + self, + schema_field: SchemaFieldClass, + server_field: Optional[SchemaFieldClass], ) -> SchemaFieldClass: all_terms = self.config.get_terms_to_add(schema_field.fieldPath) if len(all_terms) == 0: @@ -86,7 +88,8 @@ def extend_field( schema_field.glossaryTerms.auditStamp if schema_field.glossaryTerms is not None else AuditStampClass( - time=builder.get_sys_time(), actor="urn:li:corpUser:restEmitter" + time=builder.get_sys_time(), + actor="urn:li:corpUser:restEmitter", ) ), ) @@ -96,15 +99,20 @@ def extend_field( return schema_field def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[builder.Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[builder.Aspect], ) -> Optional[builder.Aspect]: schema_metadata_aspect: SchemaMetadataClass = cast(SchemaMetadataClass, aspect) assert schema_metadata_aspect is None or isinstance( - schema_metadata_aspect, SchemaMetadataClass + schema_metadata_aspect, + SchemaMetadataClass, ) server_field_map: Dict[ - str, SchemaFieldClass + str, + SchemaFieldClass, ] = {} # Map to cache server field objects, where fieldPath is key if self.config.semantics == TransformerSemantics.PATCH: assert self.ctx.graph @@ -160,7 +168,9 @@ def __init__(self, config: PatternDatasetTermsConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetSchemaTerms": config = PatternDatasetTermsConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py index 355ca7a373653f..251666a3e28d87 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py @@ -45,12 +45,17 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetTags": return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_global_tags_aspect: GlobalTagsClass = cast(GlobalTagsClass, aspect) out_global_tags_aspect: GlobalTagsClass = GlobalTagsClass(tags=[]) self.update_if_keep_existing( - self.config, in_global_tags_aspect, out_global_tags_aspect + self.config, + in_global_tags_aspect, + out_global_tags_aspect, ) tags_to_add = self.config.get_tags_to_add(entity_urn) @@ -61,7 +66,10 @@ def transform_aspect( self.processed_tags.setdefault(tag.tag, tag) return self.get_result_semantics( - self.config, self.ctx.graph, entity_urn, out_global_tags_aspect + self.config, + self.ctx.graph, + entity_urn, + out_global_tags_aspect, ) def handle_end_of_stream( @@ -79,7 +87,7 @@ def handle_end_of_stream( MetadataChangeProposalWrapper( entityUrn=tag_urn.urn(), aspect=tag_urn.to_key_aspect(), - ) + ), ) return mcps diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py index 3daf52e32ed4bb..3ec8dc3ce51fef 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py @@ -60,16 +60,20 @@ def _merge_with_server_glossary_terms( { **{term.urn: term for term in server_glossary_terms_aspect.terms}, **{term.urn: term for term in glossary_terms_aspect.terms}, - }.values() + }.values(), ) return glossary_terms_aspect def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_glossary_terms: Optional[GlossaryTermsClass] = cast( - Optional[GlossaryTermsClass], aspect + Optional[GlossaryTermsClass], + aspect, ) out_glossary_terms: GlossaryTermsClass = GlossaryTermsClass( terms=[], @@ -77,7 +81,8 @@ def transform_aspect( in_glossary_terms.auditStamp if in_glossary_terms is not None else AuditStampClass( - time=builder.get_sys_time(), actor="urn:li:corpUser:restEmitter" + time=builder.get_sys_time(), + actor="urn:li:corpUser:restEmitter", ) ), ) @@ -94,7 +99,9 @@ def transform_aspect( if self.config.semantics == TransformerSemantics.PATCH: assert self.ctx.graph patch_glossary_terms = AddDatasetTerms._merge_with_server_glossary_terms( - self.ctx.graph, entity_urn, out_glossary_terms + self.ctx.graph, + entity_urn, + out_glossary_terms, ) return cast(Optional[Aspect], patch_glossary_terms) else: @@ -145,7 +152,9 @@ def __init__(self, config: PatternDatasetTermsConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetTerms": config = PatternDatasetTermsConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/auto_helper_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/auto_helper_transformer.py index cec7be36060b72..2a9f51f28abee2 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/auto_helper_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/auto_helper_transformer.py @@ -26,7 +26,8 @@ def __init__( self.converter = converter def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: records = list(record_envelopes) @@ -36,7 +37,7 @@ def transform( yield from self._from_workunits( self.converter( self._into_workunits(normal_records), - ) + ), ) # Pass through control records as-is. Note that this isn't fully correct, since it technically @@ -64,7 +65,8 @@ def _into_workunits( @classmethod def _from_workunits( - cls, stream: Iterable[MetadataWorkUnit] + cls, + stream: Iterable[MetadataWorkUnit], ) -> Iterable[RecordEnvelope]: for workunit in stream: yield RecordEnvelope( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py index 0a59380531ad36..88982531ab01ce 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/base_transformer.py @@ -18,7 +18,9 @@ def _update_work_unit_id( - envelope: RecordEnvelope, urn: str, aspect_name: str + envelope: RecordEnvelope, + urn: str, + aspect_name: str, ) -> Dict[Any, Any]: structured_urn = Urn.from_string(urn) simple_name = "-".join(structured_urn.entity_ids) @@ -35,7 +37,9 @@ def handle_end_of_stream( class LegacyMCETransformer( - Transformer, HandleEndOfStreamTransformer, metaclass=ABCMeta + Transformer, + HandleEndOfStreamTransformer, + metaclass=ABCMeta, ): @abstractmethod def transform_one(self, mce: MetadataChangeEventClass) -> MetadataChangeEventClass: @@ -50,7 +54,10 @@ def aspect_name(self) -> str: @abstractmethod def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: """Implement this method to transform a single aspect for an entity. param: entity_urn: the entity that is being processed @@ -100,7 +107,8 @@ def _should_process( entity_type = guess_entity_type(record.proposedSnapshot.urn) return entity_type in entity_types elif isinstance( - record, (MetadataChangeProposalWrapper, MetadataChangeProposalClass) + record, + (MetadataChangeProposalWrapper, MetadataChangeProposalClass), ): return record.entityType in entity_types @@ -115,7 +123,8 @@ def _record_mce(self, mce: MetadataChangeEventClass) -> None: self.entity_map[mce.proposedSnapshot.urn] = record_entry def _record_mcp( - self, mcp: Union[MetadataChangeProposalWrapper, MetadataChangeProposalClass] + self, + mcp: Union[MetadataChangeProposalWrapper, MetadataChangeProposalClass], ) -> None: assert mcp.entityUrn record_entry = self.entity_map.get(mcp.entityUrn, {"seen": {}}) @@ -191,7 +200,7 @@ def _transform_or_record_mcpw( if transformed_aspect is None: # drop the record log.debug( - f"Dropping record {envelope} as transformation result is None" + f"Dropping record {envelope} as transformation result is None", ) envelope.record.aspect = transformed_aspect else: @@ -199,10 +208,12 @@ def _transform_or_record_mcpw( return envelope if envelope.record.aspect is not None else None def _handle_end_of_stream( - self, envelope: RecordEnvelope + self, + envelope: RecordEnvelope, ) -> Iterable[RecordEnvelope]: if not isinstance(self, SingleAspectTransformer) and not isinstance( - self, LegacyMCETransformer + self, + LegacyMCETransformer, ): return @@ -228,7 +239,8 @@ def _handle_end_of_stream( ) def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: for envelope in record_envelopes: if not self._should_process(envelope.record): @@ -237,7 +249,8 @@ def transform( elif isinstance(envelope.record, MetadataChangeEventClass): envelope = self._transform_or_record_mce(envelope) elif isinstance( - envelope.record, MetadataChangeProposalWrapper + envelope.record, + MetadataChangeProposalWrapper, ) and isinstance(self, SingleAspectTransformer): return_envelope = self._transform_or_record_mcpw(envelope) if return_envelope is None: @@ -245,7 +258,8 @@ def transform( else: envelope = return_envelope elif isinstance(envelope.record, EndOfStream) and isinstance( - self, SingleAspectTransformer + self, + SingleAspectTransformer, ): # walk through state and call transform for any unprocessed entities for urn, state in self.entity_map.items(): diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py index 6b78b71eaa78e9..9ec8b5e24f4968 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py @@ -74,24 +74,28 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetDomain": def raise_ctx_configuration_error(ctx: PipelineContext) -> None: if ctx.graph is None: raise ConfigurationError( - "AddDatasetDomain requires a datahub_api to connect to. Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe" + "AddDatasetDomain requires a datahub_api to connect to. Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe", ) @staticmethod def get_domain_class( - graph: Optional[DataHubGraph], domains: List[str] + graph: Optional[DataHubGraph], + domains: List[str], ) -> DomainsClass: domain_registry: DomainRegistry = DomainRegistry( - cached_domains=[k for k in domains], graph=graph + cached_domains=[k for k in domains], + graph=graph, ) domain_class = DomainsClass( - domains=[domain_registry.get_domain_urn(domain) for domain in domains] + domains=[domain_registry.get_domain_urn(domain) for domain in domains], ) return domain_class @staticmethod def _merge_with_server_domains( - graph: Optional[DataHubGraph], urn: str, mce_domain: Optional[DomainsClass] + graph: Optional[DataHubGraph], + urn: str, + mce_domain: Optional[DomainsClass], ) -> Optional[DomainsClass]: if not mce_domain or not mce_domain.domains: # nothing to add, no need to consult server @@ -139,7 +143,7 @@ def handle_end_of_stream( container_urn = path.urn if not container_urn or not container_urn.startswith( - "urn:li:container:" + "urn:li:container:", ): continue @@ -149,8 +153,8 @@ def handle_end_of_stream( container_domain_mapping[container_urn] = list( set( container_domain_mapping[container_urn] - + domain_to_add.domains - ) + + domain_to_add.domains, + ), ) for urn, domains in container_domain_mapping.items(): @@ -158,13 +162,16 @@ def handle_end_of_stream( MetadataChangeProposalWrapper( entityUrn=urn, aspect=DomainsClass(domains=domains), - ) + ), ) return domain_mcps def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_domain_aspect: DomainsClass = cast(DomainsClass, aspect) domain_aspect: DomainsClass = DomainsClass(domains=[]) @@ -185,7 +192,9 @@ def transform_aspect( return None if self.config.semantics == TransformerSemantics.PATCH: final_aspect = AddDatasetDomain._merge_with_server_domains( - self.ctx.graph, entity_urn, domain_aspect + self.ctx.graph, + entity_urn, + domain_aspect, ) return cast(Optional[Aspect], final_aspect) @@ -194,7 +203,9 @@ class SimpleAddDatasetDomain(AddDatasetDomain): """Transformer that adds a specified set of domains to each dataset.""" def __init__( - self, config: SimpleDatasetDomainSemanticsConfig, ctx: PipelineContext + self, + config: SimpleDatasetDomainSemanticsConfig, + ctx: PipelineContext, ): AddDatasetDomain.raise_ctx_configuration_error(ctx) domains = AddDatasetDomain.get_domain_class(ctx.graph, config.domains) @@ -206,7 +217,9 @@ def __init__( @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SimpleAddDatasetDomain": config = SimpleDatasetDomainSemanticsConfig.parse_obj(config_dict) return cls(config, ctx) @@ -216,7 +229,9 @@ class PatternAddDatasetDomain(AddDatasetDomain): """Transformer that adds a specified set of domains to each dataset.""" def __init__( - self, config: PatternDatasetDomainSemanticsConfig, ctx: PipelineContext + self, + config: PatternDatasetDomainSemanticsConfig, + ctx: PipelineContext, ): AddDatasetDomain.raise_ctx_configuration_error(ctx) @@ -236,7 +251,9 @@ def resolve_domain(domain_urn: str) -> DomainsClass: @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternAddDatasetDomain": config = PatternDatasetDomainSemanticsConfig.parse_obj(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py index bb2f318dcac8b8..178998a5a7c11c 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py @@ -25,13 +25,18 @@ def __init__(self, config: DatasetTagDomainMapperConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "DatasetTagDomainMapper": config = DatasetTagDomainMapperConfig.parse_obj(config_dict) return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: # Initialize the existing domain aspect existing_domain_aspect: DomainsClass = cast(DomainsClass, aspect) @@ -55,14 +60,17 @@ def transform_aspect( domains_to_add.append(domain_mapping[tag]) mapped_domains = AddDatasetDomain.get_domain_class( - self.ctx.graph, domains_to_add + self.ctx.graph, + domains_to_add, ) domain_aspect.domains.extend(mapped_domains.domains) if self.config.semantics == TransformerSemantics.PATCH: # Try merging with server-side domains patch_domain_aspect: Optional[DomainsClass] = ( AddDatasetDomain._merge_with_server_domains( - self.ctx.graph, entity_urn, domain_aspect + self.ctx.graph, + entity_urn, + domain_aspect, ) ) return cast(Optional[Aspect], patch_domain_aspect) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_transformer.py index 00b3a9ba59f924..eb98bc34fd21fb 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_transformer.py @@ -28,7 +28,9 @@ def entity_types(self) -> List[str]: class OwnershipTransformer( - DatasetTransformer, SingleAspectTransformer, metaclass=ABCMeta + DatasetTransformer, + SingleAspectTransformer, + metaclass=ABCMeta, ): def aspect_name(self) -> str: return "ownership" @@ -79,7 +81,9 @@ def aspect_name(self) -> str: @staticmethod def merge_with_server_global_tags( - graph: DataHubGraph, urn: str, global_tags_aspect: Optional[GlobalTagsClass] + graph: DataHubGraph, + urn: str, + global_tags_aspect: Optional[GlobalTagsClass], ) -> Optional[GlobalTagsClass]: if not global_tags_aspect or not global_tags_aspect.tags: # nothing to add, no need to consult server @@ -93,7 +97,7 @@ def merge_with_server_global_tags( { **{tag.tag: tag for tag in server_global_tags_aspect.tags}, **{tag.tag: tag for tag in global_tags_aspect.tags}, - }.values() + }.values(), ) return global_tags_aspect @@ -124,7 +128,9 @@ def get_result_semantics( return cast( Optional[Aspect], DatasetTagsTransformer.merge_with_server_global_tags( - graph, urn, out_global_tags_aspect + graph, + urn, + out_global_tags_aspect, ), ) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py index 4b64d38a9b42fa..96a8a1408f6e19 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py @@ -45,12 +45,17 @@ def _get_tags_to_add(self, entity_urn: str) -> List[TagAssociationClass]: raise NotImplementedError() def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_global_tags_aspect: GlobalTagsClass = cast(GlobalTagsClass, aspect) out_global_tags_aspect: GlobalTagsClass = GlobalTagsClass(tags=[]) self.update_if_keep_existing( - self.config, in_global_tags_aspect, out_global_tags_aspect + self.config, + in_global_tags_aspect, + out_global_tags_aspect, ) tags_to_add = self._get_tags_to_add(entity_urn) @@ -58,5 +63,8 @@ def transform_aspect( out_global_tags_aspect.tags.extend(tags_to_add) return self.get_result_semantics( - self.config, self.ctx.graph, entity_urn, out_global_tags_aspect + self.config, + self.ctx.graph, + entity_urn, + out_global_tags_aspect, ) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py index 32707dcd3a372f..9fc98f78bcb148 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py @@ -33,7 +33,8 @@ class ExtractOwnersFromTagsConfig(ConfigModel): owner_type_urn: Optional[str] = None _rename_tag_prefix_to_tag_pattern = pydantic_renamed_field( - "tag_prefix", "tag_pattern" + "tag_prefix", + "tag_pattern", ) @@ -60,7 +61,9 @@ def __init__(self, config: ExtractOwnersFromTagsConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "ExtractOwnersFromTagsTransformer": config = ExtractOwnersFromTagsConfig.parse_obj(config_dict) return cls(config, ctx) @@ -108,7 +111,10 @@ def handle_end_of_stream( return self.owner_mcps def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_tags_aspect: Optional[GlobalTagsClass] = cast(GlobalTagsClass, aspect) if in_tags_aspect is None: @@ -136,7 +142,7 @@ def transform_aspect( owner=owner_urn, type=OwnershipTypeClass.CUSTOM, typeUrn=make_ownership_type_urn(re_match.group(1)), - ) + ), ) else: owner_type = get_owner_type(self.config.owner_type) @@ -150,7 +156,7 @@ def transform_aspect( owner=owner_urn, type=owner_type, typeUrn=self.config.owner_type_urn, - ) + ), ) self.owner_mcps.append( @@ -159,6 +165,6 @@ def transform_aspect( aspect=OwnershipClass( owners=owners, ), - ) + ), ) return aspect diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/generic_aspect_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/generic_aspect_transformer.py index 5bf70274dce89d..471e2114716501 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/generic_aspect_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/generic_aspect_transformer.py @@ -20,7 +20,9 @@ class GenericAspectTransformer( - BaseTransformer, SingleAspectTransformer, metaclass=ABCMeta + BaseTransformer, + SingleAspectTransformer, + metaclass=ABCMeta, ): """Transformer that does transform custom aspects using GenericAspectClass.""" @@ -28,14 +30,20 @@ def __init__(self): super().__init__() def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: """Do not implement.""" pass @abstractmethod def transform_generic_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[GenericAspectClass] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[GenericAspectClass], ) -> Optional[GenericAspectClass]: """Implement this method to transform the single custom aspect for an entity. The purpose of this abstract method is to reinforce the use of GenericAspectClass. @@ -57,7 +65,7 @@ def _transform_or_record_mcpc( self._mark_processed(envelope.record.entityUrn) if transformed_aspect is None: log.debug( - f"Dropping record {envelope} as transformation result is None" + f"Dropping record {envelope} as transformation result is None", ) envelope.record.aspect = transformed_aspect else: @@ -65,7 +73,8 @@ def _transform_or_record_mcpc( return envelope if envelope.record.aspect is not None else None def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: """ This method overrides the original one from BaseTransformer in order to support @@ -86,7 +95,8 @@ def transform( else: envelope = return_envelope elif isinstance(envelope.record, EndOfStream) and isinstance( - self, SingleAspectTransformer + self, + SingleAspectTransformer, ): for urn, state in self.entity_map.items(): if "seen" in state: @@ -105,8 +115,8 @@ def transform( record_metadata = envelope.metadata.copy() record_metadata.update( { - "workunit_id": f"txform-{simple_name}-{self.aspect_name()}" - } + "workunit_id": f"txform-{simple_name}-{self.aspect_name()}", + }, ) yield RecordEnvelope( record=MetadataChangeProposalClass( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py b/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py index 00ef29183a0c9a..ec2edec38862a6 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py @@ -28,7 +28,10 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "MarkDatasetStatus": return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[builder.Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[builder.Aspect], ) -> Optional[builder.Aspect]: assert aspect is None or isinstance(aspect, StatusClass) status_aspect: StatusClass = aspect or StatusClass(removed=None) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py index a3d41c8e91ec52..898acf057796b8 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py @@ -36,21 +36,27 @@ def __init__( @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternCleanupDatasetUsageUser": config = PatternCleanupDatasetUsageUserConfig.parse_obj(config_dict) return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_dataset_properties_aspect: DatasetUsageStatisticsClass = cast( - DatasetUsageStatisticsClass, aspect + DatasetUsageStatisticsClass, + aspect, ) if in_dataset_properties_aspect.userCounts is not None: out_dataset_properties_aspect: DatasetUsageStatisticsClass = copy.deepcopy( - in_dataset_properties_aspect + in_dataset_properties_aspect, ) if out_dataset_properties_aspect.userCounts is not None: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py index f17546d6f72990..d3c28246abe59e 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py @@ -31,7 +31,9 @@ def __init__(self, config: PatternCleanUpOwnershipConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "PatternCleanUpOwnership": config = PatternCleanUpOwnershipConfig.parse_obj(config_dict) return cls(config, ctx) @@ -50,7 +52,10 @@ def _get_current_owner_urns(self, entity_urn: str) -> Set[str]: return set() def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[builder.Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[builder.Aspect], ) -> Optional[builder.Aspect]: # get current owner URNs from the graph current_owner_urns = self._get_current_owner_urns(entity_urn) @@ -65,7 +70,7 @@ def transform_aspect( cleaned_owner_urns.append(_USER_URN_PREFIX + user_id) ownership_type, ownership_type_urn = builder.validate_ownership_type( - OwnershipTypeClass.DATAOWNER + OwnershipTypeClass.DATAOWNER, ) owners = [ OwnerClass( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py index 934e2a13d56314..a9319e7ba7ba33 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py @@ -19,13 +19,18 @@ def __init__(self, config: ClearDatasetOwnershipConfig, ctx: PipelineContext): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SimpleRemoveDatasetOwnership": config = ClearDatasetOwnershipConfig.parse_obj(config_dict) return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_ownership_aspect = cast(OwnershipClass, aspect) if in_ownership_aspect is None: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py b/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py index f6847f234aefe6..a3a721924b0eb1 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py @@ -45,16 +45,22 @@ def __init__( @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "ReplaceExternalUrlDataset": config = ReplaceExternalUrlConfig.parse_obj(config_dict) return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_dataset_properties_aspect: DatasetPropertiesClass = cast( - DatasetPropertiesClass, aspect + DatasetPropertiesClass, + aspect, ) if ( @@ -64,7 +70,7 @@ def transform_aspect( return cast(Aspect, in_dataset_properties_aspect) else: out_dataset_properties_aspect: DatasetPropertiesClass = copy.deepcopy( - in_dataset_properties_aspect + in_dataset_properties_aspect, ) out_dataset_properties_aspect.externalUrl = self.replace_url( @@ -95,16 +101,22 @@ def __init__( @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "ReplaceExternalUrlContainer": config = ReplaceExternalUrlConfig.parse_obj(config_dict) return cls(config, ctx) def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_container_properties_aspect: ContainerPropertiesClass = cast( - ContainerPropertiesClass, aspect + ContainerPropertiesClass, + aspect, ) if ( not hasattr(in_container_properties_aspect, "externalUrl") @@ -113,7 +125,7 @@ def transform_aspect( return cast(Aspect, in_container_properties_aspect) else: out_container_properties_aspect: ContainerPropertiesClass = copy.deepcopy( - in_container_properties_aspect + in_container_properties_aspect, ) out_container_properties_aspect.externalUrl = self.replace_url( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/system_metadata_transformer.py b/metadata-ingestion/src/datahub/ingestion/transformer/system_metadata_transformer.py index 3b5f26c127741d..bd775b94ebc295 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/system_metadata_transformer.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/system_metadata_transformer.py @@ -21,7 +21,8 @@ def auto_system_metadata( for workunit in stream: if set_system_metadata: workunit.metadata.systemMetadata = SystemMetadataClass( - lastObserved=get_sys_time(), runId=ctx.run_id + lastObserved=get_sys_time(), + runId=ctx.run_id, ) if set_pipeline_name: workunit.metadata.systemMetadata.pipelineName = ctx.pipeline_name @@ -32,11 +33,12 @@ def auto_system_metadata( class SystemMetadataTransformer(Transformer): def __init__(self, ctx: PipelineContext): self._inner_transformer = AutoHelperTransformer( - functools.partial(auto_system_metadata, ctx) + functools.partial(auto_system_metadata, ctx), ) def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: yield from self._inner_transformer.transform(record_envelopes) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py index 65cf2ac3614ae0..d9f71c92db6143 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py @@ -53,7 +53,7 @@ def _merge_with_server_glossary_terms( { **{term.urn: term for term in server_glossary_terms_aspect.terms}, **{term.urn: term for term in glossary_terms_aspect.terms}, - }.values() + }.values(), ) return glossary_terms_aspect @@ -77,20 +77,24 @@ def get_tags_from_schema_metadata( for field in schema_metadata.fields: if field.globalTags: tags.update( - TagsToTermMapper.get_tags_from_global_tags(field.globalTags) + TagsToTermMapper.get_tags_from_global_tags(field.globalTags), ) return tags def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[Aspect], ) -> Optional[Aspect]: in_glossary_terms: Optional[GlossaryTermsClass] = cast( - Optional[GlossaryTermsClass], aspect + Optional[GlossaryTermsClass], + aspect, ) assert self.ctx.graph in_global_tags_aspect: Optional[GlobalTagsClass] = self.ctx.graph.get_tags( - entity_urn + entity_urn, ) in_schema_metadata_aspect: Optional[SchemaMetadataClass] = ( self.ctx.graph.get_schema_metadata(entity_urn) @@ -101,7 +105,7 @@ def transform_aspect( global_tags = TagsToTermMapper.get_tags_from_global_tags(in_global_tags_aspect) schema_metadata_tags = TagsToTermMapper.get_tags_from_schema_metadata( - in_schema_metadata_aspect + in_schema_metadata_aspect, ) # Combine tags from both global and schema level @@ -129,14 +133,17 @@ def transform_aspect( out_glossary_terms = GlossaryTermsClass( terms=[GlossaryTermAssociationClass(urn=term) for term in terms_to_add], auditStamp=AuditStampClass( - time=builder.get_sys_time(), actor="urn:li:corpUser:restEmitter" + time=builder.get_sys_time(), + actor="urn:li:corpUser:restEmitter", ), ) if self.config.semantics == TransformerSemantics.PATCH: patch_glossary_terms: Optional[GlossaryTermsClass] = ( TagsToTermMapper._merge_with_server_glossary_terms( - self.ctx.graph, entity_urn, out_glossary_terms + self.ctx.graph, + entity_urn, + out_glossary_terms, ) ) return cast(Optional[Aspect], patch_glossary_terms) diff --git a/metadata-ingestion/src/datahub/integrations/assertion/registry.py b/metadata-ingestion/src/datahub/integrations/assertion/registry.py index 26015ddbf9a315..393d2da00c174d 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/registry.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/registry.py @@ -4,5 +4,5 @@ from datahub.integrations.assertion.snowflake.compiler import SnowflakeAssertionCompiler ASSERTION_PLATFORMS: Dict[str, Type[AssertionCompiler]] = { - "snowflake": SnowflakeAssertionCompiler + "snowflake": SnowflakeAssertionCompiler, } diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py index e32f1ddc3943ae..754c4a9fa2c39a 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/compiler.py @@ -50,7 +50,8 @@ def __init__(self, output_dir: str, extras: Dict[str, str]) -> None: self.output_dir = Path(output_dir) self.extras = extras self.metric_generator = SnowflakeMetricSQLGenerator( - SnowflakeFieldMetricSQLGenerator(), SnowflakeFieldValuesMetricSQLGenerator() + SnowflakeFieldMetricSQLGenerator(), + SnowflakeFieldValuesMetricSQLGenerator(), ) self.metric_evaluator = SnowflakeMetricEvalOperatorSQLGenerator() self.dmf_handler = SnowflakeDMFHandler() @@ -59,7 +60,9 @@ def __init__(self, output_dir: str, extras: Dict[str, str]) -> None: @classmethod def create( - cls, output_dir: str, extras: Dict[str, str] + cls, + output_dir: str, + extras: Dict[str, str], ) -> "SnowflakeAssertionCompiler": assert os.path.exists(output_dir), ( f"Specified location {output_dir} does not exist." @@ -76,7 +79,8 @@ def create( return SnowflakeAssertionCompiler(output_dir, extras) def compile( - self, assertion_config_spec: AssertionsConfigSpec + self, + assertion_config_spec: AssertionsConfigSpec, ) -> AssertionCompilationResult: result = AssertionCompilationResult("snowflake", "success") @@ -92,7 +96,7 @@ def compile( try: start_line = f"\n-- Start of Assertion {assertion_spec.get_id()}\n" (dmf_definition, dmf_association) = self.process_assertion( - assertion_spec + assertion_spec, ) end_line = f"\n-- End of Assertion {assertion_spec.get_id()}\n" @@ -119,7 +123,7 @@ def compile( path=dmf_definitions_path, type=CompileResultArtifactType.SQL_QUERIES, description="SQL file containing DMF create definitions equivalent to Datahub Assertions", - ) + ), ) result.add_artifact( CompileResultArtifact( @@ -127,7 +131,7 @@ def compile( path=dmf_associations_path, type=CompileResultArtifactType.SQL_QUERIES, description="ALTER TABLE queries to associate DMFs to table to run on configured schedule.", - ) + ), ) return result @@ -160,7 +164,8 @@ def process_assertion(self, assertion: DataHubAssertion) -> Tuple[str, str]: ) else: assertion_sql = self.metric_evaluator.operator_sql( - assertion.assertion.operator, metric_definition + assertion.assertion.operator, + metric_definition, ) dmf_name = get_dmf_name(assertion) @@ -171,7 +176,8 @@ def process_assertion(self, assertion: DataHubAssertion) -> Tuple[str, str]: entity_name = get_entity_name(assertion.assertion) self._entity_schedule_history.setdefault( - assertion.assertion.entity, assertion.assertion.trigger + assertion.assertion.entity, + assertion.assertion.trigger, ) if ( assertion.assertion.entity in self._entity_schedule_history @@ -182,7 +188,7 @@ def process_assertion(self, assertion: DataHubAssertion) -> Tuple[str, str]: "Assertions on same entity must have same schedules as of now." f" Found different schedules on entity {assertion.assertion.entity} ->" f" ({self._entity_schedule_history[assertion.assertion.entity].trigger})," - f" ({assertion.assertion.trigger.trigger})" + f" ({assertion.assertion.trigger.trigger})", ) dmf_schedule = get_dmf_schedule(assertion.assertion.trigger) @@ -220,7 +226,8 @@ def get_dmf_args(assertion: DataHubAssertion) -> Tuple[str, str]: if entity_schema: for col_dict in entity_schema: return args_create_dmf.format( - col_name=col_dict["col"], col_type=col_dict["native_type"] + col_name=col_dict["col"], + col_type=col_dict["native_type"], ), args_add_dmf.format(col_name=col_dict["col"]) raise ValueError("entity schema not available") diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/dmf_generator.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/dmf_generator.py index 4f50b7c2b81a57..8c470f691dd209 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/dmf_generator.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/dmf_generator.py @@ -1,6 +1,10 @@ class SnowflakeDMFHandler: def create_dmf( - self, dmf_name: str, dmf_args: str, dmf_comment: str, dmf_sql: str + self, + dmf_name: str, + dmf_args: str, + dmf_comment: str, + dmf_sql: str, ) -> str: return f""" CREATE or REPLACE DATA METRIC FUNCTION @@ -14,7 +18,11 @@ def create_dmf( """ def add_dmf_to_table( - self, dmf_name: str, dmf_col_args: str, dmf_schedule: str, table_identifier: str + self, + dmf_name: str, + dmf_col_args: str, + dmf_schedule: str, + table_identifier: str, ) -> str: return f""" ALTER TABLE {table_identifier} SET DATA_METRIC_SCHEDULE = '{dmf_schedule}'; diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_metric_sql_generator.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_metric_sql_generator.py index 3ff218a9f280b3..acab6027ba4875 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_metric_sql_generator.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_metric_sql_generator.py @@ -7,111 +7,162 @@ class SnowflakeFieldMetricSQLGenerator: def unique_count_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select count(distinct {field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def unique_percentage_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select count(distinct {field_name})/count(*) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def null_count_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: where_clause = self._setup_where_clause( - [dataset_filter, f"{field_name} is null"] + [dataset_filter, f"{field_name} is null"], ) return f"""select count(*) from {entity_name} {where_clause}""" def null_percentage_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select ({self.null_count_sql(field_name, entity_name, dataset_filter)})/count(*) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def min_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select min({field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def max_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select max({field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def mean_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select avg({field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def median_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select median({field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def stddev_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select stddev({field_name}) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def negative_count_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: where_clause = self._setup_where_clause([dataset_filter, f"{field_name} < 0"]) return f"""select count(*) from {entity_name} {where_clause}""" def negative_percentage_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select ({self.negative_count_sql(field_name, entity_name, dataset_filter)})/count(*) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def zero_count_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: where_clause = self._setup_where_clause([dataset_filter, f"{field_name} = 0"]) return f"""select count(*) from {entity_name} {where_clause}""" def zero_percentage_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select ({self.zero_count_sql(field_name, entity_name, dataset_filter)})/count(*) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def min_length_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select min(length({field_name})) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def max_length_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select max(length({field_name})) from {entity_name} {self._setup_where_clause([dataset_filter])}""" def empty_count_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: where_clause = self._setup_where_clause( - [dataset_filter, f"({field_name} is null or trim({field_name})='')"] + [dataset_filter, f"({field_name} is null or trim({field_name})='')"], ) return f"""select count(*) from {entity_name} {where_clause}""" def empty_percentage_sql( - self, field_name: str, entity_name: str, dataset_filter: Optional[str] + self, + field_name: str, + entity_name: str, + dataset_filter: Optional[str], ) -> str: return f"""select ({self.empty_count_sql(field_name, entity_name, dataset_filter)})/count(*) from {entity_name} {self._setup_where_clause([dataset_filter])}""" diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_values_metric_sql_generator.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_values_metric_sql_generator.py index b77cc971d3a450..52b3be3a04dbc6 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_values_metric_sql_generator.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/field_values_metric_sql_generator.py @@ -236,7 +236,9 @@ def _setup_where_clause(self, filters: List[Optional[str]]) -> str: return f"where {where_clause}" if where_clause else "" def _setup_field_transform( - self, field: str, transform: Optional[FieldTransform] + self, + field: str, + transform: Optional[FieldTransform], ) -> str: if transform is None: return field @@ -266,14 +268,18 @@ def metric_sql(self, assertion: FieldValuesAssertion) -> str: [ dataset_filter, f"{assertion.field} is not null" if assertion.exclude_nulls else None, - ] + ], ) transformed_field = self._setup_field_transform( - assertion.field, assertion.field_transform + assertion.field, + assertion.field_transform, ) # this sql would return boolean value for each table row. 1 if fail and 0 if pass. sql = self.values_metric_sql( - assertion.operator, entity_name, transformed_field, where_clause + assertion.operator, + entity_name, + transformed_field, + where_clause, ) # metric would be number of failing rows OR percentage of failing rows. diff --git a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/metric_sql_generator.py b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/metric_sql_generator.py index facc7d107d1ba7..8d6d9d9a40aead 100644 --- a/metadata-ingestion/src/datahub/integrations/assertion/snowflake/metric_sql_generator.py +++ b/metadata-ingestion/src/datahub/integrations/assertion/snowflake/metric_sql_generator.py @@ -67,7 +67,7 @@ def _(self, assertion: FixedIntervalFreshnessAssertion) -> str: ) as metric from {entity_name} {where_clause}""" else: raise ValueError( - f"Unsupported freshness source type {assertion.source_type} " + f"Unsupported freshness source type {assertion.source_type} ", ) @metric_sql.register diff --git a/metadata-ingestion/src/datahub/lite/duckdb_lite.py b/metadata-ingestion/src/datahub/lite/duckdb_lite.py index fe025842822b13..bd5e2df6c3925e 100644 --- a/metadata-ingestion/src/datahub/lite/duckdb_lite.py +++ b/metadata-ingestion/src/datahub/lite/duckdb_lite.py @@ -50,17 +50,22 @@ def __init__(self, config: DuckDBLiteConfig) -> None: fpath = pathlib.Path(self.config.file) fpath.parent.mkdir(exist_ok=True) self.duckdb_client = duckdb.connect( - str(fpath), read_only=config.read_only, config=config.options + str(fpath), + read_only=config.read_only, + config=config.options, ) if not config.read_only: self._init_db() def _create_unique_index( - self, index_name: str, table_name: str, columns: list + self, + index_name: str, + table_name: str, + columns: list, ) -> None: try: self.duckdb_client.execute( - f"CREATE UNIQUE INDEX {index_name} ON {table_name} ({', '.join(columns)})" + f"CREATE UNIQUE INDEX {index_name} ON {table_name} ({', '.join(columns)})", ) except duckdb.CatalogException as e: if "already exists" not in str(e).lower(): @@ -69,20 +74,24 @@ def _create_unique_index( def _init_db(self) -> None: self.duckdb_client.execute( "CREATE TABLE IF NOT EXISTS metadata_aspect_v2 " - "(urn VARCHAR, aspect_name VARCHAR, version BIGINT, metadata JSON, system_metadata JSON, createdon BIGINT)" + "(urn VARCHAR, aspect_name VARCHAR, version BIGINT, metadata JSON, system_metadata JSON, createdon BIGINT)", ) self._create_unique_index( - "aspect_idx", "metadata_aspect_v2", ["urn", "aspect_name", "version"] + "aspect_idx", + "metadata_aspect_v2", + ["urn", "aspect_name", "version"], ) self.duckdb_client.execute( "CREATE TABLE IF NOT EXISTS metadata_edge_v2 " - "(src_id VARCHAR, relnship VARCHAR, dst_id VARCHAR, dst_label VARCHAR)" + "(src_id VARCHAR, relnship VARCHAR, dst_id VARCHAR, dst_label VARCHAR)", ) self._create_unique_index( - "edge_idx", "metadata_edge_v2", ["src_id", "relnship", "dst_id"] + "edge_idx", + "metadata_edge_v2", + ["src_id", "relnship", "dst_id"], ) def location(self) -> str: @@ -106,7 +115,7 @@ def write( writeables = mcps_from_mce(record) else: raise ValueError( - f"DuckDBCatalog only supports MCEs and MCPs, not {type(record)}" + f"DuckDBCatalog only supports MCEs and MCPs, not {type(record)}", ) if not writeables: @@ -130,7 +139,7 @@ def write( metadata_dict = json.loads(max_row[0]) # type: ignore system_metadata = json.loads(max_row[1]) # type: ignore real_version = system_metadata.get("properties", {}).get( - "sysVersion" + "sysVersion", ) if real_version is None: max_version_row = self.duckdb_client.execute( @@ -156,7 +165,8 @@ def write( if writeable.systemMetadata is None: writeable.systemMetadata = SystemMetadataClass( - lastObserved=created_on, properties={} + lastObserved=created_on, + properties={}, ) elif writeable.systemMetadata.lastObserved is None: writeable.systemMetadata.lastObserved = created_on @@ -205,7 +215,7 @@ def write( # this is a dup, we still want to update the lastObserved timestamp if not system_metadata: system_metadata = { - "lastObserved": writeable.systemMetadata.lastObserved + "lastObserved": writeable.systemMetadata.lastObserved, } else: system_metadata["lastObserved"] = ( @@ -229,7 +239,9 @@ def write( and writeable.aspect ) self.post_update_hook( - writeable.entityUrn, writeable.aspectName, writeable.aspect + writeable.entityUrn, + writeable.aspectName, + writeable.aspect, ) self.duckdb_client.commit() @@ -291,13 +303,17 @@ def search( base_query = f"SELECT distinct(urn), 'urn', NULL from metadata_aspect_v2 where urn ILIKE '%{query}%' UNION SELECT urn, aspect_name, metadata from metadata_aspect_v2 where metadata->>'$.name' ILIKE '%{query}%'" for r in self.duckdb_client.execute(base_query).fetchall(): yield Searchable( - id=r[0], aspect=r[1], snippet=r[2] if snippet else None + id=r[0], + aspect=r[1], + snippet=r[2] if snippet else None, ) elif flavor == SearchFlavor.EXACT: base_query = f"SELECT urn, aspect_name, metadata from metadata_aspect_v2 where version = 0 AND ({query})" for r in self.duckdb_client.execute(base_query).fetchall(): yield Searchable( - id=r[0], aspect=r[1], snippet=r[2] if snippet else None + id=r[0], + aspect=r[1], + snippet=r[2] if snippet else None, ) else: raise Exception(f"Unhandled search flavor {flavor}") @@ -368,7 +384,8 @@ def add_edge( ) except Exception as e: logger.error( - f"Failed to write {src_id}, {relnship}, {dst_id}, {dst_label}", e + f"Failed to write {src_id}, {relnship}, {dst_id}, {dst_label}", + e, ) raise @@ -443,14 +460,15 @@ def resolve_name_from_id(maybe_urn: str) -> str: success_path=success_path, failed_token=p, suggested_path=f"{success_path}/{r_name}".replace( - "//", "/" + "//", + "/", ), ), - ) + ), ) return results_list raise PathNotFoundException( - f"Path {path} not found at {p} for query: {query}, did you mean {alternatives}" + f"Path {path} not found at {p} for query: {query}, did you mean {alternatives}", ) in_list = [r[0] for r in results] in_list_quoted = [f"'{r}'" for r in in_list] @@ -486,7 +504,8 @@ def reindex(self) -> None: self.global_post_update_hook(urn, aspect_map) # type: ignore def get_all_entities( - self, typed: bool = False + self, + typed: bool = False, ) -> Iterable[Dict[str, Union[dict, _Aspect]]]: query = "SELECT urn, aspect_name, metadata, system_metadata from metadata_aspect_v2 where version = 0 order by (urn, aspect_name)" results = self.duckdb_client.execute(query) @@ -502,7 +521,7 @@ def get_all_entities( ) try: aspect_payload = ASPECT_MAP[aspect_name].from_obj( - post_json_transform(aspect_payload) + post_json_transform(aspect_payload), ) except Exception as e: logger.exception( @@ -532,7 +551,7 @@ def get_all_aspects(self) -> Iterable[MetadataChangeProposalWrapper]: urn = r[0] aspect_name = r[1] aspect_metadata = ASPECT_MAP[aspect_name].from_obj( - post_json_transform(json.loads(r[2])) + post_json_transform(json.loads(r[2])), ) # type: ignore system_metadata = SystemMetadataClass.from_obj(json.loads(r[3])) mcp = MetadataChangeProposalWrapper( @@ -580,35 +599,40 @@ def get_category_from_platform(self, data_platform_urn: DataPlatformUrn) -> Urn: return Urn(entity_type="systemNode", entity_id=[k]) logger.debug( - f"Failed to find category for platform {data_platform_urn}, mapping to generic data_platform" + f"Failed to find category for platform {data_platform_urn}, mapping to generic data_platform", ) return Urn(entity_type="systemNode", entity_id=["data_platforms"]) def global_post_update_hook( - self, entity_urn: str, aspect_map: Dict[str, _Aspect] + self, + entity_urn: str, + aspect_map: Dict[str, _Aspect], ) -> None: def pluralize(noun: str) -> str: return noun.lower() + "s" def get_typed_aspect( - aspect_map: Dict[str, _Aspect], aspect_type: Type[_Aspect] + aspect_map: Dict[str, _Aspect], + aspect_type: Type[_Aspect], ) -> Optional[_Aspect]: aspect_names = [k for k, v in ASPECT_MAP.items() if v == aspect_type] if aspect_names: return aspect_map.get(aspect_names[0]) raise Exception( - f"Unable to locate aspect type {aspect_type} in the registry" + f"Unable to locate aspect type {aspect_type} in the registry", ) if not entity_urn: logger.error(f"Bad input {entity_urn}: {aspect_map}") container: Optional[ContainerClass] = get_typed_aspect( # type: ignore - aspect_map, ContainerClass + aspect_map, + ContainerClass, ) # type: ignore subtypes: Optional[SubTypesClass] = get_typed_aspect(aspect_map, SubTypesClass) # type: ignore dpi: Optional[DataPlatformInstanceClass] = get_typed_aspect( # type: ignore - aspect_map, DataPlatformInstanceClass + aspect_map, + DataPlatformInstanceClass, ) # type: ignore needs_platform = Urn.from_string(entity_urn).get_type() in [ @@ -638,7 +662,7 @@ def get_typed_aspect( parent_urn = maybe_parent_urn if Urn.from_string(maybe_parent_urn).get_type() == "dataPlatform": data_platform_urn = DataPlatformUrn.from_string( - maybe_parent_urn + maybe_parent_urn, ) needs_dpi = True else: @@ -653,7 +677,7 @@ def get_typed_aspect( ) try: self._create_edges_from_data_platform_instance( - data_platform_instance_urn + data_platform_instance_urn, ) except Exception as e: logger.error(f"Failed to generate edges entity {entity_urn}", e) @@ -681,13 +705,15 @@ def get_typed_aspect( self.add_edge(type_urn, "name", pluralize(t), remove_existing=True) def _create_edges_from_data_platform_instance( - self, data_platform_instance_urn: Urn + self, + data_platform_instance_urn: Urn, ) -> None: data_platform_urn = DataPlatformUrn.from_string( - data_platform_instance_urn.get_entity_id()[0] + data_platform_instance_urn.get_entity_id()[0], ) data_platform_instances_urn = Urn( - entity_type="systemNode", entity_id=[str(data_platform_urn), "instances"] + entity_type="systemNode", + entity_id=[str(data_platform_urn), "instances"], ) data_platform_category = self.get_category_from_platform(data_platform_urn) @@ -716,7 +742,10 @@ def _create_edges_from_data_platform_instance( # ///instances self.add_edge(str(data_platform_urn), "child", str(data_platform_instances_urn)) self.add_edge( - str(data_platform_instances_urn), "name", "instances", remove_existing=True + str(data_platform_instances_urn), + "name", + "instances", + remove_existing=True, ) # ///instances/ self.add_edge( @@ -727,7 +756,10 @@ def _create_edges_from_data_platform_instance( ) def post_update_hook( - self, entity_urn: str, aspect_name: str, aspect: _Aspect + self, + entity_urn: str, + aspect_name: str, + aspect: _Aspect, ) -> None: if isinstance(aspect, DatasetPropertiesClass): dp: DatasetPropertiesClass = aspect diff --git a/metadata-ingestion/src/datahub/lite/lite_local.py b/metadata-ingestion/src/datahub/lite/lite_local.py index d767bbcec46215..93ecafec881e41 100644 --- a/metadata-ingestion/src/datahub/lite/lite_local.py +++ b/metadata-ingestion/src/datahub/lite/lite_local.py @@ -101,7 +101,8 @@ def ls(self, path: str) -> List[Browseable]: @abstractmethod def get_all_entities( - self, typed: bool = False + self, + typed: bool = False, ) -> Iterable[Dict[str, Union[dict, _Aspect]]]: pass diff --git a/metadata-ingestion/src/datahub/lite/lite_util.py b/metadata-ingestion/src/datahub/lite/lite_util.py index b1631e233fa1a9..5df5da0e1ceeec 100644 --- a/metadata-ingestion/src/datahub/lite/lite_util.py +++ b/metadata-ingestion/src/datahub/lite/lite_util.py @@ -40,7 +40,8 @@ def write( self.lite.write(record) record_envelope = RecordEnvelope(record=record, metadata={}) self.forward_to.write_record_async( - record_envelope=record_envelope, write_callback=NoopWriteCallback() + record_envelope=record_envelope, + write_callback=NoopWriteCallback(), ) def close(self) -> None: @@ -79,7 +80,8 @@ def ls(self, path: str) -> List[Browseable]: return self.lite.ls(path) def get_all_entities( - self, typed: bool = False + self, + typed: bool = False, ) -> Iterable[Dict[str, Union[dict, _Aspect]]]: yield from self.lite.get_all_entities(typed) @@ -98,11 +100,11 @@ def get_datahub_lite(config_dict: dict, read_only: bool = False) -> "DataHubLite lite_class = lite_registry.get(lite_type) except KeyError: raise Exception( - f"Failed to find a registered lite implementation for {lite_type}. Valid values are {[k for k in lite_registry.mapping.keys()]}" + f"Failed to find a registered lite implementation for {lite_type}. Valid values are {[k for k in lite_registry.mapping.keys()]}", ) lite_specific_config = lite_class.get_config_class().parse_obj( - lite_local_config.config + lite_local_config.config, ) lite = lite_class(lite_specific_config) # we only set up forwarding if forwarding config is present and read_only is set to False @@ -118,15 +120,16 @@ def get_datahub_lite(config_dict: dict, read_only: bool = False) -> "DataHubLite return DataHubLiteWrapper(lite, forward_to) except Exception as e: logger.warning( - f"Failed to set up forwarding due to {e}, will not forward events" + f"Failed to set up forwarding due to {e}, will not forward events", ) logger.debug( - "Failed to set up forwarding, will not forward events", exc_info=e + "Failed to set up forwarding, will not forward events", + exc_info=e, ) return lite else: raise Exception( - f"Failed to find a registered forwarding sink for type {lite_local_config.forward_to.type}. Valid values are {[k for k in sink_registry.mapping.keys()]}" + f"Failed to find a registered forwarding sink for type {lite_local_config.forward_to.type}. Valid values are {[k for k in sink_registry.mapping.keys()]}", ) else: return lite diff --git a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py index 8301ff2d9dc1a6..b0730493fa53a5 100644 --- a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py +++ b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py @@ -38,7 +38,7 @@ def __init__(self, config: DataHubSecretStoreConfig): self.client = DataHubSecretsClient(graph) else: raise Exception( - "Invalid configuration provided: unable to construct DataHub Graph Client." + "Invalid configuration provided: unable to construct DataHub Graph Client.", ) def get_secret_values(self, secret_names: List[str]) -> Dict[str, Union[str, None]]: @@ -49,7 +49,7 @@ def get_secret_values(self, secret_names: List[str]) -> Dict[str, Union[str, Non except Exception: # Failed to resolve secrets, return empty. logger.exception( - f"Caught exception while attempting to fetch secrets from DataHub. Secret names: {secret_names}" + f"Caught exception while attempting to fetch secrets from DataHub. Secret names: {secret_names}", ) return {} diff --git a/metadata-ingestion/src/datahub/secret/secret_common.py b/metadata-ingestion/src/datahub/secret/secret_common.py index a116c70407af23..a6f540fdd4cf08 100644 --- a/metadata-ingestion/src/datahub/secret/secret_common.py +++ b/metadata-ingestion/src/datahub/secret/secret_common.py @@ -34,13 +34,15 @@ def resolve_secrets(secret_names: List[str], secret_stores: List[SecretStore]) - final_secret_values[secret_name] = secret_value except Exception: logger.exception( - f"Failed to fetch secret values from secret store with id {secret_store.get_id()}" + f"Failed to fetch secret values from secret store with id {secret_store.get_id()}", ) return final_secret_values def resolve_recipe( - recipe: str, secret_stores: List[SecretStore], strict_env_syntax: bool = True + recipe: str, + secret_stores: List[SecretStore], + strict_env_syntax: bool = True, ) -> dict: # Note: the default for `strict_env_syntax` is normally False, but here we override # it to be true. Particularly when fetching secrets from external secret stores, we @@ -50,7 +52,8 @@ def resolve_recipe( # 1. Extract all secrets needing resolved. secrets_to_resolve = EnvResolver.list_referenced_variables( - json_recipe_raw, strict_env_syntax=strict_env_syntax + json_recipe_raw, + strict_env_syntax=strict_env_syntax, ) # 2. Resolve secret values @@ -58,7 +61,8 @@ def resolve_recipe( # 3. Substitute secrets into recipe file resolver = EnvResolver( - environ=secret_values_dict, strict_env_syntax=strict_env_syntax + environ=secret_values_dict, + strict_env_syntax=strict_env_syntax, ) json_recipe_resolved = resolver.resolve(json_recipe_raw) diff --git a/metadata-ingestion/src/datahub/specific/aspect_helpers/custom_properties.py b/metadata-ingestion/src/datahub/specific/aspect_helpers/custom_properties.py index 4b8b4d0bc99bc0..41daf92c5fee8d 100644 --- a/metadata-ingestion/src/datahub/specific/aspect_helpers/custom_properties.py +++ b/metadata-ingestion/src/datahub/specific/aspect_helpers/custom_properties.py @@ -31,7 +31,8 @@ def add_custom_property(self, key: str, value: str) -> Self: return self def add_custom_properties( - self, custom_properties: Optional[Dict[str, str]] = None + self, + custom_properties: Optional[Dict[str, str]] = None, ) -> Self: if custom_properties is not None: for key, value in custom_properties.items(): diff --git a/metadata-ingestion/src/datahub/specific/aspect_helpers/ownership.py b/metadata-ingestion/src/datahub/specific/aspect_helpers/ownership.py index 1e2c789c7def35..9a617a79c38036 100644 --- a/metadata-ingestion/src/datahub/specific/aspect_helpers/ownership.py +++ b/metadata-ingestion/src/datahub/specific/aspect_helpers/ownership.py @@ -29,7 +29,9 @@ def add_owner(self, owner: OwnerClass) -> Self: return self def remove_owner( - self, owner: str, owner_type: Optional[OwnershipTypeClass] = None + self, + owner: str, + owner_type: Optional[OwnershipTypeClass] = None, ) -> Self: """Remove an owner from the entity. @@ -62,6 +64,9 @@ def set_owners(self, owners: List[OwnerClass]) -> Self: The patch builder instance. """ self._add_patch( - OwnershipClass.ASPECT_NAME, "add", path=("owners",), value=owners + OwnershipClass.ASPECT_NAME, + "add", + path=("owners",), + value=owners, ) return self diff --git a/metadata-ingestion/src/datahub/specific/aspect_helpers/structured_properties.py b/metadata-ingestion/src/datahub/specific/aspect_helpers/structured_properties.py index 48050bbad8e50d..7b97ce563463f1 100644 --- a/metadata-ingestion/src/datahub/specific/aspect_helpers/structured_properties.py +++ b/metadata-ingestion/src/datahub/specific/aspect_helpers/structured_properties.py @@ -14,7 +14,9 @@ class HasStructuredPropertiesPatch(MetadataPatchProposal): def set_structured_property( - self, key: str, value: Union[str, float, List[Union[str, float]]] + self, + key: str, + value: Union[str, float, List[Union[str, float]]], ) -> Self: """Add or update a structured property. @@ -48,7 +50,9 @@ def remove_structured_property(self, key: str) -> Self: return self def add_structured_property( - self, key: str, value: Union[str, float, List[Union[str, float]]] + self, + key: str, + value: Union[str, float, List[Union[str, float]]], ) -> Self: """Add a structured property. diff --git a/metadata-ingestion/src/datahub/specific/aspect_helpers/tags.py b/metadata-ingestion/src/datahub/specific/aspect_helpers/tags.py index afbc9115ca6e2b..757554698d1775 100644 --- a/metadata-ingestion/src/datahub/specific/aspect_helpers/tags.py +++ b/metadata-ingestion/src/datahub/specific/aspect_helpers/tags.py @@ -23,7 +23,10 @@ def add_tag(self, tag: Tag) -> Self: # TODO: Make this support raw strings, in addition to Tag objects. self._add_patch( - GlobalTags.ASPECT_NAME, "add", path=("tags", tag.tag), value=tag + GlobalTags.ASPECT_NAME, + "add", + path=("tags", tag.tag), + value=tag, ) return self diff --git a/metadata-ingestion/src/datahub/specific/aspect_helpers/terms.py b/metadata-ingestion/src/datahub/specific/aspect_helpers/terms.py index ae199124372b40..674a164641e529 100644 --- a/metadata-ingestion/src/datahub/specific/aspect_helpers/terms.py +++ b/metadata-ingestion/src/datahub/specific/aspect_helpers/terms.py @@ -22,7 +22,10 @@ def add_term(self, term: Term) -> Self: """ # TODO: Make this support raw strings, in addition to Term objects. self._add_patch( - GlossaryTermsClass.ASPECT_NAME, "add", path=("terms", term.urn), value=term + GlossaryTermsClass.ASPECT_NAME, + "add", + path=("terms", term.urn), + value=term, ) return self @@ -38,6 +41,9 @@ def remove_term(self, term: Union[str, Urn]) -> Self: if isinstance(term, str) and not term.startswith("urn:li:glossaryTerm:"): term = GlossaryTermUrn(term) self._add_patch( - GlossaryTermsClass.ASPECT_NAME, "remove", path=("terms", term), value={} + GlossaryTermsClass.ASPECT_NAME, + "remove", + path=("terms", term), + value={}, ) return self diff --git a/metadata-ingestion/src/datahub/specific/chart.py b/metadata-ingestion/src/datahub/specific/chart.py index f44a2ffc0d68ab..12f870471ffe69 100644 --- a/metadata-ingestion/src/datahub/specific/chart.py +++ b/metadata-ingestion/src/datahub/specific/chart.py @@ -39,7 +39,9 @@ def __init__( audit_header: The Kafka audit header of the chart (optional). """ super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) @classmethod @@ -154,7 +156,8 @@ def set_last_refreshed(self, last_refreshed: Optional[int]) -> "ChartPatchBuilde return self def set_last_modified( - self, last_modified: "ChangeAuditStampsClass" + self, + last_modified: "ChangeAuditStampsClass", ) -> "ChartPatchBuilder": if last_modified: self._add_patch( @@ -188,7 +191,8 @@ def set_chart_url(self, dashboard_url: Optional[str]) -> "ChartPatchBuilder": return self def set_type( - self, type: Union[None, Union[str, "ChartTypeClass"]] = None + self, + type: Union[None, Union[str, "ChartTypeClass"]] = None, ) -> "ChartPatchBuilder": if type: self._add_patch( @@ -201,7 +205,8 @@ def set_type( return self def set_access( - self, access: Union[None, Union[str, "AccessLevelClass"]] = None + self, + access: Union[None, Union[str, "AccessLevelClass"]] = None, ) -> "ChartPatchBuilder": if access: self._add_patch( diff --git a/metadata-ingestion/src/datahub/specific/dashboard.py b/metadata-ingestion/src/datahub/specific/dashboard.py index 515fcf0c6da955..8a615974ecc2b7 100644 --- a/metadata-ingestion/src/datahub/specific/dashboard.py +++ b/metadata-ingestion/src/datahub/specific/dashboard.py @@ -38,7 +38,9 @@ def __init__( audit_header: The Kafka audit header of the dashboard (optional). """ super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) @classmethod @@ -46,7 +48,8 @@ def _custom_properties_location(cls) -> Tuple[str, PatchPath]: return DashboardInfo.ASPECT_NAME, ("customProperties",) def add_dataset_edge( - self, dataset: Union[Edge, Urn, str] + self, + dataset: Union[Edge, Urn, str], ) -> "DashboardPatchBuilder": """ Adds an dataset to the DashboardPatchBuilder. @@ -258,7 +261,8 @@ def add_charts(self, chart_urns: Optional[List[str]]) -> "DashboardPatchBuilder" return self def add_datasets( - self, dataset_urns: Optional[List[str]] + self, + dataset_urns: Optional[List[str]], ) -> "DashboardPatchBuilder": if dataset_urns: for urn in dataset_urns: @@ -272,7 +276,8 @@ def add_datasets( return self def set_dashboard_url( - self, dashboard_url: Optional[str] + self, + dashboard_url: Optional[str], ) -> "DashboardPatchBuilder": if dashboard_url: self._add_patch( @@ -285,7 +290,8 @@ def set_dashboard_url( return self def set_access( - self, access: Union[None, Union[str, "AccessLevelClass"]] = None + self, + access: Union[None, Union[str, "AccessLevelClass"]] = None, ) -> "DashboardPatchBuilder": if access: self._add_patch( @@ -298,7 +304,8 @@ def set_access( return self def set_last_refreshed( - self, last_refreshed: Optional[int] + self, + last_refreshed: Optional[int], ) -> "DashboardPatchBuilder": if last_refreshed: self._add_patch( @@ -311,7 +318,8 @@ def set_last_refreshed( return self def set_last_modified( - self, last_modified: "ChangeAuditStampsClass" + self, + last_modified: "ChangeAuditStampsClass", ) -> "DashboardPatchBuilder": if last_modified: self._add_patch( diff --git a/metadata-ingestion/src/datahub/specific/datajob.py b/metadata-ingestion/src/datahub/specific/datajob.py index fd826c6dd59ca3..79acdb03f16aa2 100644 --- a/metadata-ingestion/src/datahub/specific/datajob.py +++ b/metadata-ingestion/src/datahub/specific/datajob.py @@ -37,7 +37,9 @@ def __init__( audit_header: The Kafka audit header of the data job (optional). """ super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) @classmethod @@ -207,7 +209,8 @@ def set_input_datasets(self, inputs: List[Edge]) -> "DataJobPatchBuilder": return self def add_output_dataset( - self, output: Union[Edge, Urn, str] + self, + output: Union[Edge, Urn, str], ) -> "DataJobPatchBuilder": """ Adds an output dataset to the DataJobPatchBuilder. @@ -314,7 +317,8 @@ def add_input_dataset_field(self, input: Union[Urn, str]) -> "DataJobPatchBuilde return self def remove_input_dataset_field( - self, input: Union[str, Urn] + self, + input: Union[str, Urn], ) -> "DataJobPatchBuilder": """ Removes an input dataset field from the DataJobPatchBuilder. @@ -360,7 +364,8 @@ def set_input_dataset_fields(self, inputs: List[Edge]) -> "DataJobPatchBuilder": return self def add_output_dataset_field( - self, output: Union[Urn, str] + self, + output: Union[Urn, str], ) -> "DataJobPatchBuilder": """ Adds an output dataset field to the DataJobPatchBuilder. @@ -386,7 +391,8 @@ def add_output_dataset_field( return self def remove_output_dataset_field( - self, output: Union[str, Urn] + self, + output: Union[str, Urn], ) -> "DataJobPatchBuilder": """ Removes an output dataset field from the DataJobPatchBuilder. diff --git a/metadata-ingestion/src/datahub/specific/dataproduct.py b/metadata-ingestion/src/datahub/specific/dataproduct.py index d38d2d4156315d..c018c69fd0244b 100644 --- a/metadata-ingestion/src/datahub/specific/dataproduct.py +++ b/metadata-ingestion/src/datahub/specific/dataproduct.py @@ -55,7 +55,8 @@ def set_description(self, description: str) -> "DataProductPatchBuilder": return self def set_assets( - self, assets: List[DataProductAssociation] + self, + assets: List[DataProductAssociation], ) -> "DataProductPatchBuilder": self._add_patch( DataProductProperties.ASPECT_NAME, diff --git a/metadata-ingestion/src/datahub/specific/dataset.py b/metadata-ingestion/src/datahub/specific/dataset.py index 6332386684bbf0..2f25632691d578 100644 --- a/metadata-ingestion/src/datahub/specific/dataset.py +++ b/metadata-ingestion/src/datahub/specific/dataset.py @@ -109,7 +109,9 @@ def __init__( audit_header: Optional[KafkaAuditHeaderClass] = None, ) -> None: super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) @classmethod @@ -126,7 +128,8 @@ def add_upstream_lineage(self, upstream: Upstream) -> "DatasetPatchBuilder": return self def remove_upstream_lineage( - self, dataset: Union[str, Urn] + self, + dataset: Union[str, Urn], ) -> "DatasetPatchBuilder": self._add_patch( UpstreamLineage.ASPECT_NAME, @@ -138,12 +141,16 @@ def remove_upstream_lineage( def set_upstream_lineages(self, upstreams: List[Upstream]) -> "DatasetPatchBuilder": self._add_patch( - UpstreamLineage.ASPECT_NAME, "add", path=("upstreams",), value=upstreams + UpstreamLineage.ASPECT_NAME, + "add", + path=("upstreams",), + value=upstreams, ) return self def add_fine_grained_upstream_lineage( - self, fine_grained_lineage: FineGrainedLineage + self, + fine_grained_lineage: FineGrainedLineage, ) -> "DatasetPatchBuilder": ( transform_op, @@ -155,7 +162,10 @@ def add_fine_grained_upstream_lineage( UpstreamLineage.ASPECT_NAME, "add", path=self._build_fine_grained_path( - transform_op, downstream_urn, query_id, upstream_urn + transform_op, + downstream_urn, + query_id, + upstream_urn, ), value={"confidenceScore": fine_grained_lineage.confidenceScore}, ) @@ -175,7 +185,11 @@ def get_fine_grained_key( @classmethod def _build_fine_grained_path( - cls, transform_op: str, downstream_urn: str, query_id: str, upstream_urn: str + cls, + transform_op: str, + downstream_urn: str, + query_id: str, + upstream_urn: str, ) -> PatchPath: return ( "fineGrainedLineages", @@ -186,7 +200,8 @@ def _build_fine_grained_path( ) def remove_fine_grained_upstream_lineage( - self, fine_grained_lineage: FineGrainedLineage + self, + fine_grained_lineage: FineGrainedLineage, ) -> "DatasetPatchBuilder": ( transform_op, @@ -198,14 +213,18 @@ def remove_fine_grained_upstream_lineage( UpstreamLineage.ASPECT_NAME, "remove", path=self._build_fine_grained_path( - transform_op, downstream_urn, query_id, upstream_urn + transform_op, + downstream_urn, + query_id, + upstream_urn, ), value={}, ) return self def set_fine_grained_upstream_lineages( - self, fine_grained_lineages: List[FineGrainedLineage] + self, + fine_grained_lineages: List[FineGrainedLineage], ) -> "DatasetPatchBuilder": self._add_patch( UpstreamLineage.ASPECT_NAME, @@ -216,7 +235,9 @@ def set_fine_grained_upstream_lineages( return self def for_field( - self, field_path: str, editable: bool = True + self, + field_path: str, + editable: bool = True, ) -> FieldPatchHelper["DatasetPatchBuilder"]: """ Get a helper that can perform patches against fields in the dataset @@ -231,7 +252,9 @@ def for_field( ) def set_description( - self, description: Optional[str] = None, editable: bool = False + self, + description: Optional[str] = None, + editable: bool = False, ) -> "DatasetPatchBuilder": if description is not None: self._add_patch( @@ -247,7 +270,8 @@ def set_description( return self def set_display_name( - self, display_name: Optional[str] = None + self, + display_name: Optional[str] = None, ) -> "DatasetPatchBuilder": if display_name is not None: self._add_patch( @@ -259,7 +283,8 @@ def set_display_name( return self def set_qualified_name( - self, qualified_name: Optional[str] = None + self, + qualified_name: Optional[str] = None, ) -> "DatasetPatchBuilder": if qualified_name is not None: self._add_patch( @@ -271,7 +296,8 @@ def set_qualified_name( return self def set_created( - self, timestamp: Optional[TimeStamp] = None + self, + timestamp: Optional[TimeStamp] = None, ) -> "DatasetPatchBuilder": if timestamp is not None: self._add_patch( @@ -283,7 +309,8 @@ def set_created( return self def set_last_modified( - self, timestamp: Optional[TimeStamp] = None + self, + timestamp: Optional[TimeStamp] = None, ) -> "DatasetPatchBuilder": if timestamp is not None: self._add_patch( diff --git a/metadata-ingestion/src/datahub/specific/form.py b/metadata-ingestion/src/datahub/specific/form.py index 281b3cac99b2c1..2a7049d11484be 100644 --- a/metadata-ingestion/src/datahub/specific/form.py +++ b/metadata-ingestion/src/datahub/specific/form.py @@ -19,7 +19,9 @@ def __init__( audit_header: Optional[KafkaAuditHeaderClass] = None, ) -> None: super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) def set_name(self, name: Optional[str] = None) -> "FormPatchBuilder": diff --git a/metadata-ingestion/src/datahub/specific/structured_property.py b/metadata-ingestion/src/datahub/specific/structured_property.py index bcae174ed3c4f4..5112e82d5ede39 100644 --- a/metadata-ingestion/src/datahub/specific/structured_property.py +++ b/metadata-ingestion/src/datahub/specific/structured_property.py @@ -19,12 +19,15 @@ def __init__( audit_header: Optional[KafkaAuditHeaderClass] = None, ) -> None: super().__init__( - urn, system_metadata=system_metadata, audit_header=audit_header + urn, + system_metadata=system_metadata, + audit_header=audit_header, ) # can only be used when creating a new structured property def set_qualified_name( - self, qualified_name: str + self, + qualified_name: str, ) -> "StructuredPropertyPatchBuilder": self._add_patch( StructuredPropertyDefinition.ASPECT_NAME, @@ -35,7 +38,8 @@ def set_qualified_name( return self def set_display_name( - self, display_name: Optional[str] = None + self, + display_name: Optional[str] = None, ) -> "StructuredPropertyPatchBuilder": if display_name is not None: self._add_patch( @@ -48,7 +52,8 @@ def set_display_name( # can only be used when creating a new structured property def set_value_type( - self, value_type: Union[str, Urn] + self, + value_type: Union[str, Urn], ) -> "StructuredPropertyPatchBuilder": self._add_patch( StructuredPropertyDefinition.ASPECT_NAME, @@ -60,7 +65,8 @@ def set_value_type( # can only be used when creating a new structured property def set_type_qualifier( - self, type_qualifier: Optional[Dict[str, List[str]]] = None + self, + type_qualifier: Optional[Dict[str, List[str]]] = None, ) -> "StructuredPropertyPatchBuilder": if type_qualifier is not None: self._add_patch( @@ -73,7 +79,8 @@ def set_type_qualifier( # can only be used when creating a new structured property def add_allowed_value( - self, allowed_value: PropertyValueClass + self, + allowed_value: PropertyValueClass, ) -> "StructuredPropertyPatchBuilder": self._add_patch( StructuredPropertyDefinition.ASPECT_NAME, @@ -93,7 +100,8 @@ def set_cardinality(self, cardinality: str) -> "StructuredPropertyPatchBuilder": return self def add_entity_type( - self, entity_type: Union[str, Urn] + self, + entity_type: Union[str, Urn], ) -> "StructuredPropertyPatchBuilder": self._add_patch( StructuredPropertyDefinition.ASPECT_NAME, @@ -104,7 +112,8 @@ def add_entity_type( return self def set_description( - self, description: Optional[str] = None + self, + description: Optional[str] = None, ) -> "StructuredPropertyPatchBuilder": if description is not None: self._add_patch( diff --git a/metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py b/metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py index 55f30b576b44ef..cc4821e49ea175 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py +++ b/metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py @@ -39,7 +39,8 @@ def _new_apply_patch(source: str, patch_text: str, forwards: bool, name: str) -> logger.debug(f"Subprocess result:\n{result_subprocess}") logger.debug(f"Our result:\n{result}") diff = difflib.unified_diff( - result_subprocess.splitlines(), result.splitlines() + result_subprocess.splitlines(), + result.splitlines(), ) logger.debug("Diff:\n" + "\n".join(diff)) raise ValueError("Results from subprocess and _apply_diff do not match") diff --git a/metadata-ingestion/src/datahub/sql_parsing/datajob.py b/metadata-ingestion/src/datahub/sql_parsing/datajob.py index 215b207c3dcf51..109e8d8a7f7ca8 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/datajob.py +++ b/metadata-ingestion/src/datahub/sql_parsing/datajob.py @@ -12,7 +12,9 @@ def to_datajob_input_output( - *, mcps: Iterable[MetadataChangeProposalWrapper], ignore_extra_mcps: bool = True + *, + mcps: Iterable[MetadataChangeProposalWrapper], + ignore_extra_mcps: bool = True, ) -> Optional[DataJobInputOutputClass]: inputDatasets: List[str] = [] outputDatasets: List[str] = [] @@ -37,7 +39,7 @@ def to_datajob_input_output( pass else: raise ValueError( - f"Expected an upstreamLineage aspect, got {mcp.aspectName} for {mcp.entityUrn}" + f"Expected an upstreamLineage aspect, got {mcp.aspectName} for {mcp.entityUrn}", ) if not inputDatasets and not outputDatasets: diff --git a/metadata-ingestion/src/datahub/sql_parsing/query_types.py b/metadata-ingestion/src/datahub/sql_parsing/query_types.py index 802fb3e993f428..2a548a2ffca125 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/query_types.py +++ b/metadata-ingestion/src/datahub/sql_parsing/query_types.py @@ -37,7 +37,8 @@ def _get_create_type_from_kind(kind: Optional[str]) -> QueryType: def get_query_type_of_sql( - expression: sqlglot.exp.Expression, dialect: DialectOrStr + expression: sqlglot.exp.Expression, + dialect: DialectOrStr, ) -> Tuple[QueryType, QueryTypeProps]: dialect = get_dialect(dialect) query_type_props: QueryTypeProps = {} @@ -80,5 +81,6 @@ def get_query_type_of_sql( def is_create_table_ddl(statement: sqlglot.exp.Expression) -> bool: return isinstance(statement, sqlglot.exp.Create) and isinstance( - statement.this, sqlglot.exp.Schema + statement.this, + sqlglot.exp.Schema, ) diff --git a/metadata-ingestion/src/datahub/sql_parsing/schema_resolver.py b/metadata-ingestion/src/datahub/sql_parsing/schema_resolver.py index 55b026a144c6d5..1fde5b4ecbaa7b 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/schema_resolver.py +++ b/metadata-ingestion/src/datahub/sql_parsing/schema_resolver.py @@ -83,17 +83,20 @@ def get_urns(self) -> Set[str]: def schema_count(self) -> int: return int( self._schema_cache.sql_query( - f"SELECT COUNT(*) FROM {self._schema_cache.tablename} WHERE NOT is_missing" - )[0][0] + f"SELECT COUNT(*) FROM {self._schema_cache.tablename} WHERE NOT is_missing", + )[0][0], ) def get_urn_for_table( - self, table: _TableName, lower: bool = False, mixed: bool = False + self, + table: _TableName, + lower: bool = False, + mixed: bool = False, ) -> str: # TODO: Validate that this is the correct 2/3 layer hierarchy for the platform. table_name = ".".join( - filter(None, [table.database, table.db_schema, table.table]) + filter(None, [table.database, table.db_schema, table.table]), ) platform_instance = self.platform_instance @@ -109,7 +112,7 @@ def get_urn_for_table( # Normalize shard numbers and other BigQuery weirdness. with contextlib.suppress(IndexError): table_name = BigqueryTableIdentifier.from_string_name( - table_name + table_name, ).get_table_name() urn = make_dataset_urn_with_platform_instance( @@ -182,7 +185,9 @@ def _resolve_schema_info(self, urn: str) -> Optional[SchemaInfo]: return None def add_schema_metadata( - self, urn: str, schema_metadata: SchemaMetadataClass + self, + urn: str, + schema_metadata: SchemaMetadataClass, ) -> None: schema_info = _convert_schema_aspect_to_info(schema_metadata) self._save_to_cache(urn, schema_info) @@ -191,13 +196,16 @@ def add_raw_schema_info(self, urn: str, schema_info: SchemaInfo) -> None: self._save_to_cache(urn, schema_info) def add_graphql_schema_metadata( - self, urn: str, schema_metadata: GraphQLSchemaMetadata + self, + urn: str, + schema_metadata: GraphQLSchemaMetadata, ) -> None: schema_info = self.convert_graphql_schema_metadata_to_info(schema_metadata) self._save_to_cache(urn, schema_info) def with_temp_tables( - self, temp_tables: Dict[str, Optional[List[SchemaFieldClass]]] + self, + temp_tables: Dict[str, Optional[List[SchemaFieldClass]]], ) -> SchemaResolverInterface: extra_schemas = { urn: ( @@ -209,7 +217,8 @@ def with_temp_tables( } return _SchemaResolverWithExtras( - base_resolver=self, extra_schemas=extra_schemas + base_resolver=self, + extra_schemas=extra_schemas, ) def _save_to_cache(self, urn: str, schema_info: Optional[SchemaInfo]) -> None: @@ -224,7 +233,8 @@ def _fetch_schema_info(self, graph: DataHubGraph, urn: str) -> Optional[SchemaIn @classmethod def convert_graphql_schema_metadata_to_info( - cls, schema: GraphQLSchemaMetadata + cls, + schema: GraphQLSchemaMetadata, ) -> SchemaInfo: return { get_simple_field_path_from_v2_field_path(field["fieldPath"]): ( @@ -258,14 +268,16 @@ def includes_temp_tables(self) -> bool: def resolve_table(self, table: _TableName) -> Tuple[str, Optional[SchemaInfo]]: urn = self._base_resolver.get_urn_for_table( - table, lower=self._base_resolver._prefers_urn_lower() + table, + lower=self._base_resolver._prefers_urn_lower(), ) if urn in self._extra_schemas: return urn, self._extra_schemas[urn] return self._base_resolver.resolve_table(table) def add_temp_tables( - self, temp_tables: Dict[str, Optional[List[SchemaFieldClass]]] + self, + temp_tables: Dict[str, Optional[List[SchemaFieldClass]]], ) -> None: self._extra_schemas.update( { @@ -275,7 +287,7 @@ def add_temp_tables( else None ) for urn, fields in temp_tables.items() - } + }, ) @@ -298,7 +310,8 @@ def _convert_schema_aspect_to_info(schema_metadata: SchemaMetadataClass) -> Sche def match_columns_to_schema( - schema_info: SchemaInfo, input_columns: List[str] + schema_info: SchemaInfo, + input_columns: List[str], ) -> List[str]: column_from_gms: List[str] = list(schema_info.keys()) # list() to silent lint diff --git a/metadata-ingestion/src/datahub/sql_parsing/split_statements.py b/metadata-ingestion/src/datahub/sql_parsing/split_statements.py index 42dda4e62158b0..311db748853129 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/split_statements.py +++ b/metadata-ingestion/src/datahub/sql_parsing/split_statements.py @@ -50,7 +50,9 @@ def _is_keyword_at_position(sql: str, pos: int, keyword: str) -> bool: def _look_ahead_for_keywords( - sql: str, pos: int, keywords: List[str] + sql: str, + pos: int, + keywords: List[str], ) -> Tuple[bool, str, int]: """ Look ahead for SQL keywords at the current position. @@ -105,7 +107,9 @@ def yield_if_complete() -> Generator[str, None, None]: prev_real_char = c is_control_keyword, keyword, keyword_len = _look_ahead_for_keywords( - sql, i, keywords=CONTROL_FLOW_KEYWORDS + sql, + i, + keywords=CONTROL_FLOW_KEYWORDS, ) if is_control_keyword: # Yield current statement if any @@ -120,7 +124,9 @@ def yield_if_complete() -> Generator[str, None, None]: keyword, keyword_len, ) = _look_ahead_for_keywords( - sql, i, keywords=FORCE_NEW_STATEMENT_KEYWORDS + sql, + i, + keywords=FORCE_NEW_STATEMENT_KEYWORDS, ) if ( is_force_new_statement_keyword and most_recent_real_char != ")" diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 8637802f6b9fee..21116b460ed14d 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -217,7 +217,7 @@ class PreparsedQuery: session_id: str = _MISSING_SESSION_ID query_type: QueryType = QueryType.UNKNOWN query_type_props: QueryTypeProps = dataclasses.field( - default_factory=lambda: QueryTypeProps() + default_factory=lambda: QueryTypeProps(), ) # Use this to store addtitional key-value information about query for debugging extra_info: Optional[dict] = None @@ -234,7 +234,7 @@ class SqlAggregatorReport(Report): num_observed_queries_column_timeout: int = 0 num_observed_queries_column_failed: int = 0 observed_query_parse_failures: LossyList[str] = dataclasses.field( - default_factory=LossyList + default_factory=LossyList, ) # Views. @@ -243,7 +243,7 @@ class SqlAggregatorReport(Report): num_views_column_timeout: int = 0 num_views_column_failed: int = 0 views_parse_failures: LossyDict[UrnStr, str] = dataclasses.field( - default_factory=LossyDict + default_factory=LossyDict, ) # SQL parsing (over all invocations). @@ -267,10 +267,10 @@ class SqlAggregatorReport(Report): num_inferred_temp_schemas: Optional[int] = None num_queries_with_temp_tables_in_session: int = 0 queries_with_temp_upstreams: LossyDict[QueryId, LossyList] = dataclasses.field( - default_factory=LossyDict + default_factory=LossyDict, ) queries_with_non_authoritative_session: LossyList[QueryId] = dataclasses.field( - default_factory=LossyList + default_factory=LossyList, ) make_schema_resolver_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) @@ -352,7 +352,7 @@ def __init__( self.generate_lineage or self.generate_query_usage_statistics ): logger.warning( - "Queries will not be generated, as neither lineage nor query usage statistics are enabled" + "Queries will not be generated, as neither lineage nor query usage statistics are enabled", ) self.usage_config = usage_config @@ -397,12 +397,13 @@ def __init__( platform_instance=self.platform_instance, env=self.env, graph=graph, - ) + ), ) # Schema resolver for special case (_MISSING_SESSION_ID) # This is particularly useful for via temp table lineage if session id is not available. self._missing_session_schema_resolver = _SchemaResolverWithExtras( - base_resolver=self._schema_resolver, extra_schemas={} + base_resolver=self._schema_resolver, + extra_schemas={}, ) # Initialize internal data structures. @@ -419,37 +420,42 @@ def __init__( # By providing a filename explicitly here, we also ensure that the file # is not automatically deleted on exit. self._shared_connection = self._exit_stack.enter_context( - ConnectionWrapper(filename=query_log_path) + ConnectionWrapper(filename=query_log_path), ) # Stores the logged queries. self._logged_queries = FileBackedList[LoggedQuery]( - shared_connection=self._shared_connection, tablename="stored_queries" + shared_connection=self._shared_connection, + tablename="stored_queries", ) self._exit_stack.push(self._logged_queries) # Map of query_id -> QueryMetadata self._query_map = FileBackedDict[QueryMetadata]( - shared_connection=self._shared_connection, tablename="query_map" + shared_connection=self._shared_connection, + tablename="query_map", ) self._exit_stack.push(self._query_map) # Map of downstream urn -> { query ids } self._lineage_map = FileBackedDict[OrderedSet[QueryId]]( - shared_connection=self._shared_connection, tablename="lineage_map" + shared_connection=self._shared_connection, + tablename="lineage_map", ) self._exit_stack.push(self._lineage_map) # Map of view urn -> view definition self._view_definitions = FileBackedDict[ViewDefinition]( - shared_connection=self._shared_connection, tablename="view_definitions" + shared_connection=self._shared_connection, + tablename="view_definitions", ) self._exit_stack.push(self._view_definitions) # Map of session ID -> {temp table name -> query id} # Needs to use the query_map to find the info about the query. self._temp_lineage_map = FileBackedDict[Dict[UrnStr, OrderedSet[QueryId]]]( - shared_connection=self._shared_connection, tablename="temp_lineage_map" + shared_connection=self._shared_connection, + tablename="temp_lineage_map", ) self._exit_stack.push(self._temp_lineage_map) @@ -462,13 +468,15 @@ def __init__( # Map of table renames, from original UrnStr to new UrnStr. self._table_renames = FileBackedDict[UrnStr]( - shared_connection=self._shared_connection, tablename="table_renames" + shared_connection=self._shared_connection, + tablename="table_renames", ) self._exit_stack.push(self._table_renames) # Map of table swaps, from unique swap id to TableSwap self._table_swaps = FileBackedDict[TableSwap]( - shared_connection=self._shared_connection, tablename="table_swaps" + shared_connection=self._shared_connection, + tablename="table_swaps", ) self._exit_stack.push(self._table_swaps) @@ -513,7 +521,9 @@ def _need_schemas(self) -> bool: ) def register_schema( - self, urn: Union[str, DatasetUrn], schema: models.SchemaMetadataClass + self, + urn: Union[str, DatasetUrn], + schema: models.SchemaMetadataClass, ) -> None: # If lineage or usage is enabled, adds the schema to the schema resolver # by putting the condition in here, we can avoid all the conditional @@ -523,7 +533,8 @@ def register_schema( self._schema_resolver.add_schema_metadata(str(urn), schema) def register_schemas_from_stream( - self, stream: Iterable[MetadataWorkUnit] + self, + stream: Iterable[MetadataWorkUnit], ) -> Iterable[MetadataWorkUnit]: for wu in stream: schema_metadata = wu.get_aspect_of_type(models.SchemaMetadataClass) @@ -600,7 +611,9 @@ def add( raise ValueError(f"Cannot add unknown item type: {type(item)}") def add_known_query_lineage( - self, known_query_lineage: KnownQueryLineageInfo, merge_lineage: bool = False + self, + known_query_lineage: KnownQueryLineageInfo, + merge_lineage: bool = False, ) -> None: """Add a query and it's precomputed lineage to the aggregator. @@ -649,7 +662,8 @@ def add_known_query_lineage( # Register the lineage. self._lineage_map.for_mutation( - known_query_lineage.downstream, OrderedSet() + known_query_lineage.downstream, + OrderedSet(), ).add(query_fingerprint) def add_known_lineage_mapping( @@ -674,7 +688,7 @@ def add_known_lineage_mapping( downstream_urn: The downstream dataset URN. """ logger.debug( - f"Adding lineage to the map, downstream: {downstream_urn}, upstream: {upstream_urn}" + f"Adding lineage to the map, downstream: {downstream_urn}, upstream: {upstream_urn}", ) self.report.num_known_mapping_lineage += 1 @@ -684,7 +698,8 @@ def add_known_lineage_mapping( # Generate CLL if schema of downstream is known column_lineage: List[ColumnLineageInfo] = ( self._generate_identity_column_lineage( - upstream_urn=upstream_urn, downstream_urn=downstream_urn + upstream_urn=upstream_urn, + downstream_urn=downstream_urn, ) ) @@ -702,14 +717,17 @@ def add_known_lineage_mapping( column_lineage=column_lineage, column_usage={}, confidence_score=1.0, - ) + ), ) # Register the lineage. self._lineage_map.for_mutation(downstream_urn, OrderedSet()).add(query_id) def _generate_identity_column_lineage( - self, *, upstream_urn: UrnStr, downstream_urn: UrnStr + self, + *, + upstream_urn: UrnStr, + downstream_urn: UrnStr, ) -> List[ColumnLineageInfo]: column_lineage: List[ColumnLineageInfo] = [] if self._schema_resolver.has_urn(downstream_urn): @@ -718,7 +736,8 @@ def _generate_identity_column_lineage( column_lineage = [ ColumnLineageInfo( downstream=DownstreamColumnRef( - table=downstream_urn, column=field_path + table=downstream_urn, + column=field_path, ), upstreams=[ColumnRef(table=upstream_urn, column=field_path)], ) @@ -789,7 +808,7 @@ def add_observed_query( ) if parsed.debug_info.error: self.report.observed_query_parse_failures.append( - f"{parsed.debug_info.error} on query: {observed.query[:100]}" + f"{parsed.debug_info.error} on query: {observed.query[:100]}", ) if parsed.debug_info.table_error: self.report.num_observed_queries_failed += 1 @@ -883,7 +902,8 @@ def add_preparsed_query( if self._query_usage_counts is not None and parsed.timestamp is not None: assert self.usage_config is not None bucket = get_time_bucket( - parsed.timestamp, self.usage_config.bucket_duration + parsed.timestamp, + self.usage_config.bucket_duration, ) counts = self._query_usage_counts.for_mutation(query_fingerprint, {}) counts[bucket] = counts.get(bucket, 0) + parsed.query_count @@ -903,7 +923,7 @@ def add_preparsed_query( column_usage=parsed.column_usage or {}, confidence_score=parsed.confidence_score, used_temp_tables=session_has_temp_tables, - ) + ), ) if not parsed.downstream: @@ -929,19 +949,20 @@ def add_preparsed_query( # Also track the lineage for the temp table, for merging purposes later. self._temp_lineage_map.for_mutation(parsed.session_id, {}).setdefault( - out_table, OrderedSet() + out_table, + OrderedSet(), ).add(query_fingerprint) # Also update schema resolver for missing session id if parsed.session_id == _MISSING_SESSION_ID and parsed.inferred_schema: self._missing_session_schema_resolver.add_temp_tables( - {out_table: parsed.inferred_schema} + {out_table: parsed.inferred_schema}, ) else: # Non-temp tables immediately generate lineage. self._lineage_map.for_mutation(out_table, OrderedSet()).add( - query_fingerprint + query_fingerprint, ) def add_table_rename( @@ -976,7 +997,7 @@ def add_table_rename( ), session_id=table_rename.session_id, timestamp=table_rename.timestamp, - ) + ), ) def add_table_swap(self, table_swap: TableSwap) -> None: @@ -1011,11 +1032,12 @@ def add_table_swap(self, table_swap: TableSwap) -> None: upstreams=[table_swap.urn1], downstream=table_swap.urn2, column_lineage=self._generate_identity_column_lineage( - upstream_urn=table_swap.urn1, downstream_urn=table_swap.urn2 + upstream_urn=table_swap.urn1, + downstream_urn=table_swap.urn2, ), session_id=table_swap.session_id, timestamp=table_swap.timestamp, - ) + ), ) if not self.is_temp_table(table_swap.urn1): @@ -1027,15 +1049,17 @@ def add_table_swap(self, table_swap: TableSwap) -> None: upstreams=[table_swap.urn2], downstream=table_swap.urn1, column_lineage=self._generate_identity_column_lineage( - upstream_urn=table_swap.urn2, downstream_urn=table_swap.urn1 + upstream_urn=table_swap.urn2, + downstream_urn=table_swap.urn1, ), session_id=table_swap.session_id, timestamp=table_swap.timestamp, - ) + ), ) def _make_schema_resolver_for_session( - self, session_id: str + self, + session_id: str, ) -> SchemaResolverInterface: schema_resolver: SchemaResolverInterface = self._schema_resolver if session_id == _MISSING_SESSION_ID: @@ -1052,14 +1076,16 @@ def _make_schema_resolver_for_session( if temp_table_schemas: schema_resolver = self._schema_resolver.with_temp_tables( - temp_table_schemas + temp_table_schemas, ) self.report.num_queries_with_temp_tables_in_session += 1 return schema_resolver def _process_view_definition( - self, view_urn: UrnStr, view_definition: ViewDefinition + self, + view_urn: UrnStr, + view_definition: ViewDefinition, ) -> None: # Note that in some cases, the view definition will be a SELECT statement # instead of a CREATE VIEW ... AS SELECT statement. In those cases, we can't @@ -1086,7 +1112,7 @@ def _process_view_definition( query_fingerprint = self._view_query_id(view_urn) formatted_view_definition = self._maybe_format_query( - view_definition.view_definition + view_definition.view_definition, ) # Register the query. @@ -1103,7 +1129,7 @@ def _process_view_definition( column_lineage=parsed.column_lineage or [], column_usage=compute_upstream_fields(parsed), confidence_score=parsed.debug_info.confidence, - ) + ), ) # Register the query's lineage. @@ -1152,7 +1178,9 @@ def _run_sql_parser( return parsed def _add_to_query_map( - self, new: QueryMetadata, merge_lineage: bool = False + self, + new: QueryMetadata, + merge_lineage: bool = False, ) -> None: query_fingerprint = new.query_id @@ -1172,7 +1200,7 @@ def _add_to_query_map( # lineage as more authoritative. This isn't technically correct, but it's # better than using the newer session's lineage, which is likely incorrect. self.report.queries_with_non_authoritative_session.append( - query_fingerprint + query_fingerprint, ) return current.session_id = new.session_id @@ -1189,10 +1217,11 @@ def _add_to_query_map( # In the case of known query lineage, we might get things one at a time. # TODO: We don't yet support merging CLL for a single query. current.upstreams = list( - OrderedSet(current.upstreams) | OrderedSet(new.upstreams) + OrderedSet(current.upstreams) | OrderedSet(new.upstreams), ) current.confidence_score = min( - current.confidence_score, new.confidence_score + current.confidence_score, + new.confidence_score, ) else: self._query_map[query_fingerprint] = new @@ -1207,7 +1236,8 @@ def gen_metadata(self) -> Iterable[MetadataChangeProposalWrapper]: yield from self._gen_remaining_queries(queries_generated) def _gen_lineage_mcps( - self, queries_generated: Set[QueryId] + self, + queries_generated: Set[QueryId], ) -> Iterable[MetadataChangeProposalWrapper]: if not self.generate_lineage: return @@ -1222,7 +1252,8 @@ def _gen_lineage_mcps( # Generate lineage and queries. for downstream_urn in sorted(self._lineage_map): yield from self._gen_lineage_for_downstream( - downstream_urn, queries_generated=queries_generated + downstream_urn, + queries_generated=queries_generated, ) @classmethod @@ -1240,7 +1271,9 @@ def _query_type_precedence(cls, query_type: str) -> int: return idx def _gen_lineage_for_downstream( - self, downstream_urn: str, queries_generated: Set[QueryId] + self, + downstream_urn: str, + queries_generated: Set[QueryId], ) -> Iterable[MetadataChangeProposalWrapper]: if not self.is_allowed_table(downstream_urn): self.report.num_lineage_skipped_due_to_filters += 1 @@ -1309,7 +1342,7 @@ def _gen_lineage_for_downstream( time=get_sys_time(), actor=_DEFAULT_USER_URN.urn(), ), - ) + ), ) upstream_aspect.fineGrainedLineages = [] for downstream_column, all_upstream_columns in cll.items(): @@ -1330,7 +1363,7 @@ def _gen_lineage_for_downstream( ], downstreamType=models.FineGrainedLineageDownstreamTypeClass.FIELD, downstreams=[ - SchemaFieldUrn(downstream_urn, downstream_column).urn() + SchemaFieldUrn(downstream_urn, downstream_column).urn(), ], query=( self._query_urn(query_id) @@ -1338,13 +1371,13 @@ def _gen_lineage_for_downstream( else None ), confidenceScore=queries_map[query_id].confidence_score, - ) + ), ) if len(upstream_aspect.upstreams) > MAX_UPSTREAM_TABLES_COUNT: logger.warning( f"Too many upstream tables for {downstream_urn}: {len(upstream_aspect.upstreams)}" - f"Keeping only {MAX_UPSTREAM_TABLES_COUNT} table level upstreams/" + f"Keeping only {MAX_UPSTREAM_TABLES_COUNT} table level upstreams/", ) upstream_aspect.upstreams = upstream_aspect.upstreams[ :MAX_UPSTREAM_TABLES_COUNT @@ -1353,7 +1386,7 @@ def _gen_lineage_for_downstream( if len(upstream_aspect.fineGrainedLineages) > MAX_FINEGRAINEDLINEAGE_COUNT: logger.warning( f"Too many upstream columns for {downstream_urn}: {len(upstream_aspect.fineGrainedLineages)}" - f"Keeping only {MAX_FINEGRAINEDLINEAGE_COUNT} column level upstreams/" + f"Keeping only {MAX_FINEGRAINEDLINEAGE_COUNT} column level upstreams/", ) upstream_aspect.fineGrainedLineages = upstream_aspect.fineGrainedLineages[ :MAX_FINEGRAINEDLINEAGE_COUNT @@ -1405,7 +1438,8 @@ def _is_known_lineage_query_id(cls, query_id: QueryId) -> bool: return query_id.startswith("known_") def _gen_remaining_queries( - self, queries_generated: Set[QueryId] + self, + queries_generated: Set[QueryId], ) -> Iterable[MetadataChangeProposalWrapper]: if not self.generate_queries or not self.generate_query_usage_statistics: return @@ -1422,7 +1456,9 @@ def can_generate_query(self, query_id: QueryId) -> bool: return self.generate_queries and not self._is_known_lineage_query_id(query_id) def _gen_query( - self, query: QueryMetadata, downstream_urn: Optional[str] = None + self, + query: QueryMetadata, + downstream_urn: Optional[str] = None, ) -> Iterable[MetadataChangeProposalWrapper]: query_id = query.query_id if not self.can_generate_query(query_id): @@ -1441,7 +1477,7 @@ def _gen_query( if self.generate_query_subject_fields: for column in sorted(query.column_usage.get(upstream, [])): query_subject_urns.add( - builder.make_schema_field_urn(upstream, column) + builder.make_schema_field_urn(upstream, column), ) if downstream_urn: query_subject_urns.add(downstream_urn) @@ -1449,8 +1485,9 @@ def _gen_query( for column_lineage in query.column_lineage: query_subject_urns.add( builder.make_schema_field_urn( - downstream_urn, column_lineage.downstream.column - ) + downstream_urn, + column_lineage.downstream.column, + ), ) yield from MetadataChangeProposalWrapper.construct_many( @@ -1469,7 +1506,7 @@ def _gen_query( subjects=[ models.QuerySubjectClass(entity=urn) for urn in query_subject_urns - ] + ], ), models.DataPlatformInstanceClass( platform=self.platform.urn(), @@ -1506,7 +1543,8 @@ def _gen_query( aspect=models.QueryUsageStatisticsClass( timestampMillis=make_ts_millis(bucket), eventGranularity=models.TimeWindowSizeClass( - unit=self.usage_config.bucket_duration, multiple=1 + unit=self.usage_config.bucket_duration, + multiple=1, ), queryCount=count, uniqueUserCount=1, @@ -1515,7 +1553,7 @@ def _gen_query( models.DatasetUserUsageCountsClass( user=user.urn(), count=count, - ) + ), ] if user else None @@ -1546,11 +1584,13 @@ def _merge_lineage_from(self, other_query: "QueryLineageInfo") -> None: self.upstreams += other_query.upstreams self.column_lineage += other_query.column_lineage self.confidence_score = min( - self.confidence_score, other_query.confidence_score + self.confidence_score, + other_query.confidence_score, ) def _recurse_into_query( - query: QueryMetadata, recursion_path: List[QueryId] + query: QueryMetadata, + recursion_path: List[QueryId], ) -> QueryLineageInfo: if query.query_id in recursion_path: # This is a cycle, so we just return the query as-is. @@ -1566,7 +1606,7 @@ def _recurse_into_query( temp_upstream_queries: Dict[UrnStr, QueryLineageInfo] = {} for upstream in query.upstreams: upstream_query_ids = self._temp_lineage_map.get(session_id, {}).get( - upstream + upstream, ) if upstream_query_ids: for upstream_query_id in upstream_query_ids: @@ -1576,11 +1616,12 @@ def _recurse_into_query( and upstream_query.query_id not in composed_of_queries ): temp_query_lineage_info = _recurse_into_query( - upstream_query, recursion_path + upstream_query, + recursion_path, ) if upstream in temp_upstream_queries: temp_upstream_queries[upstream]._merge_lineage_from( - temp_query_lineage_info + temp_query_lineage_info, ) else: temp_upstream_queries[upstream] = ( @@ -1610,7 +1651,7 @@ def _recurse_into_query( for temp_col_upstream in temp_lineage_info.upstreams if temp_lineage_info.downstream.column == existing_col_upstream.column - ] + ], ) else: new_column_upstreams.append(existing_col_upstream) @@ -1619,7 +1660,7 @@ def _recurse_into_query( ColumnLineageInfo( downstream=lineage_info.downstream, upstreams=new_column_upstreams, - ) + ), ) # Compute merged confidence score. @@ -1630,7 +1671,7 @@ def _recurse_into_query( temp_upstream_query.confidence_score for temp_upstream_query in temp_upstream_queries.values() ], - ] + ], ) return QueryLineageInfo( @@ -1660,7 +1701,7 @@ def _recurse_into_query( key=lambda query: make_ts_millis(query.latest_timestamp) or 0, ) composite_query_id = self._composite_query_id( - [q.query_id for q in ordered_queries] + [q.query_id for q in ordered_queries], ) composed_of_queries_truncated: LossyList[str] = LossyList() for query_id in composed_of_queries: @@ -1670,7 +1711,7 @@ def _recurse_into_query( ) merged_query_text = ";\n\n".join( - [q.formatted_query_string for q in ordered_queries] + [q.formatted_query_string for q in ordered_queries], ) resolved_query = dataclasses.replace( @@ -1696,7 +1737,8 @@ def _gen_usage_statistics_mcps(self) -> Iterable[MetadataChangeProposalWrapper]: yield cast(MetadataChangeProposalWrapper, wu.metadata) def _gen_operation_mcps( - self, queries_generated: Set[QueryId] + self, + queries_generated: Set[QueryId], ) -> Iterable[MetadataChangeProposalWrapper]: if not self.generate_operations: return @@ -1712,7 +1754,9 @@ def _gen_operation_mcps( yield from self._gen_query(self._query_map[query_id], downstream_urn) def _gen_operation_for_downstream( - self, downstream_urn: UrnStr, query_id: QueryId + self, + downstream_urn: UrnStr, + query_id: QueryId, ) -> Iterable[MetadataChangeProposalWrapper]: query = self._query_map[query_id] if query.latest_timestamp is None: diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_common.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_common.py index ec7dbc8251b200..053d5cf3e539cb 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_common.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_common.py @@ -32,7 +32,7 @@ "snowflake", } assert DIALECTS_WITH_DEFAULT_UPPERCASE_COLS.issubset( - DIALECTS_WITH_CASE_INSENSITIVE_COLS + DIALECTS_WITH_CASE_INSENSITIVE_COLS, ) diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index c825deeccd9592..713e9bb5809d15 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -63,7 +63,8 @@ SQL_PARSE_RESULT_CACHE_SIZE = 1000 SQL_LINEAGE_TIMEOUT_ENABLED = get_boolean_env_variable( - "SQL_LINEAGE_TIMEOUT_ENABLED", True + "SQL_LINEAGE_TIMEOUT_ENABLED", + True, ) SQL_LINEAGE_TIMEOUT_SECONDS = 10 SQL_PARSER_TRACE = get_boolean_env_variable("DATAHUB_SQL_PARSER_TRACE", False) @@ -119,7 +120,8 @@ class DownstreamColumnRef(_ParserBaseModel): @pydantic.validator("column_type", pre=True) def _load_column_type( - cls, v: Optional[Union[dict, SchemaFieldDataTypeClass]] + cls, + v: Optional[Union[dict, SchemaFieldDataTypeClass]], ) -> Optional[SchemaFieldDataTypeClass]: if v is None: return None @@ -183,7 +185,7 @@ class SqlParsingResult(_ParserBaseModel): # TODO include list of referenced columns debug_info: SqlParsingDebugInfo = pydantic.Field( - default_factory=lambda: SqlParsingDebugInfo() + default_factory=lambda: SqlParsingDebugInfo(), ) @classmethod @@ -198,7 +200,8 @@ def make_from_error(cls, error: Exception) -> "SqlParsingResult": def _table_level_lineage( - statement: sqlglot.Expression, dialect: sqlglot.Dialect + statement: sqlglot.Expression, + dialect: sqlglot.Dialect, ) -> Tuple[Set[_TableName], Set[_TableName]]: # Generate table-level lineage. modified = ( @@ -289,14 +292,17 @@ class _ColumnResolver: use_case_insensitive_cols: bool def schema_aware_fuzzy_column_resolve( - self, table: Optional[_TableName], sqlglot_column: str + self, + table: Optional[_TableName], + sqlglot_column: str, ) -> str: default_col_name = ( sqlglot_column.lower() if self.use_case_insensitive_cols else sqlglot_column ) if table: return self.table_schema_normalized_mapping[table].get( - sqlglot_column, default_col_name + sqlglot_column, + default_col_name, ) else: return default_col_name @@ -318,11 +324,12 @@ def _prepare_query_columns( and not is_create_ddl ): raise UnsupportedStatementTypeError( - f"Can only generate column-level lineage for select-like inner statements, not {type(statement)}" + f"Can only generate column-level lineage for select-like inner statements, not {type(statement)}", ) use_case_insensitive_cols = is_dialect_instance( - dialect, DIALECTS_WITH_CASE_INSENSITIVE_COLS + dialect, + DIALECTS_WITH_CASE_INSENSITIVE_COLS, ) sqlglot_db_schema = sqlglot.MappingSchema( @@ -331,7 +338,7 @@ def _prepare_query_columns( normalize=False, ) table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict( - dict + dict, ) for table, table_schema in table_schemas.items(): normalized_table_schema: SchemaInfo = {} @@ -341,7 +348,8 @@ def _prepare_query_columns( # This is required to match Sqlglot's behavior. col.upper() if is_dialect_instance( - dialect, DIALECTS_WITH_DEFAULT_UPPERCASE_COLS + dialect, + DIALECTS_WITH_DEFAULT_UPPERCASE_COLS, ) else col.lower() ) @@ -407,24 +415,27 @@ def _sqlglot_force_column_normalizer( ) except (sqlglot.errors.OptimizeError, ValueError) as e: raise SqlUnderstandingError( - f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}" + f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}", ) from e if logger.isEnabledFor(logging.DEBUG): logger.debug( - "Qualified sql %s", statement.sql(pretty=True, dialect=dialect) + "Qualified sql %s", + statement.sql(pretty=True, dialect=dialect), ) # Try to figure out the types of the output columns. try: statement = sqlglot.optimizer.annotate_types.annotate_types( - statement, schema=sqlglot_db_schema + statement, + schema=sqlglot_db_schema, ) except (sqlglot.errors.OptimizeError, sqlglot.errors.ParseError) as e: # This is not a fatal error, so we can continue. logger.debug("sqlglot failed to annotate or parse types: %s", e) if _DEBUG_TYPE_ANNOTATIONS and logger.isEnabledFor(logging.DEBUG): logger.debug( - "Type annotated sql %s", statement.sql(pretty=True, dialect=dialect) + "Type annotated sql %s", + statement.sql(pretty=True, dialect=dialect), ) return statement, _ColumnResolver( @@ -455,7 +466,8 @@ def _create_table_ddl_cll( continue output_col = column_resolver.schema_aware_fuzzy_column_resolve( - output_table, column_def.name + output_table, + column_def.name, ) output_col_type = column_def.args.get("kind") @@ -467,7 +479,7 @@ def _create_table_ddl_cll( column_type=output_col_type, ), upstreams=[], - ) + ), ) return column_lineage @@ -531,7 +543,8 @@ def _select_statement_cll( # noqa: C901 output_col = original_col_expression.this.sql(dialect=dialect) output_col = column_resolver.schema_aware_fuzzy_column_resolve( - output_table, output_col + output_table, + output_col, ) # Guess the output column type. @@ -544,7 +557,8 @@ def _select_statement_cll( # noqa: C901 _ColumnRef( table=edge.table, column=column_resolver.schema_aware_fuzzy_column_resolve( - edge.table, edge.column + edge.table, + edge.column, ), ) for edge in direct_raw_col_upstreams @@ -561,13 +575,13 @@ def _select_statement_cll( # noqa: C901 ), upstreams=sorted(direct_resolved_col_upstreams), # logic=column_logic.sql(pretty=True, dialect=dialect), - ) + ), ) # TODO: Also extract referenced columns (aka auxiliary / non-SELECT lineage) except (sqlglot.errors.OptimizeError, ValueError, IndexError) as e: raise SqlUnderstandingError( - f"sqlglot failed to compute some lineage: {e}" + f"sqlglot failed to compute some lineage: {e}", ) from e return column_lineage @@ -593,7 +607,7 @@ def _column_level_lineage( select_statement = _try_extract_select(statement) except Exception as e: raise SqlUnderstandingError( - f"Failed to extract select from statement: {e}" + f"Failed to extract select from statement: {e}", ) from e try: @@ -629,11 +643,11 @@ def _column_level_lineage( root_scope = sqlglot.optimizer.build_scope(select_statement) if root_scope is None: raise SqlUnderstandingError( - f"Failed to build scope for statement - scope was empty: {statement}" + f"Failed to build scope for statement - scope was empty: {statement}", ) except (sqlglot.errors.OptimizeError, ValueError, IndexError) as e: raise SqlUnderstandingError( - f"sqlglot failed to preprocess statement: {e}" + f"sqlglot failed to preprocess statement: {e}", ) from e # Generate column-level lineage. @@ -680,7 +694,7 @@ def _get_direct_raw_col_upstreams( normalized_col = f"{normalized_col}.{node.subfield}" direct_raw_col_upstreams.add( - _ColumnRef(table=table_ref, column=normalized_col) + _ColumnRef(table=table_ref, column=normalized_col), ) else: # This branch doesn't matter. For example, a count(*) column would go here, and @@ -703,7 +717,7 @@ def _extract_select_from_create( _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT: Set[str] = set( - sqlglot.exp.Update.arg_types.keys() + sqlglot.exp.Update.arg_types.keys(), ) - set(sqlglot.exp.Select.arg_types.keys()) _UPDATE_FROM_TABLE_ARGS_TO_MOVE = {"joins", "laterals", "pivot"} @@ -719,13 +733,14 @@ def _extract_select_from_update( new_expressions = [] for expr in statement.expressions: if isinstance(expr, sqlglot.exp.EQ) and isinstance( - expr.left, sqlglot.exp.Column + expr.left, + sqlglot.exp.Column, ): new_expressions.append( sqlglot.exp.Alias( this=expr.right, alias=expr.left.this, - ) + ), ) else: # If we don't know how to convert it, just leave it as-is. If this causes issues, @@ -754,14 +769,16 @@ def _extract_select_from_update( }, **extra_args, "expressions": new_expressions, - } + }, ) # Update statements always implicitly have the updated table in context. # TODO: Retain table name alias, if one was present. if select_statement.args.get("from"): select_statement = select_statement.join( - statement.this, append=True, join_kind="cross" + statement.this, + append=True, + join_kind="cross", ) else: select_statement = select_statement.from_(statement.this) @@ -951,7 +968,7 @@ def _sqlglot_lineage_inner( # See https://github.com/tobymao/sqlglot/commit/3a13fdf4e597a2f0a3f9fc126a129183fe98262f # and https://github.com/tobymao/sqlglot/pull/2874 raise UnsupportedStatementTypeError( - f"Got unsupported syntax for statement: {sql}" + f"Got unsupported syntax for statement: {sql}", ) original_statement, statement = statement, statement.copy() @@ -996,7 +1013,9 @@ def _sqlglot_lineage_inner( # For select statements, qualification will be a no-op. For other statements, this # is where the qualification actually happens. qualified_table = table.qualified( - dialect=dialect, default_db=default_db, default_schema=default_schema + dialect=dialect, + default_db=default_db, + default_schema=default_schema, ) urn, schema_info = schema_resolver.resolve_table(qualified_table) @@ -1022,7 +1041,7 @@ def _sqlglot_lineage_inner( table_schemas_resolved=total_schemas_resolved, ) logger.debug( - f"Resolved {total_schemas_resolved} of {total_tables_discovered} table schemas" + f"Resolved {total_schemas_resolved} of {total_tables_discovered} table schemas", ) if SQL_PARSER_TRACE: for qualified_table, schema_info in table_name_schema_mapping.items(): @@ -1038,7 +1057,7 @@ def _sqlglot_lineage_inner( with cooperative_timeout( timeout=( SQL_LINEAGE_TIMEOUT_SECONDS if SQL_LINEAGE_TIMEOUT_ENABLED else None - ) + ), ): column_lineage_debug_info = _column_level_lineage( statement, @@ -1070,7 +1089,9 @@ def _sqlglot_lineage_inner( try: column_lineage_urns = [ _translate_internal_column_lineage( - table_name_urn_mapping, internal_col_lineage, dialect=dialect + table_name_urn_mapping, + internal_col_lineage, + dialect=dialect, ) for internal_col_lineage in column_lineage ] @@ -1078,15 +1099,18 @@ def _sqlglot_lineage_inner( # When this happens, it's usually because of things like PIVOT where we can't # really go up the scope chain. logger.debug( - f"Failed to translate column lineage to urns: {e}", exc_info=True + f"Failed to translate column lineage to urns: {e}", + exc_info=True, ) debug_info.column_error = e query_type, query_type_props = get_query_type_of_sql( - original_statement, dialect=dialect + original_statement, + dialect=dialect, ) query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug( - original_statement, dialect + original_statement, + dialect, ) return SqlParsingResult( query_type=query_type, @@ -1170,7 +1194,7 @@ def _sqlglot_lineage_nocache( _sqlglot_lineage_cached = functools.lru_cache(maxsize=SQL_PARSE_RESULT_CACHE_SIZE)( - _sqlglot_lineage_nocache + _sqlglot_lineage_nocache, ) @@ -1183,11 +1207,19 @@ def sqlglot_lineage( ) -> SqlParsingResult: if schema_resolver.includes_temp_tables(): return _sqlglot_lineage_nocache( - sql, schema_resolver, default_db, default_schema, default_dialect + sql, + schema_resolver, + default_db, + default_schema, + default_dialect, ) else: return _sqlglot_lineage_cached( - sql, schema_resolver, default_db, default_schema, default_dialect + sql, + schema_resolver, + default_db, + default_schema, + default_dialect, ) @@ -1280,13 +1312,14 @@ def infer_output_schema(result: SqlParsingResult) -> Optional[List[SchemaFieldCl or SchemaFieldDataTypeClass(type=NullTypeClass()) ), nativeDataType=column_info.downstream.native_column_type or "", - ) + ), ) return output_schema def view_definition_lineage_helper( - result: SqlParsingResult, view_urn: str + result: SqlParsingResult, + view_urn: str, ) -> SqlParsingResult: if result.query_type is QueryType.SELECT or ( result.out_tables and result.out_tables != [view_urn] diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py index 5b12c64a831666..cba472d89b190b 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py @@ -54,7 +54,8 @@ def get_dialect(platform: DialectOrStr) -> sqlglot.Dialect: def is_dialect_instance( - dialect: sqlglot.Dialect, platforms: Union[str, Iterable[str]] + dialect: sqlglot.Dialect, + platforms: Union[str, Iterable[str]], ) -> bool: if isinstance(platforms, str): platforms = [platforms] @@ -70,16 +71,20 @@ def is_dialect_instance( @functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE) def _parse_statement( - sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect + sql: sqlglot.exp.ExpOrStr, + dialect: sqlglot.Dialect, ) -> sqlglot.Expression: statement: sqlglot.Expression = sqlglot.maybe_parse( - sql, dialect=dialect, error_level=sqlglot.ErrorLevel.IMMEDIATE + sql, + dialect=dialect, + error_level=sqlglot.ErrorLevel.IMMEDIATE, ) return statement def parse_statement( - sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect + sql: sqlglot.exp.ExpOrStr, + dialect: sqlglot.Dialect, ) -> sqlglot.Expression: # Parsing is significantly more expensive than copying the expression. # Because the expressions are mutable, we don't want to allow the caller @@ -104,13 +109,15 @@ def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expre # Usually the prior queries are going to be things like `CREATE FUNCTION` # or `GRANT ...`, which we don't care about. logger.debug( - "Found multiple statements in query, picking the last one: %s", sql + "Found multiple statements in query, picking the last one: %s", + sql, ) return statements[-1] def _expression_to_string( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, ) -> str: if isinstance(expression, str): return expression @@ -154,7 +161,8 @@ def _expression_to_string( ): "00000000_0000_0000_0000_000000000000", # GE temporary table names (prefix + 8 digits of a UUIDv4) re.compile( - r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE + r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", + re.IGNORECASE, ): r"\1abcdefgh", # Date-suffixed table names (e.g. _20210101) re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM", @@ -260,7 +268,9 @@ def generate_hash(text: str) -> str: def get_query_fingerprint_debug( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + fast: bool = False, ) -> Tuple[str, Optional[str]]: try: if not fast: @@ -278,13 +288,15 @@ def get_query_fingerprint_debug( fingerprint = generate_hash( expression_sql if expression_sql is not None - else _expression_to_string(expression, platform=platform) + else _expression_to_string(expression, platform=platform), ) return fingerprint, expression_sql def get_query_fingerprint( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + fast: bool = False, ) -> str: """Get a fingerprint for a SQL query. @@ -311,7 +323,9 @@ def get_query_fingerprint( @functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE) def try_format_query( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, raises: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + raises: bool = False, ) -> str: """Format a SQL query. @@ -338,7 +352,9 @@ def try_format_query( def detach_ctes( - sql: sqlglot.exp.ExpOrStr, platform: str, cte_mapping: Dict[str, str] + sql: sqlglot.exp.ExpOrStr, + platform: str, + cte_mapping: Dict[str, str], ) -> sqlglot.exp.Expression: """Replace CTE references with table references. @@ -372,7 +388,9 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression: ): full_new_name = cte_mapping[node.name] table_expr = sqlglot.maybe_parse( - full_new_name, dialect=dialect, into=sqlglot.exp.Table + full_new_name, + dialect=dialect, + into=sqlglot.exp.Table, ) parent = node.parent @@ -399,12 +417,12 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression: max_eliminate_calls = 5 for iteration in range(max_eliminate_calls): new_statement = sqlglot.optimizer.eliminate_ctes.eliminate_ctes( - statement.copy() + statement.copy(), ) if new_statement == statement: if iteration > 1: logger.debug( - f"Required {iteration + 1} iterations to detach and eliminate all CTEs" + f"Required {iteration + 1} iterations to detach and eliminate all CTEs", ) break statement = new_statement diff --git a/metadata-ingestion/src/datahub/sql_parsing/tool_meta_extractor.py b/metadata-ingestion/src/datahub/sql_parsing/tool_meta_extractor.py index d2682252e0fbf5..f2988c200d1d6a 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/tool_meta_extractor.py +++ b/metadata-ingestion/src/datahub/sql_parsing/tool_meta_extractor.py @@ -81,18 +81,21 @@ def create( if graph: try: looker_user_mapping = cls.extract_looker_user_mapping_from_graph( - graph, report + graph, + report, ) except Exception as e: report.failures.append( - f"Unexpected error during Looker user metadata extraction: {str(e)}" + f"Unexpected error during Looker user metadata extraction: {str(e)}", ) return cls(report, looker_user_mapping) @classmethod def extract_looker_user_mapping_from_graph( - cls, graph: DataHubGraph, report: ToolMetaExtractorReport + cls, + graph: DataHubGraph, + report: ToolMetaExtractorReport, ) -> Optional[Dict[str, str]]: looker_user_mapping = None query = ( @@ -106,14 +109,14 @@ def extract_looker_user_mapping_from_graph( .end() ) platform_resources = list( - PlatformResource.search_by_filters(query=query, graph_client=graph) + PlatformResource.search_by_filters(query=query, graph_client=graph), ) if len(platform_resources) == 0: report.looker_user_mapping_missing = True elif len(platform_resources) > 1: report.failures.append( - "Looker user metadata extraction failed. Found more than one looker user id mappings." + "Looker user metadata extraction failed. Found more than one looker user id mappings.", ) else: platform_resource = platform_resources[0] diff --git a/metadata-ingestion/src/datahub/telemetry/stats.py b/metadata-ingestion/src/datahub/telemetry/stats.py index d6835e49de56aa..765ee4f8a8af22 100644 --- a/metadata-ingestion/src/datahub/telemetry/stats.py +++ b/metadata-ingestion/src/datahub/telemetry/stats.py @@ -12,7 +12,8 @@ def __lt__(self, __other: Any) -> Any: ... def calculate_percentiles( - data: List[_SupportsComparisonT], percentiles: List[int] + data: List[_SupportsComparisonT], + percentiles: List[int], ) -> Dict[int, _SupportsComparisonT]: size = len(data) diff --git a/metadata-ingestion/src/datahub/telemetry/telemetry.py b/metadata-ingestion/src/datahub/telemetry/telemetry.py index 22b2cb6a101af9..733dd6fe54a97e 100644 --- a/metadata-ingestion/src/datahub/telemetry/telemetry.py +++ b/metadata-ingestion/src/datahub/telemetry/telemetry.py @@ -157,7 +157,8 @@ def __init__(self): self.mp = Mixpanel( MIXPANEL_TOKEN, consumer=Consumer( - request_timeout=int(TIMEOUT), api_host=MIXPANEL_ENDPOINT + request_timeout=int(TIMEOUT), + api_host=MIXPANEL_ENDPOINT, ), ) except Exception as e: @@ -186,15 +187,15 @@ def update_config(self) -> bool: except OSError as x: if x.errno == errno.ENOENT: logger.debug( - f"{CONFIG_FILE} does not exist and could not be created. Please check permissions on the parent folder." + f"{CONFIG_FILE} does not exist and could not be created. Please check permissions on the parent folder.", ) elif x.errno == errno.EACCES: logger.debug( - f"{CONFIG_FILE} cannot be read. Please check the permissions on this file." + f"{CONFIG_FILE} cannot be read. Please check the permissions on this file.", ) else: logger.debug( - f"{CONFIG_FILE} had an IOError, please inspect this file for issues." + f"{CONFIG_FILE} had an IOError, please inspect this file for issues.", ) except Exception as e: logger.debug(f"Failed to update config file at {CONFIG_FILE} due to {e}") @@ -232,15 +233,15 @@ def load_config(self) -> bool: except OSError as x: if x.errno == errno.ENOENT: logger.debug( - f"{CONFIG_FILE} does not exist and could not be created. Please check permissions on the parent folder." + f"{CONFIG_FILE} does not exist and could not be created. Please check permissions on the parent folder.", ) elif x.errno == errno.EACCES: logger.debug( - f"{CONFIG_FILE} cannot be read. Please check the permissions on this file." + f"{CONFIG_FILE} cannot be read. Please check the permissions on this file.", ) else: logger.debug( - f"{CONFIG_FILE} had an IOError, please inspect this file for issues." + f"{CONFIG_FILE} had an IOError, please inspect this file for issues.", ) except Exception as e: logger.debug(f"Failed to load {CONFIG_FILE} due to {e}") @@ -328,7 +329,7 @@ def ping( try: if event_name == "function-call": logger.debug( - f"Sending telemetry for {event_name} {properties.get('function')}, status {properties.get('status')}" + f"Sending telemetry for {event_name} {properties.get('function')}, status {properties.get('status')}", ) else: logger.debug(f"Sending telemetry for {event_name}") @@ -353,7 +354,8 @@ def _server_props(cls, server: Optional["DataHubGraph"]) -> Dict[str, str]: else: return { "server_type": server.server_config.get("datahub", {}).get( - "serverType", "missing" + "serverType", + "missing", ), "server_version": server.server_config.get("versions", {}) .get("acryldata/datahub", {}) @@ -398,7 +400,8 @@ def _error_props(error: BaseException) -> Dict[str, Any]: def with_telemetry( - *, capture_kwargs: Optional[List[str]] = None + *, + capture_kwargs: Optional[List[str]] = None, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: kwargs_to_track = capture_kwargs or [] diff --git a/metadata-ingestion/src/datahub/testing/check_imports.py b/metadata-ingestion/src/datahub/testing/check_imports.py index e4bf07882b36ae..a0c2501fb755a9 100644 --- a/metadata-ingestion/src/datahub/testing/check_imports.py +++ b/metadata-ingestion/src/datahub/testing/check_imports.py @@ -30,5 +30,5 @@ def ensure_no_indirect_model_imports(dirs: List[pathlib.Path]) -> None: if denied_import in line: raise ValueError( f"Disallowed import found in {file}: `{line.rstrip()}`. " - f"Import from {replacement} instead." + f"Import from {replacement} instead.", ) diff --git a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py index 13be45ec1be28d..04a2b0fbc5d1c1 100644 --- a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py +++ b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py @@ -55,7 +55,7 @@ def assert_sql_result_with_resolver( expected_file.write_text(txt) raise AssertionError( f"Expected file {expected_file} does not exist. " - "Created it with the expected output. Please verify it." + "Created it with the expected output. Please verify it.", ) expected = SqlParsingResult.parse_raw(expected_file.read_text()) @@ -80,7 +80,8 @@ def assert_sql_result( **kwargs: Any, ) -> None: schema_resolver = SchemaResolver( - platform=dialect, platform_instance=platform_instance + platform=dialect, + platform_instance=platform_instance, ) if schemas: for urn, schema in schemas.items(): diff --git a/metadata-ingestion/src/datahub/testing/check_str_enum.py b/metadata-ingestion/src/datahub/testing/check_str_enum.py index 2d1a84aa5f738d..50577481656d2a 100644 --- a/metadata-ingestion/src/datahub/testing/check_str_enum.py +++ b/metadata-ingestion/src/datahub/testing/check_str_enum.py @@ -29,5 +29,5 @@ def ensure_no_enum_mixin(dirs: List[pathlib.Path]) -> None: raise ValueError( f"Disallowed enum mixin found in {file}: `{line.rstrip()}`. " "This enum mixin's behavior changed in Python 3.11, so it will work inconsistently across versions." - "Use datahub.utilities.str_enum.StrEnum instead." + "Use datahub.utilities.str_enum.StrEnum instead.", ) diff --git a/metadata-ingestion/src/datahub/testing/compare_metadata_json.py b/metadata-ingestion/src/datahub/testing/compare_metadata_json.py index 9dbadd4804997d..976b3c2cc5c9bf 100644 --- a/metadata-ingestion/src/datahub/testing/compare_metadata_json.py +++ b/metadata-ingestion/src/datahub/testing/compare_metadata_json.py @@ -54,7 +54,7 @@ def assert_metadata_files_equal( if not update_golden and not golden_exists: raise FileNotFoundError( - "Golden file does not exist. Please run with the --update-golden-files option to create." + "Golden file does not exist. Please run with the --update-golden-files option to create.", ) output = load_json_file(output_path) diff --git a/metadata-ingestion/src/datahub/testing/docker_utils.py b/metadata-ingestion/src/datahub/testing/docker_utils.py index 7c1c0304f480e6..80435f422c95d8 100644 --- a/metadata-ingestion/src/datahub/testing/docker_utils.py +++ b/metadata-ingestion/src/datahub/testing/docker_utils.py @@ -50,11 +50,16 @@ def wait_for_port( @pytest.fixture(scope="module") def docker_compose_runner( - docker_compose_command, docker_compose_project_name, docker_setup, docker_cleanup + docker_compose_command, + docker_compose_project_name, + docker_setup, + docker_cleanup, ): @contextlib.contextmanager def run( - compose_file_path: Union[str, List[str]], key: str, cleanup: bool = True + compose_file_path: Union[str, List[str]], + key: str, + cleanup: bool = True, ) -> Iterator[pytest_docker.plugin.Services]: with pytest_docker.plugin.get_docker_services( docker_compose_command=docker_compose_command, diff --git a/metadata-ingestion/src/datahub/testing/mcp_diff.py b/metadata-ingestion/src/datahub/testing/mcp_diff.py index b58afc10148edc..b5d516f5a3491a 100644 --- a/metadata-ingestion/src/datahub/testing/mcp_diff.py +++ b/metadata-ingestion/src/datahub/testing/mcp_diff.py @@ -285,7 +285,8 @@ def report_aspect(ga: AspectForDiff, idx: int, msg: str = "") -> str: @staticmethod def report_diff_level(diff: DiffLevel, idx: int) -> str: return "\t" + deepdiff.serialization.pretty_print_diff(diff).replace( - f"root[{idx}].", "" + f"root[{idx}].", + "", ) diff --git a/metadata-ingestion/src/datahub/upgrade/upgrade.py b/metadata-ingestion/src/datahub/upgrade/upgrade.py index fb14514588e5fc..5d8a8ccf8afc07 100644 --- a/metadata-ingestion/src/datahub/upgrade/upgrade.py +++ b/metadata-ingestion/src/datahub/upgrade/upgrade.py @@ -48,7 +48,8 @@ async def get_client_version_stats(): current_version_string = __version__ current_version = Version(current_version_string) client_version_stats: ClientVersionStats = ClientVersionStats( - current=VersionStats(version=current_version, release_date=None), latest=None + current=VersionStats(version=current_version, release_date=None), + latest=None, ) async with aiohttp.ClientSession() as session: pypi_url = "https://pypi.org/pypi/acryl_datahub/json" @@ -65,20 +66,24 @@ async def get_client_version_stats(): current_version_date = None if current_version_info: current_version_date = datetime.strptime( - current_version_info[0].get("upload_time"), "%Y-%m-%dT%H:%M:%S" + current_version_info[0].get("upload_time"), + "%Y-%m-%dT%H:%M:%S", ) latest_release_info = releases.get(latest_cli_release_string) latest_version_date = None if latest_release_info: latest_version_date = datetime.strptime( - latest_release_info[0].get("upload_time"), "%Y-%m-%dT%H:%M:%S" + latest_release_info[0].get("upload_time"), + "%Y-%m-%dT%H:%M:%S", ) client_version_stats = ClientVersionStats( current=VersionStats( - version=current_version, release_date=current_version_date + version=current_version, + release_date=current_version_date, ), latest=VersionStats( - version=latest_cli_release, release_date=latest_version_date + version=latest_cli_release, + release_date=latest_version_date, ), ) except Exception as e: @@ -91,7 +96,7 @@ async def get_github_stats(): import aiohttp async with aiohttp.ClientSession( - headers={"Accept": "application/vnd.github.v3+json"} + headers={"Accept": "application/vnd.github.v3+json"}, ) as session: gh_url = "https://api.github.com/repos/datahub-project/datahub/releases" async with session.get(gh_url) as gh_response: @@ -153,7 +158,7 @@ async def get_server_version_stats( server_type = server_config.get("datahub", {}).get("serverType", "unknown") if server_type == "quickstart" and commit_hash: async with aiohttp.ClientSession( - headers={"Accept": "application/vnd.github.v3+json"} + headers={"Accept": "application/vnd.github.v3+json"}, ) as session: gh_url = f"https://api.github.com/repos/datahub-project/datahub/commits/{commit_hash}" async with session.get(gh_url) as gh_response: @@ -169,7 +174,8 @@ async def get_server_version_stats( def retrieve_version_stats( - timeout: float, graph: Optional[DataHubGraph] = None + timeout: float, + graph: Optional[DataHubGraph] = None, ) -> Optional[DataHubVersionStats]: version_stats: Optional[DataHubVersionStats] = None @@ -215,7 +221,8 @@ async def _retrieve_version_stats( if current_server_version: server_version_stats = ServerVersionStats( current=VersionStats( - version=current_server_version, release_date=current_server_release_date + version=current_server_version, + release_date=current_server_release_date, ), latest=( VersionStats(version=last_server_version, release_date=last_server_date) @@ -227,7 +234,8 @@ async def _retrieve_version_stats( if client_version_stats and server_version_stats: return DataHubVersionStats( - server=server_version_stats, client=client_version_stats + server=server_version_stats, + client=client_version_stats, ) else: return None @@ -237,7 +245,8 @@ def get_days(time_point: Optional[Any]) -> str: if time_point: return "(released " + ( humanfriendly.format_timespan( - datetime.now(timezone.utc) - time_point, max_units=1 + datetime.now(timezone.utc) - time_point, + max_units=1, ) + " ago)" ) @@ -273,7 +282,7 @@ def is_client_server_compatible(client: VersionStats, server: VersionStats) -> i +ve implies server is ahead of client """ if not valid_client_version(client.version) or not valid_server_version( - server.version + server.version, ): # we cannot evaluate compatibility, choose True as default return 0 @@ -307,7 +316,8 @@ def _maybe_print_upgrade_message( # noqa: C901 else None ) client_server_compat = is_client_server_compatible( - version_stats.client.current, version_stats.server.current + version_stats.client.current, + version_stats.server.current, ) if latest_release_date and current_release_date: @@ -330,7 +340,7 @@ def _maybe_print_upgrade_message( # noqa: C901 ) if time_delta > timedelta(days=days_before_quickstart_stale): log.debug( - f"will encourage upgrade due to server being old {version_stats.server.current.release_date},{time_delta}" + f"will encourage upgrade due to server being old {version_stats.server.current.release_date},{time_delta}", ) encourage_quickstart_upgrade = True if version_stats.server.latest and ( @@ -338,7 +348,7 @@ def _maybe_print_upgrade_message( # noqa: C901 > version_stats.server.current.version ): log.debug( - f"Will encourage upgrade due to newer version of server {version_stats.server.latest.version} being available compared to {version_stats.server.current.version}" + f"Will encourage upgrade due to newer version of server {version_stats.server.latest.version} being available compared to {version_stats.server.current.version}", ) encourage_quickstart_upgrade = True @@ -356,7 +366,7 @@ def _maybe_print_upgrade_message( # noqa: C901 + click.style( f"➡️ Downgrade via `\"pip install 'acryl-datahub=={version_stats.server.current.version}'\"", fg="cyan", - ) + ), ) elif client_server_compat > 0: with contextlib.suppress(Exception): @@ -371,7 +381,7 @@ def _maybe_print_upgrade_message( # noqa: C901 + click.style( f"➡️ Upgrade via \"pip install 'acryl-datahub=={version_stats.server.current.version}'\"", fg="cyan", - ) + ), ) elif client_server_compat == 0 and encourage_cli_upgrade: with contextlib.suppress(Exception): @@ -381,7 +391,7 @@ def _maybe_print_upgrade_message( # noqa: C901 + click.style( f"You seem to be running an old version of datahub cli: {current_version} {get_days(current_release_date)}. Latest version is {latest_version} {get_days(latest_release_date)}.\nUpgrade via \"pip install -U 'acryl-datahub'\"", fg="cyan", - ) + ), ) elif encourage_quickstart_upgrade: try: @@ -413,7 +423,8 @@ def check_upgrade_post( version_stats_timeout = clip(main_method_runtime / 10, 0.7, 3.0) try: version_stats = retrieve_version_stats( - timeout=version_stats_timeout, graph=graph + timeout=version_stats_timeout, + graph=graph, ) _maybe_print_upgrade_message(version_stats=version_stats) except Exception as e: diff --git a/metadata-ingestion/src/datahub/utilities/_custom_package_loader.py b/metadata-ingestion/src/datahub/utilities/_custom_package_loader.py index 02115efc4553e5..08fe83ef8c5990 100644 --- a/metadata-ingestion/src/datahub/utilities/_custom_package_loader.py +++ b/metadata-ingestion/src/datahub/utilities/_custom_package_loader.py @@ -34,7 +34,7 @@ def _get_custom_entrypoint_for_name(name: str) -> Optional[EntryPoint]: entrypoint.dist.name for entrypoint in entrypoints if entrypoint.dist ] raise CustomPackageException( - f"Multiple custom packages registered for {name}: cannot pick between {all_package_options}" + f"Multiple custom packages registered for {name}: cannot pick between {all_package_options}", ) return entrypoints[0] diff --git a/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py b/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py index 988bd91c4a642e..0270c018ad0d7b 100644 --- a/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py +++ b/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py @@ -58,7 +58,8 @@ def map( # If the pending list is full, wait until one is done. if len(pending_futures) >= max_pending: (done, _) = concurrent.futures.wait( - pending_futures, return_when=concurrent.futures.FIRST_COMPLETED + pending_futures, + return_when=concurrent.futures.FIRST_COMPLETED, ) for future in done: pending_futures.remove(future) diff --git a/metadata-ingestion/src/datahub/utilities/checkpoint_state_util.py b/metadata-ingestion/src/datahub/utilities/checkpoint_state_util.py index 601bbe2dc7d81b..fd60dd64b048d4 100644 --- a/metadata-ingestion/src/datahub/utilities/checkpoint_state_util.py +++ b/metadata-ingestion/src/datahub/utilities/checkpoint_state_util.py @@ -20,7 +20,8 @@ def get_separator() -> str: @staticmethod def get_encoded_urns_not_in( - encoded_urns_1: List[str], encoded_urns_2: List[str] + encoded_urns_1: List[str], + encoded_urns_2: List[str], ) -> Set[str]: return set(encoded_urns_1) - set(encoded_urns_2) @@ -29,8 +30,10 @@ def get_urn_from_encoded_dataset(encoded_urn: str) -> str: platform, name, env = encoded_urn.split(CheckpointStateUtil.get_separator()) return dataset_key_to_urn( DatasetKeyClass( - platform=make_data_platform_urn(platform), name=name, origin=env - ) + platform=make_data_platform_urn(platform), + name=name, + origin=env, + ), ) @staticmethod diff --git a/metadata-ingestion/src/datahub/utilities/cooperative_timeout.py b/metadata-ingestion/src/datahub/utilities/cooperative_timeout.py index f8cb12e3f7d013..26d3e47598d009 100644 --- a/metadata-ingestion/src/datahub/utilities/cooperative_timeout.py +++ b/metadata-ingestion/src/datahub/utilities/cooperative_timeout.py @@ -57,7 +57,7 @@ def cooperative_timeout(timeout: Optional[float] = None) -> Iterator[None]: if timeout is not None: token = _cooperation_deadline.set( - time.perf_counter_ns() + int(timeout * 1_000_000_000) + time.perf_counter_ns() + int(timeout * 1_000_000_000), ) try: yield diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index 79da90ba20ea9f..abb39f06a69ea3 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -102,7 +102,9 @@ def __init__(self, filename: Optional[pathlib.Path] = None): # still need to be careful to avoid concurrent access. self.conn_lock = threading.Lock() self.conn = sqlite3.connect( - filename, isolation_level=None, check_same_thread=False + filename, + isolation_level=None, + check_same_thread=False, ) self.conn.row_factory = sqlite3.Row @@ -125,13 +127,17 @@ def allow_table_name_reuse(self) -> bool: return self._temp_directory is None def execute( - self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = () + self, + sql: str, + parameters: Union[Dict[str, Any], Sequence[Any]] = (), ) -> sqlite3.Cursor: with self.conn_lock: return self.conn.execute(sql, parameters) def executemany( - self, sql: str, parameters: Union[Dict[str, Any], Sequence[Any]] = () + self, + sql: str, + parameters: Union[Dict[str, Any], Sequence[Any]] = (), ) -> sqlite3.Cursor: with self.conn_lock: return self.conn.executemany(sql, parameters) @@ -219,7 +225,8 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]): # To improve performance, we maintain an in-memory LRU cache using an OrderedDict. # Maintains a dirty bit marking whether the value has been modified since it was persisted. _active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field( - init=False, repr=False + init=False, + repr=False, ) _use_sqlite_on_conflict: bool = field(repr=False, default=True) @@ -262,7 +269,7 @@ def __post_init__(self) -> None: key TEXT UNIQUE, value BLOB {"".join(f", {column_name} BLOB" for column_name in self.extra_columns.keys())} - )""" + )""", ) if not self.delay_index_creation: @@ -280,7 +287,7 @@ def create_indexes(self) -> None: # The key column will automatically be indexed, but we need indexes for the extra columns. for column_name in self.extra_columns.keys(): self._conn.execute( - f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})" + f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})", ) self.indexes_created = True @@ -294,7 +301,8 @@ def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None: # However, we don't want to prune the thing we just added, # in case there's a mark_dirty() call immediately after. num_items_to_prune = min( - len(self._active_object_cache) - 1, self.cache_eviction_batch_size + len(self._active_object_cache) - 1, + self.cache_eviction_batch_size, ) self._prune_cache(num_items_to_prune) @@ -355,7 +363,8 @@ def __getitem__(self, key: str) -> _VT: return self._active_object_cache[key][0] cursor = self._conn.execute( - f"SELECT value FROM {self.tablename} WHERE key = ?", (key,) + f"SELECT value FROM {self.tablename} WHERE key = ?", + (key,), ) result: Sequence[SqliteValue] = cursor.fetchone() if result is None: @@ -403,7 +412,8 @@ def __delitem__(self, key: str) -> None: in_cache = True n_deleted = self._conn.execute( - f"DELETE FROM {self.tablename} WHERE key = ?", (key,) + f"DELETE FROM {self.tablename} WHERE key = ?", + (key,), ).rowcount if not in_cache and not n_deleted: raise KeyError(key) @@ -412,7 +422,7 @@ def mark_dirty(self, key: str) -> None: if key not in self._active_object_cache: raise ValueError( f"key {key} not in active object cache, which means any dirty value " - "is already persisted or lost" + "is already persisted or lost", ) if not self._active_object_cache[key][1]: @@ -424,13 +434,14 @@ def __iter__(self) -> Iterator[str]: # Our active object cache should now be empty, so it's fine to # just pull from the DB. cursor = self._conn.execute( - f"SELECT key FROM {self.tablename} ORDER BY rowid ASC" + f"SELECT key FROM {self.tablename} ORDER BY rowid ASC", ) for row in cursor: yield row[0] def items_snapshot( - self, cond_sql: Optional[str] = None + self, + cond_sql: Optional[str] = None, ) -> Iterator[Tuple[str, _VT]]: """ Return a fixed snapshot, rather than a view, of the dictionary's items. diff --git a/metadata-ingestion/src/datahub/utilities/hive_schema_to_avro.py b/metadata-ingestion/src/datahub/utilities/hive_schema_to_avro.py index fccd8dd8a60c35..89644110340c19 100644 --- a/metadata-ingestion/src/datahub/utilities/hive_schema_to_avro.py +++ b/metadata-ingestion/src/datahub/utilities/hive_schema_to_avro.py @@ -55,7 +55,8 @@ class HiveColumnToAvroConverter: @staticmethod def _parse_datatype_string( - s: str, **kwargs: Any + s: str, + **kwargs: Any, ) -> Union[object, Dict[str, object]]: s = s.strip() if s.startswith("array<"): @@ -73,7 +74,7 @@ def _parse_datatype_string( if len(parts) != 2: raise ValueError( "The map type string format is: 'map', " - + f"but got: {s}" + + f"but got: {s}", ) kt = HiveColumnToAvroConverter._parse_datatype_string(parts[0]) @@ -97,8 +98,9 @@ def _parse_datatype_string( # ustruct_seqn defines sequence number of struct in union t.append( HiveColumnToAvroConverter._parse_datatype_string( - part, ustruct_seqn=ustruct_seqn - ) + part, + ustruct_seqn=ustruct_seqn, + ), ) ustruct_seqn += 1 else: @@ -108,7 +110,8 @@ def _parse_datatype_string( if s[-1] != ">": raise ValueError("'>' should be the last char, but got: %s" % s) return HiveColumnToAvroConverter._parse_struct_fields_string( - s[7:-1], **kwargs + s[7:-1], + **kwargs, ) elif ":" in s: return HiveColumnToAvroConverter._parse_struct_fields_string(s, **kwargs) @@ -121,12 +124,13 @@ def _parse_struct_fields_string(s: str, **kwargs: Any) -> Dict[str, object]: fields: List[Dict] = [] for part in parts: name_and_type = HiveColumnToAvroConverter._ignore_brackets_split( - part.strip(), HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR + part.strip(), + HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR, ) if len(name_and_type) != 2: raise ValueError( "The struct field string format is: 'field_name:field_type', " - + f"but got: {part}" + + f"but got: {part}", ) field_name = name_and_type[0].strip() @@ -135,7 +139,7 @@ def _parse_struct_fields_string(s: str, **kwargs: Any) -> Dict[str, object]: raise ValueError("'`' should be the last char, but got: %s" % s) field_name = field_name[1:-1] field_type = HiveColumnToAvroConverter._parse_datatype_string( - name_and_type[1] + name_and_type[1], ) if not any(field["name"] == field_name for field in fields): @@ -245,7 +249,9 @@ def is_primitive_hive_type(hive_type: str) -> bool: @classmethod def get_avro_schema_for_hive_column( - cls, hive_column_name: str, hive_column_type: str + cls, + hive_column_name: str, + hive_column_type: str, ) -> Union[object, Dict[str, object]]: converter = cls() # Below Record structure represents the dataset level @@ -260,7 +266,7 @@ def get_avro_schema_for_hive_column( { "name": hive_column_name, "type": converter._parse_datatype_string(hive_column_type), - } + }, ], } @@ -270,7 +276,8 @@ def get_avro_schema_for_hive_column( hive_column_type: str, ) -> Union[object, Dict[str, object]]: return HiveColumnToAvroConverter.get_avro_schema_for_hive_column( - hive_column_name, hive_column_type + hive_column_name, + hive_column_type, ) @@ -283,7 +290,8 @@ def get_schema_fields_for_hive_column( ) -> List[SchemaField]: try: avro_schema_json = get_avro_schema_for_hive_column( - hive_column_name=hive_column_name, hive_column_type=hive_column_type + hive_column_name=hive_column_name, + hive_column_type=hive_column_type, ) schema_fields = avro_schema_to_mce_fields( avro_schema=json.dumps(avro_schema_json), @@ -292,14 +300,14 @@ def get_schema_fields_for_hive_column( ) except Exception as e: logger.warning( - f"Unable to parse column {hive_column_name} and type {hive_column_type} the error was: {e}" + f"Unable to parse column {hive_column_name} and type {hive_column_type} the error was: {e}", ) schema_fields = [ SchemaField( fieldPath=hive_column_name, type=SchemaFieldDataTypeClass(type=NullTypeClass()), nativeDataType=hive_column_type, - ) + ), ] assert schema_fields diff --git a/metadata-ingestion/src/datahub/utilities/logging_manager.py b/metadata-ingestion/src/datahub/utilities/logging_manager.py index a5fd20fef307d0..044fa3347b34d9 100644 --- a/metadata-ingestion/src/datahub/utilities/logging_manager.py +++ b/metadata-ingestion/src/datahub/utilities/logging_manager.py @@ -96,7 +96,7 @@ def extract_name_from_filename(filename: str, fallback_name: str) -> str: if "metadata-ingestion" in part ), [None], - ) + ), ) if src_dir_index is not None: # Join the parts after 'src' with '.' diff --git a/metadata-ingestion/src/datahub/utilities/lossy_collections.py b/metadata-ingestion/src/datahub/utilities/lossy_collections.py index 31d6d0eb842d04..b828224e43e14e 100644 --- a/metadata-ingestion/src/datahub/utilities/lossy_collections.py +++ b/metadata-ingestion/src/datahub/utilities/lossy_collections.py @@ -108,7 +108,7 @@ def as_obj(self) -> List[Union[T, str]]: base_list: List[Union[T, str]] = list(self.__iter__()) if self.sampled: base_list.append( - f"... sampled with at most {self._items_removed} elements missing." + f"... sampled with at most {self._items_removed} elements missing.", ) return base_list diff --git a/metadata-ingestion/src/datahub/utilities/mapping.py b/metadata-ingestion/src/datahub/utilities/mapping.py index 96870fc6fcd378..4497c48a45ce44 100644 --- a/metadata-ingestion/src/datahub/utilities/mapping.py +++ b/metadata-ingestion/src/datahub/utilities/mapping.py @@ -121,7 +121,7 @@ def make_owner_category_list(self) -> List[Dict]: "urn": owner_id, "category": owner_category, "categoryUrn": owner_category_urn, - } + }, ) return res @@ -190,7 +190,7 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 if datahub_prop.tags: # Note that tags get converted to urns later because we need to support the tag prefix. operations_map.setdefault(Constants.ADD_TAG_OPERATION, []).extend( - datahub_prop.tags + datahub_prop.tags, ) if datahub_prop.terms: @@ -200,12 +200,13 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 if datahub_prop.owners: operations_map.setdefault(Constants.ADD_OWNER_OPERATION, []).extend( - datahub_prop.make_owner_category_list() + datahub_prop.make_owner_category_list(), ) if datahub_prop.domain: operations_map.setdefault( - Constants.ADD_DOMAIN_OPERATION, [] + Constants.ADD_DOMAIN_OPERATION, + [], ).append(mce_builder.make_domain_urn(datahub_prop.domain)) except Exception as e: logger.error(f"Error while processing datahub property: {e}") @@ -214,10 +215,10 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 try: for operation_key in self.operation_defs: operation_type = self.operation_defs.get(operation_key, {}).get( - Constants.OPERATION + Constants.OPERATION, ) operation_config = self.operation_defs.get(operation_key, {}).get( - Constants.OPERATION_CONFIG + Constants.OPERATION_CONFIG, ) if not operation_type or not operation_config: continue @@ -225,7 +226,9 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 if not raw_props_value and self.match_nested_props: try: raw_props_value = reduce( - operator.getitem, operation_key.split("."), raw_props + operator.getitem, + operation_key.split("."), + raw_props, ) except KeyError: pass @@ -236,7 +239,10 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 ) if maybe_match is not None: operation = self.get_operation_value( - operation_key, operation_type, operation_config, maybe_match + operation_key, + operation_type, + operation_config, + maybe_match, ) if operation_type == Constants.ADD_TERMS_OPERATION: @@ -250,18 +256,18 @@ def process(self, raw_props: Mapping[str, Any]) -> Dict[str, Any]: # noqa: C901 and operation_type == Constants.ADD_OWNER_OPERATION ): operations_map.setdefault(operation_type, []).extend( - operation + operation, ) elif isinstance(operation, (str, list)): operations_map.setdefault(operation_type, []).extend( operation if isinstance(operation, list) - else [operation] + else [operation], ) else: operations_map.setdefault(operation_type, []).append( - operation + operation, ) except Exception as e: logger.error(f"Error while processing operation defs over raw_props: {e}") @@ -278,7 +284,7 @@ def convert_to_aspects(self, operation_map: Dict[str, list]) -> Dict[str, Any]: if Constants.ADD_TAG_OPERATION in operation_map: tag_aspect = mce_builder.make_global_tag_aspect_with_tag_list( - sorted(set(operation_map[Constants.ADD_TAG_OPERATION])) + sorted(set(operation_map[Constants.ADD_TAG_OPERATION])), ) aspect_map[Constants.ADD_TAG_OPERATION] = tag_aspect @@ -300,14 +306,14 @@ def convert_to_aspects(self, operation_map: Dict[str, list]) -> Dict[str, Any]: operation_map[Constants.ADD_OWNER_OPERATION], key=lambda x: x["urn"], ) - ] + ], ) aspect_map[Constants.ADD_OWNER_OPERATION] = owner_aspect if Constants.ADD_TERM_OPERATION in operation_map: term_aspect = mce_builder.make_glossary_terms_aspect_from_urn_list( - sorted(set(operation_map[Constants.ADD_TERM_OPERATION])) + sorted(set(operation_map[Constants.ADD_TERM_OPERATION])), ) aspect_map[Constants.ADD_TERM_OPERATION] = term_aspect @@ -316,23 +322,25 @@ def convert_to_aspects(self, operation_map: Dict[str, list]) -> Dict[str, Any]: domains=[ mce_builder.make_domain_urn(domain) for domain in operation_map[Constants.ADD_DOMAIN_OPERATION] - ] + ], ) aspect_map[Constants.ADD_DOMAIN_OPERATION] = domain_aspect if Constants.ADD_DOC_LINK_OPERATION in operation_map: try: if len( - operation_map[Constants.ADD_DOC_LINK_OPERATION] + operation_map[Constants.ADD_DOC_LINK_OPERATION], ) == 1 and isinstance( - operation_map[Constants.ADD_DOC_LINK_OPERATION], list + operation_map[Constants.ADD_DOC_LINK_OPERATION], + list, ): docs_dict = cast( - List[Dict], operation_map[Constants.ADD_DOC_LINK_OPERATION] + List[Dict], + operation_map[Constants.ADD_DOC_LINK_OPERATION], )[0] if "description" not in docs_dict or "link" not in docs_dict: raise Exception( - "Documentation_link meta_mapping config needs a description key and a link key" + "Documentation_link meta_mapping config needs a description key and a link key", ) now = int(time.time() * 1000) # milliseconds since epoch @@ -340,13 +348,14 @@ def convert_to_aspects(self, operation_map: Dict[str, list]) -> Dict[str, Any]: url=docs_dict["link"], description=docs_dict["description"], createStamp=AuditStampClass( - time=now, actor="urn:li:corpuser:ingestion" + time=now, + actor="urn:li:corpuser:ingestion", ), ) # create a new institutional memory aspect institutional_memory_aspect = InstitutionalMemoryClass( - elements=[institutional_memory_element] + elements=[institutional_memory_element], ) aspect_map[Constants.ADD_DOC_LINK_OPERATION] = ( @@ -356,12 +365,12 @@ def convert_to_aspects(self, operation_map: Dict[str, list]) -> Dict[str, Any]: raise Exception( f"Expected 1 item of type list for the documentation_link meta_mapping config," f" received type of {type(operation_map[Constants.ADD_DOC_LINK_OPERATION])}" - f", and size of {len(operation_map[Constants.ADD_DOC_LINK_OPERATION])}." + f", and size of {len(operation_map[Constants.ADD_DOC_LINK_OPERATION])}.", ) except Exception as e: logger.error( - f"Error while constructing aspect for documentation link and description : {e}" + f"Error while constructing aspect for documentation link and description : {e}", ) return aspect_map @@ -388,7 +397,8 @@ def get_operation_value( owner_ids: List[str] = [_id.strip() for _id in owner_id.split(",")] owner_type_raw = operation_config.get( - Constants.OWNER_TYPE, Constants.USER_OWNER + Constants.OWNER_TYPE, + Constants.USER_OWNER, ) owner_type_mapping: Dict[str, OwnerType] = { Constants.USER_OWNER: OwnerType.USER, @@ -396,7 +406,7 @@ def get_operation_value( } if owner_type_raw not in owner_type_mapping: logger.warning( - f"Invalid owner type: {owner_type_raw}. Valid owner types are {', '.join(owner_type_mapping.keys())}" + f"Invalid owner type: {owner_type_raw}. Valid owner types are {', '.join(owner_type_mapping.keys())}", ) return None owner_type = owner_type_mapping[owner_type_raw] diff --git a/metadata-ingestion/src/datahub/utilities/parsing_util.py b/metadata-ingestion/src/datahub/utilities/parsing_util.py index 6d91700aefdbf1..0b4262226acc9a 100644 --- a/metadata-ingestion/src/datahub/utilities/parsing_util.py +++ b/metadata-ingestion/src/datahub/utilities/parsing_util.py @@ -16,7 +16,8 @@ def get_first_missing_key(inp_dict: Dict, keys: List[str]) -> Optional[str]: def get_first_missing_key_any( - inp_dict: Dict[str, Any], keys: List[str] + inp_dict: Dict[str, Any], + keys: List[str], ) -> Optional[str]: for key in keys: if key not in inp_dict: diff --git a/metadata-ingestion/src/datahub/utilities/partition_executor.py b/metadata-ingestion/src/datahub/utilities/partition_executor.py index 542889f2f90e29..db0d0994b0ac66 100644 --- a/metadata-ingestion/src/datahub/utilities/partition_executor.py +++ b/metadata-ingestion/src/datahub/utilities/partition_executor.py @@ -65,7 +65,8 @@ def __init__(self, max_workers: int, max_pending: int) -> None: # Any entries in the key's value e.g. the deque are requests that are waiting # to be submitted once the current request for that key completes. self._pending_by_key: Dict[ - str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]] + str, + Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]], ] = {} def submit( @@ -283,7 +284,7 @@ def __init__( self._pending_count = BoundedSemaphore(max_pending) self._pending: "queue.Queue[Optional[_BatchPartitionWorkItem]]" = queue.Queue( - maxsize=max_pending + maxsize=max_pending, ) # If this is true, that means shutdown() has been called. @@ -308,7 +309,8 @@ def _clearinghouse_worker(self) -> None: # noqa: C901 last_submit_time = _now() def _handle_batch_completion( - batch: List[_BatchPartitionWorkItem], future: Future + batch: List[_BatchPartitionWorkItem], + future: Future, ) -> None: with clearinghouse_state_lock: nonlocal workers_available @@ -382,7 +384,7 @@ def _build_batch() -> List[_BatchPartitionWorkItem]: except queue.Empty: if blocking: next_batch.extend( - _find_ready_items(self.max_per_batch - len(next_batch)) + _find_ready_items(self.max_per_batch - len(next_batch)), ) else: break @@ -401,10 +403,11 @@ def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None: last_submit_time = _now() future = self._executor.submit( - self.process_batch, [item.args for item in next_batch] + self.process_batch, + [item.args for item in next_batch], ) future.add_done_callback( - functools.partial(_handle_batch_completion, next_batch) + functools.partial(_handle_batch_completion, next_batch), ) try: @@ -439,13 +442,14 @@ def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None: logger.error( f"{self.__class__.__name__}: submit() was called but the executor was not cleaned up properly. " "The data from the submit() calls will be lost. Use a context manager or call shutdown() explicitly " - "to ensure all submitted work is processed." + "to ensure all submitted work is processed.", ) return # This represents a fatal error that makes the entire executor defunct. logger.exception( - "Threaded executor's clearinghouse worker failed.", exc_info=e + "Threaded executor's clearinghouse worker failed.", + exc_info=e, ) finally: self._clearinghouse_started = False @@ -453,7 +457,7 @@ def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None: def _ensure_clearinghouse_started(self) -> None: if self._shutting_down: raise RuntimeError( - f"{self.__class__.__name__} is shutting down; cannot submit new work items." + f"{self.__class__.__name__} is shutting down; cannot submit new work items.", ) with self._state_lock: diff --git a/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py b/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py index 271c2517e87713..9180878a265bc3 100644 --- a/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py +++ b/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py @@ -11,13 +11,17 @@ class PrefixGroup: def build_prefix_batches( - names: List[str], max_batch_size: int, max_groups_in_batch: int + names: List[str], + max_batch_size: int, + max_groups_in_batch: int, ) -> List[List[PrefixGroup]]: """Split the names into a list of batches, where each batch is a list of groups and each group is a list of names with a common prefix.""" groups = _build_prefix_groups(names, max_batch_size=max_batch_size) batches = _batch_prefix_groups( - groups, max_batch_size=max_batch_size, max_groups_in_batch=max_groups_in_batch + groups, + max_batch_size=max_batch_size, + max_groups_in_batch=max_groups_in_batch, ) return batches @@ -53,7 +57,9 @@ def split_group(group: PrefixGroup) -> List[PrefixGroup]: def _batch_prefix_groups( - groups: List[PrefixGroup], max_batch_size: int, max_groups_in_batch: int + groups: List[PrefixGroup], + max_batch_size: int, + max_groups_in_batch: int, ) -> List[List[PrefixGroup]]: """Batch the groups together, so that no batch's total is larger than `max_batch_size` and no group in a batch is larger than `max_group_size`.""" diff --git a/metadata-ingestion/src/datahub/utilities/registries/domain_registry.py b/metadata-ingestion/src/datahub/utilities/registries/domain_registry.py index da8b6310c79f59..71e281b871c59f 100644 --- a/metadata-ingestion/src/datahub/utilities/registries/domain_registry.py +++ b/metadata-ingestion/src/datahub/utilities/registries/domain_registry.py @@ -24,7 +24,7 @@ def __init__( ] if domains_needing_resolution and not graph: raise ValueError( - f"Following domains need server-side resolution {domains_needing_resolution} but a DataHub server wasn't provided. Either use fully qualified domain ids (e.g. urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba) or provide a datahub_api config in your recipe." + f"Following domains need server-side resolution {domains_needing_resolution} but a DataHub server wasn't provided. Either use fully qualified domain ids (e.g. urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba) or provide a datahub_api config in your recipe.", ) for domain_identifier in domains_needing_resolution: assert graph @@ -40,10 +40,10 @@ def __init__( self.domain_registry[domain_identifier] = domain_urn else: logger.error( - f"Failed to retrieve domain id for domain {domain_identifier}" + f"Failed to retrieve domain id for domain {domain_identifier}", ) raise ValueError( - f"domain {domain_identifier} doesn't seem to be provisioned on DataHub. Either provision it first and re-run ingestion, or provide a fully qualified domain id (e.g. urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba) to skip this check." + f"domain {domain_identifier} doesn't seem to be provisioned on DataHub. Either provision it first and re-run ingestion, or provide a fully qualified domain id (e.g. urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba) to skip this check.", ) def get_domain_urn(self, domain_identifier: str) -> str: diff --git a/metadata-ingestion/src/datahub/utilities/search_utils.py b/metadata-ingestion/src/datahub/utilities/search_utils.py index 0bd88addd86600..502eb9c78bd82a 100644 --- a/metadata-ingestion/src/datahub/utilities/search_utils.py +++ b/metadata-ingestion/src/datahub/utilities/search_utils.py @@ -73,7 +73,10 @@ def escape_special_characters(cls, value: str) -> str: return re.sub(f"([{re.escape(cls.SPECIAL_CHARACTERS)}])", r"\\\1", value) def _create_term( - self, field: SearchField, value: str, is_exact: bool = False + self, + field: SearchField, + value: str, + is_exact: bool = False, ) -> str: escaped_value = self.escape_special_characters(field.get_search_value(value)) field_name: str = field.field_name @@ -82,14 +85,20 @@ def _create_term( return f"{field_name}:{escaped_value}" def add_field_match( - self, field: SearchField, value: str, is_exact: bool = True + self, + field: SearchField, + value: str, + is_exact: bool = True, ) -> "ElasticsearchQueryBuilder": term = self._create_term(field, value, is_exact) self.root.add_child(term) return self def add_field_not_match( - self, field: SearchField, value: str, is_exact: bool = True + self, + field: SearchField, + value: str, + is_exact: bool = True, ) -> "ElasticsearchQueryBuilder": term = f"-{self._create_term(field, value, is_exact)}" self.root.add_child(term) @@ -117,14 +126,20 @@ def add_wildcard(self, field: str, pattern: str) -> "ElasticsearchQueryBuilder": return self def add_fuzzy( - self, field: str, value: str, fuzziness: int = 2 + self, + field: str, + value: str, + fuzziness: int = 2, ) -> "ElasticsearchQueryBuilder": fuzzy_query = f"{field}:{value}~{fuzziness}" self.root.add_child(fuzzy_query) return self def add_boost( - self, field: str, value: str, boost: float + self, + field: str, + value: str, + boost: float, ) -> "ElasticsearchQueryBuilder": boosted_query = f"{field}:{value}^{boost}" self.root.add_child(boosted_query) @@ -144,7 +159,10 @@ def __init__(self, parent: ElasticsearchQueryBuilder, operator: LogicalOperator) self.parent.root.add_child(self.node) def add_field_match( - self, field: Union[str, SearchField], value: str, is_exact: bool = True + self, + field: Union[str, SearchField], + value: str, + is_exact: bool = True, ) -> "QueryGroup": if isinstance(field, str): field = SearchField.from_string_field(field) @@ -153,7 +171,10 @@ def add_field_match( return self def add_field_not_match( - self, field: Union[str, SearchField], value: str, is_exact: bool = True + self, + field: Union[str, SearchField], + value: str, + is_exact: bool = True, ) -> "QueryGroup": if isinstance(field, str): field = SearchField.from_string_field(field) @@ -226,14 +247,18 @@ def create_from( instance.add_field_match(field, value) elif isinstance(field, str): instance.add_field_match( - SearchField.from_string_field(field), value + SearchField.from_string_field(field), + value, ) else: raise ValueError("Invalid field type {}".format(type(field))) return instance def add_field_match( - self, field: Union[str, SearchField], value: str, is_exact: bool = True + self, + field: Union[str, SearchField], + value: str, + is_exact: bool = True, ) -> "ElasticDocumentQuery": if isinstance(field, str): field = SearchField.from_string_field(field) @@ -241,7 +266,10 @@ def add_field_match( return self def add_field_not_match( - self, field: SearchField, value: str, is_exact: bool = True + self, + field: SearchField, + value: str, + is_exact: bool = True, ) -> "ElasticDocumentQuery": self.query_builder.add_field_not_match(field, value, is_exact) return self @@ -256,7 +284,11 @@ def add_range( ) -> "ElasticDocumentQuery": field_name: str = field.field_name # type: ignore self.query_builder.add_range( - field_name, min_value, max_value, include_min, include_max + field_name, + min_value, + max_value, + include_min, + include_max, ) return self @@ -266,14 +298,20 @@ def add_wildcard(self, field: SearchField, pattern: str) -> "ElasticDocumentQuer return self def add_fuzzy( - self, field: SearchField, value: str, fuzziness: int = 2 + self, + field: SearchField, + value: str, + fuzziness: int = 2, ) -> "ElasticDocumentQuery": field_name: str = field.field_name # type: ignore self.query_builder.add_fuzzy(field_name, value, fuzziness) return self def add_boost( - self, field: SearchField, value: str, boost: float + self, + field: SearchField, + value: str, + boost: float, ) -> "ElasticDocumentQuery": self.query_builder.add_boost(field.field_name, value, boost) return self diff --git a/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py index bdfe4285065522..7df89e77c0601c 100644 --- a/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py +++ b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py @@ -42,7 +42,8 @@ def wrapper(*args: _F.args, **kwargs: _F.kwargs) -> _T: # We need a type ignore here because there's no way for us to require that # the args and kwargs are hashable while using ParamSpec. key: _Key = cachetools.keys.hashkey( - *args, **{k: v for k, v in kwargs.items() if "cache_exclude" not in k} + *args, + **{k: v for k, v in kwargs.items() if "cache_exclude" not in k}, ) # type: ignore with cache_lock: diff --git a/metadata-ingestion/src/datahub/utilities/sql_formatter.py b/metadata-ingestion/src/datahub/utilities/sql_formatter.py index 5b62c10378e995..649de494b4d7d0 100644 --- a/metadata-ingestion/src/datahub/utilities/sql_formatter.py +++ b/metadata-ingestion/src/datahub/utilities/sql_formatter.py @@ -18,7 +18,9 @@ def format_sql_query(query: str, **options: Any) -> str: def trim_query( - query: str, budget_per_query: int, query_trimmer_string: str = " ..." + query: str, + budget_per_query: int, + query_trimmer_string: str = " ...", ) -> str: trimmed_query = query if len(query) > budget_per_query: @@ -27,6 +29,6 @@ def trim_query( trimmed_query = query[:end_index] + query_trimmer_string else: raise Exception( - "Budget per query is too low. Please, decrease the number of top_n_queries." + "Budget per query is too low. Please, decrease the number of top_n_queries.", ) return trimmed_query diff --git a/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py b/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py index cf92336c68cdf6..7bb97a4e08ca10 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py +++ b/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py @@ -84,7 +84,7 @@ def scalar(self) -> Any: return row[0] elif self._result: raise MultipleResultsFound( - "Multiple rows were found when exactly one was required" + "Multiple rows were found when exactly one was required", ) return None @@ -148,17 +148,17 @@ class SQLAlchemyQueryCombiner: # The Python GIL ensures that modifications to the report's counters # are safe. report: SQLAlchemyQueryCombinerReport = dataclasses.field( - default_factory=SQLAlchemyQueryCombinerReport + default_factory=SQLAlchemyQueryCombinerReport, ) # There will be one main greenlet per thread. As such, queries will be # queued according to the main greenlet's thread ID. We also keep track # of the greenlets we spawn for bookkeeping purposes. _queries_by_thread_lock: threading.Lock = dataclasses.field( - default_factory=lambda: threading.Lock() + default_factory=lambda: threading.Lock(), ) _greenlets_by_thread_lock: threading.Lock = dataclasses.field( - default_factory=lambda: threading.Lock() + default_factory=lambda: threading.Lock(), ) _queries_by_thread: Dict[greenlet.greenlet, Dict[str, _QueryFuture]] = ( dataclasses.field(default_factory=lambda: collections.defaultdict(dict)) @@ -186,7 +186,8 @@ def _get_queue(self, main_greenlet: greenlet.greenlet) -> Dict[str, _QueryFuture return self._queries_by_thread.setdefault(main_greenlet, {}) def _get_greenlet_pool( - self, main_greenlet: greenlet.greenlet + self, + main_greenlet: greenlet.greenlet, ) -> Set[greenlet.greenlet]: assert main_greenlet.parent is None @@ -194,7 +195,11 @@ def _get_greenlet_pool( return self._greenlets_by_thread[main_greenlet] def _handle_execute( - self, conn: Connection, query: Any, multiparams: Any, params: Any + self, + conn: Connection, + query: Any, + multiparams: Any, + params: Any, ) -> Tuple[bool, Optional[_QueryFuture]]: # Returns True with result if the query was handled, False if it # should be executed normally using the fallback method. @@ -224,7 +229,7 @@ def _handle_execute( assert len(get_query_columns(query)) > 0 except AttributeError as e: logger.debug( - f"Query of type: '{type(query)}' does not contain attributes required by 'get_query_columns()'. AttributeError: {e}" + f"Query of type: '{type(query)}' does not contain attributes required by 'get_query_columns()'. AttributeError: {e}", ) return False, None @@ -246,7 +251,10 @@ def _handle_execute( @contextlib.contextmanager def activate(self) -> Iterator["SQLAlchemyQueryCombiner"]: def _sa_execute_fake( - conn: Connection, query: Any, *args: Any, **kwargs: Any + conn: Connection, + query: Any, + *args: Any, + **kwargs: Any, ) -> Any: try: self.report.total_queries += 1 @@ -255,7 +263,7 @@ def _sa_execute_fake( if not self.catch_exceptions: raise e logger.warning( - f"Failed to execute query normally, using fallback: {str(query)}" + f"Failed to execute query normally, using fallback: {str(query)}", ) logger.debug("Failed to execute query normally", exc_info=e) self.report.query_exceptions += 1 @@ -274,7 +282,8 @@ def _sa_execute_fake( with _sa_execute_method_patching_lock: with unittest.mock.patch( - "sqlalchemy.engine.Connection.execute", _sa_execute_fake + "sqlalchemy.engine.Connection.execute", + _sa_execute_fake, ): yield self @@ -301,7 +310,7 @@ def _execute_queue(self, main_greenlet: greenlet.greenlet) -> None: pending_queue = {k: v for k, v in full_queue.items() if not v.done} pending_queue = dict( - itertools.islice(pending_queue.items(), MAX_QUERIES_TO_COMBINE_AT_ONCE) + itertools.islice(pending_queue.items(), MAX_QUERIES_TO_COMBINE_AT_ONCE), ) if pending_queue: @@ -323,7 +332,7 @@ def _execute_queue(self, main_greenlet: greenlet.greenlet) -> None: for col in get_query_columns(cte) ] for _, cte in ctes.items() - ] + ], ) combined_query = sqlalchemy.select(combined_cols) for cte in ctes.values(): @@ -402,7 +411,7 @@ def flush(self) -> None: if not self.serial_execution_fallback_enabled: raise e logger.warning( - "Failed to execute queue using combiner, will fallback to execute one by one." + "Failed to execute queue using combiner, will fallback to execute one by one.", ) logger.debug("Failed to execute queue using combiner", exc_info=e) self.report.query_exceptions += 1 diff --git a/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py index ad94c6904e2807..830c60b0054f55 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py +++ b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py @@ -46,11 +46,14 @@ class SqlAlchemyColumnToAvroConverter: @classmethod def get_avro_type( - cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool + cls, + column_type: Union[types.TypeEngine, STRUCT, MapType], + nullable: bool, ) -> Dict[str, Any]: """Determines the concrete AVRO schema type for a SQLalchemy-typed column""" if isinstance( - column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys()) + column_type, + tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys()), ): return { "type": cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE[type(column_type)], @@ -95,17 +98,19 @@ def get_avro_type( return { "type": "map", "values": cls.get_avro_type( - column_type=value_type, nullable=nullable + column_type=value_type, + nullable=nullable, ), "native_data_type": str(column_type), "key_type": cls.get_avro_type( - column_type=key_type, nullable=nullable + column_type=key_type, + nullable=nullable, ), "key_native_data_type": str(key_type), } except Exception as e: logger.warning( - f"Unable to parse MapType {column_type} the error was: {e}" + f"Unable to parse MapType {column_type} the error was: {e}", ) return { "type": "map", @@ -122,9 +127,10 @@ def get_avro_type( { "name": field_name, "type": cls.get_avro_type( - column_type=field_type, nullable=nullable + column_type=field_type, + nullable=nullable, ), - } + }, ) struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}" try: @@ -167,9 +173,10 @@ def get_avro_for_sqlalchemy_column( { "name": column_name, "type": cls.get_avro_type( - column_type=column_type, nullable=nullable + column_type=column_type, + nullable=nullable, ), - } + }, ], } return cls.get_avro_type(column_type=column_type, nullable=nullable) @@ -211,7 +218,7 @@ def get_schema_fields_for_sqlalchemy_column( ) except Exception as e: logger.warning( - f"Unable to parse column {column_name} and type {column_type} the error was: {e} Traceback: {traceback.format_exc()}" + f"Unable to parse column {column_name} and type {column_type} the error was: {e} Traceback: {traceback.format_exc()}", ) # fallback description in case any exception occurred @@ -223,7 +230,7 @@ def get_schema_fields_for_sqlalchemy_column( column_type, inspector, ), - ) + ), ] # for all non-nested data types an additional modification of the `fieldPath` property is required @@ -249,7 +256,8 @@ def get_schema_fields_for_sqlalchemy_column( def get_native_data_type_for_sqlalchemy_type( - column_type: types.TypeEngine, inspector: Inspector + column_type: types.TypeEngine, + inspector: Inspector, ) -> str: if isinstance(column_type, types.NullType): return column_type.__visit_name__ @@ -258,7 +266,7 @@ def get_native_data_type_for_sqlalchemy_type( return column_type.compile(dialect=inspector.dialect) except Exception as e: logger.debug( - f"Unable to compile sqlalchemy type {column_type} the error was: {e}" + f"Unable to compile sqlalchemy type {column_type} the error was: {e}", ) if ( diff --git a/metadata-ingestion/src/datahub/utilities/sqllineage_patch.py b/metadata-ingestion/src/datahub/utilities/sqllineage_patch.py index 4c237d02727f72..e9a98bafc9949d 100644 --- a/metadata-ingestion/src/datahub/utilities/sqllineage_patch.py +++ b/metadata-ingestion/src/datahub/utilities/sqllineage_patch.py @@ -30,7 +30,7 @@ def end_of_query_cleanup_patch(self, holder: SubQueryLineageHolder) -> None: # for tgt_col in col_grp: tgt_col.parent = tgt_tbl for src_col in tgt_col.to_source_columns( - self._get_alias_mapping_from_table_group(tbl_grp, holder) + self._get_alias_mapping_from_table_group(tbl_grp, holder), ): if holder.write: holder.add_column_lineage(src_col, tgt_col) diff --git a/metadata-ingestion/src/datahub/utilities/stats_collections.py b/metadata-ingestion/src/datahub/utilities/stats_collections.py index c0bd9d058e5d37..80d6ce4dbb292c 100644 --- a/metadata-ingestion/src/datahub/utilities/stats_collections.py +++ b/metadata-ingestion/src/datahub/utilities/stats_collections.py @@ -39,7 +39,9 @@ def as_obj(self) -> Dict[_KT, _VT]: else: try: trimmed_dict = dict( - sorted(self.items(), key=lambda x: x[1], reverse=True)[: self.top_k] + sorted(self.items(), key=lambda x: x[1], reverse=True)[ + : self.top_k + ], ) except TypeError: trimmed_dict = dict(list(self.items())[: self.top_k]) diff --git a/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py b/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py index ab8987a7d2e8b2..8bdae57db8263d 100644 --- a/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py +++ b/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py @@ -22,7 +22,8 @@ def process( out_q: queue.Queue[T] = queue.Queue() def _worker_wrapper( - worker_func: Callable[..., Iterable[T]], *args: Any + worker_func: Callable[..., Iterable[T]], + *args: Any, ) -> None: for item in worker_func(*args): out_q.put(item) diff --git a/metadata-ingestion/src/datahub/utilities/threading_timeout.py b/metadata-ingestion/src/datahub/utilities/threading_timeout.py index e2caf57ad2116a..5fb41328271b4a 100644 --- a/metadata-ingestion/src/datahub/utilities/threading_timeout.py +++ b/metadata-ingestion/src/datahub/utilities/threading_timeout.py @@ -36,7 +36,7 @@ def threading_timeout(timeout: float) -> ContextManager[None]: if not _is_cpython(): raise RuntimeError( - f"Timeout is only supported on CPython, not {platform.python_implementation()}" + f"Timeout is only supported on CPython, not {platform.python_implementation()}", ) return _ThreadingTimeout(timeout, swallow_exc=False) diff --git a/metadata-ingestion/src/datahub/utilities/type_annotations.py b/metadata-ingestion/src/datahub/utilities/type_annotations.py index b139a0ed235da8..16cd7f0f996c1b 100644 --- a/metadata-ingestion/src/datahub/utilities/type_annotations.py +++ b/metadata-ingestion/src/datahub/utilities/type_annotations.py @@ -7,7 +7,9 @@ def get_class_from_annotation( - derived_cls: Type, super_class: Type, target_class: Type[TargetClass] + derived_cls: Type, + super_class: Type, + target_class: Type[TargetClass], ) -> Optional[Type[TargetClass]]: """ Attempts to find an instance of target_class in the type annotations of derived_class. diff --git a/metadata-ingestion/src/datahub/utilities/unified_diff.py b/metadata-ingestion/src/datahub/utilities/unified_diff.py index c896fd4df4d8f2..bac373fe5c56fe 100644 --- a/metadata-ingestion/src/datahub/utilities/unified_diff.py +++ b/metadata-ingestion/src/datahub/utilities/unified_diff.py @@ -125,7 +125,7 @@ def find_hunk_start(source_lines: List[str], hunk: Hunk) -> int: return hunk.source_start - 1 # Default to the original start if no context logger.debug( - f"Searching for {len(context_lines)} context lines, starting with {context_lines[0]}" + f"Searching for {len(context_lines)} context lines, starting with {context_lines[0]}", ) # Define the range to search for the context lines @@ -177,14 +177,14 @@ def apply_hunk(result_lines: List[str], hunk: Hunk, hunk_index: int) -> None: # If there's context or deletions past the end of the file, that's an error. if line_index < len(hunk.lines): raise DiffApplyError( - f"Found context or deletions after end of file in hunk {hunk_index + 1}" + f"Found context or deletions after end of file in hunk {hunk_index + 1}", ) break if prefix == "-": if result_lines[current_line].strip() != content.strip(): raise DiffApplyError( - f"Removing line that doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'" + f"Removing line that doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'", ) result_lines.pop(current_line) elif prefix == "+": @@ -193,12 +193,12 @@ def apply_hunk(result_lines: List[str], hunk: Hunk, hunk_index: int) -> None: elif prefix == " ": if result_lines[current_line].strip() != content.strip(): raise DiffApplyError( - f"Context line doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'" + f"Context line doesn't exactly match. Expected: '{content.strip()}', Found: '{result_lines[current_line].strip()}'", ) current_line += 1 else: raise DiffApplyError( - f"Invalid line prefix '{prefix}' in hunk {hunk_index + 1}, line {line_index + 1}" + f"Invalid line prefix '{prefix}' in hunk {hunk_index + 1}, line {line_index + 1}", ) diff --git a/metadata-ingestion/src/datahub/utilities/urns/_urn_base.py b/metadata-ingestion/src/datahub/utilities/urns/_urn_base.py index e8e22cd85ac9ff..f1a9d9b4055128 100644 --- a/metadata-ingestion/src/datahub/utilities/urns/_urn_base.py +++ b/metadata-ingestion/src/datahub/utilities/urns/_urn_base.py @@ -127,17 +127,17 @@ def from_string(cls, urn_str: Union[str, "Urn"], /) -> Self: if not urn_str.startswith("urn:li:"): raise InvalidUrnError( - f"Invalid urn string: {urn_str}. Urns should start with 'urn:li:'" + f"Invalid urn string: {urn_str}. Urns should start with 'urn:li:'", ) parts: List[str] = urn_str.split(":", maxsplit=3) if len(parts) != 4: raise InvalidUrnError( - f"Invalid urn string: {urn_str}. Expect 4 parts from urn string but found {len(parts)}" + f"Invalid urn string: {urn_str}. Expect 4 parts from urn string but found {len(parts)}", ) if "" in parts: raise InvalidUrnError( - f"Invalid urn string: {urn_str}. There should not be empty parts in urn string." + f"Invalid urn string: {urn_str}. There should not be empty parts in urn string.", ) _urn, _li, entity_type, entity_ids_str = parts @@ -150,14 +150,14 @@ def from_string(cls, urn_str: Union[str, "Urn"], /) -> Self: # with Urn.from_string(), that's fine. However, if we're called as # DatasetUrn.from_string('urn:li:corpuser:foo'), that should throw an error. raise InvalidUrnError( - f"Passed an urn of type {entity_type} to the from_string method of {cls.__name__}. Use Urn.from_string() or {UrnCls.__name__}.from_string() instead." + f"Passed an urn of type {entity_type} to the from_string method of {cls.__name__}. Use Urn.from_string() or {UrnCls.__name__}.from_string() instead.", ) return UrnCls._parse_ids(entity_ids) # type: ignore # Fallback for unknown types. if cls != Urn: raise InvalidUrnError( - f"Unknown urn type {entity_type} for urn {urn_str} of type {cls}" + f"Unknown urn type {entity_type} for urn {urn_str} of type {cls}", ) return cls(entity_type, entity_ids) @@ -186,7 +186,7 @@ def __eq__(self, other: object) -> bool: def __lt__(self, other: object) -> bool: if not isinstance(other, Urn): raise TypeError( - f"'<' not supported between instances of '{type(self)}' and '{type(other)}'" + f"'<' not supported between instances of '{type(self)}' and '{type(other)}'", ) return self.urn() < other.urn() diff --git a/metadata-ingestion/src/datahub/utilities/urns/urn_iter.py b/metadata-ingestion/src/datahub/utilities/urns/urn_iter.py index d792e0bba649dd..35909c6e041720 100644 --- a/metadata-ingestion/src/datahub/utilities/urns/urn_iter.py +++ b/metadata-ingestion/src/datahub/utilities/urns/urn_iter.py @@ -15,7 +15,8 @@ def _add_prefix_to_paths( - prefix: _Path, items: List[Tuple[str, _Path]] + prefix: _Path, + items: List[Tuple[str, _Path]], ) -> List[Tuple[str, _Path]]: return [(urn, [*prefix, *path]) for urn, path in items] @@ -40,12 +41,13 @@ def list_urns_with_path( if model.entityKeyAspect: urns.extend( _add_prefix_to_paths( - ["entityKeyAspect"], list_urns_with_path(model.entityKeyAspect) - ) + ["entityKeyAspect"], + list_urns_with_path(model.entityKeyAspect), + ), ) if model.aspect: urns.extend( - _add_prefix_to_paths(["aspect"], list_urns_with_path(model.aspect)) + _add_prefix_to_paths(["aspect"], list_urns_with_path(model.aspect)), ) return urns @@ -65,7 +67,7 @@ def list_urns_with_path( for i, item in enumerate(value): if isinstance(item, DictWrapper): urns.extend( - _add_prefix_to_paths([key, i], list_urns_with_path(item)) + _add_prefix_to_paths([key, i], list_urns_with_path(item)), ) elif is_urn: urns.append((item, [key, i])) @@ -134,7 +136,9 @@ def _modify_at_path( def lowercase_dataset_urn(dataset_urn: str) -> str: cur_urn = DatasetUrn.from_string(dataset_urn) new_urn = DatasetUrn( - platform=cur_urn.platform, name=cur_urn.name.lower(), env=cur_urn.env + platform=cur_urn.platform, + name=cur_urn.name.lower(), + env=cur_urn.env, ) return str(new_urn) diff --git a/metadata-ingestion/tests/conftest.py b/metadata-ingestion/tests/conftest.py index 4685faabfcb285..6553884184a4c5 100644 --- a/metadata-ingestion/tests/conftest.py +++ b/metadata-ingestion/tests/conftest.py @@ -64,7 +64,8 @@ def pytest_addoption(parser): def pytest_collection_modifyitems( - config: pytest.Config, items: List[pytest.Item] + config: pytest.Config, + items: List[pytest.Item], ) -> None: # https://docs.pytest.org/en/latest/reference/reference.html#pytest.hookspec.pytest_collection_modifyitems # Adapted from https://stackoverflow.com/a/57046943/5004662. diff --git a/metadata-ingestion/tests/integration/athena/test_athena_source.py b/metadata-ingestion/tests/integration/athena/test_athena_source.py index 56e7cbe6b3e2dd..5e0d979d22e1fc 100644 --- a/metadata-ingestion/tests/integration/athena/test_athena_source.py +++ b/metadata-ingestion/tests/integration/athena/test_athena_source.py @@ -24,9 +24,11 @@ def test_athena_source_ingestion(pytestconfig, tmp_path): # Mock dependencies with patch.object( - AthenaSource, "get_inspectors" + AthenaSource, + "get_inspectors", ) as mock_get_inspectors, patch.object( - AthenaSource, "get_table_properties" + AthenaSource, + "get_table_properties", ) as mock_get_table_properties: # Mock engine and inspectors mock_inspector = MagicMock() @@ -91,7 +93,8 @@ def mock_get_view_definition(view_name, schema): { "name": "job_history", "type": MapType( - String(), STRUCT(year=INTEGER(), company=String(), role=String()) + String(), + STRUCT(year=INTEGER(), company=String(), role=String()), ), "nullable": True, "default": None, diff --git a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py index 5bb078a368dd50..08d3f151fe84b4 100644 --- a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py +++ b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py @@ -57,13 +57,13 @@ def run_ingest( ) with patch( - "datahub.ingestion.source.identity.azure_ad.AzureADSource.get_token" + "datahub.ingestion.source.identity.azure_ad.AzureADSource.get_token", ) as mock_token, patch( - "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_users" + "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_users", ) as mock_users, patch( - "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_groups" + "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_groups", ) as mock_groups, patch( - "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_group_members" + "datahub.ingestion.source.identity.azure_ad.AzureADSource._get_azure_ad_group_members", ) as mock_group_users, patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", mock_datahub_graph, @@ -71,7 +71,11 @@ def run_ingest( mock_checkpoint.return_value = mock_datahub_graph mocked_functions_reference( - test_resources_dir, mock_token, mock_users, mock_groups, mock_group_users + test_resources_dir, + mock_token, + mock_users, + mock_groups, + mock_group_users, ) # Run an azure usage ingestion run. @@ -102,7 +106,7 @@ def load_test_resources(test_resources_dir): azure_ad_nested_groups_members_json_file.open() ) as azure_ad_nested_groups_users_json: reference_nested_groups_users = json.loads( - azure_ad_nested_groups_users_json.read() + azure_ad_nested_groups_users_json.read(), ) return ( @@ -131,7 +135,7 @@ def mocked_functions( # mock users and groups response users, groups, nested_group, nested_group_members = load_test_resources( - test_resources_dir + test_resources_dir, ) mock_users.return_value = iter(list([users])) mock_groups.return_value = ( @@ -200,7 +204,7 @@ def test_azure_ad_config(): ingest_users=True, ingest_groups=True, ingest_group_membership=True, - ) + ), ) # Sanity on required configurations @@ -249,7 +253,9 @@ def test_azure_ad_source_default_configs(pytestconfig, mock_datahub_graph, tmp_p @freeze_time(FROZEN_TIME) def test_azure_ad_source_empty_group_membership( - pytestconfig, mock_datahub_graph, tmp_path + pytestconfig, + mock_datahub_graph, + tmp_path, ): test_resources_dir: pathlib.Path = ( pytestconfig.rootpath / "tests/integration/azure_ad" @@ -330,7 +336,10 @@ def test_azure_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_p @freeze_time(FROZEN_TIME) def test_azure_ad_stateful_ingestion( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): new_recipe = default_recipe(tmp_path) @@ -366,10 +375,12 @@ def test_azure_ad_stateful_ingestion( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline1, expected_providers=1 + pipeline=pipeline1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline2, expected_providers=1 + pipeline=pipeline2, + expected_providers=1, ) # Perform all assertions on the states. The deleted Dashboard should not be @@ -378,7 +389,7 @@ def test_azure_ad_stateful_ingestion( state2 = checkpoint2.state difference_dashboard_urns = list( - state1.get_urns_not_in(type="corpGroup", other_checkpoint_state=state2) + state1.get_urns_not_in(type="corpGroup", other_checkpoint_state=state2), ) assert len(difference_dashboard_urns) == 1 diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py index 2dd320041a1132..9c22b25c6ee83e 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py +++ b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py @@ -48,7 +48,7 @@ def random_email(): [ random.choice(string.ascii_lowercase) for i in range(random.randint(10, 15)) - ] + ], ) + "@xyz.com" ) @@ -80,7 +80,7 @@ def recipe(mcp_output_path: str, source_config_override: dict = {}) -> dict: config=DataHubClassifierConfig( minimum_values_threshold=1, ), - ) + ), ], max_workers=1, ).dict(), @@ -145,12 +145,14 @@ def side_effect(*args: Any) -> Optional[PlatformResource]: get_datasets_for_project_id.return_value = [ # BigqueryDataset(name=dataset_name, location="US") BigqueryDataset( - name=dataset_name, location="US", labels={"priority": "medium:test"} - ) + name=dataset_name, + location="US", + labels={"priority": "medium:test"}, + ), ] table_list_item = TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}} + {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}}, ) table_name = "table-1" snapshot_table_name = "snapshot-table-1" @@ -271,15 +273,15 @@ def test_bigquery_v2_project_labels_ingest( mcp_output_path = "{}/{}".format(tmp_path, "bigquery_project_label_mcp_output.json") get_datasets_for_project_id.return_value = [ - BigqueryDataset(name="bigquery-dataset-1") + BigqueryDataset(name="bigquery-dataset-1"), ] get_projects_with_labels.return_value = [ - BigqueryProject(id="dev", name="development") + BigqueryProject(id="dev", name="development"), ] table_list_item = TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}} + {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}}, ) table_name = "table-1" get_core_table_details.return_value = {table_name: table_list_item} @@ -306,7 +308,7 @@ def test_bigquery_v2_project_labels_ingest( is_partition_column=False, cluster_column_position=None, ), - ] + ], } get_sample_data_for_table.return_value = { "age": [random.randint(1, 80) for i in range(20)], @@ -328,7 +330,7 @@ def test_bigquery_v2_project_labels_ingest( del pipeline_config_dict["source"]["config"]["project_ids"] pipeline_config_dict["source"]["config"]["project_labels"] = [ - "environment:development" + "environment:development", ] run_and_get_pipeline(pipeline_config_dict) @@ -371,11 +373,11 @@ def test_bigquery_queries_v2_ingest( dataset_name = "bigquery-dataset-1" get_datasets_for_project_id.return_value = [ - BigqueryDataset(name=dataset_name, location="US") + BigqueryDataset(name=dataset_name, location="US"), ] table_list_item = TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}} + {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}}, ) table_name = "table-1" snapshot_table_name = "snapshot-table-1" @@ -492,10 +494,16 @@ def test_bigquery_queries_v2_lineage_usage_ingest( get_bigquery_client.return_value = client client.list_tables.return_value = [ TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": "table-1"}} + { + "tableReference": { + "projectId": "", + "datasetId": "", + "tableId": "table-1", + }, + }, ), TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": "view-1"}} + {"tableReference": {"projectId": "", "datasetId": "", "tableId": "view-1"}}, ), ] @@ -624,11 +632,11 @@ def test_bigquery_lineage_v2_ingest_view_snapshots( dataset_name = "bigquery-dataset-1" get_datasets_for_project_id.return_value = [ - BigqueryDataset(name=dataset_name, location="US") + BigqueryDataset(name=dataset_name, location="US"), ] table_list_item = TableListItem( - {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}} + {"tableReference": {"projectId": "", "datasetId": "", "tableId": ""}}, ) table_name = "table-1" snapshot_table_name = "snapshot-table-1" diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py index 806779475dea9d..8bb1570d5657f1 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py +++ b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py @@ -102,7 +102,8 @@ def test_queries_ingestion(project_client, client, pytestconfig, monkeypatch, tm def test_source_close_cleans_tmp(projects_client, client, tmp_path): with patch("tempfile.tempdir", str(tmp_path)): source = BigQueryQueriesSource.create( - {"project_ids": ["project1"]}, PipelineContext("run-id") + {"project_ids": ["project1"]}, + PipelineContext("run-id"), ) assert len(os.listdir(tmp_path)) > 0 # This closes QueriesExtractor which in turn closes SqlParsingAggregator diff --git a/metadata-ingestion/tests/integration/business-glossary/test_business_glossary.py b/metadata-ingestion/tests/integration/business-glossary/test_business_glossary.py index 74cf9aa3b528f2..644704344ea879 100644 --- a/metadata-ingestion/tests/integration/business-glossary/test_business_glossary.py +++ b/metadata-ingestion/tests/integration/business-glossary/test_business_glossary.py @@ -11,7 +11,9 @@ def get_default_recipe( - glossary_yml_file_path: str, event_output_file_path: str, enable_auto_id: bool + glossary_yml_file_path: str, + event_output_file_path: str, + enable_auto_id: bool, ) -> Dict[str, Any]: return { "source": { @@ -57,7 +59,7 @@ def test_glossary_ingest( glossary_yml_file_path=f"{test_resources_dir}/business_glossary.yml", event_output_file_path=output_mces_path, enable_auto_id=enable_auto_id, - ) + ), ) pipeline.ctx.graph = mock_datahub_graph_instance pipeline.run() @@ -89,7 +91,7 @@ def test_single_owner_types( glossary_yml_file_path=f"{test_resources_dir}/single_owner_types.yml", event_output_file_path=output_mces_path, enable_auto_id=False, - ) + ), ) pipeline.ctx.graph = mock_datahub_graph_instance pipeline.run() @@ -122,7 +124,7 @@ def test_multiple_owners_same_type( glossary_yml_file_path=f"{test_resources_dir}/multiple_owners_same_type.yml", event_output_file_path=output_mces_path, enable_auto_id=False, - ) + ), ) pipeline.ctx.graph = mock_datahub_graph_instance pipeline.run() @@ -155,7 +157,7 @@ def test_multiple_owners_different_types( glossary_yml_file_path=f"{test_resources_dir}/multiple_owners_different_types.yml", event_output_file_path=output_mces_path, enable_auto_id=False, - ) + ), ) pipeline.ctx.graph = mock_datahub_graph_instance pipeline.run() @@ -186,7 +188,7 @@ def test_custom_ownership_urns( glossary_yml_file_path=f"{test_resources_dir}/custom_ownership_urns.yml", event_output_file_path=output_mces_path, enable_auto_id=False, - ) + ), ) pipeline.ctx.graph = mock_datahub_graph_instance pipeline.run() diff --git a/metadata-ingestion/tests/integration/cassandra/test_cassandra.py b/metadata-ingestion/tests/integration/cassandra/test_cassandra.py index d561308aaad20e..86d8edfe0c133e 100644 --- a/metadata-ingestion/tests/integration/cassandra/test_cassandra.py +++ b/metadata-ingestion/tests/integration/cassandra/test_cassandra.py @@ -15,7 +15,8 @@ def test_cassandra_ingest(docker_compose_runner, pytestconfig, tmp_path): test_resources_dir = pytestconfig.rootpath / "tests/integration/cassandra" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "cassandra" + test_resources_dir / "docker-compose.yml", + "cassandra", ) as docker_services: wait_for_port(docker_services, "test-cassandra", 9042) @@ -39,7 +40,7 @@ def test_cassandra_ingest(docker_compose_runner, pytestconfig, tmp_path): "filename": f"{tmp_path}/cassandra_mcps.json", }, }, - } + }, ) pipeline_default_platform_instance.run() pipeline_default_platform_instance.raise_from_status() diff --git a/metadata-ingestion/tests/integration/circuit_breaker/test_circuit_breaker.py b/metadata-ingestion/tests/integration/circuit_breaker/test_circuit_breaker.py index b9c661935c5e06..1816e888fff5a1 100644 --- a/metadata-ingestion/tests/integration/circuit_breaker/test_circuit_breaker.py +++ b/metadata-ingestion/tests/integration/circuit_breaker/test_circuit_breaker.py @@ -15,11 +15,11 @@ except ImportError: pass lastUpdatedResponseBeforeLastAssertion = { - "dataset": {"operations": [{"lastUpdatedTimestamp": 1640685600000}]} + "dataset": {"operations": [{"lastUpdatedTimestamp": 1640685600000}]}, } lastUpdatedResponseAfterLastAssertion = { - "dataset": {"operations": [{"lastUpdatedTimestamp": 1652450039000}]} + "dataset": {"operations": [{"lastUpdatedTimestamp": 1652450039000}]}, } @@ -37,7 +37,7 @@ def test_operation_circuit_breaker_with_empty_response(pytestconfig): cb = OperationCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:hive,SampleHiveDataset,PROD))" + urn="urn:li:dataset:(urn:li:dataPlatform:hive,SampleHiveDataset,PROD))", ) assert result is True @@ -57,7 +57,7 @@ def test_operation_circuit_breaker_with_valid_response(pytestconfig): cb = OperationCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,my_project.jaffle_shop.customers,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,my_project.jaffle_shop.customers,PROD)", ) assert result is False @@ -77,7 +77,7 @@ def test_operation_circuit_breaker_with_not_recent_operation(pytestconfig): cb = OperationCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,my_project.jaffle_shop.customers,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,my_project.jaffle_shop.customers,PROD)", ) assert result is True @@ -96,7 +96,7 @@ def test_assertion_circuit_breaker_with_empty_response(pytestconfig): cb = AssertionCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)", ) assert result is True @@ -115,7 +115,7 @@ def test_assertion_circuit_breaker_with_no_error(pytestconfig): cb = AssertionCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)", ) assert result is False @@ -133,7 +133,7 @@ def test_assertion_circuit_breaker_updated_at_after_last_assertion(pytestconfig) config = AssertionCircuitBreakerConfig(datahub_host="dummy") cb = AssertionCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)", ) assert result is True @@ -150,6 +150,6 @@ def test_assertion_circuit_breaker_assertion_with_active_assertion(pytestconfig) config = AssertionCircuitBreakerConfig(datahub_host="dummy") cb = AssertionCircuitBreaker(config) result = cb.is_circuit_breaker_active( - urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:postgres,postgres1.postgres.public.foo1,PROD)", ) assert result is True # add assertion here diff --git a/metadata-ingestion/tests/integration/clickhouse/test_clickhouse.py b/metadata-ingestion/tests/integration/clickhouse/test_clickhouse.py index 2c065b5b357b52..38cc737ffa4f39 100644 --- a/metadata-ingestion/tests/integration/clickhouse/test_clickhouse.py +++ b/metadata-ingestion/tests/integration/clickhouse/test_clickhouse.py @@ -15,7 +15,8 @@ def test_clickhouse_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): test_resources_dir = pytestconfig.rootpath / "tests/integration/clickhouse" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "clickhouse" + test_resources_dir / "docker-compose.yml", + "clickhouse", ) as docker_services: wait_for_port(docker_services, "testclickhouse", 8123, timeout=120) # Run the metadata ingestion pipeline. @@ -42,11 +43,15 @@ def test_clickhouse_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_t @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_clickhouse_ingest_uri_form( - docker_compose_runner, pytestconfig, tmp_path, mock_time + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/clickhouse" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "clickhouse" + test_resources_dir / "docker-compose.yml", + "clickhouse", ) as docker_services: wait_for_port(docker_services, "testclickhouse", 8123, timeout=120) diff --git a/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py b/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py index 4a447037d1dda2..e441bf8658cb19 100644 --- a/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py +++ b/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py @@ -16,7 +16,7 @@ def test_csv_enricher_config(): write_semantics="OVERRIDE", delimiter=",", array_delimiter="|", - ) + ), ) assert config @@ -45,7 +45,7 @@ def test_csv_enricher_source(pytestconfig, tmp_path): "filename": f"{tmp_path}/csv_enricher.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/dbt/test_dbt.py b/metadata-ingestion/tests/integration/dbt/test_dbt.py index c6a3dc4fd590bd..0cc16eafbe771c 100644 --- a/metadata-ingestion/tests/integration/dbt/test_dbt.py +++ b/metadata-ingestion/tests/integration/dbt/test_dbt.py @@ -104,7 +104,7 @@ def set_paths( "match": ".*", "operation": "add_tag", "config": {"tag": "{{ $match }}"}, - } + }, }, }, **self.source_config_modifiers, @@ -138,7 +138,7 @@ def set_paths( manifest_file="dbt_manifest_complex_owner_patterns.json", source_config_modifiers={ "node_name_pattern": { - "deny": ["source.sample_dbt.pagila.payment_p2020_06"] + "deny": ["source.sample_dbt.pagila.payment_p2020_06"], }, "owner_extraction_pattern": "(.*)(?P(?<=\\().*?(?=\\)))", "strip_user_ids_from_email": True, @@ -263,7 +263,7 @@ def test_dbt_ingest( "type": "file", "config": config.sink_config, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -299,17 +299,18 @@ def test_dbt_ingest( @freeze_time(FROZEN_TIME) def test_dbt_test_connection(test_resources_dir, config_dict, is_success): config_dict["manifest_path"] = str( - (test_resources_dir / config_dict["manifest_path"]).resolve() + (test_resources_dir / config_dict["manifest_path"]).resolve(), ) config_dict["catalog_path"] = str( - (test_resources_dir / config_dict["catalog_path"]).resolve() + (test_resources_dir / config_dict["catalog_path"]).resolve(), ) report = test_connection_helpers.run_test_connection(DBTCoreSource, config_dict) if is_success: test_connection_helpers.assert_basic_connectivity_success(report) else: test_connection_helpers.assert_basic_connectivity_failure( - report, "No such file or directory" + report, + "No such file or directory", ) @@ -327,23 +328,23 @@ def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwar config=DBTCoreConfig( **_default_dbt_source_args, manifest_path=str( - (test_resources_dir / "jaffle_shop_manifest.json").resolve() + (test_resources_dir / "jaffle_shop_manifest.json").resolve(), ), catalog_path=str( - (test_resources_dir / "jaffle_shop_catalog.json").resolve() + (test_resources_dir / "jaffle_shop_catalog.json").resolve(), ), target_platform="postgres", run_results_paths=[ str( ( test_resources_dir / "jaffle_shop_test_results.json" - ).resolve() - ) + ).resolve(), + ), ], ), ), sink=DynamicTypedConfig(type="file", config={"filename": str(output_file)}), - ) + ), ) pipeline.run() pipeline.raise_from_status() @@ -359,7 +360,11 @@ def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwar @pytest.mark.integration @freeze_time(FROZEN_TIME) def test_dbt_tests_only_assertions( - test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, + **kwargs, ): # Run the metadata ingestion pipeline. output_file = tmp_path / "test_only_assertions.json" @@ -371,26 +376,26 @@ def test_dbt_tests_only_assertions( config=DBTCoreConfig( **_default_dbt_source_args, manifest_path=str( - (test_resources_dir / "jaffle_shop_manifest.json").resolve() + (test_resources_dir / "jaffle_shop_manifest.json").resolve(), ), catalog_path=str( - (test_resources_dir / "jaffle_shop_catalog.json").resolve() + (test_resources_dir / "jaffle_shop_catalog.json").resolve(), ), target_platform="postgres", run_results_paths=[ str( ( test_resources_dir / "jaffle_shop_test_results.json" - ).resolve() - ) + ).resolve(), + ), ], entities_enabled=DBTEntitiesEnabled( - test_results=EmitDirective.ONLY + test_results=EmitDirective.ONLY, ), ), ), sink=DynamicTypedConfig(type="file", config={"filename": str(output_file)}), - ) + ), ) pipeline.run() pipeline.raise_from_status() @@ -407,14 +412,17 @@ def test_dbt_tests_only_assertions( number_of_valid_assertions_in_test_results = 24 assert ( mce_helpers.assert_entity_urn_like( - entity_type="assertion", regex_pattern="urn:li:assertion:", file=output_file + entity_type="assertion", + regex_pattern="urn:li:assertion:", + file=output_file, ) == number_of_valid_assertions_in_test_results ) # no assertionInfo should be emitted with pytest.raises( - AssertionError, match="Failed to find aspect_name assertionInfo for urns" + AssertionError, + match="Failed to find aspect_name assertionInfo for urns", ): mce_helpers.assert_for_each_entity( entity_type="assertion", @@ -439,7 +447,11 @@ def test_dbt_tests_only_assertions( @pytest.mark.integration @freeze_time(FROZEN_TIME) def test_dbt_only_test_definitions_and_results( - test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, + **kwargs, ): # Run the metadata ingestion pipeline. output_file = tmp_path / "test_only_definitions_and_assertions.json" @@ -451,18 +463,18 @@ def test_dbt_only_test_definitions_and_results( config=DBTCoreConfig( **_default_dbt_source_args, manifest_path=str( - (test_resources_dir / "jaffle_shop_manifest.json").resolve() + (test_resources_dir / "jaffle_shop_manifest.json").resolve(), ), catalog_path=str( - (test_resources_dir / "jaffle_shop_catalog.json").resolve() + (test_resources_dir / "jaffle_shop_catalog.json").resolve(), ), target_platform="postgres", run_results_paths=[ str( ( test_resources_dir / "jaffle_shop_test_results.json" - ).resolve() - ) + ).resolve(), + ), ], entities_enabled=DBTEntitiesEnabled( sources=EmitDirective.NO, @@ -472,7 +484,7 @@ def test_dbt_only_test_definitions_and_results( ), ), sink=DynamicTypedConfig(type="file", config={"filename": str(output_file)}), - ) + ), ) pipeline.run() pipeline.raise_from_status() @@ -488,7 +500,9 @@ def test_dbt_only_test_definitions_and_results( number_of_assertions = 25 assert ( mce_helpers.assert_entity_urn_like( - entity_type="assertion", regex_pattern="urn:li:assertion:", file=output_file + entity_type="assertion", + regex_pattern="urn:li:assertion:", + file=output_file, ) == number_of_assertions ) diff --git a/metadata-ingestion/tests/integration/delta_lake/test_delta_lake_minio.py b/metadata-ingestion/tests/integration/delta_lake/test_delta_lake_minio.py index 6146c6d1a948ca..2f03e0c3fb0939 100644 --- a/metadata-ingestion/tests/integration/delta_lake/test_delta_lake_minio.py +++ b/metadata-ingestion/tests/integration/delta_lake/test_delta_lake_minio.py @@ -35,7 +35,8 @@ def test_resources_dir(pytestconfig): def minio_runner(docker_compose_runner, pytestconfig, test_resources_dir): container_name = "minio_test" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", container_name + test_resources_dir / "docker-compose.yml", + container_name, ) as docker_services: wait_for_port( docker_services, @@ -101,7 +102,7 @@ def test_delta_lake_ingest(pytestconfig, tmp_path, test_resources_dir): "filename": f"{tmp_path}/delta_lake_minio_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/dremio/test_dremio.py b/metadata-ingestion/tests/integration/dremio/test_dremio.py index c286746c68b79d..67b6d1c963c9dc 100644 --- a/metadata-ingestion/tests/integration/dremio/test_dremio.py +++ b/metadata-ingestion/tests/integration/dremio/test_dremio.py @@ -411,7 +411,8 @@ def test_resources_dir(pytestconfig): def mock_dremio_service(docker_compose_runner, pytestconfig, test_resources_dir): # Spin up Dremio and MinIO (for mock S3) services using Docker Compose. with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "dremio" + test_resources_dir / "docker-compose.yml", + "dremio", ) as docker_services: wait_for_port(docker_services, "dremio", 9047, timeout=120) wait_for_port( @@ -431,7 +432,8 @@ def mock_dremio_service(docker_compose_runner, pytestconfig, test_resources_dir) # Ensure the admin and data setup scripts have the right permissions subprocess.run( - ["chmod", "+x", f"{test_resources_dir}/setup_dremio_admin.sh"], check=True + ["chmod", "+x", f"{test_resources_dir}/setup_dremio_admin.sh"], + check=True, ) # Run the setup_dremio_admin.sh script diff --git a/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py b/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py index 4edbbbb3ffc64f..16cfbdf484113b 100644 --- a/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py +++ b/metadata-ingestion/tests/integration/dynamodb/test_dynamodb.py @@ -46,16 +46,16 @@ def test_dynamodb(pytestconfig, tmp_path): "L": [ {"S": "+14150000000"}, {"S": "+14151111111"}, - ] + ], }, "services": { # Map type "M": { "parking": {"BOOL": True}, "wifi": {"S": "Free"}, "hours": { # Map type inside Map for nested structure - "M": {"open": {"S": "08:00"}, "close": {"S": "22:00"}} + "M": {"open": {"S": "08:00"}, "close": {"S": "22:00"}}, }, - } + }, }, }, ) @@ -78,7 +78,7 @@ def test_dynamodb(pytestconfig, tmp_path): "filename": f"{tmp_path}/dynamodb_default_platform_instance_mces.json", }, }, - } + }, ) pipeline_default_platform_instance.run() pipeline_default_platform_instance.raise_from_status() @@ -115,11 +115,11 @@ def test_dynamodb(pytestconfig, tmp_path): description=0, datatype=0, values=0.3, - ) - ) + ), + ), }, ), - ) + ), ], ), }, @@ -130,7 +130,7 @@ def test_dynamodb(pytestconfig, tmp_path): "filename": f"{tmp_path}/dynamodb_platform_instance_mces.json", }, }, - } + }, ) pipeline_with_platform_instance.run() pipeline_with_platform_instance.raise_from_status() diff --git a/metadata-ingestion/tests/integration/feast/test_feast_repository.py b/metadata-ingestion/tests/integration/feast/test_feast_repository.py index 80d7c6311a9589..3d40ac2b9efc03 100644 --- a/metadata-ingestion/tests/integration/feast/test_feast_repository.py +++ b/metadata-ingestion/tests/integration/feast/test_feast_repository.py @@ -34,7 +34,7 @@ def test_feast_repository_ingest(pytestconfig, tmp_path, mock_time): "feast_owner_name": "MOCK_OWNER", "datahub_owner_urn": "urn:li:corpGroup:MOCK_OWNER", "datahub_ownership_type": "BUSINESS_OWNER", - } + }, ], }, }, @@ -44,7 +44,7 @@ def test_feast_repository_ingest(pytestconfig, tmp_path, mock_time): "filename": str(output_path), }, }, - } + }, ) pipeline.run() diff --git a/metadata-ingestion/tests/integration/file/test_file_source.py b/metadata-ingestion/tests/integration/file/test_file_source.py index fb2cfd9fc9f0f4..739400d383aa96 100644 --- a/metadata-ingestion/tests/integration/file/test_file_source.py +++ b/metadata-ingestion/tests/integration/file/test_file_source.py @@ -36,7 +36,7 @@ def test_stateful_ingestion(tmp_path, pytestconfig): }, } with mock.patch( - "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" + "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj", ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline = Pipeline.create(pipeline_config) diff --git a/metadata-ingestion/tests/integration/fivetran/test_fivetran.py b/metadata-ingestion/tests/integration/fivetran/test_fivetran.py index 2e6c2b1370d166..e4e300b7a35218 100644 --- a/metadata-ingestion/tests/integration/fivetran/test_fivetran.py +++ b/metadata-ingestion/tests/integration/fivetran/test_fivetran.py @@ -36,7 +36,8 @@ def default_query_results( - query, connector_query_results=default_connector_query_results + query, + connector_query_results=default_connector_query_results, ): fivetran_log_query = FivetranLogQuery() fivetran_log_query.set_db("test") @@ -45,7 +46,7 @@ def default_query_results( elif query == fivetran_log_query.get_connectors_query(): return connector_query_results elif query == fivetran_log_query.get_table_lineage_query( - connector_ids=["calendar_elected"] + connector_ids=["calendar_elected"], ): return [ { @@ -68,7 +69,7 @@ def default_query_results( }, ] elif query == fivetran_log_query.get_column_lineage_query( - connector_ids=["calendar_elected"] + connector_ids=["calendar_elected"], ): return [ { @@ -103,7 +104,7 @@ def default_query_results( "given_name": "Shubham", "family_name": "Jagtap", "email": "abc.xyz@email.com", - } + }, ] elif query == fivetran_log_query.get_sync_logs_query( syncs_interval=7, @@ -146,7 +147,7 @@ def test_fivetran_with_snowflake_dest(pytestconfig, tmp_path): golden_file = test_resources_dir / "fivetran_snowflake_golden.json" with mock.patch( - "datahub.ingestion.source.fivetran.fivetran_log_api.create_engine" + "datahub.ingestion.source.fivetran.fivetran_log_api.create_engine", ) as mock_create_engine: connection_magic_mock = MagicMock() connection_magic_mock.execute.side_effect = default_query_results @@ -174,18 +175,18 @@ def test_fivetran_with_snowflake_dest(pytestconfig, tmp_path): "connector_patterns": { "allow": [ "postgres", - ] + ], }, "destination_patterns": { "allow": [ "interval_unconstitutional", - ] + ], }, "sources_to_platform_instance": { "calendar_elected": { "database": "postgres_db", "env": "DEV", - } + }, }, }, }, @@ -195,7 +196,7 @@ def test_fivetran_with_snowflake_dest(pytestconfig, tmp_path): "filename": f"{output_file}", }, }, - } + }, ) pipeline.run() @@ -220,7 +221,7 @@ def test_fivetran_with_snowflake_dest_and_null_connector_user(pytestconfig, tmp_ ) with mock.patch( - "datahub.ingestion.source.fivetran.fivetran_log_api.create_engine" + "datahub.ingestion.source.fivetran.fivetran_log_api.create_engine", ) as mock_create_engine: connection_magic_mock = MagicMock() @@ -237,7 +238,8 @@ def test_fivetran_with_snowflake_dest_and_null_connector_user(pytestconfig, tmp_ ] connection_magic_mock.execute.side_effect = partial( - default_query_results, connector_query_results=connector_query_results + default_query_results, + connector_query_results=connector_query_results, ) mock_create_engine.return_value = connection_magic_mock @@ -263,19 +265,19 @@ def test_fivetran_with_snowflake_dest_and_null_connector_user(pytestconfig, tmp_ "connector_patterns": { "allow": [ "postgres", - ] + ], }, "destination_patterns": { "allow": [ "interval_unconstitutional", - ] + ], }, "sources_to_platform_instance": { "calendar_elected": { "platform": "postgres", "env": "DEV", "database": "postgres_db", - } + }, }, }, }, @@ -285,7 +287,7 @@ def test_fivetran_with_snowflake_dest_and_null_connector_user(pytestconfig, tmp_ "filename": f"{output_file}", }, }, - } + }, ) pipeline.run() diff --git a/metadata-ingestion/tests/integration/git/test_git_clone.py b/metadata-ingestion/tests/integration/git/test_git_clone.py index 01e075930998a4..6d82b0323ed7b2 100644 --- a/metadata-ingestion/tests/integration/git/test_git_clone.py +++ b/metadata-ingestion/tests/integration/git/test_git_clone.py @@ -28,7 +28,8 @@ def test_base_url_guessing() -> None: # GitLab repo (notice the trailing slash). config_ref = GitReference( - repo="https://gitlab.com/gitlab-tests/sample-project/", branch="master" + repo="https://gitlab.com/gitlab-tests/sample-project/", + branch="master", ) assert ( config_ref.get_url_for_file_path("hello_world.md") @@ -37,7 +38,8 @@ def test_base_url_guessing() -> None: # Three-tier GitLab repo. config = GitInfo( - repo="https://gitlab.com/gitlab-com/gl-infra/reliability", branch="master" + repo="https://gitlab.com/gitlab-com/gl-infra/reliability", + branch="master", ) assert ( config.get_url_for_file_path("onboarding/gitlab.nix") @@ -67,7 +69,7 @@ def test_base_url_guessing() -> None: repo="https://github.com/datahub-project/datahub", branch="master", base_url="http://mygithubmirror.local", - ) + ), ) diff --git a/metadata-ingestion/tests/integration/grafana/test_grafana.py b/metadata-ingestion/tests/integration/grafana/test_grafana.py index cbac965884365d..edd193dc35562b 100644 --- a/metadata-ingestion/tests/integration/grafana/test_grafana.py +++ b/metadata-ingestion/tests/integration/grafana/test_grafana.py @@ -74,7 +74,8 @@ def test_api_key(): # Step 1: Create the service account service_account = grafana_client.create_service_account( - name="example-service-account", role="Viewer" + name="example-service-account", + role="Viewer", ) if service_account: print(f"Service Account Created: {service_account}") @@ -97,7 +98,8 @@ def test_api_key(): @pytest.fixture(scope="module") def loaded_grafana(docker_compose_runner, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "grafana" + test_resources_dir / "docker-compose.yml", + "grafana", ) as docker_services: wait_for_port( docker_services, @@ -144,7 +146,11 @@ def test_grafana_dashboard(loaded_grafana, pytestconfig, tmp_path, test_resource @freeze_time(FROZEN_TIME) def test_grafana_ingest( - loaded_grafana, pytestconfig, tmp_path, test_resources_dir, test_api_key + loaded_grafana, + pytestconfig, + tmp_path, + test_resources_dir, + test_api_key, ): # Wait for Grafana to be up and running url = "http://localhost:3000/api/health" @@ -175,7 +181,7 @@ def test_grafana_ingest( "type": "file", "config": {"filename": "./grafana_mcps.json"}, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/hana/test_hana.py b/metadata-ingestion/tests/integration/hana/test_hana.py index 726f8744167dbc..fe05b4b8cfb808 100644 --- a/metadata-ingestion/tests/integration/hana/test_hana.py +++ b/metadata-ingestion/tests/integration/hana/test_hana.py @@ -21,7 +21,8 @@ def test_hana_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): test_resources_dir = pytestconfig.rootpath / "tests/integration/hana" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "hana" + test_resources_dir / "docker-compose.yml", + "hana", ) as docker_services: # added longer timeout and pause due to slow start of hana wait_for_port( @@ -36,7 +37,8 @@ def test_hana_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "hana_to_file.yml").resolve() run_datahub_cmd( - ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ["ingest", "--strict-warnings", "-c", f"{config_file}"], + tmp_path=tmp_path, ) # Verify the output. diff --git a/metadata-ingestion/tests/integration/hive-metastore/test_hive_metastore.py b/metadata-ingestion/tests/integration/hive-metastore/test_hive_metastore.py index dbc1d0706c4b6b..4d497b110ec16a 100644 --- a/metadata-ingestion/tests/integration/hive-metastore/test_hive_metastore.py +++ b/metadata-ingestion/tests/integration/hive-metastore/test_hive_metastore.py @@ -20,7 +20,8 @@ def hive_metastore_runner(docker_compose_runner, pytestconfig): test_resources_dir = pytestconfig.rootpath / "tests/integration/hive-metastore" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "hive-metastore" + test_resources_dir / "docker-compose.yml", + "hive-metastore", ) as docker_services: wait_for_port(docker_services, "presto", 8080) wait_for_port(docker_services, "hiveserver2", 10000, timeout=120) @@ -147,7 +148,11 @@ def test_hive_metastore_ingest( @freeze_time(FROZEN_TIME) def test_hive_metastore_instance_ingest( - loaded_hive_metastore, test_resources_dir, pytestconfig, tmp_path, mock_time + loaded_hive_metastore, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, ): instance = "production_warehouse" platform = "hive" @@ -185,7 +190,7 @@ def test_hive_metastore_instance_ingest( # Assert that all events generated have instance specific urns urn_pattern = "^" + re.escape( - f"urn:li:dataset:(urn:li:dataPlatform:{platform},{instance}." + f"urn:li:dataset:(urn:li:dataPlatform:{platform},{instance}.", ) assert ( mce_helpers.assert_mce_entity_urn( @@ -214,7 +219,7 @@ def test_hive_metastore_instance_ingest( entity_type="dataset", aspect_name="dataPlatformInstance", aspect_field_matcher={ - "instance": f"urn:li:dataPlatformInstance:(urn:li:dataPlatform:{platform},{instance})" + "instance": f"urn:li:dataPlatformInstance:(urn:li:dataPlatform:{platform},{instance})", }, file=events_file, ) diff --git a/metadata-ingestion/tests/integration/hive/test_hive.py b/metadata-ingestion/tests/integration/hive/test_hive.py index caffb761380ddc..1bf285a900168b 100644 --- a/metadata-ingestion/tests/integration/hive/test_hive.py +++ b/metadata-ingestion/tests/integration/hive/test_hive.py @@ -19,7 +19,8 @@ def hive_runner(docker_compose_runner, pytestconfig): test_resources_dir = pytestconfig.rootpath / "tests/integration/hive" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "hive" + test_resources_dir / "docker-compose.yml", + "hive", ) as docker_services: wait_for_port(docker_services, "testhiveserver2", 10000, timeout=120) yield docker_services @@ -57,7 +58,11 @@ def base_pipeline_config(events_file, db=None): @freeze_time(FROZEN_TIME) def test_hive_ingest( - loaded_hive, pytestconfig, test_resources_dir, tmp_path, mock_time + loaded_hive, + pytestconfig, + test_resources_dir, + tmp_path, + mock_time, ): mce_out_file = "test_hive_ingest.json" events_file = tmp_path / mce_out_file @@ -85,7 +90,11 @@ def test_hive_ingest( @freeze_time(FROZEN_TIME) @pytest.mark.integration_batch_1 def test_hive_ingest_all_db( - loaded_hive, pytestconfig, test_resources_dir, tmp_path, mock_time + loaded_hive, + pytestconfig, + test_resources_dir, + tmp_path, + mock_time, ): mce_out_file = "test_hive_ingest.json" events_file = tmp_path / mce_out_file @@ -128,7 +137,7 @@ def test_hive_instance_check(loaded_hive, test_resources_dir, tmp_path, pytestco # Assert that all events generated have instance specific urns urn_pattern = "^" + re.escape( - f"urn:li:dataset:(urn:li:dataPlatform:{data_platform},{instance}." + f"urn:li:dataset:(urn:li:dataPlatform:{data_platform},{instance}.", ) mce_helpers.assert_mce_entity_urn( "ALL", @@ -151,7 +160,7 @@ def test_hive_instance_check(loaded_hive, test_resources_dir, tmp_path, pytestco entity_type="dataset", aspect_name="dataPlatformInstance", aspect_field_matcher={ - "instance": f"urn:li:dataPlatformInstance:(urn:li:dataPlatform:{data_platform},{instance})" + "instance": f"urn:li:dataPlatformInstance:(urn:li:dataPlatform:{data_platform},{instance})", }, file=events_file, ) diff --git a/metadata-ingestion/tests/integration/iceberg/setup/create.py b/metadata-ingestion/tests/integration/iceberg/setup/create.py index 0799ce9c93916c..0df209ef998a4a 100644 --- a/metadata-ingestion/tests/integration/iceberg/setup/create.py +++ b/metadata-ingestion/tests/integration/iceberg/setup/create.py @@ -24,7 +24,7 @@ def main(table_name: str) -> None: StructField("trip_distance", FloatType(), True), StructField("fare_amount", DoubleType(), True), StructField("store_and_fwd_flag", StringType(), True), - ] + ], ) data = [ diff --git a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py index 85809e557dd8d3..d6472bfb26ac9f 100644 --- a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py +++ b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py @@ -44,12 +44,16 @@ def spark_submit(file_path: str, args: str = "") -> None: @freeze_time(FROZEN_TIME) def test_multiprocessing_iceberg_ingest( - docker_compose_runner, pytestconfig, tmp_path, mock_time + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg/" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "iceberg" + test_resources_dir / "docker-compose.yml", + "iceberg", ) as docker_services: wait_for_port(docker_services, "spark-iceberg", 8888, timeout=120) @@ -61,7 +65,8 @@ def test_multiprocessing_iceberg_ingest( test_resources_dir / "iceberg_multiprocessing_to_file.yml" ).resolve() run_datahub_cmd( - ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ["ingest", "--strict-warnings", "-c", f"{config_file}"], + tmp_path=tmp_path, ) # Verify the output. mce_helpers.check_golden_file( @@ -77,7 +82,8 @@ def test_iceberg_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg/" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "iceberg" + test_resources_dir / "docker-compose.yml", + "iceberg", ) as docker_services: wait_for_port(docker_services, "spark-iceberg", 8888, timeout=120) @@ -87,7 +93,8 @@ def test_iceberg_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "iceberg_to_file.yml").resolve() run_datahub_cmd( - ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ["ingest", "--strict-warnings", "-c", f"{config_file}"], + tmp_path=tmp_path, ) # Verify the output. mce_helpers.check_golden_file( @@ -100,7 +107,11 @@ def test_iceberg_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time @freeze_time(FROZEN_TIME) def test_iceberg_stateful_ingest( - docker_compose_runner, pytestconfig, tmp_path, mock_time, mock_datahub_graph + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg" platform_instance = "test_platform_instance" @@ -137,13 +148,14 @@ def test_iceberg_stateful_ingest( }, "sink": { # we are not really interested in the resulting events for this test - "type": "console" + "type": "console", }, "pipeline_name": "test_pipeline", } with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "iceberg" + test_resources_dir / "docker-compose.yml", + "iceberg", ) as docker_services, patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", mock_datahub_graph, @@ -184,7 +196,7 @@ def test_iceberg_stateful_ingest( state1 = checkpoint1.state state2 = checkpoint2.state difference_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert len(difference_urns) == 1 @@ -195,10 +207,12 @@ def test_iceberg_stateful_ingest( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Verify the output. @@ -215,7 +229,8 @@ def test_iceberg_profiling(docker_compose_runner, pytestconfig, tmp_path, mock_t test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg/" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "iceberg" + test_resources_dir / "docker-compose.yml", + "iceberg", ) as docker_services: wait_for_port(docker_services, "spark-iceberg", 8888, timeout=120) @@ -225,7 +240,8 @@ def test_iceberg_profiling(docker_compose_runner, pytestconfig, tmp_path, mock_t # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "iceberg_profile_to_file.yml").resolve() run_datahub_cmd( - ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ["ingest", "--strict-warnings", "-c", f"{config_file}"], + tmp_path=tmp_path, ) # Verify the output. diff --git a/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py b/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py index d8c98b12951f5d..0a9160fd8db147 100644 --- a/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py +++ b/metadata-ingestion/tests/integration/kafka-connect/test_kafka_connect.py @@ -53,7 +53,9 @@ def kafka_connect_runner(docker_compose_runner, pytestconfig, test_resources_dir str(test_resources_dir / "docker-compose.override.yml"), ] with docker_compose_runner( - docker_compose_file, "kafka-connect", cleanup=False + docker_compose_file, + "kafka-connect", + cleanup=False, ) as docker_services: wait_for_port( docker_services, @@ -367,7 +369,10 @@ def loaded_kafka_connect(kafka_connect_runner): @freeze_time(FROZEN_TIME) def test_kafka_connect_ingest( - loaded_kafka_connect, pytestconfig, tmp_path, test_resources_dir + loaded_kafka_connect, + pytestconfig, + tmp_path, + test_resources_dir, ): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "kafka_connect_to_file.yml").resolve() @@ -384,7 +389,10 @@ def test_kafka_connect_ingest( @freeze_time(FROZEN_TIME) def test_kafka_connect_mongosourceconnect_ingest( - loaded_kafka_connect, pytestconfig, tmp_path, test_resources_dir + loaded_kafka_connect, + pytestconfig, + tmp_path, + test_resources_dir, ): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "kafka_connect_mongo_to_file.yml").resolve() @@ -401,7 +409,10 @@ def test_kafka_connect_mongosourceconnect_ingest( @freeze_time(FROZEN_TIME) def test_kafka_connect_s3sink_ingest( - loaded_kafka_connect, pytestconfig, tmp_path, test_resources_dir + loaded_kafka_connect, + pytestconfig, + tmp_path, + test_resources_dir, ): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "kafka_connect_s3sink_to_file.yml").resolve() @@ -418,7 +429,11 @@ def test_kafka_connect_s3sink_ingest( @freeze_time(FROZEN_TIME) def test_kafka_connect_ingest_stateful( - loaded_kafka_connect, pytestconfig, tmp_path, mock_datahub_graph, test_resources_dir + loaded_kafka_connect, + pytestconfig, + tmp_path, + mock_datahub_graph, + test_resources_dir, ): output_file_name: str = "kafka_connect_before_mces.json" golden_file_name: str = "kafka_connect_before_golden_mces.json" @@ -475,7 +490,7 @@ def test_kafka_connect_ingest_stateful( ) as mock_checkpoint: mock_checkpoint.return_value = mock_datahub_graph pipeline_run1_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( # type: ignore - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) # Set the special properties for this run pipeline_run1_config["source"]["config"]["connector_patterns"]["allow"] = [ @@ -507,7 +522,7 @@ def test_kafka_connect_ingest_stateful( ) as mock_checkpoint: mock_checkpoint.return_value = mock_datahub_graph pipeline_run2_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) # Set the special properties for this run pipeline_run1_config["source"]["config"]["connector_patterns"]["allow"] = [ @@ -532,10 +547,12 @@ def test_kafka_connect_ingest_stateful( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform all assertions on the states. The deleted table should not be @@ -544,17 +561,17 @@ def test_kafka_connect_ingest_stateful( state2 = cast(GenericCheckpointState, checkpoint2.state) difference_pipeline_urns = list( - state1.get_urns_not_in(type="dataFlow", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataFlow", other_checkpoint_state=state2), ) assert len(difference_pipeline_urns) == 1 deleted_pipeline_urns: List[str] = [ - "urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD)" + "urn:li:dataFlow:(kafka-connect,connect-instance-1.mysql_source2,PROD)", ] assert sorted(deleted_pipeline_urns) == sorted(difference_pipeline_urns) difference_job_urns = list( - state1.get_urns_not_in(type="dataJob", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataJob", other_checkpoint_state=state2), ) assert len(difference_job_urns) == 3 deleted_job_urns = [ @@ -591,7 +608,10 @@ def register_mock_api(request_mock: Any, override_data: Optional[dict] = None) - @freeze_time(FROZEN_TIME) def test_kafka_connect_snowflake_sink_ingest( - pytestconfig, tmp_path, mock_time, requests_mock + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/kafka-connect" override_data = { @@ -641,7 +661,7 @@ def test_kafka_connect_snowflake_sink_ingest( "connector_patterns": { "allow": [ "snowflake_sink1", - ] + ], }, }, }, @@ -651,7 +671,7 @@ def test_kafka_connect_snowflake_sink_ingest( "filename": f"{tmp_path}/kafka_connect_snowflake_sink_mces.json", }, }, - } + }, ) pipeline.run() @@ -667,7 +687,10 @@ def test_kafka_connect_snowflake_sink_ingest( @freeze_time(FROZEN_TIME) def test_kafka_connect_bigquery_sink_ingest( - loaded_kafka_connect, pytestconfig, tmp_path, test_resources_dir + loaded_kafka_connect, + pytestconfig, + tmp_path, + test_resources_dir, ): # Run the metadata ingestion pipeline. config_file = ( diff --git a/metadata-ingestion/tests/integration/kafka/create_key_value_topic.py b/metadata-ingestion/tests/integration/kafka/create_key_value_topic.py index 4ce3fa0d1cd2d3..ef71247e4403b8 100644 --- a/metadata-ingestion/tests/integration/kafka/create_key_value_topic.py +++ b/metadata-ingestion/tests/integration/kafka/create_key_value_topic.py @@ -31,18 +31,23 @@ def parse_command_line_args(): help="Record key. If not provided, will be a random UUID", ) arg_parser.add_argument( - "--key-schema-file", required=False, help="File name of key Avro schema to use" + "--key-schema-file", + required=False, + help="File name of key Avro schema to use", ) arg_parser.add_argument("--record-value", required=False, help="Record value") arg_parser.add_argument( - "--value-schema-file", required=False, help="File name of Avro schema to use" + "--value-schema-file", + required=False, + help="File name of Avro schema to use", ) return arg_parser.parse_args() def load_avro_schema_from_file( - key_schema_file: str, value_schema_file: str + key_schema_file: str, + value_schema_file: str, ) -> Tuple[Schema, Schema]: key_schema = ( avro.load(key_schema_file) @@ -58,7 +63,8 @@ def load_avro_schema_from_file( def send_record(args): key_schema, value_schema = load_avro_schema_from_file( - args.key_schema_file, args.value_schema_file + args.key_schema_file, + args.value_schema_file, ) producer_config = { @@ -80,7 +86,7 @@ def send_record(args): producer.flush() except Exception as e: print( - f"Exception while producing record value - {value} to topic - {args.topic}: {e}" + f"Exception while producing record value - {value} to topic - {args.topic}: {e}", ) raise e else: diff --git a/metadata-ingestion/tests/integration/kafka/test_kafka.py b/metadata-ingestion/tests/integration/kafka/test_kafka.py index 648c4b26b20a76..174dd66bf719aa 100644 --- a/metadata-ingestion/tests/integration/kafka/test_kafka.py +++ b/metadata-ingestion/tests/integration/kafka/test_kafka.py @@ -25,13 +25,16 @@ def test_resources_dir(pytestconfig): @pytest.fixture(scope="module") def mock_kafka_service(docker_compose_runner, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "kafka", cleanup=False + test_resources_dir / "docker-compose.yml", + "kafka", + cleanup=False, ) as docker_services: wait_for_port(docker_services, "test_zookeeper", 52181, timeout=120) # Running docker compose twice, since the broker sometimes fails to come up on the first try. with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "kafka" + test_resources_dir / "docker-compose.yml", + "kafka", ) as docker_services: wait_for_port(docker_services, "test_broker", 29092, timeout=120) wait_for_port(docker_services, "test_schema_registry", 8081, timeout=120) @@ -47,7 +50,12 @@ def mock_kafka_service(docker_compose_runner, test_resources_dir): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_kafka_ingest( - mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time, approach + mock_kafka_service, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, + approach, ): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / f"{approach}_to_file.yml").resolve() @@ -97,12 +105,13 @@ def test_kafka_test_connection(mock_kafka_service, config_dict, is_success): ) else: test_connection_helpers.assert_basic_connectivity_failure( - report, "Failed to get metadata" + report, + "Failed to get metadata", ) test_connection_helpers.assert_capability_report( capability_report=report.capability_report, failure_capabilities={ - SourceCapability.SCHEMA_METADATA: "[Errno 111] Connection refused" + SourceCapability.SCHEMA_METADATA: "[Errno 111] Connection refused", }, ) @@ -110,7 +119,11 @@ def test_kafka_test_connection(mock_kafka_service, config_dict, is_success): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_kafka_oauth_callback( - mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time + mock_kafka_service, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, ): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "kafka_to_file_oauth.yml").resolve() @@ -118,7 +131,7 @@ def test_kafka_oauth_callback( log_file = tmp_path / "kafka_oauth_message.log" file_handler = logging.FileHandler( - str(log_file) + str(log_file), ) # Add a file handler to later validate a test-case logging.getLogger().addHandler(file_handler) @@ -171,8 +184,8 @@ def test_kafka_source_oauth_cb_signature(): "connection": { "bootstrap": "foobar:9092", "consumer_config": {"oauth_cb": "oauth:create_token_no_args"}, - } - } + }, + }, ) with pytest.raises( @@ -184,6 +197,6 @@ def test_kafka_source_oauth_cb_signature(): "connection": { "bootstrap": "foobar:9092", "consumer_config": {"oauth_cb": "oauth:create_token_only_kwargs"}, - } - } + }, + }, ) diff --git a/metadata-ingestion/tests/integration/kafka/test_kafka_state.py b/metadata-ingestion/tests/integration/kafka/test_kafka_state.py index 24e81fbf128b01..2c8682f916bab3 100644 --- a/metadata-ingestion/tests/integration/kafka/test_kafka_state.py +++ b/metadata-ingestion/tests/integration/kafka/test_kafka_state.py @@ -81,7 +81,11 @@ def __exit__(self, exc_type, exc, traceback): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_kafka_ingest_with_stateful( - docker_compose_runner, pytestconfig, tmp_path, mock_time, mock_datahub_graph + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/kafka" topic_prefix: str = "stateful_ingestion_test" @@ -89,7 +93,8 @@ def test_kafka_ingest_with_stateful( platform_instance = "test_platform_instance_1" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "kafka" + test_resources_dir / "docker-compose.yml", + "kafka", ) as docker_services: wait_for_port(docker_services, "test_broker", KAFKA_PORT, timeout=120) wait_for_port(docker_services, "test_schema_registry", 8081, timeout=120) @@ -118,20 +123,21 @@ def test_kafka_ingest_with_stateful( }, "sink": { # we are not really interested in the resulting events for this test - "type": "console" + "type": "console", }, "pipeline_name": "test_pipeline", # enable reporting "reporting": [ { "type": "datahub", - } + }, ], } # topics will be automatically created and deleted upon test completion with KafkaTopicsCxtManager( - topic_names, KAFKA_BOOTSTRAP_SERVER + topic_names, + KAFKA_BOOTSTRAP_SERVER, ) as kafka_ctx, patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", mock_datahub_graph, @@ -161,7 +167,7 @@ def test_kafka_ingest_with_stateful( state1 = checkpoint1.state state2 = checkpoint2.state difference_urns = list( - state1.get_urns_not_in(type="topic", other_checkpoint_state=state2) + state1.get_urns_not_in(type="topic", other_checkpoint_state=state2), ) assert len(difference_urns) == 1 @@ -174,8 +180,10 @@ def test_kafka_ingest_with_stateful( # NOTE: The following validation asserts for presence of state as well # and validates reporting. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) diff --git a/metadata-ingestion/tests/integration/ldap/test_ldap.py b/metadata-ingestion/tests/integration/ldap/test_ldap.py index 3e76f13fc823d2..6422f4b915927d 100644 --- a/metadata-ingestion/tests/integration/ldap/test_ldap.py +++ b/metadata-ingestion/tests/integration/ldap/test_ldap.py @@ -12,7 +12,8 @@ def test_ldap_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): test_resources_dir = pytestconfig.rootpath / "tests/integration/ldap" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "ldap" + test_resources_dir / "docker-compose.yml", + "ldap", ) as docker_services: # The openldap container loads the sample data after exposing the port publicly. As such, # we must wait a little bit extra to ensure that the sample data is loaded. @@ -42,7 +43,7 @@ def test_ldap_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/ldap_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -59,7 +60,8 @@ def test_ldap_memberof_ingest(docker_compose_runner, pytestconfig, tmp_path, moc test_resources_dir = pytestconfig.rootpath / "tests/integration/ldap" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "ldap" + test_resources_dir / "docker-compose.yml", + "ldap", ) as docker_services: # The openldap container loads the sample data after exposing the port publicly. As such, # we must wait a little bit extra to ensure that the sample data is loaded. @@ -90,7 +92,7 @@ def test_ldap_memberof_ingest(docker_compose_runner, pytestconfig, tmp_path, moc "filename": f"{tmp_path}/ldap_memberof_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -104,12 +106,16 @@ def test_ldap_memberof_ingest(docker_compose_runner, pytestconfig, tmp_path, moc @pytest.mark.integration def test_ldap_ingest_with_email_as_username( - docker_compose_runner, pytestconfig, tmp_path, mock_time + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/ldap" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "ldap" + test_resources_dir / "docker-compose.yml", + "ldap", ) as docker_services: # The openldap container loads the sample data after exposing the port publicly. As such, # we must wait a little bit extra to ensure that the sample data is loaded. @@ -141,7 +147,7 @@ def test_ldap_ingest_with_email_as_username( "filename": f"{tmp_path}/ldap_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py b/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py index 0563c2771b0932..88fbf95213c4f2 100644 --- a/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py +++ b/metadata-ingestion/tests/integration/ldap/test_ldap_stateful.py @@ -30,7 +30,8 @@ def ldap_ingest_common( ): test_resources_dir = pathlib.Path(pytestconfig.rootpath / "tests/integration/ldap") with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "ldap" + test_resources_dir / "docker-compose.yml", + "ldap", ) as docker_services: # The openldap container loads the sample data after exposing the port publicly. As such, # we must wait a little bit extra to ensure that the sample data is loaded. @@ -74,7 +75,7 @@ def ldap_ingest_common( "filename": f"{tmp_path}/{output_file_name}", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -91,7 +92,11 @@ def ldap_ingest_common( @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_ldap_stateful( - docker_compose_runner, pytestconfig, tmp_path, mock_time, mock_datahub_graph + docker_compose_runner, + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): golden_file_name: str = "ldap_mces_golden_stateful.json" output_file_name: str = "ldap_mces_stateful.json" @@ -139,17 +144,19 @@ def test_ldap_stateful( assert checkpoint2.state validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) state1 = checkpoint1.state state2 = checkpoint2.state difference_dataset_urns = list( - state1.get_urns_not_in(type="corpuser", other_checkpoint_state=state2) + state1.get_urns_not_in(type="corpuser", other_checkpoint_state=state2), ) assert len(difference_dataset_urns) == 1 deleted_dataset_urns = [ @@ -189,7 +196,7 @@ def test_ldap_stateful( state4 = checkpoint4.state difference_dataset_urns = list( - state3.get_urns_not_in(type="corpGroup", other_checkpoint_state=state4) + state3.get_urns_not_in(type="corpGroup", other_checkpoint_state=state4), ) assert len(difference_dataset_urns) == 1 deleted_dataset_urns = [ diff --git a/metadata-ingestion/tests/integration/looker/test_looker.py b/metadata-ingestion/tests/integration/looker/test_looker.py index bbcc6332539c02..79f26d8714fbdc 100644 --- a/metadata-ingestion/tests/integration/looker/test_looker.py +++ b/metadata-ingestion/tests/integration/looker/test_looker.py @@ -111,7 +111,7 @@ def test_looker_ingest(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/looker_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -144,8 +144,8 @@ def setup_mock_external_project_view_explore(mocked_client): label_short="Dimensions One Label", view="faa_flights", source_file="imported_projects/datahub-demo/views/datahub-demo/datasets/faa_flights.view.lkml", - ) - ] + ), + ], ), source_file="test_source_file.lkml", ) @@ -179,7 +179,7 @@ def test_looker_ingest_external_project_view(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/looker_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -220,7 +220,7 @@ def test_looker_ingest_joins(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/looker_mces_joins.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -256,7 +256,7 @@ def test_looker_ingest_unaliased_joins(pytestconfig, tmp_path, mock_time): fields=["dim1"], dynamic_fields='[{"table_calculation":"calc","label":"foobar","expression":"offset(${my_table.value},1)","value_format":null,"value_format_name":"eur","_kind_hint":"measure","_type_hint":"number"}]', ), - ) + ), ], ) setup_mock_explore_unaliased_with_joins(mocked_client) @@ -281,7 +281,7 @@ def test_looker_ingest_unaliased_joins(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/looker_mces_unaliased_joins.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -314,7 +314,7 @@ def setup_mock_dashboard(mocked_client): view="my_view", dynamic_fields='[{"table_calculation":"calc","label":"foobar","expression":"offset(${my_table.value},1)","value_format":null,"value_format_name":"eur","_kind_hint":"measure","_type_hint":"number"}]', ), - ) + ), ], ) @@ -355,7 +355,7 @@ def setup_mock_look(mocked_client): ], dynamic_fields=None, filters=None, - ) + ), ), LookWithQuery( query=Query( @@ -367,7 +367,7 @@ def setup_mock_look(mocked_client): ], dynamic_fields=None, filters=None, - ) + ), ), ] @@ -379,7 +379,7 @@ def setup_mock_soft_deleted_look(mocked_client): title="Soft Deleted", description="I am not part of any Dashboard", query_id="1", - ) + ), ] @@ -418,7 +418,8 @@ def setup_mock_dashboard_multiple_charts(mocked_client): def setup_mock_dashboard_with_usage( - mocked_client: mock.MagicMock, skip_look: bool = False + mocked_client: mock.MagicMock, + skip_look: bool = False, ) -> None: mocked_client.all_dashboards.return_value = [Dashboard(id="1")] mocked_client.dashboard.return_value = Dashboard( @@ -471,8 +472,8 @@ def setup_mock_explore_with_joins(mocked_client): type="string", description="dimension one description", label_short="Dimensions One Label", - ) - ] + ), + ], ), source_file="test_source_file.lkml", joins=[ @@ -511,8 +512,8 @@ def setup_mock_explore_unaliased_with_joins(mocked_client): dimension_group=None, description="dimension one description", label_short="Dimensions One Label", - ) - ] + ), + ], ), source_file="test_source_file.lkml", joins=[ @@ -545,7 +546,7 @@ def setup_mock_explore( dimension_group=None, description="dimension one description", label_short="Dimensions One Label", - ) + ), ] lkml_fields.extend(additional_lkml_fields) @@ -590,7 +591,9 @@ def all_users( def side_effect_query_inline( - result_format: str, body: WriteQuery, transport_options: Optional[TransportOptions] + result_format: str, + body: WriteQuery, + transport_options: Optional[TransportOptions], ) -> str: query_type: looker_usage.QueryId if result_format == "sql": @@ -628,7 +631,7 @@ def side_effect_query_inline( HistoryViewField.HISTORY_DASHBOARD_USER: 1, HistoryViewField.HISTORY_DASHBOARD_RUN_COUNT: 5, }, - ] + ], ), looker_usage.QueryId.DASHBOARD_PER_USER_PER_DAY_USAGE_STAT: json.dumps( [ @@ -650,7 +653,7 @@ def side_effect_query_inline( UserViewField.USER_ID: 1, HistoryViewField.HISTORY_DASHBOARD_RUN_COUNT: 5, }, - ] + ], ), looker_usage.QueryId.LOOK_PER_DAY_USAGE_STAT: json.dumps( [ @@ -669,7 +672,7 @@ def side_effect_query_inline( HistoryViewField.HISTORY_COUNT: 35, LookViewField.LOOK_ID: 3, }, - ] + ], ), looker_usage.QueryId.LOOK_PER_USER_PER_DAY_USAGE_STAT: json.dumps( [ @@ -685,7 +688,7 @@ def side_effect_query_inline( LookViewField.LOOK_ID: 3, UserViewField.USER_ID: 2, }, - ] + ], ), } @@ -725,7 +728,7 @@ def test_looker_ingest_allow_pattern(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/looker_mces.json", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -772,7 +775,7 @@ def test_looker_ingest_usage_history(pytestconfig, tmp_path, mock_time): "filename": temp_output_file, }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -837,7 +840,7 @@ def test_looker_filter_usage_history(pytestconfig, tmp_path, mock_time): "filename": temp_output_file, }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -964,10 +967,12 @@ def looker_source_config(sink_file_name): # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform all assertions on the states. The deleted table should not be @@ -976,23 +981,23 @@ def looker_source_config(sink_file_name): state2 = cast(GenericCheckpointState, checkpoint2.state) difference_dataset_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert len(difference_dataset_urns) == 1 deleted_dataset_urns: List[str] = [ - "urn:li:dataset:(urn:li:dataPlatform:looker,bogus data.explore.my_view,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:looker,bogus data.explore.my_view,PROD)", ] assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns) difference_chart_urns = list( - state1.get_urns_not_in(type="chart", other_checkpoint_state=state2) + state1.get_urns_not_in(type="chart", other_checkpoint_state=state2), ) assert len(difference_chart_urns) == 1 deleted_chart_urns = ["urn:li:chart:(looker,dashboard_elements.10)"] assert sorted(deleted_chart_urns) == sorted(difference_chart_urns) difference_dashboard_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_dashboard_urns) == 1 deleted_dashboard_urns = ["urn:li:dashboard:(looker,dashboards.11)"] @@ -1064,7 +1069,10 @@ def ingest_independent_looks( @freeze_time(FROZEN_TIME) def test_independent_looks_ingest_with_personal_folder( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): ingest_independent_looks( pytestconfig=pytestconfig, @@ -1078,7 +1086,10 @@ def test_independent_looks_ingest_with_personal_folder( @freeze_time(FROZEN_TIME) def test_independent_looks_ingest_without_personal_folder( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): ingest_independent_looks( pytestconfig=pytestconfig, @@ -1092,7 +1103,10 @@ def test_independent_looks_ingest_without_personal_folder( @freeze_time(FROZEN_TIME) def test_file_path_in_view_naming_pattern( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): mocked_client = mock.MagicMock() new_recipe = get_default_recipe(output_file_path=f"{tmp_path}/looker_mces.json") @@ -1119,7 +1133,7 @@ def test_file_path_in_view_naming_pattern( label_short="Dimensions One Label", view="underlying_view", source_file="views/underlying_view.view.lkml", - ) + ), ], ) setup_mock_look(mocked_client) @@ -1156,7 +1170,7 @@ def test_independent_soft_deleted_looks( base_url="https://fake.com", client_id="foo", client_secret="bar", - ) + ), ) looks: List[Look] = looker_api.all_looks( fields=["id"], @@ -1319,7 +1333,7 @@ def side_effect_function_for_dashboards(*args: Tuple[str], **kwargs: Any) -> Das fields=["dim1"], dynamic_fields='[{"table_calculation":"calc","label":"foobar","expression":"offset(${my_table.value},1)","value_format":null,"value_format_name":"eur","_kind_hint":"measure","_type_hint":"number"}]', ), - ) + ), ], ) @@ -1342,7 +1356,7 @@ def side_effect_function_for_dashboards(*args: Tuple[str], **kwargs: Any) -> Das fields=["dim1"], dynamic_fields='[{"table_calculation":"calc","label":"foobar","expression":"offset(${my_table.value},1)","value_format":null,"value_format_name":"eur","_kind_hint":"measure","_type_hint":"number"}]', ), - ) + ), ], ) @@ -1365,7 +1379,7 @@ def side_effect_function_for_dashboards(*args: Tuple[str], **kwargs: Any) -> Das fields=["dim1"], dynamic_fields='[{"table_calculation":"calc","label":"foobar","expression":"offset(${my_table.value},1)","value_format":null,"value_format_name":"eur","_kind_hint":"measure","_type_hint":"number"}]', ), - ) + ), ], ) @@ -1382,7 +1396,8 @@ def side_effect_function_for_dashboards(*args: Tuple[str], **kwargs: Any) -> Das def side_effect_function_folder_ancestors( - *args: Tuple[Any], **kwargs: Any + *args: Tuple[Any], + **kwargs: Any, ) -> Sequence[Folder]: assert args[0] in ["a", "b", "c"], "Invalid folder id" diff --git a/metadata-ingestion/tests/integration/lookml/test_lookml.py b/metadata-ingestion/tests/integration/lookml/test_lookml.py index d803b8498104fd..bd77d9d9ad9343 100644 --- a/metadata-ingestion/tests/integration/lookml/test_lookml.py +++ b/metadata-ingestion/tests/integration/lookml/test_lookml.py @@ -73,8 +73,9 @@ def test_lookml_ingest(pytestconfig, tmp_path, mock_time): pipeline = Pipeline.create( get_default_recipe( - f"{tmp_path}/{mce_out_file}", f"{test_resources_dir}/lkml_samples" - ) + f"{tmp_path}/{mce_out_file}", + f"{test_resources_dir}/lkml_samples", + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -97,7 +98,8 @@ def test_lookml_refinement_ingest(pytestconfig, tmp_path, mock_time): # to resolve relative table names (which are not fully qualified) # We keep this check just to validate that ingestion doesn't croak on this config new_recipe = get_default_recipe( - f"{tmp_path}/{mce_out_file}", f"{test_resources_dir}/lkml_samples" + f"{tmp_path}/{mce_out_file}", + f"{test_resources_dir}/lkml_samples", ) new_recipe["source"]["config"]["process_refinements"] = True @@ -134,10 +136,10 @@ def test_lookml_refinement_include_order(pytestconfig, tmp_path, mock_time): new_recipe["source"]["config"]["process_refinements"] = True new_recipe["source"]["config"]["project_name"] = "lkml_refinement_sample1" new_recipe["source"]["config"]["view_naming_pattern"] = { - "pattern": "{project}.{model}.view.{name}" + "pattern": "{project}.{model}.view.{name}", } new_recipe["source"]["config"]["connection_to_platform_map"] = { - "db-connection": "conn" + "db-connection": "conn", } pipeline = Pipeline.create(new_recipe) pipeline.run() @@ -180,13 +182,13 @@ def test_lookml_explore_refinement(pytestconfig, tmp_path, mock_time): "client_id": "fake_client_id", "client_secret": "fake_client_secret", }, - } + }, ), connection_definition=None, # type: ignore ) new_explore: dict = refinement_resolver.apply_explore_refinement( - looker_model.explores[0] + looker_model.explores[0], ) assert new_explore.get("extends") is not None @@ -203,7 +205,7 @@ def test_lookml_view_merge(pytestconfig, tmp_path, mock_time): "primary_key": "yes", "sql": '${TABLE}."id"', "name": "id", - } + }, ], "name": "flights", } @@ -215,7 +217,7 @@ def test_lookml_view_merge(pytestconfig, tmp_path, mock_time): "type": "string", "sql": '${TABLE}."air_carrier"', "name": "air_carrier", - } + }, ], "name": "+flights", }, @@ -257,7 +259,8 @@ def test_lookml_view_merge(pytestconfig, tmp_path, mock_time): ] merged_view: dict = LookerRefinementResolver.merge_refinements( - raw_view=raw_view, refinement_views=refinement_views + raw_view=raw_view, + refinement_views=refinement_views, ) expected_view: dict = { @@ -289,7 +292,7 @@ def test_lookml_view_merge(pytestconfig, tmp_path, mock_time): "sql_start": '${TABLE}."enrollment_date"', "sql_end": '${TABLE}."graduation_date"', "name": "enrolled", - } + }, ], } @@ -313,7 +316,7 @@ def test_lookml_ingest_offline(pytestconfig, tmp_path, mock_time): "platform": "snowflake", "default_db": "default_db", "default_schema": "default_schema", - } + }, }, "parse_table_names_from_sql": True, "project_name": "lkml_samples", @@ -328,7 +331,7 @@ def test_lookml_ingest_offline(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -358,7 +361,7 @@ def test_lookml_ingest_offline_with_model_deny(pytestconfig, tmp_path, mock_time "platform": "snowflake", "default_db": "default_db", "default_schema": "default_schema", - } + }, }, "parse_table_names_from_sql": True, "project_name": "lkml_samples", @@ -373,7 +376,7 @@ def test_lookml_ingest_offline_with_model_deny(pytestconfig, tmp_path, mock_time "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -405,7 +408,7 @@ def test_lookml_ingest_offline_platform_instance(pytestconfig, tmp_path, mock_ti "platform_env": "dev", "default_db": "default_db", "default_schema": "default_schema", - } + }, }, "parse_table_names_from_sql": True, "project_name": "lkml_samples", @@ -420,7 +423,7 @@ def test_lookml_ingest_offline_platform_instance(pytestconfig, tmp_path, mock_ti "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -441,7 +444,9 @@ def test_lookml_ingest_api_bigquery(pytestconfig, tmp_path, mock_time): tmp_path, mock_time, DBConnection( - dialect_name="bigquery", host="project-foo", database="default-db" + dialect_name="bigquery", + host="project-foo", + database="default-db", ), ) @@ -503,7 +508,7 @@ def ingestion_test( "filename": f"{tmp_path}/{mce_out_file}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -533,7 +538,7 @@ def test_lookml_git_info(pytestconfig, tmp_path, mock_time): "platform": "snowflake", "default_db": "default_db", "default_schema": "default_schema", - } + }, }, "parse_table_names_from_sql": True, "project_name": "lkml_samples", @@ -549,7 +554,7 @@ def test_lookml_git_info(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -602,7 +607,7 @@ def test_reachable_views(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -648,7 +653,7 @@ def test_hive_platform_drops_ids(pytestconfig, tmp_path, mock_time): "platform": "hive", "default_db": "default_database", "default_schema": "default_schema", - } + }, }, "parse_table_names_from_sql": True, "project_name": "lkml_samples", @@ -664,7 +669,7 @@ def test_hive_platform_drops_ids(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -756,11 +761,12 @@ def test_lookml_base_folder(): "deploy_key": "this-is-fake", }, "api": fake_api, - } + }, ) with pytest.raises( - pydantic.ValidationError, match=r"base_folder.+nor.+git_info.+provided" + pydantic.ValidationError, + match=r"base_folder.+nor.+git_info.+provided", ): LookMLSourceConfig.parse_obj({"api": fake_api}) @@ -778,7 +784,7 @@ def test_same_name_views_different_file_path(pytestconfig, tmp_path, mock_time): "config": { "base_folder": str( test_resources_dir - / "lkml_same_name_views_different_file_path_samples" + / "lkml_same_name_views_different_file_path_samples", ), "connection_to_platform_map": { "my_connection": { @@ -802,7 +808,7 @@ def test_same_name_views_different_file_path(pytestconfig, tmp_path, mock_time): "filename": f"{tmp_path}/{mce_out}", }, }, - } + }, ) pipeline.run() pipeline.pretty_print_summary() @@ -920,7 +926,7 @@ def test_special_liquid_variables(): } actual_liquid_variable = SpecialVariable( - input_liquid_variable + input_liquid_variable, ).liquid_variable_with_default(text) assert ( expected_liquid_variable == actual_liquid_variable @@ -943,7 +949,7 @@ def test_special_liquid_variables(): } actual_liquid_variable = SpecialVariable( - input_liquid_variable + input_liquid_variable, ).liquid_variable_with_default(text) assert ( expected_liquid_variable == actual_liquid_variable @@ -1000,7 +1006,7 @@ def test_drop_hive(pytestconfig, tmp_path, mock_time): ) new_recipe["source"]["config"]["connection_to_platform_map"] = { - "my_connection": "hive" + "my_connection": "hive", } pipeline = Pipeline.create(new_recipe) @@ -1027,7 +1033,7 @@ def test_gms_schema_resolution(pytestconfig, tmp_path, mock_time): ) new_recipe["source"]["config"]["connection_to_platform_map"] = { - "my_connection": "hive" + "my_connection": "hive", } return_value: Tuple[str, Optional[SchemaInfo]] = ( diff --git a/metadata-ingestion/tests/integration/metabase/test_metabase.py b/metadata-ingestion/tests/integration/metabase/test_metabase.py index 2d67f0ca5223f8..a6e8dcfbfbbb37 100644 --- a/metadata-ingestion/tests/integration/metabase/test_metabase.py +++ b/metadata-ingestion/tests/integration/metabase/test_metabase.py @@ -29,7 +29,12 @@ class MockResponse: def __init__( - self, url, json_response_map=None, data=None, jsond=None, error_list=None + self, + url, + json_response_map=None, + data=None, + jsond=None, + error_list=None, ): self.json_data = data self.url = url @@ -46,7 +51,7 @@ def json(self): if not pathlib.Path(response_json_path).exists(): raise Exception( - f"mock response file not found {self.url} -> {mocked_response_file}" + f"mock response file not found {self.url} -> {mocked_response_file}", ) with open(response_json_path) as file: @@ -164,22 +169,26 @@ def test_pipeline(pytestconfig, tmp_path): @freeze_time(FROZEN_TIME) def test_metabase_ingest_success( - pytestconfig, tmp_path, test_pipeline, mock_datahub_graph, default_json_response_map + pytestconfig, + tmp_path, + test_pipeline, + mock_datahub_graph, + default_json_response_map, ): with patch( "datahub.ingestion.source.metabase.requests.session", side_effect=MockResponse.build_mocked_requests_sucess( - default_json_response_map + default_json_response_map, ), ), patch( "datahub.ingestion.source.metabase.requests.post", side_effect=MockResponse.build_mocked_requests_session_post( - default_json_response_map + default_json_response_map, ), ), patch( "datahub.ingestion.source.metabase.requests.delete", side_effect=MockResponse.build_mocked_requests_session_delete( - default_json_response_map + default_json_response_map, ), ), patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", @@ -201,7 +210,9 @@ def test_metabase_ingest_success( @freeze_time(FROZEN_TIME) def test_stateful_ingestion( - test_pipeline, mock_datahub_graph, default_json_response_map + test_pipeline, + mock_datahub_graph, + default_json_response_map, ): json_response_map = default_json_response_map with patch( @@ -213,7 +224,7 @@ def test_stateful_ingestion( ), patch( "datahub.ingestion.source.metabase.requests.delete", side_effect=MockResponse.build_mocked_requests_session_delete( - json_response_map + json_response_map, ), ), patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", @@ -245,17 +256,19 @@ def test_stateful_ingestion( state2 = checkpoint2.state difference_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_urns) == 1 assert difference_urns[0] == "urn:li:dashboard:(metabase,20)" validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) @@ -264,17 +277,17 @@ def test_metabase_ingest_failure(pytestconfig, tmp_path, default_json_response_m with patch( "datahub.ingestion.source.metabase.requests.session", side_effect=MockResponse.build_mocked_requests_failure( - default_json_response_map + default_json_response_map, ), ), patch( "datahub.ingestion.source.metabase.requests.post", side_effect=MockResponse.build_mocked_requests_session_post( - default_json_response_map + default_json_response_map, ), ), patch( "datahub.ingestion.source.metabase.requests.delete", side_effect=MockResponse.build_mocked_requests_session_delete( - default_json_response_map + default_json_response_map, ), ): pipeline = Pipeline.create( @@ -294,7 +307,7 @@ def test_metabase_ingest_failure(pytestconfig, tmp_path, default_json_response_m "filename": f"{tmp_path}/metabase_mces.json", }, }, - } + }, ) pipeline.run() try: diff --git a/metadata-ingestion/tests/integration/mode/test_mode.py b/metadata-ingestion/tests/integration/mode/test_mode.py index 7f1e3935aa0fa1..cae50912a18038 100644 --- a/metadata-ingestion/tests/integration/mode/test_mode.py +++ b/metadata-ingestion/tests/integration/mode/test_mode.py @@ -127,7 +127,7 @@ def test_mode_ingest_success(pytestconfig, tmp_path): "filename": f"{tmp_path}/mode_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -167,7 +167,7 @@ def test_mode_ingest_failure(pytestconfig, tmp_path): "filename": f"{tmp_path}/mode_mces.json", }, }, - } + }, ) pipeline.run() with pytest.raises(PipelineExecutionError) as exec_error: @@ -186,7 +186,7 @@ def test_mode_ingest_json_empty(pytestconfig, tmp_path): with patch( "datahub.ingestion.source.mode.requests.Session", side_effect=lambda *args, **kwargs: MockResponseJson( - json_empty_list=["https://app.mode.com/api/modeuser"] + json_empty_list=["https://app.mode.com/api/modeuser"], ), ): global test_resources_dir @@ -210,7 +210,7 @@ def test_mode_ingest_json_empty(pytestconfig, tmp_path): "filename": f"{tmp_path}/mode_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status(raise_warnings=True) @@ -221,7 +221,7 @@ def test_mode_ingest_json_failure(pytestconfig, tmp_path): with patch( "datahub.ingestion.source.mode.requests.Session", side_effect=lambda *args, **kwargs: MockResponseJson( - json_error_list=["https://app.mode.com/api/modeuser"] + json_error_list=["https://app.mode.com/api/modeuser"], ), ): global test_resources_dir @@ -245,7 +245,7 @@ def test_mode_ingest_json_failure(pytestconfig, tmp_path): "filename": f"{tmp_path}/mode_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status(raise_warnings=False) diff --git a/metadata-ingestion/tests/integration/mongodb/test_mongodb.py b/metadata-ingestion/tests/integration/mongodb/test_mongodb.py index 6dc8bb295ed455..59d384fe8a1b90 100644 --- a/metadata-ingestion/tests/integration/mongodb/test_mongodb.py +++ b/metadata-ingestion/tests/integration/mongodb/test_mongodb.py @@ -10,7 +10,8 @@ def test_mongodb_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time test_resources_dir = pytestconfig.rootpath / "tests/integration/mongodb" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "mongo" + test_resources_dir / "docker-compose.yml", + "mongo", ) as docker_services: wait_for_port(docker_services, "testmongodb", 27017) @@ -35,7 +36,7 @@ def test_mongodb_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time "filename": f"{tmp_path}/mongodb_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -67,7 +68,7 @@ def test_mongodb_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time "filename": f"{tmp_path}/mongodb_mces_small_schema_size.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/mysql/test_mysql.py b/metadata-ingestion/tests/integration/mysql/test_mysql.py index c19198c7d2bbd0..986231a34a7700 100644 --- a/metadata-ingestion/tests/integration/mysql/test_mysql.py +++ b/metadata-ingestion/tests/integration/mysql/test_mysql.py @@ -31,7 +31,8 @@ def is_mysql_up(container_name: str, port: int) -> bool: @pytest.fixture(scope="module") def mysql_runner(docker_compose_runner, pytestconfig, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "mysql" + test_resources_dir / "docker-compose.yml", + "mysql", ) as docker_services: wait_for_port( docker_services, @@ -109,5 +110,6 @@ def test_mysql_test_connection(mysql_runner, config_dict, is_success): test_connection_helpers.assert_basic_connectivity_success(report) else: test_connection_helpers.assert_basic_connectivity_failure( - report, "Connection refused" + report, + "Connection refused", ) diff --git a/metadata-ingestion/tests/integration/nifi/test_nifi.py b/metadata-ingestion/tests/integration/nifi/test_nifi.py index 924e854a47e4eb..43cc8b5b1b3ba0 100644 --- a/metadata-ingestion/tests/integration/nifi/test_nifi.py +++ b/metadata-ingestion/tests/integration/nifi/test_nifi.py @@ -22,7 +22,8 @@ def test_resources_dir(pytestconfig): @pytest.fixture(scope="module") def loaded_nifi(docker_compose_runner, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "nifi" + test_resources_dir / "docker-compose.yml", + "nifi", ) as docker_services: wait_for_port( docker_services, @@ -56,7 +57,10 @@ def loaded_nifi(docker_compose_runner, test_resources_dir): @freeze_time(FROZEN_TIME) def test_nifi_ingest_standalone( - loaded_nifi, pytestconfig, tmp_path, test_resources_dir + loaded_nifi, + pytestconfig, + tmp_path, + test_resources_dir, ): # Wait for nifi standalone to execute all lineage processors, max wait time 120 seconds url = "http://localhost:9443/nifi-api/flow/process-groups/80404c81-017d-1000-e8e8-af7420af06c1" @@ -92,7 +96,7 @@ def test_nifi_ingest_standalone( "type": "file", "config": {"filename": "./nifi_mces.json"}, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -150,7 +154,7 @@ def test_nifi_ingest_cluster(loaded_nifi, pytestconfig, tmp_path, test_resources "type": "file", "config": {"filename": "./nifi_mces_cluster.json"}, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/okta/test_okta.py b/metadata-ingestion/tests/integration/okta/test_okta.py index 10148273c93666..35d8ec258864c4 100644 --- a/metadata-ingestion/tests/integration/okta/test_okta.py +++ b/metadata-ingestion/tests/integration/okta/test_okta.py @@ -59,7 +59,7 @@ def run_ingest( recipe, ): with patch( - "datahub.ingestion.source.identity.okta.OktaClient" + "datahub.ingestion.source.identity.okta.OktaClient", ) as MockClient, patch( "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", mock_datahub_graph, @@ -78,7 +78,7 @@ def run_ingest( def test_okta_config(): config = OktaConfig.parse_obj( - dict(okta_domain="test.okta.com", okta_api_token="test-token") + dict(okta_domain="test.okta.com", okta_api_token="test-token"), ) # Sanity on required configurations @@ -108,7 +108,8 @@ def test_okta_source_default_configs(pytestconfig, mock_datahub_graph, tmp_path) run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - _init_mock_okta_client, test_resources_dir=test_resources_dir + _init_mock_okta_client, + test_resources_dir=test_resources_dir, ), recipe=default_recipe(output_file_path), ) @@ -133,7 +134,8 @@ def test_okta_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_pa run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - _init_mock_okta_client, test_resources_dir=test_resources_dir + _init_mock_okta_client, + test_resources_dir=test_resources_dir, ), recipe=new_recipe, ) @@ -148,7 +150,9 @@ def test_okta_source_ingestion_disabled(pytestconfig, mock_datahub_graph, tmp_pa @freeze_time(FROZEN_TIME) @pytest.mark.asyncio def test_okta_source_include_deprovisioned_suspended_users( - pytestconfig, mock_datahub_graph, tmp_path + pytestconfig, + mock_datahub_graph, + tmp_path, ): test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta" @@ -162,7 +166,8 @@ def test_okta_source_include_deprovisioned_suspended_users( run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - _init_mock_okta_client, test_resources_dir=test_resources_dir + _init_mock_okta_client, + test_resources_dir=test_resources_dir, ), recipe=new_recipe, ) @@ -187,7 +192,8 @@ def test_okta_source_custom_user_name_regex(pytestconfig, mock_datahub_graph, tm run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - _init_mock_okta_client, test_resources_dir=test_resources_dir + _init_mock_okta_client, + test_resources_dir=test_resources_dir, ), recipe=new_recipe, ) @@ -218,7 +224,8 @@ def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub pipeline1 = run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - _init_mock_okta_client, test_resources_dir=test_resources_dir + _init_mock_okta_client, + test_resources_dir=test_resources_dir, ), recipe=new_recipe, ) @@ -234,7 +241,8 @@ def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub pipeline2 = run_ingest( mock_datahub_graph=mock_datahub_graph, mocked_functions_reference=partial( - overwrite_group_in_mocked_data, test_resources_dir=test_resources_dir + overwrite_group_in_mocked_data, + test_resources_dir=test_resources_dir, ), recipe=new_recipe, ) @@ -244,10 +252,12 @@ def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub # # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline1, expected_providers=1 + pipeline=pipeline1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline2, expected_providers=1 + pipeline=pipeline2, + expected_providers=1, ) # Perform all assertions on the states. The deleted group should not be @@ -256,7 +266,7 @@ def test_okta_stateful_ingestion(pytestconfig, tmp_path, mock_time, mock_datahub state2 = checkpoint2.state difference_group_urns = list( - state1.get_urns_not_in(type="corpGroup", other_checkpoint_state=state2) + state1.get_urns_not_in(type="corpGroup", other_checkpoint_state=state2), ) assert len(difference_group_urns) == 1 @@ -273,7 +283,10 @@ def overwrite_group_in_mocked_data(test_resources_dir, MockClient): # Initializes a Mock Okta Client to return users from okta_users.json and groups from okta_groups.json. def _init_mock_okta_client( - test_resources_dir, MockClient, mock_users_json=None, mock_groups_json=None + test_resources_dir, + MockClient, + mock_users_json=None, + mock_groups_json=None, ): okta_users_json_file = ( test_resources_dir / "okta_users.json" @@ -303,7 +316,7 @@ def _init_mock_okta_client( users_next_future = asyncio.Future() # type: asyncio.Future users_next_future.set_result( # users, err - ([users[-1]], None) + ([users[-1]], None), ) users_resp_mock.next.return_value = users_next_future @@ -311,7 +324,7 @@ def _init_mock_okta_client( list_users_future = asyncio.Future() # type: asyncio.Future list_users_future.set_result( # users, resp, err - (users[0:-1], users_resp_mock, None) + (users[0:-1], users_resp_mock, None), ) MockClient().list_users.return_value = list_users_future @@ -321,7 +334,7 @@ def _init_mock_okta_client( groups_next_future = asyncio.Future() # type: asyncio.Future groups_next_future.set_result( # groups, err - ([groups[-1]], None) + ([groups[-1]], None), ) groups_resp_mock.next.return_value = groups_next_future @@ -339,7 +352,7 @@ def _init_mock_okta_client( group_users_next_future = asyncio.Future() # type: asyncio.Future group_users_next_future.set_result( # users, err - ([users[-1]], None) + ([users[-1]], None), ) group_users_resp_mock.next.return_value = group_users_next_future # users, resp, err @@ -347,7 +360,7 @@ def _init_mock_okta_client( # Exclude last user from being in any groups filtered_users = [user for user in users if user.id != USER_ID_NOT_IN_GROUPS] list_group_users_future.set_result( - (filtered_users, group_users_resp_mock, None) + (filtered_users, group_users_resp_mock, None), ) list_group_users_result_values.append(list_group_users_future) diff --git a/metadata-ingestion/tests/integration/oracle/common.py b/metadata-ingestion/tests/integration/oracle/common.py index 9e2cc42ef10256..ace1b3f8ced3ca 100644 --- a/metadata-ingestion/tests/integration/oracle/common.py +++ b/metadata-ingestion/tests/integration/oracle/common.py @@ -57,7 +57,7 @@ def fetchall(self): self.rem_pos, self.search_condition, self.delete_rule, - ) + ), ] @@ -89,7 +89,7 @@ def execute(self): self.generated, self.default_on_nul, self.identity_options, - ] + ], ] @@ -161,7 +161,7 @@ def get_recipe_source(self) -> dict: "config": { **self.get_default_recipe_config().dict(), }, - } + }, } def get_username(self) -> str: @@ -207,7 +207,7 @@ def get_recipe_sink(self, output_path: str) -> dict: "config": { "filename": output_path, }, - } + }, } def get_output_mce_path(self): @@ -232,6 +232,7 @@ def apply(self): self.pytestconfig, output_path=output_path, golden_path="{}/{}".format( - self.get_test_resource_dir(), self.golden_file_name + self.get_test_resource_dir(), + self.golden_file_name, ), ) diff --git a/metadata-ingestion/tests/integration/postgres/test_postgres.py b/metadata-ingestion/tests/integration/postgres/test_postgres.py index 0c7b017949ac38..f1533345338900 100644 --- a/metadata-ingestion/tests/integration/postgres/test_postgres.py +++ b/metadata-ingestion/tests/integration/postgres/test_postgres.py @@ -30,7 +30,8 @@ def is_postgres_up(container_name: str) -> bool: @pytest.fixture(scope="module") def postgres_runner(docker_compose_runner, pytestconfig, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "postgres" + test_resources_dir / "docker-compose.yml", + "postgres", ) as docker_services: wait_for_port( docker_services, @@ -45,7 +46,11 @@ def postgres_runner(docker_compose_runner, pytestconfig, test_resources_dir): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_postgres_ingest_with_db( - postgres_runner, pytestconfig, test_resources_dir, tmp_path, mock_time + postgres_runner, + pytestconfig, + test_resources_dir, + tmp_path, + mock_time, ): # Run the metadata ingestion pipeline. config_file = ( @@ -66,7 +71,11 @@ def test_postgres_ingest_with_db( @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_postgres_ingest_with_all_db( - postgres_runner, pytestconfig, test_resources_dir, tmp_path, mock_time + postgres_runner, + pytestconfig, + test_resources_dir, + tmp_path, + mock_time, ): # Run the metadata ingestion pipeline. config_file = ( diff --git a/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py b/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py index 00dc79ed38cfba..75a01d17c50bf4 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py +++ b/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py @@ -18,8 +18,8 @@ def scan_init_response(request, context): w_id_vs_response: Dict[str, Any] = { "64ED5CAD-7C10-4684-8180-826122881108": { - "id": "4674efd1-603c-4129-8d82-03cf2be05aff" - } + "id": "4674efd1-603c-4129-8d82-03cf2be05aff", + }, } return w_id_vs_response[workspace_id] @@ -33,8 +33,8 @@ def admin_datasets_response(request, context): "id": "05169CD2-E713-41E6-9600-1D8066D95445", "name": "library-dataset", "webUrl": "http://localhost/groups/64ED5CAD-7C10-4684-8180-826122881108/datasets/05169CD2-E713-41E6-9600-1D8066D95445", - } - ] + }, + ], } if "ba0130a1-5b03-40de-9535-b34e778ea6ed" in request.query: @@ -44,8 +44,8 @@ def admin_datasets_response(request, context): "id": "ba0130a1-5b03-40de-9535-b34e778ea6ed", "name": "hr_pbi_test", "webUrl": "http://localhost/groups/64ED5CAD-7C10-4684-8180-826122881108/datasets/ba0130a1-5b03-40de-9535-b34e778ea6ed", - } - ] + }, + ], } @@ -67,7 +67,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "name": "demo-workspace", "type": "Workspace", "state": "Active", - } + }, ], }, }, @@ -89,8 +89,8 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "displayName": "test_dashboard", "embedUrl": "https://localhost/dashboards/embed/1", "webUrl": "https://localhost/dashboards/web/1", - } - ] + }, + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/reports/5b218778-e7a5-4d73-8187-f10824047715/users": { @@ -116,7 +116,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "principalType": "User", "reportUserAccessRight": "Owner", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/dashboards/7D668CAD-7FFC-4505-9215-655BCA5BEBAE/users": { @@ -140,7 +140,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "graphId": "C9EE53F2-88EA-4711-A173-AF0515A5REWS", "principalType": "User", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/dashboards/7D668CAD-8FFC-4505-9215-655BCA5BEBAE/users": { @@ -164,7 +164,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "graphId": "C9EE53F2-88EA-4711-A173-AF0515A5REWS", "principalType": "User", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/dashboards/7D668CAD-7FFC-4505-9215-655BCA5BEBAE/tiles": { @@ -184,7 +184,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "embedUrl": "https://localhost/tiles/embed/2", "datasetId": "ba0130a1-5b03-40de-9535-b34e778ea6ed", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/datasets/05169CD2-E713-41E6-9600-1D8066D95445/datasources": { @@ -200,7 +200,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "server": "foo", }, }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/workspaces/scanStatus/4674efd1-603c-4129-8d82-03cf2be05aff": { @@ -238,12 +238,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": "dummy", - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -251,12 +251,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Snowflake.Databases("hp123rt5.ap-southeast-2.fakecomputing.com","PBI_TEST_WAREHOUSE_PROD",[Role="PBI_TEST_MEMBER"]),\n PBI_TEST_Database = Source{[Name="PBI_TEST",Kind="Database"]}[Data],\n TEST_Schema = PBI_TEST_Database{[Name="TEST",Kind="Schema"]}[Data],\n TESTTABLE_Table = TEST_Schema{[Name="TESTTABLE",Kind="Table"]}[Data]\nin\n TESTTABLE_Table', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -264,12 +264,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Value.NativeQuery(Snowflake.Databases("bu20658.ap-southeast-2.snowflakecomputing.com","operations_analytics_warehouse_prod",[Role="OPERATIONS_ANALYTICS_MEMBER"]){[Name="OPERATIONS_ANALYTICS"]}[Data], "SELECT#(lf)concat((UPPER(REPLACE(SELLER,\'-\',\'\'))), MONTHID) as AGENT_KEY,#(lf)concat((UPPER(REPLACE(CLIENT_DIRECTOR,\'-\',\'\'))), MONTHID) as CD_AGENT_KEY,#(lf) *#(lf)FROM#(lf)OPERATIONS_ANALYTICS.TRANSFORMED_PROD.V_APS_SME_UNITS_V4", null, [EnableFolding=true]),\n #"Added Conditional Column" = Table.AddColumn(Source, "SME Units ENT", each if [DEAL_TYPE] = "SME Unit" then [UNIT] else 0),\n #"Added Conditional Column1" = Table.AddColumn(#"Added Conditional Column", "Banklink Units", each if [DEAL_TYPE] = "Banklink" then [UNIT] else 0),\n #"Removed Columns" = Table.RemoveColumns(#"Added Conditional Column1",{"Banklink Units"}),\n #"Added Custom" = Table.AddColumn(#"Removed Columns", "Banklink Units", each if [DEAL_TYPE] = "Banklink" and [SALES_TYPE] = "3 - Upsell"\nthen [UNIT]\n\nelse if [SALES_TYPE] = "Adjusted BL Migration"\nthen [UNIT]\n\nelse 0),\n #"Added Custom1" = Table.AddColumn(#"Added Custom", "SME Units in $ (*$361)", each if [DEAL_TYPE] = "SME Unit" \nand [SALES_TYPE] <> "4 - Renewal"\n then [UNIT] * 361\nelse 0),\n #"Added Custom2" = Table.AddColumn(#"Added Custom1", "Banklink in $ (*$148)", each [Banklink Units] * 148)\nin\n #"Added Custom2"', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -277,12 +277,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Value.NativeQuery(Snowflake.Databases("xaa48144.snowflakecomputing.com","GSL_TEST_WH",[Role="ACCOUNTADMIN"]){[Name="GSL_TEST_DB"]}[Data], "select A.name from GSL_TEST_DB.PUBLIC.SALES_ANALYST as A inner join GSL_TEST_DB.PUBLIC.SALES_FORECAST as B on A.name = B.name where startswith(A.name, \'mo\')", null, [EnableFolding=true])\nin\n Source', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -290,12 +290,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Oracle.Database("localhost:1521/salesdb.domain.com", [HierarchicalNavigation=true]), HR = Source{[Schema="HR"]}[Data], EMPLOYEES1 = HR{[Name="EMPLOYEES"]}[Data] \n in EMPLOYEES1', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -303,12 +303,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = PostgreSQL.Database("localhost" , "mics" ),\n public_order_date = Source{[Schema="public",Item="order_date"]}[Data] \n in \n public_order_date', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, ], @@ -324,12 +324,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Sql.Database("localhost", "library"),\n dbo_book_issue = Source{[Schema="dbo",Item="book_issue"]}[Data]\n in dbo_book_issue', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -337,12 +337,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = Sql.Database("AUPRDWHDB", "COMMOPSDB", [Query="select *,#(lf)concat((UPPER(REPLACE(CLIENT_DIRECTOR,\'-\',\'\'))), MONTH_WID) as CD_AGENT_KEY,#(lf)concat((UPPER(REPLACE(CLIENT_MANAGER_CLOSING_MONTH,\'-\',\'\'))), MONTH_WID) as AGENT_KEY#(lf)#(lf)from V_PS_CD_RETENTION", CommandTimeout=#duration(0, 1, 30, 0)]),\n #"Changed Type" = Table.TransformColumnTypes(Source,{{"mth_date", type date}}),\n #"Added Custom" = Table.AddColumn(#"Changed Type", "Month", each Date.Month([mth_date])),\n #"Added Custom1" = Table.AddColumn(#"Added Custom", "TPV Opening", each if [Month] = 1 then [TPV_AMV_OPENING]\nelse if [Month] = 2 then 0\nelse if [Month] = 3 then 0\nelse if [Month] = 4 then [TPV_AMV_OPENING]\nelse if [Month] = 5 then 0\nelse if [Month] = 6 then 0\nelse if [Month] = 7 then [TPV_AMV_OPENING]\nelse if [Month] = 8 then 0\nelse if [Month] = 9 then 0\nelse if [Month] = 10 then [TPV_AMV_OPENING]\nelse if [Month] = 11 then 0\nelse if [Month] = 12 then 0\n\nelse 0)\nin\n #"Added Custom1"', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], }, { @@ -375,13 +375,13 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "description": "column description", "expression": "let\n x", "isHidden": False, - } + }, ], "isHidden": False, "source": [ { - "expression": 'let\n Source = Sql.Database("database.sql.net", "analytics", [Query="select * from analytics.sales_revenue"])\nin\n Source' - } + "expression": 'let\n Source = Sql.Database("database.sql.net", "analytics", [Query="select * from analytics.sales_revenue"])\nin\n Source', + }, ], }, ], @@ -391,7 +391,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None { "id": "7D668CAD-7FFC-4505-9215-655BCA5BEBAE", "isReadOnly": True, - } + }, ], "reports": [ { @@ -400,10 +400,10 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "reportType": "PowerBIReport", "name": "SalesMarketing", "description": "Acryl sales marketing report", - } + }, ], }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/workspaces/modified": { @@ -433,8 +433,8 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "description": "Acryl sales marketing report", "webUrl": "https://app.powerbi.com/groups/f089354e-8366-4e18-aea3-4cb4a3a50b48/reports/5b218778-e7a5-4d73-8187-f10824047715", "embedUrl": "https://app.powerbi.com/reportEmbed?reportId=5b218778-e7a5-4d73-8187-f10824047715&groupId=f089354e-8366-4e18-aea3-4cb4a3a50b48", - } - ] + }, + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/reports/5b218778-e7a5-4d73-8187-f10824047715": { @@ -465,7 +465,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "name": "ReportSection1", "order": "1", }, - ] + ], }, }, } @@ -538,7 +538,7 @@ def test_admin_only_apis(mock_msal, pytestconfig, tmp_path, mock_time, requests_ "filename": f"{tmp_path}/powerbi_admin_only_mces.json", }, }, - } + }, ) pipeline.run() @@ -555,7 +555,11 @@ def test_admin_only_apis(mock_msal, pytestconfig, tmp_path, mock_time, requests_ @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_most_config_and_modified_since( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" @@ -603,7 +607,7 @@ def test_most_config_and_modified_since( "filename": f"{tmp_path}/powerbi_test_most_config_and_modified_since_mces.json", }, }, - } + }, ) pipeline.run() diff --git a/metadata-ingestion/tests/integration/powerbi/test_m_parser.py b/metadata-ingestion/tests/integration/powerbi/test_m_parser.py index 0d85d370265cae..4c40e3e9c7423a 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_m_parser.py +++ b/metadata-ingestion/tests/integration/powerbi/test_m_parser.py @@ -95,7 +95,9 @@ def get_data_platform_tables_with_dummy_table( def get_default_instances( override_config: dict = {}, ) -> Tuple[ - PipelineContext, PowerBiDashboardSourceConfig, AbstractDataPlatformInstanceResolver + PipelineContext, + PowerBiDashboardSourceConfig, + AbstractDataPlatformInstanceResolver, ]: config: PowerBiDashboardSourceConfig = PowerBiDashboardSourceConfig.parse_obj( { @@ -104,7 +106,7 @@ def get_default_instances( "client_secret": "bar", "enable_advance_lineage_sql_construct": False, **override_config, - } + }, ) platform_instance_resolver: AbstractDataPlatformInstanceResolver = ( @@ -608,7 +610,7 @@ def test_multi_source_table(): ctx=ctx, config=config, platform_instance_resolver=platform_instance_resolver, - ) + ), ) assert len(data_platform_tables) == 2 @@ -643,7 +645,7 @@ def test_table_combine(): ctx=ctx, config=config, platform_instance_resolver=platform_instance_resolver, - ) + ), ) assert len(data_platform_tables) == 2 @@ -758,11 +760,11 @@ def test_sqlglot_parser(): "bu10758.ap-unknown-2.fakecomputing.com": { "platform_instance": "sales_deployment", "env": "PROD", - } + }, }, "native_query_parsing": True, "enable_advance_lineage_sql_construct": True, - } + }, ) lineage: List[datahub.ingestion.source.powerbi.m_query.data_classes.Lineage] = ( @@ -863,9 +865,9 @@ def test_databricks_catalog_pattern_2(): "abc.cloud.databricks.com": { "metastore": "central_metastore", "platform_instance": "abc", - } - } - } + }, + }, + }, ) data_platform_tables: List[DataPlatformTable] = parser.get_upstream_tables( table, @@ -897,11 +899,11 @@ def test_sqlglot_parser_2(): "0DD93C6BD5A6.snowflakecomputing.com": { "platform_instance": "sales_deployment", "env": "PROD", - } + }, }, "native_query_parsing": True, "enable_advance_lineage_sql_construct": True, - } + }, ) lineage: List[datahub.ingestion.source.powerbi.m_query.data_classes.Lineage] = ( @@ -1066,7 +1068,8 @@ def test_unsupported_data_platform(): ) info_entries: dict = reporter._structured_logs._entries.get( - StructuredLogLevel.INFO, {} + StructuredLogLevel.INFO, + {}, ) # type :ignore is_entry_present: bool = False @@ -1159,7 +1162,8 @@ def test_m_query_timeout(mock_get_lark_parser): ) warn_entries: dict = reporter._structured_logs._entries.get( - StructuredLogLevel.WARN, {} + StructuredLogLevel.WARN, + {}, ) # type :ignore is_entry_present: bool = False diff --git a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py index 7f62e433bc8014..07f6b68a112f96 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py +++ b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py @@ -56,16 +56,16 @@ def scan_init_response(request, context): w_id_vs_response: Dict[str, Any] = { "64ED5CAD-7C10-4684-8180-826122881108": { - "id": "4674efd1-603c-4129-8d82-03cf2be05aff" + "id": "4674efd1-603c-4129-8d82-03cf2be05aff", }, "64ED5CAD-7C22-4684-8180-826122881108": { - "id": "a674efd1-603c-4129-8d82-03cf2be05aff" + "id": "a674efd1-603c-4129-8d82-03cf2be05aff", }, "64ED5CAD-7C10-4684-8180-826122881108||64ED5CAD-7C22-4684-8180-826122881108": { - "id": "a674efd1-603c-4129-8d82-03cf2be05aff" + "id": "a674efd1-603c-4129-8d82-03cf2be05aff", }, "A8D655A6-F521-477E-8C22-255018583BF4": { - "id": "62DAF926-0B18-4FF1-982C-2A3EB6B8F0E4" + "id": "62DAF926-0B18-4FF1-982C-2A3EB6B8F0E4", }, "C5DA6EA8-625E-4AB1-90B6-CAEA0BF9F492": { "id": "81B02907-E2A3-45C3-B505-3781839C8CAA", @@ -84,7 +84,9 @@ def read_mock_data(path: Union[Path, str]) -> dict: def register_mock_api( - pytestconfig: pytest.Config, request_mock: Any, override_data: Optional[dict] = None + pytestconfig: pytest.Config, + request_mock: Any, + override_data: Optional[dict] = None, ) -> None: default_mock_data_path = ( pytestconfig.rootpath @@ -163,7 +165,7 @@ def test_powerbi_ingest( "filename": f"{tmp_path}/powerbi_mces.json", }, }, - } + }, ) pipeline.run() @@ -194,7 +196,7 @@ def test_powerbi_workspace_type_filter( pytestconfig=pytestconfig, override_data=read_mock_data( pytestconfig.rootpath - / "tests/integration/powerbi/mock_data/workspace_type_filter.json" + / "tests/integration/powerbi/mock_data/workspace_type_filter.json", ), ) @@ -222,7 +224,7 @@ def test_powerbi_workspace_type_filter( "filename": f"{tmp_path}/powerbi_mces.json", }, }, - } + }, ) pipeline.run() @@ -266,7 +268,7 @@ def test_powerbi_ingest_patch_disabled( "filename": f"{tmp_path}/powerbi_mces.json", }, }, - } + }, ) pipeline.run() @@ -285,7 +287,8 @@ def test_powerbi_ingest_patch_disabled( @pytest.mark.integration def test_powerbi_test_connection_success(mock_msal): report = test_connection_helpers.run_test_connection( - PowerBiDashboardSource, default_source_config() + PowerBiDashboardSource, + default_source_config(), ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -294,10 +297,12 @@ def test_powerbi_test_connection_success(mock_msal): @pytest.mark.integration def test_powerbi_test_connection_failure(): report = test_connection_helpers.run_test_connection( - PowerBiDashboardSource, default_source_config() + PowerBiDashboardSource, + default_source_config(), ) test_connection_helpers.assert_basic_connectivity_failure( - report, "Unable to get authority configuration" + report, + "Unable to get authority configuration", ) @@ -333,7 +338,7 @@ def test_powerbi_platform_instance_ingest( "filename": output_path, }, }, - } + }, ) pipeline.run() @@ -380,7 +385,7 @@ def test_powerbi_ingest_urn_lower_case( "filename": f"{tmp_path}/powerbi_lower_case_urn_mces.json", }, }, - } + }, ) pipeline.run() @@ -424,7 +429,7 @@ def test_override_ownership( "filename": f"{tmp_path}/powerbi_mces_disabled_ownership.json", }, }, - } + }, ) pipeline.run() @@ -472,7 +477,7 @@ def test_scan_all_workspaces( "filename": f"{tmp_path}/powerbi_mces_scan_all_workspaces.json", }, }, - } + }, ) pipeline.run() @@ -517,7 +522,7 @@ def test_extract_reports( "filename": f"{tmp_path}/powerbi_report_mces.json", }, }, - } + }, ) pipeline.run() @@ -556,7 +561,7 @@ def test_extract_lineage( "dataset_type_mapping": { "PostgreSql": {"platform_instance": "operational_instance"}, "Oracle": { - "platform_instance": "high_performance_production_unit" + "platform_instance": "high_performance_production_unit", }, "Sql": {"platform_instance": "reporting-db"}, "Snowflake": {"platform_instance": "sn-2"}, @@ -569,7 +574,7 @@ def test_extract_lineage( "filename": f"{tmp_path}/powerbi_lineage_mces.json", }, }, - } + }, ) pipeline.run() @@ -614,7 +619,7 @@ def test_extract_endorsements( "filename": f"{tmp_path}/powerbi_endorsement_mces.json", }, }, - } + }, ) pipeline.run() @@ -663,7 +668,7 @@ def test_admin_access_is_not_allowed( "dataset_type_mapping": { "PostgreSql": {"platform_instance": "operational_instance"}, "Oracle": { - "platform_instance": "high_performance_production_unit" + "platform_instance": "high_performance_production_unit", }, "Sql": {"platform_instance": "reporting-db"}, "Snowflake": {"platform_instance": "sn-2"}, @@ -676,7 +681,7 @@ def test_admin_access_is_not_allowed( "filename": f"{tmp_path}/golden_test_admin_access_not_allowed_mces.json", }, }, - } + }, ) pipeline.run() @@ -724,7 +729,7 @@ def test_workspace_container( "filename": f"{tmp_path}/powerbi_container_mces.json", }, }, - } + }, ) pipeline.run() @@ -764,7 +769,7 @@ def test_access_token_expiry_with_long_expiry( "filename": f"{tmp_path}/powerbi_access_token_mces.json", }, }, - } + }, ) # for long expiry, the token should only be requested once. @@ -804,7 +809,7 @@ def test_access_token_expiry_with_short_expiry( "filename": f"{tmp_path}/powerbi_access_token_mces.json", }, }, - } + }, ) # for short expiry, the token should be requested when expires. @@ -820,7 +825,8 @@ def test_access_token_expiry_with_short_expiry( def dataset_type_mapping_set_to_all_platform(pipeline: Pipeline) -> None: source_config: PowerBiDashboardSourceConfig = cast( - PowerBiDashboardSource, pipeline.source + PowerBiDashboardSource, + pipeline.source, ).source_config assert source_config.dataset_type_mapping is not None @@ -839,7 +845,11 @@ def dataset_type_mapping_set_to_all_platform(pipeline: Pipeline) -> None: @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) @pytest.mark.integration def test_dataset_type_mapping_should_set_to_all( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): """ Here we don't need to run the pipeline. We need to verify dataset_type_mapping is set to default dataplatform @@ -865,7 +875,7 @@ def test_dataset_type_mapping_should_set_to_all( "filename": f"{tmp_path}/powerbi_lower_case_urn_mces.json", }, }, - } + }, ) dataset_type_mapping_set_to_all_platform(pipeline) @@ -875,7 +885,11 @@ def test_dataset_type_mapping_should_set_to_all( @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) @pytest.mark.integration def test_dataset_type_mapping_error( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): """ Here we don't need to run the pipeline. We need to verify if both dataset_type_mapping and server_to_platform_instance @@ -894,7 +908,7 @@ def test_dataset_type_mapping_error( "server_to_platform_instance": { "localhost": { "platform_instance": "test", - } + }, }, }, }, @@ -904,14 +918,18 @@ def test_dataset_type_mapping_error( "filename": f"{tmp_path}/powerbi_lower_case_urn_mces.json", }, }, - } + }, ) @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_server_to_platform_map( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" new_config: dict = { @@ -951,7 +969,7 @@ def test_server_to_platform_map( "filename": output_path, }, }, - } + }, ) pipeline.run() @@ -988,7 +1006,8 @@ def validate_pipeline(pipeline: Pipeline) -> None: ) # Fetch actual reports reports: List[Report] = cast( - PowerBiDashboardSource, pipeline.source + PowerBiDashboardSource, + pipeline.source, ).powerbi_client.get_reports(workspace=mock_workspace) assert len(reports) == 2 @@ -1031,7 +1050,8 @@ def validate_pipeline(pipeline: Pipeline) -> None: pages=[ Page( id="{}.{}".format( - report[Constant.ID], page[Constant.NAME].replace(" ", "_") + report[Constant.ID], + page[Constant.NAME].replace(" ", "_"), ), name=page[Constant.NAME], displayName=page[Constant.DISPLAY_NAME], @@ -1095,7 +1115,7 @@ def test_reports_with_failed_page_request( "webUrl": "https://app.powerbi.com/groups/64ED5CAD-7C10-4684-8180-826122881108/reports/e9fd6b0b-d8c8-4265-8c44-67e183aebf97", "embedUrl": "https://app.powerbi.com/reportEmbed?reportId=e9fd6b0b-d8c8-4265-8c44-67e183aebf97&groupId=64ED5CAD-7C10-4684-8180-826122881108", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/reports/5b218778-e7a5-4d73-8187-f10824047715": { @@ -1139,7 +1159,7 @@ def test_reports_with_failed_page_request( "name": "ReportSection1", "order": "1", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/reports/e9fd6b0b-d8c8-4265-8c44-67e183aebf97/pages": { @@ -1149,7 +1169,7 @@ def test_reports_with_failed_page_request( "error": { "code": "InvalidRequest", "message": "Request is currently not supported for RDL reports", - } + }, }, }, }, @@ -1172,7 +1192,7 @@ def test_reports_with_failed_page_request( "filename": f"{tmp_path}powerbi_reports_with_failed_page_request_mces.json", }, }, - } + }, ) validate_pipeline(pipeline) @@ -1233,14 +1253,14 @@ def test_independent_datasets_extraction( "source": [ { "expression": "dummy", - } + }, ], - } + }, ], }, ], }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/dashboards": { @@ -1267,7 +1287,7 @@ def test_independent_datasets_extraction( "filename": f"{tmp_path}/powerbi_independent_mces.json", }, }, - } + }, ) pipeline.run() @@ -1323,7 +1343,7 @@ def test_cll_extraction( "filename": f"{tmp_path}/powerbi_cll_mces.json", }, }, - } + }, ) pipeline.run() @@ -1353,7 +1373,7 @@ def test_cll_extraction_flags( default_conf: dict = default_source_config() pattern: str = re.escape( - "Enable all these flags in recipe: ['native_query_parsing', 'enable_advance_lineage_sql_construct', 'extract_lineage']" + "Enable all these flags in recipe: ['native_query_parsing', 'enable_advance_lineage_sql_construct', 'extract_lineage']", ) with pytest.raises(Exception, match=pattern): @@ -1373,7 +1393,7 @@ def test_cll_extraction_flags( "filename": f"{tmp_path}/powerbi_cll_mces.json", }, }, - } + }, ) @@ -1392,7 +1412,7 @@ def test_powerbi_cross_workspace_reference_info_message( request_mock=requests_mock, override_data=read_mock_data( path=pytestconfig.rootpath - / "tests/integration/powerbi/mock_data/cross_workspace_mock_response.json" + / "tests/integration/powerbi/mock_data/cross_workspace_mock_response.json", ), ) @@ -1404,7 +1424,7 @@ def test_powerbi_cross_workspace_reference_info_message( "allow": [ "A8D655A6-F521-477E-8C22-255018583BF4", "C5DA6EA8-625E-4AB1-90B6-CAEA0BF9F492", - ] + ], } config["include_workspace_name_in_dataset_urn"] = True @@ -1424,7 +1444,7 @@ def test_powerbi_cross_workspace_reference_info_message( "filename": f"{tmp_path}/powerbi_mces.json", }, }, - } + }, ) pipeline.run() @@ -1433,7 +1453,8 @@ def test_powerbi_cross_workspace_reference_info_message( assert isinstance(pipeline.source, PowerBiDashboardSource) # to silent the lint info_entries: dict = pipeline.source.reporter._structured_logs._entries.get( - StructuredLogLevel.INFO, {} + StructuredLogLevel.INFO, + {}, ) # type :ignore is_entry_present: bool = False @@ -1469,7 +1490,7 @@ def common_app_ingest( request_mock=requests_mock, override_data=read_mock_data( path=pytestconfig.rootpath - / "tests/integration/powerbi/mock_data/workspace_with_app_mock_response.json" + / "tests/integration/powerbi/mock_data/workspace_with_app_mock_response.json", ), ) @@ -1480,7 +1501,7 @@ def common_app_ingest( config["workspace_id_pattern"] = { "allow": [ "8F756DE6-26AD-45FF-A201-44276FF1F561", - ] + ], } config.update(override_config) @@ -1500,7 +1521,7 @@ def common_app_ingest( "filename": output_mcp_path, }, }, - } + }, ) pipeline.run() @@ -1558,7 +1579,8 @@ def test_powerbi_app_ingest_info_message( assert isinstance(pipeline.source, PowerBiDashboardSource) # to silent the lint info_entries: dict = pipeline.source.reporter._structured_logs._entries.get( - StructuredLogLevel.INFO, {} + StructuredLogLevel.INFO, + {}, ) # type :ignore is_entry_present: bool = False diff --git a/metadata-ingestion/tests/integration/powerbi/test_profiling.py b/metadata-ingestion/tests/integration/powerbi/test_profiling.py index 78d35cf31a26d9..f32b33ad6792c4 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_profiling.py +++ b/metadata-ingestion/tests/integration/powerbi/test_profiling.py @@ -15,8 +15,8 @@ def scan_init_response(request, context): w_id_vs_response: Dict[str, Any] = { "64ED5CAD-7C10-4684-8180-826122881108": { - "id": "4674efd1-603c-4129-8d82-03cf2be05aff" - } + "id": "4674efd1-603c-4129-8d82-03cf2be05aff", + }, } return w_id_vs_response[workspace_id] @@ -29,8 +29,8 @@ def admin_datasets_response(request, context): "id": "05169CD2-E713-41E6-9600-1D8066D95445", "name": "library-dataset", "webUrl": "http://localhost/groups/64ED5CAD-7C10-4684-8180-826122881108/datasets/05169CD2-E713-41E6-9600-1D8066D95445", - } - ] + }, + ], } @@ -48,10 +48,10 @@ def execute_queries_response(request, context): "[max]": 34333, "[unique_count]": 15, }, - ] - } - ] - } + ], + }, + ], + }, ], } elif "COUNTROWS" in query: @@ -64,10 +64,10 @@ def execute_queries_response(request, context): { "[count]": 542300, }, - ] - } - ] - } + ], + }, + ], + }, ], } elif "TOPN" in query: @@ -95,10 +95,10 @@ def execute_queries_response(request, context): "[topic]": "normal matters", "[view_count]": 123455, }, - ] - } - ] - } + ], + }, + ], + }, ], } @@ -121,7 +121,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "name": "demo-workspace", "type": "Workspace", "state": "Active", - } + }, ], }, }, @@ -150,7 +150,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "server": "foo", }, }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/datasets/05169CD2-E713-41E6-9600-1D8066D95445/executeQueries": { @@ -193,12 +193,12 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "source": [ { "expression": 'let\n Source = PostgreSQL.Database("localhost" , "mics" ),\n public_order_date = Source{[Schema="public",Item="order_date"]}[Data] \n in \n public_order_date', - } + }, ], "datasourceUsages": [ { "datasourceInstanceId": "DCE90B40-84D6-467A-9A5C-648E830E72D3", - } + }, ], "columns": [ { @@ -229,7 +229,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "description": "column description", "expression": "let\n x", "isHidden": False, - } + }, ], }, ], @@ -238,7 +238,7 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None "dashboards": [], "reports": [], }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/admin/workspaces/getInfo": { @@ -322,7 +322,7 @@ def test_profiling(mock_msal, pytestconfig, tmp_path, mock_time, requests_mock): "filename": f"{tmp_path}/powerbi_profiling.json", }, }, - } + }, ) pipeline.run() diff --git a/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py b/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py index 84f7a87ce5d2d0..275d1c46ddd9d7 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py +++ b/metadata-ingestion/tests/integration/powerbi/test_stateful_ingestion.py @@ -71,7 +71,7 @@ def register_mock_api_state1(request_mock): "embedUrl": "https://localhost/dashboards/embed/1", "webUrl": "https://localhost/dashboards/web/1", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/44444444-7C10-4684-8180-826122881108/dashboards": { @@ -86,7 +86,7 @@ def register_mock_api_state1(request_mock): "embedUrl": "https://localhost/dashboards/embed/multi_workspace", "webUrl": "https://localhost/dashboards/web/multi_workspace", }, - ] + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/64ED5CAD-7C10-4684-8180-826122881108/dashboards/7D668CAD-7FFC-4505-9215-655BCA5BEBAE/tiles": { @@ -161,8 +161,8 @@ def register_mock_api_state2(request_mock): "displayName": "marketing", "embedUrl": "https://localhost/dashboards/embed/1", "webUrl": "https://localhost/dashboards/web/1", - } - ] + }, + ], }, }, "https://api.powerbi.com/v1.0/myorg/groups/44444444-7C10-4684-8180-826122881108/dashboards": { @@ -206,7 +206,7 @@ def default_source_config(): "allow": [ "64ED5CAD-7C10-4684-8180-826122881108", "44444444-7C10-4684-8180-826122881108", - ] + ], }, "dataset_type_mapping": { "PostgreSql": "postgres", @@ -234,7 +234,7 @@ def get_current_checkpoint_from_pipeline( for job_id in powerbi_source.state_provider._usecase_handlers.keys(): # for multi-workspace checkpoint, every good checkpoint will have an unique workspaceid suffix checkpoints[job_id] = powerbi_source.state_provider.get_current_checkpoint( - job_id + job_id, ) return checkpoints @@ -271,7 +271,12 @@ def ingest(pipeline_name, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_powerbi_stateful_ingestion( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock, mock_datahub_graph + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, + mock_datahub_graph, ): register_mock_api_state1(request_mock=requests_mock) pipeline1 = ingest("run1", tmp_path, mock_datahub_graph) @@ -289,17 +294,20 @@ def test_powerbi_stateful_ingestion( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline1, expected_providers=1 + pipeline=pipeline1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline2, expected_providers=1 + pipeline=pipeline2, + expected_providers=1, ) # Perform all assertions on the states. The deleted Dashboard should not be # part of the second state for job_id in checkpoint1.keys(): if isinstance(checkpoint1[job_id], Checkpoint) and isinstance( - checkpoint2[job_id], Checkpoint + checkpoint2[job_id], + Checkpoint, ): state1 = checkpoint1[job_id].state # type:ignore state2 = checkpoint2[job_id].state # type:ignore @@ -309,22 +317,22 @@ def test_powerbi_stateful_ingestion( == "powerbi_stale_entity_removal_64ED5CAD-7C10-4684-8180-826122881108" ): difference_dashboard_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_dashboard_urns) == 1 assert difference_dashboard_urns == [ - "urn:li:dashboard:(powerbi,dashboards.e41cbfe7-9f54-40ad-8d6a-043ab97cf303)" + "urn:li:dashboard:(powerbi,dashboards.e41cbfe7-9f54-40ad-8d6a-043ab97cf303)", ] elif ( job_id == "powerbi_stale_entity_removal_44444444-7C10-4684-8180-826122881108" ): difference_dashboard_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_dashboard_urns) == 1 assert difference_dashboard_urns == [ - "urn:li:dashboard:(powerbi,dashboards.7D668CAD-4444-4505-9215-655BCA5BEBAE)" + "urn:li:dashboard:(powerbi,dashboards.7D668CAD-4444-4505-9215-655BCA5BEBAE)", ] diff --git a/metadata-ingestion/tests/integration/powerbi_report_server/test_powerbi_report_server.py b/metadata-ingestion/tests/integration/powerbi_report_server/test_powerbi_report_server.py index 826c2b77bce362..a91dd2a0e8be14 100644 --- a/metadata-ingestion/tests/integration/powerbi_report_server/test_powerbi_report_server.py +++ b/metadata-ingestion/tests/integration/powerbi_report_server/test_powerbi_report_server.py @@ -49,7 +49,7 @@ def register_mock_api(request_mock, override_mock_data={}): "HasSharedDataSets": True, "HasParameters": True, }, - ] + ], }, }, "https://host_port/Reports/api/v2.0/LinkedReports": { @@ -78,7 +78,7 @@ def register_mock_api(request_mock, override_mock_data={}): "HasParameters": True, "Link": "sjfgnk-7134-1234-abcd-ee5axvcv938b", }, - ] + ], }, }, "https://host_port/Reports/api/v2.0/PowerBIReports": { @@ -105,7 +105,7 @@ def register_mock_api(request_mock, override_mock_data={}): "HasDataSources": True, "Roles": [], }, - ] + ], }, }, } @@ -169,7 +169,7 @@ def test_powerbi_ingest(mock_msal, pytestconfig, tmp_path, mock_time, requests_m register_mock_api(request_mock=requests_mock) pipeline = Pipeline.create( - get_default_recipe(output_path=f"{tmp_path}/powerbi_report_server_mces.json") + get_default_recipe(output_path=f"{tmp_path}/powerbi_report_server_mces.json"), ) add_mock_method_in_pipeline(pipeline=pipeline) @@ -188,7 +188,11 @@ def test_powerbi_ingest(mock_msal, pytestconfig, tmp_path, mock_time, requests_m @freeze_time(FROZEN_TIME) @mock.patch("requests_ntlm.HttpNtlmAuth") def test_powerbi_ingest_with_failure( - mock_msal, pytestconfig, tmp_path, mock_time, requests_mock + mock_msal, + pytestconfig, + tmp_path, + mock_time, + requests_mock, ): test_resources_dir = ( pytestconfig.rootpath / "tests/integration/powerbi_report_server" @@ -201,12 +205,12 @@ def test_powerbi_ingest_with_failure( "method": "GET", "status_code": 404, "json": {"error": "Request Failed"}, - } + }, }, ) pipeline = Pipeline.create( - get_default_recipe(output_path=f"{tmp_path}/powerbi_report_server_mces.json") + get_default_recipe(output_path=f"{tmp_path}/powerbi_report_server_mces.json"), ) add_mock_method_in_pipeline(pipeline=pipeline) diff --git a/metadata-ingestion/tests/integration/preset/test_preset.py b/metadata-ingestion/tests/integration/preset/test_preset.py index f926a762e6a078..fa02a6bfb4823f 100644 --- a/metadata-ingestion/tests/integration/preset/test_preset.py +++ b/metadata-ingestion/tests/integration/preset/test_preset.py @@ -28,7 +28,7 @@ def register_mock_api(request_mock: Any, override_data: Optional[dict] = None) - "json": { "payload": { "access_token": "test_token", - } + }, }, }, "mock://mock-domain.preset.io/version": { @@ -220,7 +220,7 @@ def test_preset_ingest(pytestconfig, tmp_path, mock_time, requests_mock): "filename": f"{tmp_path}/preset_mces.json", }, }, - } + }, ) pipeline.run() @@ -237,7 +237,11 @@ def test_preset_ingest(pytestconfig, tmp_path, mock_time, requests_mock): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_preset_stateful_ingest( - pytestconfig, tmp_path, mock_time, requests_mock, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + requests_mock, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/preset" @@ -266,7 +270,7 @@ def test_preset_stateful_ingest( }, "sink": { # we are not really interested in the resulting events for this test - "type": "console" + "type": "console", }, "pipeline_name": "test_pipeline", } @@ -321,7 +325,8 @@ def test_preset_stateful_ingest( # Remove one dashboard from the preset config. register_mock_api( - request_mock=requests_mock, override_data=dashboard_endpoint_override + request_mock=requests_mock, + override_data=dashboard_endpoint_override, ) # Capture MCEs of second run to validate Status(removed=true) @@ -341,7 +346,7 @@ def test_preset_stateful_ingest( state1 = checkpoint1.state state2 = checkpoint2.state difference_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_urns) == 1 @@ -352,10 +357,12 @@ def test_preset_stateful_ingest( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Verify the output. diff --git a/metadata-ingestion/tests/integration/qlik_sense/test_qlik_sense.py b/metadata-ingestion/tests/integration/qlik_sense/test_qlik_sense.py index 95f096cc3def35..a7af9c3cf28532 100644 --- a/metadata-ingestion/tests/integration/qlik_sense/test_qlik_sense.py +++ b/metadata-ingestion/tests/integration/qlik_sense/test_qlik_sense.py @@ -39,21 +39,21 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: }, "links": { "self": { - "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces/659d0e41d1b0ecce6eebc9b1" + "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces/659d0e41d1b0ecce6eebc9b1", }, "assignments": { - "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces/659d0e41d1b0ecce6eebc9b1/assignments" + "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces/659d0e41d1b0ecce6eebc9b1/assignments", }, }, "createdAt": "2024-01-09T09:13:38.002Z", "createdBy": "657b5abe656297cec3d8b205", "updatedAt": "2024-01-09T09:13:38.002Z", - } + }, ], "links": { "self": { - "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces" - } + "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/spaces", + }, }, }, }, @@ -102,7 +102,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "meta": { "isFavorited": False, "tags": [ - {"id": "659ce561640a2affcf0d629f", "name": "test_tag"} + {"id": "659ce561640a2affcf0d629f", "name": "test_tag"}, ], "collections": [], }, @@ -191,8 +191,8 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: ], "links": { "self": { - "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/items" - } + "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/items", + }, }, }, }, @@ -223,8 +223,8 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "lastUpdated": "2024-01-25T06:28:22.629Z", "links": { "self": { - "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/users/657b5abe656297cec3d8b205" - } + "href": "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/users/657b5abe656297cec3d8b205", + }, }, }, }, @@ -338,7 +338,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "qHeaderSize": 0, "qRecordSize": 0, "qTabSize": 0, - } + }, }, "effectiveDate": "2024-01-09T18:05:39.713Z", "overrideSchemaAnomalies": False, @@ -405,7 +405,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "sensitive": False, "orphan": False, "nullable": False, - } + }, ], "loadOptions": {}, "effectiveDate": "2024-01-12T12:59:54.522Z", @@ -448,7 +448,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: }, ], "metadata": {}, - } + }, }, }, "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/lineage-graphs/nodes/qri%3Aapp%3Asense%3A%2F%2Ff0714ca7-7093-49e4-8b58-47bb38563647/actions/expand?node=qri%3Aapp%3Asense%3A%2F%2Ff0714ca7-7093-49e4-8b58-47bb38563647%23FcJ-H2TvmAyI--l6fn0VQGPtHf8kB2rj7sj0_ysRHgc&level=FIELD": { @@ -483,7 +483,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: }, ], "metadata": {}, - } + }, }, }, "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/lineage-graphs/nodes/qri%3Aapp%3Asense%3A%2F%2Ff0714ca7-7093-49e4-8b58-47bb38563647/actions/expand?node=qri%3Aapp%3Asense%3A%2F%2Ff0714ca7-7093-49e4-8b58-47bb38563647%23Rrg6-1CeRbo4ews9o--QUP3tOXhm5moLizGY6_wCxJE&level=FIELD": { @@ -499,17 +499,17 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "qri:app:sense://f0714ca7-7093-49e4-8b58-47bb38563647#Rrg6-1CeRbo4ews9o--QUP3tOXhm5moLizGY6_wCxJE#gqNTf_Dbzn7sNdae3DoYnubxfYLzU6VT-aqWywvjzok": { "label": "name", "metadata": {"type": "FIELD"}, - } + }, }, "edges": [ { "relation": "read", "source": "qri:qdf:space://Ebw1EudUywmUi8p2bM7COr5OATHzuYxvT0BIrCc2irU#JOKG8u7CvizvGXwrFsyXRU0yKr2rL2WFD5djpH9bj5Q#Rrg6-1CeRbo4ews9o--QUP3tOXhm5moLizGY6_wCxJE#gqNTf_Dbzn7sNdae3DoYnubxfYLzU6VT-aqWywvjzok", "target": "qri:app:sense://f0714ca7-7093-49e4-8b58-47bb38563647#Rrg6-1CeRbo4ews9o--QUP3tOXhm5moLizGY6_wCxJE#gqNTf_Dbzn7sNdae3DoYnubxfYLzU6VT-aqWywvjzok", - } + }, ], "metadata": {}, - } + }, }, }, "https://iq37k6byr9lgam8.us.qlikcloud.com/api/v1/lineage-graphs/nodes/qri%3Aapp%3Asense%3A%2F%2Ff0714ca7-7093-49e4-8b58-47bb38563647/overview": { @@ -529,8 +529,8 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "tableLabel": "IPL_Matches_2022", }, ], - } - ] + }, + ], }, }, { @@ -545,10 +545,10 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "resourceQRI": "qri:qdf:space://Ebw1EudUywmUi8p2bM7COr5OATHzuYxvT0BIrCc2irU#JOKG8u7CvizvGXwrFsyXRU0yKr2rL2WFD5djpH9bj5Q", "tableQRI": "qri:qdf:space://Ebw1EudUywmUi8p2bM7COr5OATHzuYxvT0BIrCc2irU#JOKG8u7CvizvGXwrFsyXRU0yKr2rL2WFD5djpH9bj5Q#Rrg6-1CeRbo4ews9o--QUP3tOXhm5moLizGY6_wCxJE", "tableLabel": "test_table", - } + }, ], - } - ] + }, + ], }, }, ], @@ -587,7 +587,7 @@ def mock_websocket_response(*args, **kwargs): "qType": "Doc", "qHandle": 1, "qGenericId": "f0714ca7-7093-49e4-8b58-47bb38563647", - } + }, } elif request == { "jsonrpc": "2.0", @@ -631,7 +631,7 @@ def mock_websocket_response(*args, **kwargs): "source", "exportappdata", ], - } + }, } elif request == { "jsonrpc": "2.0", @@ -670,7 +670,7 @@ def mock_websocket_response(*args, **kwargs): }, "qData": {}, }, - ] + ], } elif request == { "jsonrpc": "2.0", @@ -685,7 +685,7 @@ def mock_websocket_response(*args, **kwargs): "qHandle": 2, "qGenericType": "sheet", "qGenericId": "f4f57386-263a-4ec9-b40c-abcd2467f423", - } + }, } elif request == { "jsonrpc": "2.0", @@ -732,14 +732,14 @@ def mock_websocket_response(*args, **kwargs): "qInfo": {"qId": "QYUUb", "qType": "barchart"}, "qMeta": {"privileges": ["read", "update", "delete"]}, "qData": {"title": ""}, - } - ] + }, + ], }, "customRowBase": 12, "gridResolution": "small", "layoutOptions": {"mobileLayout": "LIST", "extendable": False}, "gridMode": "simpleEdit", - } + }, } elif request == { "jsonrpc": "2.0", @@ -754,7 +754,7 @@ def mock_websocket_response(*args, **kwargs): "qHandle": 3, "qGenericType": "scatterplot", "qGenericId": "QYUUb", - } + }, } elif request == { "jsonrpc": "2.0", @@ -811,7 +811,7 @@ def mock_websocket_response(*args, **kwargs): "autoSort": True, "cId": "FdErA", "othersLabel": "Others", - } + }, ], "qMeasureInfo": [ { @@ -841,7 +841,7 @@ def mock_websocket_response(*args, **kwargs): "showDetails": True, "showDetailsExpression": False, "visualization": "scatterplot", - } + }, } elif request == { "jsonrpc": "2.0", @@ -856,7 +856,7 @@ def mock_websocket_response(*args, **kwargs): "qHandle": 4, "qGenericType": "LoadModel", "qGenericId": "LoadModel", - } + }, } elif request == { "jsonrpc": "2.0", @@ -878,7 +878,7 @@ def mock_websocket_response(*args, **kwargs): "tableName": "test_table", "tableAlias": "test_table", "loadProperties": { - "filterInfo": {"filterClause": "", "filterType": 1} + "filterInfo": {"filterClause": "", "filterType": 1}, }, "key": "Google_BigQuery_harshal-playground-306419:::test_table", "fields": [ @@ -888,7 +888,7 @@ def mock_websocket_response(*args, **kwargs): "selected": True, "checked": True, "id": "dsd.test_table.name", - } + }, ], "connectionInfo": { "name": "Google_BigQuery_harshal-playground-306419", @@ -982,7 +982,7 @@ def mock_websocket_response(*args, **kwargs): }, ], "schemaVersion": 2.1, - } + }, } else: return {} @@ -991,9 +991,9 @@ def mock_websocket_response(*args, **kwargs): @pytest.fixture(scope="module") def mock_websocket_send_request(): with patch( - "datahub.ingestion.source.qlik_sense.qlik_api.WebsocketConnection._send_request" + "datahub.ingestion.source.qlik_sense.qlik_api.WebsocketConnection._send_request", ) as mock_websocket_send_request, patch( - "datahub.ingestion.source.qlik_sense.websocket_connection.create_connection" + "datahub.ingestion.source.qlik_sense.websocket_connection.create_connection", ): mock_websocket_send_request.side_effect = mock_websocket_response yield mock_websocket_send_request @@ -1009,7 +1009,10 @@ def default_config(): @pytest.mark.integration def test_qlik_sense_ingest( - pytestconfig, tmp_path, requests_mock, mock_websocket_send_request + pytestconfig, + tmp_path, + requests_mock, + mock_websocket_send_request, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/qlik_sense" @@ -1032,7 +1035,7 @@ def test_qlik_sense_ingest( "filename": output_path, }, }, - } + }, ) pipeline.run() @@ -1048,7 +1051,10 @@ def test_qlik_sense_ingest( @pytest.mark.integration def test_platform_instance_ingest( - pytestconfig, tmp_path, requests_mock, mock_websocket_send_request + pytestconfig, + tmp_path, + requests_mock, + mock_websocket_send_request, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/qlik_sense" @@ -1068,7 +1074,7 @@ def test_platform_instance_ingest( "Google_BigQuery_harshal-playground-306419": { "platform_instance": "google-cloud", "env": "DEV", - } + }, }, }, }, @@ -1078,7 +1084,7 @@ def test_platform_instance_ingest( "filename": output_path, }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py b/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py index a9eebb8d54154e..16ab945c113564 100644 --- a/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py +++ b/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py @@ -36,7 +36,7 @@ def test_redshift_usage_config(): email_domain="xxxxx", include_views=True, include_tables=True, - ) + ), ) assert config.host_port == "xxxxx" @@ -52,12 +52,12 @@ def test_redshift_usage_config(): @patch("redshift_connector.Connection") def test_redshift_usage_source(mock_cursor, mock_connection, pytestconfig, tmp_path): test_resources_dir = pathlib.Path( - pytestconfig.rootpath / "tests/integration/redshift-usage" + pytestconfig.rootpath / "tests/integration/redshift-usage", ) generate_mcps_path = Path(f"{tmp_path}/redshift_usages.json") mock_usage_query_result = open(f"{test_resources_dir}/usage_events_history.json") mock_operational_query_result = open( - f"{test_resources_dir}/operational_events_history.json" + f"{test_resources_dir}/operational_events_history.json", ) mock_usage_query_result_dict = json.load(mock_usage_query_result) mock_operational_query_result_dict = json.load(mock_operational_query_result) @@ -109,7 +109,7 @@ def test_redshift_usage_source(mock_cursor, mock_connection, pytestconfig, tmp_p created=None, comment="", ), - ] + ], }, "dev": { "public": [ @@ -127,7 +127,7 @@ def test_redshift_usage_source(mock_cursor, mock_connection, pytestconfig, tmp_p created=None, comment="", ), - ] + ], }, } mwus = usage_extractor.get_usage_workunits(all_tables=all_tables) @@ -160,12 +160,12 @@ def test_redshift_usage_source(mock_cursor, mock_connection, pytestconfig, tmp_p @patch("redshift_connector.Connection") def test_redshift_usage_filtering(mock_cursor, mock_connection, pytestconfig, tmp_path): test_resources_dir = pathlib.Path( - pytestconfig.rootpath / "tests/integration/redshift-usage" + pytestconfig.rootpath / "tests/integration/redshift-usage", ) generate_mcps_path = Path(f"{tmp_path}/redshift_usages.json") mock_usage_query_result = open(f"{test_resources_dir}/usage_events_history.json") mock_operational_query_result = open( - f"{test_resources_dir}/operational_events_history.json" + f"{test_resources_dir}/operational_events_history.json", ) mock_usage_query_result_dict = json.load(mock_usage_query_result) mock_operational_query_result_dict = json.load(mock_operational_query_result) @@ -216,7 +216,7 @@ def test_redshift_usage_filtering(mock_cursor, mock_connection, pytestconfig, tm created=None, comment="", ), - ] + ], }, } mwus = usage_extractor.get_usage_workunits(all_tables=all_tables) diff --git a/metadata-ingestion/tests/integration/remote/test_remote.py b/metadata-ingestion/tests/integration/remote/test_remote.py index 881c8c1b3cd6d0..dce5085e1d2bf4 100644 --- a/metadata-ingestion/tests/integration/remote/test_remote.py +++ b/metadata-ingestion/tests/integration/remote/test_remote.py @@ -17,7 +17,8 @@ def test_remote_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time) test_resources_dir = pytestconfig.rootpath / "tests/integration/remote" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "remote" + test_resources_dir / "docker-compose.yml", + "remote", ) as docker_services: wait_for_port( docker_services=docker_services, @@ -45,7 +46,7 @@ def test_remote_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time) "filename": f"{tmp_path}/parsed_enriched_file.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -72,7 +73,7 @@ def test_remote_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time) "filename": f"{tmp_path}/remote_file_output.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -100,7 +101,7 @@ def test_remote_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time) "filename": f"{tmp_path}/parsed_lineage_output.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -127,7 +128,7 @@ def test_remote_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time) "filename": f"{tmp_path}/remote_glossary_output.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/s3/test_s3.py b/metadata-ingestion/tests/integration/s3/test_s3.py index 0e73cdca006bd0..b84193934635e2 100644 --- a/metadata-ingestion/tests/integration/s3/test_s3.py +++ b/metadata-ingestion/tests/integration/s3/test_s3.py @@ -100,7 +100,8 @@ def s3_populate(pytestconfig, s3_resource, s3_client, bucket_names): ) current_time_sec = datetime.strptime( - FROZEN_TIME, "%Y-%m-%d %H:%M:%S" + FROZEN_TIME, + "%Y-%m-%d %H:%M:%S", ).timestamp() file_list = [] for root, _dirs, files in os.walk(test_resources_dir): @@ -161,7 +162,11 @@ def touch_local_files(pytestconfig): @pytest.mark.integration @pytest.mark.parametrize("source_file_tuple", shared_source_files + s3_source_files) def test_data_lake_s3_ingest( - pytestconfig, s3_populate, source_file_tuple, tmp_path, mock_time + pytestconfig, + s3_populate, + source_file_tuple, + tmp_path, + mock_time, ): source_dir, source_file = source_file_tuple test_resources_dir = pytestconfig.rootpath / "tests/integration/s3/" @@ -198,7 +203,11 @@ def test_data_lake_s3_ingest( @pytest.mark.integration @pytest.mark.parametrize("source_file_tuple", shared_source_files) def test_data_lake_local_ingest( - pytestconfig, touch_local_files, source_file_tuple, tmp_path, mock_time + pytestconfig, + touch_local_files, + source_file_tuple, + tmp_path, + mock_time, ): source_dir, source_file = source_file_tuple test_resources_dir = pytestconfig.rootpath / "tests/integration/s3/" @@ -210,10 +219,12 @@ def test_data_lake_local_ingest( path_spec["include"] = ( path_spec["include"] .replace( - "s3://my-test-bucket/", "tests/integration/s3/test_data/local_system/" + "s3://my-test-bucket/", + "tests/integration/s3/test_data/local_system/", ) .replace( - "s3://my-test-bucket-2/", "tests/integration/s3/test_data/local_system/" + "s3://my-test-bucket-2/", + "tests/integration/s3/test_data/local_system/", ) ) @@ -259,7 +270,7 @@ def test_data_lake_incorrect_config_raises_error(tmp_path, mock_time): # Baseline: valid config source: dict = { - "path_spec": {"include": "a/b/c/d/{table}.*", "table_name": "{table}"} + "path_spec": {"include": "a/b/c/d/{table}.*", "table_name": "{table}"}, } s3 = S3Source.create(source, ctx) assert s3.source_config.platform == "file" @@ -283,7 +294,7 @@ def test_data_lake_incorrect_config_raises_error(tmp_path, mock_time): source = { "path_spec": { "include": "a/b/c/d/{table}/*.hd5", - } + }, } with pytest.raises(ValidationError, match="file type"): S3Source.create(source, ctx) diff --git a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py index 9e68ff22a767e2..090a626d2ff5c4 100644 --- a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py +++ b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py @@ -39,19 +39,22 @@ def side_effect_call_salesforce(type, url): return MockResponse(_read_response("account_fields_soql_response.json"), 200) elif url.endswith("FROM CustomField WHERE EntityDefinitionId='Account'"): return MockResponse( - _read_response("account_custom_fields_soql_response.json"), 200 + _read_response("account_custom_fields_soql_response.json"), + 200, ) elif url.endswith("FROM CustomObject where DeveloperName='Property'"): return MockResponse( - _read_response("property_custom_object_soql_response.json"), 200 + _read_response("property_custom_object_soql_response.json"), + 200, ) elif url.endswith( - "FROM EntityParticle WHERE EntityDefinitionId='01I5i000000Y6fp'" + "FROM EntityParticle WHERE EntityDefinitionId='01I5i000000Y6fp'", ): # DurableId of Property__c return MockResponse(_read_response("property_fields_soql_response.json"), 200) elif url.endswith("FROM CustomField WHERE EntityDefinitionId='01I5i000000Y6fp'"): return MockResponse( - _read_response("property_custom_fields_soql_response.json"), 200 + _read_response("property_custom_fields_soql_response.json"), + 200, ) elif url.endswith("/recordCount?sObjects=Property__c"): return MockResponse(_read_response("record_count_property_response.json"), 200) @@ -83,9 +86,9 @@ def test_latest_version(mock_sdk): "profile_pattern": { "allow": [ "^Property__c$", - ] + ], }, - } + }, ) SalesforceSource(config=config, ctx=Mock()) calls = mock_sf._call_salesforce.mock_calls @@ -126,9 +129,9 @@ def test_custom_version(mock_sdk): "profile_pattern": { "allow": [ "^Property__c$", - ] + ], }, - } + }, ) SalesforceSource(config=config, ctx=Mock()) @@ -171,7 +174,7 @@ def test_salesforce_ingest(pytestconfig, tmp_path): "profile_pattern": { "allow": [ "^Property__c$", - ] + ], }, }, }, @@ -181,7 +184,7 @@ def test_salesforce_ingest(pytestconfig, tmp_path): "filename": f"{tmp_path}/salesforce_mces.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() diff --git a/metadata-ingestion/tests/integration/sigma/test_sigma.py b/metadata-ingestion/tests/integration/sigma/test_sigma.py index 19fa1448fee598..44451ca370719e 100644 --- a/metadata-ingestion/tests/integration/sigma/test_sigma.py +++ b/metadata-ingestion/tests/integration/sigma/test_sigma.py @@ -254,7 +254,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "source": "inode-2Fby2MBLPM5jUMfBB15On1", "target": "mnJ7_k2sbt", "type": "source", - } + }, ], }, }, @@ -289,7 +289,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "source": "inode-49HFLTr6xytgrPly3PFsNC", "target": "qrL7BEq8LR", "type": "source", - } + }, ], }, }, @@ -325,7 +325,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "Updated At (ADOPTIONS)", ], "vizualizationType": "levelTable", - } + }, ], "total": 1, "nextPage": None, @@ -436,7 +436,7 @@ def test_sigma_ingest(pytestconfig, tmp_path, requests_mock): "client_secret": "CLIENTSECRET", "chart_sources_platform_mapping": { "Acryl Data/Acryl Workbook": { - "data_source_platform": "snowflake" + "data_source_platform": "snowflake", }, }, }, @@ -447,7 +447,7 @@ def test_sigma_ingest(pytestconfig, tmp_path, requests_mock): "filename": output_path, }, }, - } + }, ) pipeline.run() @@ -493,7 +493,7 @@ def test_platform_instance_ingest(pytestconfig, tmp_path, requests_mock): "filename": output_path, }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -583,7 +583,7 @@ def test_sigma_ingest_shared_entities(pytestconfig, tmp_path, requests_mock): "ingest_shared_entities": True, "chart_sources_platform_mapping": { "Acryl Data/Acryl Workbook": { - "data_source_platform": "snowflake" + "data_source_platform": "snowflake", }, }, }, @@ -594,7 +594,7 @@ def test_sigma_ingest_shared_entities(pytestconfig, tmp_path, requests_mock): "filename": output_path, }, }, - } + }, ) pipeline.run() diff --git a/metadata-ingestion/tests/integration/snowflake/common.py b/metadata-ingestion/tests/integration/snowflake/common.py index 7b4f5abe1cd462..cf336cd3643ba7 100644 --- a/metadata-ingestion/tests/integration/snowflake/common.py +++ b/metadata-ingestion/tests/integration/snowflake/common.py @@ -201,7 +201,7 @@ def default_query_results( # noqa: C901 "name": "TEST_DB", "created_on": datetime(2021, 6, 8, 0, 0, 0, 0), "comment": "Comment for TEST_DB", - } + }, ] elif query == SnowflakeQuery.get_databases("TEST_DB"): return [ @@ -210,7 +210,7 @@ def default_query_results( # noqa: C901 "CREATED": datetime(2021, 6, 8, 0, 0, 0, 0), "LAST_ALTERED": datetime(2021, 6, 8, 0, 0, 0, 0), "COMMENT": "Comment for TEST_DB", - } + }, ] elif query == SnowflakeQuery.schemas_for_database("TEST_DB"): return [ @@ -304,7 +304,7 @@ def default_query_results( # noqa: C901 { "MIN_TIME": datetime(2021, 6, 8, 0, 0, 0, 0), "MAX_TIME": datetime(2022, 6, 7, 7, 17, 0, 0), - } + }, ] elif query == snowflake_query.SnowflakeQuery.operational_data_for_time_window( 1654473600000, @@ -313,10 +313,10 @@ def default_query_results( # noqa: C901 return [ { "QUERY_START_TIME": datetime(2022, 6, 2, 4, 41, 1, 367000).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), "QUERY_TEXT": "create or replace table TABLE_{} as select * from TABLE_2 left join TABLE_3 using COL_1 left join TABLE 4 using COL2".format( - op_idx + op_idx, ), "QUERY_TYPE": "CREATE_TABLE_AS_SELECT", "ROWS_INSERTED": 0, @@ -351,7 +351,7 @@ def default_query_results( # noqa: C901 "objectId": 0, "objectName": "TEST_DB.TEST_SCHEMA.TABLE_4", }, - ] + ], ), "DIRECT_OBJECTS_ACCESSED": json.dumps( [ @@ -382,7 +382,7 @@ def default_query_results( # noqa: C901 "objectId": 0, "objectName": "TEST_DB.TEST_SCHEMA.TABLE_4", }, - ] + ], ), "OBJECTS_MODIFIED": json.dumps( [ @@ -397,7 +397,7 @@ def default_query_results( # noqa: C901 "objectDomain": "Table", "objectId": 0, "objectName": "TEST_DB.TEST_SCHEMA.TABLE_2", - } + }, ], } for col_idx in range(1, num_cols + 1) @@ -405,8 +405,8 @@ def default_query_results( # noqa: C901 "objectDomain": "Table", "objectId": 0, "objectName": f"TEST_DB.TEST_SCHEMA.TABLE_{op_idx}", - } - ] + }, + ], ), "USER_NAME": "SERVICE_ACCOUNT_TESTS_ADMIN", "FIRST_NAME": None, @@ -435,20 +435,20 @@ def default_query_results( # noqa: C901 { "OBJECT_NAME": f"TEST_DB.TEST_SCHEMA.TABLE_{i}{random.randint(99, 999) if i > num_tables else ''}", "BUCKET_START_TIME": datetime(2022, 6, 6, 0, 0, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), "OBJECT_DOMAIN": "Table", "TOTAL_QUERIES": 10, "TOTAL_USERS": 1, "TOP_SQL_QUERIES": json.dumps([large_sql_query for _ in range(10)]), "FIELD_COUNTS": json.dumps( - [{"col": f"col{c}", "total": 10} for c in range(num_cols)] + [{"col": f"col{c}", "total": 10} for c in range(num_cols)], ), "USER_COUNTS": json.dumps( [ {"email": f"abc{i}@xyz.com", "user_name": f"abc{i}", "total": 1} for i in range(10) - ] + ], ), } for i in range(num_usages) @@ -471,7 +471,7 @@ def default_query_results( # noqa: C901 "upstream_object_name": "TEST_DB.TEST_SCHEMA.TABLE_2", "upstream_object_domain": "TABLE", "query_id": f"01b2576e-0804-4957-0034-7d83066cd0ee{op_idx}", - } + }, ] + ( # This additional upstream is only for TABLE_1 [ @@ -488,7 +488,7 @@ def default_query_results( # noqa: C901 ] if op_idx == 1 else [] - ) + ), ), "UPSTREAM_COLUMNS": json.dumps( [ @@ -502,9 +502,9 @@ def default_query_results( # noqa: C901 "object_name": "TEST_DB.TEST_SCHEMA.TABLE_2", "object_domain": "Table", "column_name": f"COL_{col_idx}", - } + }, ], - } + }, ], } for col_idx in range(1, num_cols + 1) @@ -521,15 +521,15 @@ def default_query_results( # noqa: C901 "object_name": "OTHER_DB.OTHER_SCHEMA.TABLE_1", "object_domain": "Table", "column_name": "COL_1", - } + }, ], - } + }, ], - } + }, ] if op_idx == 1 else [] - ) + ), ), "QUERIES": json.dumps( [ @@ -537,8 +537,8 @@ def default_query_results( # noqa: C901 "query_text": f"INSERT INTO TEST_DB.TEST_SCHEMA.TABLE_{op_idx} SELECT * FROM TEST_DB.TEST_SCHEMA.TABLE_2", "query_id": f"01b2576e-0804-4957-0034-7d83066cd0ee{op_idx}", "start_time": "06-06-2022", - } - ] + }, + ], ), } for op_idx in range(1, num_ops + 1) @@ -572,7 +572,7 @@ def default_query_results( # noqa: C901 ] if op_idx == 1 else [] - ) + ), ), "QUERIES": json.dumps( [ @@ -580,10 +580,10 @@ def default_query_results( # noqa: C901 "query_text": f"INSERT INTO TEST_DB.TEST_SCHEMA.TABLE_{op_idx} SELECT * FROM TEST_DB.TEST_SCHEMA.TABLE_2", "query_id": f"01b2576e-0804-4957-0034-7d83066cd0ee{op_idx}", "start_time": datetime(2022, 6, 6, 0, 0, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), - } - ] + }, + ], ), } for op_idx in range(1, num_ops + 1) @@ -597,7 +597,7 @@ def default_query_results( # noqa: C901 "REFERENCING_OBJECT_DOMAIN": "view", "DOWNSTREAM_VIEW": "TEST_DB.TEST_SCHEMA.VIEW_2", "VIEW_UPSTREAM": "TEST_DB.TEST_SCHEMA.TABLE_2", - } + }, ] elif query in [ snowflake_query.SnowflakeQuery.view_dependencies_v2(), @@ -612,20 +612,22 @@ def default_query_results( # noqa: C901 { "upstream_object_name": "TEST_DB.TEST_SCHEMA.TABLE_2", "upstream_object_domain": "table", - } - ] + }, + ], ), - } + }, ] elif query in [ snowflake_query.SnowflakeQuery.view_dependencies_v2(), snowflake_query.SnowflakeQuery.view_dependencies(), snowflake_query.SnowflakeQuery.show_external_tables(), snowflake_query.SnowflakeQuery.copy_lineage_history( - start_time_millis=1654473600000, end_time_millis=1654621200000 + start_time_millis=1654473600000, + end_time_millis=1654621200000, ), snowflake_query.SnowflakeQuery.copy_lineage_history( - start_time_millis=1654473600000, end_time_millis=1654586220000 + start_time_millis=1654473600000, + end_time_millis=1654586220000, ), ]: return [] @@ -633,7 +635,7 @@ def default_query_results( # noqa: C901 elif ( query == snowflake_query.SnowflakeQuery.get_all_tags_in_database_without_propagation( - "TEST_DB" + "TEST_DB", ) ): return [ diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py index d2e20e784282ee..2bfb1ea43d04cf 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py @@ -38,7 +38,7 @@ def random_email(): [ random.choice(string.ascii_lowercase) for i in range(random.randint(10, 15)) - ] + ], ) + "@xyz.com" ) @@ -52,7 +52,7 @@ def random_cloud_region(): random.choice(["central", "north", "south", "east", "west"]), "-", str(random.randint(1, 2)), - ] + ], ) @@ -64,7 +64,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): golden_file = test_resources_dir / "snowflake_golden.json" with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch( - "datahub.ingestion.source.snowflake.snowflake_data_reader.SnowflakeDataReader.get_sample_data_for_table" + "datahub.ingestion.source.snowflake.snowflake_data_reader.SnowflakeDataReader.get_sample_data_for_table", ) as mock_sample_values: sf_connection = mock.MagicMock() sf_cursor = mock.MagicMock() @@ -85,8 +85,11 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): info_types_config={ "Age": InfoTypeConfig( prediction_factors_and_weights=PredictionFactorsAndWeights( - name=0, values=1, description=0, datatype=0 - ) + name=0, + values=1, + description=0, + datatype=0, + ), ), "CloudRegion": InfoTypeConfig( prediction_factors_and_weights=PredictionFactorsAndWeights( @@ -98,7 +101,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): values=ValuesFactorConfig( prediction_type="regex", regex=[ - r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+" + r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+", ], ), ), @@ -124,20 +127,21 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): email_as_user_identifier=True, incremental_lineage=False, start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), classification=ClassificationConfig( enabled=True, column_pattern=AllowDenyPattern( - allow=[".*col_1$", ".*col_2$", ".*col_3$"] + allow=[".*col_1$", ".*col_2$", ".*col_3$"], ), classifiers=[ DynamicTypedClassifierConfig( - type="datahub", config=datahub_classifier_config - ) + type="datahub", + config=datahub_classifier_config, + ), ], max_workers=1, ), @@ -152,9 +156,10 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): ), ), sink=DynamicTypedConfig( - type="file", config={"filename": str(output_file)} + type="file", + config={"filename": str(output_file)}, ), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -186,7 +191,10 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): def test_snowflake_tags_as_structured_properties( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake" @@ -222,9 +230,10 @@ def test_snowflake_tags_as_structured_properties( ), ), sink=DynamicTypedConfig( - type="file", config={"filename": str(output_file)} + type="file", + config={"filename": str(output_file)}, ), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -249,7 +258,10 @@ def test_snowflake_tags_as_structured_properties( @freeze_time(FROZEN_TIME) def test_snowflake_private_link_and_incremental_mcps( - pytestconfig, tmp_path, mock_time, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/snowflake" @@ -284,17 +296,18 @@ def test_snowflake_private_link_and_incremental_mcps( include_operational_stats=False, platform_instance="instance1", start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), ), ), sink=DynamicTypedConfig( - type="file", config={"filename": str(output_file)} + type="file", + config={"filename": str(output_file)}, ), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py index 52453b30f740ab..209ed8639b8d5e 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py @@ -35,7 +35,7 @@ ) def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tables): with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch( - "datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table" + "datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table", ) as mock_sample_values: sf_connection = mock.MagicMock() sf_cursor = mock.MagicMock() @@ -43,11 +43,13 @@ def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tabl sf_connection.cursor.return_value = sf_cursor sf_cursor.execute.side_effect = partial( - default_query_results, num_tables=num_tables, num_cols=num_cols_per_table + default_query_results, + num_tables=num_tables, + num_cols=num_cols_per_table, ) mock_sample_values.return_value = pd.DataFrame( - data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)} + data={f"col_{i}": sample_values for i in range(1, num_cols_per_table + 1)}, ) datahub_classifier_config = DataHubClassifierConfig( @@ -74,14 +76,15 @@ def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tabl max_workers=num_workers, classifiers=[ DynamicTypedClassifierConfig( - type="datahub", config=datahub_classifier_config - ) + type="datahub", + config=datahub_classifier_config, + ), ], ), ), ), sink=DynamicTypedConfig(type="blackhole", config={}), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -97,7 +100,7 @@ def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tabl len( cast(SnowflakeV2Report, source_report).info_types_detected[ "Email_Address" - ] + ], ) == num_tables * num_cols_per_table ) diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py index de6e996a52642b..7cbdfda3b629d5 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py @@ -51,7 +51,7 @@ def snowflake_pipeline_config(tmp_path): schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), include_usage_stats=False, start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( - tzinfo=timezone.utc + tzinfo=timezone.utc, ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace(tzinfo=timezone.utc), ), @@ -69,7 +69,7 @@ def test_snowflake_missing_role_access_causes_pipeline_failure( with mock.patch("snowflake.connector.connect") as mock_connect: # Snowflake connection fails role not granted error mock_connect.side_effect = Exception( - "250001 (08001): Failed to connect to DB: abc12345.ap-south-1.snowflakecomputing.com:443. Role 'TEST_ROLE' specified in the connect string is not granted to this user. Contact your local system administrator, or attempt to login with another role, e.g. PUBLIC" + "250001 (08001): Failed to connect to DB: abc12345.ap-south-1.snowflakecomputing.com:443. Role 'TEST_ROLE' specified in the connect string is not granted to this user. Contact your local system administrator, or attempt to login with another role, e.g. PUBLIC", ) with pytest.raises(PipelineInitError, match="Permissions error"): @@ -227,7 +227,7 @@ def test_snowflake_missing_snowflake_lineage_permission_causes_pipeline_failure( start_time_millis=1654473600000, end_time_millis=1654586220000, include_column_lineage=True, - ) + ), ], "Database 'SNOWFLAKE' does not exist or not authorized.", ) diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py index ae0f23d93215d4..835ad792c1aca2 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py @@ -14,7 +14,7 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path): "account_id": "ABC12345.ap-south-1.aws", "username": "TST_USR", "password": "TST_PWD", - } + }, }, PipelineContext("run-id"), ) diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py index 7e2ac94fa4e35c..e8ea38ca78427e 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py @@ -41,7 +41,7 @@ def stateful_pipeline_config(include_tables: bool) -> PipelineConfig: "type": "datahub", "config": {"datahub_api": {"server": GMS_SERVER}}, }, - } + }, ), ), ), @@ -92,10 +92,12 @@ def test_stale_metadata_removal(mock_datahub_graph): # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform all assertions on the states. The deleted table should not be @@ -104,7 +106,7 @@ def test_stale_metadata_removal(mock_datahub_graph): state2 = checkpoint2.state difference_dataset_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert sorted(difference_dataset_urns) == [ "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_1,PROD)", diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py index d4f6e92c93c1e0..4d5684f0f2abcf 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py @@ -26,7 +26,7 @@ def test_snowflake_tag_pattern(): match_fully_qualified_names=True, schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), tag_pattern=AllowDenyPattern( - allow=["TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1"] + allow=["TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1"], ), include_technical_schema=True, include_table_lineage=False, @@ -40,7 +40,7 @@ def test_snowflake_tag_pattern(): config=PipelineConfig( source=SourceConfig(type="snowflake", config=tag_config), sink=DynamicTypedConfig(type="blackhole", config={}), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -50,7 +50,7 @@ def test_snowflake_tag_pattern(): assert isinstance(source_report, SnowflakeV2Report) assert source_report.tags_scanned == 5 assert source_report._processed_tags == { - "TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1" + "TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1", } @@ -69,7 +69,7 @@ def test_snowflake_tag_pattern_deny(): match_fully_qualified_names=True, schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), tag_pattern=AllowDenyPattern( - deny=["TEST_DB.TEST_SCHEMA.my_tag_2:my_value_2"] + deny=["TEST_DB.TEST_SCHEMA.my_tag_2:my_value_2"], ), include_technical_schema=True, include_table_lineage=False, @@ -83,7 +83,7 @@ def test_snowflake_tag_pattern_deny(): config=PipelineConfig( source=SourceConfig(type="snowflake", config=tag_config), sink=DynamicTypedConfig(type="blackhole", config={}), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() @@ -116,10 +116,10 @@ def test_snowflake_structured_property_pattern_deny(): schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), extract_tags_as_structured_properties=True, tag_pattern=AllowDenyPattern( - deny=["TEST_DB.TEST_SCHEMA.my_tag_2:my_value_2"] + deny=["TEST_DB.TEST_SCHEMA.my_tag_2:my_value_2"], ), structured_property_pattern=AllowDenyPattern( - deny=["TEST_DB.TEST_SCHEMA.my_tag_[0-9]"] + deny=["TEST_DB.TEST_SCHEMA.my_tag_[0-9]"], ), include_technical_schema=True, include_table_lineage=False, @@ -133,7 +133,7 @@ def test_snowflake_structured_property_pattern_deny(): config=PipelineConfig( source=SourceConfig(type="snowflake", config=tag_config), sink=DynamicTypedConfig(type="blackhole", config={}), - ) + ), ) pipeline.run() pipeline.pretty_print_summary() diff --git a/metadata-ingestion/tests/integration/sql_server/test_sql_server.py b/metadata-ingestion/tests/integration/sql_server/test_sql_server.py index 7fab5fc7dae1ba..6bf5f18f91eabe 100644 --- a/metadata-ingestion/tests/integration/sql_server/test_sql_server.py +++ b/metadata-ingestion/tests/integration/sql_server/test_sql_server.py @@ -20,7 +20,8 @@ def mssql_runner(docker_compose_runner, pytestconfig): test_resources_dir = pytestconfig.rootpath / "tests/integration/sql_server" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "sql-server" + test_resources_dir / "docker-compose.yml", + "sql-server", ) as docker_services: # Wait for SQL Server to be ready. We wait an extra couple seconds, as the port being available # does not mean the server is accepting connections. @@ -49,7 +50,9 @@ def test_mssql_ingest(mssql_runner, pytestconfig, tmp_path, mock_time, config_fi # Run the metadata ingestion pipeline. config_file_path = (test_resources_dir / f"source_files/{config_file}").resolve() run_datahub_cmd( - ["ingest", "-c", f"{config_file_path}"], tmp_path=tmp_path, check_result=True + ["ingest", "-c", f"{config_file_path}"], + tmp_path=tmp_path, + check_result=True, ) # Verify the output. @@ -74,7 +77,8 @@ def test_mssql_ingest(mssql_runner, pytestconfig, tmp_path, mock_time, config_fi @pytest.mark.parametrize("procedure_sql_file", procedure_sqls) @pytest.mark.integration def test_stored_procedure_lineage( - pytestconfig: pytest.Config, procedure_sql_file: str + pytestconfig: pytest.Config, + procedure_sql_file: str, ) -> None: sql_file_path = PROCEDURE_SQLS_DIR / procedure_sql_file procedure_code = sql_file_path.read_text() @@ -102,7 +106,7 @@ def test_stored_procedure_lineage( procedure=procedure, procedure_job_urn=data_job_urn, is_temp_table=lambda name: "temp" in name.lower(), - ) + ), ) mce_helpers.check_goldens_stream( pytestconfig, diff --git a/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py b/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py index 9d23a6d7cd9fcc..af517673571e43 100644 --- a/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py +++ b/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py @@ -23,7 +23,7 @@ def test_trino_usage_config(): audit_schema="xxxxx", include_views=True, include_tables=True, - ) + ), ) assert config.host_port == "xxxxx" @@ -39,11 +39,11 @@ def test_trino_usage_config(): @freeze_time(FROZEN_TIME) def test_trino_usage_source(pytestconfig, tmp_path): test_resources_dir = pathlib.Path( - pytestconfig.rootpath / "tests/integration/starburst-trino-usage" + pytestconfig.rootpath / "tests/integration/starburst-trino-usage", ) with patch( - "datahub.ingestion.source.usage.starburst_trino_usage.TrinoUsageSource._get_trino_history" + "datahub.ingestion.source.usage.starburst_trino_usage.TrinoUsageSource._get_trino_history", ) as mock_event_history: access_events = load_access_events(test_resources_dir) mock_event_history.return_value = access_events diff --git a/metadata-ingestion/tests/integration/superset/test_superset.py b/metadata-ingestion/tests/integration/superset/test_superset.py index e8251e54a1f85a..9c647446640111 100644 --- a/metadata-ingestion/tests/integration/superset/test_superset.py +++ b/metadata-ingestion/tests/integration/superset/test_superset.py @@ -165,7 +165,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "id": 1, "last_name": "Owner1", "username": "test_username_1", - } + }, ], "schema": "test_schema1", "sql": "SELECT * FROM test_table1", @@ -195,7 +195,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "id": 2, "last_name": "Owner2", "username": "test_username_2", - } + }, ], "schema": "test_schema2", "sql": "SELECT * FROM test_table2", @@ -255,7 +255,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "rendered_expression": "count(*)", "verbose_name": None, "warning_text": None, - } + }, ], "name": "Test Table 1", "normalize_columns": True, @@ -326,7 +326,7 @@ def register_mock_api(request_mock: Any, override_data: dict = {}) -> None: "rendered_expression": "sum(value)", "verbose_name": "Total Value", "warning_text": None, - } + }, ], "name": "Test Table 2", "normalize_columns": True, @@ -423,7 +423,7 @@ def test_superset_ingest(pytestconfig, tmp_path, mock_time, requests_mock): "filename": f"{tmp_path}/superset_mces.json", }, }, - } + }, ) pipeline.run() @@ -440,7 +440,11 @@ def test_superset_ingest(pytestconfig, tmp_path, mock_time, requests_mock): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_superset_stateful_ingest( - pytestconfig, tmp_path, mock_time, requests_mock, mock_datahub_graph + pytestconfig, + tmp_path, + mock_time, + requests_mock, + mock_datahub_graph, ): test_resources_dir = pytestconfig.rootpath / "tests/integration/superset" @@ -470,7 +474,7 @@ def test_superset_stateful_ingest( }, "sink": { # we are not really interested in the resulting events for this test - "type": "console" + "type": "console", }, "pipeline_name": "test_pipeline", } @@ -584,7 +588,7 @@ def test_superset_stateful_ingest( "id": 2, "last_name": "Owner2", "username": "test_username_2", - } + }, ], "schema": "test_schema2", "sql": "SELECT * FROM test_table2", @@ -629,13 +633,13 @@ def test_superset_stateful_ingest( state1 = checkpoint1.state state2 = checkpoint2.state dashboard_difference_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) chart_difference_urns = list( - state1.get_urns_not_in(type="chart", other_checkpoint_state=state2) + state1.get_urns_not_in(type="chart", other_checkpoint_state=state2), ) dataset_difference_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert len(dashboard_difference_urns) == 1 @@ -652,10 +656,12 @@ def test_superset_stateful_ingest( # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Verify the output. diff --git a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py index 32b1ef2ed1f835..ac9b15c7c19188 100644 --- a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py +++ b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py @@ -141,11 +141,11 @@ def side_effect_group_data(*arg, **kwargs): mock_pagination.total_available = None group1: GroupItem = GroupItem( - name="AB_XY00-Tableau-Access_A_123_PROJECT_XY_Consumer" + name="AB_XY00-Tableau-Access_A_123_PROJECT_XY_Consumer", ) group1._id = "79d02655-88e5-45a6-9f9b-eeaf5fe54903-group1" group2: GroupItem = GroupItem( - name="AB_XY00-Tableau-Access_A_123_PROJECT_XY_Analyst" + name="AB_XY00-Tableau-Access_A_123_PROJECT_XY_Analyst", ) group2._id = "79d02655-88e5-45a6-9f9b-eeaf5fe54903-group2" @@ -155,7 +155,8 @@ def side_effect_group_data(*arg, **kwargs): def side_effect_workbook_permissions(*arg, **kwargs): project_capabilities1 = {"Read": "Allow", "ViewComments": "Allow"} reference: ResourceReference = ResourceReference( - id_="79d02655-88e5-45a6-9f9b-eeaf5fe54903-group1", tag_name="group" + id_="79d02655-88e5-45a6-9f9b-eeaf5fe54903-group1", + tag_name="group", ) rule1 = PermissionsRule(grantee=reference, capabilities=project_capabilities1) @@ -166,7 +167,8 @@ def side_effect_workbook_permissions(*arg, **kwargs): "Write": "Allow", } reference2: ResourceReference = ResourceReference( - id_="79d02655-88e5-45a6-9f9b-eeaf5fe54903-group2", tag_name="group" + id_="79d02655-88e5-45a6-9f9b-eeaf5fe54903-group2", + tag_name="group", ) rule2 = PermissionsRule(grantee=reference2, capabilities=project_capabilities2) @@ -232,22 +234,26 @@ def side_effect_workbook_data(*arg, **kwargs): workbook1._id = "65a404a8-48a2-4c2a-9eb0-14ee5e78b22b" workbook2: WorkbookItem = WorkbookItem( - project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", name="Dvdrental Workbook" + project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", + name="Dvdrental Workbook", ) workbook2._id = "b2c84ac6-1e37-4ca0-bf9b-62339be046fc" workbook3: WorkbookItem = WorkbookItem( - project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", name="Executive Dashboard" + project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", + name="Executive Dashboard", ) workbook3._id = "68ebd5b2-ecf6-4fdf-ba1a-95427baef506" workbook4: WorkbookItem = WorkbookItem( - project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", name="Workbook published ds" + project_id="190a6a5c-63ed-4de1-8045-faeae5df5b01", + name="Workbook published ds", ) workbook4._id = "a059a443-7634-4abf-9e46-d147b99168be" workbook5: WorkbookItem = WorkbookItem( - project_id="79d02655-88e5-45a6-9f9b-eeaf5fe54903", name="Deny Pattern WorkBook" + project_id="79d02655-88e5-45a6-9f9b-eeaf5fe54903", + name="Deny Pattern WorkBook", ) workbook5._id = "b45eabfe-dc3d-4331-9324-cc1b14b0549b" @@ -304,7 +310,7 @@ def mock_sdk_client( mock_client.workbooks.get.side_effect = side_effect_workbook_data workbook_mock = mock.create_autospec(WorkbookItem, instance=True) type(workbook_mock).permissions = mock.PropertyMock( - return_value=side_effect_workbook_permissions() + return_value=side_effect_workbook_permissions(), ) mock_client.workbooks.get_by_id.return_value = workbook_mock @@ -355,7 +361,7 @@ def tableau_ingest_common( "filename": f"{tmp_path}/{output_file_name}", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -410,7 +416,8 @@ def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph): def test_tableau_test_connection_success(): with mock.patch("datahub.ingestion.source.tableau.tableau.Server"): report = test_connection_helpers.run_test_connection( - TableauSource, config_source_default + TableauSource, + config_source_default, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -419,7 +426,8 @@ def test_tableau_test_connection_success(): @pytest.mark.integration def test_tableau_test_connection_failure(): report = test_connection_helpers.run_test_connection( - TableauSource, config_source_default + TableauSource, + config_source_default, ) test_connection_helpers.assert_basic_connectivity_failure(report, "Unable to login") @@ -565,7 +573,9 @@ def test_extract_all_project(pytestconfig, tmp_path, mock_datahub_graph): def test_value_error_projects_and_project_pattern( - pytestconfig, tmp_path, mock_datahub_graph + pytestconfig, + tmp_path, + mock_datahub_graph, ): new_config = config_source_default.copy() new_config["projects"] = ["default"] @@ -632,7 +642,9 @@ def test_project_path_pattern_deny(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_ingest_with_platform_instance( - pytestconfig, tmp_path, mock_datahub_graph + pytestconfig, + tmp_path, + mock_datahub_graph, ): output_file_name: str = "tableau_with_platform_instance_mces.json" golden_file_name: str = "tableau_with_platform_instance_mces_golden.json" @@ -689,7 +701,8 @@ def test_lineage_overrides(): "test-table", "presto", ).make_dataset_urn( - env=DEFAULT_ENV, platform_instance_map={"presto": "my_presto_instance"} + env=DEFAULT_ENV, + platform_instance_map={"presto": "my_presto_instance"}, ) == "urn:li:dataset:(urn:li:dataPlatform:presto,my_presto_instance.presto_catalog.test-schema.test-table,PROD)" ) @@ -757,7 +770,7 @@ def test_database_hostname_to_platform_instance_map(): env=DEFAULT_ENV, platform_instance_map={}, database_hostname_to_platform_instance_map={ - "test-hostname": "test-platform-instance" + "test-hostname": "test-platform-instance", }, database_server_hostname_map={"test-database-id": "test-hostname"}, ) @@ -802,10 +815,12 @@ def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph) # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform all assertions on the states. The deleted table should not be @@ -814,7 +829,7 @@ def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph) state2 = checkpoint2.state difference_dataset_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert len(difference_dataset_urns) == 35 @@ -858,7 +873,7 @@ def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph) assert sorted(deleted_dataset_urns) == sorted(difference_dataset_urns) difference_chart_urns = list( - state1.get_urns_not_in(type="chart", other_checkpoint_state=state2) + state1.get_urns_not_in(type="chart", other_checkpoint_state=state2), ) assert len(difference_chart_urns) == 24 deleted_chart_urns = [ @@ -890,7 +905,7 @@ def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph) assert sorted(deleted_chart_urns) == sorted(difference_chart_urns) difference_dashboard_urns = list( - state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dashboard", other_checkpoint_state=state2), ) assert len(difference_dashboard_urns) == 4 deleted_dashboard_urns = [ @@ -948,11 +963,14 @@ def test_tableau_unsupported_csql(): config = TableauConfig.parse_obj(config_dict) config.extract_lineage_from_unsupported_custom_sql_queries = True config.lineage_overrides = TableauLineageOverrides( - database_override_map={"production database": "prod"} + database_override_map={"production database": "prod"}, ) def check_lineage_metadata( - lineage, expected_entity_urn, expected_upstream_table, expected_cll + lineage, + expected_entity_urn, + expected_upstream_table, + expected_cll, ): mcp = cast(MetadataChangeProposalWrapper, list(lineage)[0].metadata) @@ -961,17 +979,17 @@ def check_lineage_metadata( UpstreamClass( dataset=expected_upstream_table, type=DatasetLineageType.TRANSFORMED, - ) + ), ], fineGrainedLineages=[ FineGrainedLineage( upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, upstreams=[ - make_schema_field_urn(expected_upstream_table, upstream_column) + make_schema_field_urn(expected_upstream_table, upstream_column), ], downstreamType=FineGrainedLineageDownstreamType.FIELD, downstreams=[ - make_schema_field_urn(expected_entity_urn, downstream_column) + make_schema_field_urn(expected_entity_urn, downstream_column), ], ) for upstream_column, downstream_column in expected_cll.items() @@ -1278,7 +1296,7 @@ def test_permission_warning(pytestconfig, tmp_path, mock_datahub_graph): with mock.patch("datahub.ingestion.source.tableau.tableau.Server") as mock_sdk: mock_sdk.return_value = mock_sdk_client( side_effect_query_metadata_response=[ - read_response("permission_mode_switched_error.json") + read_response("permission_mode_switched_error.json"), ], sign_out_side_effect=[{}], datasources_side_effect=[{}], @@ -1330,7 +1348,7 @@ def test_retry_on_error(pytestconfig, tmp_path, mock_datahub_graph): mock_client = mock_sdk_client( side_effect_query_metadata_response=[ NonXMLResponseError( - """{"timestamp":"xxx","status":401,"error":"Unauthorized","path":"/relationship-service-war/graphql"}""" + """{"timestamp":"xxx","status":401,"error":"Unauthorized","path":"/relationship-service-war/graphql"}""", ), *mock_data(), ], @@ -1340,8 +1358,9 @@ def test_retry_on_error(pytestconfig, tmp_path, mock_datahub_graph): mock_client.users = mock.Mock() mock_client.users.get_by_id.side_effect = [ UserItem( - name="name", site_role=UserItem.Roles.SiteAdministratorExplorer - ) + name="name", + site_role=UserItem.Roles.SiteAdministratorExplorer, + ), ] mock_sdk.return_value = mock_client diff --git a/metadata-ingestion/tests/integration/trino/test_trino.py b/metadata-ingestion/tests/integration/trino/test_trino.py index 6437666fed62b5..41e0298b15c362 100644 --- a/metadata-ingestion/tests/integration/trino/test_trino.py +++ b/metadata-ingestion/tests/integration/trino/test_trino.py @@ -28,7 +28,8 @@ def trino_runner(docker_compose_runner, pytestconfig): test_resources_dir = pytestconfig.rootpath / "tests/integration/trino" with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "trino" + test_resources_dir / "docker-compose.yml", + "trino", ) as docker_services: wait_for_port(docker_services, "testtrino", 8080) wait_for_port(docker_services, "testhiveserver2", 10000, timeout=120) @@ -58,7 +59,11 @@ def loaded_trino(trino_runner): @freeze_time(FROZEN_TIME) def test_trino_ingest( - loaded_trino, test_resources_dir, pytestconfig, tmp_path, mock_time + loaded_trino, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, ): # Run the metadata ingestion pipeline. with fs_helpers.isolated_filesystem(tmp_path): @@ -76,7 +81,7 @@ def test_trino_ingest( username="foo", schema_pattern=AllowDenyPattern(allow=["^librarydb"]), profile_pattern=AllowDenyPattern( - allow=["postgresqldb.librarydb.*"] + allow=["postgresqldb.librarydb.*"], ), profiling=GEProfilingConfig( enabled=True, @@ -100,7 +105,7 @@ def test_trino_ingest( config=DataHubClassifierConfig( minimum_values_threshold=1, ), - ) + ), ], max_workers=1, ), @@ -108,7 +113,7 @@ def test_trino_ingest( "postgresqldb": ConnectorDetail( connector_database="postgres", platform_instance="local_server", - ) + ), }, ).dict(), }, @@ -133,7 +138,11 @@ def test_trino_ingest( @freeze_time(FROZEN_TIME) def test_trino_hive_ingest( - loaded_trino, test_resources_dir, pytestconfig, tmp_path, mock_time + loaded_trino, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, ): # Run the metadata ingestion pipeline for trino catalog referring to postgres database mce_out_file = "trino_hive_mces.json" @@ -156,7 +165,7 @@ def test_trino_hive_ingest( config=DataHubClassifierConfig( minimum_values_threshold=1, ), - ) + ), ], max_workers=1, ), @@ -200,7 +209,11 @@ def test_trino_hive_ingest( @freeze_time(FROZEN_TIME) def test_trino_instance_ingest( - loaded_trino, test_resources_dir, pytestconfig, tmp_path, mock_time + loaded_trino, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, ): mce_out_file = "trino_instance_mces.json" events_file = tmp_path / mce_out_file @@ -218,7 +231,7 @@ def test_trino_instance_ingest( "hivedb": ConnectorDetail( connector_platform="glue", platform_instance="local_server", - ) + ), }, ).dict(), }, diff --git a/metadata-ingestion/tests/integration/unity/test_unity_catalog_ingest.py b/metadata-ingestion/tests/integration/unity/test_unity_catalog_ingest.py index 9c7b86a275f6d0..f86e4d917d6f66 100644 --- a/metadata-ingestion/tests/integration/unity/test_unity_catalog_ingest.py +++ b/metadata-ingestion/tests/integration/unity/test_unity_catalog_ingest.py @@ -61,7 +61,7 @@ def register_mock_data(workspace_client): "updated_by": "abc@acryl.io", "cloud": "aws", "global_metastore_id": "aws:us-west-1:2c983545-d403-4f87-9063-5b7e3b6d3736", - } + }, ) workspace_client.catalogs.list.return_value = [ @@ -77,7 +77,7 @@ def register_mock_data(workspace_client): "updated_at": 1666186064332, "updated_by": "abc@acryl.io", "catalog_type": "MANAGED_CATALOG", - } + }, ] ] @@ -151,7 +151,7 @@ def register_mock_data(workspace_client): "updated_at": 1666186049633, "updated_by": "abc@acryl.io", "table_id": "cff27aa1-1c6a-4d78-b713-562c660c2896", - } + }, ), databricks.sdk.service.catalog.TableInfo.from_dict( { @@ -201,7 +201,7 @@ def register_mock_data(workspace_client): "updated_at": 1666186049633, "updated_by": "abc@acryl.io", "table_id": "cff27aa1-1c6a-4d78-b713-562c660c2896", - } + }, ), ] @@ -254,7 +254,7 @@ def register_mock_data(workspace_client): "updated_at": 1666186049633, "updated_by": "abc@acryl.io", "table_id": "cff27aa1-1c6a-4d78-b713-562c660c2896", - } + }, ) ) @@ -279,7 +279,8 @@ def register_mock_data(workspace_client): TableEntry = namedtuple("TableEntry", ["database", "tableName", "isTemporary"]) ViewEntry = namedtuple( - "ViewEntry", ["namespace", "viewName", "isTemporary", "isMaterialized"] + "ViewEntry", + ["namespace", "viewName", "isTemporary", "isMaterialized"], ) @@ -406,13 +407,13 @@ def mock_hive_sql(query): ] elif query == "DESCRIBE EXTENDED `bronze_kambi`.`delta_error_table`": raise Exception( - "[DELTA_PATH_DOES_NOT_EXIST] doesn't exist, or is not a Delta table." + "[DELTA_PATH_DOES_NOT_EXIST] doesn't exist, or is not a Delta table.", ) elif query == "SHOW CREATE TABLE `bronze_kambi`.`view1`": return [ ( "CREATE VIEW `hive_metastore`.`bronze_kambi`.`view1` AS SELECT * FROM `hive_metastore`.`bronze_kambi`.`bet`", - ) + ), ] elif query == "SHOW TABLES FROM `bronze_kambi`": return [ @@ -436,9 +437,10 @@ def test_ingestion(pytestconfig, tmp_path, requests_mock): output_file_name = "unity_catalog_mcps.json" with patch( - "datahub.ingestion.source.unity.proxy.WorkspaceClient" + "datahub.ingestion.source.unity.proxy.WorkspaceClient", ) as mock_client, patch.object( - HiveMetastoreProxy, "get_inspector" + HiveMetastoreProxy, + "get_inspector", ) as get_inspector, patch.object(HiveMetastoreProxy, "_execute_sql") as execute_sql: workspace_client: mock.MagicMock = mock.MagicMock() mock_client.return_value = workspace_client diff --git a/metadata-ingestion/tests/integration/vertica/test_vertica.py b/metadata-ingestion/tests/integration/vertica/test_vertica.py index d7b4c390f75d94..1056fb2189471b 100644 --- a/metadata-ingestion/tests/integration/vertica/test_vertica.py +++ b/metadata-ingestion/tests/integration/vertica/test_vertica.py @@ -28,7 +28,8 @@ def is_vertica_responsive(container_name: str) -> bool: @pytest.fixture(scope="module") def vertica_runner(docker_compose_runner, test_resources_dir): with docker_compose_runner( - test_resources_dir / "docker-compose.yml", "vertica" + test_resources_dir / "docker-compose.yml", + "vertica", ) as docker_services: wait_for_port( docker_services, @@ -61,7 +62,8 @@ def test_vertica_ingest_with_db(vertica_runner, pytestconfig, tmp_path): # Run the metadata ingestion pipeline. config_file = (test_resources_dir / "vertica_to_file.yml").resolve() run_datahub_cmd( - ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ["ingest", "--strict-warnings", "-c", f"{config_file}"], + tmp_path=tmp_path, ) ignore_paths: List[str] = [ diff --git a/metadata-ingestion/tests/performance/bigquery/bigquery_events.py b/metadata-ingestion/tests/performance/bigquery/bigquery_events.py index bf3d566da8d278..f2f35b2fc0a80a 100644 --- a/metadata-ingestion/tests/performance/bigquery/bigquery_events.py +++ b/metadata-ingestion/tests/performance/bigquery/bigquery_events.py @@ -62,7 +62,7 @@ def generate_events( ref_from_table(field.table, table_to_project) for field in query.fields_accessed if field.table.is_view() - ) + ), ) yield AuditEvent.create( @@ -83,7 +83,7 @@ def generate_events( ref_from_table(field.table, table_to_project) for field in query.fields_accessed if not field.table.is_view() - ) + ), ) + list( dict.fromkeys( # Preserve order @@ -91,7 +91,7 @@ def generate_events( for field in query.fields_accessed if field.table.is_view() for parent in field.table.upstreams - ) + ), ), referencedViews=referencedViews, payload=( @@ -100,19 +100,19 @@ def generate_events( else None ), query_on_view=True if referencedViews else False, - ) + ), ) table_accesses: Dict[BigQueryTableRef, Set[str]] = defaultdict(set) for field in query.fields_accessed: if not field.table.is_view(): table_accesses[ref_from_table(field.table, table_to_project)].add( - field.column + field.column, ) else: # assuming that same fields are accessed in parent tables for parent in field.table.upstreams: table_accesses[ref_from_table(parent, table_to_project)].add( - field.column + field.column, ) for ref, columns in table_accesses.items(): @@ -129,13 +129,15 @@ def generate_events( if config.debug_include_full_payloads else None ), - ) + ), ) def ref_from_table(table: Table, table_to_project: Dict[str, str]) -> BigQueryTableRef: return BigQueryTableRef( BigqueryTableIdentifier( - table_to_project[table.name], table.container.name, table.name - ) + table_to_project[table.name], + table.container.name, + table.name, + ), ) diff --git a/metadata-ingestion/tests/performance/bigquery/test_bigquery_usage.py b/metadata-ingestion/tests/performance/bigquery/test_bigquery_usage.py index 24460f38298069..894a7279e404f5 100644 --- a/metadata-ingestion/tests/performance/bigquery/test_bigquery_usage.py +++ b/metadata-ingestion/tests/performance/bigquery/test_bigquery_usage.py @@ -67,11 +67,11 @@ def run_test(): num_unique_queries=50_000, num_users=2000, query_length=NormalDistribution(2000, 500), - ) + ), ) queries.sort(key=lambda q: q.timestamp) events = list( - generate_events(queries, projects, table_to_project, config=config) + generate_events(queries, projects, table_to_project, config=config), ) print(f"Events generated: {len(events)}") pre_mem_usage = psutil.Process(os.getpid()).memory_info().rss @@ -86,7 +86,7 @@ def run_test(): print(f"Seconds Elapsed: {timer.elapsed_seconds(digits=2)} seconds") print( - f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}" + f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}", ) print(f"Disk Used: {report.processing_perf.usage_state_size}") print(f"Hash collisions: {report.num_usage_query_hash_collisions}") diff --git a/metadata-ingestion/tests/performance/data_generation.py b/metadata-ingestion/tests/performance/data_generation.py index 266c0d9af03224..4d8d84841b4a1a 100644 --- a/metadata-ingestion/tests/performance/data_generation.py +++ b/metadata-ingestion/tests/performance/data_generation.py @@ -50,7 +50,10 @@ def _sample(self) -> int: raise NotImplementedError def sample( - self, *, floor: Optional[int] = None, ceiling: Optional[int] = None + self, + *, + floor: Optional[int] = None, + ceiling: Optional[int] = None, ) -> int: value = self._sample() if floor is not None: @@ -208,7 +211,8 @@ def generate_queries( type="SELECT", actor=random.choice(users), timestamp=_random_time_between( - seed_metadata.start_time, seed_metadata.end_time + seed_metadata.start_time, + seed_metadata.end_time, ), fields_accessed=_sample_list(all_columns, columns_per_select), ) @@ -229,7 +233,8 @@ def generate_queries( type=random.choice(OPERATION_TYPES), actor=random.choice(users), timestamp=_random_time_between( - seed_metadata.start_time, seed_metadata.end_time + seed_metadata.start_time, + seed_metadata.end_time, ), # Can have no field accesses, e.g. on a standard INSERT fields_accessed=_sample_list(all_columns, num_columns_modified, 0), @@ -247,7 +252,9 @@ def _container_type(i: int) -> str: def _generate_table( - i: int, parents: List[Container], columns_per_table: Distribution + i: int, + parents: List[Container], + columns_per_table: Distribution, ) -> Table: num_columns = columns_per_table.sample(floor=1) diff --git a/metadata-ingestion/tests/performance/databricks/generator.py b/metadata-ingestion/tests/performance/databricks/generator.py index b11771e55b2c9e..78353e19a631de 100644 --- a/metadata-ingestion/tests/performance/databricks/generator.py +++ b/metadata-ingestion/tests/performance/databricks/generator.py @@ -35,7 +35,9 @@ def __init__(self, host: str, token: str, warehouse_id: str): uri_opts={"http_path": f"/sql/1.0/warehouses/{warehouse_id}"}, ) engine = create_engine( - url, connect_args={"timeout": 600}, pool_size=MAX_WORKERS + url, + connect_args={"timeout": 600}, + pool_size=MAX_WORKERS, ) self.connection = engine.connect() @@ -64,11 +66,14 @@ def create_data( "populate tables", seed_metadata.tables, lambda t: self._populate_table( - t, num_rows_distribution.sample(ceiling=1_000_000) + t, + num_rows_distribution.sample(ceiling=1_000_000), ), ) _thread_pool_execute( - "create table lineage", seed_metadata.tables, self._create_table_lineage + "create table lineage", + seed_metadata.tables, + self._create_table_lineage, ) def _create_catalog(self, catalog: Container) -> None: diff --git a/metadata-ingestion/tests/performance/databricks/test_unity.py b/metadata-ingestion/tests/performance/databricks/test_unity.py index 71192dc5b509bc..002df2e0e204a6 100644 --- a/metadata-ingestion/tests/performance/databricks/test_unity.py +++ b/metadata-ingestion/tests/performance/databricks/test_unity.py @@ -35,7 +35,9 @@ def run_test(): num_users=1000, ) proxy_mock = UnityCatalogApiProxyMock( - seed_metadata, queries=queries, num_service_principals=10000 + seed_metadata, + queries=queries, + num_service_principals=10000, ) print("Data generated") @@ -62,7 +64,7 @@ def run_test(): print(f"Seconds Elapsed: {timer.elapsed_seconds(digits=2)} seconds") print( - f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}" + f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}", ) print(source.report.as_string()) diff --git a/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py b/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py index 307a7ba71ef839..0e04d70a3a4c96 100644 --- a/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py +++ b/metadata-ingestion/tests/performance/databricks/unity_proxy_mock.py @@ -103,7 +103,7 @@ def tables(self, schema: Schema) -> Iterable[Table]: comment=None, type_precision=0, type_scale=0, - ) + ), ) yield Table( diff --git a/metadata-ingestion/tests/performance/helpers.py b/metadata-ingestion/tests/performance/helpers.py index 9bfd9ebc8de0d5..039c32ad30e292 100644 --- a/metadata-ingestion/tests/performance/helpers.py +++ b/metadata-ingestion/tests/performance/helpers.py @@ -12,10 +12,12 @@ def workunit_sink(workunits: Iterable[MetadataWorkUnit]) -> Tuple[int, int]: for i, _wu in enumerate(workunits): if i % 10_000 == 0: peak_memory_usage = max( - peak_memory_usage, psutil.Process(os.getpid()).memory_info().rss + peak_memory_usage, + psutil.Process(os.getpid()).memory_info().rss, ) peak_memory_usage = max( - peak_memory_usage, psutil.Process(os.getpid()).memory_info().rss + peak_memory_usage, + psutil.Process(os.getpid()).memory_info().rss, ) return i, peak_memory_usage diff --git a/metadata-ingestion/tests/performance/snowflake/test_snowflake.py b/metadata-ingestion/tests/performance/snowflake/test_snowflake.py index a940cce46a8f74..a1f3535fcadd9b 100644 --- a/metadata-ingestion/tests/performance/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/performance/snowflake/test_snowflake.py @@ -57,7 +57,7 @@ def run_test(): logging.info(source.get_report().as_string()) logging.info( - f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}" + f"Peak Memory Used: {humanfriendly.format_size(peak_memory_usage - pre_mem_usage)}", ) logging.info(source.report.aspects) diff --git a/metadata-ingestion/tests/performance/sql/test_sql_formatter.py b/metadata-ingestion/tests/performance/sql/test_sql_formatter.py index f09047c0ec4a4f..02caf7bb7e3b31 100644 --- a/metadata-ingestion/tests/performance/sql/test_sql_formatter.py +++ b/metadata-ingestion/tests/performance/sql/test_sql_formatter.py @@ -12,13 +12,13 @@ def run_test() -> None: for i in range(N): if i % 50 == 0: print( - f"Running iteration {i}, elapsed time: {timer.elapsed_seconds(digits=2)} seconds" + f"Running iteration {i}, elapsed time: {timer.elapsed_seconds(digits=2)} seconds", ) try_format_query.__wrapped__(large_sql_query, platform="snowflake") print( - f"Total time taken for {N} iterations: {timer.elapsed_seconds(digits=2)} seconds" + f"Total time taken for {N} iterations: {timer.elapsed_seconds(digits=2)} seconds", ) diff --git a/metadata-ingestion/tests/test_helpers/click_helpers.py b/metadata-ingestion/tests/test_helpers/click_helpers.py index 89ac6b143f4ed3..df3a94f68502da 100644 --- a/metadata-ingestion/tests/test_helpers/click_helpers.py +++ b/metadata-ingestion/tests/test_helpers/click_helpers.py @@ -18,7 +18,9 @@ def assert_result_ok(result: Result) -> None: def run_datahub_cmd( - command: List[str], tmp_path: Optional[Path] = None, check_result: bool = True + command: List[str], + tmp_path: Optional[Path] = None, + check_result: bool = True, ) -> Result: runner = CliRunner() diff --git a/metadata-ingestion/tests/test_helpers/graph_helpers.py b/metadata-ingestion/tests/test_helpers/graph_helpers.py index 127285a48d930b..a48b7b130e5e90 100644 --- a/metadata-ingestion/tests/test_helpers/graph_helpers.py +++ b/metadata-ingestion/tests/test_helpers/graph_helpers.py @@ -22,7 +22,8 @@ class MockDataHubGraph(DataHubGraph): def __init__( - self, entity_graph: Optional[Dict[str, Dict[str, Any]]] = None + self, + entity_graph: Optional[Dict[str, Dict[str, Any]]] = None, ) -> None: self.emitted: List[ Union[ @@ -38,7 +39,8 @@ def import_file(self, file: Path) -> None: This function can be called repeatedly on the same Mock instance to load up metadata from multiple files.""" file_source: GenericFileSource = GenericFileSource( - ctx=PipelineContext(run_id="test"), config=FileSourceConfig(path=str(file)) + ctx=PipelineContext(run_id="test"), + config=FileSourceConfig(path=str(file)), ) for wu in file_source.get_workunits(): if isinstance(wu, MetadataWorkUnit): @@ -52,12 +54,13 @@ def import_file(self, file: Path) -> None: if isinstance(metadata, MetadataChangeEvent): mcps = mcps_from_mce(metadata) elif isinstance( - metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + metadata, + (MetadataChangeProposal, MetadataChangeProposalWrapper), ): mcps = [metadata] else: raise Exception( - f"Unexpected metadata type {type(metadata)}. Was expecting MCE, MCP or MCPW" + f"Unexpected metadata type {type(metadata)}. Was expecting MCE, MCP or MCPW", ) for mcp in mcps: @@ -69,7 +72,10 @@ def import_file(self, file: Path) -> None: self.entity_graph[mcp.entityUrn][mcp.aspectName] = mcp.aspect def get_aspect( - self, entity_urn: str, aspect_type: Type[Aspect], version: int = 0 + self, + entity_urn: str, + aspect_type: Type[Aspect], + version: int = 0, ) -> Optional[Aspect]: aspect_name = [v for v in ASPECT_NAME_MAP if ASPECT_NAME_MAP[v] == aspect_type][ 0 @@ -129,7 +135,9 @@ def get_emitted( self, ) -> List[ Union[ - MetadataChangeEvent, MetadataChangeProposal, MetadataChangeProposalWrapper + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, ] ]: return self.emitted diff --git a/metadata-ingestion/tests/test_helpers/mce_helpers.py b/metadata-ingestion/tests/test_helpers/mce_helpers.py index d70a440dab0657..9bdc4dda6e2d02 100644 --- a/metadata-ingestion/tests/test_helpers/mce_helpers.py +++ b/metadata-ingestion/tests/test_helpers/mce_helpers.py @@ -125,7 +125,9 @@ def _get_field_for_entity_type_in_mce(entity_type: str) -> str: def _get_filter( - mce: bool = False, mcp: bool = False, entity_type: Optional[str] = None + mce: bool = False, + mcp: bool = False, + entity_type: Optional[str] = None, ) -> Callable[[Dict], bool]: if mce: # cheap way to determine if we are working with an MCE for the appropriate entity_type @@ -159,7 +161,9 @@ def _get_element(event: Dict[str, Any], path_spec: List[str]) -> Any: def _element_matches_pattern( - event: Dict[str, Any], path_spec: List[str], pattern: str + event: Dict[str, Any], + path_spec: List[str], + pattern: str, ) -> Tuple[bool, bool]: import re @@ -194,7 +198,10 @@ def _get_entity_urns(events_list: List[Dict]) -> Set[str]: def assert_mcp_entity_urn( - filter: str, entity_type: str, regex_pattern: str, file: str + filter: str, + entity_type: str, + regex_pattern: str, + file: str, ) -> int: def get_path_spec_for_urn() -> List[str]: return [MCPConstants.ENTITY_URN] @@ -214,7 +221,7 @@ def get_path_spec_for_urn() -> List[str]: return len(filtered_events) else: raise Exception( - f"Did not expect the file {file} to not contain a list of items" + f"Did not expect the file {file} to not contain a list of items", ) @@ -233,7 +240,10 @@ def _get_mcp_urn_path_spec() -> List[str]: def assert_mce_entity_urn( - filter: str, entity_type: str, regex_pattern: str, file: str + filter: str, + entity_type: str, + regex_pattern: str, + file: str, ) -> int: """Assert that all mce entity urns must match the regex pattern passed in. Return the number of events matched""" @@ -249,12 +259,12 @@ def assert_mce_entity_urn( failed_events = [y for y in filtered_events if not y[1][0] or not y[1][1]] if failed_events: raise Exception( - "Failed to match events: {json.dumps(failed_events, indent=2)}" + "Failed to match events: {json.dumps(failed_events, indent=2)}", ) return len(filtered_events) else: raise Exception( - f"Did not expect the file {file} to not contain a list of items" + f"Did not expect the file {file} to not contain a list of items", ) @@ -292,7 +302,8 @@ def assert_for_each_entity( if o.get(MCPConstants.ASPECT_NAME) == aspect_name: # load the inner aspect payload and assign to this urn aspect_map[o[MCPConstants.ENTITY_URN]] = o.get( - MCPConstants.ASPECT_VALUE, {} + MCPConstants.ASPECT_VALUE, + {}, ).get("json") success: List[str] = [] @@ -312,14 +323,17 @@ def assert_for_each_entity( print(f"Succeeded on assertion for urns {success}") if failures: raise AssertionError( - f"Failed to find aspect_name {aspect_name} for urns {json.dumps(failures, indent=2)}" + f"Failed to find aspect_name {aspect_name} for urns {json.dumps(failures, indent=2)}", ) return len(success) def assert_entity_mce_aspect( - entity_urn: str, aspect: Any, aspect_type: Type, file: str + entity_urn: str, + aspect: Any, + aspect_type: Type, + file: str, ) -> int: # TODO: Replace with read_metadata_file() test_output = load_json_file(file) @@ -342,7 +356,10 @@ def assert_entity_mce_aspect( def assert_entity_mcp_aspect( - entity_urn: str, aspect_field_matcher: Dict[str, Any], aspect_name: str, file: str + entity_urn: str, + aspect_field_matcher: Dict[str, Any], + aspect_name: str, + file: str, ) -> int: # TODO: Replace with read_metadata_file() test_output = load_json_file(file) @@ -416,5 +433,5 @@ def assert_entity_urn_like(entity_type: str, regex_pattern: str, file: str) -> i return len(matched_urns) else: raise AssertionError( - f"No urns found that match the pattern {regex_pattern}. Full list is {all_urns}" + f"No urns found that match the pattern {regex_pattern}. Full list is {all_urns}", ) diff --git a/metadata-ingestion/tests/test_helpers/sink_helpers.py b/metadata-ingestion/tests/test_helpers/sink_helpers.py index a467a13ef5783c..546afbcc22461b 100644 --- a/metadata-ingestion/tests/test_helpers/sink_helpers.py +++ b/metadata-ingestion/tests/test_helpers/sink_helpers.py @@ -15,7 +15,9 @@ def report_record_written(self, record_envelope: RecordEnvelope) -> None: class RecordingSink(Sink[ConfigModel, RecordingSinkReport]): def write_record_async( - self, record_envelope: RecordEnvelope, callback: WriteCallback + self, + record_envelope: RecordEnvelope, + callback: WriteCallback, ) -> None: self.report.report_record_written(record_envelope) callback.on_success(record_envelope, {}) diff --git a/metadata-ingestion/tests/test_helpers/state_helpers.py b/metadata-ingestion/tests/test_helpers/state_helpers.py index c469db6ce8cf80..f9035345653b69 100644 --- a/metadata-ingestion/tests/test_helpers/state_helpers.py +++ b/metadata-ingestion/tests/test_helpers/state_helpers.py @@ -23,7 +23,8 @@ def validate_all_providers_have_committed_successfully( - pipeline: Pipeline, expected_providers: int + pipeline: Pipeline, + expected_providers: int, ) -> None: """ makes sure the pipeline includes the desired number of providers @@ -62,17 +63,21 @@ def __init__(self) -> None: self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True} # Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits. self.mock_graph.emit_mcp = types.MethodType( - self.monkey_patch_emit_mcp, self.mock_graph + self.monkey_patch_emit_mcp, + self.mock_graph, ) # Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value self.mock_graph.get_latest_timeseries_value = types.MethodType( - self.monkey_patch_get_latest_timeseries_value, self.mock_graph + self.monkey_patch_get_latest_timeseries_value, + self.mock_graph, ) # Tracking for emitted mcps. self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {} def monkey_patch_emit_mcp( - self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper + self, + graph_ref: MagicMock, + mcpw: MetadataChangeProposalWrapper, ) -> None: """ Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory. @@ -117,6 +122,6 @@ def get_current_checkpoint_from_pipeline( stateful_source = cast(StatefulIngestionSourceBase, pipeline.source) return stateful_source.state_provider.get_current_checkpoint( StaleEntityRemovalHandler.compute_job_id( - getattr(stateful_source, "platform", "default") - ) + getattr(stateful_source, "platform", "default"), + ), ) diff --git a/metadata-ingestion/tests/test_helpers/test_connection_helpers.py b/metadata-ingestion/tests/test_helpers/test_connection_helpers.py index 45543033ae010c..adb3b71d3326e4 100644 --- a/metadata-ingestion/tests/test_helpers/test_connection_helpers.py +++ b/metadata-ingestion/tests/test_helpers/test_connection_helpers.py @@ -9,7 +9,8 @@ def run_test_connection( - source_cls: Type[TestableSource], config_dict: Dict + source_cls: Type[TestableSource], + config_dict: Dict, ) -> TestConnectionReport: return source_cls.test_connection(config_dict) @@ -22,7 +23,8 @@ def assert_basic_connectivity_success(report: TestConnectionReport) -> None: def assert_basic_connectivity_failure( - report: TestConnectionReport, expected_reason: str + report: TestConnectionReport, + expected_reason: str, ) -> None: assert report is not None assert report.basic_connectivity diff --git a/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py b/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py index 7be8b667a500b3..ed333ef2437f3e 100644 --- a/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py +++ b/metadata-ingestion/tests/unit/api/entities/datacontract/test_data_quality_assertion.py @@ -27,7 +27,8 @@ def test_parse_sql_assertion(): } assert DataQualityAssertion.parse_obj(d).generate_mcp( - assertion_urn, entity_urn + assertion_urn, + entity_urn, ) == [ MetadataChangeProposalWrapper( entityUrn=assertion_urn, @@ -51,5 +52,5 @@ def test_parse_sql_assertion(): ), ), ), - ) + ), ] diff --git a/metadata-ingestion/tests/unit/api/entities/dataproducts/test_dataproduct.py b/metadata-ingestion/tests/unit/api/entities/dataproducts/test_dataproduct.py index dad7662d9ad00b..4f7ec82185cbcf 100644 --- a/metadata-ingestion/tests/unit/api/entities/dataproducts/test_dataproduct.py +++ b/metadata-ingestion/tests/unit/api/entities/dataproducts/test_dataproduct.py @@ -18,9 +18,10 @@ def base_entity_metadata(): return { "urn:li:domain:12345": { "domainProperties": DomainPropertiesClass( - name="Marketing", description="Marketing Domain" - ) - } + name="Marketing", + description="Marketing Domain", + ), + }, } @@ -45,7 +46,11 @@ def check_yaml_golden_file(input_file: str, golden_file: str) -> bool: diff_exists = False for line in difflib.unified_diff( - input_lines, golden_lines, fromfile=input_file, tofile=golden_file, lineterm="" + input_lines, + golden_lines, + fromfile=input_file, + tofile=golden_file, + lineterm="", ): print(line) diff_exists = True @@ -99,7 +104,8 @@ def test_dataproduct_from_datahub( mock_graph.import_file(golden_file) data_product: DataProduct = DataProduct.from_datahub( - mock_graph, id="urn:li:dataProduct:pet_of_the_week" + mock_graph, + id="urn:li:dataProduct:pet_of_the_week", ) assert data_product.domain == "urn:li:domain:12345" assert data_product.assets is not None @@ -139,17 +145,20 @@ def test_dataproduct_patch_yaml( data_product_file = test_resources_dir / original_file original_data_product: DataProduct = DataProduct.from_yaml( - data_product_file, mock_graph + data_product_file, + mock_graph, ) data_product: DataProduct = DataProduct.from_datahub( - mock_graph, id="urn:li:dataProduct:pet_of_the_week" + mock_graph, + id="urn:li:dataProduct:pet_of_the_week", ) dataproduct_output_file = Path(tmp_path / f"patch_{original_file}") data_product.patch_yaml(original_data_product, dataproduct_output_file) dataproduct_golden_file = Path(test_resources_dir / "golden_dataproduct_v2.yaml") assert ( check_yaml_golden_file( - str(dataproduct_output_file), str(dataproduct_golden_file) + str(dataproduct_output_file), + str(dataproduct_golden_file), ) is False ) @@ -177,7 +186,9 @@ def test_dataproduct_ownership_type_urn_from_yaml( @freeze_time(FROZEN_TIME) def test_dataproduct_ownership_type_urn_patch_yaml( - tmp_path: Path, test_resources_dir: Path, base_mock_graph: MockDataHubGraph + tmp_path: Path, + test_resources_dir: Path, + base_mock_graph: MockDataHubGraph, ) -> None: mock_graph = base_mock_graph source_file = test_resources_dir / "golden_dataproduct_out_ownership_type_urn.json" @@ -187,11 +198,13 @@ def test_dataproduct_ownership_type_urn_patch_yaml( test_resources_dir / "dataproduct_ownership_type_urn_different_owner.yaml" ) original_data_product: DataProduct = DataProduct.from_yaml( - data_product_file, mock_graph + data_product_file, + mock_graph, ) data_product: DataProduct = DataProduct.from_datahub( - mock_graph, id="urn:li:dataProduct:pet_of_the_week" + mock_graph, + id="urn:li:dataProduct:pet_of_the_week", ) dataproduct_output_file = ( diff --git a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py index 6a03f511fa51c5..c90fa36f596149 100644 --- a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py +++ b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py @@ -35,20 +35,22 @@ def test_platform_resource_dict(): x for x in mcps if x.aspectName == "platformResourceInfo" ][0] assert isinstance( - platform_resource_info_mcp.aspect, models.PlatformResourceInfoClass + platform_resource_info_mcp.aspect, + models.PlatformResourceInfoClass, ) assert platform_resource_info_mcp.aspect.primaryKey == "test_primary_key" assert platform_resource_info_mcp.aspect.secondaryKeys == ["test_secondary_key"] assert platform_resource_info_mcp.aspect.resourceType == "test_resource_type" assert isinstance( - platform_resource_info_mcp.aspect.value, models.SerializedValueClass + platform_resource_info_mcp.aspect.value, + models.SerializedValueClass, ) assert ( platform_resource_info_mcp.aspect.value.contentType == models.SerializedValueContentTypeClass.JSON ) assert platform_resource_info_mcp.aspect.value.blob == json.dumps( - {"test_key": "test_value"} + {"test_key": "test_value"}, ).encode("utf-8") assert platform_resource_info_mcp.aspect.value.schemaType is None @@ -63,7 +65,7 @@ def test_platform_resource_dict_removed(): platform="test_platform", resource_type="test_resource_type", primary_key="test_primary_key", - ) + ), ) mcps = [x for x in platform_resource.to_mcps()] @@ -104,13 +106,15 @@ def test_platform_resource_dictwrapper(): x for x in mcps if x.aspectName == "platformResourceInfo" ][0] assert isinstance( - platform_resource_info_mcp.aspect, models.PlatformResourceInfoClass + platform_resource_info_mcp.aspect, + models.PlatformResourceInfoClass, ) assert platform_resource_info_mcp.aspect.primaryKey == "U123456" assert platform_resource_info_mcp.aspect.secondaryKeys == ["a@b.com"] assert platform_resource_info_mcp.aspect.resourceType == "user_info" assert isinstance( - platform_resource_info_mcp.aspect.value, models.SerializedValueClass + platform_resource_info_mcp.aspect.value, + models.SerializedValueClass, ) assert ( platform_resource_info_mcp.aspect.value.contentType @@ -121,7 +125,7 @@ def test_platform_resource_dictwrapper(): == models.SerializedValueSchemaTypeClass.PEGASUS ) assert platform_resource_info_mcp.aspect.value.blob == json.dumps( - user_editable_info.to_obj() + user_editable_info.to_obj(), ).encode("utf-8") assert ( platform_resource_info_mcp.aspect.value.schemaRef @@ -165,20 +169,22 @@ class TestModel(BaseModel): x for x in mcps if x.aspectName == "platformResourceInfo" ][0] assert isinstance( - platform_resource_info_mcp.aspect, models.PlatformResourceInfoClass + platform_resource_info_mcp.aspect, + models.PlatformResourceInfoClass, ) assert platform_resource_info_mcp.aspect.primaryKey == "test_primary_key" assert platform_resource_info_mcp.aspect.secondaryKeys == ["test_secondary_key"] assert platform_resource_info_mcp.aspect.resourceType == "test_resource_type" assert isinstance( - platform_resource_info_mcp.aspect.value, models.SerializedValueClass + platform_resource_info_mcp.aspect.value, + models.SerializedValueClass, ) assert ( platform_resource_info_mcp.aspect.value.contentType == models.SerializedValueContentTypeClass.JSON ) assert platform_resource_info_mcp.aspect.value.blob == json.dumps( - test_model.dict() + test_model.dict(), ).encode("utf-8") assert platform_resource_info_mcp.aspect.value.schemaType == "JSON" assert platform_resource_info_mcp.aspect.value.schemaRef == TestModel.__name__ diff --git a/metadata-ingestion/tests/unit/api/source_helpers/test_auto_browse_path_v2.py b/metadata-ingestion/tests/unit/api/source_helpers/test_auto_browse_path_v2.py index 0ad777c577d70d..a14ecd98d3c54e 100644 --- a/metadata-ingestion/tests/unit/api/source_helpers/test_auto_browse_path_v2.py +++ b/metadata-ingestion/tests/unit/api/source_helpers/test_auto_browse_path_v2.py @@ -42,8 +42,9 @@ def test_auto_browse_path_v2_gen_containers_threaded(): if aspect: assert aspect.path == [ BrowsePathEntryClass( - id=database_key.as_urn(), urn=database_key.as_urn() - ) + id=database_key.as_urn(), + urn=database_key.as_urn(), + ), ] @@ -108,7 +109,7 @@ def test_auto_browse_path_v2_by_container_hierarchy(telemetry_ping_mock): if wu.get_aspect_of_type(models.StatusClass) and wu.get_urn() == urn ) assert new_wus[idx + 1].get_aspect_of_type( - models.BrowsePathsV2Class + models.BrowsePathsV2Class, ) or new_wus[idx + 2].get_aspect_of_type(models.BrowsePathsV2Class) @@ -124,17 +125,19 @@ def test_auto_browse_path_v2_ignores_urns_already_with(telemetry_ping_mock): "f": [ models.BrowsePathsClass(paths=["/one/two"]), models.BrowsePathsV2Class( - path=_make_browse_path_entries(["my", "path"]) + path=_make_browse_path_entries(["my", "path"]), ), ], "c": [ models.BrowsePathsV2Class( - path=_make_container_browse_path_entries(["custom", "path"]) - ) + path=_make_container_browse_path_entries( + ["custom", "path"], + ), + ), ], }, ), - ) + ), ] new_wus = list(auto_browse_path_v2(wus)) assert not telemetry_ping_mock.call_count, telemetry_ping_mock.call_args_list @@ -149,7 +152,7 @@ def test_auto_browse_path_v2_ignores_urns_already_with(telemetry_ping_mock): assert paths["f"] == _make_browse_path_entries(["my", "path"]) assert paths["d"] == _make_container_browse_path_entries(["custom", "path", "c"]) assert paths["e"] == _make_container_browse_path_entries( - ["custom", "path", "c", "d"] + ["custom", "path", "c", "d"], ) @@ -174,10 +177,10 @@ def test_auto_browse_path_v2_with_platform_instance_and_source_browse_path_v2( ], }, ), - ) + ), ] new_wus = list( - auto_browse_path_v2(wus, platform=platform, platform_instance=instance) + auto_browse_path_v2(wus, platform=platform, platform_instance=instance), ) assert not telemetry_ping_mock.call_count, telemetry_ping_mock.call_args_list assert ( @@ -254,7 +257,7 @@ def test_auto_browse_path_v2_container_over_legacy_browse_path(telemetry_ping_mo structure, other_aspects={"b": [models.BrowsePathsClass(paths=["/one/two"])]}, ), - ) + ), ) new_wus = list(auto_browse_path_v2(wus)) assert not telemetry_ping_mock.call_count, telemetry_ping_mock.call_args_list @@ -275,7 +278,8 @@ def test_auto_browse_path_v2_with_platform_instance(telemetry_ping_mock): platform_instance = "my_instance" platform_instance_urn = make_dataplatform_instance_urn(platform, platform_instance) platform_instance_entry = models.BrowsePathEntryClass( - platform_instance_urn, platform_instance_urn + platform_instance_urn, + platform_instance_urn, ) structure = {"a": {"b": ["c"]}} @@ -286,7 +290,7 @@ def test_auto_browse_path_v2_with_platform_instance(telemetry_ping_mock): wus, platform=platform, platform_instance=platform_instance, - ) + ), ) assert telemetry_ping_mock.call_count == 0 @@ -409,12 +413,14 @@ def _create_container_aspects( for k, v in d.items(): urn = make_container_urn(k) yield MetadataChangeProposalWrapper( - entityUrn=urn, aspect=models.StatusClass(removed=False) + entityUrn=urn, + aspect=models.StatusClass(removed=False), ).as_workunit() for aspect in other_aspects.pop(k, []): yield MetadataChangeProposalWrapper( - entityUrn=urn, aspect=aspect + entityUrn=urn, + aspect=aspect, ).as_workunit() for child in list(v): @@ -424,14 +430,17 @@ def _create_container_aspects( ).as_workunit() if isinstance(v, dict): yield from _create_container_aspects( - v, other_aspects=other_aspects, root=False + v, + other_aspects=other_aspects, + root=False, ) if root: for k, v in other_aspects.items(): for aspect in v: yield MetadataChangeProposalWrapper( - entityUrn=make_container_urn(k), aspect=aspect + entityUrn=make_container_urn(k), + aspect=aspect, ).as_workunit() diff --git a/metadata-ingestion/tests/unit/api/source_helpers/test_ensure_aspect_size.py b/metadata-ingestion/tests/unit/api/source_helpers/test_ensure_aspect_size.py index 8a45efb46893ae..a2c9ebd78596a7 100644 --- a/metadata-ingestion/tests/unit/api/source_helpers/test_ensure_aspect_size.py +++ b/metadata-ingestion/tests/unit/api/source_helpers/test_ensure_aspect_size.py @@ -63,7 +63,7 @@ def too_big_schema_metadata() -> SchemaMetadataClass: nativeDataType="int", type=SchemaFieldDataTypeClass(type=NumberTypeClass()), description=20000 * "a", - ) + ), ) # adding small field to check whether it will still be present in the output @@ -72,7 +72,7 @@ def too_big_schema_metadata() -> SchemaMetadataClass: "dddd", nativeDataType="int", type=SchemaFieldDataTypeClass(type=NumberTypeClass()), - ) + ), ) return SchemaMetadataClass( schemaName="abcdef", @@ -148,7 +148,8 @@ def proper_dataset_profile() -> DatasetProfileClass: DatasetFieldProfileClass(fieldPath="j", sampleValues=sample_values), ] return DatasetProfileClass( - timestampMillis=int(time.time()) * 1000, fieldProfiles=field_profiles + timestampMillis=int(time.time()) * 1000, + fieldProfiles=field_profiles, ) @@ -157,7 +158,8 @@ def test_ensure_size_of_proper_dataset_profile(processor): profile = proper_dataset_profile() orig_repr = json.dumps(profile.to_obj()) processor.ensure_dataset_profile_size( - "urn:li:dataset:(s3, dummy_dataset, DEV)", profile + "urn:li:dataset:(s3, dummy_dataset, DEV)", + profile, ) assert orig_repr == json.dumps(profile.to_obj()), ( "Aspect was modified in case where workunit processor should have been no-op" @@ -170,7 +172,8 @@ def test_ensure_size_of_too_big_schema_metadata(processor): assert len(schema.fields) == 1004 processor.ensure_schema_metadata_size( - "urn:li:dataset:(s3, dummy_dataset, DEV)", schema + "urn:li:dataset:(s3, dummy_dataset, DEV)", + schema, ) assert len(schema.fields) < 1004, "Schema has not been properly truncated" assert schema.fields[-1].fieldPath == "dddd", "Small field was not added at the end" @@ -187,7 +190,8 @@ def test_ensure_size_of_proper_schema_metadata(processor): schema = proper_schema_metadata() orig_repr = json.dumps(schema.to_obj()) processor.ensure_schema_metadata_size( - "urn:li:dataset:(s3, dummy_dataset, DEV)", schema + "urn:li:dataset:(s3, dummy_dataset, DEV)", + schema, ) assert orig_repr == json.dumps(schema.to_obj()), ( "Aspect was modified in case where workunit processor should have been no-op" @@ -204,7 +208,8 @@ def test_ensure_size_of_too_big_dataset_profile(processor): assert profile.fieldProfiles profile.fieldProfiles.insert(4, big_field) processor.ensure_dataset_profile_size( - "urn:li:dataset:(s3, dummy_dataset, DEV)", profile + "urn:li:dataset:(s3, dummy_dataset, DEV)", + profile, ) expected_profile = proper_dataset_profile() @@ -221,13 +226,15 @@ def test_ensure_size_of_too_big_dataset_profile(processor): @freeze_time("2023-01-02 00:00:00") @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size", ) @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size", ) def test_wu_processor_triggered_by_data_profile_aspect( - ensure_dataset_profile_size_mock, ensure_schema_metadata_size_mock, processor + ensure_dataset_profile_size_mock, + ensure_schema_metadata_size_mock, + processor, ): ret = [ # noqa: F841 *processor.ensure_aspect_size( @@ -235,9 +242,9 @@ def test_wu_processor_triggered_by_data_profile_aspect( MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:s3, dummy_name, DEV)", aspect=proper_dataset_profile(), - ).as_workunit() - ] - ) + ).as_workunit(), + ], + ), ] ensure_dataset_profile_size_mock.assert_called_once() ensure_schema_metadata_size_mock.assert_not_called() @@ -245,13 +252,15 @@ def test_wu_processor_triggered_by_data_profile_aspect( @freeze_time("2023-01-02 00:00:00") @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size", ) @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size", ) def test_wu_processor_triggered_by_data_profile_aspect_mcpc( - ensure_dataset_profile_size_mock, ensure_schema_metadata_size_mock, processor + ensure_dataset_profile_size_mock, + ensure_schema_metadata_size_mock, + processor, ): profile_aspect = proper_dataset_profile() mcpc = MetadataWorkUnit( @@ -274,20 +283,23 @@ def test_wu_processor_triggered_by_data_profile_aspect_mcpc( @freeze_time("2023-01-02 00:00:00") @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size", ) @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size", ) def test_wu_processor_triggered_by_data_profile_aspect_mce( - ensure_dataset_profile_size_mock, ensure_schema_metadata_size_mock, processor + ensure_dataset_profile_size_mock, + ensure_schema_metadata_size_mock, + processor, ): snapshot = DatasetSnapshotClass( urn="urn:li:dataset:(urn:li:dataPlatform:s3, dummy_name, DEV)", aspects=[proper_schema_metadata()], ) mce = MetadataWorkUnit( - id="test", mce=MetadataChangeEvent(proposedSnapshot=snapshot) + id="test", + mce=MetadataChangeEvent(proposedSnapshot=snapshot), ) ret = [*processor.ensure_aspect_size([mce])] # noqa: F841 ensure_schema_metadata_size_mock.assert_called_once() @@ -296,13 +308,15 @@ def test_wu_processor_triggered_by_data_profile_aspect_mce( @freeze_time("2023-01-02 00:00:00") @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size", ) @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size", ) def test_wu_processor_triggered_by_schema_metadata_aspect( - ensure_dataset_profile_size_mock, ensure_schema_metadata_size_mock, processor + ensure_dataset_profile_size_mock, + ensure_schema_metadata_size_mock, + processor, ): ret = [ # noqa: F841 *processor.ensure_aspect_size( @@ -310,9 +324,9 @@ def test_wu_processor_triggered_by_schema_metadata_aspect( MetadataChangeProposalWrapper( entityUrn="urn:li:dataset:(urn:li:dataPlatform:s3, dummy_name, DEV)", aspect=proper_schema_metadata(), - ).as_workunit() - ] - ) + ).as_workunit(), + ], + ), ] ensure_schema_metadata_size_mock.assert_called_once() ensure_dataset_profile_size_mock.assert_not_called() @@ -320,13 +334,15 @@ def test_wu_processor_triggered_by_schema_metadata_aspect( @freeze_time("2023-01-02 00:00:00") @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_schema_metadata_size", ) @patch( - "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size" + "datahub.ingestion.api.auto_work_units.auto_ensure_aspect_size.EnsureAspectSizeProcessor.ensure_dataset_profile_size", ) def test_wu_processor_not_triggered_by_unhandled_aspects( - ensure_dataset_profile_size_mock, ensure_schema_metadata_size_mock, processor + ensure_dataset_profile_size_mock, + ensure_schema_metadata_size_mock, + processor, ): ret = [ # noqa: F841 *processor.ensure_aspect_size( @@ -339,8 +355,8 @@ def test_wu_processor_not_triggered_by_unhandled_aspects( entityUrn="urn:li:dataset:(urn:li:dataPlatform:s3, dummy_name, DEV)", aspect=SubTypesClass(typeNames=["table"]), ).as_workunit(), - ] - ) + ], + ), ] ensure_schema_metadata_size_mock.assert_not_called() ensure_dataset_profile_size_mock.assert_not_called() diff --git a/metadata-ingestion/tests/unit/api/source_helpers/test_incremental_lineage_helper.py b/metadata-ingestion/tests/unit/api/source_helpers/test_incremental_lineage_helper.py index c5c4a378894c32..547708c3d8e8a5 100644 --- a/metadata-ingestion/tests/unit/api/source_helpers/test_incremental_lineage_helper.py +++ b/metadata-ingestion/tests/unit/api/source_helpers/test_incremental_lineage_helper.py @@ -32,7 +32,8 @@ def make_lineage_aspect( dataset=upstream_urn, type=models.DatasetLineageTypeClass.TRANSFORMED, auditStamp=models.AuditStampClass( - time=timestamp, actor="urn:li:corpuser:unknown" + time=timestamp, + actor="urn:li:corpuser:unknown", ), ) for upstream_urn in upstreams @@ -89,8 +90,10 @@ def test_incremental_table_lineage(tmp_path, pytestconfig): incremental_lineage=True, stream=[ MetadataChangeProposalWrapper( - entityUrn=urn, aspect=aspect, systemMetadata=system_metadata - ).as_workunit() + entityUrn=urn, + aspect=aspect, + systemMetadata=system_metadata, + ).as_workunit(), ], ) @@ -99,7 +102,9 @@ def test_incremental_table_lineage(tmp_path, pytestconfig): [wu.metadata for wu in processed_wus], ) mce_helpers.check_golden_file( - pytestconfig=pytestconfig, output_path=test_file, golden_path=golden_file + pytestconfig=pytestconfig, + output_path=test_file, + golden_path=golden_file, ) @@ -114,8 +119,10 @@ def test_incremental_table_lineage_empty_upstreams(tmp_path, pytestconfig): incremental_lineage=True, stream=[ MetadataChangeProposalWrapper( - entityUrn=urn, aspect=aspect, systemMetadata=system_metadata - ).as_workunit() + entityUrn=urn, + aspect=aspect, + systemMetadata=system_metadata, + ).as_workunit(), ], ) @@ -134,8 +141,10 @@ def test_incremental_column_lineage(tmp_path, pytestconfig): incremental_lineage=True, stream=[ MetadataChangeProposalWrapper( - entityUrn=urn, aspect=aspect, systemMetadata=system_metadata - ).as_workunit() + entityUrn=urn, + aspect=aspect, + systemMetadata=system_metadata, + ).as_workunit(), ], ) @@ -144,7 +153,9 @@ def test_incremental_column_lineage(tmp_path, pytestconfig): [wu.metadata for wu in processed_wus], ) mce_helpers.check_golden_file( - pytestconfig=pytestconfig, output_path=test_file, golden_path=golden_file + pytestconfig=pytestconfig, + output_path=test_file, + golden_path=golden_file, ) @@ -181,5 +192,7 @@ def test_incremental_lineage_pass_through(tmp_path, pytestconfig): [wu.metadata for wu in processed_wus], ) mce_helpers.check_golden_file( - pytestconfig=pytestconfig, output_path=test_file, golden_path=golden_file + pytestconfig=pytestconfig, + output_path=test_file, + golden_path=golden_file, ) diff --git a/metadata-ingestion/tests/unit/api/source_helpers/test_source_helpers.py b/metadata-ingestion/tests/unit/api/source_helpers/test_source_helpers.py index cdfd24554e5e53..360bcbfec45b2e 100644 --- a/metadata-ingestion/tests/unit/api/source_helpers/test_source_helpers.py +++ b/metadata-ingestion/tests/unit/api/source_helpers/test_source_helpers.py @@ -92,8 +92,8 @@ def test_auto_status_aspect(): entityUrn="urn:li:dataset:(urn:li:dataPlatform:bigquery,bigquery-public-data.covid19_aha.staffing,PROD)", aspect=models.StatusClass(removed=False), ), - ] - ) + ], + ), ), ] assert list(auto_status_aspect(initial_wu)) == expected @@ -104,10 +104,14 @@ def test_auto_lowercase_aspects(): [ MetadataChangeProposalWrapper( entityUrn=make_dataset_urn( - "bigquery", "myProject.mySchema.myTable", "PROD" + "bigquery", + "myProject.mySchema.myTable", + "PROD", ), aspect=models.DatasetKeyClass( - "urn:li:dataPlatform:bigquery", "myProject.mySchema.myTable", "PROD" + "urn:li:dataPlatform:bigquery", + "myProject.mySchema.myTable", + "PROD", ), ), MetadataChangeProposalWrapper( @@ -128,7 +132,7 @@ def test_auto_lowercase_aspects(): ], ), ), - ] + ], ) expected = [ @@ -161,8 +165,8 @@ def test_auto_lowercase_aspects(): ], ), ), - ] - ) + ], + ), ), ] assert list(auto_lowercase_urns(mcws)) == expected @@ -179,12 +183,12 @@ def test_auto_empty_dataset_usage_statistics(caplog: pytest.LogCaptureFixture) - aspect=models.DatasetUsageStatisticsClass( timestampMillis=int(config.start_time.timestamp() * 1000), eventGranularity=models.TimeWindowSizeClass( - models.CalendarIntervalClass.DAY + models.CalendarIntervalClass.DAY, ), uniqueUserCount=1, totalSqlQueries=1, ), - ).as_workunit() + ).as_workunit(), ] caplog.clear() with caplog.at_level(logging.WARNING): @@ -194,7 +198,7 @@ def test_auto_empty_dataset_usage_statistics(caplog: pytest.LogCaptureFixture) - dataset_urns={has_urn, empty_urn}, config=config, all_buckets=False, - ) + ), ) assert not caplog.records @@ -205,7 +209,7 @@ def test_auto_empty_dataset_usage_statistics(caplog: pytest.LogCaptureFixture) - aspect=models.DatasetUsageStatisticsClass( timestampMillis=int(datetime(2023, 1, 1).timestamp() * 1000), eventGranularity=models.TimeWindowSizeClass( - models.CalendarIntervalClass.DAY + models.CalendarIntervalClass.DAY, ), uniqueUserCount=0, totalSqlQueries=0, @@ -229,12 +233,12 @@ def test_auto_empty_dataset_usage_statistics_invalid_timestamp( aspect=models.DatasetUsageStatisticsClass( timestampMillis=0, eventGranularity=models.TimeWindowSizeClass( - models.CalendarIntervalClass.DAY + models.CalendarIntervalClass.DAY, ), uniqueUserCount=1, totalSqlQueries=1, ), - ).as_workunit() + ).as_workunit(), ] caplog.clear() with caplog.at_level(logging.WARNING): @@ -244,7 +248,7 @@ def test_auto_empty_dataset_usage_statistics_invalid_timestamp( dataset_urns={urn}, config=config, all_buckets=True, - ) + ), ) assert len(caplog.records) == 1 assert "1970-01-01 00:00:00+00:00" in caplog.records[0].msg @@ -256,7 +260,7 @@ def test_auto_empty_dataset_usage_statistics_invalid_timestamp( aspect=models.DatasetUsageStatisticsClass( timestampMillis=int(config.start_time.timestamp() * 1000), eventGranularity=models.TimeWindowSizeClass( - models.CalendarIntervalClass.DAY + models.CalendarIntervalClass.DAY, ), uniqueUserCount=0, totalSqlQueries=0, @@ -295,7 +299,8 @@ def get_sample_mcps(mcps_to_append: List = []) -> List[MetadataChangeProposalWra def to_patch_work_units(patch_builder: DatasetPatchBuilder) -> List[MetadataWorkUnit]: return [ MetadataWorkUnit( - id=MetadataWorkUnit.generate_workunit_id(patch_mcp), mcp_raw=patch_mcp + id=MetadataWorkUnit.generate_workunit_id(patch_mcp), + mcp_raw=patch_mcp, ) for patch_mcp in patch_builder.build() ] @@ -303,7 +308,7 @@ def to_patch_work_units(patch_builder: DatasetPatchBuilder) -> List[MetadataWork def get_auto_generated_wu() -> List[MetadataWorkUnit]: dataset_patch_builder = DatasetPatchBuilder( - urn="urn:li:dataset:(urn:li:dataPlatform:dbt,abc.foo.bar,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:dbt,abc.foo.bar,PROD)", ).set_last_modified(TimeStampClass(time=20)) auto_generated_work_units = to_patch_work_units(dataset_patch_builder) @@ -317,7 +322,7 @@ def test_auto_patch_last_modified_no_change(): MetadataChangeProposalWrapper( entityUrn="urn:li:container:008e111aa1d250dd52e0fd5d4b307b1a", aspect=models.StatusClass(removed=False), - ) + ), ] initial_wu = list(auto_workunit(mcps)) @@ -349,7 +354,7 @@ def test_auto_patch_last_modified_multi_patch(): mcps = get_sample_mcps() dataset_patch_builder = DatasetPatchBuilder( - urn="urn:li:dataset:(urn:li:dataPlatform:dbt,abc.foo.bar,PROD)" + urn="urn:li:dataset:(urn:li:dataPlatform:dbt,abc.foo.bar,PROD)", ) dataset_patch_builder.set_display_name("foo") diff --git a/metadata-ingestion/tests/unit/api/test_pipeline.py b/metadata-ingestion/tests/unit/api/test_pipeline.py index 324e4ed0f6e853..86fce9a15b40c8 100644 --- a/metadata-ingestion/tests/unit/api/test_pipeline.py +++ b/metadata-ingestion/tests/unit/api/test_pipeline.py @@ -35,7 +35,8 @@ class TestPipeline: @patch("confluent_kafka.Consumer", autospec=True) @patch( - "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True + "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", + autospec=True, ) @patch("datahub.ingestion.sink.console.ConsoleSink.close", autospec=True) @freeze_time(FROZEN_TIME) @@ -47,7 +48,7 @@ def test_configure(self, mock_sink, mock_source, mock_consumer): "config": {"connection": {"bootstrap": "fake-dns-name:9092"}}, }, "sink": {"type": "console"}, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -69,7 +70,10 @@ def test_configure(self, mock_sink, mock_source, mock_consumer): return_value=DatahubClientConfig(server="http://fake-gms-server:8080"), ) def test_configure_without_sink( - self, mock_emitter, mock_graph, mock_load_client_config + self, + mock_emitter, + mock_graph, + mock_load_client_config, ): pipeline = Pipeline.create( { @@ -77,7 +81,7 @@ def test_configure_without_sink( "type": "file", "config": {"path": "test_file.json"}, }, - } + }, ) # assert that the default sink is a DatahubRestSink assert isinstance(pipeline.sink, DatahubRestSink) @@ -102,7 +106,11 @@ def test_configure_without_sink( return_value="Basic user:pass", ) def test_configure_without_sink_use_system_auth( - self, mock_emitter, mock_graph, mock_load_client_config, mock_get_system_auth + self, + mock_emitter, + mock_graph, + mock_load_client_config, + mock_get_system_auth, ): pipeline = Pipeline.create( { @@ -110,7 +118,7 @@ def test_configure_without_sink_use_system_auth( "type": "file", "config": {"path": "test_file.json"}, }, - } + }, ) # assert that the default sink is a DatahubRestSink assert isinstance(pipeline.sink, DatahubRestSink) @@ -130,7 +138,9 @@ def test_configure_without_sink_use_system_auth( return_value={"noCode": True}, ) def test_configure_with_rest_sink_initializes_graph( - self, mock_source, mock_test_connection + self, + mock_source, + mock_test_connection, ): pipeline = Pipeline.create( { @@ -170,7 +180,9 @@ def test_configure_with_rest_sink_initializes_graph( return_value={"noCode": True}, ) def test_configure_with_rest_sink_with_additional_props_initializes_graph( - self, mock_source, mock_test_connection + self, + mock_source, + mock_test_connection, ): pipeline = Pipeline.create( { @@ -186,7 +198,7 @@ def test_configure_with_rest_sink_with_additional_props_initializes_graph( "mode": "sync", }, }, - } + }, ) # assert that the default sink config is for a DatahubRestSink assert isinstance(pipeline.config.sink, DynamicTypedConfig) @@ -202,7 +214,8 @@ def test_configure_with_rest_sink_with_additional_props_initializes_graph( @freeze_time(FROZEN_TIME) @patch( - "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True + "datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", + autospec=True, ) def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_path): pipeline = Pipeline.create( @@ -217,7 +230,7 @@ def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_pat "filename": str(tmp_path / "test.json"), }, }, - } + }, ) # assert that the default sink config is for a DatahubRestSink assert isinstance(pipeline.config.sink, DynamicTypedConfig) @@ -231,11 +244,13 @@ def test_run_including_fake_transformation(self): { "source": {"type": "tests.unit.api.test_pipeline.FakeSource"}, "transformers": [ - {"type": "tests.unit.api.test_pipeline.AddStatusRemovedTransformer"} + { + "type": "tests.unit.api.test_pipeline.AddStatusRemovedTransformer", + }, ], "sink": {"type": "tests.test_helpers.sink_helpers.RecordingSink"}, "run_id": "pipeline_test", - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -246,7 +261,8 @@ def test_run_including_fake_transformation(self): dataset_snapshot.aspects.append(get_status_removed_aspect()) sink_report: RecordingSinkReport = cast( - RecordingSinkReport, pipeline.sink.get_report() + RecordingSinkReport, + pipeline.sink.get_report(), ) assert len(sink_report.received_records) == 1 @@ -266,10 +282,10 @@ def test_run_including_registered_transformation(self): "owner_urns": ["urn:li:corpuser:foo"], "ownership_type": "urn:li:ownershipType:__system__technical_owner", }, - } + }, ], "sink": {"type": "tests.test_helpers.sink_helpers.RecordingSink"}, - } + }, ) assert pipeline @@ -306,7 +322,7 @@ def test_pipeline_return_code(self, tmp_path, source, strict_warnings, exit_code config: {{}} sink: type: console -""" +""", ) res = run_datahub_cmd( @@ -387,7 +403,7 @@ def test_pipeline_process_commits(self, commit_policy, source, should_commit): "source": {"type": f"tests.unit.api.test_pipeline.{source}"}, "sink": {"type": "console"}, "run_id": "pipeline_test", - } + }, ) class FakeCommittable(Committable): @@ -401,7 +417,9 @@ def commit(self) -> None: fake_committable: Committable = FakeCommittable(commit_policy) with patch.object( - FakeCommittable, "commit", wraps=fake_committable.commit + FakeCommittable, + "commit", + wraps=fake_committable.commit, ) as mock_commit: pipeline.ctx.register_checkpointer(fake_committable) @@ -419,15 +437,17 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> "Transformer": return cls() def transform( - self, record_envelopes: Iterable[RecordEnvelope] + self, + record_envelopes: Iterable[RecordEnvelope], ) -> Iterable[RecordEnvelope]: for record_envelope in record_envelopes: if isinstance(record_envelope.record, MetadataChangeEventClass): assert isinstance( - record_envelope.record.proposedSnapshot, DatasetSnapshotClass + record_envelope.record.proposedSnapshot, + DatasetSnapshotClass, ) record_envelope.record.proposedSnapshot.aspects.append( - get_status_removed_aspect() + get_status_removed_aspect(), ) yield record_envelope @@ -437,7 +457,7 @@ def __init__(self, ctx: PipelineContext): super().__init__(ctx) self.source_report = SourceReport() self.work_units: List[MetadataWorkUnit] = [ - MetadataWorkUnit(id="workunit-1", mce=get_initial_mce()) + MetadataWorkUnit(id="workunit-1", mce=get_initial_mce()), ] @classmethod @@ -480,11 +500,12 @@ def get_initial_mce() -> MetadataChangeEventClass: aspects=[ DatasetPropertiesClass( description="test.description", - ) + ), ], ), systemMetadata=SystemMetadata( - lastObserved=1586847600000, runId="pipeline_test" + lastObserved=1586847600000, + runId="pipeline_test", ), ) diff --git a/metadata-ingestion/tests/unit/api/test_plugin_system.py b/metadata-ingestion/tests/unit/api/test_plugin_system.py index 0e12416325bf9a..3db93bdb41444a 100644 --- a/metadata-ingestion/tests/unit/api/test_plugin_system.py +++ b/metadata-ingestion/tests/unit/api/test_plugin_system.py @@ -86,7 +86,8 @@ def test_registry(): fake_registry.register("console", ConsoleSink) fake_registry.register_disabled("disabled", ModuleNotFoundError("disabled sink")) fake_registry.register_disabled( - "disabled-exception", Exception("second disabled sink") + "disabled-exception", + Exception("second disabled sink"), ) class DummyClass: @@ -101,7 +102,8 @@ class DummyClass: # Test lazy-loading capabilities. fake_registry.register_lazy( - "lazy-console", "datahub.ingestion.sink.console:ConsoleSink" + "lazy-console", + "datahub.ingestion.sink.console:ConsoleSink", ) assert fake_registry.get("lazy-console") == ConsoleSink diff --git a/metadata-ingestion/tests/unit/api/test_workunit.py b/metadata-ingestion/tests/unit/api/test_workunit.py index 5a4fbf315edb6f..b1c5b68a105864 100644 --- a/metadata-ingestion/tests/unit/api/test_workunit.py +++ b/metadata-ingestion/tests/unit/api/test_workunit.py @@ -18,7 +18,8 @@ def test_get_aspects_of_type_mcp(): aspect = StatusClass(False) wu = MetadataChangeProposalWrapper( - entityUrn="urn:li:container:asdf", aspect=aspect + entityUrn="urn:li:container:asdf", + aspect=aspect, ).as_workunit() assert wu.get_aspect_of_type(StatusClass) == aspect assert wu.get_aspect_of_type(ContainerClass) is None @@ -32,7 +33,7 @@ def test_get_aspects_of_type_mce(): proposedSnapshot=DatasetSnapshotClass( urn="urn:li:dataset:asdf", aspects=[status_aspect, lineage_aspect, status_aspect_2], - ) + ), ) wu = MetadataWorkUnit(id="id", mce=mce) assert wu.get_aspect_of_type(StatusClass) == status_aspect_2 diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigquery_lineage.py b/metadata-ingestion/tests/unit/bigquery/test_bigquery_lineage.py index f494ed78211dcf..58d79ee3d2ac83 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigquery_lineage.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigquery_lineage.py @@ -39,14 +39,14 @@ def lineage_entries() -> List[QueryEvent]: end_time=None, referencedTables=[ BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_source_table1" + "projects/my_project/datasets/my_dataset/tables/my_source_table1", ), BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_source_table2" + "projects/my_project/datasets/my_dataset/tables/my_source_table2", ), ], destinationTable=BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_table" + "projects/my_project/datasets/my_dataset/tables/my_table", ), ), QueryEvent( @@ -56,15 +56,16 @@ def lineage_entries() -> List[QueryEvent]: statementType="SELECT", project_id="proj_12344", end_time=datetime.datetime.fromtimestamp( - 1617295943.17321, tz=datetime.timezone.utc + 1617295943.17321, + tz=datetime.timezone.utc, ), referencedTables=[ BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_source_table3" + "projects/my_project/datasets/my_dataset/tables/my_source_table3", ), ], destinationTable=BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_table" + "projects/my_project/datasets/my_dataset/tables/my_table", ), ), QueryEvent( @@ -75,11 +76,11 @@ def lineage_entries() -> List[QueryEvent]: project_id="proj_12344", referencedViews=[ BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_source_view1" + "projects/my_project/datasets/my_dataset/tables/my_source_view1", ), ], destinationTable=BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_table" + "projects/my_project/datasets/my_dataset/tables/my_table", ), ), ] @@ -96,11 +97,11 @@ def test_lineage_with_timestamps(lineage_entries: List[QueryEvent]) -> None: ) bq_table = BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_table" + "projects/my_project/datasets/my_dataset/tables/my_table", ) lineage_map: Dict[str, Set[LineageEdge]] = extractor._create_lineage_map( - iter(lineage_entries) + iter(lineage_entries), ) upstream_lineage = extractor.get_lineage_for_table( @@ -123,7 +124,7 @@ def test_column_level_lineage(lineage_entries: List[QueryEvent]) -> None: ) bq_table = BigQueryTableRef.from_string_name( - "projects/my_project/datasets/my_dataset/tables/my_table" + "projects/my_project/datasets/my_dataset/tables/my_table", ) lineage_map: Dict[str, Set[LineageEdge]] = extractor._create_lineage_map( @@ -154,7 +155,7 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: version=0, hash="", platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -181,7 +182,7 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: PathSpec(include="gs://bigquery_data/customer3/{table}/*.parquet"), ] gcs_lineage_config: GcsLineageProviderConfig = GcsLineageProviderConfig( - path_specs=path_specs + path_specs=path_specs, ) config = BigQueryV2Config( @@ -250,7 +251,7 @@ def fake_schema_metadata(entity_urn: str) -> Optional[models.SchemaMetadataClass PathSpec(include="gs://bigquery_data/customer3/{table}/*.parquet"), ] gcs_lineage_config: GcsLineageProviderConfig = GcsLineageProviderConfig( - path_specs=path_specs + path_specs=path_specs, ) config = BigQueryV2Config( @@ -330,7 +331,8 @@ def test_lineage_for_external_table_path_not_matching_specs( PathSpec(include="gs://different_data/db2/db3/{table}/*.parquet"), ] gcs_lineage_config: GcsLineageProviderConfig = GcsLineageProviderConfig( - path_specs=path_specs, ignore_non_path_spec_path=True + path_specs=path_specs, + ignore_non_path_spec_path=True, ) config = BigQueryV2Config( include_table_lineage=True, diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py b/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py index b605e9b3f8a3e6..5dc9321816c80a 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py @@ -60,14 +60,14 @@ def test_bigquery_uri(): config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) assert config.get_sql_alchemy_url() == "bigquery://" def test_bigquery_uri_on_behalf(): config = BigQueryV2Config.parse_obj( - {"project_id": "test-project", "project_on_behalf": "test-project-on-behalf"} + {"project_id": "test-project", "project_on_behalf": "test-project-on-behalf"}, ) assert config.get_sql_alchemy_url() == "bigquery://test-project-on-behalf" @@ -86,7 +86,7 @@ def test_bigquery_dataset_pattern(): "project\\.second_dataset", ], }, - } + }, ) assert config.dataset_pattern.allow == [ r".*\.test-dataset", @@ -112,7 +112,7 @@ def test_bigquery_dataset_pattern(): ], }, "match_fully_qualified_names": False, - } + }, ) assert config.dataset_pattern.allow == [ r"test-dataset", @@ -149,7 +149,7 @@ def test_bigquery_uri_with_credential(): "client_email": "test@acryl.io", "client_id": "test_client-id", }, - } + }, ) try: @@ -181,7 +181,7 @@ def test_get_projects_with_project_ids( config = BigQueryV2Config.parse_obj( { "project_ids": ["test-1", "test-2"], - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test1")) assert get_projects( @@ -195,7 +195,7 @@ def test_get_projects_with_project_ids( assert client_mock.list_projects.call_count == 0 config = BigQueryV2Config.parse_obj( - {"project_ids": ["test-1", "test-2"], "project_id": "test-3"} + {"project_ids": ["test-1", "test-2"], "project_id": "test-3"}, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test2")) assert get_projects( @@ -219,7 +219,7 @@ def test_get_projects_with_project_ids_overrides_project_id_pattern( { "project_ids": ["test-project", "test-project-2"], "project_id_pattern": {"deny": ["^test-project$"]}, - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) projects = get_projects( @@ -235,12 +235,12 @@ def test_get_projects_with_project_ids_overrides_project_id_pattern( def test_platform_instance_config_always_none(): config = BigQueryV2Config.parse_obj( - {"include_data_platform_instance": True, "platform_instance": "something"} + {"include_data_platform_instance": True, "platform_instance": "something"}, ) assert config.platform_instance is None config = BigQueryV2Config.parse_obj( - dict(platform_instance="something", project_id="project_id") + dict(platform_instance="something", project_id="project_id"), ) assert config.project_ids == ["project_id"] assert config.platform_instance is None @@ -262,7 +262,8 @@ def test_get_dataplatform_instance_aspect_returns_project_id( schema_gen = source.bq_schema_extractor data_platform_instance = schema_gen.get_dataplatform_instance_aspect( - "urn:li:test", project_id + "urn:li:test", + project_id, ) metadata = data_platform_instance.metadata @@ -282,7 +283,8 @@ def test_get_dataplatform_instance_default_no_instance( schema_gen = source.bq_schema_extractor data_platform_instance = schema_gen.get_dataplatform_instance_aspect( - "urn:li:test", "project_id" + "urn:li:test", + "project_id", ) metadata = data_platform_instance.metadata @@ -322,7 +324,7 @@ def test_get_projects_by_list(get_projects_client, get_bigquery_client): [ SimpleNamespace(project_id="test-1", friendly_name="one"), SimpleNamespace(project_id="test-2", friendly_name="two"), - ] + ], ) first_page.next_page_token = "token1" @@ -331,7 +333,7 @@ def test_get_projects_by_list(get_projects_client, get_bigquery_client): [ SimpleNamespace(project_id="test-3", friendly_name="three"), SimpleNamespace(project_id="test-4", friendly_name="four"), - ] + ], ) second_page.next_page_token = None @@ -356,7 +358,9 @@ def test_get_projects_by_list(get_projects_client, get_bigquery_client): @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_get_projects_filter_by_pattern( - get_projects_client, get_bq_client_mock, get_projects_mock + get_projects_client, + get_bq_client_mock, + get_projects_mock, ): get_projects_mock.return_value = [ BigqueryProject("test-project", "Test Project"), @@ -364,7 +368,7 @@ def test_get_projects_filter_by_pattern( ] config = BigQueryV2Config.parse_obj( - {"project_id_pattern": {"deny": ["^test-project$"]}} + {"project_id_pattern": {"deny": ["^test-project$"]}}, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) projects = get_projects( @@ -381,12 +385,14 @@ def test_get_projects_filter_by_pattern( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_get_projects_list_empty( - get_projects_client, get_bq_client_mock, get_projects_mock + get_projects_client, + get_bq_client_mock, + get_projects_mock, ): get_projects_mock.return_value = [] config = BigQueryV2Config.parse_obj( - {"project_id_pattern": {"deny": ["^test-project$"]}} + {"project_id_pattern": {"deny": ["^test-project$"]}}, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) projects = get_projects( @@ -411,7 +417,7 @@ def test_get_projects_list_failure( bq_client_mock.list_projects.side_effect = GoogleAPICallError(error_str) config = BigQueryV2Config.parse_obj( - {"project_id_pattern": {"deny": ["^test-project$"]}} + {"project_id_pattern": {"deny": ["^test-project$"]}}, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) caplog.clear() @@ -431,12 +437,14 @@ def test_get_projects_list_failure( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_get_projects_list_fully_filtered( - get_projects_mock, get_bq_client_mock, get_projects_client + get_projects_mock, + get_bq_client_mock, + get_projects_client, ): get_projects_mock.return_value = [BigqueryProject("test-project", "Test Project")] config = BigQueryV2Config.parse_obj( - {"project_id_pattern": {"deny": ["^test-project$"]}} + {"project_id_pattern": {"deny": ["^test-project$"]}}, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) projects = get_projects( @@ -471,7 +479,9 @@ def bigquery_table() -> BigqueryTable: @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_gen_table_dataset_workunits( - get_projects_client, get_bq_client_mock, bigquery_table + get_projects_client, + get_bq_client_mock, + bigquery_table, ): project_id = "test-project" dataset_name = "test-dataset" @@ -479,15 +489,19 @@ def test_gen_table_dataset_workunits( { "project_id": project_id, "capture_table_label_as_tag": True, - } + }, ) source: BigqueryV2Source = BigqueryV2Source( - config=config, ctx=PipelineContext(run_id="test") + config=config, + ctx=PipelineContext(run_id="test"), ) schema_gen = source.bq_schema_extractor gen = schema_gen.gen_table_dataset_workunits( - bigquery_table, [], project_id, dataset_name + bigquery_table, + [], + project_id, + dataset_name, ) mcps = list(gen) @@ -520,10 +534,10 @@ def find_mcp_by_aspect(aspect_type): ) assert dataset_props_mcp.metadata.aspect.description == bigquery_table.comment assert dataset_props_mcp.metadata.aspect.created == TimeStampClass( - time=int(bigquery_table.created.timestamp() * 1000) + time=int(bigquery_table.created.timestamp() * 1000), ) assert dataset_props_mcp.metadata.aspect.lastModified == TimeStampClass( - time=int(bigquery_table.last_altered.timestamp() * 1000) + time=int(bigquery_table.last_altered.timestamp() * 1000), ) assert dataset_props_mcp.metadata.aspect.tags == [] @@ -546,8 +560,8 @@ def find_mcp_by_aspect(aspect_type): global_tags_mcp = find_mcp_by_aspect(GlobalTagsClass) assert global_tags_mcp.metadata.aspect.tags == [ TagAssociationClass( - "urn:li:tag:data_producer_owner_email:games_team-nytimes_com" - ) + "urn:li:tag:data_producer_owner_email:games_team-nytimes_com", + ), ] # Assert ContainerClass @@ -572,27 +586,33 @@ def find_mcp_by_aspect(aspect_type): def test_simple_upstream_table_generation(get_bq_client_mock, get_projects_client): a: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="a" - ) + project_id="test-project", + dataset="test-dataset", + table="a", + ), ) b: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="b" - ) + project_id="test-project", + dataset="test-dataset", + table="b", + ), ) config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) lineage_metadata = { str(a): { LineageEdge( - table=str(b), auditStamp=datetime.now(), column_mapping=frozenset() - ) - } + table=str(b), + auditStamp=datetime.now(), + column_mapping=frozenset(), + ), + }, } upstreams = source.lineage_extractor.get_upstream_tables(a, lineage_metadata) @@ -608,28 +628,34 @@ def test_upstream_table_generation_with_temporary_table_without_temp_upstream( ): a: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="a" - ) + project_id="test-project", + dataset="test-dataset", + table="a", + ), ) b: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="_temp-dataset", table="b" - ) + project_id="test-project", + dataset="_temp-dataset", + table="b", + ), ) config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) lineage_metadata = { str(a): { LineageEdge( - table=str(b), auditStamp=datetime.now(), column_mapping=frozenset() - ) - } + table=str(b), + auditStamp=datetime.now(), + column_mapping=frozenset(), + ), + }, } upstreams = source.lineage_extractor.get_upstream_tables(a, lineage_metadata) assert list(upstreams) == [] @@ -638,30 +664,37 @@ def test_upstream_table_generation_with_temporary_table_without_temp_upstream( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_upstream_table_column_lineage_with_temp_table( - get_bq_client_mock, get_projects_client + get_bq_client_mock, + get_projects_client, ): from datahub.ingestion.api.common import PipelineContext a: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="a" - ) + project_id="test-project", + dataset="test-dataset", + table="a", + ), ) b: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="_temp-dataset", table="b" - ) + project_id="test-project", + dataset="_temp-dataset", + table="b", + ), ) c: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="c" - ) + project_id="test-project", + dataset="test-dataset", + table="c", + ), ) config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) @@ -673,12 +706,13 @@ def test_upstream_table_column_lineage_with_temp_table( column_mapping=frozenset( [ LineageEdgeColumnMapping( - "a_col1", in_columns=frozenset(["b_col2", "b_col3"]) - ) - ] + "a_col1", + in_columns=frozenset(["b_col2", "b_col3"]), + ), + ], ), column_confidence=0.8, - ) + ), }, str(b): { LineageEdge( @@ -687,15 +721,17 @@ def test_upstream_table_column_lineage_with_temp_table( column_mapping=frozenset( [ LineageEdgeColumnMapping( - "b_col2", in_columns=frozenset(["c_col1", "c_col2"]) + "b_col2", + in_columns=frozenset(["c_col1", "c_col2"]), ), LineageEdgeColumnMapping( - "b_col3", in_columns=frozenset(["c_col2", "c_col3"]) + "b_col3", + in_columns=frozenset(["c_col2", "c_col3"]), ), - ] + ], ), column_confidence=0.7, - ) + ), }, } upstreams = source.lineage_extractor.get_upstream_tables(a, lineage_metadata) @@ -706,9 +742,10 @@ def test_upstream_table_column_lineage_with_temp_table( assert upstream.column_mapping == frozenset( [ LineageEdgeColumnMapping( - "a_col1", in_columns=frozenset(["c_col1", "c_col2", "c_col3"]) - ) - ] + "a_col1", + in_columns=frozenset(["c_col1", "c_col2", "c_col3"]), + ), + ], ) assert upstream.column_confidence == 0.7 @@ -716,58 +753,77 @@ def test_upstream_table_column_lineage_with_temp_table( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_upstream_table_generation_with_temporary_table_with_multiple_temp_upstream( - get_bq_client_mock, get_projects_client + get_bq_client_mock, + get_projects_client, ): a: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="a" - ) + project_id="test-project", + dataset="test-dataset", + table="a", + ), ) b: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="_temp-dataset", table="b" - ) + project_id="test-project", + dataset="_temp-dataset", + table="b", + ), ) c: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="c" - ) + project_id="test-project", + dataset="test-dataset", + table="c", + ), ) d: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="_test-dataset", table="d" - ) + project_id="test-project", + dataset="_test-dataset", + table="d", + ), ) e: BigQueryTableRef = BigQueryTableRef( BigqueryTableIdentifier( - project_id="test-project", dataset="test-dataset", table="e" - ) + project_id="test-project", + dataset="test-dataset", + table="e", + ), ) config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) lineage_metadata = { str(a): { LineageEdge( - table=str(b), auditStamp=datetime.now(), column_mapping=frozenset() - ) + table=str(b), + auditStamp=datetime.now(), + column_mapping=frozenset(), + ), }, str(b): { LineageEdge( - table=str(c), auditStamp=datetime.now(), column_mapping=frozenset() + table=str(c), + auditStamp=datetime.now(), + column_mapping=frozenset(), ), LineageEdge( - table=str(d), auditStamp=datetime.now(), column_mapping=frozenset() + table=str(d), + auditStamp=datetime.now(), + column_mapping=frozenset(), ), }, str(d): { LineageEdge( - table=str(e), auditStamp=datetime.now(), column_mapping=frozenset() - ) + table=str(e), + auditStamp=datetime.now(), + column_mapping=frozenset(), + ), }, } upstreams = source.lineage_extractor.get_upstream_tables(a, lineage_metadata) @@ -781,14 +837,16 @@ def test_upstream_table_generation_with_temporary_table_with_multiple_temp_upstr @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_table_processing_logic( - get_projects_client, get_bq_client_mock, data_dictionary_mock + get_projects_client, + get_bq_client_mock, + data_dictionary_mock, ): client_mock = MagicMock() get_bq_client_mock.return_value = client_mock config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) tableListItems = [ @@ -798,8 +856,8 @@ def test_table_processing_logic( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "test-table", - } - } + }, + }, ), TableListItem( { @@ -807,8 +865,8 @@ def test_table_processing_logic( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "test-sharded-table_20220102", - } - } + }, + }, ), TableListItem( { @@ -816,8 +874,8 @@ def test_table_processing_logic( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "test-sharded-table_20210101", - } - } + }, + }, ), TableListItem( { @@ -825,8 +883,8 @@ def test_table_processing_logic( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "test-sharded-table_20220101", - } - } + }, + }, ), ] @@ -838,8 +896,9 @@ def test_table_processing_logic( _ = list( schema_gen.get_tables_for_dataset( - project_id="test-project", dataset=BigqueryDataset("test-dataset") - ) + project_id="test-project", + dataset=BigqueryDataset("test-dataset"), + ), ) assert data_dictionary_mock.call_count == 1 @@ -856,7 +915,9 @@ def test_table_processing_logic( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_table_processing_logic_date_named_tables( - get_projects_client, get_bq_client_mock, data_dictionary_mock + get_projects_client, + get_bq_client_mock, + data_dictionary_mock, ): client_mock = MagicMock() get_bq_client_mock.return_value = client_mock @@ -864,7 +925,7 @@ def test_table_processing_logic_date_named_tables( config = BigQueryV2Config.parse_obj( { "project_id": "test-project", - } + }, ) tableListItems = [ @@ -874,8 +935,8 @@ def test_table_processing_logic_date_named_tables( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "test-table", - } - } + }, + }, ), TableListItem( { @@ -883,8 +944,8 @@ def test_table_processing_logic_date_named_tables( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "20220102", - } - } + }, + }, ), TableListItem( { @@ -892,8 +953,8 @@ def test_table_processing_logic_date_named_tables( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "20210101", - } - } + }, + }, ), TableListItem( { @@ -901,8 +962,8 @@ def test_table_processing_logic_date_named_tables( "projectId": "test-project", "datasetId": "test-dataset", "tableId": "20220103", - } - } + }, + }, ), ] @@ -914,8 +975,9 @@ def test_table_processing_logic_date_named_tables( _ = list( schema_gen.get_tables_for_dataset( - project_id="test-project", dataset=BigqueryDataset("test-dataset") - ) + project_id="test-project", + dataset=BigqueryDataset("test-dataset"), + ), ) assert data_dictionary_mock.call_count == 1 @@ -986,7 +1048,7 @@ def test_get_views_for_dataset( comment=bigquery_view_1.comment, view_definition=bigquery_view_1.view_definition, table_type="VIEW", - ) + ), ) row2 = create_row( # Materialized view, no last_altered dict( @@ -995,7 +1057,7 @@ def test_get_views_for_dataset( comment=bigquery_view_2.comment, view_definition=bigquery_view_2.view_definition, table_type="MATERIALIZED VIEW", - ) + ), ) query_mock.return_value = [row1, row2] bigquery_data_dictionary = BigQuerySchemaApi( @@ -1014,27 +1076,36 @@ def test_get_views_for_dataset( @patch.object( - BigQuerySchemaGenerator, "gen_dataset_workunits", lambda *args, **kwargs: [] + BigQuerySchemaGenerator, + "gen_dataset_workunits", + lambda *args, **kwargs: [], ) @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_gen_view_dataset_workunits( - get_projects_client, get_bq_client_mock, bigquery_view_1, bigquery_view_2 + get_projects_client, + get_bq_client_mock, + bigquery_view_1, + bigquery_view_2, ): project_id = "test-project" dataset_name = "test-dataset" config = BigQueryV2Config.parse_obj( { "project_id": project_id, - } + }, ) source: BigqueryV2Source = BigqueryV2Source( - config=config, ctx=PipelineContext(run_id="test") + config=config, + ctx=PipelineContext(run_id="test"), ) schema_gen = source.bq_schema_extractor gen = schema_gen.gen_view_dataset_workunits( - bigquery_view_1, [], project_id, dataset_name + bigquery_view_1, + [], + project_id, + dataset_name, ) mcp = cast(MetadataChangeProposalClass, next(iter(gen)).metadata) assert mcp.aspect == ViewProperties( @@ -1044,7 +1115,10 @@ def test_gen_view_dataset_workunits( ) gen = schema_gen.gen_view_dataset_workunits( - bigquery_view_2, [], project_id, dataset_name + bigquery_view_2, + [], + project_id, + dataset_name, ) mcp = cast(MetadataChangeProposalClass, next(iter(gen)).metadata) assert mcp.aspect == ViewProperties( @@ -1099,7 +1173,7 @@ def test_get_snapshots_for_dataset( base_table_catalog=bigquery_snapshot.base_table_identifier.project_id, base_table_schema=bigquery_snapshot.base_table_identifier.dataset, base_table_name=bigquery_snapshot.base_table_identifier.table, - ) + ), ) query_mock.return_value = [row1] bigquery_data_dictionary = BigQuerySchemaApi( @@ -1120,28 +1194,34 @@ def test_get_snapshots_for_dataset( @patch.object(BigQueryV2Config, "get_bigquery_client") @patch.object(BigQueryV2Config, "get_projects_client") def test_gen_snapshot_dataset_workunits( - get_bq_client_mock, get_projects_client, bigquery_snapshot + get_bq_client_mock, + get_projects_client, + bigquery_snapshot, ): project_id = "test-project" dataset_name = "test-dataset" config = BigQueryV2Config.parse_obj( { "project_id": project_id, - } + }, ) source: BigqueryV2Source = BigqueryV2Source( - config=config, ctx=PipelineContext(run_id="test") + config=config, + ctx=PipelineContext(run_id="test"), ) schema_gen = source.bq_schema_extractor gen = schema_gen.gen_snapshot_dataset_workunits( - bigquery_snapshot, [], project_id, dataset_name + bigquery_snapshot, + [], + project_id, + dataset_name, ) mcp = cast(MetadataChangeProposalWrapper, list(gen)[2].metadata) dataset_properties = cast(DatasetPropertiesClass, mcp.aspect) assert dataset_properties.customProperties["snapshot_ddl"] == bigquery_snapshot.ddl assert dataset_properties.customProperties["snapshot_time"] == str( - bigquery_snapshot.snapshot_time + bigquery_snapshot.snapshot_time, ) @@ -1168,7 +1248,9 @@ def test_gen_snapshot_dataset_workunits( ], ) def test_get_table_and_shard_default( - table_name: str, expected_table_prefix: Optional[str], expected_shard: Optional[str] + table_name: str, + expected_table_prefix: Optional[str], + expected_shard: Optional[str], ) -> None: with patch( "datahub.ingestion.source.bigquery_v2.bigquery_audit.BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX", @@ -1203,7 +1285,9 @@ def test_get_table_and_shard_default( ], ) def test_get_table_and_shard_custom_shard_pattern( - table_name: str, expected_table_prefix: Optional[str], expected_shard: Optional[str] + table_name: str, + expected_table_prefix: Optional[str], + expected_shard: Optional[str], ) -> None: with patch( "datahub.ingestion.source.bigquery_v2.bigquery_audit.BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX", @@ -1319,7 +1403,7 @@ def test_bigquery_config_deprecated_schema_pattern(): } config = BigQueryV2Config.parse_obj(config_with_dataset_pattern) assert config.dataset_pattern == AllowDenyPattern( - deny=["temp.*"] + deny=["temp.*"], ) # dataset_pattern @@ -1341,7 +1425,7 @@ def test_get_projects_with_project_labels( config = BigQueryV2Config.parse_obj( { "project_labels": ["environment:dev", "environment:qa"], - } + }, ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test1")) diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigquery_usage.py b/metadata-ingestion/tests/unit/bigquery/test_bigquery_usage.py index 7ff83bff4a72a5..7b8185396c0027 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigquery_usage.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigquery_usage.py @@ -48,7 +48,10 @@ DATABASE_2 = Container("database_2") TABLE_1 = Table("table_1", DATABASE_1, columns=["id", "name", "age"], upstreams=[]) TABLE_2 = Table( - "table_2", DATABASE_1, columns=["id", "table_1_id", "value"], upstreams=[] + "table_2", + DATABASE_1, + columns=["id", "table_1_id", "value"], + upstreams=[], ) VIEW_1 = View( name="view_1", @@ -175,7 +178,8 @@ def make_usage_workunit( def make_operational_workunit( - resource_urn: str, operation: OperationClass + resource_urn: str, + operation: OperationClass, ) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( entityUrn=resource_urn, @@ -232,7 +236,8 @@ def make_zero_usage_workunit( def compare_workunits( - output: Iterable[MetadataWorkUnit], expected: Iterable[MetadataWorkUnit] + output: Iterable[MetadataWorkUnit], + expected: Iterable[MetadataWorkUnit], ) -> None: assert not diff_metadata_json( [wu.metadata.to_obj() for wu in output], @@ -266,7 +271,8 @@ def test_usage_counts_single_bucket_resource_project( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=len(queries), topSqlQueries=[query_table_1_a().text, query_table_1_b().text], @@ -346,7 +352,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=6, topSqlQueries=[ @@ -389,7 +396,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=3, topSqlQueries=[ @@ -417,7 +425,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=3, topSqlQueries=[ @@ -429,7 +438,7 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( user=ACTOR_1_URN, count=3, userEmail=ACTOR_1, - ) + ), ], fieldCounts=[ DatasetFieldUsageCountsClass( @@ -450,7 +459,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_2.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=5, topSqlQueries=[ @@ -494,7 +504,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_2.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=2, topSqlQueries=[query_view_1().text, query_view_1_and_table_1().text], @@ -515,7 +526,8 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_2.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=2, topSqlQueries=[query_tables_1_and_2().text, query_table_2().text], @@ -525,7 +537,7 @@ def test_usage_counts_multiple_buckets_and_resources_view_usage( user=ACTOR_2_URN, count=2, userEmail=ACTOR_2, - ) + ), ], fieldCounts=[ DatasetFieldUsageCountsClass( @@ -592,7 +604,8 @@ def test_usage_counts_multiple_buckets_and_resources_no_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=9, topSqlQueries=[ @@ -640,7 +653,8 @@ def test_usage_counts_multiple_buckets_and_resources_no_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=6, topSqlQueries=[query_tables_1_and_2().text, query_view_1().text], @@ -684,7 +698,8 @@ def test_usage_counts_multiple_buckets_and_resources_no_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_2.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=6, topSqlQueries=[ @@ -733,7 +748,8 @@ def test_usage_counts_multiple_buckets_and_resources_no_view_usage( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_2.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=4, topSqlQueries=[ @@ -803,7 +819,7 @@ def test_usage_counts_no_query_event( fieldsRead=["id", "name", "total"], readReason="JOB", payload=None, - ) + ), ) workunits = usage_extractor._get_workunits_internal([event], [str(ref)]) expected = [ @@ -812,7 +828,8 @@ def test_usage_counts_no_query_event( aspect=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=0, uniqueUserCount=0, @@ -820,7 +837,7 @@ def test_usage_counts_no_query_event( userCounts=[], fieldCounts=[], ), - ).as_workunit() + ).as_workunit(), ] compare_workunits(workunits, expected) assert not caplog.records @@ -833,7 +850,7 @@ def test_usage_counts_no_columns( ) -> None: job_name = "job_name" ref = BigQueryTableRef( - BigqueryTableIdentifier(PROJECT_1, DATABASE_1.name, TABLE_1.name) + BigqueryTableIdentifier(PROJECT_1, DATABASE_1.name, TABLE_1.name), ) events = [ AuditEvent.create( @@ -859,13 +876,14 @@ def test_usage_counts_no_columns( referencedTables=[ref], referencedViews=[], payload=None, - ) + ), ), ] caplog.clear() with caplog.at_level(logging.WARNING): workunits = usage_extractor._get_workunits_internal( - events, [TABLE_REFS[TABLE_1.name]] + events, + [TABLE_REFS[TABLE_1.name]], ) expected = [ make_usage_workunit( @@ -873,7 +891,8 @@ def test_usage_counts_no_columns( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=1, topSqlQueries=["SELECT * FROM table_1"], @@ -888,7 +907,7 @@ def test_usage_counts_no_columns( fieldCounts=[], ), identifiers=usage_extractor.identifiers, - ) + ), ] compare_workunits(workunits, expected) assert not caplog.records @@ -903,7 +922,7 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( job_name = "job_name" ref = BigQueryTableRef( - BigqueryTableIdentifier(PROJECT_1, DATABASE_1.name, TABLE_1.name) + BigqueryTableIdentifier(PROJECT_1, DATABASE_1.name, TABLE_1.name), ) events = [ AuditEvent.create( @@ -951,7 +970,7 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( referencedTables=[ref], referencedViews=[], payload=None, - ) + ), ), AuditEvent.create( QueryEvent( @@ -965,7 +984,7 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( referencedTables=[ref], referencedViews=[], payload=None, - ) + ), ), AuditEvent.create( QueryEvent( @@ -979,13 +998,14 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( referencedTables=[ref], referencedViews=[], payload=None, - ) + ), ), ] caplog.clear() with caplog.at_level(logging.WARNING): workunits = usage_extractor._get_workunits_internal( - events, [TABLE_REFS[TABLE_1.name]] + events, + [TABLE_REFS[TABLE_1.name]], ) expected = [ make_usage_workunit( @@ -993,7 +1013,8 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( dataset_usage_statistics=DatasetUsageStatisticsClass( timestampMillis=int(TS_1.timestamp() * 1000), eventGranularity=TimeWindowSizeClass( - unit=BucketDuration.DAY, multiple=1 + unit=BucketDuration.DAY, + multiple=1, ), totalSqlQueries=3, topSqlQueries=["SELECT * FROM table_1"], @@ -1008,7 +1029,7 @@ def test_usage_counts_no_columns_and_top_n_limit_hit( fieldCounts=[], ), identifiers=usage_extractor.identifiers, - ) + ), ] compare_workunits(workunits, expected) assert not caplog.records @@ -1046,7 +1067,7 @@ def test_operational_stats( num_operations=20, num_unique_queries=10, num_users=3, - ) + ), ) events = generate_events(queries, projects, table_to_project, config=config) @@ -1055,8 +1076,8 @@ def test_operational_stats( make_operational_workunit( usage_extractor.identifiers.gen_dataset_urn_from_raw_ref( BigQueryTableRef.from_string_name( - table_refs[query.object_modified.name] - ) + table_refs[query.object_modified.name], + ), ), OperationClass( timestampMillis=int(FROZEN_TIME.timestamp() * 1000), @@ -1076,22 +1097,22 @@ def test_operational_stats( dict.fromkeys( # Preserve order usage_extractor.identifiers.gen_dataset_urn_from_raw_ref( BigQueryTableRef.from_string_name( - table_refs[field.table.name] - ) + table_refs[field.table.name], + ), ) for field in query.fields_accessed if not field.table.is_view() - ) + ), ) + list( dict.fromkeys( # Preserve order usage_extractor.identifiers.gen_dataset_urn_from_raw_ref( - BigQueryTableRef.from_string_name(table_refs[parent.name]) + BigQueryTableRef.from_string_name(table_refs[parent.name]), ) for field in query.fields_accessed if field.table.is_view() for parent in field.table.upstreams - ) + ), ), ), ) @@ -1111,22 +1132,24 @@ def test_operational_stats( def test_get_tables_from_query(usage_extractor): assert usage_extractor.get_tables_from_query( - "SELECT * FROM project-1.database_1.view_1", default_project=PROJECT_1 + "SELECT * FROM project-1.database_1.view_1", + default_project=PROJECT_1, ) == [ - BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")) + BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")), ] assert usage_extractor.get_tables_from_query( - "SELECT * FROM database_1.view_1", default_project=PROJECT_1 + "SELECT * FROM database_1.view_1", + default_project=PROJECT_1, ) == [ - BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")) + BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")), ] assert sorted( usage_extractor.get_tables_from_query( "SELECT v.id, v.name, v.total, t.name as name1 FROM database_1.view_1 as v inner join database_1.table_1 as t on v.id=t.id", default_project=PROJECT_1, - ) + ), ) == [ BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "table_1")), BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")), @@ -1136,7 +1159,7 @@ def test_get_tables_from_query(usage_extractor): usage_extractor.get_tables_from_query( "CREATE TABLE database_1.new_table AS SELECT v.id, v.name, v.total, t.name as name1 FROM database_1.view_1 as v inner join database_1.table_1 as t on v.id=t.id", default_project=PROJECT_1, - ) + ), ) == [ BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "table_1")), BigQueryTableRef(BigqueryTableIdentifier("project-1", "database_1", "view_1")), diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py b/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py index 3247a64631da76..ce2827e00b5db7 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py @@ -44,7 +44,7 @@ def test_bigqueryv2_uri_with_credential(): "client_email": "test@acryl.io", "client_id": "test_client-id", }, - } + }, ) try: @@ -80,7 +80,7 @@ def test_bigqueryv2_filters(): "allow": ["test-regex", "test-regex-1"], "deny": ["excluded_table_regex", "excluded-regex-2"], }, - } + }, ) expected_filter: str = """resource.type=(\"bigquery_project\" OR \"bigquery_dataset\") AND @@ -130,7 +130,7 @@ def test_bigqueryv2_filters(): def test_bigquery_table_sanitasitation(): table_ref = BigQueryTableRef( - BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_*") + BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_*"), ) assert ( @@ -143,17 +143,17 @@ def test_bigquery_table_sanitasitation(): assert table_ref.table_identifier.get_table_display_name() == "foo" table_ref = BigQueryTableRef( - BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_2022") + BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_2022"), ) new_table_ref = BigqueryTableIdentifier.from_string_name( - table_ref.table_identifier.get_table_name() + table_ref.table_identifier.get_table_name(), ) assert new_table_ref.table == "foo_2022" assert new_table_ref.project_id == "project-1234" assert new_table_ref.dataset == "dataset-4567" table_ref = BigQueryTableRef( - BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_20221210") + BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_20221210"), ) new_table_identifier = table_ref.table_identifier assert new_table_identifier.table == "foo_20221210" @@ -163,17 +163,17 @@ def test_bigquery_table_sanitasitation(): assert new_table_identifier.dataset == "dataset-4567" table_ref = BigQueryTableRef( - BigqueryTableIdentifier("project-1234", "dataset-4567", "foo") + BigqueryTableIdentifier("project-1234", "dataset-4567", "foo"), ) new_table_ref = BigqueryTableIdentifier.from_string_name( - table_ref.table_identifier.get_table_name() + table_ref.table_identifier.get_table_name(), ) assert new_table_ref.table == "foo" assert new_table_ref.project_id == "project-1234" assert new_table_ref.dataset == "dataset-4567" table_ref = BigQueryTableRef( - BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_2016*") + BigqueryTableIdentifier("project-1234", "dataset-4567", "foo_2016*"), ) table_identifier = table_ref.table_identifier assert table_identifier.is_sharded_table() diff --git a/metadata-ingestion/tests/unit/bigquery/test_bq_get_partition_range.py b/metadata-ingestion/tests/unit/bigquery/test_bq_get_partition_range.py index 0d2a7410fb4651..11e1d55eec2bfe 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bq_get_partition_range.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bq_get_partition_range.py @@ -6,13 +6,16 @@ def test_get_partition_range_from_partition_id(): # yearly partition check assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022", datetime.datetime(2022, 1, 1) + "2022", + datetime.datetime(2022, 1, 1), ) == (datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022", datetime.datetime(2022, 3, 12) + "2022", + datetime.datetime(2022, 3, 12), ) == (datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022", datetime.datetime(2021, 5, 2) + "2022", + datetime.datetime(2021, 5, 2), ) == (datetime.datetime(2021, 1, 1), datetime.datetime(2022, 1, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id("2022", None) == ( datetime.datetime(2022, 1, 1), @@ -20,13 +23,16 @@ def test_get_partition_range_from_partition_id(): ) # monthly partition check assert BigqueryProfiler.get_partition_range_from_partition_id( - "202202", datetime.datetime(2022, 2, 1) + "202202", + datetime.datetime(2022, 2, 1), ) == (datetime.datetime(2022, 2, 1), datetime.datetime(2022, 3, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "202202", datetime.datetime(2022, 2, 3) + "202202", + datetime.datetime(2022, 2, 3), ) == (datetime.datetime(2022, 2, 1), datetime.datetime(2022, 3, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "202202", datetime.datetime(2021, 12, 13) + "202202", + datetime.datetime(2021, 12, 13), ) == (datetime.datetime(2021, 12, 1), datetime.datetime(2022, 1, 1)) assert BigqueryProfiler.get_partition_range_from_partition_id("202202", None) == ( datetime.datetime(2022, 2, 1), @@ -34,10 +40,12 @@ def test_get_partition_range_from_partition_id(): ) # daily partition check assert BigqueryProfiler.get_partition_range_from_partition_id( - "20220205", datetime.datetime(2022, 2, 5) + "20220205", + datetime.datetime(2022, 2, 5), ) == (datetime.datetime(2022, 2, 5), datetime.datetime(2022, 2, 6)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "20220205", datetime.datetime(2022, 2, 3) + "20220205", + datetime.datetime(2022, 2, 3), ) == (datetime.datetime(2022, 2, 3), datetime.datetime(2022, 2, 4)) assert BigqueryProfiler.get_partition_range_from_partition_id("20220205", None) == ( datetime.datetime(2022, 2, 5), @@ -45,13 +53,16 @@ def test_get_partition_range_from_partition_id(): ) # hourly partition check assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022020509", datetime.datetime(2022, 2, 5, 9) + "2022020509", + datetime.datetime(2022, 2, 5, 9), ) == (datetime.datetime(2022, 2, 5, 9), datetime.datetime(2022, 2, 5, 10)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022020509", datetime.datetime(2022, 2, 3, 1) + "2022020509", + datetime.datetime(2022, 2, 3, 1), ) == (datetime.datetime(2022, 2, 3, 1), datetime.datetime(2022, 2, 3, 2)) assert BigqueryProfiler.get_partition_range_from_partition_id( - "2022020509", None + "2022020509", + None, ) == ( datetime.datetime(2022, 2, 5, 9), datetime.datetime(2022, 2, 5, 10), diff --git a/metadata-ingestion/tests/unit/cli/test_cli_utils.py b/metadata-ingestion/tests/unit/cli/test_cli_utils.py index c430f585200e5a..df93df719dab13 100644 --- a/metadata-ingestion/tests/unit/cli/test_cli_utils.py +++ b/metadata-ingestion/tests/unit/cli/test_cli_utils.py @@ -21,7 +21,8 @@ def test_correct_url_when_gms_host_in_old_format(): @mock.patch.dict( - os.environ, {"DATAHUB_GMS_HOST": "localhost", "DATAHUB_GMS_PORT": "8080"} + os.environ, + {"DATAHUB_GMS_HOST": "localhost", "DATAHUB_GMS_PORT": "8080"}, ) def test_correct_url_when_gms_host_and_port_set(): assert _get_config_from_env() == ("http://localhost:8080", None) diff --git a/metadata-ingestion/tests/unit/cli/test_quickstart_version_mapping.py b/metadata-ingestion/tests/unit/cli/test_quickstart_version_mapping.py index 38f3451a191a43..56007263d896ed 100644 --- a/metadata-ingestion/tests/unit/cli/test_quickstart_version_mapping.py +++ b/metadata-ingestion/tests/unit/cli/test_quickstart_version_mapping.py @@ -32,7 +32,7 @@ "mysql_tag": "8.2", }, }, - } + }, ) @@ -59,7 +59,9 @@ def test_quickstart_version_config_default(): def test_quickstart_version_config_stable(): execution_plan = example_version_mapper.get_quickstart_execution_plan("stable") expected = QuickstartExecutionPlan( - docker_tag="latest", composefile_git_ref="v1.0.1", mysql_tag="8.2" + docker_tag="latest", + composefile_git_ref="v1.0.1", + mysql_tag="8.2", ) assert execution_plan == expected @@ -87,7 +89,7 @@ def test_quickstart_forced_not_a_version_tag(): a recent change, otherwise it should be on an older version of datahub. """ execution_plan = example_version_mapper.get_quickstart_execution_plan( - "NOT A VERSION" + "NOT A VERSION", ) expected = QuickstartExecutionPlan( docker_tag="NOT A VERSION", diff --git a/metadata-ingestion/tests/unit/config/test_config_clean.py b/metadata-ingestion/tests/unit/config/test_config_clean.py index 178030c773e7f3..888cde22da62e6 100644 --- a/metadata-ingestion/tests/unit/config/test_config_clean.py +++ b/metadata-ingestion/tests/unit/config/test_config_clean.py @@ -4,7 +4,8 @@ def test_remove_suffix(): assert ( config_clean.remove_suffix( - "xaaabcdef.snowflakecomputing.com", ".snowflakecomputing.com" + "xaaabcdef.snowflakecomputing.com", + ".snowflakecomputing.com", ) == "xaaabcdef" ) diff --git a/metadata-ingestion/tests/unit/config/test_config_loader.py b/metadata-ingestion/tests/unit/config/test_config_loader.py index 43781acd7f80c0..6c1a3f3d22ff6a 100644 --- a/metadata-ingestion/tests/unit/config/test_config_loader.py +++ b/metadata-ingestion/tests/unit/config/test_config_loader.py @@ -90,7 +90,7 @@ "test_url$vanillavar", "test_urlstuff9vanillaVarstuff10", "${VAR11}", - ] + ], }, ], }, @@ -153,7 +153,7 @@ def test_load_strict_env_syntax() -> None: assert EnvResolver( environ={ "BAR": "bar", - } + }, ).resolve(config) == { "foo": "bar", "baz": "$BAZ", @@ -237,7 +237,7 @@ def test_write_file_directive(pytestconfig): -----BEGIN CERTIFICATE----- thisisnotarealcert -----END CERTIFICATE----- - """ + """, ).lstrip() ) diff --git a/metadata-ingestion/tests/unit/config/test_connection_resolver.py b/metadata-ingestion/tests/unit/config/test_connection_resolver.py index 592d145ac3c040..3aa3449d306e13 100644 --- a/metadata-ingestion/tests/unit/config/test_connection_resolver.py +++ b/metadata-ingestion/tests/unit/config/test_connection_resolver.py @@ -18,7 +18,7 @@ class MyConnectionType(ConfigModel): def test_auto_connection_resolver(): # Test a normal config. config = MyConnectionType.parse_obj( - {"username": "test_user", "password": "test_password"} + {"username": "test_user", "password": "test_password"}, ) assert config.username == "test_user" assert config.password == "test_password" @@ -28,7 +28,7 @@ def test_auto_connection_resolver(): config = MyConnectionType.parse_obj( { "connection": "test_connection", - } + }, ) # Missing connection -> should raise an error. @@ -39,7 +39,7 @@ def test_auto_connection_resolver(): config = MyConnectionType.parse_obj( { "connection": "urn:li:dataHubConnection:missing-connection", - } + }, ) # Bad connection config -> should raise an error. @@ -49,7 +49,7 @@ def test_auto_connection_resolver(): config = MyConnectionType.parse_obj( { "connection": "urn:li:dataHubConnection:bad-connection", - } + }, ) # Good connection config. @@ -62,7 +62,7 @@ def test_auto_connection_resolver(): { "connection": "urn:li:dataHubConnection:good-connection", "username": "override_user", - } + }, ) assert config.username == "override_user" assert config.password == "test_password" diff --git a/metadata-ingestion/tests/unit/config/test_datetime_parser.py b/metadata-ingestion/tests/unit/config/test_datetime_parser.py index 7b1a1f83f8373f..6e615d7567ab05 100644 --- a/metadata-ingestion/tests/unit/config/test_datetime_parser.py +++ b/metadata-ingestion/tests/unit/config/test_datetime_parser.py @@ -13,41 +13,99 @@ def test_user_time_parser(): # Absolute times. assert parse_user_datetime("2022-01-01 01:02:03 UTC") == datetime( - 2022, 1, 1, 1, 2, 3, tzinfo=timezone.utc + 2022, + 1, + 1, + 1, + 2, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("2022-01-01 01:02:03 -02:00") == datetime( - 2022, 1, 1, 3, 2, 3, tzinfo=timezone.utc + 2022, + 1, + 1, + 3, + 2, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("2024-03-01 00:46:33.000 -0800") == datetime( - 2024, 3, 1, 8, 46, 33, tzinfo=timezone.utc + 2024, + 3, + 1, + 8, + 46, + 33, + tzinfo=timezone.utc, ) # Times with no timestamp are assumed to be in UTC. assert parse_user_datetime("2022-01-01 01:02:03") == datetime( - 2022, 1, 1, 1, 2, 3, tzinfo=timezone.utc + 2022, + 1, + 1, + 1, + 2, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("2022-02-03") == datetime( - 2022, 2, 3, tzinfo=timezone.utc + 2022, + 2, + 3, + tzinfo=timezone.utc, ) # Timestamps. assert parse_user_datetime("1630440123") == datetime( - 2021, 8, 31, 20, 2, 3, tzinfo=timezone.utc + 2021, + 8, + 31, + 20, + 2, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("1630440123837.018") == datetime( - 2021, 8, 31, 20, 2, 3, 837018, tzinfo=timezone.utc + 2021, + 8, + 31, + 20, + 2, + 3, + 837018, + tzinfo=timezone.utc, ) # Relative times. assert parse_user_datetime("10m") == datetime( - 2021, 9, 1, 10, 12, 3, tzinfo=timezone.utc + 2021, + 9, + 1, + 10, + 12, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("+ 1 day") == datetime( - 2021, 9, 2, 10, 2, 3, tzinfo=timezone.utc + 2021, + 9, + 2, + 10, + 2, + 3, + tzinfo=timezone.utc, ) assert parse_user_datetime("-2 days") == datetime( - 2021, 8, 30, 10, 2, 3, tzinfo=timezone.utc + 2021, + 8, + 30, + 10, + 2, + 3, + tzinfo=timezone.utc, ) # Invalid inputs. diff --git a/metadata-ingestion/tests/unit/config/test_key_value_pattern.py b/metadata-ingestion/tests/unit/config/test_key_value_pattern.py index 1ac5666f1f4b9a..8d9e6dfa5a6909 100644 --- a/metadata-ingestion/tests/unit/config/test_key_value_pattern.py +++ b/metadata-ingestion/tests/unit/config/test_key_value_pattern.py @@ -28,7 +28,8 @@ def test_no_fallthrough_pattern() -> None: def test_fallthrough_pattern() -> None: pattern = KeyValuePattern( - rules={"foo.*": ["bar", "baz"], ".*": ["qux"]}, first_match_only=False + rules={"foo.*": ["bar", "baz"], ".*": ["qux"]}, + first_match_only=False, ) assert pattern.value("foo") == ["bar", "baz", "qux"] assert pattern.value("foo.bar") == ["bar", "baz", "qux"] diff --git a/metadata-ingestion/tests/unit/config/test_pydantic_validators.py b/metadata-ingestion/tests/unit/config/test_pydantic_validators.py index f687a2776f6e2d..34669364d84473 100644 --- a/metadata-ingestion/tests/unit/config/test_pydantic_validators.py +++ b/metadata-ingestion/tests/unit/config/test_pydantic_validators.py @@ -99,7 +99,7 @@ class TestModel(ConfigModel): with pytest.warns(ConfigurationWarning, match=r"d\d.+ deprecated"): v = TestModel.parse_obj( - {"b": "original", "d1": "deprecated", "d2": "deprecated"} + {"b": "original", "d1": "deprecated", "d2": "deprecated"}, ) assert v.b == "original" assert v.d1 == "deprecated" diff --git a/metadata-ingestion/tests/unit/config/test_time_window_config.py b/metadata-ingestion/tests/unit/config/test_time_window_config.py index 847bda2511a0ce..8bec70fded8bd3 100644 --- a/metadata-ingestion/tests/unit/config/test_time_window_config.py +++ b/metadata-ingestion/tests/unit/config/test_time_window_config.py @@ -34,13 +34,13 @@ def test_relative_start_time(): assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) config = BaseTimeWindowConfig.parse_obj( - {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} + {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"}, ) assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) config = BaseTimeWindowConfig.parse_obj( - {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} + {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"}, ) assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) @@ -69,7 +69,8 @@ def test_invalid_relative_start_time(): BaseTimeWindowConfig.parse_obj({"start_time": "-2"}) with pytest.raises( - ValueError, match="Relative start time should start with minus sign" + ValueError, + match="Relative start time should start with minus sign", ): BaseTimeWindowConfig.parse_obj({"start_time": "2d"}) diff --git a/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py b/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py index a1ef02c27ea540..bdf7efab024557 100644 --- a/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py +++ b/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py @@ -36,12 +36,13 @@ "integer_field": [1, 2, 3], "boolean_field": [True, False, True], "string_field": ["a", "b", "c"], - } + }, ) def assert_field_types_match( - fields: List[SchemaField], expected_field_types: List[Type] + fields: List[SchemaField], + expected_field_types: List[Type], ) -> None: assert len(fields) == len(expected_field_types) for field, expected_type in zip(fields, expected_field_types): @@ -63,8 +64,9 @@ def test_infer_schema_tsv(): with tempfile.TemporaryFile(mode="w+b") as file: file.write( bytes( - test_table.to_csv(index=False, header=True, sep="\t"), encoding="utf-8" - ) + test_table.to_csv(index=False, header=True, sep="\t"), + encoding="utf-8", + ), ) file.seek(0) @@ -77,7 +79,7 @@ def test_infer_schema_tsv(): def test_infer_schema_jsonl(): with tempfile.TemporaryFile(mode="w+b") as file: file.write( - bytes(test_table.to_json(orient="records", lines=True), encoding="utf-8") + bytes(test_table.to_json(orient="records", lines=True), encoding="utf-8"), ) file.seek(0) @@ -120,8 +122,8 @@ def test_infer_schema_avro(): {"name": "boolean_field", "type": "boolean"}, {"name": "string_field", "type": "string"}, ], - } - ) + }, + ), ) writer = DataFileWriter(file, DatumWriter(), schema) records = test_table.to_dict(orient="records") diff --git a/metadata-ingestion/tests/unit/glue/test_glue_source.py b/metadata-ingestion/tests/unit/glue/test_glue_source.py index 9e3f260a23f1c8..ef2884dd43ba70 100644 --- a/metadata-ingestion/tests/unit/glue/test_glue_source.py +++ b/metadata-ingestion/tests/unit/glue/test_glue_source.py @@ -153,7 +153,8 @@ def glue_source_with_profiling( ) def test_column_type(hive_column_type: str, expected_type: Type) -> None: avro_schema = get_avro_schema_for_hive_column( - f"test_column_{hive_column_type}", hive_column_type + f"test_column_{hive_column_type}", + hive_column_type, ) schema_fields = avro_schema_to_mce_fields(json.dumps(avro_schema)) actual_schema_field_type = schema_fields[0].type @@ -213,7 +214,7 @@ def test_glue_ingest( with Stubber(glue_source_instance.s3_client) as s3_stubber: for _ in range( len(get_tables_response_1["TableList"]) - + len(get_tables_response_2["TableList"]) + + len(get_tables_response_2["TableList"]), ): s3_stubber.add_response( "get_bucket_tagging", @@ -334,21 +335,23 @@ def format_databases(databases): ) with Stubber(single_catalog_source.glue_client) as glue_stubber: glue_stubber.add_response( - "get_databases", get_databases_response, {"CatalogId": catalog_id} + "get_databases", + get_databases_response, + {"CatalogId": catalog_id}, ) expected = [flights_database, test_database] actual = single_catalog_source.get_all_databases() assert format_databases(actual) == format_databases(expected) assert single_catalog_source.report.databases.dropped_entities.as_obj() == [ - "empty-database" + "empty-database", ] @freeze_time(FROZEN_TIME) def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): deleted_actor_golden_mcs = "{}/glue_deleted_actor_mces_golden.json".format( - test_resources_dir + test_resources_dir, ) stateful_config = { @@ -376,7 +379,7 @@ def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): }, "sink": { # we are not really interested in the resulting events for this test - "type": "console" + "type": "console", }, "pipeline_name": "statefulpipeline", } @@ -416,10 +419,12 @@ def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Validate against golden MCEs where Status(removed=true) @@ -434,7 +439,7 @@ def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph): state1 = cast(BaseSQLAlchemyCheckpointState, checkpoint1.state) state2 = cast(BaseSQLAlchemyCheckpointState, checkpoint2.state) difference_urns = set( - state1.get_urns_not_in(type="*", other_checkpoint_state=state2) + state1.get_urns_not_in(type="*", other_checkpoint_state=state2), ) assert difference_urns == { "urn:li:dataset:(urn:li:dataPlatform:glue,flights-database.avro,PROD)", @@ -567,7 +572,7 @@ def test_glue_ingest_include_table_lineage( with Stubber(glue_source_instance.s3_client) as s3_stubber: for _ in range( len(get_tables_response_1["TableList"]) - + len(get_tables_response_2["TableList"]) + + len(get_tables_response_2["TableList"]), ): s3_stubber.add_response( "get_bucket_tagging", @@ -642,7 +647,7 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: version=0, hash="", platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -694,7 +699,9 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: with Stubber(glue_source_instance.glue_client) as glue_stubber: glue_stubber.add_response( - "get_databases", get_databases_response_for_lineage, {} + "get_databases", + get_databases_response_for_lineage, + {}, ) glue_stubber.add_response( "get_tables", diff --git a/metadata-ingestion/tests/unit/glue/test_glue_source_stubs.py b/metadata-ingestion/tests/unit/glue/test_glue_source_stubs.py index 43bf62fd4e3b8a..f4b6a4d35f99cb 100644 --- a/metadata-ingestion/tests/unit/glue/test_glue_source_stubs.py +++ b/metadata-ingestion/tests/unit/glue/test_glue_source_stubs.py @@ -43,7 +43,7 @@ "IsRegisteredWithLakeFormation": False, "CatalogId": "432143214321", "VersionId": "504", - } + }, ] get_tables_response_for_target_database = {"TableList": target_database_tables} @@ -55,10 +55,10 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "123412341234", "LocationUri": "s3://test-bucket/test-prefix", @@ -70,10 +70,10 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "123412341234", }, @@ -83,14 +83,14 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "000000000000", }, - ] + ], } flights_database = {"Name": "flights-database", "CatalogId": "123412341234"} test_database = {"Name": "test-database", "CatalogId": "123412341234"} @@ -145,7 +145,7 @@ "StoredAsSubDirectories": False, }, "PartitionKeys": [ - {"Name": "year", "Type": "string", "Comment": "partition test comment"} + {"Name": "year", "Type": "string", "Comment": "partition test comment"}, ], "TableType": "EXTERNAL_TABLE", "Parameters": { @@ -164,7 +164,7 @@ "CreatedBy": "arn:aws:sts::123412341234:assumed-role/AWSGlueServiceRole-flights-crawler/AWS-Crawler", "IsRegisteredWithLakeFormation": False, "CatalogId": "123412341234", - } + }, ] get_tables_response_1 = {"TableList": tables_1} tables_2 = [ @@ -181,7 +181,7 @@ { "Name": "markers", "Type": "array,location:array>>", - } + }, ], "Location": "s3://test-glue-jsons/markers/", "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", @@ -355,7 +355,7 @@ "NumberOfWorkers": 10, "GlueVersion": "2.0", }, - ] + ], } # for job 1 get_dataflow_graph_response_1 = { @@ -813,14 +813,14 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "123412341234", }, - ] + ], } delta_tables_1 = [ { @@ -848,7 +848,7 @@ "CreatedBy": "arn:aws:sts::123412341234:assumed-role/AWSGlueServiceRole-flights-crawler/AWS-Crawler", "IsRegisteredWithLakeFormation": False, "CatalogId": "123412341234", - } + }, ] get_delta_tables_response_1 = {"TableList": delta_tables_1} @@ -876,7 +876,7 @@ "CreatedBy": "arn:aws:sts::123412341234:assumed-role/AWSGlueServiceRole-flights-crawler/AWS-Crawler", "IsRegisteredWithLakeFormation": False, "CatalogId": "123412341234", - } + }, ] get_delta_tables_response_2 = {"TableList": delta_tables_2} @@ -888,16 +888,16 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "123412341234", "LocationUri": "s3://test-bucket/test-prefix", "Parameters": {"param1": "value1", "param2": "value2"}, }, - ] + ], } tables_lineage_1 = [ @@ -949,7 +949,7 @@ "StoredAsSubDirectories": False, }, "PartitionKeys": [ - {"Name": "year", "Type": "string", "Comment": "partition test comment"} + {"Name": "year", "Type": "string", "Comment": "partition test comment"}, ], "TableType": "EXTERNAL_TABLE", "Parameters": { @@ -968,7 +968,7 @@ "CreatedBy": "arn:aws:sts::123412341234:assumed-role/AWSGlueServiceRole-flights-crawler/AWS-Crawler", "IsRegisteredWithLakeFormation": False, "CatalogId": "123412341234", - } + }, ] get_tables_lineage_response_1 = {"TableList": tables_lineage_1} @@ -981,16 +981,16 @@ "CreateTableDefaultPermissions": [ { "Principal": { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS", }, "Permissions": ["ALL"], - } + }, ], "CatalogId": "123412341234", "LocationUri": "s3://test-bucket/test-prefix", "Parameters": {"param1": "value1", "param2": "value2"}, }, - ] + ], } tables_profiling_1 = [ @@ -1074,7 +1074,7 @@ "CreatedBy": "arn:aws:sts::123412341234:assumed-role/AWSGlueServiceRole-flights-crawler/AWS-Crawler", "IsRegisteredWithLakeFormation": False, "CatalogId": "123412341234", - } + }, ] get_tables_response_profiling_1 = {"TableList": tables_profiling_1} diff --git a/metadata-ingestion/tests/unit/patch/test_patch_builder.py b/metadata-ingestion/tests/unit/patch/test_patch_builder.py index f4bf501e0714d0..1caeb095066670 100644 --- a/metadata-ingestion/tests/unit/patch/test_patch_builder.py +++ b/metadata-ingestion/tests/unit/patch/test_patch_builder.py @@ -36,7 +36,7 @@ def test_basic_dataset_patch_builder(): patcher = DatasetPatchBuilder( - make_dataset_urn(platform="hive", name="fct_users_created", env="PROD") + make_dataset_urn(platform="hive", name="fct_users_created", env="PROD"), ).add_tag(TagAssociationClass(tag=make_tag_urn("test_tag"))) assert patcher.build() == [ @@ -54,11 +54,12 @@ def test_basic_dataset_patch_builder(): def test_complex_dataset_patch( - pytestconfig: pytest.Config, tmp_path: pathlib.Path + pytestconfig: pytest.Config, + tmp_path: pathlib.Path, ) -> None: patcher = ( DatasetPatchBuilder( - make_dataset_urn(platform="hive", name="fct_users_created", env="PROD") + make_dataset_urn(platform="hive", name="fct_users_created", env="PROD"), ) .set_description("test description") .add_custom_property("test_key_1", "test_value_1") @@ -67,18 +68,22 @@ def test_complex_dataset_patch( .add_upstream_lineage( upstream=UpstreamClass( dataset=make_dataset_urn( - platform="hive", name="fct_users_created_upstream", env="PROD" + platform="hive", + name="fct_users_created_upstream", + env="PROD", ), type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ) .add_upstream_lineage( upstream=UpstreamClass( dataset=make_dataset_urn( - platform="s3", name="my-bucket/my-folder/my-file.txt", env="PROD" + platform="s3", + name="my-bucket/my-folder/my-file.txt", + env="PROD", ), type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ) .add_fine_grained_upstream_lineage( fine_grained_lineage=FineGrainedLineageClass( @@ -91,7 +96,7 @@ def test_complex_dataset_patch( env="PROD", ), field_path="foo", - ) + ), ], upstreams=[ make_schema_field_urn( @@ -101,12 +106,12 @@ def test_complex_dataset_patch( env="PROD", ), field_path="bar", - ) + ), ], downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD, transformOperation="TRANSFORM", confidenceScore=1.0, - ) + ), ) .add_fine_grained_upstream_lineage( fine_grained_lineage=FineGrainedLineageClass( @@ -119,7 +124,7 @@ def test_complex_dataset_patch( env="PROD", ), field_path="foo", - ) + ), ], downstreamType=FineGrainedLineageDownstreamTypeClass.FIELD_SET, downstreams=[ @@ -130,9 +135,9 @@ def test_complex_dataset_patch( env="PROD", ), field_path="foo", - ) + ), ], - ) + ), ) ) patcher.for_field("field1").add_tag(TagAssociationClass(tag=make_tag_urn("tag1"))) @@ -149,7 +154,7 @@ def test_complex_dataset_patch( def test_basic_chart_patch_builder(): patcher = ChartPatchBuilder( - make_chart_urn(platform="hive", name="fct_users_created") + make_chart_urn(platform="hive", name="fct_users_created"), ).add_tag(TagAssociationClass(tag=make_tag_urn("test_tag"))) assert patcher.build() == [ @@ -168,7 +173,7 @@ def test_basic_chart_patch_builder(): def test_basic_dashboard_patch_builder(): patcher = DashboardPatchBuilder( - make_dashboard_urn(platform="hive", name="fct_users_created") + make_dashboard_urn(platform="hive", name="fct_users_created"), ).add_tag(TagAssociationClass(tag=make_tag_urn("test_tag"))) assert patcher.build() == [ @@ -246,27 +251,29 @@ def get_edge_expectation(urn: str) -> Dict[str, Any]: return {"destinationUrn": str(urn)} flow_urn = make_data_flow_urn( - orchestrator="nifi", flow_id="252C34e5af19-0192-1000-b248-b1abee565b5d" + orchestrator="nifi", + flow_id="252C34e5af19-0192-1000-b248-b1abee565b5d", ) job_urn = make_data_job_urn_with_flow( - flow_urn, "5ca6fee7-0192-1000-f206-dfbc2b0d8bfb" + flow_urn, + "5ca6fee7-0192-1000-f206-dfbc2b0d8bfb", ) patcher = DataJobPatchBuilder(job_urn) patcher.add_output_dataset( make_edge_or_urn( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder1,DEV)" - ) + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder1,DEV)", + ), ) patcher.add_output_dataset( make_edge_or_urn( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder3,DEV)" - ) + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder3,DEV)", + ), ) patcher.add_output_dataset( make_edge_or_urn( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder2,DEV)" - ) + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder2,DEV)", + ), ) assert patcher.build() == [ @@ -282,26 +289,26 @@ def get_edge_expectation(urn: str) -> Dict[str, Any]: "op": "add", "path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket~1folder1,DEV)", "value": get_edge_expectation( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder1,DEV)" + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder1,DEV)", ), }, { "op": "add", "path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket~1folder3,DEV)", "value": get_edge_expectation( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder3,DEV)" + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder3,DEV)", ), }, { "op": "add", "path": "/outputDatasetEdges/urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket~1folder2,DEV)", "value": get_edge_expectation( - "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder2,DEV)" + "urn:li:dataset:(urn:li:dataPlatform:s3,output-bucket/folder2,DEV)", ), }, - ] + ], ).encode("utf-8"), contentType="application/json-patch+json", ), - ) + ), ] diff --git a/metadata-ingestion/tests/unit/redshift/redshift_query_mocker.py b/metadata-ingestion/tests/unit/redshift/redshift_query_mocker.py index 10c0250e37e37f..254182f6e166d7 100644 --- a/metadata-ingestion/tests/unit/redshift/redshift_query_mocker.py +++ b/metadata-ingestion/tests/unit/redshift/redshift_query_mocker.py @@ -20,7 +20,7 @@ def mock_temp_table_cursor(cursor: MagicMock) -> None: "price_usd from player_activity group by player_id", "CREATE TABLE #player_price", datetime.now(), - ) + ), ], [ # Empty result to stop the while loop @@ -46,7 +46,7 @@ def mock_stl_insert_table_cursor(cursor: MagicMock) -> None: "player_price_with_hike_v6", "INSERT INTO player_price_with_hike_v6 SELECT (price_usd + 0.2 * price_usd) as price, '20%' FROM " "#player_price", - ) + ), ], [ # Empty result to stop the while loop diff --git a/metadata-ingestion/tests/unit/redshift/test_redshift_lineage.py b/metadata-ingestion/tests/unit/redshift/test_redshift_lineage.py index 27045dfc656cbe..53a1dce8df4805 100644 --- a/metadata-ingestion/tests/unit/redshift/test_redshift_lineage.py +++ b/metadata-ingestion/tests/unit/redshift/test_redshift_lineage.py @@ -37,10 +37,13 @@ def test_get_sources_from_query(): select * from my_schema.my_table """ lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") + config, + report, + PipelineContext(run_id="foo"), ) lineage_datasets, _ = lineage_extractor._get_sources_from_query( - db_name="test", query=test_query + db_name="test", + query=test_query, ) assert len(lineage_datasets) == 1 @@ -60,10 +63,13 @@ def test_get_sources_from_query_with_only_table_name(): select * from my_table """ lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") + config, + report, + PipelineContext(run_id="foo"), ) lineage_datasets, _ = lineage_extractor._get_sources_from_query( - db_name="test", query=test_query + db_name="test", + query=test_query, ) assert len(lineage_datasets) == 1 @@ -83,10 +89,13 @@ def test_get_sources_from_query_with_database(): select * from test.my_schema.my_table """ lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") + config, + report, + PipelineContext(run_id="foo"), ) lineage_datasets, _ = lineage_extractor._get_sources_from_query( - db_name="test", query=test_query + db_name="test", + query=test_query, ) assert len(lineage_datasets) == 1 @@ -106,10 +115,13 @@ def test_get_sources_from_query_with_non_default_database(): select * from test2.my_schema.my_table """ lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") + config, + report, + PipelineContext(run_id="foo"), ) lineage_datasets, _ = lineage_extractor._get_sources_from_query( - db_name="test", query=test_query + db_name="test", + query=test_query, ) assert len(lineage_datasets) == 1 @@ -129,10 +141,13 @@ def test_get_sources_from_query_with_only_table(): select * from my_table """ lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo") + config, + report, + PipelineContext(run_id="foo"), ) lineage_datasets, _ = lineage_extractor._get_sources_from_query( - db_name="test", query=test_query + db_name="test", + query=test_query, ) assert len(lineage_datasets) == 1 @@ -151,7 +166,8 @@ def test_parse_alter_table_rename(): "bar", ) assert parse_alter_table_rename( - "public", "alter table second_schema.storage_v2_stg rename to storage_v2; " + "public", + "alter table second_schema.storage_v2_stg rename to storage_v2; ", ) == ( "second_schema", "storage_v2_stg", @@ -170,7 +186,9 @@ def get_lineage_extractor() -> RedshiftLineageExtractor: report = RedshiftReport() lineage_extractor = RedshiftLineageExtractor( - config, report, PipelineContext(run_id="foo", graph=mock_graph()) + config, + report, + PipelineContext(run_id="foo", graph=mock_graph()), ) return lineage_extractor @@ -247,7 +265,9 @@ def test_collapse_temp_lineage(): query="select * from test_collapse_temp_lineage", database=lineage_extractor.config.database, all_tables_set={ - lineage_extractor.config.database: {"public": {"player_price_with_hike_v6"}} + lineage_extractor.config.database: { + "public": {"player_price_with_hike_v6"}, + }, }, connection=connection, lineage_type=LineageCollectorType.QUERY_SQL_PARSER, @@ -294,10 +314,10 @@ def test_collapse_temp_recursive_cll_lineage(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -312,7 +332,7 @@ def test_collapse_temp_recursive_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -327,7 +347,7 @@ def test_collapse_temp_recursive_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="price_usd", - ) + ), ], logic=None, ), @@ -346,10 +366,10 @@ def test_collapse_temp_recursive_cll_lineage(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -364,7 +384,7 @@ def test_collapse_temp_recursive_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -379,7 +399,7 @@ def test_collapse_temp_recursive_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", column="price", - ) + ), ], logic=None, ), @@ -406,10 +426,10 @@ def test_collapse_temp_recursive_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", column="price_usd", - ) + ), ], logic=None, - ) + ), ] datasets = lineage_extractor._get_upstream_lineages( @@ -417,7 +437,7 @@ def test_collapse_temp_recursive_cll_lineage(): LineageDataset( platform=LineageDatasetPlatform.REDSHIFT, urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", - ) + ), ], target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", raw_db_name="dev", @@ -425,7 +445,7 @@ def test_collapse_temp_recursive_cll_lineage(): all_tables_set={ "dev": { "public": set(), - } + }, }, connection=MagicMock(), target_dataset_cll=target_dataset_cll, @@ -457,10 +477,10 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -475,7 +495,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -513,10 +533,10 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -531,7 +551,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -546,7 +566,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", column="price", - ) + ), ], logic=None, ), @@ -561,7 +581,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", column="tax", - ) + ), ], logic=None, ), @@ -587,7 +607,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", column="price_usd", - ) + ), ], logic=None, ), @@ -602,7 +622,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -613,7 +633,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): LineageDataset( platform=LineageDatasetPlatform.REDSHIFT, urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", - ) + ), ], target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", raw_db_name="dev", @@ -621,7 +641,7 @@ def test_collapse_temp_recursive_with_compex_column_cll_lineage(): all_tables_set={ "dev": { "public": set(), - } + }, }, connection=MagicMock(), target_dataset_cll=target_dataset_cll, @@ -655,10 +675,10 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -673,7 +693,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -688,7 +708,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="price_usd", - ) + ), ], logic=None, ), @@ -707,10 +727,10 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): parsed_result=SqlParsingResult( query_type=QueryType.CREATE_TABLE_AS_SELECT, in_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_activity,PROD)", ], out_tables=[ - "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", ], debug_info=SqlParsingDebugInfo(), column_lineage=[ @@ -725,7 +745,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="player_id", - ) + ), ], logic=None, ), @@ -740,7 +760,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_activity_temp,PROD)", column="price_usd", - ) + ), ], logic=None, ), @@ -767,10 +787,10 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): sqlglot_l.ColumnRef( table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", column="price_usd", - ) + ), ], logic=None, - ) + ), ] datasets = lineage_extractor._get_upstream_lineages( @@ -778,7 +798,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): LineageDataset( platform=LineageDatasetPlatform.REDSHIFT, urn="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.#player_price,PROD)", - ) + ), ], target_table="urn:li:dataset:(urn:li:dataPlatform:redshift,dev.public.player_price_with_hike_v4,PROD)", raw_db_name="dev", @@ -786,7 +806,7 @@ def test_collapse_temp_recursive_cll_lineage_with_circular_reference(): all_tables_set={ "dev": { "public": set(), - } + }, }, connection=MagicMock(), target_dataset_cll=target_dataset_cll, diff --git a/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py b/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py index 2ab6208e2dcc68..5d1245f1ca8f0d 100644 --- a/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py +++ b/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py @@ -160,10 +160,12 @@ def test_non_set_data() -> None: def test_empty_structures() -> None: """Test handling of empty structures""" input_data: Dict[ - str, Union[Set[Any], Dict[Any, Any], List[Any], Tuple[Any, ...]] + str, + Union[Set[Any], Dict[Any, Any], List[Any], Tuple[Any, ...]], ] = {"empty_set": set(), "empty_dict": {}, "empty_list": [], "empty_tuple": ()} expected: Dict[ - str, Union[List[Any], Dict[Any, Any], List[Any], Tuple[Any, ...]] + str, + Union[List[Any], Dict[Any, Any], List[Any], Tuple[Any, ...]], ] = {"empty_set": [], "empty_dict": {}, "empty_list": [], "empty_tuple": ()} result = DatahubIngestionRunSummaryProvider._convert_sets_to_lists(input_data) assert result == expected diff --git a/metadata-ingestion/tests/unit/s3/test_s3_source.py b/metadata-ingestion/tests/unit/s3/test_s3_source.py index 902987213e122f..d1c3fa320e9602 100644 --- a/metadata-ingestion/tests/unit/s3/test_s3_source.py +++ b/metadata-ingestion/tests/unit/s3/test_s3_source.py @@ -136,7 +136,9 @@ def test_path_spec_with_double_star_ending(): ], ) def test_path_spec_partition_detection( - path_spec: str, path: str, expected: List[Tuple[str, str]] + path_spec: str, + path: str, + expected: List[Tuple[str, str]], ) -> None: ps = PathSpec(include=path_spec, default_extension="csv", allow_double_stars=True) assert ps.allowed(path) @@ -175,7 +177,8 @@ def test_path_spec_dir_allowed(): def test_container_generation_without_folders(): cwu = ContainerWUCreator("s3", None, "PROD") mcps = cwu.create_container_hierarchy( - "s3://my-bucket/my-file.json.gz", "urn:li:dataset:123" + "s3://my-bucket/my-file.json.gz", + "urn:li:dataset:123", ) def container_properties_filter(x: MetadataWorkUnit) -> bool: @@ -194,7 +197,8 @@ def container_properties_filter(x: MetadataWorkUnit) -> bool: def test_container_generation_with_folder(): cwu = ContainerWUCreator("s3", None, "PROD") mcps = cwu.create_container_hierarchy( - "s3://my-bucket/my-dir/my-file.json.gz", "urn:li:dataset:123" + "s3://my-bucket/my-dir/my-file.json.gz", + "urn:li:dataset:123", ) def container_properties_filter(x: MetadataWorkUnit) -> bool: @@ -218,7 +222,8 @@ def container_properties_filter(x: MetadataWorkUnit) -> bool: def test_container_generation_with_multiple_folders(): cwu = ContainerWUCreator("s3", None, "PROD") mcps = cwu.create_container_hierarchy( - "s3://my-bucket/my-dir/my-dir2/my-file.json.gz", "urn:li:dataset:123" + "s3://my-bucket/my-dir/my-dir2/my-file.json.gz", + "urn:li:dataset:123", ) def container_properties_filter(x: MetadataWorkUnit) -> bool: @@ -291,12 +296,14 @@ def _get_s3_source(path_spec_: PathSpec) -> S3Source: last_modified=datetime(2025, 1, 1, 2), size=100, ), - ] + ], ) # act res = _get_s3_source(path_spec).get_folder_info( - path_spec, bucket, prefix="/my-folder" + path_spec, + bucket, + prefix="/my-folder", ) # assert diff --git a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py index c7a1fab068a838..b05335c9bd72b5 100644 --- a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py +++ b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py @@ -79,14 +79,14 @@ def test_sagemaker_ingest(tmp_path, pytestconfig): "list_associations", list_first_endpoint_incoming_response, { - "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint" + "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint", }, ) sagemaker_stubber.add_response( "list_associations", list_first_endpoint_outgoing_response, { - "SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint" + "SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-first-endpoint", }, ) @@ -94,14 +94,14 @@ def test_sagemaker_ingest(tmp_path, pytestconfig): "list_associations", list_second_endpoint_incoming_response, { - "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint" + "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint", }, ) sagemaker_stubber.add_response( "list_associations", list_second_endpoint_outgoing_response, { - "SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint" + "SourceArn": "arn:aws:sagemaker:us-west-2:123412341234:action/deploy-the-second-endpoint", }, ) @@ -109,21 +109,21 @@ def test_sagemaker_ingest(tmp_path, pytestconfig): "list_associations", get_model_group_incoming_response, { - "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:context/a-model-package-group-context" + "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:context/a-model-package-group-context", }, ) sagemaker_stubber.add_response( "list_associations", get_first_model_package_incoming_response, { - "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-first-model-package-artifact" + "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-first-model-package-artifact", }, ) sagemaker_stubber.add_response( "list_associations", get_second_model_package_incoming_response, { - "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-second-model-package-artifact" + "DestinationArn": "arn:aws:sagemaker:us-west-2:123412341234:artifact/the-second-model-package-artifact", }, ) diff --git a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source_stubs.py b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source_stubs.py index 59f04c8c1798b1..fd044b889d5402 100644 --- a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source_stubs.py +++ b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source_stubs.py @@ -122,7 +122,7 @@ "S3DataSource": { "S3DataType": "ManifestFile", # 'ManifestFile'|'S3Prefix' "S3Uri": "s3://auto-ml-job-input-bucket/file.txt", - } + }, }, "CompressionType": "None", # 'None'|'Gzip' }, @@ -144,7 +144,7 @@ "AutoMLProblemTypeResolvedAttributes": { "TabularResolvedAttributes": { "ProblemType": "BinaryClassification", # 'BinaryClassification'|'MulticlassClassification'|'Regression' - } + }, }, }, "SecurityConfig": { @@ -202,7 +202,7 @@ "LastModifiedTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "FailureReason": "string", "CandidateProperties": { - "CandidateArtifactLocations": {"Explainability": "string"} + "CandidateArtifactLocations": {"Explainability": "string"}, }, }, "AutoMLJobStatus": "Completed", # "Completed" | "InProgress" | "Failed" | "Stopped" | "Stopping" @@ -229,9 +229,9 @@ { "AutoMLAlgorithms": [ "xgboost", # 'xgboost'|'linear-learner'|'mlp'|'lightgbm'|'catboost'|'randomforest'|'extra-trees'|'nn-torch'|'fastai' - ] + ], }, - ] + ], }, "CompletionCriteria": { "MaxCandidates": 123, @@ -247,7 +247,7 @@ # | "Regression", "TargetAttributeName": "ChannelType", # ChannelType, ContentType, CompressionType, DataSource "SampleWeightAttributeName": "string", - } + }, }, "AutoMLJobArtifacts": { "CandidateDefinitionNotebookLocation": "string", @@ -293,7 +293,7 @@ "LastModifiedTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "FailureReason": "string", "ModelArtifacts": { - "S3ModelArtifacts": "s3://compilation-job-bucket/model-artifacts.tar.gz" + "S3ModelArtifacts": "s3://compilation-job-bucket/model-artifacts.tar.gz", }, "ModelDigests": {"ArtifactDigest": "string"}, "RoleArn": "arn:aws:iam::123412341234:role/service-role/AmazonSageMakerServiceCatalogProductsUseRole", @@ -688,7 +688,7 @@ "ContentClassifiers": [ "FreeOfPersonallyIdentifiableInformation", "FreeOfAdultContent", - ] + ], }, }, }, @@ -719,7 +719,7 @@ "ContentClassifiers": [ "FreeOfPersonallyIdentifiableInformation", "FreeOfAdultContent", - ] + ], }, }, "OutputConfig": { @@ -756,7 +756,7 @@ "MaxConcurrentTaskCount": 123, "AnnotationConsolidationConfig": {"AnnotationConsolidationLambdaArn": "string"}, "PublicWorkforceTaskPrice": { - "AmountInUsd": {"Dollars": 123, "Cents": 123, "TenthFractionsOfACent": 123} + "AmountInUsd": {"Dollars": 123, "Cents": 123, "TenthFractionsOfACent": 123}, }, }, "Tags": [ @@ -1037,7 +1037,7 @@ "InstanceType": "ml.t3.medium", "VolumeSizeInGB": 123, "VolumeKmsKeyId": "string", - } + }, }, "StoppingCondition": {"MaxRuntimeInSeconds": 123}, "AppSpecification": { @@ -1117,7 +1117,7 @@ "S3DataSource": { "S3DataType": "ManifestFile", # "ManifestFile" | "S3Prefix" | "AugmentedManifestFile" "S3Uri": "s3://transform-job/input-data-source.tar.gz", - } + }, }, "ContentType": "string", "CompressionType": "None", # "None" | "Gzip" @@ -1254,7 +1254,7 @@ "AutoRollbackConfiguration": { "Alarms": [ {"AlarmName": "string"}, - ] + ], }, }, } @@ -1305,7 +1305,7 @@ "AutoRollbackConfiguration": { "Alarms": [ {"AlarmName": "string"}, - ] + ], }, }, } @@ -1492,7 +1492,7 @@ "DestinationName": "the-first-endpoint-artifact", "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "CreatedBy": {}, - } + }, ], } @@ -1534,7 +1534,7 @@ "DestinationName": "the-second-endpoint-artifact", "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "CreatedBy": {}, - } + }, ], } @@ -1615,7 +1615,7 @@ "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "LastModifiedTime": datetime(2015, 1, 1, tzinfo=timezone.utc), }, - ] + ], } get_model_group_incoming_response = { @@ -1640,7 +1640,7 @@ "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "CreatedBy": {}, }, - ] + ], } get_first_model_package_incoming_response = { @@ -1663,7 +1663,7 @@ "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "CreatedBy": {}, }, - ] + ], } get_second_model_package_incoming_response = { @@ -1686,7 +1686,7 @@ "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "CreatedBy": {}, }, - ] + ], } list_groups_response = { @@ -1698,7 +1698,7 @@ "CreationTime": datetime(2015, 1, 1, tzinfo=timezone.utc), "ModelPackageGroupStatus": "Completed", # 'Pending'|'InProgress'|'Completed'|'Failed'|'Deleting'|'DeleteFailed' }, - ] + ], } describe_group_response = { diff --git a/metadata-ingestion/tests/unit/schema/test_json_schema_util.py b/metadata-ingestion/tests/unit/schema/test_json_schema_util.py index 34ccc3d4fb9225..1266599b0e4d9f 100644 --- a/metadata-ingestion/tests/unit/schema/test_json_schema_util.py +++ b/metadata-ingestion/tests/unit/schema/test_json_schema_util.py @@ -30,7 +30,7 @@ "my.field": { "type": ["string", "null"], "description": "some.doc", - } + }, }, } @@ -53,7 +53,8 @@ def assert_fields_are_valid(fields: Iterable[SchemaField]) -> None: def assert_field_paths_match( - fields: Iterable[SchemaField], expected_field_paths: Union[List[str], List[Dict]] + fields: Iterable[SchemaField], + expected_field_paths: Union[List[str], List[Dict]], ) -> None: log_field_paths(fields) assert len([f for f in fields]) == len(expected_field_paths) @@ -94,7 +95,7 @@ def test_json_schema_to_mce_fields_sample_events_with_different_field_types(): "a_map_of_longs_field": { "type": "object", "additionalProperties": {"type": "integer"}, - } + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) @@ -102,7 +103,7 @@ def test_json_schema_to_mce_fields_sample_events_with_different_field_types(): { "path": "[version=2.0].[type=R].[type=map].[type=integer].a_map_of_longs_field", "type": MapTypeClass, - } + }, ] assert_field_paths_match(fields, expected_field_paths) assert_fields_are_valid(fields) @@ -137,7 +138,7 @@ def test_json_schema_to_mce_fields_toplevel_isnt_a_record(): schema = {"type": "string"} fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) expected_field_paths = [ - {"path": "[version=2.0].[type=string]", "type": StringTypeClass} + {"path": "[version=2.0].[type=string]", "type": StringTypeClass}, ] assert_field_paths_match(fields, expected_field_paths) assert_fields_are_valid(fields) @@ -237,7 +238,7 @@ def test_json_schema_to_schema_fields_with_nesting_across_records(): "streetAddress": {"type": "string"}, "city": {"type": "string"}, }, - } + }, }, "oneOf": [ {"$ref": "#/definitions/Address"}, @@ -297,11 +298,11 @@ def test_simple_nested_record_with_a_string_field_for_key_schema(): "type": "object", "title": "InnerRcd", "properties": {"aStringField": {"type": "string"}}, - } + }, }, } fields = list( - JsonSchemaTranslator.get_fields_from_schema(schema, is_key_schema=True) + JsonSchemaTranslator.get_fields_from_schema(schema, is_key_schema=True), ) expected_field_paths: List[str] = [ "[version=2.0].[key=True].[type=SimpleNested].[type=InnerRcd].nestedRcd", @@ -324,8 +325,8 @@ def test_union_with_nested_record_of_union(): "title": "Rcd", "properties": {"aNullableStringField": {"type": "string"}}, }, - ] - } + ], + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) @@ -371,7 +372,7 @@ def test_nested_arrays(): "properties": {"a": {"type": "integer"}}, }, }, - } + }, }, } @@ -402,13 +403,13 @@ def test_map_of_union_of_int_and_record_of_union(): "title": "Rcd", "properties": { "aUnion": { - "oneOf": [{"type": "string"}, {"type": "integer"}] - } + "oneOf": [{"type": "string"}, {"type": "integer"}], + }, }, }, - ] + ], }, - } + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) @@ -454,7 +455,7 @@ def test_recursive_json(): "anIntegerField": {"type": "integer"}, "aRecursiveField": {"$ref": "#/properties/r"}, }, - } + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) @@ -498,8 +499,8 @@ def test_needs_disambiguation_nested_union_of_records_with_same_field_name(): }, }, }, - ] - } + ], + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) @@ -522,7 +523,7 @@ def test_datahub_json_schemas_parses_okay(tmp_path): """This is more like an integration test that helps us exercise the complexity in parsing and catch unexpected regressions.""" json_path: Path = Path(os.path.dirname(__file__)) / Path( - "../../../../metadata-models/src/generatedJsonSchema/json/" + "../../../../metadata-models/src/generatedJsonSchema/json/", ) pipeline = Pipeline.create( config_dict={ @@ -539,7 +540,7 @@ def test_datahub_json_schemas_parses_okay(tmp_path): "filename": f"{tmp_path}/json_schema_test.json", }, }, - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -575,12 +576,12 @@ def test_key_schema_handling(): }, }, }, - ] - } + ], + }, }, } fields: List[SchemaField] = list( - JsonSchemaTranslator.get_fields_from_schema(schema, is_key_schema=True) + JsonSchemaTranslator.get_fields_from_schema(schema, is_key_schema=True), ) expected_field_paths: List[str] = [ "[version=2.0].[key=True].[type=ABFooUnion].[type=union].a", @@ -606,7 +607,7 @@ def test_ignore_exceptions(): "tags": ["business-timestamp"], } fields: List[SchemaField] = list( - JsonSchemaTranslator.get_fields_from_schema(malformed_schema) + JsonSchemaTranslator.get_fields_from_schema(malformed_schema), ) assert not fields @@ -870,7 +871,7 @@ def test_description_extraction(): "type": "array", "items": {"type": "string"}, "description": "XYZ", - } + }, }, } fields = list(JsonSchemaTranslator.get_fields_from_schema(schema)) diff --git a/metadata-ingestion/tests/unit/sdk/test_client.py b/metadata-ingestion/tests/unit/sdk/test_client.py index 16795ef8c7f814..696a22c8defd5f 100644 --- a/metadata-ingestion/tests/unit/sdk/test_client.py +++ b/metadata-ingestion/tests/unit/sdk/test_client.py @@ -19,7 +19,7 @@ def test_get_aspect(mock_test_connection): return_value={ "version": 0, "aspect": {"com.linkedin.identity.CorpUserEditableInfo": {}}, - } + }, ) mock_get.return_value = mock_response editable = graph.get_aspect(user_urn, CorpUserEditableInfoClass) diff --git a/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py b/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py index 8154f179c36b8e..30a08386268f33 100644 --- a/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py +++ b/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py @@ -15,7 +15,7 @@ class KafkaEmitterTest(unittest.TestCase): def test_kafka_emitter_config(self): emitter_config = KafkaEmitterConfig.parse_obj( - {"connection": {"bootstrap": "foobar:9092"}} + {"connection": {"bootstrap": "foobar:9092"}}, ) assert emitter_config.topic_routes[MCE_KEY] == DEFAULT_MCE_KAFKA_TOPIC assert emitter_config.topic_routes[MCP_KEY] == DEFAULT_MCP_KAFKA_TOPIC @@ -31,7 +31,7 @@ def test_kafka_emitter_config_old_and_new(self): "connection": {"bootstrap": "foobar:9092"}, "topic": "NewTopic", "topic_routes": {MCE_KEY: "NewTopic"}, - } + }, ) """ @@ -40,7 +40,7 @@ def test_kafka_emitter_config_old_and_new(self): def test_kafka_emitter_config_topic_upgrade(self): emitter_config = KafkaEmitterConfig.parse_obj( - {"connection": {"bootstrap": "foobar:9092"}, "topic": "NewTopic"} + {"connection": {"bootstrap": "foobar:9092"}, "topic": "NewTopic"}, ) assert emitter_config.topic_routes[MCE_KEY] == "NewTopic" # MCE topic upgraded assert ( diff --git a/metadata-ingestion/tests/unit/sdk/test_mcp_builder.py b/metadata-ingestion/tests/unit/sdk/test_mcp_builder.py index e304edb24789cd..b84956b608441c 100644 --- a/metadata-ingestion/tests/unit/sdk/test_mcp_builder.py +++ b/metadata-ingestion/tests/unit/sdk/test_mcp_builder.py @@ -4,7 +4,10 @@ def test_guid_generator(): key = builder.SchemaKey( - database="test", schema="Test", platform="mysql", instance="TestInstance" + database="test", + schema="Test", + platform="mysql", + instance="TestInstance", ) guid = key.guid() @@ -78,7 +81,10 @@ def test_guid_generator_with_env(): def test_guid_generators(): key = builder.SchemaKey( - database="test", schema="Test", platform="mysql", instance="TestInstance" + database="test", + schema="Test", + platform="mysql", + instance="TestInstance", ) guid_datahub = key.guid() diff --git a/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py b/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py index 81120dfc87aba3..0aa414ef004ded 100644 --- a/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py +++ b/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py @@ -18,7 +18,9 @@ def test_datahub_rest_emitter_construction() -> None: def test_datahub_rest_emitter_timeout_construction() -> None: emitter = DatahubRestEmitter( - MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4 + MOCK_GMS_ENDPOINT, + connect_timeout_sec=2, + read_timeout_sec=4, ) assert emitter._session_config.timeout == (2, 4) @@ -40,7 +42,8 @@ def test_datahub_rest_emitter_retry_construction() -> None: def test_datahub_rest_emitter_extra_params() -> None: emitter = DatahubRestEmitter( - MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"} + MOCK_GMS_ENDPOINT, + extra_headers={"key1": "value1", "key2": "value2"}, ) assert emitter._session.headers.get("key1") == "value1" assert emitter._session.headers.get("key2") == "value2" diff --git a/metadata-ingestion/tests/unit/serde/test_codegen.py b/metadata-ingestion/tests/unit/serde/test_codegen.py index 13fcf3d919cc03..3516b98b480a2f 100644 --- a/metadata-ingestion/tests/unit/serde/test_codegen.py +++ b/metadata-ingestion/tests/unit/serde/test_codegen.py @@ -21,7 +21,7 @@ _UPDATE_ENTITY_REGISTRY = os.getenv("UPDATE_ENTITY_REGISTRY", "false").lower() == "true" ENTITY_REGISTRY_PATH = pathlib.Path( - "../metadata-models/src/main/resources/entity-registry.yml" + "../metadata-models/src/main/resources/entity-registry.yml", ) @@ -71,7 +71,7 @@ def test_urn_annotation(): == "DatasetUrn" ) assert not UpstreamClass.RECORD_SCHEMA.fields_dict["dataset"].get_prop( - "urn_is_array" + "urn_is_array", ) assert ( @@ -79,7 +79,7 @@ def test_urn_annotation(): == "Urn" ) assert FineGrainedLineageClass.RECORD_SCHEMA.fields_dict["upstreams"].get_prop( - "urn_is_array" + "urn_is_array", ) @@ -96,7 +96,7 @@ def _add_to_registry(entity: str, aspect: str) -> None: break else: raise ValueError( - f'could not find entity "{entity}" in entity registry at {ENTITY_REGISTRY_PATH}' + f'could not find entity "{entity}" in entity registry at {ENTITY_REGISTRY_PATH}', ) # Prevent line wrapping + preserve indentation. @@ -116,14 +116,15 @@ def _err(msg: str) -> None: errors.append(msg) snapshot_classes: List[Type] = typing_inspect.get_args( - typing.get_type_hints(MetadataChangeEventClass.__init__)["proposedSnapshot"] + typing.get_type_hints(MetadataChangeEventClass.__init__)["proposedSnapshot"], ) lowercase_entity_type_map = {name.lower(): name for name in KEY_ASPECTS} for snapshot_class in snapshot_classes: lowercase_entity_type: str = snapshot_class.__name__.replace( - "SnapshotClass", "" + "SnapshotClass", + "", ).lower() entity_type = lowercase_entity_type_map.get(lowercase_entity_type) if entity_type is None: @@ -135,8 +136,8 @@ def _err(msg: str) -> None: snapshot_aspect_types: List[Type[_Aspect]] = typing_inspect.get_args( typing_inspect.get_args( - typing.get_type_hints(snapshot_class.__init__)["aspects"] - )[0] + typing.get_type_hints(snapshot_class.__init__)["aspects"], + )[0], ) # print(f"Entity type: {entity_type}") @@ -153,7 +154,7 @@ def _err(msg: str) -> None: _add_to_registry(entity_type, aspect_name) else: _err( - f"entity {entity_type}: aspect {aspect_name} is missing from the entity registry" + f"entity {entity_type}: aspect {aspect_name} is missing from the entity registry", ) assert not errors, ( diff --git a/metadata-ingestion/tests/unit/serde/test_serde.py b/metadata-ingestion/tests/unit/serde/test_serde.py index a131ac9ce2a1bc..4d1a0d4c18a2db 100644 --- a/metadata-ingestion/tests/unit/serde/test_serde.py +++ b/metadata-ingestion/tests/unit/serde/test_serde.py @@ -40,7 +40,9 @@ ], ) def test_serde_to_json( - pytestconfig: pytest.Config, tmp_path: pathlib.Path, json_filename: str + pytestconfig: pytest.Config, + tmp_path: pathlib.Path, + json_filename: str, ) -> None: golden_file = pytestconfig.rootpath / json_filename output_file = tmp_path / "output.json" @@ -50,7 +52,7 @@ def test_serde_to_json( "source": {"type": "file", "config": {"filename": str(golden_file)}}, "sink": {"type": "file", "config": {"filename": str(output_file)}}, "run_id": "serde_test", - } + }, ) pipeline.run() pipeline.raise_from_status() @@ -81,13 +83,14 @@ def test_serde_to_avro( with patch("datahub.ingestion.api.common.PipelineContext") as mock_pipeline_context: json_path = pytestconfig.rootpath / json_filename source = GenericFileSource( - ctx=mock_pipeline_context, config=FileSourceConfig(path=str(json_path)) + ctx=mock_pipeline_context, + config=FileSourceConfig(path=str(json_path)), ) mces = list(source.iterate_mce_file(str(json_path))) # Serialize to Avro. parsed_schema = fastavro.parse_schema( - json.loads(getMetadataChangeEventSchema()) + json.loads(getMetadataChangeEventSchema()), ) fo = io.BytesIO() out_records = [mce.to_obj(tuples=True) for mce in mces] @@ -132,7 +135,8 @@ def test_check_metadata_schema(pytestconfig: pytest.Config, json_filename: str) def test_check_metadata_rewrite( - pytestconfig: pytest.Config, tmp_path: pathlib.Path + pytestconfig: pytest.Config, + tmp_path: pathlib.Path, ) -> None: json_input = ( pytestconfig.rootpath / "tests/unit/serde/test_canonicalization_input.json" @@ -144,11 +148,13 @@ def test_check_metadata_rewrite( output_file_path = tmp_path / "output.json" shutil.copyfile(json_input, output_file_path) run_datahub_cmd( - ["check", "metadata-file", f"{output_file_path}", "--rewrite", "--unpack-mces"] + ["check", "metadata-file", f"{output_file_path}", "--rewrite", "--unpack-mces"], ) mce_helpers.check_golden_file( - pytestconfig, output_path=output_file_path, golden_path=json_output_reference + pytestconfig, + output_path=output_file_path, + golden_path=json_output_reference, ) @@ -160,7 +166,8 @@ def test_check_metadata_rewrite( ], ) def test_check_mce_schema_failure( - pytestconfig: pytest.Config, json_filename: str + pytestconfig: pytest.Config, + json_filename: str, ) -> None: json_file_path = pytestconfig.rootpath / json_filename @@ -189,7 +196,9 @@ def test_field_discriminator() -> None: def test_type_error() -> None: dataflow = models.DataFlowSnapshotClass( urn=mce_builder.make_data_flow_urn( - orchestrator="argo", flow_id="42", cluster="DEV" + orchestrator="argo", + flow_id="42", + cluster="DEV", ), aspects=[ models.DataFlowInfoClass( @@ -198,7 +207,7 @@ def test_type_error() -> None: externalUrl="http://example.com", # This is a type error - custom properties should be a Dict[str, str]. customProperties={"x": 1}, # type: ignore - ) + ), ], ) @@ -237,10 +246,10 @@ def test_missing_optional_simple() -> None: "condition": "EQUALS", "field": "TYPE", "values": ["notebook", "dataset", "dashboard"], - } - ] + }, + ], }, - } + }, ) # This one is missing the optional filters.allResources field. @@ -251,8 +260,8 @@ def test_missing_optional_simple() -> None: "condition": "EQUALS", "field": "TYPE", "values": ["notebook", "dataset", "dashboard"], - } - ] + }, + ], }, } revised = models.DataHubResourceFilterClass.from_obj(revised_obj) @@ -264,13 +273,13 @@ def test_missing_optional_simple() -> None: def test_missing_optional_in_union() -> None: # This one doesn't contain any optional fields and should work fine. revised_json = json.loads( - '{"lastUpdatedTimestamp":1662356745807,"actors":{"groups":[],"resourceOwners":false,"allUsers":true,"allGroups":false,"users":[]},"privileges":["EDIT_ENTITY_ASSERTIONS","EDIT_DATASET_COL_GLOSSARY_TERMS","EDIT_DATASET_COL_TAGS","EDIT_DATASET_COL_DESCRIPTION"],"displayName":"customtest","resources":{"filter":{"criteria":[{"field":"TYPE","condition":"EQUALS","values":["notebook","dataset","dashboard"]}]},"allResources":false},"description":"","state":"ACTIVE","type":"METADATA"}' + '{"lastUpdatedTimestamp":1662356745807,"actors":{"groups":[],"resourceOwners":false,"allUsers":true,"allGroups":false,"users":[]},"privileges":["EDIT_ENTITY_ASSERTIONS","EDIT_DATASET_COL_GLOSSARY_TERMS","EDIT_DATASET_COL_TAGS","EDIT_DATASET_COL_DESCRIPTION"],"displayName":"customtest","resources":{"filter":{"criteria":[{"field":"TYPE","condition":"EQUALS","values":["notebook","dataset","dashboard"]}]},"allResources":false},"description":"","state":"ACTIVE","type":"METADATA"}', ) revised = models.DataHubPolicyInfoClass.from_obj(revised_json) # This one is missing the optional filters.allResources field. original_json = json.loads( - '{"privileges":["EDIT_ENTITY_ASSERTIONS","EDIT_DATASET_COL_GLOSSARY_TERMS","EDIT_DATASET_COL_TAGS","EDIT_DATASET_COL_DESCRIPTION"],"actors":{"resourceOwners":false,"groups":[],"allGroups":false,"allUsers":true,"users":[]},"lastUpdatedTimestamp":1662356745807,"displayName":"customtest","description":"","resources":{"filter":{"criteria":[{"field":"TYPE","condition":"EQUALS","values":["notebook","dataset","dashboard"]}]}},"state":"ACTIVE","type":"METADATA"}' + '{"privileges":["EDIT_ENTITY_ASSERTIONS","EDIT_DATASET_COL_GLOSSARY_TERMS","EDIT_DATASET_COL_TAGS","EDIT_DATASET_COL_DESCRIPTION"],"actors":{"resourceOwners":false,"groups":[],"allGroups":false,"allUsers":true,"users":[]},"lastUpdatedTimestamp":1662356745807,"displayName":"customtest","description":"","resources":{"filter":{"criteria":[{"field":"TYPE","condition":"EQUALS","values":["notebook","dataset","dashboard"]}]}},"state":"ACTIVE","type":"METADATA"}', ) original = models.DataHubPolicyInfoClass.from_obj(original_json) @@ -286,9 +295,9 @@ def test_reserved_keywords() -> None: models.ConjunctiveCriterionClass( and_=[ models.CriterionClass(field="foo", value="var", negated=True), - ] - ) - ] + ], + ), + ], ) assert "or" in filter2.to_obj() @@ -302,13 +311,15 @@ def test_read_empty_dict() -> None: model = models.AssertionResultClass.from_obj(json.loads(original)) assert model.nativeResults == {} assert model == models.AssertionResultClass( - type=models.AssertionResultTypeClass.SUCCESS, nativeResults={} + type=models.AssertionResultTypeClass.SUCCESS, + nativeResults={}, ) def test_write_optional_empty_dict() -> None: model = models.AssertionResultClass( - type=models.AssertionResultTypeClass.SUCCESS, nativeResults={} + type=models.AssertionResultTypeClass.SUCCESS, + nativeResults={}, ) assert model.nativeResults == {} @@ -329,7 +340,7 @@ def test_write_optional_empty_dict() -> None: fieldDiscriminator=models.CostCostDiscriminatorClass.costCode, costCode="sampleCostCode", ), - ) + ), ], ), { @@ -339,8 +350,8 @@ def test_write_optional_empty_dict() -> None: "com.linkedin.common.Cost": { "costType": "ORG_COST_TYPE", "cost": {"costCode": "sampleCostCode"}, - } - } + }, + }, ], }, ), diff --git a/metadata-ingestion/tests/unit/serde/test_urn_iterator.py b/metadata-ingestion/tests/unit/serde/test_urn_iterator.py index 135580dcdff13e..7f9ed9778c9b1c 100644 --- a/metadata-ingestion/tests/unit/serde/test_urn_iterator.py +++ b/metadata-ingestion/tests/unit/serde/test_urn_iterator.py @@ -116,14 +116,16 @@ def test_upstream_lineage_urn_iterator(): def _make_test_lineage_obj( - table: str, upstream: str, downstream: str + table: str, + upstream: str, + downstream: str, ) -> MetadataChangeProposalWrapper: lineage = UpstreamLineage( upstreams=[ Upstream( dataset=_datasetUrn(upstream), type=DatasetLineageTypeClass.TRANSFORMED, - ) + ), ], fineGrainedLineages=[ FineGrainedLineage( @@ -140,11 +142,15 @@ def _make_test_lineage_obj( def test_dataset_urn_lowercase_transformer(): original = _make_test_lineage_obj( - "mainTableName", "upstreamTable", "downstreamTable" + "mainTableName", + "upstreamTable", + "downstreamTable", ) expected = _make_test_lineage_obj( - "maintablename", "upstreamtable", "downstreamtable" + "maintablename", + "upstreamtable", + "downstreamtable", ) assert original != expected # sanity check diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_shares.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_shares.py index 2e78f0bb3ae659..4a8579efc6ae1c 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_shares.py @@ -84,7 +84,7 @@ def snowflake_databases() -> List[SnowflakeDatabase]: last_altered=None, tables=["table311", "table312"], views=["view311"], - ) + ), ], ), ] @@ -92,7 +92,9 @@ def snowflake_databases() -> List[SnowflakeDatabase]: def make_snowflake_urn(table_name, instance_name=None): return make_dataset_urn_with_platform_instance( - "snowflake", table_name, instance_name + "snowflake", + table_name, + instance_name, ) @@ -122,14 +124,14 @@ def test_same_database_inbound_and_outbound_invalid_config() -> None: database="db1", platform_instance="instance2", consumers=[ - DatabaseId(database="db1", platform_instance="instance1") + DatabaseId(database="db1", platform_instance="instance1"), ], ), "share2": SnowflakeShareConfig( database="db1", platform_instance="instance3", consumers=[ - DatabaseId(database="db1", platform_instance="instance1") + DatabaseId(database="db1", platform_instance="instance1"), ], ), }, @@ -147,14 +149,14 @@ def test_same_database_inbound_and_outbound_invalid_config() -> None: database="db1", platform_instance="instance2", consumers=[ - DatabaseId(database="db1", platform_instance="instance1") + DatabaseId(database="db1", platform_instance="instance1"), ], ), "share2": SnowflakeShareConfig( database="db1", platform_instance="instance1", consumers=[ - DatabaseId(database="db1", platform_instance="instance3") + DatabaseId(database="db1", platform_instance="instance3"), ], ), }, @@ -172,14 +174,14 @@ def test_same_database_inbound_and_outbound_invalid_config() -> None: database="db1", platform_instance="instance1", consumers=[ - DatabaseId(database="db1", platform_instance="instance3") + DatabaseId(database="db1", platform_instance="instance3"), ], ), "share1": SnowflakeShareConfig( database="db1", platform_instance="instance2", consumers=[ - DatabaseId(database="db1", platform_instance="instance1") + DatabaseId(database="db1", platform_instance="instance1"), ], ), }, @@ -197,7 +199,7 @@ def test_snowflake_shares_workunit_inbound_share( database="db1", platform_instance="instance2", consumers=[DatabaseId(database="db1", platform_instance="instance1")], - ) + ), }, ) @@ -214,14 +216,16 @@ def test_snowflake_shares_workunit_inbound_share( for wu in wus: assert isinstance( - wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + wu.metadata, + (MetadataChangeProposal, MetadataChangeProposalWrapper), ) if wu.metadata.aspectName == "upstreamLineage": upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) assert upstream_aspect is not None assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( - "instance1.db1", "instance2.db1" + "instance1.db1", + "instance2.db1", ) upstream_lineage_aspect_entity_urns.add(wu.get_urn()) else: @@ -230,7 +234,7 @@ def test_snowflake_shares_workunit_inbound_share( assert not siblings_aspect.primary assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.db1") + wu.get_urn().replace("instance1.db1", "instance2.db1"), ] sibling_aspect_entity_urns.add(wu.get_urn()) @@ -249,11 +253,12 @@ def test_snowflake_shares_workunit_outbound_share( platform_instance="instance1", consumers=[ DatabaseId( - database="db2_from_share", platform_instance="instance2" + database="db2_from_share", + platform_instance="instance2", ), DatabaseId(database="db2", platform_instance="instance3"), ], - ) + ), }, ) @@ -298,7 +303,8 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( platform_instance="instance1", consumers=[ DatabaseId( - database="db2_from_share", platform_instance="instance2" + database="db2_from_share", + platform_instance="instance2", ), DatabaseId(database="db2", platform_instance="instance3"), ], @@ -317,14 +323,16 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( for wu in wus: assert isinstance( - wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + wu.metadata, + (MetadataChangeProposal, MetadataChangeProposalWrapper), ) if wu.metadata.aspectName == "upstreamLineage": upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) assert upstream_aspect is not None assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( - "instance1.db1", "instance2.db1" + "instance1.db1", + "instance2.db1", ) else: siblings_aspect = wu.get_aspect_of_type(Siblings) @@ -333,7 +341,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( assert not siblings_aspect.primary assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("instance1.db1", "instance2.db1") + wu.get_urn().replace("instance1.db1", "instance2.db1"), ] else: assert siblings_aspect.primary @@ -385,14 +393,16 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instan for wu in wus: assert isinstance( - wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper) + wu.metadata, + (MetadataChangeProposal, MetadataChangeProposalWrapper), ) if wu.metadata.aspectName == "upstreamLineage": upstream_aspect = wu.get_aspect_of_type(UpstreamLineage) assert upstream_aspect is not None assert len(upstream_aspect.upstreams) == 1 assert upstream_aspect.upstreams[0].dataset == wu.get_urn().replace( - "db2.", "db2_main." + "db2.", + "db2_main.", ) else: siblings_aspect = wu.get_aspect_of_type(Siblings) @@ -408,5 +418,5 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instan assert not siblings_aspect.primary assert len(siblings_aspect.siblings) == 1 assert siblings_aspect.siblings == [ - wu.get_urn().replace("db2.", "db2_main.") + wu.get_urn().replace("db2.", "db2_main."), ] diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py index 75f32b535eb2e8..50fa9192c90589 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py @@ -47,7 +47,7 @@ def test_snowflake_source_throws_error_on_account_id_missing(): { "username": "user", "password": "password", - } + }, ) @@ -72,7 +72,7 @@ def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_f "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", "oauth_config": oauth_dict, - } + }, ) @@ -89,7 +89,7 @@ def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_cert "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", "oauth_config": oauth_dict, - } + }, ) @@ -99,14 +99,15 @@ def test_snowflake_oauth_okta_does_not_support_certificate(): oauth_dict["provider"] = "okta" OAuthConfiguration.parse_obj(oauth_dict) with pytest.raises( - ValueError, match="Certificate authentication is not supported for Okta." + ValueError, + match="Certificate authentication is not supported for Okta.", ): SnowflakeV2Config.parse_obj( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", "oauth_config": oauth_dict, - } + }, ) @@ -118,7 +119,7 @@ def test_snowflake_oauth_happy_paths(): "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", "oauth_config": oauth_dict, - } + }, ) oauth_dict["use_certificate"] = True oauth_dict["provider"] = "microsoft" @@ -129,7 +130,7 @@ def test_snowflake_oauth_happy_paths(): "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", "oauth_config": oauth_dict, - } + }, ) @@ -141,20 +142,21 @@ def test_snowflake_oauth_token_happy_path(): "token": "valid-token", "username": "test-user", "oauth_config": None, - } + }, ) def test_snowflake_oauth_token_without_token(): with pytest.raises( - ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN." + ValidationError, + match="Token required for OAUTH_AUTHENTICATOR_TOKEN.", ): SnowflakeV2Config.parse_obj( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR_TOKEN", "username": "test-user", - } + }, ) @@ -169,13 +171,14 @@ def test_snowflake_oauth_token_with_wrong_auth_type(): "authentication_type": "OAUTH_AUTHENTICATOR", "token": "some-token", "username": "test-user", - } + }, ) def test_snowflake_oauth_token_with_empty_token(): with pytest.raises( - ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN." + ValidationError, + match="Token required for OAUTH_AUTHENTICATOR_TOKEN.", ): SnowflakeV2Config.parse_obj( { @@ -183,7 +186,7 @@ def test_snowflake_oauth_token_with_empty_token(): "authentication_type": "OAUTH_AUTHENTICATOR_TOKEN", "token": "", "username": "test-user", - } + }, ) @@ -286,7 +289,7 @@ def test_private_key_set_but_auth_not_changed(): { "account_id": "acctname", "private_key_path": "/a/random/path", - } + }, ) @@ -305,17 +308,20 @@ def test_snowflake_config_with_connect_args_overrides_base_connect_args(): def test_test_connection_failure(mock_connect): mock_connect.side_effect = Exception("Failed to connect to snowflake") report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_failure( - report, "Failed to connect to snowflake" + report, + "Failed to connect to snowflake", ) @patch("snowflake.connector.connect") def test_test_connection_basic_success(mock_connect): report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -360,7 +366,8 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -368,7 +375,7 @@ def query_results(query): capability_report=report.capability_report, success_capabilities=[SourceCapability.CONTAINERS], failure_capabilities={ - SourceCapability.SCHEMA_METADATA: "Current role TEST_ROLE does not have permissions to use warehouse" + SourceCapability.SCHEMA_METADATA: "Current role TEST_ROLE does not have permissions to use warehouse", }, ) @@ -383,7 +390,8 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -391,7 +399,7 @@ def query_results(query): capability_report=report.capability_report, success_capabilities=[SourceCapability.CONTAINERS], failure_capabilities={ - SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them" + SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them", }, ) @@ -414,7 +422,8 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -467,7 +476,8 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) report = test_connection_helpers.run_test_connection( - SnowflakeV2Source, default_config_dict + SnowflakeV2Source, + default_config_dict, ) test_connection_helpers.assert_basic_connectivity_success(report) @@ -488,7 +498,7 @@ def test_aws_cloud_region_from_snowflake_region_id(): cloud, cloud_region_id, ) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "aws_ca_central_1" + "aws_ca_central_1", ) assert cloud == SnowflakeCloudProvider.AWS @@ -498,7 +508,7 @@ def test_aws_cloud_region_from_snowflake_region_id(): cloud, cloud_region_id, ) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "aws_us_east_1_gov" + "aws_us_east_1_gov", ) assert cloud == SnowflakeCloudProvider.AWS @@ -510,7 +520,7 @@ def test_google_cloud_region_from_snowflake_region_id(): cloud, cloud_region_id, ) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "gcp_europe_west2" + "gcp_europe_west2", ) assert cloud == SnowflakeCloudProvider.GCP @@ -522,7 +532,7 @@ def test_azure_cloud_region_from_snowflake_region_id(): cloud, cloud_region_id, ) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "azure_switzerlandnorth" + "azure_switzerlandnorth", ) assert cloud == SnowflakeCloudProvider.AZURE @@ -532,7 +542,7 @@ def test_azure_cloud_region_from_snowflake_region_id(): cloud, cloud_region_id, ) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "azure_centralindia" + "azure_centralindia", ) assert cloud == SnowflakeCloudProvider.AZURE @@ -542,7 +552,7 @@ def test_azure_cloud_region_from_snowflake_region_id(): def test_unknown_cloud_region_from_snowflake_region_id(): with pytest.raises(Exception, match="Unknown snowflake region"): SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id( - "somecloud_someregion" + "somecloud_someregion", ) @@ -555,7 +565,7 @@ def test_snowflake_object_access_entry_missing_object_id(): ], "objectDomain": "View", "objectName": "SOME.OBJECT.NAME", - } + }, ) @@ -578,7 +588,8 @@ def test_snowflake_query_create_deny_regex_sql(): assert ( create_deny_regex_sql_filter( - DEFAULT_TEMP_TABLES_PATTERNS, ["upstream_table_name"] + DEFAULT_TEMP_TABLES_PATTERNS, + ["upstream_table_name"], ) == r"NOT RLIKE(upstream_table_name,'.*\.FIVETRAN_.*_STAGING\..*','i') AND NOT RLIKE(upstream_table_name,'.*__DBT_TMP$','i') AND NOT RLIKE(upstream_table_name,'.*\.SEGMENT_[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i') AND NOT RLIKE(upstream_table_name,'.*\.STAGING_.*_[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i') AND NOT RLIKE(upstream_table_name,'.*\.(GE_TMP_|GE_TEMP_|GX_TEMP_)[0-9A-F]{8}','i')" ) @@ -591,7 +602,7 @@ def test_snowflake_temporary_patterns_config_rename(): "username": "user", "password": "password", "upstreams_deny_pattern": [".*tmp.*"], - } + }, ) assert conf.temporary_tables_pattern == [".*tmp.*"] @@ -628,7 +639,9 @@ def test_email_filter_query_generation_one_allow_and_deny(): def test_email_filter_query_generation_with_case_insensitive_filter(): email_filter = AllowDenyPattern( - allow=[".*@example.com"], deny=[".*@example2.com"], ignoreCase=False + allow=[".*@example.com"], + deny=[".*@example2.com"], + ignoreCase=False, ) filter_query = SnowflakeQuery.gen_email_filter_query(email_filter) assert ( @@ -639,14 +652,18 @@ def test_email_filter_query_generation_with_case_insensitive_filter(): def test_create_snowsight_base_url_us_west(): result = SnowsightUrlBuilder( - "account_locator", "aws_us_west_2", privatelink=False + "account_locator", + "aws_us_west_2", + privatelink=False, ).snowsight_base_url assert result == "https://app.snowflake.com/us-west-2/account_locator/" def test_create_snowsight_base_url_ap_northeast_1(): result = SnowsightUrlBuilder( - "account_locator", "aws_ap_northeast_1", privatelink=False + "account_locator", + "aws_ap_northeast_1", + privatelink=False, ).snowsight_base_url assert result == "https://app.snowflake.com/ap-northeast-1.aws/account_locator/" @@ -664,7 +681,7 @@ def test_using_removed_fields_causes_no_error() -> None: "password": "snowflake", "include_view_lineage": "true", "include_view_column_lineage": "true", - } + }, ) @@ -677,7 +694,7 @@ def test_snowflake_query_result_parsing(): "query_id": "01b92f61-0611-c826-000d-0103cf9b5db7", "upstream_object_domain": "Table", "upstream_object_name": "db.schema.upstream_table", - } + }, ], "UPSTREAM_COLUMNS": [{}], "QUERIES": [ @@ -685,7 +702,7 @@ def test_snowflake_query_result_parsing(): "query_id": "01b92f61-0611-c826-000d-0103cf9b5db7", "query_text": "Query test", "start_time": "2022-12-01 19:56:34", - } + }, ], } assert UpstreamLineageEdge.parse_obj(db_row) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_schemaresolver.py b/metadata-ingestion/tests/unit/sql_parsing/test_schemaresolver.py index 67222531d3bc15..e8bc25462ae15a 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_schemaresolver.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_schemaresolver.py @@ -29,7 +29,7 @@ def test_basic_schema_resolver(): schema_resolver = create_default_schema_resolver(urn=input_urn) urn, schema = schema_resolver.resolve_table( - _TableName(database="my_db", db_schema="public", table="test_table") + _TableName(database="my_db", db_schema="public", table="test_table"), ) assert urn == input_urn @@ -106,7 +106,8 @@ def test_match_columns_to_schema(): schema_info: SchemaInfo = {"id": "string", "Name": "string", "Address": "string"} output_columns = match_columns_to_schema( - schema_info, input_columns=["Id", "name", "address", "weight"] + schema_info, + input_columns=["Id", "name", "address", "weight"], ) assert output_columns == ["id", "Name", "Address", "weight"] diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py index 2a771a9847abd8..2c397b733d4919 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py @@ -33,7 +33,8 @@ FROZEN_TIME = "2024-02-06T01:23:45Z" check_goldens_stream = functools.partial( - mce_helpers.check_goldens_stream, ignore_order=False + mce_helpers.check_goldens_stream, + ignore_order=False, ) @@ -56,7 +57,7 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N query="create table foo as select a, b from bar", default_db="dev", default_schema="public", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -78,10 +79,12 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N str(query_log_db), "--output", str(query_log_json), - ] + ], ) mce_helpers.check_golden_file( - pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json" + pytestconfig, + query_log_json, + RESOURCE_DIR / "test_basic_lineage_query_log.json", ) @@ -100,7 +103,7 @@ def test_overlapping_inserts(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", timestamp=_ts(20), - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -108,7 +111,7 @@ def test_overlapping_inserts(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", timestamp=_ts(25), - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -140,7 +143,7 @@ def test_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session1", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -148,7 +151,7 @@ def test_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session2", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -156,7 +159,7 @@ def test_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session2", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -164,7 +167,7 @@ def test_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session3", - ) + ), ) # foo_session2 should come from bar (via temp table foo), have columns a and c, and depend on bar.{a,b,c} @@ -194,7 +197,7 @@ def test_multistep_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session1", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -202,7 +205,7 @@ def test_multistep_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session1", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -210,7 +213,7 @@ def test_multistep_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session1", - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -218,7 +221,7 @@ def test_multistep_temp_table(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", session_id="session1", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -230,7 +233,7 @@ def test_multistep_temp_table(pytestconfig: pytest.Config) -> None: len( report.queries_with_temp_upstreams[ "composite_48c238412066895ccad5d27f9425ce969b2c0633203627eb476d0c9e5357825a" - ] + ], ) == 4 ) @@ -263,7 +266,7 @@ def test_overlapping_inserts_from_temp_tables(pytestconfig: pytest.Config) -> No default_db="dev", default_schema="public", session_id="1234", - ) + ), ) aggregator.add_observed_query( @@ -275,7 +278,7 @@ def test_overlapping_inserts_from_temp_tables(pytestconfig: pytest.Config) -> No default_db="dev", default_schema="public", session_id="2323", - ) + ), ) aggregator.add_observed_query( @@ -284,7 +287,7 @@ def test_overlapping_inserts_from_temp_tables(pytestconfig: pytest.Config) -> No default_db="dev", default_schema="public", session_id="1234", - ) + ), ) aggregator.add_observed_query( @@ -293,7 +296,7 @@ def test_overlapping_inserts_from_temp_tables(pytestconfig: pytest.Config) -> No default_db="dev", default_schema="public", session_id="2323", - ) + ), ) # We only have one create temp table, but the same insert command from multiple sessions. @@ -305,7 +308,7 @@ def test_overlapping_inserts_from_temp_tables(pytestconfig: pytest.Config) -> No default_db="dev", default_schema="public", session_id="5435", - ) + ), ) assert len(report.queries_with_non_authoritative_session) == 1 @@ -334,7 +337,7 @@ def test_aggregate_operations(pytestconfig: pytest.Config) -> None: default_schema="public", timestamp=_ts(20), user=CorpUserUrn("user1"), - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -343,7 +346,7 @@ def test_aggregate_operations(pytestconfig: pytest.Config) -> None: default_schema="public", timestamp=_ts(25), user=CorpUserUrn("user2"), - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -352,7 +355,7 @@ def test_aggregate_operations(pytestconfig: pytest.Config) -> None: default_schema="public", timestamp=_ts(26), user=CorpUserUrn("user3"), - ) + ), ) # The first query will basically be ignored, as it's a duplicate of the second one. @@ -449,14 +452,14 @@ def test_column_lineage_deduplication(pytestconfig: pytest.Config) -> None: query="/* query 1 */ insert into foo (a, b, c) select a, b, c from bar", default_db="dev", default_schema="public", - ) + ), ) aggregator.add_observed_query( ObservedQuery( query="/* query 2 */ insert into foo (a, b) select a, b from bar", default_db="dev", default_schema="public", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -537,7 +540,7 @@ def test_table_rename(pytestconfig: pytest.Config) -> None: TableRename( original_urn=DatasetUrn("redshift", "dev.public.foo_staging").urn(), new_urn=DatasetUrn("redshift", "dev.public.foo").urn(), - ) + ), ) # Add an unrelated query. @@ -546,7 +549,7 @@ def test_table_rename(pytestconfig: pytest.Config) -> None: query="create table bar as select a, b from baz", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the staging table. @@ -555,7 +558,7 @@ def test_table_rename(pytestconfig: pytest.Config) -> None: query="create table foo_staging as select a, b from foo_dep", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the downstream from foo_staging table. @@ -564,7 +567,7 @@ def test_table_rename(pytestconfig: pytest.Config) -> None: query="create table foo_downstream as select a, b from foo_staging", default_db="dev", default_schema="public", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -597,7 +600,7 @@ def test_table_rename_with_temp(pytestconfig: pytest.Config) -> None: original_urn=DatasetUrn("redshift", "dev.public.foo_staging").urn(), new_urn=DatasetUrn("redshift", "dev.public.foo").urn(), query="alter table dev.public.foo_staging rename to dev.public.foo", - ) + ), ) # Add an unrelated query. @@ -606,7 +609,7 @@ def test_table_rename_with_temp(pytestconfig: pytest.Config) -> None: query="create table bar as select a, b from baz", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the staging table. @@ -615,7 +618,7 @@ def test_table_rename_with_temp(pytestconfig: pytest.Config) -> None: query="create table foo_staging as select a, b from foo_dep", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the downstream from foo_staging table. @@ -624,7 +627,7 @@ def test_table_rename_with_temp(pytestconfig: pytest.Config) -> None: query="create table foo_downstream as select a, b from foo_staging", default_db="dev", default_schema="public", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -656,7 +659,7 @@ def test_table_swap(pytestconfig: pytest.Config) -> None: query="create table bar as select a, b from baz", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the swap table initially. @@ -666,7 +669,7 @@ def test_table_swap(pytestconfig: pytest.Config) -> None: query_text="CREATE TABLE person_info_swap CLONE person_info;", upstreams=[DatasetUrn("snowflake", "dev.public.person_info").urn()], downstream=DatasetUrn("snowflake", "dev.public.person_info_swap").urn(), - ) + ), ) # Add the query that created the incremental table. @@ -678,9 +681,10 @@ def test_table_swap(pytestconfig: pytest.Config) -> None: DatasetUrn("snowflake", "dev.public.person_info_dep").urn(), ], downstream=DatasetUrn( - "snowflake", "dev.public.person_info_incremental" + "snowflake", + "dev.public.person_info_incremental", ).urn(), - ) + ), ) # Add the query that updated the swap table. @@ -692,14 +696,14 @@ def test_table_swap(pytestconfig: pytest.Config) -> None: DatasetUrn("snowflake", "dev.public.person_info_incremental").urn(), ], downstream=DatasetUrn("snowflake", "dev.public.person_info_swap").urn(), - ) + ), ) aggregator.add_table_swap( TableSwap( urn1=DatasetUrn("snowflake", "dev.public.person_info").urn(), urn2=DatasetUrn("snowflake", "dev.public.person_info_swap").urn(), - ) + ), ) # Add the query that is created from swap table. @@ -711,7 +715,7 @@ def test_table_swap(pytestconfig: pytest.Config) -> None: DatasetUrn("snowflake", "dev.public.person_info_swap").urn(), ], downstream=DatasetUrn("snowflake", "dev.public.person_info_backup").urn(), - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -744,7 +748,7 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: query="create table bar as select a, b from baz", default_db="dev", default_schema="public", - ) + ), ) # Add the query that created the swap table initially. @@ -760,21 +764,23 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: ColumnLineageInfo( downstream=DownstreamColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_swap" + "snowflake", + "dev.public.person_info_swap", ).urn(), column="a", ), upstreams=[ ColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info" + "snowflake", + "dev.public.person_info", ).urn(), column="a", - ) + ), ], - ) + ), ], - ) + ), ) # Add the query that created the incremental table. @@ -786,7 +792,8 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: DatasetUrn("snowflake", "dev.public.person_info_dep").urn(), ], downstream=DatasetUrn( - "snowflake", "dev.public.person_info_incremental" + "snowflake", + "dev.public.person_info_incremental", ).urn(), session_id="xxx", timestamp=_ts(20), @@ -794,21 +801,23 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: ColumnLineageInfo( downstream=DownstreamColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_incremental" + "snowflake", + "dev.public.person_info_incremental", ).urn(), column="a", ), upstreams=[ ColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_dep" + "snowflake", + "dev.public.person_info_dep", ).urn(), column="a", - ) + ), ], - ) + ), ], - ) + ), ) # Add the query that updated the swap table. @@ -826,21 +835,23 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: ColumnLineageInfo( downstream=DownstreamColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_swap" + "snowflake", + "dev.public.person_info_swap", ).urn(), column="a", ), upstreams=[ ColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_incremental" + "snowflake", + "dev.public.person_info_incremental", ).urn(), column="a", - ) + ), ], - ) + ), ], - ) + ), ) aggregator.add_table_swap( @@ -849,7 +860,7 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: urn2=DatasetUrn("snowflake", "dev.public.person_info_swap").urn(), session_id="xxx", timestamp=_ts(40), - ) + ), ) # Add the query that is created from swap table. @@ -867,21 +878,23 @@ def test_table_swap_with_temp(pytestconfig: pytest.Config) -> None: ColumnLineageInfo( downstream=DownstreamColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_backup" + "snowflake", + "dev.public.person_info_backup", ).urn(), column="a", ), upstreams=[ ColumnRef( table=DatasetUrn( - "snowflake", "dev.public.person_info_swap" + "snowflake", + "dev.public.person_info_swap", ).urn(), column="a", - ) + ), ], - ) + ), ], - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -908,7 +921,7 @@ def test_create_table_query_mcps(pytestconfig: pytest.Config) -> None: default_db="dev", default_schema="public", timestamp=datetime.now(), - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -936,14 +949,14 @@ def test_table_lineage_via_temp_table_disordered_add( query="create table derived_from_foo as select * from foo", default_db="dev", default_schema="public", - ) + ), ) aggregator.add_observed_query( ObservedQuery( query="create temp table foo as select a, b+c as c from bar", default_db="dev", default_schema="public", - ) + ), ) mcps = list(aggregator.gen_metadata()) @@ -983,7 +996,7 @@ def test_basic_usage(pytestconfig: pytest.Config) -> None: usage_multiplier=5, timestamp=frozen_timestamp, user=CorpUserUrn("user1"), - ) + ), ) aggregator.add_observed_query( ObservedQuery( @@ -992,7 +1005,7 @@ def test_basic_usage(pytestconfig: pytest.Config) -> None: default_schema="public", timestamp=frozen_timestamp, user=CorpUserUrn("user2"), - ) + ), ) mcps = list(aggregator.gen_metadata()) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py index 2df6f6fa26cb56..5c87ed6fd5b102 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py @@ -10,7 +10,8 @@ @pytest.fixture(autouse=True) def set_update_sql_parser( - pytestconfig: pytest.Config, monkeypatch: pytest.MonkeyPatch + pytestconfig: pytest.Config, + monkeypatch: pytest.MonkeyPatch, ) -> None: update_golden = pytestconfig.getoption("--update-golden-files") diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_patch.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_patch.py index dee6d9630c12eb..df4861191ba2e2 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_patch.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_patch.py @@ -20,7 +20,7 @@ def test_cooperative_timeout_sql() -> None: statement = sqlglot.parse_one("SELECT pg_sleep(3)", dialect="postgres") with pytest.raises( - CooperativeTimeoutError + CooperativeTimeoutError, ), PerfTimer() as timer, cooperative_timeout(timeout=0.6): while True: # sql() implicitly calls copy(), which is where we check for the timeout. @@ -31,7 +31,7 @@ def test_cooperative_timeout_sql() -> None: def test_scope_circular_dependency() -> None: scope = sqlglot.optimizer.build_scope( - sqlglot.parse_one("WITH w AS (SELECT * FROM q) SELECT * FROM w") + sqlglot.parse_one("WITH w AS (SELECT * FROM q) SELECT * FROM w"), ) assert scope is not None diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py index c3c3a4a15d915b..b35555b42dd83f 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py @@ -34,7 +34,8 @@ def test_is_dialect_instance(): def test_query_types(): assert get_query_type_of_sql( sqlglot.parse_one( - "create temp table foo as select * from bar", dialect="redshift" + "create temp table foo as select * from bar", + dialect="redshift", ), dialect="redshift", ) == (QueryType.CREATE_TABLE_AS_SELECT, {"kind": "TABLE", "temporary": True}) @@ -105,7 +106,7 @@ class QueryGeneralizationTestMode(Enum): (Column1, Column2, Column3) VALUES ('John', 123, 'Lloyds Office'); - """ + """, ), "mssql", "INSERT INTO MyTable (Column1, Column2, Column3) VALUES (?)", @@ -138,7 +139,7 @@ class QueryGeneralizationTestMode(Enum): ('Jane', 124, 'Lloyds Office'), ('Billy', 125, 'London Office'), ('Miranda', 126, 'Bristol Office'); - """ + """, ), "mssql", "INSERT INTO MyTable (Column1, Column2, Column3) VALUES (?), (?), (?), (?)", @@ -167,7 +168,10 @@ class QueryGeneralizationTestMode(Enum): ], ) def test_query_generalization( - query: str, dialect: str, expected: str, mode: QueryGeneralizationTestMode + query: str, + dialect: str, + expected: str, + mode: QueryGeneralizationTestMode, ) -> None: if mode in {QueryGeneralizationTestMode.FULL, QueryGeneralizationTestMode.BOTH}: assert generalize_query(query, dialect=dialect) == expected @@ -180,11 +184,13 @@ def test_query_generalization( def test_query_fingerprint(): assert get_query_fingerprint( - "select * /* everything */ from foo where ts = 34", platform="redshift" + "select * /* everything */ from foo where ts = 34", + platform="redshift", ) == get_query_fingerprint("SELECT * FROM foo where ts = 38", platform="redshift") assert get_query_fingerprint( - "select 1 + 1", platform="postgres" + "select 1 + 1", + platform="postgres", ) != get_query_fingerprint("select 2", platform="postgres") @@ -193,8 +199,11 @@ def test_redshift_query_fingerprint(): query2 = "INSERT INTO insert_into_table (SELECT * FROM base_table)" assert get_query_fingerprint(query1, "redshift") == get_query_fingerprint( - query2, "redshift" + query2, + "redshift", ) assert get_query_fingerprint(query1, "redshift", True) != get_query_fingerprint( - query2, "redshift", True + query2, + "redshift", + True, ) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_tool_meta_extractor.py b/metadata-ingestion/tests/unit/sql_parsing/test_tool_meta_extractor.py index f6566f007f5e6b..8cdfc51037e177 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_tool_meta_extractor.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_tool_meta_extractor.py @@ -35,7 +35,8 @@ def test_extract_mode_metadata() -> None: def test_extract_looker_metadata() -> None: extractor = ToolMetaExtractor( - report=ToolMetaExtractorReport(), looker_user_mapping={"7": "john.doe@xyz.com"} + report=ToolMetaExtractorReport(), + looker_user_mapping={"7": "john.doe@xyz.com"}, ) looker_query = """\ SELECT diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_provider.py b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_provider.py index d097c75ebd952c..e87295291d23b3 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_provider.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_provider.py @@ -48,7 +48,8 @@ def _setup_mock_graph(self) -> None: Setup monkey-patched graph client. """ self.patcher = patch( - "datahub.ingestion.graph.client.DataHubGraph", autospec=True + "datahub.ingestion.graph.client.DataHubGraph", + autospec=True, ) self.addCleanup(self.patcher.stop) self.mock_graph = self.patcher.start() @@ -56,18 +57,21 @@ def _setup_mock_graph(self) -> None: self.mock_graph.get_config.return_value = {"statefulIngestionCapable": True} # Bind mock_graph's emit_mcp to testcase's monkey_patch_emit_mcp so that we can emulate emits. self.mock_graph.emit_mcp = types.MethodType( - self.monkey_patch_emit_mcp, self.mock_graph + self.monkey_patch_emit_mcp, + self.mock_graph, ) # Bind mock_graph's get_latest_timeseries_value to monkey_patch_get_latest_timeseries_value self.mock_graph.get_latest_timeseries_value = types.MethodType( - self.monkey_patch_get_latest_timeseries_value, self.mock_graph + self.monkey_patch_get_latest_timeseries_value, + self.mock_graph, ) # Tracking for emitted mcps. self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {} def _create_providers(self) -> None: ctx: PipelineContext = PipelineContext( - run_id=self.run_id, pipeline_name=self.pipeline_name + run_id=self.run_id, + pipeline_name=self.pipeline_name, ) ctx.graph = self.mock_graph self.providers: List[IngestionCheckpointingProviderBase] = [ @@ -79,7 +83,9 @@ def _create_providers(self) -> None: ] def monkey_patch_emit_mcp( - self, graph_ref: MagicMock, mcpw: MetadataChangeProposalWrapper + self, + graph_ref: MagicMock, + mcpw: MetadataChangeProposalWrapper, ) -> None: """ Mockey patched implementation of DatahubGraph.emit_mcp that caches the mcp locally in memory. @@ -132,7 +138,8 @@ def test_providers(self): ) # Job2 - Checkpoint with a BaseTimeWindowCheckpointState state job2_state_obj = BaseTimeWindowCheckpointState( - begin_timestamp_millis=10, end_timestamp_millis=100 + begin_timestamp_millis=10, + end_timestamp_millis=100, ) job2_checkpoint = Checkpoint( job_name=self.job_names[1], @@ -145,10 +152,10 @@ def test_providers(self): provider.state_to_commit = { # NOTE: state_to_commit accepts only the aspect version of the checkpoint. self.job_names[0]: assert_not_null( - job1_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20) + job1_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20), ), self.job_names[1]: assert_not_null( - job2_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20) + job2_checkpoint.to_checkpoint_aspect(max_allowed_state_size=2**20), ), } @@ -162,10 +169,12 @@ def test_providers(self): # 4. Get last committed state. This must match what has been committed earlier. # NOTE: This will retrieve the state form where it is committed. job1_last_state = provider.get_latest_checkpoint( - self.pipeline_name, self.job_names[0] + self.pipeline_name, + self.job_names[0], ) job2_last_state = provider.get_latest_checkpoint( - self.pipeline_name, self.job_names[1] + self.pipeline_name, + self.job_names[1], ) # 5. Validate individual job checkpoint state values that have been committed and retrieved @@ -191,7 +200,8 @@ def test_state_provider_wrapper_with_config_provided(self): ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name) ctx.graph = self.mock_graph state_provider = StateProviderWrapper( - StatefulIngestionConfig(enabled=True), ctx + StatefulIngestionConfig(enabled=True), + ctx, ) assert state_provider.stateful_ingestion_config assert state_provider.ingestion_checkpointing_state_provider @@ -199,7 +209,8 @@ def test_state_provider_wrapper_with_config_provided(self): ctx = PipelineContext(run_id=self.run_id, pipeline_name=self.pipeline_name) ctx.graph = self.mock_graph state_provider = StateProviderWrapper( - StatefulIngestionConfig(enabled=False), ctx + StatefulIngestionConfig(enabled=False), + ctx, ) assert state_provider.stateful_ingestion_config assert not state_provider.ingestion_checkpointing_state_provider diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py index ecea3183393453..79b3d809223397 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py @@ -63,17 +63,21 @@ def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState: # all existing code uses correctly formed envs. base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState() base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( - type="table", urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.t1,prod)" + type="table", + urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.t1,prod)", ) base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( - type="view", urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.v1,prod)" + type="view", + urn="urn:li:dataset:(urn:li:dataPlatform:mysql,db1.v1,prod)", ) return base_sql_alchemy_checkpoint_state_obj def _make_usage_checkpoint_state() -> BaseTimeWindowCheckpointState: base_usage_checkpoint_state_obj = BaseTimeWindowCheckpointState( - version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100 + version="2.0", + begin_timestamp_millis=1, + end_timestamp_millis=100, ) return base_usage_checkpoint_state_obj @@ -128,7 +132,7 @@ def test_serde_idempotence(state_obj): # 2. Convert it to the aspect form. checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect( - max_allowed_state_size=2**20 + max_allowed_state_size=2**20, ) assert checkpoint_aspect is not None @@ -146,7 +150,9 @@ def test_supported_encodings(): Tests utf-8 and base85-bz2-json encodings """ test_state = BaseTimeWindowCheckpointState( - version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100 + version="1.0", + begin_timestamp_millis=1, + end_timestamp_millis=100, ) # 1. Test UTF-8 encoding @@ -163,11 +169,14 @@ def test_base85_upgrade_pickle_to_json(): base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao6Yg3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n None: state2 = BaseSQLAlchemyCheckpointState() dataset_urns_diff = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) assert len(dataset_urns_diff) == 2 and sorted(dataset_urns_diff) == sorted( [ test_table_urn, test_view_urn, - ] + ], ) container_urns_diff = list( - state1.get_urns_not_in(type="container", other_checkpoint_state=state2) + state1.get_urns_not_in(type="container", other_checkpoint_state=state2), ) assert ( len(container_urns_diff) == 1 and container_urns_diff[0] == test_container_urn @@ -42,7 +42,7 @@ def test_state_backward_compat() -> None: encoded_view_urns=["mysql||db1.v1||PROD"], encoded_container_urns=["1154d1da73a95376c9f33f47694cf1de"], encoded_assertion_urns=["815963e1332b46a203504ba46ebfab24"], - ) + ), ) assert state == BaseSQLAlchemyCheckpointState( urns=[ @@ -50,7 +50,7 @@ def test_state_backward_compat() -> None: "urn:li:dataset:(urn:li:dataPlatform:mysql,db1.v1,PROD)", "urn:li:container:1154d1da73a95376c9f33f47694cf1de", "urn:li:assertion:815963e1332b46a203504ba46ebfab24", - ] + ], ) @@ -62,7 +62,7 @@ def test_deduplication_and_order_preservation() -> None: "urn:li:container:1154d1da73a95376c9f33f47694cf1de", "urn:li:assertion:815963e1332b46a203504ba46ebfab24", "urn:li:dataset:(urn:li:dataPlatform:mysql,db1.v1,PROD)", # duplicate - ] + ], ) assert len(state.urns) == 4 diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py index b04d4b86d2e4bb..95144dcdc2da37 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stale_entity_removal_handler.py @@ -41,10 +41,13 @@ ids=new_old_ent_tests.keys(), ) def test_change_percent( - new_entities: EntList, old_entities: EntList, expected_percent_change: float + new_entities: EntList, + old_entities: EntList, + expected_percent_change: float, ) -> None: actual_percent_change = compute_percent_entities_changed( - new_entities=new_entities, old_entities=old_entities + new_entities=new_entities, + old_entities=old_entities, ) assert actual_percent_change == expected_percent_change @@ -57,7 +60,7 @@ def test_filter_ignored_entity_types(): "urn:li:dataset:(urn:li:dataPlatform:postgres,dummy_dataset3,PROD)", "urn:li:dataProcessInstance:478810e859f870a54f72c681f41af619", "urn:li:query:query1", - ] + ], ) == [ "urn:li:dataset:(urn:li:dataPlatform:postgres,dummy_dataset1,PROD)", "urn:li:dataset:(urn:li:dataPlatform:postgres,dummy_dataset2,PROD)", diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py index e69727f73b6bf4..951b6c6c1e27e3 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py @@ -67,7 +67,8 @@ class DummySourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): ) # Configuration for stateful ingestion stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = pydantic.Field( - default=None, description="Dummy source Ingestion Config." + default=None, + description="Dummy source Ingestion Config.", ) report_failure: bool = Field( default=False, @@ -78,7 +79,8 @@ class DummySourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin): description="Data process instance id to ingest.", ) query_id_to_ingest: Optional[str] = Field( - default=None, description="Query id to ingest" + default=None, + description="Query id to ingest", ) @@ -96,7 +98,9 @@ def __init__(self, config: DummySourceConfig, ctx: PipelineContext): self.reporter = DummySourceReport() # Create and register the stateful ingestion use-case handler. self.stale_entity_removal_handler = StaleEntityRemovalHandler.create( - self, self.source_config, self.ctx + self, + self.source_config, + self.ctx, ) @classmethod @@ -149,7 +153,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield MetadataChangeProposalWrapper( entityUrn=QueryUrn(self.source_config.query_id_to_ingest).urn(), aspect=DataPlatformInstanceClass( - platform=DataPlatformUrn("bigquery").urn() + platform=DataPlatformUrn("bigquery").urn(), ), ).as_workunit() @@ -163,7 +167,7 @@ def get_report(self) -> SourceReport: @pytest.fixture(scope="module") def mock_generic_checkpoint_state(): with mock.patch( - "datahub.ingestion.source.state.entity_removal_state.GenericCheckpointState" + "datahub.ingestion.source.state.entity_removal_state.GenericCheckpointState", ) as mock_checkpoint_state: checkpoint_state = mock_checkpoint_state.return_value checkpoint_state.serde.return_value = "utf-8" @@ -215,7 +219,7 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): } with mock.patch( - "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" + "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj", ) as mock_state, mock.patch( "datahub.ingestion.source.state.stale_entity_removal_handler.STATEFUL_INGESTION_IGNORED_ENTITY_TYPES", {}, @@ -224,7 +228,7 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run1 = None pipeline_run1_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( # type: ignore - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) pipeline_run1_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name}" @@ -250,12 +254,12 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): assert checkpoint1.state with mock.patch( - "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" + "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj", ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run2 = None pipeline_run2_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) pipeline_run2_config["source"]["config"]["dataset_patterns"] = { "allow": ["dummy_dataset1", "dummy_dataset2"], @@ -288,10 +292,12 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform all assertions on the states. The deleted table should not be @@ -300,7 +306,7 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): state2 = cast(GenericCheckpointState, checkpoint2.state) difference_dataset_urns = list( - state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2), ) # the difference in dataset urns is the dataset which is not allowed to ingest assert len(difference_dataset_urns) == 1 @@ -317,7 +323,7 @@ def test_stateful_ingestion(pytestconfig, tmp_path, mock_time): "urn:li:query:query1", ] assert sorted(non_deletable_urns) == sorted( - report.last_state_non_deletable_entities + report.last_state_non_deletable_entities, ) @@ -365,12 +371,12 @@ def test_stateful_ingestion_failure(pytestconfig, tmp_path, mock_time): } with mock.patch( - "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" + "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj", ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run1 = None pipeline_run1_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( # type: ignore - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) pipeline_run1_config["sink"]["config"]["filename"] = ( f"{tmp_path}/{output_file_name}" @@ -396,12 +402,12 @@ def test_stateful_ingestion_failure(pytestconfig, tmp_path, mock_time): assert checkpoint1.state with mock.patch( - "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj" + "datahub.ingestion.source.state.stale_entity_removal_handler.StaleEntityRemovalHandler._get_state_obj", ) as mock_state: mock_state.return_value = GenericCheckpointState(serde="utf-8") pipeline_run2 = None pipeline_run2_config: Dict[str, Dict[str, Dict[str, Any]]] = dict( - base_pipeline_config # type: ignore + base_pipeline_config, # type: ignore ) pipeline_run2_config["source"]["config"]["dataset_patterns"] = { "allow": ["dummy_dataset1", "dummy_dataset2"], @@ -431,10 +437,12 @@ def test_stateful_ingestion_failure(pytestconfig, tmp_path, mock_time): # Validate that all providers have committed successfully. validate_all_providers_have_committed_successfully( - pipeline=pipeline_run1, expected_providers=1 + pipeline=pipeline_run1, + expected_providers=1, ) validate_all_providers_have_committed_successfully( - pipeline=pipeline_run2, expected_providers=1 + pipeline=pipeline_run2, + expected_providers=1, ) # Perform assertions on the states. The deleted table should be diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py b/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py index ba40962866f8cc..8138f3159284be 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py @@ -48,7 +48,7 @@ timeout_sec=10, extra_headers={}, max_threads=10, - ) + ), ), ), False, diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py b/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py index 3b0e4e31d4b4a2..d12a0b1d443a21 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py @@ -10,7 +10,7 @@ def test_kafka_common_state() -> None: state2 = GenericCheckpointState() topic_urns_diff = list( - state1.get_urns_not_in(type="topic", other_checkpoint_state=state2) + state1.get_urns_not_in(type="topic", other_checkpoint_state=state2), ) assert len(topic_urns_diff) == 1 and topic_urns_diff[0] == test_topic_urn @@ -21,8 +21,8 @@ def test_kafka_state_migration() -> None: "encoded_topic_urns": [ "kafka||test_topic1||test", "kafka||topic_2||DEV", - ] - } + ], + }, ) assert state.urns == [ "urn:li:dataset:(urn:li:dataPlatform:kafka,test_topic1,TEST)", diff --git a/metadata-ingestion/tests/unit/structured_properties/test_structured_properties.py b/metadata-ingestion/tests/unit/structured_properties/test_structured_properties.py index d03b08b77d5a96..09a99c038bec0d 100644 --- a/metadata-ingestion/tests/unit/structured_properties/test_structured_properties.py +++ b/metadata-ingestion/tests/unit/structured_properties/test_structured_properties.py @@ -45,7 +45,9 @@ def mock_graph(): def test_structured_properties_basic_creation(): props = StructuredProperties( - id="test_prop", type="string", description="Test description" + id="test_prop", + type="string", + description="Test description", ) assert props.id == "test_prop" assert props.type == "urn:li:dataType:datahub.string" @@ -95,7 +97,7 @@ def test_structured_properties_generate_mcps(): display_name="Test Property", entity_types=["dataset"], allowed_values=[ - AllowedValue(value="test_value", description="Test value description") + AllowedValue(value="test_value", description="Test value description"), ], ) @@ -121,14 +123,15 @@ def test_structured_properties_from_datahub(mock_graph): entityTypes=["urn:li:entityType:datahub.dataset"], cardinality="SINGLE", allowedValues=[ - PropertyValueClass(value="test_value", description="Test description") + PropertyValueClass(value="test_value", description="Test description"), ], ) mock_graph.get_aspect.return_value = mock_aspect props = StructuredProperties.from_datahub( - mock_graph, "urn:li:structuredProperty:test_prop" + mock_graph, + "urn:li:structuredProperty:test_prop", ) assert props.qualified_name == "test_prop" @@ -145,7 +148,7 @@ def test_structured_properties_to_yaml(tmp_path): type="string", description="Test description", allowed_values=[ - AllowedValue(value="test_value", description="Test value description") + AllowedValue(value="test_value", description="Test value description"), ], ) @@ -185,7 +188,7 @@ def test_structured_properties_type_qualifier(): mcps = props.generate_mcps() assert mcps[0].aspect assert mcps[0].aspect.typeQualifier["allowedTypes"] == [ # type: ignore - "urn:li:entityType:datahub.dataset" + "urn:li:entityType:datahub.dataset", ] @@ -206,7 +209,7 @@ def test_structured_properties_list(mock_graph): # Verify get_urns_by_filter was called with correct arguments mock_graph.get_urns_by_filter.assert_called_once_with( - entity_types=["structuredProperty"] + entity_types=["structuredProperty"], ) assert len(props) == 2 diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index f8b6220d182735..f871b03852820c 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -24,7 +24,7 @@ def test_athena_config_query_location_old_plus_new_value_not_allowed(): "s3_staging_dir": "s3://sample-staging-dir/", "query_result_location": "s3://query_result_location", "work_group": "test-workgroup", - } + }, ) @@ -36,7 +36,7 @@ def test_athena_config_staging_dir_is_set_as_query_result(): "aws_region": "us-west-1", "s3_staging_dir": "s3://sample-staging-dir/", "work_group": "test-workgroup", - } + }, ) expected_config = AthenaConfig.parse_obj( @@ -44,7 +44,7 @@ def test_athena_config_staging_dir_is_set_as_query_result(): "aws_region": "us-west-1", "query_result_location": "s3://sample-staging-dir/", "work_group": "test-workgroup", - } + }, ) assert config.json() == expected_config.json() @@ -58,7 +58,7 @@ def test_athena_uri(): "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", "work_group": "test-workgroup", - } + }, ) assert config.get_sql_alchemy_url() == ( "awsathena+rest://@athena.us-west-1.amazonaws.com:443" @@ -81,7 +81,7 @@ def test_athena_get_table_properties(): "aws_region": "us-west-1", "s3_staging_dir": "s3://sample-staging-dir/", "work_group": "test-workgroup", - } + }, ) schema: str = "test_schema" table: str = "test_table" @@ -110,7 +110,7 @@ def test_athena_get_table_properties(): mock_inspector = mock.MagicMock() mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor mock_cursor.get_table_metadata.return_value = AthenaTableMetadata( - response=table_metadata + response=table_metadata, ) # Mock partition query results @@ -126,7 +126,9 @@ def test_athena_get_table_properties(): # Test table properties description, custom_properties, location = source.get_table_properties( - inspector=mock_inspector, table=table, schema=schema + inspector=mock_inspector, + table=table, + schema=schema, ) assert custom_properties == { "comment": "testComment", @@ -143,7 +145,9 @@ def test_athena_get_table_properties(): # Test partition functionality partitions = source.get_partitions( - inspector=mock_inspector, schema=schema, table=table + inspector=mock_inspector, + schema=schema, + table=table, ) assert partitions == ["year", "month"] @@ -167,19 +171,24 @@ def test_athena_get_table_properties(): def test_get_column_type_simple_types(): assert isinstance( - CustomAthenaRestDialect()._get_column_type(type_="int"), types.Integer + CustomAthenaRestDialect()._get_column_type(type_="int"), + types.Integer, ) assert isinstance( - CustomAthenaRestDialect()._get_column_type(type_="string"), types.String + CustomAthenaRestDialect()._get_column_type(type_="string"), + types.String, ) assert isinstance( - CustomAthenaRestDialect()._get_column_type(type_="boolean"), types.BOOLEAN + CustomAthenaRestDialect()._get_column_type(type_="boolean"), + types.BOOLEAN, ) assert isinstance( - CustomAthenaRestDialect()._get_column_type(type_="long"), types.BIGINT + CustomAthenaRestDialect()._get_column_type(type_="long"), + types.BIGINT, ) assert isinstance( - CustomAthenaRestDialect()._get_column_type(type_="double"), types.FLOAT + CustomAthenaRestDialect()._get_column_type(type_="double"), + types.FLOAT, ) @@ -217,7 +226,7 @@ def test_column_type_decimal(): def test_column_type_complex_combination(): result = CustomAthenaRestDialect()._get_column_type( - type_="struct>>" + type_="struct>>", ) assert isinstance(result, STRUCT) @@ -239,13 +248,15 @@ def test_column_type_complex_combination(): assert isinstance(result._STRUCT_fields[2][1].item_type._STRUCT_fields[0], tuple) assert result._STRUCT_fields[2][1].item_type._STRUCT_fields[0][0] == "id" assert isinstance( - result._STRUCT_fields[2][1].item_type._STRUCT_fields[0][1], types.String + result._STRUCT_fields[2][1].item_type._STRUCT_fields[0][1], + types.String, ) assert isinstance(result._STRUCT_fields[2][1].item_type._STRUCT_fields[1], tuple) assert result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][0] == "label" assert isinstance( - result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String + result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], + types.String, ) diff --git a/metadata-ingestion/tests/unit/test_aws_common.py b/metadata-ingestion/tests/unit/test_aws_common.py index dd1f06cf9bfd55..85b7901cb2a462 100644 --- a/metadata-ingestion/tests/unit/test_aws_common.py +++ b/metadata-ingestion/tests/unit/test_aws_common.py @@ -46,7 +46,8 @@ def test_environment_detection_lambda(self, mock_disable_ec2_metadata): assert detect_aws_environment() == AwsEnvironment.LAMBDA def test_environment_detection_lambda_cloudformation( - self, mock_disable_ec2_metadata + self, + mock_disable_ec2_metadata, ): """Test CloudFormation Lambda environment detection""" with patch.dict( @@ -77,7 +78,8 @@ def test_environment_detection_app_runner(self, mock_disable_ec2_metadata): def test_environment_detection_ecs(self, mock_disable_ec2_metadata): """Test ECS environment detection""" with patch.dict( - os.environ, {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"} + os.environ, + {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"}, ): assert detect_aws_environment() == AwsEnvironment.ECS @@ -164,11 +166,12 @@ def test_get_current_identity_lambda(self): "Effect": "Allow", "Principal": {"Service": "lambda.amazonaws.com"}, "Action": "sts:AssumeRole", - } + }, ], } iam_client.create_role( - RoleName="test-role", AssumeRolePolicyDocument=json.dumps(trust_policy) + RoleName="test-role", + AssumeRolePolicyDocument=json.dumps(trust_policy), ) lambda_client = boto3.client("lambda", region_name="us-east-1") @@ -197,7 +200,7 @@ def test_get_instance_role_arn_success(self, mock_put, mock_get): with patch("boto3.client") as mock_boto: mock_sts = MagicMock() mock_sts.get_caller_identity.return_value = { - "Arn": "arn:aws:sts::123456789012:assumed-role/test-role/instance" + "Arn": "arn:aws:sts::123456789012:assumed-role/test-role/instance", } mock_boto.return_value = mock_sts @@ -239,7 +242,7 @@ def test_aws_connection_config_role_assumption(self): ) with patch( - "datahub.ingestion.source.aws.aws_common.get_current_identity" + "datahub.ingestion.source.aws.aws_common.get_current_identity", ) as mock_identity: mock_identity.return_value = (None, None) session = config.get_session() @@ -255,7 +258,7 @@ def test_aws_connection_config_skip_role_assumption(self): ) with patch( - "datahub.ingestion.source.aws.aws_common.get_current_identity" + "datahub.ingestion.source.aws.aws_common.get_current_identity", ) as mock_identity: mock_identity.return_value = ( "arn:aws:iam::123456789012:role/current-role", @@ -278,7 +281,7 @@ def test_aws_connection_config_multiple_roles(self): ) with patch( - "datahub.ingestion.source.aws.aws_common.get_current_identity" + "datahub.ingestion.source.aws.aws_common.get_current_identity", ) as mock_identity: mock_identity.return_value = (None, None) session = config.get_session() @@ -334,7 +337,10 @@ def test_aws_connection_config_validation_error(self): ], ) def test_environment_detection_parametrized( - self, mock_disable_ec2_metadata, env_vars, expected_environment + self, + mock_disable_ec2_metadata, + env_vars, + expected_environment, ): """Parametrized test for environment detection with different configurations""" with patch.dict(os.environ, env_vars, clear=True): diff --git a/metadata-ingestion/tests/unit/test_capability_report.py b/metadata-ingestion/tests/unit/test_capability_report.py index 08ada0386b0eae..5bacaf73d62536 100644 --- a/metadata-ingestion/tests/unit/test_capability_report.py +++ b/metadata-ingestion/tests/unit/test_capability_report.py @@ -11,14 +11,20 @@ def test_basic_capability_report(): report = TestConnectionReport( basic_connectivity=CapabilityReport( - capable=True, failure_reason=None, mitigation_message=None + capable=True, + failure_reason=None, + mitigation_message=None, ), capability_report={ "CONTAINERS": CapabilityReport( - capable=True, failure_reason=None, mitigation_message=None + capable=True, + failure_reason=None, + mitigation_message=None, ), "SCHEMA_METADATA": CapabilityReport( - capable=True, failure_reason=None, mitigation_message=None + capable=True, + failure_reason=None, + mitigation_message=None, ), "DESCRIPTIONS": CapabilityReport( capable=False, @@ -26,7 +32,9 @@ def test_basic_capability_report(): mitigation_message="Enable admin privileges for this account.", ), "DATA_PROFILING": CapabilityReport( - capable=True, failure_reason=None, mitigation_message=None + capable=True, + failure_reason=None, + mitigation_message=None, ), SourceCapability.DOMAINS: CapabilityReport(capable=True), }, diff --git a/metadata-ingestion/tests/unit/test_cassandra_source.py b/metadata-ingestion/tests/unit/test_cassandra_source.py index 75dedde76c7c89..2a8106ae8ba841 100644 --- a/metadata-ingestion/tests/unit/test_cassandra_source.py +++ b/metadata-ingestion/tests/unit/test_cassandra_source.py @@ -20,7 +20,8 @@ def assert_field_paths_are_unique(fields: List[SchemaField]) -> None: def assert_field_paths_match( - fields: List[SchemaField], expected_field_paths: List[str] + fields: List[SchemaField], + expected_field_paths: List[str], ) -> None: logger.debug('FieldPaths=\n"' + '",\n"'.join(f.fieldPath for f in fields) + '"') assert len(fields) == len(expected_field_paths) @@ -44,7 +45,7 @@ def assert_field_paths_match( "email", "name", ], - ) + ), } @@ -54,7 +55,8 @@ def assert_field_paths_match( ids=schema_test_cases.keys(), ) def test_cassandra_schema_conversion( - schema: str, expected_field_paths: List[str] + schema: str, + expected_field_paths: List[str], ) -> None: schema_dict: Dict[str, List[Any]] = json.loads(schema) column_infos: List = schema_dict["column_infos"] diff --git a/metadata-ingestion/tests/unit/test_classification.py b/metadata-ingestion/tests/unit/test_classification.py index c79ae5808b2a69..74c4eb86ab9cfa 100644 --- a/metadata-ingestion/tests/unit/test_classification.py +++ b/metadata-ingestion/tests/unit/test_classification.py @@ -21,7 +21,7 @@ def test_default_datahub_classifier_config(): def test_selective_datahub_classifier_config_override(): simple_config_override = DataHubClassifier.create( - config_dict={"confidence_level_threshold": 0.7} + config_dict={"confidence_level_threshold": 0.7}, ).config assert simple_config_override.info_types_config is not None @@ -38,7 +38,7 @@ def test_selective_datahub_classifier_config_override(): }, }, }, - } + }, ).config assert complex_config_override.info_types_config is not None @@ -76,17 +76,17 @@ def test_custom_info_type_config(): "regex": [ ".*region.*id", ".*cloud.*region.*", - ] + ], }, "Values": { "prediction_type": "regex", "regex": [ - r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+" + r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+", ], }, }, }, - } + }, ).config assert custom_info_type_config.info_types_config @@ -110,7 +110,8 @@ def test_custom_info_type_config(): def test_incorrect_custom_info_type_config(): with pytest.raises( - ValidationError, match="Missing Configuration for Prediction Factor" + ValidationError, + match="Missing Configuration for Prediction Factor", ): DataHubClassifierConfig.parse_obj( { @@ -127,11 +128,11 @@ def test_incorrect_custom_info_type_config(): "regex": [ ".*region.*id", ".*cloud.*region.*", - ] + ], }, }, }, - } + }, ) with pytest.raises(ValidationError, match="Invalid Prediction Type"): @@ -150,12 +151,12 @@ def test_incorrect_custom_info_type_config(): "regex": [ ".*region.*id", ".*cloud.*region.*", - ] + ], }, "Values": {"prediction_type": "library", "library": ["spacy"]}, }, }, - } + }, ) @@ -180,14 +181,14 @@ def test_exclude_name_config(): "^.*add.*mail.*$", "email", "mail", - ] + ], }, "Description": {"regex": []}, "Datatype": {"type": ["str"]}, "Values": {"prediction_type": "regex", "regex": [], "library": []}, - } + }, }, - } + }, ).config assert config.info_types_config["Email_Address"].ExcludeName is not None assert config.info_types_config["Email_Address"].ExcludeName == [ @@ -216,13 +217,13 @@ def test_no_exclude_name_config(): "^.*add.*mail.*$", "email", "mail", - ] + ], }, "Description": {"regex": []}, "Datatype": {"type": ["str"]}, "Values": {"prediction_type": "regex", "regex": [], "library": []}, - } + }, }, - } + }, ).config assert config.info_types_config["Email_Address"].ExcludeName is None diff --git a/metadata-ingestion/tests/unit/test_clickhouse_source.py b/metadata-ingestion/tests/unit/test_clickhouse_source.py index 1b2ffb70c8d190..95571b60aa6f83 100644 --- a/metadata-ingestion/tests/unit/test_clickhouse_source.py +++ b/metadata-ingestion/tests/unit/test_clickhouse_source.py @@ -9,7 +9,7 @@ def test_clickhouse_uri_https(): "host_port": "host:1111", "database": "db", "uri_opts": {"protocol": "https"}, - } + }, ) assert ( config.get_sql_alchemy_url() @@ -24,7 +24,7 @@ def test_clickhouse_uri_native(): "password": "password", "host_port": "host:1111", "scheme": "clickhouse+native", - } + }, ) assert config.get_sql_alchemy_url() == "clickhouse+native://user:password@host:1111" @@ -38,7 +38,7 @@ def test_clickhouse_uri_native_secure(): "database": "db", "scheme": "clickhouse+native", "uri_opts": {"secure": True}, - } + }, ) assert ( config.get_sql_alchemy_url() @@ -53,7 +53,7 @@ def test_clickhouse_uri_default_password(): "host_port": "host:1111", "database": "db", "scheme": "clickhouse+native", - } + }, ) assert config.get_sql_alchemy_url() == "clickhouse+native://user@host:1111/db" @@ -67,7 +67,7 @@ def test_clickhouse_uri_native_secure_backward_compatibility(): "database": "db", "scheme": "clickhouse+native", "secure": True, - } + }, ) assert ( config.get_sql_alchemy_url() @@ -83,7 +83,7 @@ def test_clickhouse_uri_https_backward_compatibility(): "host_port": "host:1111", "database": "db", "protocol": "https", - } + }, ) assert ( config.get_sql_alchemy_url() diff --git a/metadata-ingestion/tests/unit/test_compare_metadata.py b/metadata-ingestion/tests/unit/test_compare_metadata.py index 8316e226d4847c..1509e1b1e4f5ce 100644 --- a/metadata-ingestion/tests/unit/test_compare_metadata.py +++ b/metadata-ingestion/tests/unit/test_compare_metadata.py @@ -49,7 +49,7 @@ }, "proposedDelta": null } -]""" +]""", ) # Timestamps changed from basic_1 but same otherwise. @@ -97,7 +97,7 @@ }, "proposedDelta": null } -]""" +]""", ) # Dataset owner changed from basic_2. @@ -145,7 +145,7 @@ }, "proposedDelta": null } -]""" +]""", ) @@ -156,12 +156,16 @@ def test_basic_diff_same() -> None: def test_basic_diff_only_owner_change() -> None: with pytest.raises(AssertionError): assert not diff_metadata_json( - basic_2, basic_3, mce_helpers.IGNORE_PATH_TIMESTAMPS + basic_2, + basic_3, + mce_helpers.IGNORE_PATH_TIMESTAMPS, ) def test_basic_diff_owner_change() -> None: with pytest.raises(AssertionError): assert not diff_metadata_json( - basic_1, basic_3, mce_helpers.IGNORE_PATH_TIMESTAMPS + basic_1, + basic_3, + mce_helpers.IGNORE_PATH_TIMESTAMPS, ) diff --git a/metadata-ingestion/tests/unit/test_confluent_schema_registry.py b/metadata-ingestion/tests/unit/test_confluent_schema_registry.py index effa6ba85acaeb..ca2fa5da05d40f 100644 --- a/metadata-ingestion/tests/unit/test_confluent_schema_registry.py +++ b/metadata-ingestion/tests/unit/test_confluent_schema_registry.py @@ -66,10 +66,11 @@ def test_get_schema_str_replace_confluent_ref_avro(self): "bootstrap": "localhost:9092", "schema_registry_url": "http://localhost:8081", }, - } + }, ) confluent_schema_registry = ConfluentSchemaRegistry.create( - kafka_source_config, KafkaSourceReport() + kafka_source_config, + KafkaSourceReport(), ) def new_get_latest_version(subject_name: str) -> RegisteredSchema: @@ -93,14 +94,16 @@ def new_get_latest_version(subject_name: str) -> RegisteredSchema: schema_type="AVRO", references=[ SchemaReference( - name="TestTopic1", subject="schema_subject_1", version=1 - ) + name="TestTopic1", + subject="schema_subject_1", + version=1, + ), ], - ) + ), ) ) assert schema_str == ConfluentSchemaRegistry._compact_schema( - schema_str_final + schema_str_final, ) with patch.object( @@ -116,14 +119,16 @@ def new_get_latest_version(subject_name: str) -> RegisteredSchema: schema_type="AVRO", references=[ SchemaReference( - name="schema_subject_1", subject="TestTopic1", version=1 - ) + name="schema_subject_1", + subject="TestTopic1", + version=1, + ), ], - ) + ), ) ) assert schema_str == ConfluentSchemaRegistry._compact_schema( - schema_str_final + schema_str_final, ) diff --git a/metadata-ingestion/tests/unit/test_csv_enricher_source.py b/metadata-ingestion/tests/unit/test_csv_enricher_source.py index 4e05daf0779c64..9019cf8879f420 100644 --- a/metadata-ingestion/tests/unit/test_csv_enricher_source.py +++ b/metadata-ingestion/tests/unit/test_csv_enricher_source.py @@ -18,7 +18,8 @@ def create_owners_list_from_urn_list( - owner_urns: List[str], source_type: str + owner_urns: List[str], + source_type: str, ) -> List[OwnerClass]: ownership_source_type: Union[None, OwnershipSourceClass] = None if source_type: @@ -38,21 +39,23 @@ def create_mocked_csv_enricher_source() -> CSVEnricherSource: ctx = PipelineContext("test-run-id") graph = mock.MagicMock() graph.get_ownership.return_value = mce_builder.make_ownership_aspect_from_urn_list( - ["urn:li:corpuser:olduser1"], "AUDIT" + ["urn:li:corpuser:olduser1"], + "AUDIT", ) graph.get_glossary_terms.return_value = ( mce_builder.make_glossary_terms_aspect_from_urn_list( - ["urn:li:glossaryTerm:oldterm1", "urn:li:glossaryTerm:oldterm2"] + ["urn:li:glossaryTerm:oldterm1", "urn:li:glossaryTerm:oldterm2"], ) ) graph.get_tags.return_value = mce_builder.make_global_tag_aspect_with_tag_list( - ["oldtag1", "oldtag2"] + ["oldtag1", "oldtag2"], ) graph.get_aspect_v2.return_value = None graph.get_domain.return_value = None ctx.graph = graph return CSVEnricherSource( - CSVEnricherConfig(**create_base_csv_enricher_config()), ctx + CSVEnricherConfig(**create_base_csv_enricher_config()), + ctx, ) @@ -83,7 +86,8 @@ def test_get_resource_glossary_terms_no_new_glossary_terms(): GlossaryTermAssociationClass(term) for term in new_glossary_terms ] maybe_terms_wu = source.get_resource_glossary_terms_work_unit( - DATASET_URN, term_associations + DATASET_URN, + term_associations, ) assert not maybe_terms_wu @@ -98,7 +102,8 @@ def test_get_resource_glossary_terms_work_unit_produced(): GlossaryTermAssociationClass(term) for term in new_glossary_terms ] maybe_terms_wu = source.get_resource_glossary_terms_work_unit( - DATASET_URN, term_associations + DATASET_URN, + term_associations, ) assert maybe_terms_wu @@ -154,8 +159,9 @@ def test_maybe_extract_owners_ownership_type_urn(): } assert source.maybe_extract_owners(row=row, is_resource_row=True) == [ OwnerClass( - owner="urn:li:corpuser:datahub", type=OwnershipTypeClass.TECHNICAL_OWNER - ) + owner="urn:li:corpuser:datahub", + type=OwnershipTypeClass.TECHNICAL_OWNER, + ), ] row2 = { @@ -169,7 +175,7 @@ def test_maybe_extract_owners_ownership_type_urn(): owner="urn:li:corpuser:datahub", type=OwnershipTypeClass.CUSTOM, typeUrn="urn:li:ownershipType:technical_owner", - ) + ), ] @@ -187,7 +193,8 @@ def test_get_resource_description_no_description(): source = create_mocked_csv_enricher_source() new_description = None maybe_description_wu = source.get_resource_description_work_unit( - DATASET_URN, new_description + DATASET_URN, + new_description, ) assert not maybe_description_wu @@ -196,7 +203,8 @@ def test_get_resource_description_work_unit_produced(): source = create_mocked_csv_enricher_source() new_description = "description" maybe_description_wu = source.get_resource_description_work_unit( - DATASET_URN, new_description + DATASET_URN, + new_description, ) assert maybe_description_wu diff --git a/metadata-ingestion/tests/unit/test_datahub_source.py b/metadata-ingestion/tests/unit/test_datahub_source.py index 67b2b85d9af6dd..4f92ff84dc9b41 100644 --- a/metadata-ingestion/tests/unit/test_datahub_source.py +++ b/metadata-ingestion/tests/unit/test_datahub_source.py @@ -31,7 +31,8 @@ def test_version_orderer(rows): orderer = VersionOrderer[Dict[str, Any]](enabled=True) ordered_rows = list(orderer(rows)) assert ordered_rows == sorted( - ordered_rows, key=lambda x: (x["createdon"], x["version"] == 0) + ordered_rows, + key=lambda x: (x["createdon"], x["version"] == 0), ) diff --git a/metadata-ingestion/tests/unit/test_dbt_source.py b/metadata-ingestion/tests/unit/test_dbt_source.py index d7899af69f8405..02cdb2f87249bb 100644 --- a/metadata-ingestion/tests/unit/test_dbt_source.py +++ b/metadata-ingestion/tests/unit/test_dbt_source.py @@ -30,7 +30,8 @@ def create_owners_list_from_urn_list( - owner_urns: List[str], source_type: str + owner_urns: List[str], + source_type: str, ) -> List[OwnerClass]: ownership_source_type: Union[None, OwnershipSourceClass] = None if source_type: @@ -50,15 +51,16 @@ def create_mocked_dbt_source() -> DBTCoreSource: ctx = PipelineContext(run_id="test-run-id", pipeline_name="dbt-source") graph = mock.MagicMock() graph.get_ownership.return_value = mce_builder.make_ownership_aspect_from_urn_list( - ["urn:li:corpuser:test_user"], "AUDIT" + ["urn:li:corpuser:test_user"], + "AUDIT", ) graph.get_glossary_terms.return_value = ( mce_builder.make_glossary_terms_aspect_from_urn_list( - ["urn:li:glossaryTerm:old", "urn:li:glossaryTerm:old2"] + ["urn:li:glossaryTerm:old", "urn:li:glossaryTerm:old2"], ) ) graph.get_tags.return_value = mce_builder.make_global_tag_aspect_with_tag_list( - ["non_dbt_existing", "dbt:existing"] + ["non_dbt_existing", "dbt:existing"], ) ctx.graph = graph return DBTCoreSource(DBTCoreConfig(**create_base_dbt_config()), ctx, "dbt") @@ -82,7 +84,9 @@ def test_dbt_source_patching_no_new(): # verifying when there are no new owners to be added assert source.ctx.graph transformed_owner_list = source.get_transformed_owners_by_source_type( - [], "urn:li:dataset:dummy", "SERVICE" + [], + "urn:li:dataset:dummy", + "SERVICE", ) assert len(transformed_owner_list) == 1 @@ -93,7 +97,9 @@ def test_dbt_source_patching_no_conflict(): new_owner_urns = ["urn:li:corpuser:new_test"] new_owners_list = create_owners_list_from_urn_list(new_owner_urns, "SERVICE") transformed_owner_list = source.get_transformed_owners_by_source_type( - new_owners_list, "urn:li:dataset:dummy", "DATABASE" + new_owners_list, + "urn:li:dataset:dummy", + "DATABASE", ) assert len(transformed_owner_list) == 2 owner_set = {"urn:li:corpuser:test_user", "urn:li:corpuser:new_test"} @@ -111,7 +117,9 @@ def test_dbt_source_patching_with_conflict(): new_owner_urns = ["urn:li:corpuser:new_test", "urn:li:corpuser:new_test2"] new_owners_list = create_owners_list_from_urn_list(new_owner_urns, "AUDIT") transformed_owner_list = source.get_transformed_owners_by_source_type( - new_owners_list, "urn:li:dataset:dummy", "AUDIT" + new_owners_list, + "urn:li:dataset:dummy", + "AUDIT", ) assert len(transformed_owner_list) == 2 expected_owner_set = {"urn:li:corpuser:new_test", "urn:li:corpuser:new_test2"} @@ -129,13 +137,16 @@ def test_dbt_source_patching_with_conflict_null_source_type_in_existing_owner(): source = create_mocked_dbt_source() graph = mock.MagicMock() graph.get_ownership.return_value = mce_builder.make_ownership_aspect_from_urn_list( - ["urn:li:corpuser:existing_test_user"], None + ["urn:li:corpuser:existing_test_user"], + None, ) source.ctx.graph = graph new_owner_urns = ["urn:li:corpuser:new_test", "urn:li:corpuser:new_test2"] new_owners_list = create_owners_list_from_urn_list(new_owner_urns, "AUDIT") transformed_owner_list = source.get_transformed_owners_by_source_type( - new_owners_list, "urn:li:dataset:dummy", "AUDIT" + new_owners_list, + "urn:li:dataset:dummy", + "AUDIT", ) assert len(transformed_owner_list) == 2 expected_owner_set = {"urn:li:corpuser:new_test", "urn:li:corpuser:new_test2"} @@ -153,10 +164,12 @@ def test_dbt_source_patching_tags(): # override the existing one with the same prefix. source = create_mocked_dbt_source() new_tag_aspect = mce_builder.make_global_tag_aspect_with_tag_list( - ["new_non_dbt", "dbt:new_dbt"] + ["new_non_dbt", "dbt:new_dbt"], ) transformed_tags = source.get_transformed_tags_by_prefix( - new_tag_aspect.tags, "urn:li:dataset:dummy", "dbt:" + new_tag_aspect.tags, + "urn:li:dataset:dummy", + "dbt:", ) expected_tags = { "urn:li:tag:new_non_dbt", @@ -172,10 +185,11 @@ def test_dbt_source_patching_terms(): # existing terms and new terms have two terms each and one common. After deduping we should only get 3 unique terms source = create_mocked_dbt_source() new_terms = mce_builder.make_glossary_terms_aspect_from_urn_list( - ["urn:li:glossaryTerm:old", "urn:li:glossaryTerm:new"] + ["urn:li:glossaryTerm:old", "urn:li:glossaryTerm:new"], ) transformed_terms = source.get_transformed_terms( - new_terms.terms, "urn:li:dataset:dummy" + new_terms.terms, + "urn:li:dataset:dummy", ) expected_terms = { "urn:li:glossaryTerm:old", @@ -266,7 +280,7 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference(): **create_base_dbt_config(), "skip_sources_in_lineage": True, "prefer_sql_parser_lineage": True, - } + }, ) source: DBTCoreSource = DBTCoreSource(config, ctx, "dbt") all_nodes_map = { @@ -294,7 +308,8 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference(): } source._infer_schemas_and_update_cll(all_nodes_map) upstream_lineage = source._create_lineage_aspect_for_dbt_node( - all_nodes_map["model1"], all_nodes_map + all_nodes_map["model1"], + all_nodes_map, ) assert upstream_lineage is not None assert len(upstream_lineage.upstreams) == 1 @@ -340,7 +355,7 @@ def test_default_convert_column_urns_to_lowercase(): **config_dict, "convert_column_urns_to_lowercase": False, "target_platform": "snowflake", - } + }, ) assert config.convert_column_urns_to_lowercase is False @@ -464,7 +479,7 @@ def test_dbt_time_parsing() -> None: # Ensure that we get an object with tzinfo set to UTC. assert timestamp.tzinfo is not None and timestamp.tzinfo.utcoffset( - timestamp + timestamp, ) == timedelta(0) diff --git a/metadata-ingestion/tests/unit/test_elasticsearch_source.py b/metadata-ingestion/tests/unit/test_elasticsearch_source.py index 3b8435e531fb21..d989e4ef6c8c52 100644 --- a/metadata-ingestion/tests/unit/test_elasticsearch_source.py +++ b/metadata-ingestion/tests/unit/test_elasticsearch_source.py @@ -26,8 +26,8 @@ def test_elasticsearch_throws_error_wrong_operation_config(): "operation_config": { "lower_freq_profile_enabled": True, }, - } - } + }, + }, ) @@ -39,7 +39,8 @@ def assert_field_paths_are_unique(fields: List[SchemaField]) -> None: def assret_field_paths_match( - fields: List[SchemaField], expected_field_paths: List[str] + fields: List[SchemaField], + expected_field_paths: List[str], ) -> None: logger.debug('FieldPaths=\n"' + '",\n"'.join(f.fieldPath for f in fields) + '"') assert len(fields) == len(expected_field_paths) @@ -2448,7 +2449,8 @@ def assret_field_paths_match( ids=schema_test_cases.keys(), ) def test_elastic_search_schema_conversion( - schema: str, expected_field_paths: List[str] + schema: str, + expected_field_paths: List[str], ) -> None: schema_dict: Dict[str, Any] = json.loads(schema) mappings: Dict[str, Any] = {"properties": schema_dict} @@ -2495,7 +2497,7 @@ def test_collapse_urns() -> None: collapse_urns=CollapseUrns( urns_suffix_regex=[ "-\\d+$", - ] + ], ), ) == "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,platform1.prefix_datahub_usage_event,PROD)" @@ -2507,7 +2509,7 @@ def test_collapse_urns() -> None: collapse_urns=CollapseUrns( urns_suffix_regex=[ "-\\d{4}\\.\\d{2}\\.\\d{2}", - ] + ], ), ) == "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,platform1.prefix_datahub_usage_event,PROD)" diff --git a/metadata-ingestion/tests/unit/test_gc.py b/metadata-ingestion/tests/unit/test_gc.py index fde9a3f2e0cf03..03074340097d56 100644 --- a/metadata-ingestion/tests/unit/test_gc.py +++ b/metadata-ingestion/tests/unit/test_gc.py @@ -23,7 +23,10 @@ def setUp(self): self.config = SoftDeletedEntitiesCleanupConfig() self.report = SoftDeletedEntitiesReport() self.cleanup = SoftDeletedEntitiesCleanup( - self.ctx, self.config, self.report, dry_run=True + self.ctx, + self.config, + self.report, + dry_run=True, ) def test_update_report(self): @@ -46,11 +49,14 @@ def setUp(self): self.config = DataProcessCleanupConfig() self.report = DataProcessCleanupReport() self.cleanup = DataProcessCleanup( - self.ctx, self.config, self.report, dry_run=True + self.ctx, + self.config, + self.report, + dry_run=True, ) @patch( - "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis" + "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis", ) def test_delete_dpi_from_datajobs(self, mock_fetch_dpis): job = DataJobEntity( @@ -65,7 +71,7 @@ def test_delete_dpi_from_datajobs(self, mock_fetch_dpis): { "urn": f"urn:li:dataprocessInstance:{i}", "created": { - "time": int(datetime.now(timezone.utc).timestamp() + i) * 1000 + "time": int(datetime.now(timezone.utc).timestamp() + i) * 1000, }, } for i in range(10) @@ -74,7 +80,7 @@ def test_delete_dpi_from_datajobs(self, mock_fetch_dpis): self.assertEqual(5, self.report.num_aspects_removed) @patch( - "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis" + "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis", ) def test_delete_dpi_from_datajobs_without_dpis(self, mock_fetch_dpis): job = DataJobEntity( @@ -90,7 +96,7 @@ def test_delete_dpi_from_datajobs_without_dpis(self, mock_fetch_dpis): self.assertEqual(0, self.report.num_aspects_removed) @patch( - "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis" + "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis", ) def test_delete_dpi_from_datajobs_without_dpi_created_time(self, mock_fetch_dpis): job = DataJobEntity( @@ -107,16 +113,17 @@ def test_delete_dpi_from_datajobs_without_dpi_created_time(self, mock_fetch_dpis { "urn": "urn:li:dataprocessInstance:11", "created": {"time": int(datetime.now(timezone.utc).timestamp() * 1000)}, - } + }, ] self.cleanup.delete_dpi_from_datajobs(job) self.assertEqual(10, self.report.num_aspects_removed) @patch( - "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis" + "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis", ) def test_delete_dpi_from_datajobs_without_dpi_null_created_time( - self, mock_fetch_dpis + self, + mock_fetch_dpis, ): job = DataJobEntity( urn="urn:li:dataJob:1", @@ -132,13 +139,13 @@ def test_delete_dpi_from_datajobs_without_dpi_null_created_time( { "urn": "urn:li:dataprocessInstance:11", "created": {"time": None}, - } + }, ] self.cleanup.delete_dpi_from_datajobs(job) self.assertEqual(11, self.report.num_aspects_removed) @patch( - "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis" + "datahub.ingestion.source.gc.dataprocess_cleanup.DataProcessCleanup.fetch_dpis", ) def test_delete_dpi_from_datajobs_without_dpi_without_time(self, mock_fetch_dpis): job = DataJobEntity( @@ -155,7 +162,7 @@ def test_delete_dpi_from_datajobs_without_dpi_without_time(self, mock_fetch_dpis { "urn": "urn:li:dataprocessInstance:11", "created": None, - } + }, ] self.cleanup.delete_dpi_from_datajobs(job) self.assertEqual(11, self.report.num_aspects_removed) @@ -170,12 +177,12 @@ def test_fetch_dpis(self): { "urn": "urn:li:dataprocessInstance:1", "created": { - "time": int(datetime.now(timezone.utc).timestamp()) + "time": int(datetime.now(timezone.utc).timestamp()), }, - } - ] - } - } + }, + ], + }, + }, } dpis = self.cleanup.fetch_dpis("urn:li:dataJob:1", 10) self.assertEqual(len(dpis), 1) diff --git a/metadata-ingestion/tests/unit/test_gcs_source.py b/metadata-ingestion/tests/unit/test_gcs_source.py index 9d5f4e915b18cf..2e69352cf6bb0d 100644 --- a/metadata-ingestion/tests/unit/test_gcs_source.py +++ b/metadata-ingestion/tests/unit/test_gcs_source.py @@ -15,7 +15,7 @@ def test_gcs_source_setup(): { "include": "gs://bucket_name/{table}/year={partition[0]}/month={partition[1]}/day={partition[1]}/*.parquet", "table_name": "{table}", - } + }, ], "credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"}, } @@ -23,7 +23,8 @@ def test_gcs_source_setup(): assert gcs.s3_source.source_config.platform == PLATFORM_GCS assert ( gcs.s3_source.create_s3_path( - "bucket-name", "food_parquet/year%3D2023/month%3D4/day%3D24/part1.parquet" + "bucket-name", + "food_parquet/year%3D2023/month%3D4/day%3D24/part1.parquet", ) == "s3://bucket-name/food_parquet/year=2023/month=4/day=24/part1.parquet" ) @@ -46,7 +47,7 @@ def test_data_lake_incorrect_config_raises_error(): { "include": "gs://a/b/c/d/{table}/*.*", "exclude": ["gs://a/b/c/d/a-{exclude}/**"], - } + }, ], "credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"}, } @@ -58,7 +59,7 @@ def test_data_lake_incorrect_config_raises_error(): "path_specs": [ { "include": "gs://a/b/c/d/{table}/*.hd5", - } + }, ], "credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"}, } @@ -70,7 +71,7 @@ def test_data_lake_incorrect_config_raises_error(): "path_specs": [ { "include": "gs://a/b/c/d/**/*.*", - } + }, ], "credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"}, } diff --git a/metadata-ingestion/tests/unit/test_ge_profiling_config.py b/metadata-ingestion/tests/unit/test_ge_profiling_config.py index f4d73a6ffe1e4e..c8971605bcada3 100644 --- a/metadata-ingestion/tests/unit/test_ge_profiling_config.py +++ b/metadata-ingestion/tests/unit/test_ge_profiling_config.py @@ -5,7 +5,7 @@ def test_profile_table_level_only(): config = GEProfilingConfig.parse_obj( - {"enabled": True, "profile_table_level_only": True} + {"enabled": True, "profile_table_level_only": True}, ) assert config.any_field_level_metrics_enabled() is False @@ -14,7 +14,7 @@ def test_profile_table_level_only(): "enabled": True, "profile_table_level_only": True, "include_field_max_value": False, - } + }, ) assert config.any_field_level_metrics_enabled() is False @@ -29,5 +29,5 @@ def test_profile_table_level_only_fails_with_field_metric_enabled(): "enabled": True, "profile_table_level_only": True, "include_field_max_value": True, - } + }, ) diff --git a/metadata-ingestion/tests/unit/test_generic_aspect_transformer.py b/metadata-ingestion/tests/unit/test_generic_aspect_transformer.py index 52d7aa7f509c9e..b727e191620ed0 100644 --- a/metadata-ingestion/tests/unit/test_generic_aspect_transformer.py +++ b/metadata-ingestion/tests/unit/test_generic_aspect_transformer.py @@ -40,7 +40,7 @@ def make_mce_datajob( if aspects is None: aspects = [StatusClass(removed=False)] return MetadataChangeEventClass( - proposedSnapshot=DataJobSnapshotClass(urn=entity_urn, aspects=aspects) + proposedSnapshot=DataJobSnapshotClass(urn=entity_urn, aspects=aspects), ) @@ -78,7 +78,9 @@ def __init__(self): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "DummyGenericAspectTransformer": return cls() @@ -89,7 +91,10 @@ def aspect_name(self) -> str: return "customAspect" def transform_generic_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[GenericAspectClass] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[GenericAspectClass], ) -> Optional[GenericAspectClass]: value = ( aspect.value if aspect else json.dumps({"customAspect": 10}).encode("utf-8") @@ -108,8 +113,8 @@ def test_add_generic_aspect_when_mce_received(self): inputs = [mce_dataset, mce_datajob, EndOfStream()] outputs = list( DummyGenericAspectTransformer().transform( - [RecordEnvelope(i, metadata={}) for i in inputs] - ) + [RecordEnvelope(i, metadata={}) for i in inputs], + ), ) assert len(outputs) == len(inputs) + 1 @@ -129,13 +134,13 @@ def test_add_generic_aspect_when_mce_received(self): def test_add_generic_aspect_when_mcpw_received(self): mcpw_dataset = make_mcpw() mcpw_datajob = make_mcpw( - entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)" + entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)", ) inputs = [mcpw_dataset, mcpw_datajob, EndOfStream()] outputs = list( DummyGenericAspectTransformer().transform( - [RecordEnvelope(i, metadata={}) for i in inputs] - ) + [RecordEnvelope(i, metadata={}) for i in inputs], + ), ) assert len(outputs) == len(inputs) + 1 @@ -155,13 +160,13 @@ def test_add_generic_aspect_when_mcpw_received(self): def test_add_generic_aspect_when_mcpc_received(self): mcpc_dataset = make_mcpc() mcpc_datajob = make_mcpc( - entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)" + entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)", ) inputs = [mcpc_dataset, mcpc_datajob, EndOfStream()] outputs = list( DummyGenericAspectTransformer().transform( - [RecordEnvelope(i, metadata={}) for i in inputs] - ) + [RecordEnvelope(i, metadata={}) for i in inputs], + ), ) assert len(outputs) == len(inputs) + 1 @@ -189,7 +194,7 @@ def test_modify_generic_aspect_when_mcpc_received(self): ), ) mcpc_datajob = make_mcpc( - entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)" + entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)", ) inputs = [ mcpc_dataset_without_custom_aspect, @@ -199,8 +204,8 @@ def test_modify_generic_aspect_when_mcpc_received(self): ] outputs = list( DummyGenericAspectTransformer().transform( - [RecordEnvelope(i, metadata={}) for i in inputs] - ) + [RecordEnvelope(i, metadata={}) for i in inputs], + ), ) assert len(outputs) == len(inputs) + 1 @@ -229,7 +234,9 @@ def __init__(self): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "DummyRemoveGenericAspectTransformer": return cls() @@ -240,7 +247,10 @@ def aspect_name(self) -> str: return "customAspect" def transform_generic_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[GenericAspectClass] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[GenericAspectClass], ) -> Optional[GenericAspectClass]: return None @@ -257,7 +267,7 @@ def test_remove_generic_aspect_when_mcpc_received(self): ), ) mcpc_datajob = make_mcpc( - entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)" + entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)", ) inputs = [ mcpc_dataset_without_custom_aspect, @@ -267,8 +277,8 @@ def test_remove_generic_aspect_when_mcpc_received(self): ] outputs = list( DummyRemoveGenericAspectTransformer().transform( - [RecordEnvelope(i, metadata={}) for i in inputs] - ) + [RecordEnvelope(i, metadata={}) for i in inputs], + ), ) # Check that the second entry is removed. diff --git a/metadata-ingestion/tests/unit/test_hana_source.py b/metadata-ingestion/tests/unit/test_hana_source.py index aa9d37069092e3..b965a30136d647 100644 --- a/metadata-ingestion/tests/unit/test_hana_source.py +++ b/metadata-ingestion/tests/unit/test_hana_source.py @@ -29,7 +29,7 @@ def test_hana_uri_native(): "password": "password", "host_port": "host:39041", "scheme": "hana+hdbcli", - } + }, ) assert config.get_sql_alchemy_url() == "hana+hdbcli://user:password@host:39041" @@ -46,7 +46,7 @@ def test_hana_uri_native_db(): "host_port": "host:39041", "scheme": "hana+hdbcli", "database": "database", - } + }, ) assert ( config.get_sql_alchemy_url() diff --git a/metadata-ingestion/tests/unit/test_hive_source.py b/metadata-ingestion/tests/unit/test_hive_source.py index 2eeebdc8cd1f09..976866df0e4534 100644 --- a/metadata-ingestion/tests/unit/test_hive_source.py +++ b/metadata-ingestion/tests/unit/test_hive_source.py @@ -68,7 +68,7 @@ def test_hive_configuration_get_avro_schema_from_native_data_type(): "native_data_type": "string", "_nullable": True, }, - } + }, ], }, }, diff --git a/metadata-ingestion/tests/unit/test_iceberg.py b/metadata-ingestion/tests/unit/test_iceberg.py index 48524450caf36e..b13465a520ea8b 100644 --- a/metadata-ingestion/tests/unit/test_iceberg.py +++ b/metadata-ingestion/tests/unit/test_iceberg.py @@ -70,7 +70,9 @@ def with_iceberg_source(processing_threads: int = 1, **kwargs: Any) -> IcebergSo return IcebergSource( ctx=PipelineContext(run_id="iceberg-source-test"), config=IcebergSourceConfig( - catalog=catalog, processing_threads=processing_threads, **kwargs + catalog=catalog, + processing_threads=processing_threads, + **kwargs, ), ) @@ -78,7 +80,8 @@ def with_iceberg_source(processing_threads: int = 1, **kwargs: Any) -> IcebergSo def with_iceberg_profiler() -> IcebergProfiler: iceberg_source_instance = with_iceberg_source() return IcebergProfiler( - iceberg_source_instance.report, iceberg_source_instance.config.profiling + iceberg_source_instance.report, + iceberg_source_instance.config.profiling, ) @@ -155,7 +158,7 @@ def test_config_support_nested_dicts(): "nested_array": ["a1", "a2"], "subnested_dict": {"subnested_key": "subnested_value"}, }, - } + }, } test_config = IcebergSourceConfig(catalog=catalog) assert isinstance(test_config.catalog["test"]["nested_dict"], Dict) @@ -163,7 +166,8 @@ def test_config_support_nested_dicts(): assert isinstance(test_config.catalog["test"]["nested_dict"]["nested_array"], List) assert test_config.catalog["test"]["nested_dict"]["nested_array"][0] == "a1" assert isinstance( - test_config.catalog["test"]["nested_dict"]["subnested_dict"], Dict + test_config.catalog["test"]["nested_dict"]["subnested_dict"], + Dict, ) assert ( test_config.catalog["test"]["nested_dict"]["subnested_dict"]["subnested_key"] @@ -203,7 +207,8 @@ def test_config_support_nested_dicts(): ], ) def test_iceberg_primitive_type_to_schema_field( - iceberg_type: PrimitiveType, expected_schema_field_type: Any + iceberg_type: PrimitiveType, + expected_schema_field_type: Any, ) -> None: """ Test converting a primitive typed Iceberg field to a SchemaField @@ -211,10 +216,18 @@ def test_iceberg_primitive_type_to_schema_field( iceberg_source_instance = with_iceberg_source() for column in [ NestedField( - 1, "required_field", iceberg_type, True, "required field documentation" + 1, + "required_field", + iceberg_type, + True, + "required field documentation", ), NestedField( - 1, "optional_field", iceberg_type, False, "optional field documentation" + 1, + "optional_field", + iceberg_type, + False, + "optional field documentation", ), ]: schema = Schema(column) @@ -262,7 +275,8 @@ def test_iceberg_primitive_type_to_schema_field( ], ) def test_iceberg_list_to_schema_field( - iceberg_type: PrimitiveType, expected_array_nested_type: Any + iceberg_type: PrimitiveType, + expected_array_nested_type: Any, ) -> None: """ Test converting a list typed Iceberg field to an ArrayType SchemaField, including the list nested type. @@ -304,7 +318,10 @@ def test_iceberg_list_to_schema_field( f"Expected 1 field, but got {len(schema_fields)}" ) assert_field( - schema_fields[0], list_column.doc, list_column.optional, ArrayTypeClass + schema_fields[0], + list_column.doc, + list_column.optional, + ArrayTypeClass, ) assert isinstance(schema_fields[0].type.type, ArrayType), ( f"Field type {schema_fields[0].type.type} was expected to be {ArrayType}" @@ -347,7 +364,8 @@ def test_iceberg_list_to_schema_field( ], ) def test_iceberg_map_to_schema_field( - iceberg_type: PrimitiveType, expected_map_type: Any + iceberg_type: PrimitiveType, + expected_map_type: Any, ) -> None: """ Test converting a map typed Iceberg field to a MapType SchemaField, where the key is the same type as the value. @@ -391,7 +409,10 @@ def test_iceberg_map_to_schema_field( f"Expected 3 fields, but got {len(schema_fields)}" ) assert_field( - schema_fields[0], map_column.doc, map_column.optional, ArrayTypeClass + schema_fields[0], + map_column.doc, + map_column.optional, + ArrayTypeClass, ) # The second field will be the key type @@ -438,24 +459,35 @@ def test_iceberg_map_to_schema_field( ], ) def test_iceberg_struct_to_schema_field( - iceberg_type: PrimitiveType, expected_schema_field_type: Any + iceberg_type: PrimitiveType, + expected_schema_field_type: Any, ) -> None: """ Test converting a struct typed Iceberg field to a RecordType SchemaField. """ field1 = NestedField(11, "field1", iceberg_type, True, "field documentation") struct_column = NestedField( - 1, "structField", StructType(field1), True, "struct documentation" + 1, + "structField", + StructType(field1), + True, + "struct documentation", ) iceberg_source_instance = with_iceberg_source() schema = Schema(struct_column) schema_fields = iceberg_source_instance._get_schema_fields_for_schema(schema) assert len(schema_fields) == 2, f"Expected 2 fields, but got {len(schema_fields)}" assert_field( - schema_fields[0], struct_column.doc, struct_column.optional, RecordTypeClass + schema_fields[0], + struct_column.doc, + struct_column.optional, + RecordTypeClass, ) assert_field( - schema_fields[1], field1.doc, field1.optional, expected_schema_field_type + schema_fields[1], + field1.doc, + field1.optional, + expected_schema_field_type, ) @@ -491,7 +523,9 @@ def test_iceberg_struct_to_schema_field( ], ) def test_iceberg_profiler_value_render( - value_type: IcebergType, value: Any, expected_value: Optional[str] + value_type: IcebergType, + value: Any, + expected_value: Optional[str], ) -> None: iceberg_profiler_instance = with_iceberg_profiler() assert ( @@ -511,7 +545,7 @@ def test_avro_decimal_bytes_nullable() -> None: decimal_avro_schema = avro.schema.parse(decimal_avro_schema_string) print("\nDecimal (bytes)") print( - f"Original avro schema string: {decimal_avro_schema_string}" + f"Original avro schema string: {decimal_avro_schema_string}", ) print(f"After avro parsing, _nullable attribute is missing: {decimal_avro_schema}") @@ -519,20 +553,20 @@ def test_avro_decimal_bytes_nullable() -> None: decimal_fixed_avro_schema = avro.schema.parse(decimal_fixed_avro_schema_string) print("\nDecimal (fixed)") print( - f"Original avro schema string: {decimal_fixed_avro_schema_string}" + f"Original avro schema string: {decimal_fixed_avro_schema_string}", ) print( - f"After avro parsing, _nullable attribute is preserved: {decimal_fixed_avro_schema}" + f"After avro parsing, _nullable attribute is preserved: {decimal_fixed_avro_schema}", ) boolean_avro_schema_string = """{"type": "record", "name": "__struct_", "fields": [{"type": {"type": "boolean", "native_data_type": "boolean", "_nullable": false}, "name": "required_field", "doc": "required field documentation"}]}""" boolean_avro_schema = avro.schema.parse(boolean_avro_schema_string) print("\nBoolean") print( - f"Original avro schema string: {boolean_avro_schema_string}" + f"Original avro schema string: {boolean_avro_schema_string}", ) print( - f"After avro parsing, _nullable attribute is preserved: {boolean_avro_schema}" + f"After avro parsing, _nullable attribute is preserved: {boolean_avro_schema}", ) @@ -573,7 +607,7 @@ def test_exception_while_listing_namespaces() -> None: source = with_iceberg_source(processing_threads=2) mock_catalog = MockCatalogExceptionListingNamespaces({}) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog, pytest.raises(Exception): get_catalog.return_value = mock_catalog [*source.get_workunits_internal()] @@ -595,7 +629,7 @@ def test_known_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceA/table1", io=PyArrowFileIO(), catalog=None, - ) + ), }, "no_such_namespace": {}, "namespaceB": { @@ -636,7 +670,7 @@ def test_known_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceC/table4", io=PyArrowFileIO(), catalog=None, - ) + ), }, "namespaceD": { "table5": lambda: Table( @@ -650,12 +684,12 @@ def test_known_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceA/table5", io=PyArrowFileIO(), catalog=None, - ) + ), }, - } + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] @@ -696,7 +730,7 @@ def test_unknown_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceA/table1", io=PyArrowFileIO(), catalog=None, - ) + ), }, "generic_exception": {}, "namespaceB": { @@ -737,7 +771,7 @@ def test_unknown_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceC/table4", io=PyArrowFileIO(), catalog=None, - ) + ), }, "namespaceD": { "table5": lambda: Table( @@ -751,12 +785,12 @@ def test_unknown_exception_while_listing_tables() -> None: metadata_location="s3://abcdefg/namespaceA/table5", io=PyArrowFileIO(), catalog=None, - ) + ), }, - } + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] @@ -797,13 +831,13 @@ def test_proper_run_with_multiple_namespaces() -> None: metadata_location="s3://abcdefg/namespaceA/table1", io=PyArrowFileIO(), catalog=None, - ) + ), }, "namespaceB": {}, - } + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] @@ -927,10 +961,10 @@ def test_filtering() -> None: catalog=None, ), }, - } + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] @@ -1024,11 +1058,11 @@ def _raise_server_error(): "table7": _raise_file_not_found_error, "table8": _raise_no_such_iceberg_table_exception, "table9": _raise_server_error, - } - } + }, + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] @@ -1110,11 +1144,11 @@ def _raise_exception(): catalog=None, ), "table5": _raise_exception, - } - } + }, + }, ) with patch( - "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog", ) as get_catalog: get_catalog.return_value = mock_catalog wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] diff --git a/metadata-ingestion/tests/unit/test_kafka_sink.py b/metadata-ingestion/tests/unit/test_kafka_sink.py index 9f4062bf93bf82..37e8acebcf623a 100644 --- a/metadata-ingestion/tests/unit/test_kafka_sink.py +++ b/metadata-ingestion/tests/unit/test_kafka_sink.py @@ -20,7 +20,8 @@ class KafkaSinkTest(unittest.TestCase): @patch("datahub.emitter.kafka_emitter.SerializingProducer", autospec=True) def test_kafka_sink_config(self, mock_producer, mock_context): kafka_sink = DatahubKafkaSink.create( - {"connection": {"bootstrap": "foobar:9092"}}, mock_context + {"connection": {"bootstrap": "foobar:9092"}}, + mock_context, ) kafka_sink.close() assert ( @@ -51,7 +52,8 @@ def test_kafka_sink_mcp(self, mock_producer, mock_callback): PipelineContext(run_id="test"), ) kafka_sink.write_record_async( - RecordEnvelope(record=mcp, metadata={}), mock_callback + RecordEnvelope(record=mcp, metadata={}), + mock_callback, ) kafka_sink.close() assert mock_producer.call_count == 2 # constructor should be called @@ -63,7 +65,8 @@ def test_kafka_sink_write(self, mock_k_callback, mock_producer, mock_context): mock_k_callback_instance = mock_k_callback.return_value callback = MagicMock(spec=WriteCallback) kafka_sink = DatahubKafkaSink.create( - {"connection": {"bootstrap": "foobar:9092"}}, mock_context + {"connection": {"bootstrap": "foobar:9092"}}, + mock_context, ) mock_producer_instance = kafka_sink.emitter.producers[MCE_KEY] @@ -86,7 +89,9 @@ def test_kafka_sink_write(self, mock_k_callback, mock_producer, mock_context): mock_producer_instance.poll.assert_called_once() # producer should call poll() first self.validate_kafka_callback( - mock_k_callback, re, callback + mock_k_callback, + re, + callback, ) # validate kafka callback was constructed appropriately # validate that confluent_kafka.Producer.produce was called with the right arguments @@ -111,7 +116,9 @@ def test_kafka_sink_close(self, mock_producer, mock_context): @patch("datahub.ingestion.sink.datahub_kafka.WriteCallback", autospec=True) def test_kafka_callback_class(self, mock_w_callback, mock_re): callback = _KafkaCallback( - SinkReport(), record_envelope=mock_re, write_callback=mock_w_callback + SinkReport(), + record_envelope=mock_re, + write_callback=mock_w_callback, ) mock_error = MagicMock() mock_message = MagicMock() diff --git a/metadata-ingestion/tests/unit/test_kafka_source.py b/metadata-ingestion/tests/unit/test_kafka_source.py index 1a8afe1b956fae..9e27640c2cb96b 100644 --- a/metadata-ingestion/tests/unit/test_kafka_source.py +++ b/metadata-ingestion/tests/unit/test_kafka_source.py @@ -40,7 +40,8 @@ @pytest.fixture def mock_admin_client(): with patch( - "datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True + "datahub.ingestion.source.kafka.kafka.AdminClient", + autospec=True, ) as mock: yield mock @@ -152,7 +153,8 @@ def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_cl ] assert len(data_platform_aspects) == 1 assert data_platform_aspects[0].instance == make_dataplatform_instance_urn( - PLATFORM, PLATFORM_INSTANCE + PLATFORM, + PLATFORM_INSTANCE, ) # The default browse path should include the platform_instance value @@ -228,7 +230,9 @@ def test_close(mock_kafka, mock_admin_client): ) @patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True) def test_kafka_source_workunits_schema_registry_subject_name_strategies( - mock_kafka_consumer, mock_schema_registry_client, mock_admin_client + mock_kafka_consumer, + mock_schema_registry_client, + mock_admin_client, ): # Setup the topic to key/value schema mappings for all types of schema registry subject name strategies. # ,) @@ -489,7 +493,8 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]: @patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True) @patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True) def test_kafka_source_succeeds_with_admin_client_init_error( - mock_kafka, mock_kafka_admin_client + mock_kafka, + mock_kafka_admin_client, ): mock_kafka_instance = mock_kafka.return_value mock_cluster_metadata = MagicMock() @@ -519,7 +524,8 @@ def test_kafka_source_succeeds_with_admin_client_init_error( @patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True) @patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True) def test_kafka_source_succeeds_with_describe_configs_error( - mock_kafka, mock_kafka_admin_client + mock_kafka, + mock_kafka_admin_client, ): mock_kafka_instance = mock_kafka.return_value mock_cluster_metadata = MagicMock() @@ -555,7 +561,9 @@ def test_kafka_source_succeeds_with_describe_configs_error( ) @patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True) def test_kafka_source_topic_meta_mappings( - mock_kafka_consumer, mock_schema_registry_client, mock_admin_client + mock_kafka_consumer, + mock_schema_registry_client, + mock_admin_client, ): # Setup the topic to key/value schema mappings for all types of schema registry subject name strategies. # ,) @@ -585,14 +593,14 @@ def test_kafka_source_topic_meta_mappings( "has_pii": True, "int_property": 1, "double_property": 2.5, - } + }, ), schema_type="AVRO", ), subject="topic1-value", version=1, ), - ) + ), } # Mock the kafka consumer @@ -680,7 +688,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]: asp for asp in mce.proposedSnapshot.aspects if isinstance(asp, GlobalTagsClass) ][0] assert tags_aspect == make_global_tag_aspect_with_tag_list( - ["has_pii_test", "int_meta_property"] + ["has_pii_test", "int_meta_property"], ) terms_aspect = [ @@ -692,7 +700,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]: [ "urn:li:glossaryTerm:Finance_test", "urn:li:glossaryTerm:double_meta_property", - ] + ], ) assert isinstance(workunits[1].metadata, MetadataChangeProposalWrapper) mce = workunits[2].metadata @@ -716,7 +724,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]: asp for asp in mce.proposedSnapshot.aspects if isinstance(asp, GlobalTagsClass) ][0] assert tags_aspect == make_global_tag_aspect_with_tag_list( - ["has_pii_test", "int_meta_property"] + ["has_pii_test", "int_meta_property"], ) terms_aspect = [ @@ -728,7 +736,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]: [ "urn:li:glossaryTerm:Finance_test", "urn:li:glossaryTerm:double_meta_property", - ] + ], ) assert isinstance(workunits[5].metadata, MetadataChangeProposalWrapper) @@ -755,8 +763,8 @@ def test_kafka_source_oauth_cb_configuration(): "connection": { "bootstrap": "foobar:9092", "consumer_config": { - "oauth_cb": test_kafka_ignore_warnings_on_schema_type + "oauth_cb": test_kafka_ignore_warnings_on_schema_type, }, - } - } + }, + }, ) diff --git a/metadata-ingestion/tests/unit/test_ldap_source.py b/metadata-ingestion/tests/unit/test_ldap_source.py index cc2bd605c547d0..d141d1a0f5d105 100644 --- a/metadata-ingestion/tests/unit/test_ldap_source.py +++ b/metadata-ingestion/tests/unit/test_ldap_source.py @@ -32,7 +32,7 @@ def test_parse_ldap_dn(input, expected): "admins": [ b"uid=A.B,ou=People,dc=internal,dc=machines", b"uid=C.D,ou=People,dc=internal,dc=machines", - ] + ], }, ["urn:li:corpuser:A.B", "urn:li:corpuser:C.D"], ), @@ -40,7 +40,7 @@ def test_parse_ldap_dn(input, expected): { "not_admins": [ b"doesntmatter", - ] + ], }, [], ), @@ -64,7 +64,7 @@ def test_parse_users(input, expected): "memberOf": [ b"cn=group1,ou=Groups,dc=internal,dc=machines", b"cn=group2,ou=Groups,dc=internal,dc=machines", - ] + ], }, ["urn:li:corpGroup:group1", "urn:li:corpGroup:group2"], ), @@ -72,7 +72,7 @@ def test_parse_users(input, expected): { "not_member": [ b"doesntmatter", - ] + ], }, [], ), diff --git a/metadata-ingestion/tests/unit/test_mapping.py b/metadata-ingestion/tests/unit/test_mapping.py index 75eacf9509a533..07af3ea20bf8ac 100644 --- a/metadata-ingestion/tests/unit/test_mapping.py +++ b/metadata-ingestion/tests/unit/test_mapping.py @@ -387,7 +387,7 @@ def test_operation_processor_datahub_props(): }, ], "domain": "domain1", - } + }, } processor = OperationProcessor( diff --git a/metadata-ingestion/tests/unit/test_mlflow_source.py b/metadata-ingestion/tests/unit/test_mlflow_source.py index e882296b6f331d..e0e2380ba91659 100644 --- a/metadata-ingestion/tests/unit/test_mlflow_source.py +++ b/metadata-ingestion/tests/unit/test_mlflow_source.py @@ -142,7 +142,7 @@ def test_make_external_link_remote_via_config(source, model_version): custom_base_url = "https://custom-server.org" source.config.base_external_url = custom_base_url source.client = MlflowClient( - tracking_uri="https://dummy-mlflow-tracking-server.org" + tracking_uri="https://dummy-mlflow-tracking-server.org", ) expected_url = f"{custom_base_url}/#/models/{model_version.name}/versions/{model_version.version}" diff --git a/metadata-ingestion/tests/unit/test_neo4j_source.py b/metadata-ingestion/tests/unit/test_neo4j_source.py index 62586718e86067..426684f97b674e 100644 --- a/metadata-ingestion/tests/unit/test_neo4j_source.py +++ b/metadata-ingestion/tests/unit/test_neo4j_source.py @@ -19,7 +19,10 @@ def source(tracking_uri: str) -> Neo4jSource: return Neo4jSource( ctx=PipelineContext(run_id="neo4j-test"), config=Neo4jConfig( - uri=tracking_uri, env="Prod", username="test", password="test" + uri=tracking_uri, + env="Prod", + username="test", + password="test", ), ) @@ -39,11 +42,11 @@ def data(): "type": "STRING", "indexed": False, "array": False, - } + }, }, "direction": "in", "labels": ["Node_2"], - } + }, }, "RELATIONSHIP_2": { "count": 2, @@ -53,7 +56,7 @@ def data(): "type": "STRING", "indexed": False, "array": False, - } + }, }, "direction": "in", "labels": ["Node_3"], @@ -95,11 +98,11 @@ def data(): "type": "STRING", "indexed": False, "array": False, - } + }, }, "direction": "out", "labels": ["Node_2"], - } + }, }, "type": "node", "properties": { @@ -136,7 +139,7 @@ def data(): "type": "STRING", "indexed": False, "array": False, - } + }, }, }, }, @@ -150,7 +153,8 @@ def test_process_nodes(source): def test_process_relationships(source): df = source.process_relationships( - data=data(), node_df=source.process_nodes(data=data()) + data=data(), + node_df=source.process_nodes(data=data()), ) assert type(df) is pd.DataFrame @@ -188,7 +192,7 @@ def test_get_property_data_types(source): {"Node2_Property3": "STRING"}, ] assert source.get_property_data_types(results[2]["value"]["properties"]) == [ - {"Relationship1_Property1": "STRING"} + {"Relationship1_Property1": "STRING"}, ] @@ -205,14 +209,14 @@ def test_get_properties(source): "Node2_Property3", ] assert list(source.get_properties(results[2]["value"]).keys()) == [ - "Relationship1_Property1" + "Relationship1_Property1", ] def test_get_relationships(source): results = data() record = list( - results[0]["value"]["relationships"].keys() + results[0]["value"]["relationships"].keys(), ) # Get the first key from the dict_keys assert record == ["RELATIONSHIP_1"] diff --git a/metadata-ingestion/tests/unit/test_nifi_source.py b/metadata-ingestion/tests/unit/test_nifi_source.py index 30a0855d44f341..bd423c5e5656d4 100644 --- a/metadata-ingestion/tests/unit/test_nifi_source.py +++ b/metadata-ingestion/tests/unit/test_nifi_source.py @@ -23,9 +23,9 @@ def test_nifi_s3_provenance_event(): ctx = PipelineContext(run_id="test") with patch( - "datahub.ingestion.source.nifi.NifiSource.fetch_provenance_events" + "datahub.ingestion.source.nifi.NifiSource.fetch_provenance_events", ) as mock_provenance_events, patch( - "datahub.ingestion.source.nifi.NifiSource.delete_provenance" + "datahub.ingestion.source.nifi.NifiSource.delete_provenance", ) as mock_delete_provenance: mocked_functions(mock_provenance_events, mock_delete_provenance, "puts3") @@ -53,7 +53,7 @@ def test_nifi_s3_provenance_event(): config={}, target_uris=None, last_event_time=None, - ) + ), }, remotely_accessible_ports={}, connections=BidirectionalComponentGraph(), @@ -89,14 +89,14 @@ def test_nifi_s3_provenance_event(): ioAspect = workunits[5].metadata.aspect assert ioAspect.outputDatasets == [ - "urn:li:dataset:(urn:li:dataPlatform:s3,foo-nifi/tropical_data,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:s3,foo-nifi/tropical_data,PROD)", ] assert ioAspect.inputDatasets == [] with patch( - "datahub.ingestion.source.nifi.NifiSource.fetch_provenance_events" + "datahub.ingestion.source.nifi.NifiSource.fetch_provenance_events", ) as mock_provenance_events, patch( - "datahub.ingestion.source.nifi.NifiSource.delete_provenance" + "datahub.ingestion.source.nifi.NifiSource.delete_provenance", ) as mock_delete_provenance: mocked_functions(mock_provenance_events, mock_delete_provenance, "fetchs3") @@ -124,7 +124,7 @@ def test_nifi_s3_provenance_event(): config={}, target_uris=None, last_event_time=None, - ) + ), }, remotely_accessible_ports={}, connections=BidirectionalComponentGraph(), @@ -161,7 +161,7 @@ def test_nifi_s3_provenance_event(): ioAspect = workunits[5].metadata.aspect assert ioAspect.outputDatasets == [] assert ioAspect.inputDatasets == [ - "urn:li:dataset:(urn:li:dataPlatform:s3,enriched-topical-chat,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:s3,enriched-topical-chat,PROD)", ] @@ -282,39 +282,42 @@ def mocked_functions(mock_provenance_events, mock_delete_provenance, provenance_ @pytest.mark.parametrize("auth", ["SINGLE_USER", "BASIC_AUTH"]) def test_auth_without_password(auth): with pytest.raises( - ValueError, match=f"`username` and `password` is required for {auth} auth" + ValueError, + match=f"`username` and `password` is required for {auth} auth", ): NifiSourceConfig.parse_obj( { "site_url": "https://localhost:8443", "auth": auth, "username": "someuser", - } + }, ) @pytest.mark.parametrize("auth", ["SINGLE_USER", "BASIC_AUTH"]) def test_auth_without_username_and_password(auth): with pytest.raises( - ValueError, match=f"`username` and `password` is required for {auth} auth" + ValueError, + match=f"`username` and `password` is required for {auth} auth", ): NifiSourceConfig.parse_obj( { "site_url": "https://localhost:8443", "auth": auth, - } + }, ) def test_client_cert_auth_without_client_cert_file(): with pytest.raises( - ValueError, match="`client_cert_file` is required for CLIENT_CERT auth" + ValueError, + match="`client_cert_file` is required for CLIENT_CERT auth", ): NifiSourceConfig.parse_obj( { "site_url": "https://localhost:8443", "auth": "CLIENT_CERT", - } + }, ) diff --git a/metadata-ingestion/tests/unit/test_oracle_source.py b/metadata-ingestion/tests/unit/test_oracle_source.py index 0477044354576b..4bef9c7459cbdd 100644 --- a/metadata-ingestion/tests/unit/test_oracle_source.py +++ b/metadata-ingestion/tests/unit/test_oracle_source.py @@ -17,7 +17,7 @@ def test_oracle_config(): { **base_config, "service_name": "svc01", - } + }, ) assert ( config.get_sql_alchemy_url() @@ -30,11 +30,11 @@ def test_oracle_config(): **base_config, "database": "db", "service_name": "svc01", - } + }, ) with unittest.mock.patch( - "datahub.ingestion.source.sql.sql_common.SQLAlchemySource.get_workunits" + "datahub.ingestion.source.sql.sql_common.SQLAlchemySource.get_workunits", ): OracleSource.create( { diff --git a/metadata-ingestion/tests/unit/test_packaging.py b/metadata-ingestion/tests/unit/test_packaging.py index 4b99be750a4da7..c4bc4f47650be2 100644 --- a/metadata-ingestion/tests/unit/test_packaging.py +++ b/metadata-ingestion/tests/unit/test_packaging.py @@ -4,7 +4,7 @@ @pytest.mark.filterwarnings( - "ignore:pkg_resources is deprecated as an API:DeprecationWarning" + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", ) def test_datahub_version(): # Simply importing pkg_resources checks for unsatisfied dependencies. diff --git a/metadata-ingestion/tests/unit/test_postgres_source.py b/metadata-ingestion/tests/unit/test_postgres_source.py index 25140cf1b997f8..cb8f54623ceeff 100644 --- a/metadata-ingestion/tests/unit/test_postgres_source.py +++ b/metadata-ingestion/tests/unit/test_postgres_source.py @@ -51,7 +51,7 @@ def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock): execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] config = PostgresConfig.parse_obj( - {**_base_config(), "sqlalchemy_uri": "custom_url"} + {**_base_config(), "sqlalchemy_uri": "custom_url"}, ) source = PostgresSource(config, PipelineContext(run_id="test")) _ = list(source.get_inspectors()) @@ -64,7 +64,9 @@ def test_database_in_identifier(): mock_inspector = mock.MagicMock() assert ( PostgresSource(config, PipelineContext(run_id="test")).get_identifier( - schema="superset", entity="logs", inspector=mock_inspector + schema="superset", + entity="logs", + inspector=mock_inspector, ) == "postgres.superset.logs" ) @@ -76,7 +78,9 @@ def test_current_sqlalchemy_database_in_identifier(): mock_inspector.engine.url.database = "current_db" assert ( PostgresSource(config, PipelineContext(run_id="test")).get_identifier( - schema="superset", entity="logs", inspector=mock_inspector + schema="superset", + entity="logs", + inspector=mock_inspector, ) == "current_db.superset.logs" ) diff --git a/metadata-ingestion/tests/unit/test_powerbi_parser.py b/metadata-ingestion/tests/unit/test_powerbi_parser.py index a487a3a5b87f8b..b0064e359b964c 100644 --- a/metadata-ingestion/tests/unit/test_powerbi_parser.py +++ b/metadata-ingestion/tests/unit/test_powerbi_parser.py @@ -31,7 +31,7 @@ def creator(): reporter=PowerBiDashboardSourceReport(), config=config, platform_instance_resolver=ResolvePlatformInstanceFromDatasetTypeMapping( - config + config, ), ) diff --git a/metadata-ingestion/tests/unit/test_protobuf_util.py b/metadata-ingestion/tests/unit/test_protobuf_util.py index 86418d2d97a59e..c93a5a791b8687 100644 --- a/metadata-ingestion/tests/unit/test_protobuf_util.py +++ b/metadata-ingestion/tests/unit/test_protobuf_util.py @@ -17,7 +17,7 @@ def test_protobuf_schema_to_mce_fields_with_single_empty_message() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_1.proto", schema) + ProtobufSchema("main_1.proto", schema), ) assert 0 == len(fields) @@ -34,7 +34,8 @@ def test_protobuf_schema_to_mce_fields_with_single_message_single_field_key_sche } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_2.proto", schema), is_key_schema=True + ProtobufSchema("main_2.proto", schema), + is_key_schema=True, ) assert 1 == len(fields) assert ( @@ -64,7 +65,7 @@ def test_protobuf_schema_to_mce_fields_with_two_messages_enum() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_3.proto", schema) + ProtobufSchema("main_3.proto", schema), ) assert 5 == len(fields) @@ -98,7 +99,7 @@ def test_protobuf_schema_to_mce_fields_nested(): } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_4.proto", schema) + ProtobufSchema("main_4.proto", schema), ) assert 4 == len(fields) @@ -127,7 +128,7 @@ def test_protobuf_schema_to_mce_fields_repeated() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_5.proto", schema) + ProtobufSchema("main_5.proto", schema), ) assert 1 == len(fields) @@ -154,7 +155,7 @@ def test_protobuf_schema_to_mce_fields_nestd_repeated() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_6.proto", schema) + ProtobufSchema("main_6.proto", schema), ) assert 2 == len(fields) @@ -189,7 +190,7 @@ def test_protobuf_schema_to_mce_fields_map() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_7.proto", schema) + ProtobufSchema("main_7.proto", schema), ) assert 7 == len(fields) @@ -253,7 +254,7 @@ def test_protobuf_schema_to_mce_fields_with_complex_schema() -> None: } """ fields: List[SchemaFieldClass] = protobuf_schema_to_mce_fields( - ProtobufSchema("main_8.proto", schema) + ProtobufSchema("main_8.proto", schema), ) assert 10 == len(fields) diff --git a/metadata-ingestion/tests/unit/test_pulsar_source.py b/metadata-ingestion/tests/unit/test_pulsar_source.py index 7e83030f5c8404..add7a41b1fd6e7 100644 --- a/metadata-ingestion/tests/unit/test_pulsar_source.py +++ b/metadata-ingestion/tests/unit/test_pulsar_source.py @@ -29,17 +29,18 @@ class TestPulsarSourceConfig: def test_pulsar_source_config_valid_web_service_url(self): assert ( PulsarSourceConfig().web_service_url_scheme_host_port( - "http://localhost:8080/" + "http://localhost:8080/", ) == "http://localhost:8080" ) def test_pulsar_source_config_invalid_web_service_url_scheme(self): with pytest.raises( - ValueError, match=r"Scheme should be http or https, found ftp" + ValueError, + match=r"Scheme should be http or https, found ftp", ): PulsarSourceConfig().web_service_url_scheme_host_port( - "ftp://localhost:8080/" + "ftp://localhost:8080/", ) def test_pulsar_source_config_invalid_web_service_url_host(self): @@ -48,7 +49,7 @@ def test_pulsar_source_config_invalid_web_service_url_host(self): match=r"Not a valid hostname, hostname contains invalid characters, found localhost&", ): PulsarSourceConfig().web_service_url_scheme_host_port( - "http://localhost&:8080/" + "http://localhost&:8080/", ) @@ -97,7 +98,7 @@ def test_pulsar_source_get_token_jwt(self): def test_pulsar_source_get_token_oauth(self, mock_post, mock_get): ctx = PipelineContext(run_id="test") mock_get.return_value.json.return_value = { - "token_endpoint": "http://127.0.0.1:8083/realms/pulsar/protocol/openid-connect/token" + "token_endpoint": "http://127.0.0.1:8083/realms/pulsar/protocol/openid-connect/token", } pulsar_source = PulsarSource.create( @@ -124,7 +125,7 @@ def test_pulsar_source_get_workunits_all_tenant(self, mock_session): # Mock fetching Pulsar metadata with patch( - "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata" + "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata", ) as mock: mock.side_effect = [ ["t_1"], # tenant list @@ -165,7 +166,7 @@ def test_pulsar_source_get_workunits_custom_tenant(self, mock_session): # Mock fetching Pulsar metadata with patch( - "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata" + "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata", ) as mock: mock.side_effect = [ ["t_1/ns_1"], # namespaces list @@ -209,7 +210,7 @@ def test_pulsar_source_get_workunits_patterns(self, mock_session): # Mock fetching Pulsar metadata with patch( - "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata" + "datahub.ingestion.source.pulsar.PulsarSource._get_pulsar_metadata", ) as mock: mock.side_effect = [ ["t_1/ns_1", "t_2/ns_1"], # namespaces list diff --git a/metadata-ingestion/tests/unit/test_redash_source.py b/metadata-ingestion/tests/unit/test_redash_source.py index 32ab200847dc6c..3de7c07d70e44b 100644 --- a/metadata-ingestion/tests/unit/test_redash_source.py +++ b/metadata-ingestion/tests/unit/test_redash_source.py @@ -118,7 +118,7 @@ "percentFormat": "0[.]00%", "sortX": True, "seriesOptions": { - "value": {"zIndex": 0, "index": 0, "type": "pie", "yAxis": 0} + "value": {"zIndex": 0, "index": 0, "type": "pie", "yAxis": 0}, }, "valuesOptions": {"Yes": {}, "No": {}}, "xAxis": {"labels": {"enabled": True}, "type": "-"}, @@ -447,7 +447,7 @@ "percentFormat": "0[.]00%", "sortX": True, "seriesOptions": { - "value": {"zIndex": 0, "index": 0, "type": "pie", "yAxis": 0} + "value": {"zIndex": 0, "index": 0, "type": "pie", "yAxis": 0}, }, "valuesOptions": {"Yes": {}, "No": {}}, "xAxis": {"labels": {"enabled": True}, "type": "-"}, @@ -491,16 +491,18 @@ def test_get_dashboard_snapshot_before_v10(): lastModified=ChangeAuditStamps( created=None, lastModified=AuditStamp( - time=1628882055288, actor="urn:li:corpuser:unknown" + time=1628882055288, + actor="urn:li:corpuser:unknown", ), ), dashboardUrl="http://localhost:5000/dashboard/my-dashboard", customProperties={}, - ) + ), ], ) result = redash_source()._get_dashboard_snapshot( - mock_dashboard_response, "9.0.0-beta" + mock_dashboard_response, + "9.0.0-beta", ) assert result == expected @@ -521,16 +523,18 @@ def test_get_dashboard_snapshot_after_v10(): lastModified=ChangeAuditStamps( created=None, lastModified=AuditStamp( - time=1628882055288, actor="urn:li:corpuser:unknown" + time=1628882055288, + actor="urn:li:corpuser:unknown", ), ), dashboardUrl="http://localhost:5000/dashboards/3", customProperties={}, - ) + ), ], ) result = redash_source()._get_dashboard_snapshot( - mock_dashboard_response, "10.0.0-beta" + mock_dashboard_response, + "10.0.0-beta", ) assert result == expected @@ -549,13 +553,14 @@ def test_get_known_viz_chart_snapshot(mocked_data_source): lastModified=ChangeAuditStamps( created=None, lastModified=AuditStamp( - time=1628882022544, actor="urn:li:corpuser:unknown" + time=1628882022544, + actor="urn:li:corpuser:unknown", ), ), chartUrl="http://localhost:5000/queries/4#10", inputs=["urn:li:dataset:(urn:li:dataPlatform:mysql,Rfam,PROD)"], type="PIE", - ) + ), ], ) viz_data = mock_chart_response.get("visualizations", [])[2] @@ -580,13 +585,14 @@ def test_get_unknown_viz_chart_snapshot(mocked_data_source): lastModified=ChangeAuditStamps( created=None, lastModified=AuditStamp( - time=1628882009571, actor="urn:li:corpuser:unknown" + time=1628882009571, + actor="urn:li:corpuser:unknown", ), ), chartUrl="http://localhost:5000/queries/4#9", inputs=["urn:li:dataset:(urn:li:dataPlatform:mysql,Rfam,PROD)"], type="TABLE", - ) + ), ], ) viz_data = mock_chart_response.get("visualizations", [])[1] @@ -673,8 +679,10 @@ def test_get_full_qualified_name(): result.append( get_full_qualified_name( - platform=platform, database_name=database_name, table_name=table_name - ) + platform=platform, + database_name=database_name, + table_name=table_name, + ), ) assert expected == result @@ -705,7 +713,8 @@ def test_get_chart_snapshot_parse_table_names_from_sql(mocked_data_source): lastModified=ChangeAuditStamps( created=None, lastModified=AuditStamp( - time=1628882022544, actor="urn:li:corpuser:unknown" + time=1628882022544, + actor="urn:li:corpuser:unknown", ), ), chartUrl="http://localhost:5000/queries/4#10", @@ -715,12 +724,13 @@ def test_get_chart_snapshot_parse_table_names_from_sql(mocked_data_source): "urn:li:dataset:(urn:li:dataPlatform:mysql,rfam.staffs,PROD)", ], type="PIE", - ) + ), ], ) viz_data = mock_chart_response.get("visualizations", [])[2] result = redash_source_parse_table_names_from_sql()._get_chart_snapshot( - mock_chart_response, viz_data + mock_chart_response, + viz_data, ) assert result == expected diff --git a/metadata-ingestion/tests/unit/test_rest_sink.py b/metadata-ingestion/tests/unit/test_rest_sink.py index 564cf613c04464..c10ab2505b9c92 100644 --- a/metadata-ingestion/tests/unit/test_rest_sink.py +++ b/metadata-ingestion/tests/unit/test_rest_sink.py @@ -40,8 +40,8 @@ dataset="urn:li:dataset:(urn:li:dataPlatform:bigquery,upstream2,PROD)", type="TRANSFORMED", ), - ] - ) + ], + ), ], ), ), @@ -71,12 +71,12 @@ "dataset": "urn:li:dataset:(urn:li:dataPlatform:bigquery,upstream2,PROD)", "type": "TRANSFORMED", }, - ] - } - } + ], + }, + }, ], - } - } + }, + }, }, "systemMetadata": { "lastObserved": FROZEN_TIME, @@ -105,7 +105,7 @@ type=models.ChartTypeClass.SCATTER, ), ], - ) + ), ), "/entities?action=ingest", { @@ -130,11 +130,11 @@ }, }, "type": "SCATTER", - } - } + }, + }, ], - } - } + }, + }, }, "systemMetadata": { "lastObserved": FROZEN_TIME, @@ -157,9 +157,9 @@ name="User Deletions", description="Constructs the fct_users_deleted from logging_events", type=models.AzkabanJobTypeClass.SQL, - ) + ), ], - ) + ), ), "/entities?action=ingest", { @@ -174,11 +174,11 @@ "name": "User Deletions", "description": "Constructs the fct_users_deleted from logging_events", "type": {"string": "SQL"}, - } - } + }, + }, ], - } - } + }, + }, }, "systemMetadata": { "lastObserved": FROZEN_TIME, @@ -234,8 +234,8 @@ "totalSqlQueries": 1, "topSqlQueries": ["SELECT * FROM foo"], }, - } - ] + }, + ], }, ), ( @@ -246,7 +246,7 @@ models.OwnerClass( owner="urn:li:corpuser:fbar", type=models.OwnershipTypeClass.DATAOWNER, - ) + ), ], lastModified=models.AuditStampClass( time=0, @@ -274,7 +274,7 @@ }, "runId": "no-run-id-provided", }, - } + }, }, ), ], diff --git a/metadata-ingestion/tests/unit/test_schema_util.py b/metadata-ingestion/tests/unit/test_schema_util.py index 0a111d700cf8ce..937eb0153b9ffa 100644 --- a/metadata-ingestion/tests/unit/test_schema_util.py +++ b/metadata-ingestion/tests/unit/test_schema_util.py @@ -85,7 +85,7 @@ }, }, ], - } + }, ) @@ -103,7 +103,8 @@ def assert_field_paths_are_unique(fields: List[SchemaField]) -> None: def assert_field_paths_match( - fields: List[SchemaField], expected_field_paths: List[str] + fields: List[SchemaField], + expected_field_paths: List[str], ) -> None: log_field_paths(fields) assert len(fields) == len(expected_field_paths) @@ -640,7 +641,7 @@ def test_mce_avro_parses_okay(): "datahub", "metadata", "schema.avsc", - ) + ), ).read_text() fields = avro_schema_to_mce_fields(schema) assert len(fields) @@ -755,11 +756,12 @@ def test_logical_types_fully_specified_in_type(): "native_data_type": "decimal(3, 2)", "_nullable": True, }, - } + }, ], } fields: List[SchemaField] = avro_schema_to_mce_fields( - json.dumps(schema), default_nullable=True + json.dumps(schema), + default_nullable=True, ) assert len(fields) == 1 assert "[version=2.0].[type=test].[type=bytes].name" == fields[0].fieldPath @@ -852,7 +854,7 @@ def test_avro_schema_to_mce_fields_with_field_meta_mapping(): "operation": "add_term", "config": {"term": "{{ $match }}"}, }, - } + }, ) fields = avro_schema_to_mce_fields(schema, meta_mapping_processor=processor) expected_field_paths = [ @@ -874,9 +876,9 @@ def test_avro_schema_to_mce_fields_with_field_meta_mapping(): assert fields[2].globalTags == pii_tag_aspect assert fields[3].globalTags == pii_tag_aspect assert fields[3].glossaryTerms == make_glossary_terms_aspect_from_urn_list( - ["urn:li:glossaryTerm:PhoneNumber"] + ["urn:li:glossaryTerm:PhoneNumber"], ) assert fields[8].globalTags == pii_tag_aspect assert fields[8].glossaryTerms == make_glossary_terms_aspect_from_urn_list( - ["urn:li:glossaryTerm:Address"] + ["urn:li:glossaryTerm:Address"], ) diff --git a/metadata-ingestion/tests/unit/test_source.py b/metadata-ingestion/tests/unit/test_source.py index d2ed21fccb4cb9..67f8137957c3e6 100644 --- a/metadata-ingestion/tests/unit/test_source.py +++ b/metadata-ingestion/tests/unit/test_source.py @@ -19,11 +19,11 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: platform_id="elasticsearch", table_name="fooIndex", env="PROD", - ) + ), ), aspect=StatusClass(removed=False), ), - ) + ), ] def __init__(self, ctx: PipelineContext): diff --git a/metadata-ingestion/tests/unit/test_sql_common.py b/metadata-ingestion/tests/unit/test_sql_common.py index cfb8f55bd977f7..b86bcdfed2548e 100644 --- a/metadata-ingestion/tests/unit/test_sql_common.py +++ b/metadata-ingestion/tests/unit/test_sql_common.py @@ -24,7 +24,8 @@ def create(cls, config_dict, ctx): def get_test_sql_alchemy_source(): return _TestSQLAlchemySource.create( - config_dict={}, ctx=PipelineContext(run_id="test_ctx") + config_dict={}, + ctx=PipelineContext(run_id="test_ctx"), ) @@ -46,10 +47,10 @@ def test_generate_foreign_key(): assert fk_dict.get("name") == foreign_key.name assert [ - "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_referred_schema.test_table,PROD),test_referred_column)" + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_referred_schema.test_table,PROD),test_referred_column)", ] == foreign_key.foreignFields assert [ - "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.base_urn,PROD),test_column)" + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.base_urn,PROD),test_column)", ] == foreign_key.sourceFields @@ -70,10 +71,10 @@ def test_use_source_schema_for_foreign_key_if_not_specified(): assert fk_dict.get("name") == foreign_key.name assert [ - "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.test_table,PROD),test_referred_column)" + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.test_table,PROD),test_referred_column)", ] == foreign_key.foreignFields assert [ - "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.base_urn,PROD),test_column)" + "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:TEST,test_schema.base_urn,PROD),test_column)", ] == foreign_key.sourceFields @@ -111,7 +112,7 @@ def test_get_platform_from_sqlalchemy_uri(uri: str, expected_platform: str) -> N def test_get_db_schema_with_dots_in_view_name(): source = get_test_sql_alchemy_source() database, schema = source.get_db_schema( - dataset_identifier="database.schema.long.view.name1" + dataset_identifier="database.schema.long.view.name1", ) assert database == "database" assert schema == "schema" diff --git a/metadata-ingestion/tests/unit/test_sql_utils.py b/metadata-ingestion/tests/unit/test_sql_utils.py index 1b7dc6bcf23f1e..63964f0187ef2f 100644 --- a/metadata-ingestion/tests/unit/test_sql_utils.py +++ b/metadata-ingestion/tests/unit/test_sql_utils.py @@ -44,7 +44,9 @@ def test_guid_generators(): test_profile_pattern_matching_on_table_allow_list_test_data, ) def test_profile_pattern_matching_on_table_allow_list( - allow_pattern: str, table_name: str, result: bool + allow_pattern: str, + table_name: str, + result: bool, ) -> None: pattern = AllowDenyPattern(allow=[allow_pattern]) assert check_table_with_profile_pattern(pattern, table_name) == result @@ -73,7 +75,9 @@ def test_profile_pattern_matching_on_table_allow_list( test_profile_pattern_matching_on_table_deny_list_test_data, ) def test_profile_pattern_matching_on_table_deny_list( - deny_pattern: str, table_name: str, result: bool + deny_pattern: str, + table_name: str, + result: bool, ) -> None: pattern = AllowDenyPattern(deny=[deny_pattern]) assert check_table_with_profile_pattern(pattern, table_name) == result diff --git a/metadata-ingestion/tests/unit/test_tableau_source.py b/metadata-ingestion/tests/unit/test_tableau_source.py index ba5e5a9832a62f..ec5ad389b2c046 100644 --- a/metadata-ingestion/tests/unit/test_tableau_source.py +++ b/metadata-ingestion/tests/unit/test_tableau_source.py @@ -35,14 +35,15 @@ def test_tablea_source_handles_none_nativedatatype(): "formula": "a/b + d", } schema_field: SchemaField = tableau_field_to_schema_field( - field=field, ingest_tags=False + field=field, + ingest_tags=False, ) assert schema_field.nativeDataType == "UNKNOWN" def test_tableau_source_unescapes_lt(): res = TableauSiteSource._clean_tableau_query_parameters( - "select * from t where c1 << 135" + "select * from t where c1 << 135", ) assert res == "select * from t where c1 < 135" @@ -50,7 +51,7 @@ def test_tableau_source_unescapes_lt(): def test_tableau_source_unescapes_gt(): res = TableauSiteSource._clean_tableau_query_parameters( - "select * from t where c1 >> 135" + "select * from t where c1 >> 135", ) assert res == "select * from t where c1 > 135" @@ -58,7 +59,7 @@ def test_tableau_source_unescapes_gt(): def test_tableau_source_unescapes_gte(): res = TableauSiteSource._clean_tableau_query_parameters( - "select * from t where c1 >>= 135" + "select * from t where c1 >>= 135", ) assert res == "select * from t where c1 >= 135" @@ -66,7 +67,7 @@ def test_tableau_source_unescapes_gte(): def test_tableau_source_unescapeslgte(): res = TableauSiteSource._clean_tableau_query_parameters( - "select * from t where c1 <<= 135" + "select * from t where c1 <<= 135", ) assert res == "select * from t where c1 <= 135" @@ -74,7 +75,7 @@ def test_tableau_source_unescapeslgte(): def test_tableau_source_doesnt_touch_not_escaped(): res = TableauSiteSource._clean_tableau_query_parameters( - "select * from t where c1 < 135 and c2 > 15" + "select * from t where c1 < 135 and c2 > 15", ) assert res == "select * from t where c1 < 135 and c2 > 15" @@ -106,7 +107,7 @@ def test_tableau_source_doesnt_touch_not_escaped(): def test_tableau_source_cleanups_tableau_parameters_in_equi_predicates(p): assert ( TableauSiteSource._clean_tableau_query_parameters( - f"select * from t where c1 = {p} and c2 = {p} and c3 = 7" + f"select * from t where c1 = {p} and c2 = {p} and c3 = 7", ) == "select * from t where c1 = 1 and c2 = 1 and c3 = 7" ) @@ -116,7 +117,7 @@ def test_tableau_source_cleanups_tableau_parameters_in_equi_predicates(p): def test_tableau_source_cleanups_tableau_parameters_in_lt_gt_predicates(p): assert ( TableauSiteSource._clean_tableau_query_parameters( - f"select * from t where c1 << {p} and c2<<{p} and c3 >> {p} and c4>>{p} or {p} >> c1 and {p}>>c2 and {p} << c3 and {p}<> {p} and c4>>{p} or {p} >> c1 and {p}>>c2 and {p} << c3 and {p}< 1 and c4>1 or 1 > c1 and 1>c2 and 1 < c3 and 1>= {p} and c4>>={p} or {p} >>= c1 and {p}>>=c2 and {p} <<= c3 and {p}<<=c4" + f"select * from t where c1 <<= {p} and c2<<={p} and c3 >>= {p} and c4>>={p} or {p} >>= c1 and {p}>>=c2 and {p} <<= c3 and {p}<<=c4", ) == "select * from t where c1 <= 1 and c2<=1 and c3 >= 1 and c4>=1 or 1 >= c1 and 1>=c2 and 1 <= c3 and 1<=c4" ) @@ -136,7 +137,7 @@ def test_tableau_source_cleanups_tableau_parameters_in_lte_gte_predicates(p): def test_tableau_source_cleanups_tableau_parameters_in_join_predicate(p): assert ( TableauSiteSource._clean_tableau_query_parameters( - f"select * from t1 inner join t2 on t1.id = t2.id and t2.c21 = {p} and t1.c11 = 123 + {p}" + f"select * from t1 inner join t2 on t1.id = t2.id and t2.c21 = {p} and t1.c11 = 123 + {p}", ) == "select * from t1 inner join t2 on t1.id = t2.id and t2.c21 = 1 and t1.c11 = 123 + 1" ) @@ -146,7 +147,7 @@ def test_tableau_source_cleanups_tableau_parameters_in_join_predicate(p): def test_tableau_source_cleanups_tableau_parameters_in_complex_expressions(p): assert ( TableauSiteSource._clean_tableau_query_parameters( - f"select myudf1(c1, {p}, c2) / myudf2({p}) > ({p} + 3 * {p} * c5) * {p} - c4" + f"select myudf1(c1, {p}, c2) / myudf2({p}) > ({p} + 3 * {p} * c5) * {p} - c4", ) == "select myudf1(c1, 1, c2) / myudf2(1) > (1 + 3 * 1 * c5) * 1 - c4" ) @@ -276,7 +277,7 @@ def test_tableau_upstream_reference(): try: ref = TableauUpstreamReference.create(None) # type: ignore[arg-type] raise AssertionError( - "TableauUpstreamReference.create with None should have raised exception" + "TableauUpstreamReference.create with None should have raised exception", ) except ValueError: assert True @@ -349,7 +350,7 @@ def test_fine_grained(self): assert config.effective_embedded_datasource_page_size == any_page_size config = TableauPageSizeConfig( - embedded_datasource_field_upstream_page_size=any_page_size + embedded_datasource_field_upstream_page_size=any_page_size, ) assert config.page_size == DEFAULT_PAGE_SIZE assert ( @@ -362,7 +363,7 @@ def test_fine_grained(self): assert config.effective_published_datasource_page_size == any_page_size config = TableauPageSizeConfig( - published_datasource_field_upstream_page_size=any_page_size + published_datasource_field_upstream_page_size=any_page_size, ) assert config.page_size == DEFAULT_PAGE_SIZE assert ( diff --git a/metadata-ingestion/tests/unit/test_transform_dataset.py b/metadata-ingestion/tests/unit/test_transform_dataset.py index 5151be9c8b1997..3f9fa0f0c8cd8a 100644 --- a/metadata-ingestion/tests/unit/test_transform_dataset.py +++ b/metadata-ingestion/tests/unit/test_transform_dataset.py @@ -151,16 +151,18 @@ def create_and_run_test_pipeline( path: str, ) -> str: with mock.patch( - "tests.unit.test_source.FakeSource.get_workunits" + "tests.unit.test_source.FakeSource.get_workunits", ) as mock_getworkunits: mock_getworkunits.return_value = [ ( workunit.MetadataWorkUnit( - id=f"test-workunit-mce-{e.proposedSnapshot.urn}", mce=e + id=f"test-workunit-mce-{e.proposedSnapshot.urn}", + mce=e, ) if isinstance(e, MetadataChangeEventClass) else workunit.MetadataWorkUnit( - id=f"test-workunit-mcp-{e.entityUrn}-{e.aspectName}", mcp=e + id=f"test-workunit-mcp-{e.entityUrn}-{e.aspectName}", + mcp=e, ) ) for e in events @@ -174,7 +176,7 @@ def create_and_run_test_pipeline( }, "transformers": transformers, "sink": {"type": "file", "config": {"filename": events_file}}, - } + }, ) pipeline.run() @@ -195,9 +197,10 @@ def make_dataset_with_owner() -> models.MetadataChangeEventClass: ), ], lastModified=models.AuditStampClass( - time=1625266033123, actor="urn:li:corpuser:datahub" + time=1625266033123, + actor="urn:li:corpuser:datahub", ), - ) + ), ], ), ) @@ -213,7 +216,7 @@ def make_dataset_with_properties() -> models.MetadataChangeEventClass: aspects=[ models.StatusClass(removed=False), models.DatasetPropertiesClass( - customProperties=EXISTING_PROPERTIES.copy() + customProperties=EXISTING_PROPERTIES.copy(), ), ], ), @@ -233,9 +236,9 @@ def test_dataset_ownership_transformation(mock_time): name="User Deletions", description="Constructs the fct_users_deleted from logging_events", type=models.AzkabanJobTypeClass.SQL, - ) + ), ], - ) + ), ) inputs = [no_owner_aspect, with_owner_aspect, not_a_dataset, EndOfStream()] @@ -245,20 +248,21 @@ def test_dataset_ownership_transformation(mock_time): "owner_urns": [ builder.make_user_urn("person1"), builder.make_user_urn("person2"), - ] + ], }, PipelineContext(run_id="test"), ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert len(outputs) == len(inputs) + 2 # Check the first entry. first_ownership_aspect = builder.get_aspect_if_available( - outputs[0].record, models.OwnershipClass + outputs[0].record, + models.OwnershipClass, ) assert first_ownership_aspect is None @@ -271,12 +275,13 @@ def test_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER and owner.typeUrn is None for owner in last_event.aspect.owners - ] + ], ) # Check the second entry. second_ownership_aspect = builder.get_aspect_if_available( - outputs[1].record, models.OwnershipClass + outputs[1].record, + models.OwnershipClass, ) assert second_ownership_aspect assert len(second_ownership_aspect.owners) == 3 @@ -284,7 +289,7 @@ def test_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER and owner.typeUrn is None for owner in second_ownership_aspect.owners - ] + ], ) third_ownership_aspect = outputs[4].record.aspect @@ -294,7 +299,7 @@ def test_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER and owner.typeUrn is None for owner in second_ownership_aspect.owners - ] + ], ) # Verify that the third entry is unchanged. @@ -322,8 +327,8 @@ def test_simple_dataset_ownership_with_type_transformation(mock_time): [ RecordEnvelope(input, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(output) == 3 @@ -356,8 +361,8 @@ def test_simple_dataset_ownership_with_type_urn_transformation(mock_time): [ RecordEnvelope(input, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(output) == 3 @@ -391,8 +396,8 @@ def _test_extract_tags(in_urn: str, regex_str: str, out_tag: str) -> None: [ RecordEnvelope(input, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(output) == 3 @@ -447,11 +452,12 @@ def test_simple_remove_dataset_ownership(): PipelineContext(run_id="test"), ) outputs = list( - transformer.transform([RecordEnvelope(with_owner_aspect, metadata={})]) + transformer.transform([RecordEnvelope(with_owner_aspect, metadata={})]), ) ownership_aspect = builder.get_aspect_if_available( - outputs[0].record, models.OwnershipClass + outputs[0].record, + models.OwnershipClass, ) assert ownership_aspect assert len(ownership_aspect.owners) == 0 @@ -468,12 +474,13 @@ def test_mark_status_dataset(tmp_path): transformer.transform( [ RecordEnvelope(dataset, metadata={}), - ] - ) + ], + ), ) assert len(removed) == 1 status_aspect = builder.get_aspect_if_available( - removed[0].record, models.StatusClass + removed[0].record, + models.StatusClass, ) assert status_aspect assert status_aspect.removed is True @@ -486,12 +493,13 @@ def test_mark_status_dataset(tmp_path): transformer.transform( [ RecordEnvelope(dataset, metadata={}), - ] - ) + ], + ), ) assert len(not_removed) == 1 status_aspect = builder.get_aspect_if_available( - not_removed[0].record, models.StatusClass + not_removed[0].record, + models.StatusClass, ) assert status_aspect assert status_aspect.removed is False @@ -684,9 +692,9 @@ def _test_owner( dataset = make_generic_dataset( aspects=[ models.GlobalTagsClass( - tags=[TagAssociationClass(tag=builder.make_tag_urn(tag))] - ) - ] + tags=[TagAssociationClass(tag=builder.make_tag_urn(tag))], + ), + ], ) transformer = ExtractOwnersFromTagsTransformer.create( @@ -699,8 +707,8 @@ def _test_owner( [ RecordEnvelope(dataset, metadata={}), RecordEnvelope(record=EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(record_envelops) == 3 @@ -807,8 +815,8 @@ def test_add_dataset_browse_paths(): [ RecordEnvelope(dataset, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) browse_path_aspect = transformed[1].record.aspect assert browse_path_aspect @@ -816,7 +824,7 @@ def test_add_dataset_browse_paths(): # use an mce with a pre-existing browse path dataset_mce = make_generic_dataset( - aspects=[StatusClass(removed=False), browse_path_aspect] + aspects=[StatusClass(removed=False), browse_path_aspect], ) transformer = AddDatasetBrowsePathTransformer.create( @@ -824,7 +832,7 @@ def test_add_dataset_browse_paths(): "path_templates": [ "/PLATFORM/foo/DATASET_PARTS/ENV", "/ENV/PLATFORM/bar/DATASET_PARTS/", - ] + ], }, PipelineContext(run_id="test"), ) @@ -833,12 +841,13 @@ def test_add_dataset_browse_paths(): [ RecordEnvelope(dataset_mce, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(transformed) == 2 browse_path_aspect = builder.get_aspect_if_available( - transformed[0].record, BrowsePathsClass + transformed[0].record, + BrowsePathsClass, ) assert browse_path_aspect assert browse_path_aspect.paths == [ @@ -861,12 +870,13 @@ def test_add_dataset_browse_paths(): [ RecordEnvelope(dataset_mce, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(transformed) == 2 browse_path_aspect = builder.get_aspect_if_available( - transformed[0].record, BrowsePathsClass + transformed[0].record, + BrowsePathsClass, ) assert browse_path_aspect assert browse_path_aspect.paths == [ @@ -882,7 +892,7 @@ def test_simple_dataset_tags_transformation(mock_time): "tag_urns": [ builder.make_tag_urn("NeedsDocumentation"), builder.make_tag_urn("Legacy"), - ] + ], }, PipelineContext(run_id="test-tags"), ) @@ -892,8 +902,8 @@ def test_simple_dataset_tags_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 5 @@ -932,7 +942,7 @@ def test_pattern_dataset_tags_transformation(mock_time): builder.make_tag_urn("Legacy"), ], ".*example2.*": [builder.make_term_urn("Needs Documentation")], - } + }, }, }, PipelineContext(run_id="test-tags"), @@ -943,8 +953,8 @@ def test_pattern_dataset_tags_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 5 @@ -958,14 +968,14 @@ def test_pattern_dataset_tags_transformation(mock_time): def test_add_dataset_tags_transformation(): transformer = AddDatasetTags.create( { - "get_tags_to_add": "tests.unit.test_transform_dataset.dummy_tag_resolver_method" + "get_tags_to_add": "tests.unit.test_transform_dataset.dummy_tag_resolver_method", }, PipelineContext(run_id="test-tags"), ) output = list( transformer.transform( - [RecordEnvelope(input, metadata={}) for input in [make_generic_dataset()]] - ) + [RecordEnvelope(input, metadata={}) for input in [make_generic_dataset()]], + ), ) assert output @@ -985,9 +995,10 @@ def test_pattern_dataset_ownership_transformation(mock_time): ), ], lastModified=models.AuditStampClass( - time=1625266033123, actor="urn:li:corpuser:datahub" + time=1625266033123, + actor="urn:li:corpuser:datahub", ), - ) + ), ], ), ) @@ -1000,9 +1011,9 @@ def test_pattern_dataset_ownership_transformation(mock_time): name="User Deletions", description="Constructs the fct_users_deleted from logging_events", type=models.AzkabanJobTypeClass.SQL, - ) + ), ], - ) + ), ) inputs = [no_owner_aspect, with_owner_aspect, not_a_dataset, EndOfStream()] @@ -1014,7 +1025,7 @@ def test_pattern_dataset_ownership_transformation(mock_time): ".*example1.*": [builder.make_user_urn("person1")], ".*example2.*": [builder.make_user_urn("person2")], ".*dag_abc.*": [builder.make_user_urn("person2")], - } + }, }, "ownership_type": "DATAOWNER", }, @@ -1022,7 +1033,7 @@ def test_pattern_dataset_ownership_transformation(mock_time): ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert ( @@ -1039,12 +1050,13 @@ def test_pattern_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in first_ownership_aspect.owners - ] + ], ) # Check the second entry. second_ownership_aspect = builder.get_aspect_if_available( - outputs[1].record, models.OwnershipClass + outputs[1].record, + models.OwnershipClass, ) assert second_ownership_aspect assert len(second_ownership_aspect.owners) == 2 @@ -1052,7 +1064,7 @@ def test_pattern_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in second_ownership_aspect.owners - ] + ], ) third_ownership_aspect = outputs[4].record.aspect @@ -1062,7 +1074,7 @@ def test_pattern_dataset_ownership_transformation(mock_time): [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in third_ownership_aspect.owners - ] + ], ) # Verify that the third entry is unchanged. @@ -1080,7 +1092,7 @@ def test_pattern_dataset_ownership_with_type_transformation(mock_time): "owner_pattern": { "rules": { ".*example1.*": [builder.make_user_urn("person1")], - } + }, }, "ownership_type": "PRODUCER", }, @@ -1092,8 +1104,8 @@ def test_pattern_dataset_ownership_with_type_transformation(mock_time): [ RecordEnvelope(input, metadata={}), RecordEnvelope(EndOfStream(), metadata={}), - ] - ) + ], + ), ) assert len(output) == 3 @@ -1111,7 +1123,7 @@ def test_pattern_dataset_ownership_with_invalid_type_transformation(mock_time): "owner_pattern": { "rules": { ".*example1.*": [builder.make_user_urn("person1")], - } + }, }, "ownership_type": "INVALID_TYPE", }, @@ -1120,7 +1132,8 @@ def test_pattern_dataset_ownership_with_invalid_type_transformation(mock_time): def test_pattern_container_and_dataset_ownership_transformation( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): def fake_get_aspect( entity_urn: str, @@ -1130,16 +1143,18 @@ def fake_get_aspect( return models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" + id="container_1", + urn="urn:li:container:container_1", ), models.BrowsePathEntryClass( - id="container_2", urn="urn:li:container:container_2" + id="container_2", + urn="urn:li:container:container_2", ), - ] + ], ) pipeline_context = PipelineContext( - run_id="test_pattern_container_and_dataset_ownership_transformation" + run_id="test_pattern_container_and_dataset_ownership_transformation", ) pipeline_context.graph = mock_datahub_graph_instance pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore @@ -1164,9 +1179,10 @@ def fake_get_aspect( ), ], lastModified=models.AuditStampClass( - time=1625266033123, actor="urn:li:corpuser:datahub" + time=1625266033123, + actor="urn:li:corpuser:datahub", ), - ) + ), ], ), ) @@ -1179,9 +1195,9 @@ def fake_get_aspect( name="User Deletions", description="Constructs the fct_users_deleted from logging_events", type=models.AzkabanJobTypeClass.SQL, - ) + ), ], - ) + ), ) inputs = [ @@ -1199,7 +1215,7 @@ def fake_get_aspect( ".*example1.*": [builder.make_user_urn("person1")], ".*example2.*": [builder.make_user_urn("person2")], ".*dag_abc.*": [builder.make_user_urn("person3")], - } + }, }, "ownership_type": "DATAOWNER", "is_container": True, # Enable container ownership handling @@ -1208,7 +1224,7 @@ def fake_get_aspect( ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert len(outputs) == len(inputs) + 4 @@ -1224,12 +1240,13 @@ def fake_get_aspect( [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in first_ownership_aspect.owners - ] + ], ) # Check the ownership for the second dataset (example2) second_ownership_aspect = builder.get_aspect_if_available( - outputs[1].record, models.OwnershipClass + outputs[1].record, + models.OwnershipClass, ) assert second_ownership_aspect assert len(second_ownership_aspect.owners) == 2 # One existing + one new @@ -1237,7 +1254,7 @@ def fake_get_aspect( [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in second_ownership_aspect.owners - ] + ], ) third_ownership_aspect = outputs[4].record.aspect @@ -1258,7 +1275,8 @@ def fake_get_aspect( def test_pattern_container_and_dataset_ownership_with_no_container( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): def fake_get_aspect( entity_urn: str, @@ -1268,7 +1286,7 @@ def fake_get_aspect( return None pipeline_context = PipelineContext( - run_id="test_pattern_container_and_dataset_ownership_with_no_container" + run_id="test_pattern_container_and_dataset_ownership_with_no_container", ) pipeline_context.graph = mock_datahub_graph_instance pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore @@ -1282,12 +1300,14 @@ def fake_get_aspect( models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" + id="container_1", + urn="urn:li:container:container_1", ), models.BrowsePathEntryClass( - id="container_2", urn="urn:li:container:container_2" + id="container_2", + urn="urn:li:container:container_2", ), - ] + ], ), ], ), @@ -1305,18 +1325,21 @@ def fake_get_aspect( ), ], lastModified=models.AuditStampClass( - time=1625266033123, actor="urn:li:corpuser:datahub" + time=1625266033123, + actor="urn:li:corpuser:datahub", ), ), models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" + id="container_1", + urn="urn:li:container:container_1", ), models.BrowsePathEntryClass( - id="container_2", urn="urn:li:container:container_2" + id="container_2", + urn="urn:li:container:container_2", ), - ] + ], ), ], ), @@ -1335,7 +1358,7 @@ def fake_get_aspect( "rules": { ".*example1.*": [builder.make_user_urn("person1")], ".*example2.*": [builder.make_user_urn("person2")], - } + }, }, "ownership_type": "DATAOWNER", "is_container": True, # Enable container ownership handling @@ -1344,7 +1367,7 @@ def fake_get_aspect( ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert len(outputs) == len(inputs) + 1 @@ -1357,12 +1380,13 @@ def fake_get_aspect( [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in first_ownership_aspect.owners - ] + ], ) # Check the ownership for the second dataset (example2) second_ownership_aspect = builder.get_aspect_if_available( - outputs[1].record, models.OwnershipClass + outputs[1].record, + models.OwnershipClass, ) assert second_ownership_aspect assert len(second_ownership_aspect.owners) == 2 # One existing + one new @@ -1370,12 +1394,13 @@ def fake_get_aspect( [ owner.type == models.OwnershipTypeClass.DATAOWNER for owner in second_ownership_aspect.owners - ] + ], ) def test_pattern_container_and_dataset_ownership_with_no_match( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): def fake_get_aspect( entity_urn: str, @@ -1385,13 +1410,14 @@ def fake_get_aspect( return models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" - ) - ] + id="container_1", + urn="urn:li:container:container_1", + ), + ], ) pipeline_context = PipelineContext( - run_id="test_pattern_container_and_dataset_ownership_with_no_match" + run_id="test_pattern_container_and_dataset_ownership_with_no_match", ) pipeline_context.graph = mock_datahub_graph_instance pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore @@ -1418,9 +1444,10 @@ def fake_get_aspect( ), ], lastModified=models.AuditStampClass( - time=1625266033123, actor="urn:li:corpuser:datahub" + time=1625266033123, + actor="urn:li:corpuser:datahub", ), - ) + ), ], ), ) @@ -1438,7 +1465,7 @@ def fake_get_aspect( "rules": { ".*example3.*": [builder.make_user_urn("person1")], ".*example4.*": [builder.make_user_urn("person2")], - } + }, }, "ownership_type": "DATAOWNER", "is_container": True, # Enable container ownership handling @@ -1447,7 +1474,7 @@ def fake_get_aspect( ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert len(outputs) == len(inputs) + 1 @@ -1460,7 +1487,8 @@ def fake_get_aspect( # Check the ownership for the second dataset (example2) second_ownership_aspect = builder.get_aspect_if_available( - outputs[1].record, models.OwnershipClass + outputs[1].record, + models.OwnershipClass, ) assert second_ownership_aspect assert len(second_ownership_aspect.owners) == 1 @@ -1474,11 +1502,14 @@ def fake_get_aspect( def gen_owners( owners: List[str], ownership_type: Union[ - str, models.OwnershipTypeClass + str, + models.OwnershipTypeClass, ] = models.OwnershipTypeClass.DATAOWNER, ) -> models.OwnershipClass: return models.OwnershipClass( - owners=[models.OwnerClass(owner=owner, type=ownership_type) for owner in owners] + owners=[ + models.OwnerClass(owner=owner, type=ownership_type) for owner in owners + ], ) @@ -1489,7 +1520,9 @@ def test_ownership_patching_intersect(mock_time): mock_graph.get_ownership.return_value = server_ownership test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) assert test_ownership and test_ownership.owners assert "foo" in [o.owner for o in test_ownership.owners] @@ -1502,7 +1535,9 @@ def test_ownership_patching_with_nones(mock_time): mce_ownership = gen_owners(["baz", "foo"]) mock_graph.get_ownership.return_value = None test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) assert test_ownership and test_ownership.owners assert "foo" in [o.owner for o in test_ownership.owners] @@ -1511,7 +1546,9 @@ def test_ownership_patching_with_nones(mock_time): server_ownership = gen_owners(["baz", "foo"]) mock_graph.get_ownership.return_value = server_ownership test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", None + mock_graph, + "test_urn", + None, ) assert not test_ownership @@ -1521,7 +1558,9 @@ def test_ownership_patching_with_empty_mce_none_server(mock_time): mce_ownership = gen_owners([]) mock_graph.get_ownership.return_value = None test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) # nothing to add, so we omit writing assert test_ownership is None @@ -1533,7 +1572,9 @@ def test_ownership_patching_with_empty_mce_nonempty_server(mock_time): mce_ownership = gen_owners([]) mock_graph.get_ownership.return_value = server_ownership test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) # nothing to add, so we omit writing assert test_ownership is None @@ -1545,7 +1586,9 @@ def test_ownership_patching_with_different_types_1(mock_time): mce_ownership = gen_owners(["foo"], models.OwnershipTypeClass.DATAOWNER) mock_graph.get_ownership.return_value = server_ownership test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) assert test_ownership and test_ownership.owners # nothing to add, so we omit writing @@ -1563,7 +1606,9 @@ def test_ownership_patching_with_different_types_2(mock_time): mce_ownership = gen_owners(["foo", "baz"], models.OwnershipTypeClass.DATAOWNER) mock_graph.get_ownership.return_value = server_ownership test_ownership = AddDatasetOwnership._merge_with_server_ownership( - mock_graph, "test_urn", mce_ownership + mock_graph, + "test_urn", + mce_ownership, ) assert test_ownership and test_ownership.owners assert len(test_ownership.owners) == 2 @@ -1589,20 +1634,21 @@ def test_add_dataset_properties(mock_time): transformer = AddDatasetProperties.create( { - "add_properties_resolver_class": "tests.unit.test_transform_dataset.DummyPropertiesResolverClass" + "add_properties_resolver_class": "tests.unit.test_transform_dataset.DummyPropertiesResolverClass", }, PipelineContext(run_id="test-properties"), ) outputs = list( transformer.transform( - [RecordEnvelope(input, metadata={}) for input in [dataset_mce]] - ) + [RecordEnvelope(input, metadata={}) for input in [dataset_mce]], + ), ) assert len(outputs) == 1 custom_properties = builder.get_aspect_if_available( - outputs[0].record, models.DatasetPropertiesClass + outputs[0].record, + models.DatasetPropertiesClass, ) assert custom_properties is not None @@ -1631,7 +1677,7 @@ def fake_dataset_properties(entity_urn: str) -> models.DatasetPropertiesClass: transformer_type=SimpleAddDatasetProperties, pipeline_context=pipeline_context, aspect=models.DatasetPropertiesClass( - customProperties=EXISTING_PROPERTIES.copy() + customProperties=EXISTING_PROPERTIES.copy(), ), config={ "semantics": semantics, @@ -1657,7 +1703,8 @@ def test_simple_add_dataset_properties_overwrite(mock_datahub_graph_instance): assert output[0].record assert output[0].record.aspect custom_properties_aspect: models.DatasetPropertiesClass = cast( - models.DatasetPropertiesClass, output[0].record.aspect + models.DatasetPropertiesClass, + output[0].record.aspect, ) assert custom_properties_aspect.customProperties == { @@ -1681,7 +1728,8 @@ def test_simple_add_dataset_properties_patch(mock_datahub_graph_instance): assert output[0].record assert output[0].record.aspect custom_properties_aspect: models.DatasetPropertiesClass = cast( - models.DatasetPropertiesClass, output[0].record.aspect + models.DatasetPropertiesClass, + output[0].record.aspect, ) assert custom_properties_aspect.customProperties == { **EXISTING_PROPERTIES, @@ -1695,7 +1743,7 @@ def test_simple_add_dataset_properties(mock_time): outputs = run_dataset_transformer_pipeline( transformer_type=SimpleAddDatasetProperties, aspect=models.DatasetPropertiesClass( - customProperties=EXISTING_PROPERTIES.copy() + customProperties=EXISTING_PROPERTIES.copy(), ), config={ "properties": new_properties, @@ -1706,7 +1754,8 @@ def test_simple_add_dataset_properties(mock_time): assert outputs[0].record assert outputs[0].record.aspect custom_properties_aspect: models.DatasetPropertiesClass = cast( - models.DatasetPropertiesClass, outputs[0].record.aspect + models.DatasetPropertiesClass, + outputs[0].record.aspect, ) assert custom_properties_aspect.customProperties == { **EXISTING_PROPERTIES, @@ -1719,7 +1768,7 @@ def test_simple_add_dataset_properties_replace_existing(mock_time): outputs = run_dataset_transformer_pipeline( transformer_type=SimpleAddDatasetProperties, aspect=models.DatasetPropertiesClass( - customProperties=EXISTING_PROPERTIES.copy() + customProperties=EXISTING_PROPERTIES.copy(), ), config={ "replace_existing": True, @@ -1731,7 +1780,8 @@ def test_simple_add_dataset_properties_replace_existing(mock_time): assert outputs[0].record assert outputs[0].record.aspect custom_properties_aspect: models.DatasetPropertiesClass = cast( - models.DatasetPropertiesClass, outputs[0].record.aspect + models.DatasetPropertiesClass, + outputs[0].record.aspect, ) assert custom_properties_aspect.customProperties == { @@ -1747,7 +1797,7 @@ def test_simple_dataset_terms_transformation(mock_time): "term_urns": [ builder.make_term_urn("Test"), builder.make_term_urn("Needs Review"), - ] + ], }, PipelineContext(run_id="test-terms"), ) @@ -1757,8 +1807,8 @@ def test_simple_dataset_terms_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 3 @@ -1781,7 +1831,7 @@ def test_pattern_dataset_terms_transformation(mock_time): builder.make_term_urn("Email"), ], ".*example2.*": [builder.make_term_urn("Address")], - } + }, }, }, PipelineContext(run_id="test-terms"), @@ -1792,8 +1842,8 @@ def test_pattern_dataset_terms_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 3 @@ -1813,7 +1863,7 @@ def test_mcp_add_tags_missing(mock_time): "tag_urns": [ builder.make_tag_urn("NeedsDocumentation"), builder.make_tag_urn("Legacy"), - ] + ], }, PipelineContext(run_id="test-tags"), ) @@ -1836,7 +1886,7 @@ def test_mcp_add_tags_existing(mock_time): dataset_mcp = make_generic_dataset_mcp( aspect_name="globalTags", aspect=GlobalTagsClass( - tags=[TagAssociationClass(tag=builder.make_tag_urn("Test"))] + tags=[TagAssociationClass(tag=builder.make_tag_urn("Test"))], ), ) @@ -1845,7 +1895,7 @@ def test_mcp_add_tags_existing(mock_time): "tag_urns": [ builder.make_tag_urn("NeedsDocumentation"), builder.make_tag_urn("Legacy"), - ] + ], }, PipelineContext(run_id="test-tags"), ) @@ -1887,7 +1937,7 @@ def test_mcp_multiple_transformers(mock_time, tmp_path): { "type": "set_dataset_browse_path", "config": { - "path_templates": ["/ENV/PLATFORM/EsComments/DATASET_PARTS"] + "path_templates": ["/ENV/PLATFORM/EsComments/DATASET_PARTS"], }, }, { @@ -1896,14 +1946,14 @@ def test_mcp_multiple_transformers(mock_time, tmp_path): }, ], "sink": {"type": "file", "config": {"filename": events_file}}, - } + }, ) pipeline.run() pipeline.raise_from_status() urn_pattern = "^" + re.escape( - "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,fooIndex,PROD)" + "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,fooIndex,PROD)", ) assert ( tests.test_helpers.mce_helpers.assert_mcp_entity_urn( @@ -1959,7 +2009,7 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): platform_id="elasticsearch", table_name=f"fooBarIndex{i}", env="PROD", - ) + ), ), aspect=GlobalTagsClass(tags=[TagAssociationClass(tag="urn:li:tag:Test")]), ) @@ -1973,12 +2023,12 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): platform_id="elasticsearch", table_name=f"fooBarIndex{i}", env="PROD", - ) + ), ), aspect=DatasetPropertiesClass(description="test dataset"), ) for i in range(0, 10) - ] + ], ) # shuffle the mcps @@ -1992,7 +2042,7 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): { "type": "set_dataset_browse_path", "config": { - "path_templates": ["/ENV/PLATFORM/EsComments/DATASET_PARTS"] + "path_templates": ["/ENV/PLATFORM/EsComments/DATASET_PARTS"], }, }, { @@ -2004,7 +2054,7 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): ) urn_pattern = "^" + re.escape( - "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,fooBarIndex" + "urn:li:dataset:(urn:li:dataPlatform:elasticsearch,fooBarIndex", ) # there should be 30 MCP-s @@ -2024,7 +2074,7 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): entity_type="dataset", aspect_name="globalTags", aspect_field_matcher={ - "tags": [{"tag": "urn:li:tag:Test"}, {"tag": "urn:li:tag:EsComments"}] + "tags": [{"tag": "urn:li:tag:Test"}, {"tag": "urn:li:tag:EsComments"}], }, file=events_file, ) @@ -2040,11 +2090,11 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): platform_id="elasticsearch", table_name=f"fooBarIndex{i}", env="PROD", - ) + ), ), aspect_name="browsePaths", aspect_field_matcher={ - "paths": [f"/prod/elasticsearch/EsComments/fooBarIndex{i}"] + "paths": [f"/prod/elasticsearch/EsComments/fooBarIndex{i}"], }, file=events_file, ) @@ -2055,7 +2105,9 @@ def test_mcp_multiple_transformers_replace(mock_time, tmp_path): class SuppressingTransformer(BaseTransformer, SingleAspectTransformer): @classmethod def create( - cls, config_dict: dict, ctx: PipelineContext + cls, + config_dict: dict, + ctx: PipelineContext, ) -> "SuppressingTransformer": return SuppressingTransformer() @@ -2066,7 +2118,10 @@ def aspect_name(self) -> str: return "datasetProperties" def transform_aspect( - self, entity_urn: str, aspect_name: str, aspect: Optional[builder.Aspect] + self, + entity_urn: str, + aspect_name: str, + aspect: Optional[builder.Aspect], ) -> Optional[builder.Aspect]: return None @@ -2087,8 +2142,8 @@ def test_supression_works(): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, dataset_mcp, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 2 # MCP will be dropped @@ -2100,20 +2155,20 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( fieldPath="address", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular @@ -2121,7 +2176,7 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): models.SchemaFieldClass( fieldPath="first_name", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular @@ -2129,14 +2184,14 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): models.SchemaFieldClass( fieldPath="last_name", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular ), ], - ) - ] + ), + ], ) transformer = PatternAddDatasetSchemaTerms.create( @@ -2151,7 +2206,7 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): builder.make_term_urn("Name"), builder.make_term_urn("LastName"), ], - } + }, }, }, PipelineContext(run_id="test-schema-terms"), @@ -2162,8 +2217,8 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 2 @@ -2174,17 +2229,17 @@ def test_pattern_dataset_schema_terms_transformation(mock_time): assert schema_aspect.fields[0].glossaryTerms is None assert schema_aspect.fields[1].fieldPath == "first_name" assert schema_aspect.fields[1].glossaryTerms.terms[0].urn == builder.make_term_urn( - "Name" + "Name", ) assert schema_aspect.fields[1].glossaryTerms.terms[1].urn == builder.make_term_urn( - "FirstName" + "FirstName", ) assert schema_aspect.fields[2].fieldPath == "last_name" assert schema_aspect.fields[2].glossaryTerms.terms[0].urn == builder.make_term_urn( - "Name" + "Name", ) assert schema_aspect.fields[2].glossaryTerms.terms[1].urn == builder.make_term_urn( - "LastName" + "LastName", ) @@ -2194,20 +2249,20 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( fieldPath="address", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular @@ -2215,7 +2270,7 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): models.SchemaFieldClass( fieldPath="first_name", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular @@ -2223,14 +2278,14 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): models.SchemaFieldClass( fieldPath="last_name", type=models.SchemaFieldDataTypeClass( - type=models.StringTypeClass() + type=models.StringTypeClass(), ), nativeDataType="VARCHAR(100)", # use this to provide the type of the field in the source system's vernacular ), ], - ) - ] + ), + ], ) transformer = PatternAddDatasetSchemaTags.create( @@ -2245,7 +2300,7 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): builder.make_tag_urn("Name"), builder.make_tag_urn("LastName"), ], - } + }, }, }, PipelineContext(run_id="test-schema-tags"), @@ -2256,8 +2311,8 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): [ RecordEnvelope(input, metadata={}) for input in [dataset_mce, EndOfStream()] - ] - ) + ], + ), ) assert len(outputs) == 2 @@ -2268,17 +2323,17 @@ def test_pattern_dataset_schema_tags_transformation(mock_time): assert schema_aspect.fields[0].globalTags is None assert schema_aspect.fields[1].fieldPath == "first_name" assert schema_aspect.fields[1].globalTags.tags[0].tag == builder.make_tag_urn( - "Name" + "Name", ) assert schema_aspect.fields[1].globalTags.tags[1].tag == builder.make_tag_urn( - "FirstName" + "FirstName", ) assert schema_aspect.fields[2].fieldPath == "last_name" assert schema_aspect.fields[2].globalTags.tags[0].tag == builder.make_tag_urn( - "Name" + "Name", ) assert schema_aspect.fields[2].globalTags.tags[1].tag == builder.make_tag_urn( - "LastName" + "LastName", ) @@ -2292,7 +2347,8 @@ def run_dataset_transformer_pipeline( if pipeline_context is None: pipeline_context = PipelineContext(run_id="transformer_pipe_line") transformer: DatasetTransformer = cast( - DatasetTransformer, transformer_type.create(config, pipeline_context) + DatasetTransformer, + transformer_type.create(config, pipeline_context), ) dataset: Union[MetadataChangeEventClass, MetadataChangeProposalWrapper] @@ -2301,18 +2357,19 @@ def run_dataset_transformer_pipeline( proposedSnapshot=models.DatasetSnapshotClass( urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,example1,PROD)", aspects=[], - ) + ), ) else: assert aspect dataset = make_generic_dataset_mcp( - aspect=aspect, aspect_name=transformer.aspect_name() + aspect=aspect, + aspect_name=transformer.aspect_name(), ) outputs = list( transformer.transform( - [RecordEnvelope(input, metadata={}) for input in [dataset, EndOfStream()]] - ) + [RecordEnvelope(input, metadata={}) for input in [dataset, EndOfStream()]], + ), ) return outputs @@ -2327,7 +2384,8 @@ def run_container_transformer_pipeline( if pipeline_context is None: pipeline_context = PipelineContext(run_id="transformer_pipe_line") transformer: ContainerTransformer = cast( - ContainerTransformer, transformer_type.create(config, pipeline_context) + ContainerTransformer, + transformer_type.create(config, pipeline_context), ) container: Union[MetadataChangeEventClass, MetadataChangeProposalWrapper] @@ -2336,25 +2394,29 @@ def run_container_transformer_pipeline( proposedSnapshot=models.DatasetSnapshotClass( urn="urn:li:container:6338f55439c7ae58243a62c4d6fbffde", aspects=[], - ) + ), ) else: assert aspect container = make_generic_container_mcp( - aspect=aspect, aspect_name=transformer.aspect_name() + aspect=aspect, + aspect_name=transformer.aspect_name(), ) outputs = list( transformer.transform( - [RecordEnvelope(input, metadata={}) for input in [container, EndOfStream()]] - ) + [ + RecordEnvelope(input, metadata={}) + for input in [container, EndOfStream()] + ], + ), ) return outputs def test_simple_add_dataset_domain_aspect_name(mock_datahub_graph_instance): pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2367,7 +2429,7 @@ def test_simple_add_dataset_domain(mock_datahub_graph_instance): datahub_domain = builder.make_domain_urn("datahubproject.io") pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2395,7 +2457,7 @@ def test_simple_add_dataset_domain_mce_support(mock_datahub_graph_instance): datahub_domain = builder.make_domain_urn("datahubproject.io") pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2426,7 +2488,7 @@ def test_simple_add_dataset_domain_replace_existing(mock_datahub_graph_instance) datahub_domain = builder.make_domain_urn("datahubproject.io") pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2487,7 +2549,10 @@ def fake_get_domain(entity_urn: str) -> models.DomainsClass: def test_simple_add_dataset_domain_semantics_patch( - pytestconfig, tmp_path, mock_time, mock_datahub_graph_instance + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph_instance, ): acryl_domain = builder.make_domain_urn("acryl.io") datahub_domain = builder.make_domain_urn("datahubproject.io") @@ -2527,7 +2592,10 @@ def fake_get_domain(entity_urn: str) -> models.DomainsClass: def test_simple_add_dataset_domain_on_conflict_do_nothing( - pytestconfig, tmp_path, mock_time, mock_datahub_graph_instance + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph_instance, ): acryl_domain = builder.make_domain_urn("acryl.io") datahub_domain = builder.make_domain_urn("datahubproject.io") @@ -2561,7 +2629,10 @@ def fake_get_domain(entity_urn: str) -> models.DomainsClass: def test_simple_add_dataset_domain_on_conflict_do_nothing_no_conflict( - pytestconfig, tmp_path, mock_time, mock_datahub_graph_instance + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph_instance, ): acryl_domain = builder.make_domain_urn("acryl.io") datahub_domain = builder.make_domain_urn("datahubproject.io") @@ -2603,12 +2674,13 @@ def fake_get_domain(entity_urn: str) -> models.DomainsClass: def test_pattern_add_dataset_domain_aspect_name(mock_datahub_graph_instance): pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance transformer = PatternAddDatasetDomain.create( - {"domain_pattern": {"rules": {}}}, pipeline_context + {"domain_pattern": {"rules": {}}}, + pipeline_context, ) assert transformer.aspect_name() == models.DomainsClass.ASPECT_NAME @@ -2619,7 +2691,7 @@ def test_pattern_add_dataset_domain_match(mock_datahub_graph_instance): pattern = "urn:li:dataset:\\(urn:li:dataPlatform:bigquery,.*" pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2650,7 +2722,7 @@ def test_pattern_add_dataset_domain_no_match(mock_datahub_graph_instance): pattern = "urn:li:dataset:\\(urn:li:dataPlatform:invalid,.*" pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2681,7 +2753,7 @@ def test_pattern_add_dataset_domain_replace_existing_match(mock_datahub_graph_in pattern = "urn:li:dataset:\\(urn:li:dataPlatform:bigquery,.*" pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2715,7 +2787,7 @@ def test_pattern_add_dataset_domain_replace_existing_no_match( pattern = "urn:li:dataset:\\(urn:li:dataPlatform:invalid,.*" pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -2778,7 +2850,10 @@ def fake_get_domain(entity_urn: str) -> models.DomainsClass: def test_pattern_add_dataset_domain_semantics_patch( - pytestconfig, tmp_path, mock_time, mock_datahub_graph_instance + pytestconfig, + tmp_path, + mock_time, + mock_datahub_graph_instance, ): acryl_domain = builder.make_domain_urn("acryl.io") datahub_domain = builder.make_domain_urn("datahubproject.io") @@ -2825,13 +2900,16 @@ def test_simple_dataset_ownership_transformer_semantics_patch( pipeline_context.graph = mock_datahub_graph_instance server_owner: str = builder.make_owner_urn( - "mohd@acryl.io", owner_type=builder.OwnerType.USER + "mohd@acryl.io", + owner_type=builder.OwnerType.USER, ) owner1: str = builder.make_owner_urn( - "john@acryl.io", owner_type=builder.OwnerType.USER + "john@acryl.io", + owner_type=builder.OwnerType.USER, ) owner2: str = builder.make_owner_urn( - "pedro@acryl.io", owner_type=builder.OwnerType.USER + "pedro@acryl.io", + owner_type=builder.OwnerType.USER, ) # Return fake aspect to simulate server behaviour @@ -2839,9 +2917,10 @@ def fake_ownership_class(entity_urn: str) -> models.OwnershipClass: return models.OwnershipClass( owners=[ models.OwnerClass( - owner=server_owner, type=models.OwnershipTypeClass.DATAOWNER - ) - ] + owner=server_owner, + type=models.OwnershipTypeClass.DATAOWNER, + ), + ], ) pipeline_context.graph.get_ownership = fake_ownership_class # type: ignore @@ -2850,8 +2929,11 @@ def fake_ownership_class(entity_urn: str) -> models.OwnershipClass: transformer_type=SimpleAddDatasetOwnership, aspect=models.OwnershipClass( owners=[ - models.OwnerClass(owner=owner1, type=models.OwnershipTypeClass.PRODUCER) - ] + models.OwnerClass( + owner=owner1, + type=models.OwnershipTypeClass.PRODUCER, + ), + ], ), config={ "replace_existing": False, @@ -2869,7 +2951,8 @@ def fake_ownership_class(entity_urn: str) -> models.OwnershipClass: assert output[0].record.aspect is not None assert isinstance(output[0].record.aspect, models.OwnershipClass) transformed_aspect: models.OwnershipClass = cast( - models.OwnershipClass, output[0].record.aspect + models.OwnershipClass, + output[0].record.aspect, ) assert len(transformed_aspect.owners) == 3 owner_urns: List[str] = [ @@ -2895,25 +2978,28 @@ def fake_get_aspect( return models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" + id="container_1", + urn="urn:li:container:container_1", ), models.BrowsePathEntryClass( - id="container_2", urn="urn:li:container:container_2" + id="container_2", + urn="urn:li:container:container_2", ), - ] + ], ) pipeline_context = PipelineContext( - run_id="test_pattern_container_and_dataset_domain_transformation" + run_id="test_pattern_container_and_dataset_domain_transformation", ) pipeline_context.graph = mock_datahub_graph_instance pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore with_domain_aspect = make_generic_dataset_mcp( - aspect=models.DomainsClass(domains=[datahub_domain]), aspect_name="domains" + aspect=models.DomainsClass(domains=[datahub_domain]), + aspect_name="domains", ) no_domain_aspect = make_generic_dataset_mcp( - entity_urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,example2,PROD)" + entity_urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,example2,PROD)", ) # Not a dataset, should be ignored @@ -2925,9 +3011,9 @@ def fake_get_aspect( name="User Deletions", description="Constructs the fct_users_deleted from logging_events", type=models.AzkabanJobTypeClass.SQL, - ) + ), ], - ) + ), ) inputs = [ @@ -2944,7 +3030,7 @@ def fake_get_aspect( "rules": { ".*example1.*": [acryl_domain, server_domain], ".*example2.*": [server_domain], - } + }, }, "is_container": True, # Enable container domain handling }, @@ -2952,7 +3038,7 @@ def fake_get_aspect( ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert ( @@ -3002,16 +3088,17 @@ def fake_get_aspect( return None pipeline_context = PipelineContext( - run_id="test_pattern_container_and_dataset_domain_transformation_with_no_container" + run_id="test_pattern_container_and_dataset_domain_transformation_with_no_container", ) pipeline_context.graph = mock_datahub_graph_instance pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore with_domain_aspect = make_generic_dataset_mcp( - aspect=models.DomainsClass(domains=[datahub_domain]), aspect_name="domains" + aspect=models.DomainsClass(domains=[datahub_domain]), + aspect_name="domains", ) no_domain_aspect = make_generic_dataset_mcp( - entity_urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,example2,PROD)" + entity_urn="urn:li:dataset:(urn:li:dataPlatform:bigquery,example2,PROD)", ) inputs = [ @@ -3027,7 +3114,7 @@ def fake_get_aspect( "rules": { ".*example1.*": [acryl_domain, server_domain], ".*example2.*": [server_domain], - } + }, }, "is_container": True, # Enable container domain handling }, @@ -3035,7 +3122,7 @@ def fake_get_aspect( ) outputs = list( - transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]) + transformer.transform([RecordEnvelope(input, metadata={}) for input in inputs]), ) assert len(outputs) == len(inputs) + 1 @@ -3060,7 +3147,7 @@ def test_pattern_add_container_dataset_domain_no_match(mock_datahub_graph_instan pattern = "urn:li:dataset:\\(urn:li:dataPlatform:invalid,.*" pipeline_context: PipelineContext = PipelineContext( - run_id="test_simple_add_dataset_domain" + run_id="test_simple_add_dataset_domain", ) pipeline_context.graph = mock_datahub_graph_instance @@ -3072,9 +3159,10 @@ def fake_get_aspect( return models.BrowsePathsV2Class( path=[ models.BrowsePathEntryClass( - id="container_1", urn="urn:li:container:container_1" - ) - ] + id="container_1", + urn="urn:li:container:container_1", + ), + ], ) pipeline_context.graph.get_aspect = fake_get_aspect # type: ignore @@ -3112,14 +3200,14 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: return models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -3127,8 +3215,8 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: glossaryTerms=models.GlossaryTermsClass( terms=[ models.GlossaryTermAssociationClass( - urn=builder.make_term_urn("pii") - ) + urn=builder.make_term_urn("pii"), + ), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -3141,8 +3229,8 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: glossaryTerms=models.GlossaryTermsClass( terms=[ models.GlossaryTermAssociationClass( - urn=builder.make_term_urn("pii") - ) + urn=builder.make_term_urn("pii"), + ), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -3170,20 +3258,20 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: builder.make_term_urn("Name"), builder.make_term_urn("LastName"), ], - } + }, }, }, aspect=models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -3212,10 +3300,12 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: def test_pattern_dataset_schema_terms_transformation_patch( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): output = run_pattern_dataset_schema_terms_transformation_semantics( - TransformerSemantics.PATCH, mock_datahub_graph_instance + TransformerSemantics.PATCH, + mock_datahub_graph_instance, ) assert len(output) == 2 # Check that glossary terms were added. @@ -3245,10 +3335,12 @@ def test_pattern_dataset_schema_terms_transformation_patch( def test_pattern_dataset_schema_terms_transformation_overwrite( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): output = run_pattern_dataset_schema_terms_transformation_semantics( - TransformerSemantics.OVERWRITE, mock_datahub_graph_instance + TransformerSemantics.OVERWRITE, + mock_datahub_graph_instance, ) assert len(output) == 2 @@ -3290,21 +3382,21 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: return models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( fieldPath="first_name", globalTags=models.GlobalTagsClass( tags=[ - models.TagAssociationClass(tag=builder.make_tag_urn("pii")) + models.TagAssociationClass(tag=builder.make_tag_urn("pii")), ], ), type=models.SchemaFieldDataTypeClass(type=models.StringTypeClass()), @@ -3315,7 +3407,7 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: fieldPath="mobile_number", globalTags=models.GlobalTagsClass( tags=[ - models.TagAssociationClass(tag=builder.make_tag_urn("pii")) + models.TagAssociationClass(tag=builder.make_tag_urn("pii")), ], ), type=models.SchemaFieldDataTypeClass(type=models.StringTypeClass()), @@ -3342,20 +3434,20 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: builder.make_tag_urn("Name"), builder.make_tag_urn("LastName"), ], - } + }, }, }, aspect=models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -3383,10 +3475,12 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: def test_pattern_dataset_schema_tags_transformation_overwrite( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): output = run_pattern_dataset_schema_tags_transformation_semantics( - TransformerSemantics.OVERWRITE, mock_datahub_graph_instance + TransformerSemantics.OVERWRITE, + mock_datahub_graph_instance, ) assert len(output) == 2 @@ -3417,10 +3511,12 @@ def test_pattern_dataset_schema_tags_transformation_overwrite( def test_pattern_dataset_schema_tags_transformation_patch( - mock_time, mock_datahub_graph_instance + mock_time, + mock_datahub_graph_instance, ): output = run_pattern_dataset_schema_tags_transformation_semantics( - TransformerSemantics.PATCH, mock_datahub_graph_instance + TransformerSemantics.PATCH, + mock_datahub_graph_instance, ) assert len(output) == 2 @@ -3455,15 +3551,18 @@ def test_simple_dataset_data_product_transformation(mock_time): { "dataset_to_data_product_urns": { builder.make_dataset_urn( - "bigquery", "example1" + "bigquery", + "example1", ): "urn:li:dataProduct:first", builder.make_dataset_urn( - "bigquery", "example2" + "bigquery", + "example2", ): "urn:li:dataProduct:second", builder.make_dataset_urn( - "bigquery", "example3" + "bigquery", + "example3", ): "urn:li:dataProduct:first", - } + }, }, PipelineContext(run_id="test-dataproduct"), ) @@ -3474,18 +3573,18 @@ def test_simple_dataset_data_product_transformation(mock_time): RecordEnvelope(input, metadata={}) for input in [ make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example1") + entity_urn=builder.make_dataset_urn("bigquery", "example1"), ), make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example2") + entity_urn=builder.make_dataset_urn("bigquery", "example2"), ), make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example3") + entity_urn=builder.make_dataset_urn("bigquery", "example3"), ), EndOfStream(), ] - ] - ) + ], + ), ) assert len(outputs) == 6 @@ -3495,7 +3594,7 @@ def test_simple_dataset_data_product_transformation(mock_time): assert outputs[3].record.aspectName == "dataProductProperties" first_data_product_aspect = json.loads( - outputs[3].record.aspect.value.decode("utf-8") + outputs[3].record.aspect.value.decode("utf-8"), ) assert [item["value"]["destinationUrn"] for item in first_data_product_aspect] == [ builder.make_dataset_urn("bigquery", "example1"), @@ -3503,10 +3602,10 @@ def test_simple_dataset_data_product_transformation(mock_time): ] second_data_product_aspect = json.loads( - outputs[4].record.aspect.value.decode("utf-8") + outputs[4].record.aspect.value.decode("utf-8"), ) assert [item["value"]["destinationUrn"] for item in second_data_product_aspect] == [ - builder.make_dataset_urn("bigquery", "example2") + builder.make_dataset_urn("bigquery", "example2"), ] assert isinstance(outputs[5].record, EndOfStream) @@ -3519,7 +3618,7 @@ def test_pattern_dataset_data_product_transformation(mock_time): "rules": { ".*example1.*": "urn:li:dataProduct:first", ".*": "urn:li:dataProduct:second", - } + }, }, }, PipelineContext(run_id="test-dataproducts"), @@ -3531,18 +3630,18 @@ def test_pattern_dataset_data_product_transformation(mock_time): RecordEnvelope(input, metadata={}) for input in [ make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example1") + entity_urn=builder.make_dataset_urn("bigquery", "example1"), ), make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example2") + entity_urn=builder.make_dataset_urn("bigquery", "example2"), ), make_generic_dataset( - entity_urn=builder.make_dataset_urn("bigquery", "example3") + entity_urn=builder.make_dataset_urn("bigquery", "example3"), ), EndOfStream(), ] - ] - ) + ], + ), ) assert len(outputs) == 6 @@ -3552,14 +3651,14 @@ def test_pattern_dataset_data_product_transformation(mock_time): assert outputs[3].record.aspectName == "dataProductProperties" first_data_product_aspect = json.loads( - outputs[3].record.aspect.value.decode("utf-8") + outputs[3].record.aspect.value.decode("utf-8"), ) assert [item["value"]["destinationUrn"] for item in first_data_product_aspect] == [ - builder.make_dataset_urn("bigquery", "example1") + builder.make_dataset_urn("bigquery", "example1"), ] second_data_product_aspect = json.loads( - outputs[4].record.aspect.value.decode("utf-8") + outputs[4].record.aspect.value.decode("utf-8"), ) assert [item["value"]["destinationUrn"] for item in second_data_product_aspect] == [ builder.make_dataset_urn("bigquery", "example2"), @@ -3571,7 +3670,7 @@ def test_pattern_dataset_data_product_transformation(mock_time): def dummy_data_product_resolver_method(dataset_urn): dataset_to_data_product_map = { - builder.make_dataset_urn("bigquery", "example1"): "urn:li:dataProduct:first" + builder.make_dataset_urn("bigquery", "example1"): "urn:li:dataProduct:first", } return dataset_to_data_product_map.get(dataset_urn) @@ -3579,7 +3678,7 @@ def dummy_data_product_resolver_method(dataset_urn): def test_add_dataset_data_product_transformation(): transformer = AddDatasetDataProduct.create( { - "get_data_product_to_add": "tests.unit.test_transform_dataset.dummy_data_product_resolver_method" + "get_data_product_to_add": "tests.unit.test_transform_dataset.dummy_data_product_resolver_method", }, PipelineContext(run_id="test-dataproduct"), ) @@ -3588,18 +3687,18 @@ def test_add_dataset_data_product_transformation(): [ RecordEnvelope(input, metadata={}) for input in [make_generic_dataset(), EndOfStream()] - ] - ) + ], + ), ) # Check new dataproduct entity should be there assert outputs[1].record.entityUrn == "urn:li:dataProduct:first" assert outputs[1].record.aspectName == "dataProductProperties" first_data_product_aspect = json.loads( - outputs[1].record.aspect.value.decode("utf-8") + outputs[1].record.aspect.value.decode("utf-8"), ) assert [item["value"]["destinationUrn"] for item in first_data_product_aspect] == [ - builder.make_dataset_urn("bigquery", "example1") + builder.make_dataset_urn("bigquery", "example1"), ] @@ -3615,7 +3714,7 @@ def fake_ownership_class(entity_urn: str) -> models.OwnershipClass: owners=[ models.OwnerClass(owner=owner, type=models.OwnershipTypeClass.DATAOWNER) for owner in in_owners - ] + ], ) in_pipeline_context.graph.get_ownership = fake_ownership_class # type: ignore @@ -3626,7 +3725,7 @@ def fake_ownership_class(entity_urn: str) -> models.OwnershipClass: owners=[ models.OwnerClass(owner=owner, type=models.OwnershipTypeClass.DATAOWNER) for owner in in_owners - ] + ], ), config={"pattern_for_cleanup": config}, pipeline_context=in_pipeline_context, @@ -3660,7 +3759,7 @@ def test_clean_owner_urn_transformation_remove_fixed_string( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove 'ABCDEF:' @@ -3677,7 +3776,7 @@ def test_clean_owner_urn_transformation_remove_fixed_string( expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3701,7 +3800,7 @@ def test_clean_owner_urn_transformation_remove_multiple_values( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove multiple values @@ -3718,7 +3817,7 @@ def test_clean_owner_urn_transformation_remove_multiple_values( expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3742,7 +3841,7 @@ def test_clean_owner_urn_transformation_remove_values_using_regex( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove words after `_` using RegEx i.e. `id`, `test` @@ -3759,7 +3858,7 @@ def test_clean_owner_urn_transformation_remove_values_using_regex( expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3781,7 +3880,7 @@ def test_clean_owner_urn_transformation_remove_digits(mock_datahub_graph_instanc in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove digits @@ -3798,7 +3897,7 @@ def test_clean_owner_urn_transformation_remove_digits(mock_datahub_graph_instanc expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3820,7 +3919,7 @@ def test_clean_owner_urn_transformation_remove_pattern(mock_datahub_graph_instan in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove `example.*` @@ -3837,7 +3936,7 @@ def test_clean_owner_urn_transformation_remove_pattern(mock_datahub_graph_instan expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3862,7 +3961,7 @@ def test_clean_owner_urn_transformation_remove_word_in_capital_letters( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # if string between `:` and `@` is in CAPITAL then remove it @@ -3880,7 +3979,7 @@ def test_clean_owner_urn_transformation_remove_word_in_capital_letters( expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3904,7 +4003,7 @@ def test_clean_owner_urn_transformation_remove_pattern_with_alphanumeric_value( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # remove any pattern having `id` followed by any digits @@ -3921,7 +4020,7 @@ def test_clean_owner_urn_transformation_remove_pattern_with_alphanumeric_value( expected_owner_urns: List[str] = [] for user in expected_user_emails: expected_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) _test_clean_owner_urns(pipeline_context, in_owner_urns, config, expected_owner_urns) @@ -3945,7 +4044,7 @@ def test_clean_owner_urn_transformation_should_not_remove_system_identifier( in_owner_urns: List[str] = [] for user in user_emails: in_owner_urns.append( - builder.make_owner_urn(user, owner_type=builder.OwnerType.USER) + builder.make_owner_urn(user, owner_type=builder.OwnerType.USER), ) # should not remove system identifier @@ -3958,7 +4057,7 @@ def test_replace_external_url_word_replace( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url" + run_id="test_replace_external_url", ) pipeline_context.graph = mock_datahub_graph_instance @@ -3985,7 +4084,7 @@ def test_replace_external_regex_replace_1( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url" + run_id="test_replace_external_url", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4012,7 +4111,7 @@ def test_replace_external_regex_replace_2( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url" + run_id="test_replace_external_url", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4039,7 +4138,7 @@ def test_pattern_cleanup_usage_statistics_user_1( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_pattern_cleanup_usage_statistics_user" + run_id="test_pattern_cleanup_usage_statistics_user", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4093,7 +4192,7 @@ def test_pattern_cleanup_usage_statistics_user_2( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_pattern_cleanup_usage_statistics_user" + run_id="test_pattern_cleanup_usage_statistics_user", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4147,7 +4246,7 @@ def test_pattern_cleanup_usage_statistics_user_3( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_pattern_cleanup_usage_statistics_user" + run_id="test_pattern_cleanup_usage_statistics_user", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4303,7 +4402,7 @@ def test_domain_mapping_based__r_on_tags_with_multiple_tags( # Return fake aspect to simulate server behaviour def fake_get_tags(entity_urn: str) -> models.GlobalTagsClass: return models.GlobalTagsClass( - tags=[TagAssociationClass(tag=tag_one), TagAssociationClass(tag=tag_two)] + tags=[TagAssociationClass(tag=tag_one), TagAssociationClass(tag=tag_two)], ) # Return fake aspect to simulate server behaviour @@ -4401,7 +4500,7 @@ def fake_get_tags(entity_urn: str) -> models.GlobalTagsClass: tags=[ TagAssociationClass(tag=builder.make_tag_urn("example1")), TagAssociationClass(tag=builder.make_tag_urn("example2")), - ] + ], ) # fake the server response @@ -4409,14 +4508,14 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: return models.SchemaMetadataClass( schemaName="customer", # not used platform=builder.make_data_platform_urn( - "hive" + "hive", ), # important <- platform must be an urn version=0, # when the source system has a notion of versioning of schemas, insert this in, otherwise leave as 0 hash="", # when the source system has a notion of unique schemas identified via hash, include a hash, else leave it as empty string platformSchema=models.OtherSchemaClass( - rawSchema="__insert raw schema here__" + rawSchema="__insert raw schema here__", ), fields=[ models.SchemaFieldClass( @@ -4424,15 +4523,15 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: globalTags=models.GlobalTagsClass( tags=[ models.TagAssociationClass( - tag=builder.make_tag_urn("example2") - ) + tag=builder.make_tag_urn("example2"), + ), ], ), glossaryTerms=models.GlossaryTermsClass( terms=[ models.GlossaryTermAssociationClass( - urn=builder.make_term_urn("pii") - ) + urn=builder.make_term_urn("pii"), + ), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4445,8 +4544,8 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: glossaryTerms=models.GlossaryTermsClass( terms=[ models.GlossaryTermAssociationClass( - urn=builder.make_term_urn("pii") - ) + urn=builder.make_term_urn("pii"), + ), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4470,7 +4569,7 @@ def fake_schema_metadata(entity_urn: str) -> models.SchemaMetadataClass: transformer_type=TagsToTermMapper, aspect=models.GlossaryTermsClass( terms=[ - models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")) + models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4499,7 +4598,7 @@ def fake_get_tags_no_match(entity_urn: str) -> models.GlobalTagsClass: tags=[ TagAssociationClass(tag=builder.make_tag_urn("nonMatchingTag1")), TagAssociationClass(tag=builder.make_tag_urn("nonMatchingTag2")), - ] + ], ) pipeline_context = PipelineContext(run_id="transformer_pipe_line") @@ -4514,7 +4613,7 @@ def fake_get_tags_no_match(entity_urn: str) -> models.GlobalTagsClass: transformer_type=TagsToTermMapper, aspect=models.GlossaryTermsClass( terms=[ - models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")) + models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4545,7 +4644,7 @@ def fake_get_no_tags(entity_urn: str) -> models.GlobalTagsClass: transformer_type=TagsToTermMapper, aspect=models.GlossaryTermsClass( terms=[ - models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")) + models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4566,12 +4665,12 @@ def fake_get_partial_match_tags(entity_urn: str) -> models.GlobalTagsClass: return models.GlobalTagsClass( tags=[ TagAssociationClass( - tag=builder.make_tag_urn("example1") + tag=builder.make_tag_urn("example1"), ), # Should match TagAssociationClass( - tag=builder.make_tag_urn("nonMatchingTag") + tag=builder.make_tag_urn("nonMatchingTag"), ), # No match - ] + ], ) pipeline_context = PipelineContext(run_id="transformer_pipe_line") @@ -4585,7 +4684,7 @@ def fake_get_partial_match_tags(entity_urn: str) -> models.GlobalTagsClass: transformer_type=TagsToTermMapper, aspect=models.GlossaryTermsClass( terms=[ - models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")) + models.GlossaryTermAssociationClass(urn=builder.make_term_urn("pii")), ], auditStamp=models.AuditStampClass._construct_with_defaults(), ), @@ -4605,7 +4704,7 @@ def test_replace_external_url_container_word_replace( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url_container" + run_id="test_replace_external_url_container", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4633,7 +4732,7 @@ def test_replace_external_regex_container_replace_1( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url_container" + run_id="test_replace_external_url_container", ) pipeline_context.graph = mock_datahub_graph_instance @@ -4661,7 +4760,7 @@ def test_replace_external_regex_container_replace_2( mock_datahub_graph_instance, ): pipeline_context: PipelineContext = PipelineContext( - run_id="test_replace_external_url_container" + run_id="test_replace_external_url_container", ) pipeline_context.graph = mock_datahub_graph_instance diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_config.py b/metadata-ingestion/tests/unit/test_unity_catalog_config.py index ba554e966669ec..936fd99f3dee2e 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_config.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_config.py @@ -18,12 +18,13 @@ def test_within_thirty_days(): "include_usage_statistics": True, "include_hive_metastore": False, "start_time": FROZEN_TIME - timedelta(days=30), - } + }, ) assert config.start_time == FROZEN_TIME - timedelta(days=30) with pytest.raises( - ValueError, match="Query history is only maintained for 30 days." + ValueError, + match="Query history is only maintained for 30 days.", ): UnityCatalogSourceConfig.parse_obj( { @@ -31,7 +32,7 @@ def test_within_thirty_days(): "workspace_url": "https://workspace_url", "include_usage_statistics": True, "start_time": FROZEN_TIME - timedelta(days=31), - } + }, ) @@ -46,7 +47,7 @@ def test_profiling_requires_warehouses_id(): "method": "ge", "warehouse_id": "my_warehouse_id", }, - } + }, ) assert config.profiling.enabled is True @@ -56,7 +57,7 @@ def test_profiling_requires_warehouses_id(): "workspace_url": "https://workspace_url", "include_hive_metastore": False, "profiling": {"enabled": False, "method": "ge"}, - } + }, ) assert config.profiling.enabled is False @@ -66,7 +67,7 @@ def test_profiling_requires_warehouses_id(): "token": "token", "include_hive_metastore": False, "workspace_url": "workspace_url", - } + }, ) @@ -78,7 +79,7 @@ def test_workspace_url_should_start_with_https(): "token": "token", "workspace_url": "workspace_url", "profiling": {"enabled": True}, - } + }, ) @@ -92,7 +93,7 @@ def test_global_warehouse_id_is_set_from_profiling(): "enabled": True, "warehouse_id": "my_warehouse_id", }, - } + }, ) assert config.profiling.warehouse_id == "my_warehouse_id" assert config.warehouse_id == "my_warehouse_id" @@ -113,7 +114,7 @@ def test_set_different_warehouse_id_from_profiling(): "enabled": True, "warehouse_id": "my_warehouse_id", }, - } + }, ) @@ -127,7 +128,7 @@ def test_warehouse_id_must_be_set_if_include_hive_metastore_is_true(): "token": "token", "workspace_url": "https://XXXXXXXXXXXXXXXXXXXXX", "include_hive_metastore": True, - } + }, ) @@ -152,6 +153,6 @@ def test_set_profiling_warehouse_id_from_global(): "method": "ge", "enabled": True, }, - } + }, ) assert config.profiling.warehouse_id == "my_global_warehouse_id" diff --git a/metadata-ingestion/tests/unit/test_usage_common.py b/metadata-ingestion/tests/unit/test_usage_common.py index bd6d194835dd96..d13a9d036f6b30 100644 --- a/metadata-ingestion/tests/unit/test_usage_common.py +++ b/metadata-ingestion/tests/unit/test_usage_common.py @@ -314,7 +314,9 @@ def test_make_usage_workunit_include_top_n_queries(): @freeze_time("2023-01-01 00:00:00") def test_convert_usage_aggregation_class(): urn = make_dataset_urn_with_platform_instance( - "platform", "test_db.test_schema.test_table", None + "platform", + "test_db.test_schema.test_table", + None, ) usage_aggregation = UsageAggregationClass( bucket=int(time.time() * 1000), @@ -333,7 +335,7 @@ def test_convert_usage_aggregation_class(): ), ) assert convert_usage_aggregation_class( - usage_aggregation + usage_aggregation, ) == MetadataChangeProposalWrapper( entityUrn=urn, aspect=DatasetUsageStatisticsClass( @@ -344,7 +346,9 @@ def test_convert_usage_aggregation_class(): topSqlQueries=["SELECT * FROM my_table", "SELECT col from a.b.c"], userCounts=[ DatasetUserUsageCountsClass( - user="abc", count=3, userEmail="abc@acryl.io" + user="abc", + count=3, + userEmail="abc@acryl.io", ), DatasetUserUsageCountsClass(user="def", count=1), ], @@ -367,7 +371,7 @@ def test_convert_usage_aggregation_class(): metrics=UsageAggregationMetricsClass(), ) assert convert_usage_aggregation_class( - empty_usage_aggregation + empty_usage_aggregation, ) == MetadataChangeProposalWrapper( entityUrn=empty_urn, aspect=DatasetUsageStatisticsClass( diff --git a/metadata-ingestion/tests/unit/test_vertica_source.py b/metadata-ingestion/tests/unit/test_vertica_source.py index de888ddb559242..dd545aaaf9fee5 100644 --- a/metadata-ingestion/tests/unit/test_vertica_source.py +++ b/metadata-ingestion/tests/unit/test_vertica_source.py @@ -8,7 +8,7 @@ def test_vertica_uri_https(): "password": "password", "host_port": "host:5433", "database": "db", - } + }, ) assert ( config.get_sql_alchemy_url() diff --git a/metadata-ingestion/tests/unit/urns/test_data_job_urn.py b/metadata-ingestion/tests/unit/urns/test_data_job_urn.py index 484e5a474c0cd2..0895c1c95b48fc 100644 --- a/metadata-ingestion/tests/unit/urns/test_data_job_urn.py +++ b/metadata-ingestion/tests/unit/urns/test_data_job_urn.py @@ -14,10 +14,11 @@ def test_parse_urn(self) -> None: ) data_job_urn = DataJobUrn.create_from_string(data_job_urn_str) assert data_job_urn.get_data_flow_urn() == DataFlowUrn.create_from_string( - "urn:li:dataFlow:(airflow,flow_id,prod)" + "urn:li:dataFlow:(airflow,flow_id,prod)", ) assert data_job_urn.get_job_id() == "job_id" assert data_job_urn.__str__() == data_job_urn_str assert data_job_urn == DataJobUrn( - "urn:li:dataFlow:(airflow,flow_id,prod)", "job_id" + "urn:li:dataFlow:(airflow,flow_id,prod)", + "job_id", ) diff --git a/metadata-ingestion/tests/unit/urns/test_data_process_instance_urn.py b/metadata-ingestion/tests/unit/urns/test_data_process_instance_urn.py index f9087b19b13c32..83441c52a682ed 100644 --- a/metadata-ingestion/tests/unit/urns/test_data_process_instance_urn.py +++ b/metadata-ingestion/tests/unit/urns/test_data_process_instance_urn.py @@ -10,7 +10,7 @@ class TestDataProcessInstanceUrn(unittest.TestCase): def test_parse_urn(self) -> None: dataprocessinstance_urn_str = "urn:li:dataProcessInstance:abc" dataprocessinstance_urn = DataProcessInstanceUrn.create_from_string( - dataprocessinstance_urn_str + dataprocessinstance_urn_str, ) assert dataprocessinstance_urn.get_type() == DataProcessInstanceUrn.ENTITY_TYPE diff --git a/metadata-ingestion/tests/unit/urns/test_urn.py b/metadata-ingestion/tests/unit/urns/test_urn.py index 8490364326d940..d273bef85ae65f 100644 --- a/metadata-ingestion/tests/unit/urns/test_urn.py +++ b/metadata-ingestion/tests/unit/urns/test_urn.py @@ -41,7 +41,7 @@ def test_parse_urn() -> None: def test_url_encode_urn() -> None: urn_with_slash: Urn = Urn.create_from_string( - "urn:li:dataset:(urn:li:dataPlatform:abc,def/ghi,prod)" + "urn:li:dataset:(urn:li:dataPlatform:abc,def/ghi,prod)", ) assert ( Urn.url_encode(str(urn_with_slash)) diff --git a/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py b/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py index 5b320b8a232544..061f3111f66bad 100644 --- a/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py +++ b/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py @@ -11,7 +11,9 @@ def task(i): assert { res.result() for res in BackpressureAwareExecutor.map( - task, ((i,) for i in range(10)), max_workers=2 + task, + ((i,) for i in range(10)), + max_workers=2, ) } == set(range(10)) @@ -32,7 +34,10 @@ def task(x, y): with PerfTimer() as timer: results = BackpressureAwareExecutor.map( - task, args_list, max_workers=2, max_pending=4 + task, + args_list, + max_workers=2, + max_pending=4, ) assert timer.elapsed_seconds() < task_duration diff --git a/metadata-ingestion/tests/unit/utilities/test_cli_logging.py b/metadata-ingestion/tests/unit/utilities/test_cli_logging.py index aa22a3b5e7ceba..07b0ec5abb9f4a 100644 --- a/metadata-ingestion/tests/unit/utilities/test_cli_logging.py +++ b/metadata-ingestion/tests/unit/utilities/test_cli_logging.py @@ -46,7 +46,8 @@ def test_cli_logging(tmp_path): runner = CliRunner() result = runner.invoke( - datahub, ["--debug", "--log-file", str(log_file), "my-logging-fn"] + datahub, + ["--debug", "--log-file", str(log_file), "my-logging-fn"], ) assert result.exit_code == 0 diff --git a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py index 7e1627151c6ebf..fc8d30bab3f4a2 100644 --- a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py @@ -291,7 +291,8 @@ def test_custom_column(cache_max_size: int, use_sqlite_on_conflict: bool) -> Non # Test param binding. assert ( cache.sql_query( - f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,) + f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", + params=(50,), )[0][0] == 31 ) @@ -346,7 +347,7 @@ def test_shared_connection(use_sqlite_on_conflict: bool) -> None: # Test advanced SQL queries and sql_query_iterator. iterator = cache2.sql_query_iterator( - f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y" + f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y", ) assert isinstance(iterator, sqlite3.Cursor) assert [tuple(r) for r in iterator] == [("a", 15), ("b", 11)] diff --git a/metadata-ingestion/tests/unit/utilities/test_hive_schema_to_avro.py b/metadata-ingestion/tests/unit/utilities/test_hive_schema_to_avro.py index d1d47e2d8593ae..58fd28bc58be26 100644 --- a/metadata-ingestion/tests/unit/utilities/test_hive_schema_to_avro.py +++ b/metadata-ingestion/tests/unit/utilities/test_hive_schema_to_avro.py @@ -21,7 +21,8 @@ def test_get_avro_schema_for_struct_hive_column(): def test_get_avro_schema_for_struct_hive_with_duplicate_column(): schema_fields = get_schema_fields_for_hive_column( - "test", "struct" + "test", + "struct", ) assert schema_fields[0].type.type == RecordTypeClass() # Len will be the struct + 2 key there which should remain after the deduplication @@ -39,7 +40,8 @@ def test_get_avro_schema_for_struct_hive_with_duplicate_column2(): def test_get_avro_schema_for_null_type_hive_column(): schema_fields = get_schema_fields_for_hive_column( - hive_column_name="test", hive_column_type="unknown" + hive_column_name="test", + hive_column_type="unknown", ) assert schema_fields[0].type.type == NullTypeClass() assert len(schema_fields) == 1 diff --git a/metadata-ingestion/tests/unit/utilities/test_lossy_collections.py b/metadata-ingestion/tests/unit/utilities/test_lossy_collections.py index e137d671e95d71..7c6581a8e486c2 100644 --- a/metadata-ingestion/tests/unit/utilities/test_lossy_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_lossy_collections.py @@ -35,7 +35,7 @@ def test_lossyset_sampling(length, sampling): assert lossy_set.sampled is sampling if sampling: assert f"... sampled with at most {length - 10} elements missing" in str( - lossy_set + lossy_set, ) else: assert "sampled" not in str(lossy_set) @@ -48,7 +48,8 @@ def test_lossyset_sampling(length, sampling): @pytest.mark.parametrize( - "length, sampling, sub_length", [(4, False, 4), (10, False, 14), (100, True, 1000)] + "length, sampling, sub_length", + [(4, False, 4), (10, False, 14), (100, True, 1000)], ) def test_lossydict_sampling(length, sampling, sub_length): lossy_dict: LossyDict[int, LossyList[str]] = LossyDict() diff --git a/metadata-ingestion/tests/unit/utilities/test_partition_executor.py b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py index 89e95d185e8028..699581dd1e41eb 100644 --- a/metadata-ingestion/tests/unit/utilities/test_partition_executor.py +++ b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py @@ -64,7 +64,8 @@ def task(id: str) -> str: return id with PartitionExecutor( - max_workers=5, max_pending=10 + max_workers=5, + max_pending=10, ) as executor, PerfTimer() as timer: # The first 15 submits should be non-blocking. for i in range(15): @@ -193,6 +194,9 @@ def process_batch(batch): def test_empty_batch_partition_executor(): # We want to test that even if no submit() calls are made, cleanup works fine. with BatchPartitionExecutor( - max_workers=5, max_pending=20, process_batch=lambda batch: None, max_per_batch=2 + max_workers=5, + max_pending=20, + process_batch=lambda batch: None, + max_per_batch=2, ) as executor: assert executor is not None diff --git a/metadata-ingestion/tests/unit/utilities/test_search_utils.py b/metadata-ingestion/tests/unit/utilities/test_search_utils.py index 6fa2e46c7f20e8..938270e1923a21 100644 --- a/metadata-ingestion/tests/unit/utilities/test_search_utils.py +++ b/metadata-ingestion/tests/unit/utilities/test_search_utils.py @@ -41,7 +41,7 @@ def test_simple_or_filters(): def test_simple_field_match(): query: ElasticDocumentQuery = ElasticDocumentQuery.create_from( - ("field1", "value1:1") + ("field1", "value1:1"), ) assert query.build() == 'field1:"value1\\:1"' diff --git a/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py b/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py index b080819cea95be..d8532ee53f98be 100644 --- a/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py +++ b/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py @@ -23,7 +23,9 @@ def test_get_avro_schema_for_sqlalchemy_column(): inspector_magic_mock.dialect = DefaultDialect() schema_fields = get_schema_fields_for_sqlalchemy_column( - column_name="test", column_type=types.INTEGER(), inspector=inspector_magic_mock + column_name="test", + column_type=types.INTEGER(), + inspector=inspector_magic_mock, ) assert len(schema_fields) == 1 assert schema_fields[0].fieldPath == "[version=2.0].[type=int].test" @@ -77,7 +79,8 @@ def test_get_avro_schema_for_sqlalchemy_map_column(): == "[version=2.0].[type=struct].[type=map].[type=boolean].test" ) assert schema_fields[0].type.type == MapTypeClass( - keyType="string", valueType="boolean" + keyType="string", + valueType="boolean", ) assert schema_fields[0].nativeDataType == "MapType(String(), BOOLEAN())" @@ -112,7 +115,9 @@ def test_get_avro_schema_for_sqlalchemy_unknown_column(): inspector_magic_mock.dialect = DefaultDialect() schema_fields = get_schema_fields_for_sqlalchemy_column( - "invalid", "test", inspector=inspector_magic_mock + "invalid", + "test", + inspector=inspector_magic_mock, ) assert len(schema_fields) == 1 assert schema_fields[0].type.type == NullTypeClass() diff --git a/metadata-ingestion/tests/unit/utilities/test_threaded_iterator_executor.py b/metadata-ingestion/tests/unit/utilities/test_threaded_iterator_executor.py index fb7e2266e1c9d3..5fc4802079734f 100644 --- a/metadata-ingestion/tests/unit/utilities/test_threaded_iterator_executor.py +++ b/metadata-ingestion/tests/unit/utilities/test_threaded_iterator_executor.py @@ -9,6 +9,8 @@ def table_of(i): assert { res for res in ThreadedIteratorExecutor.process( - table_of, [(i,) for i in range(1, 30)], max_workers=2 + table_of, + [(i,) for i in range(1, 30)], + max_workers=2, ) } == {x for i in range(1, 30) for x in table_of(i)} diff --git a/metadata-ingestion/tests/unit/utilities/test_unified_diff.py b/metadata-ingestion/tests/unit/utilities/test_unified_diff.py index 05277ec3fa0abb..5eedb45cb34551 100644 --- a/metadata-ingestion/tests/unit/utilities/test_unified_diff.py +++ b/metadata-ingestion/tests/unit/utilities/test_unified_diff.py @@ -78,10 +78,15 @@ def test_apply_hunk_success(): def test_apply_hunk_mismatch(): result_lines = ["Line 1", "Line 2", "Line X"] hunk = Hunk( - 2, 2, 2, 2, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")] + 2, + 2, + 2, + 2, + [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")], ) with pytest.raises( - DiffApplyError, match="Removing line that doesn't exactly match" + DiffApplyError, + match="Removing line that doesn't exactly match", ): apply_hunk(result_lines, hunk, 0) @@ -96,7 +101,11 @@ def test_apply_hunk_context_mismatch(): def test_apply_hunk_invalid_prefix(): result_lines = ["Line 1", "Line 2", "Line 3"] hunk = Hunk( - 2, 2, 2, 2, [(" ", "Line 2"), ("*", "Line 3"), ("+", "Line 3 modified")] + 2, + 2, + 2, + 2, + [(" ", "Line 2"), ("*", "Line 3"), ("+", "Line 3 modified")], ) with pytest.raises(DiffApplyError, match="Invalid line prefix"): apply_hunk(result_lines, hunk, 0) @@ -105,10 +114,15 @@ def test_apply_hunk_invalid_prefix(): def test_apply_hunk_end_of_file(): result_lines = ["Line 1", "Line 2"] hunk = Hunk( - 2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")] + 2, + 2, + 2, + 3, + [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")], ) with pytest.raises( - DiffApplyError, match="Found context or deletions after end of file" + DiffApplyError, + match="Found context or deletions after end of file", ): apply_hunk(result_lines, hunk, 0) @@ -116,10 +130,15 @@ def test_apply_hunk_end_of_file(): def test_apply_hunk_context_beyond_end_of_file(): result_lines = ["Line 1", "Line 3"] hunk = Hunk( - 2, 2, 2, 3, [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 3"), (" ", "Line 4")] + 2, + 2, + 2, + 3, + [(" ", "Line 1"), ("+", "Line 2"), (" ", "Line 3"), (" ", "Line 4")], ) with pytest.raises( - DiffApplyError, match="Found context or deletions after end of file" + DiffApplyError, + match="Found context or deletions after end of file", ): apply_hunk(result_lines, hunk, 0) @@ -127,10 +146,15 @@ def test_apply_hunk_context_beyond_end_of_file(): def test_apply_hunk_remove_non_existent_line(): result_lines = ["Line 1", "Line 2", "Line 4"] hunk = Hunk( - 2, 2, 2, 3, [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")] + 2, + 2, + 2, + 3, + [(" ", "Line 2"), ("-", "Line 3"), ("+", "Line 3 modified")], ) with pytest.raises( - DiffApplyError, match="Removing line that doesn't exactly match" + DiffApplyError, + match="Removing line that doesn't exactly match", ): apply_hunk(result_lines, hunk, 0) @@ -138,7 +162,11 @@ def test_apply_hunk_remove_non_existent_line(): def test_apply_hunk_addition_beyond_end_of_file(): result_lines = ["Line 1", "Line 2"] hunk = Hunk( - 2, 2, 2, 3, [(" ", "Line 2"), ("+", "Line 3 modified"), ("+", "Line 4")] + 2, + 2, + 2, + 3, + [(" ", "Line 2"), ("+", "Line 3 modified"), ("+", "Line 4")], ) apply_hunk(result_lines, hunk, 0) assert result_lines == ["Line 1", "Line 2", "Line 3 modified", "Line 4"] diff --git a/metadata-ingestion/tests/unit/utilities/test_urn_encoder.py b/metadata-ingestion/tests/unit/utilities/test_urn_encoder.py index 5f50fef4d97de1..9d14bd042db499 100644 --- a/metadata-ingestion/tests/unit/utilities/test_urn_encoder.py +++ b/metadata-ingestion/tests/unit/utilities/test_urn_encoder.py @@ -25,5 +25,6 @@ def test_encode_string_without_reserved_chars_no_change(name): ) def test_encode_string_with_reserved_chars(name): assert UrnEncoder.encode_string(name) == name.replace(",", "%2C").replace( - "(", "%28" + "(", + "%28", ).replace(")", "%29") diff --git a/metadata-ingestion/tests/unit/utilities/test_yaml_sync_utils.py b/metadata-ingestion/tests/unit/utilities/test_yaml_sync_utils.py index 511784bd2cbd86..b0a8bafd385121 100644 --- a/metadata-ingestion/tests/unit/utilities/test_yaml_sync_utils.py +++ b/metadata-ingestion/tests/unit/utilities/test_yaml_sync_utils.py @@ -18,7 +18,7 @@ def test_update_yaml_file(tmp_path: pathlib.Path) -> None: - foo - key1: value1 key2: value2 -""" +""", ) # ind=4, bsi=2 @@ -59,7 +59,7 @@ def test_indentation_inference(tmp_path: pathlib.Path) -> None: - foo - key1: value1 key2: value2 -""" +""", ) # ind=2, bsi=0