forked from astronomer/2-9-example-dags
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_to_snowflake.py
100 lines (81 loc) · 3.24 KB
/
load_to_snowflake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
## Load to Snowflake
This DAG loads data from the loading location in S3 to Snowflake.
To use this DAG you need a stage for each source in Snowflake.
For example for 3 sources you would need to create these 3 stages:
```sql
CREATE STAGE sales_reports_stage
URL = 's3://ce-2-8-examples-bucket/load/sales_reports/'
CREDENTIALS = (AWS_KEY_ID = '<your aws key id>' AWS_SECRET_KEY = '<your aws secret>')
FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);
CREATE STAGE customer_feedback_stage
URL = 's3://ce-2-8-examples-bucket/load/customer_feedback/'
CREDENTIALS = (AWS_KEY_ID = '<your aws key id>' AWS_SECRET_KEY = '<your aws secret>')
FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);
CREATE STAGE customer_data_stage
URL = 's3://ce-2-8-examples-bucket/load/customer_data/'
CREDENTIALS = (AWS_KEY_ID = '<your aws key id>' AWS_SECRET_KEY = '<your aws secret>')
FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);
```
"""
from airflow.datasets import Dataset
from airflow.decorators import dag
from airflow.models.baseoperator import chain
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
from pendulum import datetime
import json
from functools import reduce
with open("include/ingestion_source_config.json", "r") as f:
config = json.load(f)
ingestion_datasets = [
Dataset(source["dataset_uri"]) for source in config["sources"]
]
ingestion_datasets = tuple(ingestion_datasets)
SNOWFLAKE_CONN_ID = "snowflake_de_team"
@dag(
start_date=datetime(2024, 1, 1),
schedule=reduce(
lambda x, y: x | y, ingestion_datasets
), # NEW in Airflow 2.9: Schedule on logical expressions involving datasets
catchup=False,
tags=["Conditional Dataset Scheduling", "use-case"],
default_args={"owner": "Piglet", "retries": 3, "retry_delay": 5},
description="Load data from S3 to Snowflake",
doc_md=__doc__,
)
def load_to_snowflake():
create_file_format = SnowflakeOperator(
task_id="create_file_format",
sql="""
CREATE FILE FORMAT IF NOT EXISTS my_csv_format
TYPE = 'CSV'
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
SKIP_HEADER = 1;
""",
snowflake_conn_id=SNOWFLAKE_CONN_ID,
)
for source_name, table_creation_sql in [
(source["source_name"], source["table_creation_sql"])
for source in config["sources"]
]:
create_table_if_not_exists = SnowflakeOperator(
task_id=f"create_table_if_not_exists_{source_name}",
sql=table_creation_sql,
snowflake_conn_id=SNOWFLAKE_CONN_ID,
)
load_data = SnowflakeOperator(
task_id=f"load_{source_name}",
snowflake_conn_id=SNOWFLAKE_CONN_ID,
sql=f"""
COPY INTO {source_name}_table
FROM @{source_name}_stage
FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);
""",
outlets=[Dataset(f"snowflake://{source_name}_table")],
)
chain(
create_file_format,
create_table_if_not_exists,
load_data,
)
load_to_snowflake()