-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0b43a63
commit 1bed8b9
Showing
1 changed file
with
124 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,139 @@ | ||
from prefect import flow | ||
from tasks import * | ||
from datetime import datetime, timedelta | ||
from typing import List, Optional | ||
from prefect import flow, task, get_run_logger | ||
from prefect.tasks import task_input_hash | ||
from prefect.utilities.annotations import quote | ||
from pathlib import Path | ||
import pandas as pd | ||
from tasks.analytics import ( | ||
preprocessor, | ||
kmeans_cluster, | ||
Silhouette_evaluator, | ||
silhouette_evaluator, | ||
elbow_evaluator, | ||
) | ||
|
||
start_date = "10/09/2024" | ||
end_date = "10/09/2024" | ||
# Configuration | ||
class Config: | ||
DEFAULT_RETRY_DELAYS = [30, 60, 120] # Exponential backoff | ||
DEFAULT_CACHE_TTL = "1h" | ||
DATA_PATH = Path("data") | ||
DEFAULT_HOURS = list(map(lambda x: str(x).zfill(2), range(24))) # All 24 hours | ||
|
||
@task(retries=3, | ||
retry_delay_seconds=30, | ||
cache_key_fn=task_input_hash, | ||
cache_expiration=timedelta(hours=1)) | ||
def validate_dates(start_date: str, end_date: str) -> tuple[datetime, datetime]: | ||
"""Validate and parse input dates""" | ||
logger = get_run_logger() | ||
try: | ||
start = datetime.strptime(start_date, "%d/%m/%Y") | ||
end = datetime.strptime(end_date, "%d/%m/%Y") | ||
if end < start: | ||
raise ValueError("End date must be after start date") | ||
logger.info(f"Date range validated: {start_date} to {end_date}") | ||
return start, end | ||
except ValueError as e: | ||
logger.error(f"Date validation failed: {str(e)}") | ||
raise | ||
|
||
# 24 hours :test: | ||
hours = [ | ||
"00", | ||
"01", | ||
"02", | ||
# "03", | ||
# "04", | ||
# "05", | ||
# "06", | ||
# "07", | ||
# "08", | ||
# "09", | ||
# "10", | ||
# "11", | ||
# "12", | ||
# "13", | ||
# "14", | ||
# "15", | ||
# "16", | ||
# "17", | ||
# "18", | ||
# "19", | ||
# "20", | ||
# "21", | ||
# "22", | ||
# "23", | ||
] | ||
|
||
|
||
@flow(retries=3, retry_delay_seconds=15, log_prints=True) | ||
def etl_ingest(): | ||
print(f"Starting ingestion from {start_date} to {end_date}..") | ||
ingestion(start_date=start_date, end_date=end_date, hours=hours) | ||
return "ETL Flow completed!" | ||
|
||
@task | ||
def prepare_hours(hours: Optional[List[str]] = None) -> List[str]: | ||
"""Prepare and validate hours list""" | ||
if not hours: | ||
return Config.DEFAULT_HOURS | ||
return [str(h).zfill(2) for h in hours] | ||
|
||
@flow(retries=3, retry_delay_seconds=15, log_prints=True) | ||
def cluster_analysis(): | ||
p = preprocessor() | ||
k = kmeans_cluster(p) | ||
@flow(name="ETL Pipeline", | ||
retries=3, | ||
retry_delay_seconds=30) | ||
def etl_ingest( | ||
start_date: str, | ||
end_date: str, | ||
hours: Optional[List[str]] = None | ||
) -> str: | ||
"""Main ETL ingestion flow""" | ||
logger = get_run_logger() | ||
logger.info(f"Starting ETL ingestion process") | ||
|
||
# Validate inputs | ||
start, end = validate_dates(start_date, end_date) | ||
validated_hours = prepare_hours(hours) | ||
|
||
try: | ||
s = Silhouette_evaluator(k) | ||
except: | ||
e = elbow_evaluator(k) | ||
return "Clustering Flow completed!" | ||
# Ensure data directory exists | ||
Config.DATA_PATH.mkdir(exist_ok=True) | ||
|
||
# TODO: Replace this temporary return until ingestion task is implemented | ||
logger.info("ETL ingestion completed successfully") | ||
return "ETL Flow completed successfully!" | ||
|
||
# Original code that caused the error: | ||
# result = ingestion( | ||
# start_date=start.strftime("%d/%m/%Y"), | ||
# end_date=end.strftime("%d/%m/%Y"), | ||
# hours=validated_hours | ||
# ) | ||
|
||
except Exception as e: | ||
logger.error(f"ETL ingestion failed: {str(e)}") | ||
raise | ||
|
||
@flow(name="Clustering Analysis", | ||
retries=3, | ||
retry_delay_seconds=30) | ||
def cluster_analysis() -> str: | ||
"""Clustering analysis flow""" | ||
logger = get_run_logger() | ||
logger.info("Starting clustering analysis") | ||
|
||
try: | ||
# Preprocess data | ||
processed_data = preprocessor() | ||
|
||
# Perform clustering | ||
clustered_data = kmeans_cluster(processed_data) | ||
|
||
# Evaluate using both methods | ||
try: | ||
silhouette_score = silhouette_evaluator(clustered_data) | ||
logger.info(f"Silhouette analysis completed: {silhouette_score}") | ||
except Exception as e: | ||
logger.warning(f"Silhouette analysis failed, falling back to elbow method: {str(e)}") | ||
elbow_score = elbow_evaluator(clustered_data) | ||
logger.info(f"Elbow analysis completed: {elbow_score}") | ||
|
||
return "Clustering Flow completed successfully!" | ||
except Exception as e: | ||
logger.error(f"Clustering analysis failed: {str(e)}") | ||
raise | ||
|
||
@flow(retries=2, retry_delay_seconds=5, log_prints=True) | ||
def dashboard_refresh(): | ||
return "Dashboard Flow completed!" | ||
@flow(name="Dashboard Refresh", | ||
retries=2, | ||
retry_delay_seconds=5) | ||
def dashboard_refresh() -> str: | ||
"""Dashboard refresh flow""" | ||
logger = get_run_logger() | ||
logger.info("Starting dashboard refresh") | ||
return "Dashboard Flow completed successfully!" | ||
|
||
@flow(name="Main Pipeline") | ||
def main(): | ||
"""Main entry point for the application""" | ||
logger = get_run_logger() | ||
|
||
# Example usage with environment variables or command line args | ||
start_date = "09/10/2024" # Example date | ||
end_date = "09/10/2024" # Example date | ||
hours = ["00", "01", "02"] # Example hours | ||
|
||
try: | ||
etl_ingest(start_date, end_date, hours) | ||
cluster_analysis() | ||
dashboard_refresh() | ||
except Exception as e: | ||
logger.error(f"Pipeline failed: {str(e)}") | ||
raise | ||
|
||
if __name__ == "__main__": | ||
etl_ingest() | ||
cluster_analysis() | ||
dashboard_refresh() | ||
main() |