From a9a0766354c70d1fcb3c742dda273a49e741d9a5 Mon Sep 17 00:00:00 2001 From: Daniel Glogowski <167348611+dglogo@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:51:39 -0800 Subject: [PATCH] dglogo/code documentation (#4) * readme * readme update * code docs * code documentation --- README.md | 6 +- services/APIService/main.py | 182 ++++++++++++++++++- services/AgentService/main.py | 77 +++++++- services/AgentService/monologue_flow.py | 132 ++++++++++++-- services/AgentService/monologue_prompts.py | 54 +++++- services/AgentService/podcast_flow.py | 202 +++++++++++++++++++-- services/AgentService/podcast_prompts.py | 63 ++++++- services/AgentService/test_api.py | 19 ++ services/AgentService/test_llmmanager.py | 141 +++++++++++++- shared/setup.py | 19 +- shared/shared/api_types.py | 61 +++++-- shared/shared/connection.py | 66 ++++++- shared/shared/job.py | 81 +++++++++ shared/shared/llmmanager.py | 116 +++++++++++- shared/shared/otel.py | 44 ++++- shared/shared/pdf_types.py | 37 +++- shared/shared/podcast_types.py | 51 ++++++ shared/shared/prompt_tracker.py | 47 ++++- shared/shared/prompt_types.py | 20 ++ shared/shared/storage.py | 114 +++++++++++- tests/test.py | 76 +++++++- tests/test_db.py | 61 ++++++- tests/test_files.py | 22 +++ tests/test_invalid_filetype.py | 33 ++++ tests/test_list.py | 31 +++- 25 files changed, 1641 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index eff2b7f..9e3e0c7 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,10 @@ This NVIDIA AI blueprint shows developers how to build a microservice that trans Screenshot 2024-12-30 at 8 43 43 PM -[mermaid diagram](docs/README.md) - ## Quick Start Guide 1. **Set environment variables** - + ```bash # Create .env file with required variables echo "ELEVENLABS_API_KEY=your_key" > .env @@ -83,7 +81,7 @@ It is easy to swap out different pieces of the stack to optimize GPU usage for a 4. **Enable Tracing** -We expose a Jaeger instance at `http://localhost:16686/` for tracing. This is useful for debugging and monitoring the system. +We expose a Jaeger instance at `http://localhost:16686/` for tracing. This is useful for debugging and monitoring the system ## Contributing diff --git a/services/APIService/main.py b/services/APIService/main.py index 5fe4b9f..c471c48 100644 --- a/services/APIService/main.py +++ b/services/APIService/main.py @@ -1,3 +1,23 @@ +""" +Main FastAPI application module for the AI Research Assistant API Service. + +This module provides the core API endpoints for the PDF-to-Podcast service, handling: +- PDF file uploads and processing +- WebSocket status updates +- Job management and status tracking +- Saved podcast retrieval and management +- Vector database querying +- Service health monitoring + +The service integrates with: +- PDF Service for document processing +- Agent Service for content generation +- TTS Service for audio synthesis +- Redis for caching and pub/sub +- MinIO for file storage +- OpenTelemetry for observability +""" + from fastapi import ( HTTPException, FastAPI, @@ -100,6 +120,19 @@ @app.websocket("/ws/status/{job_id}") async def websocket_endpoint(websocket: WebSocket, job_id: str): + """ + WebSocket endpoint for real-time job status updates. + + Handles client connections and sends status updates for all services processing a job. + Implements a ready-check protocol and maintains connection with periodic pings. + + Args: + websocket (WebSocket): The WebSocket connection instance + job_id (str): Unique identifier for the job to track + + Raises: + WebSocketDisconnect: If the client disconnects + """ try: # Accept the WebSocket connection await manager.connect(websocket, job_id) @@ -166,6 +199,20 @@ def process_pdf_task( files_and_types: List[Tuple[bytes, str]], transcription_params: TranscriptionParams, ): + """ + Process PDF files through the conversion pipeline. + + Coordinates the workflow between PDF Service, Agent Service, and TTS Service + to convert PDFs into an audio podcast. + + Args: + job_id (str): Unique identifier for the job + files_and_types (List[Tuple[bytes, str]]): List of tuples containing file content and type (target/context) + transcription_params (TranscriptionParams): Parameters controlling the transcription process + + Raises: + Exception: If any service in the pipeline fails + """ with telemetry.tracer.start_as_current_span("api.process_pdf_task") as span: span.set_attribute("job_id", job_id) try: @@ -305,6 +352,21 @@ async def process_pdf( context_files: Union[UploadFile, List[UploadFile]] = File([]), transcription_params: str = Form(...), ): + """ + Process uploaded PDF files and generate a podcast. + + Args: + background_tasks (BackgroundTasks): FastAPI background tasks handler + target_files (Union[UploadFile, List[UploadFile]]): Primary PDF file(s) to process + context_files (Union[UploadFile, List[UploadFile]], optional): Supporting PDF files + transcription_params (str): JSON string containing transcription parameters + + Returns: + dict: Contains job_id for tracking the processing status + + Raises: + HTTPException: If file validation fails or parameters are invalid + """ with telemetry.tracer.start_as_current_span("api.process_pdf") as span: # Convert single file to list for consistent handling target_files_list = ( @@ -368,7 +430,19 @@ async def process_pdf( # TODO: wire up userId auth here @app.get("/status/{job_id}") async def get_status(job_id: str, userId: str = Query(..., description="KAS User ID")): - """Get aggregated status from all services""" + """ + Get aggregated status from all services for a specific job. + + Args: + job_id (str): Job identifier to check status for + userId (str): User identifier for authorization + + Returns: + dict: Status information from all services + + Raises: + HTTPException: If job is not found + """ with telemetry.tracer.start_as_current_span("api.job.status") as span: span.set_attribute("job_id", job_id) statuses = {} @@ -392,7 +466,19 @@ async def get_status(job_id: str, userId: str = Query(..., description="KAS User @app.get("/output/{job_id}") async def get_output(job_id: str, userId: str = Query(..., description="KAS User ID")): - """Get the final TTS output""" + """ + Get the final TTS output for a completed job. + + Args: + job_id (str): Job identifier to get output for + userId (str): User identifier for authorization + + Returns: + Response: Audio file response with appropriate headers + + Raises: + HTTPException: If result is not found or TTS not completed + """ with telemetry.tracer.start_as_current_span("api.job.output") as span: span.set_attribute("job_id", job_id) @@ -427,7 +513,14 @@ async def get_output(job_id: str, userId: str = Query(..., description="KAS User @app.post("/cleanup") async def cleanup_jobs(): - """Clean up old jobs across all services""" + """ + Clean up old jobs across all services. + + Removes job status and result data from Redis for all services. + + Returns: + dict: Number of jobs removed + """ removed = 0 for service in ServiceType: pattern = f"status:*:{service}" @@ -443,7 +536,18 @@ async def cleanup_jobs(): async def get_saved_podcasts( userId: str = Query(..., description="KAS User ID", min_length=1), ): - """Get a list of all saved podcasts from storage with their audio data""" + """ + Get a list of all saved podcasts from storage with their audio data. + + Args: + userId (str): User identifier to filter podcasts + + Returns: + Dict[str, List[SavedPodcast]]: List of saved podcasts metadata + + Raises: + HTTPException: If retrieval fails + """ try: with telemetry.tracer.start_as_current_span("api.saved_podcasts") as span: if not userId.strip(): # Check for whitespace-only strings @@ -478,7 +582,19 @@ async def get_saved_podcasts( async def get_saved_podcast_metadata( job_id: str, userId: str = Query(..., description="KAS User ID") ): - """Get a specific saved podcast metadata without audio data""" + """ + Get a specific saved podcast metadata without audio data. + + Args: + job_id (str): Job identifier for the podcast + userId (str): User identifier for authorization + + Returns: + SavedPodcast: Podcast metadata + + Raises: + HTTPException: If podcast not found or retrieval fails + """ try: with telemetry.tracer.start_as_current_span( "api.saved_podcast.metadata" @@ -511,7 +627,19 @@ async def get_saved_podcast_metadata( async def get_saved_podcast( job_id: str, userId: str = Query(..., description="KAS User ID") ): - """Get a specific saved podcast with its audio data""" + """ + Get a specific saved podcast with its audio data. + + Args: + job_id (str): Job identifier for the podcast + userId (str): User identifier for authorization + + Returns: + SavedPodcastWithAudio: Podcast metadata and audio content + + Raises: + HTTPException: If podcast not found or retrieval fails + """ try: with telemetry.tracer.start_as_current_span("api.saved_podcast.audio") as span: span.set_attribute("job_id", job_id) @@ -556,7 +684,19 @@ async def get_saved_podcast( async def get_saved_podcast_transcript( job_id: str, userId: str = Query(..., description="KAS User ID") ): - """Get a specific saved podcast transcript""" + """ + Get a specific saved podcast transcript. + + Args: + job_id (str): Job identifier for the podcast + userId (str): User identifier for authorization + + Returns: + Conversation: Podcast transcript data + + Raises: + HTTPException: If transcript not found or invalid format + """ with telemetry.tracer.start_as_current_span("api.saved_podcast.transcript") as span: try: span.set_attribute("job_id", job_id) @@ -590,7 +730,19 @@ async def get_saved_podcast_transcript( async def get_saved_podcast_agent_workflow( job_id: str, userId: str = Query(..., description="KAS User ID") ): - """Get a specific saved podcast agent workflow""" + """ + Get a specific saved podcast agent workflow history. + + Args: + job_id (str): Job identifier for the podcast + userId (str): User identifier for authorization + + Returns: + PromptTracker: Agent workflow history data + + Raises: + HTTPException: If history not found or retrieval fails + """ with telemetry.tracer.start_as_current_span("api.saved_podcast.history") as span: try: span.set_attribute("job_id", job_id) @@ -618,7 +770,19 @@ async def get_saved_podcast_agent_workflow( async def get_saved_podcast_pdf( job_id: str, userId: str = Query(..., description="KAS User ID") ): - """Get the original PDF file for a specific podcast""" + """ + Get the original PDF file for a specific podcast. + + Args: + job_id (str): Job identifier for the podcast + userId (str): User identifier for authorization + + Returns: + Response: PDF file response with appropriate headers + + Raises: + HTTPException: If PDF not found or retrieval fails + """ with telemetry.tracer.start_as_current_span("api.saved_podcast.pdf") as span: try: span.set_attribute("job_id", job_id) diff --git a/services/AgentService/main.py b/services/AgentService/main.py index e3f5022..2d0cfcb 100644 --- a/services/AgentService/main.py +++ b/services/AgentService/main.py @@ -1,3 +1,10 @@ +""" +Main FastAPI application for the Agent Service. + +This service coordinates the PDF-to-podcast conversion process by managing jobs, +orchestrating LLM calls, and handling both monologue and dialogue podcast generation. +""" + from fastapi import FastAPI, BackgroundTasks, HTTPException from shared.api_types import ( ServiceType, @@ -31,11 +38,14 @@ from shared.prompt_tracker import PromptTracker +# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Initialize FastAPI app app = FastAPI(debug=True) +# Set up OpenTelemetry instrumentation telemetry = OpenTelemetryInstrumentation() config = OpenTelemetryConfig( service_name="agent-service", @@ -45,14 +55,33 @@ ) telemetry.initialize(config, app) +# Initialize managers job_manager = JobStatusManager(ServiceType.AGENT, telemetry=telemetry) storage_manager = StorageManager(telemetry=telemetry) async def process_transcription(job_id: str, request: TranscriptionRequest): - """Main processing function for transcription requests""" + """ + Main processing function for transcription requests. + + Handles both monologue and dialogue podcast generation workflows by coordinating + multiple steps including PDF summarization, outline generation, and conversation creation. + + Args: + job_id (str): Unique identifier for the transcription job + request (TranscriptionRequest): Contains all parameters for the transcription including: + - PDF metadata + - Voice mapping + - Speaker names + - Duration target + - Processing preferences + + Raises: + Exception: If any step in the process fails, with error details in job status + """ with telemetry.tracer.start_as_current_span("agent.process_transcription") as span: try: + # Initialize LLM manager and prompt tracker llm_manager = LLMManager( api_key=os.getenv("NVIDIA_API_KEY"), telemetry=telemetry, @@ -208,6 +237,19 @@ async def process_transcription(job_id: str, request: TranscriptionRequest): # API Endpoints @app.post("/transcribe", status_code=202) def transcribe(request: TranscriptionRequest, background_tasks: BackgroundTasks): + """ + Endpoint to start a new transcription job. + + Accepts a transcription request and starts an asynchronous job to process it. + The job runs in the background and its status can be checked using the /status endpoint. + + Args: + request (TranscriptionRequest): Contains job parameters and PDF metadata + background_tasks (BackgroundTasks): FastAPI background tasks handler + + Returns: + dict: Contains the job_id for tracking the request + """ with telemetry.tracer.start_as_current_span("agent.transcribe") as span: span.set_attribute("request", request.model_dump(exclude={"markdown"})) job_manager.create_job(request.job_id) @@ -217,6 +259,21 @@ def transcribe(request: TranscriptionRequest, background_tasks: BackgroundTasks) @app.get("/status/{job_id}") def get_status(job_id: str): + """ + Get the current status of a transcription job. + + Args: + job_id (str): ID of the job to check + + Returns: + dict: Current job status and details containing: + - status: Current job status (PENDING, PROCESSING, COMPLETED, FAILED) + - message: Status message or error details + - progress: Optional progress information + + Raises: + HTTPException: If job is not found + """ with telemetry.tracer.start_as_current_span("agent.get_status") as span: span.set_attribute("job_id", job_id) status = job_manager.get_status(job_id) @@ -228,6 +285,18 @@ def get_status(job_id: str): @app.get("/output/{job_id}") def get_output(job_id: str): + """ + Get the final output of a completed transcription job. + + Args: + job_id (str): ID of the completed job + + Returns: + dict: The generated podcast conversation + + Raises: + HTTPException: If result is not found + """ with telemetry.tracer.start_as_current_span("agent.get_output") as span: span.set_attribute("job_id", job_id) result = job_manager.get_result(job_id) @@ -238,6 +307,12 @@ def get_output(job_id: str): @app.get("/health") def health(): + """ + Simple health check endpoint. + + Returns: + dict: Service health status + """ return { "status": "healthy", } diff --git a/services/AgentService/monologue_flow.py b/services/AgentService/monologue_flow.py index 105a5a9..e00e2ca 100644 --- a/services/AgentService/monologue_flow.py +++ b/services/AgentService/monologue_flow.py @@ -1,21 +1,41 @@ -from shared.api_types import JobStatus, TranscriptionRequest -from shared.podcast_types import Conversation -from shared.pdf_types import PDFMetadata -from shared.llmmanager import LLMManager -from shared.job import JobStatusManager -from typing import List, Dict -import ujson as json -import logging -from shared.prompt_tracker import PromptTracker -from monologue_prompts import FinancialSummaryPrompts -from langchain_core.messages import AIMessage -import asyncio +""" +Monologue flow module for converting PDFs to podcast monologues. + +This module handles the workflow for generating single-speaker podcast content from PDF documents. +It includes functionality for summarizing PDFs, generating outlines, and creating monologue scripts. +""" + +from shared.api_types import JobStatus, TranscriptionRequest # Job status tracking and request types +from shared.podcast_types import Conversation # Podcast conversation data structures +from shared.pdf_types import PDFMetadata # PDF document metadata and content +from shared.llmmanager import LLMManager # LLM interaction management +from shared.job import JobStatusManager # Background job status tracking +from typing import List, Dict # Type hints +import ujson as json # Fast JSON processing +import logging # Logging utilities +from shared.prompt_tracker import PromptTracker # Tracks prompts sent to LLM +from monologue_prompts import FinancialSummaryPrompts # Prompt templates +from langchain_core.messages import AIMessage # LLM message type +import asyncio # Async functionality async def monologue_summarize_pdf( pdf_metadata: PDFMetadata, llm_manager: LLMManager, prompt_tracker: PromptTracker ) -> AIMessage: - """Summarize a single PDF document""" + """ + Summarize a single PDF document using the LLM. + + Args: + pdf_metadata (PDFMetadata): Metadata and content of the PDF to summarize + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + + Returns: + AIMessage: The LLM's summary response + + The function uses a template to generate a summary prompt and tracks both the + prompt and response for monitoring purposes. + """ template = FinancialSummaryPrompts.get_template("monologue_summary_prompt") prompt = template.render(text=pdf_metadata.markdown) @@ -40,7 +60,23 @@ async def monologue_summarize_pdfs( job_manager: JobStatusManager, logger: logging.Logger, ) -> List[PDFMetadata]: - """Summarize all PDFs in the request""" + """ + Summarize multiple PDFs in the request. + + Args: + pdfs (List[PDFMetadata]): List of PDFs to summarize + job_id (str): ID for tracking job progress + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + List[PDFMetadata]: The input PDFs with summaries added + + Uses asyncio.gather to process multiple PDFs concurrently and updates + job status throughout the process. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, f"Summarizing {len(pdfs)} PDFs" ) @@ -65,7 +101,23 @@ async def monologue_generate_raw_outline( job_id: str, job_manager: JobStatusManager, ) -> str: - """Generate initial raw outline from summarized PDFs""" + """ + Generate an initial outline from the summarized PDFs. + + Args: + summarized_pdfs (List[PDFMetadata]): PDFs with their summaries + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + + Returns: + str: Raw outline text generated by the LLM + + Combines PDF summaries and any focus instructions to generate a structured + outline for the monologue. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, "Generating initial outline" ) @@ -105,7 +157,23 @@ async def monologue_generate_monologue( job_id: str, job_manager: JobStatusManager, ) -> str: - """Generate monologue transcript""" + """ + Generate a complete monologue transcript from the outline. + + Args: + raw_outline (str): Generated outline to expand into monologue + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + + Returns: + str: Complete monologue transcript + + Expands the outline into a natural-sounding monologue, incorporating + the speaker's name and any focus areas specified in the request. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, "Creating monologue transcript" ) @@ -144,7 +212,23 @@ async def monologue_create_final_conversation( job_id: str, job_manager: JobStatusManager, ) -> Conversation: - """Convert the monologue into structured Conversation format""" + """ + Convert the monologue into a structured Conversation format. + + Args: + monologue (str): Generated monologue transcript + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + + Returns: + Conversation: Structured conversation object + + Formats the monologue into a JSON structure that matches the Conversation + schema, handling proper text escaping and validation. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, "Formatting final conversation" ) @@ -181,6 +265,16 @@ async def monologue_create_final_conversation( def unescape_unicode_string(s: str) -> str: - """Convert escaped Unicode sequences to actual Unicode characters""" - # This handles both raw strings (with extra backslashes) and regular strings + """ + Convert escaped Unicode sequences to actual Unicode characters. + + Args: + s (str): String potentially containing escaped Unicode sequences + + Returns: + str: String with Unicode sequences properly unescaped + + This handles both raw strings (with extra backslashes) and regular strings, + ensuring proper display of special characters in the final output. + """ return s.encode("utf-8").decode("unicode-escape") diff --git a/services/AgentService/monologue_prompts.py b/services/AgentService/monologue_prompts.py index 52b7594..037aab3 100644 --- a/services/AgentService/monologue_prompts.py +++ b/services/AgentService/monologue_prompts.py @@ -1,6 +1,15 @@ +""" +Module containing prompt templates and utilities for generating monologue podcasts. + +This module provides a collection of prompt templates used to guide LLM responses +when generating podcast monologues from PDF documents. It includes templates for +summarization, synthesis, transcript generation, and dialogue formatting. +""" + import jinja2 from typing import Dict +# Template for summarizing individual PDF documents MONOLOGUE_SUMMARY_PROMPT_STR = """ You are a knowledgeable analyst. Please provide a targeted analysis of the following document, focusing on: {{ focus }} @@ -40,6 +49,7 @@ You are presenting to the board of directors. Speak in a way that is engaging and informative, but not too technical and speak in the first person. """ +# Template for synthesizing multiple document summaries into an outline MONOLOGUE_MULTI_DOC_SYNTHESIS_PROMPT_STR = """ Create a structured monologue outline synthesizing the following document summaries. The monologue should be 30-45 seconds long. @@ -80,6 +90,7 @@ Output a structured outline that synthesizes insights across all documents, emphasizing Target Documents while using Context Documents for support.""" +# Template for generating the actual monologue transcript MONOLOGUE_TRANSCRIPT_PROMPT_STR = """ Create a focused update based on this outline and source documents. @@ -129,6 +140,7 @@ Create a concise, engaging monologue that follows the outline while delivering essential financial information.""" +# Template for converting monologue to structured dialogue format MONOLOGUE_DIALOGUE_PROMPT_STR = """You are tasked with converting a financial monologue into a structured JSON format. You have: 1. Speaker information: @@ -160,6 +172,7 @@ Please output the JSON following the provided schema, maintaining all financial details and proper formatting. The output should use proper Unicode characters directly, not escaped sequences. Do not output anything besides the JSON.""" +# Dictionary mapping template names to their content PROMPT_TEMPLATES = { "monologue_summary_prompt": MONOLOGUE_SUMMARY_PROMPT_STR, "monologue_multi_doc_synthesis_prompt": MONOLOGUE_MULTI_DOC_SYNTHESIS_PROMPT_STR, @@ -174,13 +187,50 @@ class FinancialSummaryPrompts: + """ + A class providing access to financial summary prompt templates. + + This class serves as an interface to access and render various prompt templates + used in the monologue generation process. Templates are accessed either through + attribute access or the get_template class method. + + Attributes: + None + + Methods: + __getattr__(name: str) -> str: Dynamically retrieves prompt template strings by name + get_template(name: str) -> jinja2.Template: Retrieves compiled Jinja templates by name + """ + def __getattr__(self, name: str) -> str: - """Dynamically handle prompt requests by name""" + """ + Get the Jinja template by name + + Args: + name (str): Name of the prompt template to retrieve + + Returns: + str: The prompt template string + + Raises: + AttributeError: If the requested template name doesn't exist + """ if name in PROMPT_TEMPLATES: return PROMPT_TEMPLATES[name] raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") @classmethod def get_template(cls, name: str) -> jinja2.Template: - """Get the Jinja template by name""" + """ + Get the compiled Jinja template by name. + + Args: + name (str): Name of the template to retrieve + + Returns: + jinja2.Template: The compiled Jinja template object + + Raises: + KeyError: If the requested template name doesn't exist + """ return TEMPLATES[name] diff --git a/services/AgentService/podcast_flow.py b/services/AgentService/podcast_flow.py index 092509a..00b3c88 100644 --- a/services/AgentService/podcast_flow.py +++ b/services/AgentService/podcast_flow.py @@ -1,3 +1,10 @@ +""" +Podcast flow module for converting PDFs to podcast conversations. + +This module handles the workflow for generating multi-speaker podcast content from PDF documents. +It includes functionality for summarizing PDFs, generating outlines, and creating dialogue segments. +""" + from shared.pdf_types import PDFMetadata from shared.podcast_types import Conversation, PodcastOutline from shared.api_types import JobStatus, TranscriptionRequest @@ -15,7 +22,20 @@ async def podcast_summarize_pdf( pdf_metadata: PDFMetadata, llm_manager: LLMManager, prompt_tracker: PromptTracker ) -> AIMessage: - """Summarize a single PDF document""" + """ + Summarize a single PDF document using the LLM. + + Args: + pdf_metadata (PDFMetadata): The PDF document metadata and content to summarize + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + + Returns: + AIMessage: The LLM's summary response + + The function uses a template to generate a summary prompt and tracks both the + prompt and response for monitoring purposes. + """ template = PodcastPrompts.get_template("podcast_summary_prompt") prompt = template.render(text=pdf_metadata.markdown) @@ -40,7 +60,23 @@ async def podcast_summarize_pdfs( job_manager: JobStatusManager, logger: logging.Logger, ) -> List[PDFMetadata]: - """Summarize all PDFs in the request""" + """ + Summarize all PDFs in parallel and update their metadata with summaries. + + Args: + pdfs (List[PDFMetadata]): List of PDFs to summarize + job_id (str): ID for tracking job progress + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + List[PDFMetadata]: The input PDFs with summaries added + + Uses asyncio.gather to process multiple PDFs concurrently and updates + job status throughout the process. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, f"Summarizing {len(pdfs)} PDFs" ) @@ -66,7 +102,24 @@ async def podcast_generate_raw_outline( job_manager: JobStatusManager, logger: logging.Logger, ) -> str: - """Generate initial raw outline from summarized PDFs""" + """ + Generate initial raw outline from summarized PDFs. + + Args: + summarized_pdfs (List[PDFMetadata]): PDFs with their summaries + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + str: Raw outline text generated from the summaries + + Formats document summaries in XML and uses a template to generate + an initial podcast outline structure. + """ # Prepare document summaries in XML format job_manager.update_status( job_id, JobStatus.PROCESSING, "Generating initial outline" @@ -114,7 +167,24 @@ async def podcast_generate_structured_outline( job_manager: JobStatusManager, logger: logging.Logger, ) -> PodcastOutline: - """Convert raw outline to structured format""" + """ + Convert raw outline text to structured PodcastOutline format. + + Args: + raw_outline (str): Raw outline text to structure + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + PodcastOutline: Structured outline following the PodcastOutline schema + + Uses JSON schema validation to ensure the outline follows the required structure + and only references valid PDF filenames. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, @@ -157,7 +227,22 @@ async def podcast_process_segment( llm_manager: LLMManager, prompt_tracker: PromptTracker, ) -> tuple[str, str]: - """Process a single segment""" + """ + Process a single outline segment to generate initial content. + + Args: + segment (Any): Segment from the outline to process + idx (int): Index of the segment + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + + Returns: + tuple[str, str]: Tuple of (segment_id, generated_content) + + Generates initial content for a segment, incorporating referenced PDF content + if available. Uses different templates based on whether references exist. + """ # Get reference content if it exists text_content = [] if segment.references: @@ -215,7 +300,24 @@ async def podcast_process_segments( job_manager: JobStatusManager, logger: logging.Logger, ) -> Dict[str, str]: - """Process each segment in the outline""" + """ + Process all outline segments in parallel to generate initial content. + + Args: + outline (PodcastOutline): Structured outline to process + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + Dict[str, str]: Dictionary mapping segment IDs to their generated content + + Creates tasks for processing each segment and executes them in parallel using + asyncio.gather. + """ # Create tasks for processing each segment segment_tasks: List[Coroutine] = [] for idx, segment in enumerate(outline.segments): @@ -249,7 +351,23 @@ async def podcast_generate_dialogue_segment( llm_manager: LLMManager, prompt_tracker: PromptTracker, ) -> Dict[str, str]: - """Generate dialogue for a single segment""" + """ + Generate dialogue for a single segment. + + Args: + segment (Any): Segment from the outline + idx (int): Index of the segment + segment_text (str): Generated content for the segment + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + + Returns: + Dict[str, str]: Dictionary containing section name and generated dialogue + + Formats segment topics and uses a template to convert content into a dialogue + format between two speakers. + """ # Format topics for prompt topics_text = "\n".join( [ @@ -297,7 +415,24 @@ async def podcast_generate_dialogue( job_manager: JobStatusManager, logger: logging.Logger, ) -> List[Dict[str, str]]: - """Generate dialogue for each segment""" + """ + Generate dialogue for all segments in parallel. + + Args: + segments (Dict[str, str]): Dictionary of segment IDs and their content + outline (PodcastOutline): Structured outline + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + List[Dict[str, str]]: List of dictionaries containing section names and dialogues + + Creates tasks for generating dialogue for each segment and executes them in parallel. + """ job_manager.update_status(job_id, JobStatus.PROCESSING, "Generating dialogue") # Create tasks for generating dialogue for each segment @@ -346,7 +481,23 @@ async def podcast_combine_dialogues( job_manager: JobStatusManager, logger: logging.Logger, ) -> str: - """Iteratively combine dialogue segments into a cohesive conversation""" + """ + Iteratively combine dialogue segments into a cohesive conversation. + + Args: + segment_dialogues (List[Dict[str, str]]): List of segment dialogues + outline (PodcastOutline): Structured outline + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + str: Combined dialogue text + + Iteratively combines dialogue segments, ensuring smooth transitions between sections. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, "Combining dialogue segments" ) @@ -405,7 +556,24 @@ async def podcast_create_final_conversation( job_manager: JobStatusManager, logger: logging.Logger, ) -> Conversation: - """Convert the dialogue into structured Conversation format""" + """ + Convert the dialogue into structured Conversation format. + + Args: + dialogue (str): Combined dialogue text + request (TranscriptionRequest): Original transcription request + llm_manager (LLMManager): Manager for LLM interactions + prompt_tracker (PromptTracker): Tracks prompts and responses + job_id (str): ID for tracking job progress + job_manager (JobStatusManager): Manages job status updates + logger (logging.Logger): Logger for tracking progress + + Returns: + Conversation: Structured conversation following the Conversation schema + + Formats the dialogue into a structured conversation format with proper speaker + attribution and timing information. + """ job_manager.update_status( job_id, JobStatus.PROCESSING, "Formatting final conversation" ) @@ -444,6 +612,18 @@ async def podcast_create_final_conversation( def unescape_unicode_string(s: str) -> str: - """Convert escaped Unicode sequences to actual Unicode characters""" + """ + Convert escaped Unicode sequences to actual Unicode characters. + + Args: + s (str): String potentially containing escaped Unicode sequences + + Returns: + str: String with Unicode sequences converted to actual characters + + Example: + >>> unescape_unicode_string("Hello\\u2019s World") + "Hello's World" + """ # This handles both raw strings (with extra backslashes) and regular strings return s.encode("utf-8").decode("unicode-escape") diff --git a/services/AgentService/podcast_prompts.py b/services/AgentService/podcast_prompts.py index 5959df8..ec3b837 100644 --- a/services/AgentService/podcast_prompts.py +++ b/services/AgentService/podcast_prompts.py @@ -1,6 +1,15 @@ +""" +Module containing prompt templates and utilities for generating podcast dialogues. + +This module provides a collection of prompt templates used to guide LLM responses +when generating podcast dialogues from PDF documents. It includes templates for +summarization, outline generation, transcript creation, and dialogue formatting. +""" + import jinja2 from typing import Dict +# Template for summarizing individual PDF documents PODCAST_SUMMARY_PROMPT_STR = """ Please provide a comprehensive summary of the following document. Note that this document may contain OCR/PDF conversion artifacts, so please interpret the content, especially numerical data and tables, with appropriate context. @@ -35,6 +44,7 @@ Note: Focus on extracting and organizing the most essential information while ensuring no critical details are omitted. Maintain the original document's tone and context in your summary. """ +# Template for synthesizing multiple document summaries into an outline PODCAST_MULTI_PDF_OUTLINE_PROMPT_STR = """ Create a structured podcast outline synthesizing the following document summaries. The podcast should be {{total_duration}} minutes long. @@ -70,6 +80,7 @@ Ensure the outline creates a cohesive narrative that emphasizes the Target Documents while using Context Documents to provide additional depth and background information. """ +# Template for converting outline into structured JSON format PODCAST_MULTI_PDF_STRUCUTRED_OUTLINE_PROMPT_STR = """ Convert the following outline into a structured JSON format. The final section should be marked as the conclusion segment. @@ -100,6 +111,7 @@ {{ schema }} """ +# Template for generating transcript with source references PODCAST_PROMPT_WITH_REFERENCES_STR = """ Create a transcript incorporating details from the provided source material: @@ -129,6 +141,7 @@ Ensure thorough coverage of each topic while preserving the accuracy and nuance of the source material. """ +# Template for generating transcript without source references PODCAST_PROMPT_NO_REFERENCES_STR = """ Create a knowledge-based transcript following this outline: @@ -162,6 +175,7 @@ Develop a thorough exploration of each topic using available knowledge. Begin with careful brainstorming to map connections between ideas, then build a clear narrative that makes complex concepts accessible while maintaining accuracy and completeness. """ +# Template for converting transcript to dialogue format PODCAST_TRANSCRIPT_TO_DIALOGUE_PROMPT_STR = """ Your task is to transform the provided input transcript into an engaging and informative podcast dialogue. @@ -221,6 +235,7 @@ *Only return the full dialogue transcript; do not include any other information like time budget or segment names.* """ +# Template for combining multiple dialogue sections PODCAST_COMBINE_DIALOGUES_PROMPT_STR = """You are revising a podcast transcript to make it more engaging while preserving its content and structure. You have access to three key elements: 1. The podcast outline @@ -257,6 +272,7 @@ Please output the complete revised dialogue transcript from the beginning, with the next section integrated seamlessly.""" +# Template for converting dialogue to JSON format PODCAST_DIALOGUE_PROMPT_STR = """You are tasked with converting a podcast transcript into a structured JSON format. You have: 1. Two speakers: @@ -289,6 +305,7 @@ Please output the JSON following the provided schema, maintaining all conversational details and speaker attributions. The output should use proper Unicode characters directly, not escaped sequences. Do not output anything besides the JSON.""" +# Dictionary mapping prompt names to their template strings PROMPT_TEMPLATES = { "podcast_summary_prompt": PODCAST_SUMMARY_PROMPT_STR, "podcast_multi_pdf_outline_prompt": PODCAST_MULTI_PDF_OUTLINE_PROMPT_STR, @@ -307,13 +324,55 @@ class PodcastPrompts: + """ + A class providing access to podcast-related prompt templates. + + This class manages a collection of Jinja2 templates used for generating + various prompts in the podcast creation process, from PDF summarization + to dialogue generation. + + The templates are pre-compiled for efficiency and can be accessed either + through attribute access or the get_template class method. + + Attributes: + None - Templates are stored in module-level constants + + Methods: + __getattr__(name: str) -> str: + Dynamically retrieves prompt template strings by name + get_template(name: str) -> jinja2.Template: + Retrieves pre-compiled Jinja2 templates by name + """ + def __getattr__(self, name: str) -> str: - """Dynamically handle prompt requests by name""" + """ + Dynamically retrieve prompt templates by name. + + Args: + name (str): Name of the prompt template to retrieve + + Returns: + str: The prompt template string + + Raises: + AttributeError: If the requested template name doesn't exist + """ if name in PROMPT_TEMPLATES: return PROMPT_TEMPLATES[name] raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") @classmethod def get_template(cls, name: str) -> jinja2.Template: - """Get the Jinja template by name""" + """ + Get a pre-compiled Jinja2 template by name. + + Args: + name (str): Name of the template to retrieve + + Returns: + jinja2.Template: The pre-compiled Jinja2 template object + + Raises: + KeyError: If the requested template name doesn't exist + """ return TEMPLATES[name] diff --git a/services/AgentService/test_api.py b/services/AgentService/test_api.py index 705028a..fc78286 100644 --- a/services/AgentService/test_api.py +++ b/services/AgentService/test_api.py @@ -1,3 +1,10 @@ +""" +Test module for the Agent Service API endpoints. + +This module contains integration tests for the transcription API endpoints, +verifying the full workflow from request submission to job completion. +""" + import requests import ujson as json import os @@ -7,6 +14,18 @@ def test_transcribe_api(): + """ + Test the transcription API workflow. + + This test function: + 1. Creates a TranscriptionRequest with sample PDF metadata + 2. Submits the request to the transcribe endpoint + 3. Polls the job status until completion + 4. Verifies the job completes successfully + + Raises: + AssertionError: If any step of the workflow fails or times out + """ # API endpoints BASE_URL = os.getenv("AGENT_SERVICE_URL", "http://localhost:8964") TRANSCRIBE_URL = f"{BASE_URL}/transcribe" diff --git a/services/AgentService/test_llmmanager.py b/services/AgentService/test_llmmanager.py index c5097e0..f053037 100644 --- a/services/AgentService/test_llmmanager.py +++ b/services/AgentService/test_llmmanager.py @@ -1,3 +1,12 @@ +""" +Test module for the LLMManager class. + +This module contains integration tests for the LLMManager class, testing various +capabilities like basic queries, parallel processing, JSON schema validation, +and streaming responses. It uses a mock FastAPI application and OpenTelemetry +instrumentation for testing purposes. +""" + import asyncio import os from shared.otel import OpenTelemetryInstrumentation, OpenTelemetryConfig @@ -8,6 +17,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Set up mock FastAPI app and telemetry for testing mock_app = FastAPI() mock_telemetry = OpenTelemetryInstrumentation() mock_config = OpenTelemetryConfig( @@ -20,7 +30,25 @@ async def test_basic_queries(): - """Test both sync and async basic queries""" + """ + Test both synchronous and asynchronous basic queries. + + Tests the basic query functionality of LLMManager by making both sync + and async requests with simple prompts. Verifies that both methods + return expected responses. + + The test: + 1. Creates an LLMManager instance + 2. Tests synchronous query with robotics laws prompt + 3. Tests asynchronous query with machine learning prompt + 4. Prints responses for manual verification + + Returns: + None + + Raises: + Exception: If either query fails or returns unexpected response + """ print("\n=== Testing Basic Queries ===") manager = LLMManager(api_key=os.getenv("NVIDIA_API_KEY"), telemetry=mock_telemetry) @@ -55,7 +83,25 @@ async def test_basic_queries(): async def test_parallel_processing(): - """Test processing multiple queries in parallel""" + """ + Test processing multiple queries in parallel. + + Demonstrates the ability to process multiple queries concurrently using + asyncio.gather(). Sends three different programming language queries + simultaneously and collects their responses. + + The test: + 1. Creates an LLMManager instance + 2. Defines three programming language questions + 3. Processes queries in parallel using asyncio.gather() + 4. Prints responses in order with corresponding questions + + Returns: + None + + Raises: + Exception: If parallel processing fails or returns unexpected responses + """ print("\n=== Testing Parallel Processing ===") manager = LLMManager(api_key=os.getenv("NVIDIA_API_KEY"), telemetry=mock_telemetry) @@ -63,6 +109,16 @@ async def test_parallel_processing(): questions = ["What is Python?", "What is JavaScript?", "What is Rust?"] async def process_query(question: str, idx: int): + """ + Helper function to process individual queries. + + Args: + question (str): The question to ask the LLM + idx (int): Index for tracking parallel queries + + Returns: + AIMessage: The LLM's response + """ return await manager.query_async( model_key="reasoning", messages=[ @@ -81,7 +137,25 @@ async def process_query(question: str, idx: int): async def test_json_schema(): - """Test JSON schema structured output""" + """ + Test JSON schema structured output. + + Verifies that the LLMManager can generate responses conforming to a + specified JSON schema. Uses a sample schema for person details including + name, age, occupation, and hobbies. + + The test: + 1. Creates an LLMManager instance + 2. Defines a JSON schema for person details + 3. Requests a character generation conforming to schema + 4. Verifies response matches schema structure + + Returns: + None + + Raises: + Exception: If response doesn't conform to schema or query fails + """ print("\n=== Testing JSON Schema ===") manager = LLMManager(api_key=os.getenv("NVIDIA_API_KEY"), telemetry=mock_telemetry) @@ -111,7 +185,25 @@ async def test_json_schema(): async def test_streaming(): - """Test both sync and async streaming""" + """ + Test both synchronous and asynchronous streaming. + + Tests the streaming capabilities of LLMManager using both sync and async + methods. Verifies that streaming responses are received correctly for + simple counting and listing tasks. + + The test: + 1. Creates an LLMManager instance + 2. Tests sync streaming with counting prompt + 3. Tests async streaming with days of week prompt + 4. Verifies streaming responses are complete and coherent + + Returns: + None + + Raises: + Exception: If streaming fails or returns incomplete responses + """ print("\n=== Testing Streaming ===") manager = LLMManager(api_key=os.getenv("NVIDIA_API_KEY"), telemetry=mock_telemetry) @@ -146,7 +238,26 @@ async def test_streaming(): async def test_json_streaming(): - """Test JSON schema structured output with streaming""" + """ + Test JSON schema structured output with streaming. + + Tests the combination of JSON schema validation and streaming responses. + Uses a simple story summary schema to verify that streamed responses + conform to the specified structure. + + The test: + 1. Creates an LLMManager instance + 2. Defines a story summary JSON schema + 3. Tests sync JSON streaming + 4. Tests async JSON streaming + 5. Verifies both responses conform to schema + + Returns: + None + + Raises: + Exception: If streaming fails or responses don't match schema + """ print("\n=== Testing JSON Streaming ===") manager = LLMManager(api_key=os.getenv("NVIDIA_API_KEY"), telemetry=mock_telemetry) @@ -194,7 +305,25 @@ async def test_json_streaming(): async def main_test(): - """Run all tests""" + """ + Run all tests sequentially. + + Main test runner that executes all test functions in sequence. + Currently configured to run only streaming tests, with other tests + commented out for focused testing. + + The function: + 1. Attempts to run each test in sequence + 2. Catches and reports any exceptions + 3. Currently focuses on streaming tests + 4. Other tests are commented out for selective testing + + Returns: + None + + Raises: + Exception: Prints error message if any test fails + """ try: # # Test basic queries # await test_basic_queries() diff --git a/shared/setup.py b/shared/setup.py index d781198..0ab2389 100644 --- a/shared/setup.py +++ b/shared/setup.py @@ -1,3 +1,12 @@ +"""Setup script for the shared package. + +This package contains shared utilities and functionality used across the pdf-to-podcast +application, including storage management, telemetry, and type definitions. + +The package requires several external dependencies for Redis caching, HTTP requests, +data validation, and AI model integration. +""" + from setuptools import setup, find_packages setup( @@ -5,10 +14,10 @@ version="0.1", packages=find_packages(), install_requires=[ - "redis", - "pydantic", - "httpx", - "requests", - "langchain-nvidia-ai-endpoints", + "redis", # For caching and message queuing + "pydantic", # For data validation and serialization + "httpx", # For async HTTP requests + "requests", # For sync HTTP requests + "langchain-nvidia-ai-endpoints", # For AI model integration ], ) diff --git a/shared/shared/api_types.py b/shared/shared/api_types.py index 1dcac0e..77df483 100644 --- a/shared/shared/api_types.py +++ b/shared/shared/api_types.py @@ -5,35 +5,40 @@ class JobStatus(str, Enum): - PENDING = "pending" - PROCESSING = "processing" - COMPLETED = "completed" - FAILED = "failed" + """Enum representing the possible states of a job.""" + PENDING = "pending" # Job has been created but not started + PROCESSING = "processing" # Job is currently being processed + COMPLETED = "completed" # Job has finished successfully + FAILED = "failed" # Job encountered an error and failed class ServiceType(str, Enum): - PDF = "pdf" - AGENT = "agent" - TTS = "tts" + """Enum representing the different service types in the system.""" + PDF = "pdf" # PDF processing service + AGENT = "agent" # Agent/LLM service + TTS = "tts" # Text-to-speech service class StatusUpdate(BaseModel): + """Model for job status updates sent between services.""" job_id: str status: JobStatus - message: Optional[str] = None - service: Optional[ServiceType] = None - timestamp: Optional[float] = None - data: Optional[dict] = None + message: Optional[str] = None # Optional status message + service: Optional[ServiceType] = None # Service sending the update + timestamp: Optional[float] = None # Unix timestamp of update + data: Optional[dict] = None # Additional status data class StatusResponse(BaseModel): - status: str - result: Optional[str] = None - error: Optional[str] = None - message: Optional[str] = None + """Model for API status responses.""" + status: str # Overall status of the operation + result: Optional[str] = None # Optional success result + error: Optional[str] = None # Optional error message + message: Optional[str] = None # Optional status message class TranscriptionParams(BaseModel): + """Base parameters for podcast transcription requests.""" userId: str = Field(..., description="KAS User ID") name: str = Field(..., description="Name of the podcast") duration: int = Field(..., description="Duration in minutes") @@ -64,6 +69,23 @@ class TranscriptionParams(BaseModel): @model_validator(mode="after") def validate_monologue_settings(self) -> "TranscriptionParams": + """ + Validates the configuration based on monologue/dialogue mode. + + For monologue mode: + - No second speaker name should be provided + - Voice mapping should only contain speaker-1 + + For dialogue mode: + - Second speaker name is required + - Voice mapping must contain both speakers + + Returns: + TranscriptionParams: The validated model instance + + Raises: + ValueError: If validation fails + """ if self.monologue: # Check speaker_2_name is not provided if self.speaker_2_name is not None: @@ -95,11 +117,16 @@ def validate_monologue_settings(self) -> "TranscriptionParams": class TranscriptionRequest(TranscriptionParams): - pdf_metadata: List[PDFMetadata] - job_id: str + """ + Complete transcription request model extending TranscriptionParams. + Includes PDF metadata and job tracking information. + """ + pdf_metadata: List[PDFMetadata] # List of PDFs to process + job_id: str # Unique identifier for the transcription job class RAGRequest(BaseModel): + """Model for Retrieval-Augmented Generation (RAG) requests.""" query: str = Field(..., description="The search query to process") k: int = Field(..., description="Number of results to retrieve", ge=1) job_id: str = Field(..., description="The unique job identifier") diff --git a/shared/shared/connection.py b/shared/shared/connection.py index 189d0b1..1afb904 100644 --- a/shared/shared/connection.py +++ b/shared/shared/connection.py @@ -14,7 +14,31 @@ class ConnectionManager: + """ + Manages WebSocket connections and Redis pub/sub for real-time status updates. + + This class handles: + - WebSocket connections for each job ID + - Redis pub/sub subscription for status updates + - Broadcasting messages to connected clients + - Connection cleanup and resource management + + Attributes: + active_connections (Dict[str, Set[WebSocket]]): Maps job IDs to sets of active WebSocket connections + pubsub: Redis pub/sub connection + message_queue (Queue): Thread-safe queue for message processing + redis_thread (Thread): Background thread for Redis subscription + should_stop (bool): Flag to control background thread termination + redis_client (redis.Redis): Redis client instance + """ + def __init__(self, redis_client: redis.Redis): + """ + Initialize the connection manager. + + Args: + redis_client (redis.Redis): Redis client for pub/sub functionality + """ self.active_connections: Dict[str, Set[WebSocket]] = defaultdict(set) self.pubsub = None self.message_queue = queue.Queue() @@ -23,6 +47,13 @@ def __init__(self, redis_client: redis.Redis): self.redis_client = redis_client async def connect(self, websocket: WebSocket, job_id: str): + """ + Accept a new WebSocket connection for a job. + + Args: + websocket (WebSocket): The WebSocket connection to accept + job_id (str): ID of the job this connection is monitoring + """ await websocket.accept() self.active_connections[job_id].add(websocket) logger.info( @@ -38,6 +69,13 @@ async def connect(self, websocket: WebSocket, job_id: str): asyncio.create_task(self._process_messages()) def disconnect(self, websocket: WebSocket, job_id: str): + """ + Remove a WebSocket connection for a job. + + Args: + websocket (WebSocket): The WebSocket connection to remove + job_id (str): ID of the job the connection was monitoring + """ if job_id in self.active_connections: self.active_connections[job_id].remove(websocket) if not self.active_connections[job_id]: @@ -47,7 +85,12 @@ def disconnect(self, websocket: WebSocket, job_id: str): ) def _redis_listener(self): - """Redis subscription running in a separate thread""" + """ + Background thread that listens for Redis pub/sub messages. + + Subscribes to the status_updates:all channel and queues received messages + for processing by the async message processor. + """ try: self.pubsub = self.redis_client.pubsub(ignore_subscribe_messages=True) self.pubsub.subscribe("status_updates:all") @@ -77,7 +120,12 @@ def _redis_listener(self): self.pubsub.close() async def _process_messages(self): - """Async task to process messages from the queue and broadcast them""" + """ + Async task that processes queued messages and broadcasts them to clients. + + Continuously checks the message queue and broadcasts valid messages + to all connected WebSocket clients for the relevant job ID. + """ while True: try: # Check queue in a non-blocking way @@ -116,7 +164,13 @@ async def _process_messages(self): await asyncio.sleep(1) async def broadcast_to_job(self, job_id: str, message: dict): - """Send message to all WebSocket connections for a job""" + """ + Send a message to all WebSocket connections for a specific job. + + Args: + job_id (str): ID of the job to broadcast to + message (dict): Message to broadcast to all connections + """ if job_id in self.active_connections: disconnected = set() for connection in self.active_connections[job_id]: @@ -133,7 +187,11 @@ async def broadcast_to_job(self, job_id: str, message: dict): self.disconnect(connection, job_id) def cleanup(self): - """Cleanup resources""" + """ + Clean up resources used by the connection manager. + + Stops the Redis listener thread and closes the pub/sub connection. + """ self.should_stop = True if self.redis_thread: self.redis_thread.join(timeout=1.0) diff --git a/shared/shared/job.py b/shared/shared/job.py index afc5f5b..597b892 100644 --- a/shared/shared/job.py +++ b/shared/shared/job.py @@ -7,18 +7,46 @@ class JobStatusManager: + """ + Manages job status and results using Redis as a backend store. + + This class provides methods to track job status, store results, and manage cleanup + of old jobs. It uses Redis hash sets for status storage and Redis pub/sub for + real-time status updates. + + Attributes: + telemetry (OpenTelemetryInstrumentation): Telemetry instrumentation instance + redis (redis.Redis): Redis client instance + service_type (ServiceType): Type of service using this manager + _lock (threading.Lock): Thread lock for synchronization + """ + def __init__( self, service_type: ServiceType, telemetry: OpenTelemetryInstrumentation, redis_url="redis://redis:6379", ): + """ + Initialize the JobStatusManager. + + Args: + service_type (ServiceType): Type of service using this manager + telemetry (OpenTelemetryInstrumentation): Telemetry instrumentation instance + redis_url (str, optional): Redis connection URL. Defaults to "redis://redis:6379" + """ self.telemetry = telemetry self.redis = redis.Redis.from_url(redis_url, decode_responses=False) self.service_type = service_type self._lock = threading.Lock() def create_job(self, job_id: str): + """ + Create a new job with pending status. + + Args: + job_id (str): Unique identifier for the job + """ with self.telemetry.tracer.start_as_current_span("job.create_job") as span: span.set_attribute("job_id", job_id) update = { @@ -38,6 +66,14 @@ def create_job(self, job_id: str): self.redis.publish("status_updates:all", json.dumps(update).encode()) def update_status(self, job_id: str, status: str, message: str): + """ + Update the status of an existing job. + + Args: + job_id (str): Job identifier + status (str): New status value + message (str): Status update message + """ with self.telemetry.tracer.start_as_current_span("job.update_status") as span: span.set_attribute("job_id", job_id) update = { @@ -57,6 +93,13 @@ def update_status(self, job_id: str, status: str, message: str): self.redis.publish("status_updates:all", json.dumps(update).encode()) def set_result(self, job_id: str, result: bytes): + """ + Store the result data for a job. + + Args: + job_id (str): Job identifier + result (bytes): Result data to store + """ with self.telemetry.tracer.start_as_current_span("job.set_result") as span: span.set_attribute("job_id", job_id) set_key = f"result:{job_id}:{str(self.service_type)}" @@ -64,6 +107,14 @@ def set_result(self, job_id: str, result: bytes): self.redis.set(set_key, result) def set_result_with_expiration(self, job_id: str, result: bytes, ex: int): + """ + Store the result data with an expiration time. + + Args: + job_id (str): Job identifier + result (bytes): Result data to store + ex (int): Expiration time in seconds + """ with self.telemetry.tracer.start_as_current_span( "job.set_result_with_expiration" ) as span: @@ -73,6 +124,15 @@ def set_result_with_expiration(self, job_id: str, result: bytes, ex: int): self.redis.set(set_key, result, ex=ex) def get_result(self, job_id: str): + """ + Retrieve the result data for a job. + + Args: + job_id (str): Job identifier + + Returns: + bytes: Result data if found, None otherwise + """ with self.telemetry.tracer.start_as_current_span("job.get_result") as span: span.set_attribute("job_id", job_id) get_key = f"result:{job_id}:{str(self.service_type)}" @@ -81,6 +141,18 @@ def get_result(self, job_id: str): return result if result else None def get_status(self, job_id: str): + """ + Get the current status of a job. + + Args: + job_id (str): Job identifier + + Returns: + dict: Job status information + + Raises: + ValueError: If job not found + """ with self.telemetry.tracer.start_as_current_span("job.get_status") as span: span.set_attribute("job_id", job_id) # Get raw bytes and decode manually @@ -93,6 +165,15 @@ def get_status(self, job_id: str): return {k.decode(): v.decode() for k, v in status.items()} def cleanup_old_jobs(self, max_age=3600): + """ + Remove jobs older than the specified age. + + Args: + max_age (int, optional): Maximum age in seconds. Defaults to 3600. + + Returns: + int: Number of jobs removed + """ current_time = time.time() removed = 0 pattern = f"status:*:{str(self.service_type)}" diff --git a/shared/shared/llmmanager.py b/shared/shared/llmmanager.py index 11bd036..01747df 100644 --- a/shared/shared/llmmanager.py +++ b/shared/shared/llmmanager.py @@ -14,11 +14,25 @@ @dataclass class ModelConfig: + """Configuration for a specific LLM model. + + Attributes: + name (str): Name/identifier of the model + api_base (str): Base URL for the model's API endpoint + """ name: str api_base: str @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig": + """Create a ModelConfig instance from a dictionary. + + Args: + data (Dict[str, Any]): Dictionary containing model configuration + + Returns: + ModelConfig: New ModelConfig instance + """ return cls( name=data["name"], api_base=data["api_base"], @@ -35,6 +49,12 @@ class LLMManager: Configs can be overridden by providing a custom config file. Currently the defaults are hardcoded to build.nvidia.com endpoints. + Attributes: + api_key (str): API key for NVIDIA endpoints + telemetry (OpenTelemetryInstrumentation): Telemetry instrumentation instance + _llm_cache (Dict[str, ChatNVIDIA]): Cache of initialized LLM models + model_configs (Dict[str, ModelConfig]): Model configurations + Usage: >>> llm_manager = LLMManager(api_key, telemetry) >>> llm_manager.query_sync("reasoning", [{"role": "user", "content": "Hello, world!"}], "test") @@ -62,8 +82,15 @@ def __init__( config_path: Optional[str] = None, ): """ - Initialize LLMManager with telemetry - requires: OpenTelemetryInstrumentation instance for tracing + Initialize LLMManager with telemetry. + + Args: + api_key (str): API key for NVIDIA endpoints + telemetry (OpenTelemetryInstrumentation): Telemetry instrumentation instance + config_path (Optional[str]): Path to custom model configurations file + + Raises: + Exception: If initialization fails """ try: self.api_key = api_key @@ -78,7 +105,14 @@ def __init__( def _load_configurations( self, config_path: Optional[str] ) -> Dict[str, ModelConfig]: - """Load model configurations from JSON file if provided, otherwise use defaults""" + """Load model configurations from JSON file if provided, otherwise use defaults. + + Args: + config_path (Optional[str]): Path to configuration JSON file + + Returns: + Dict[str, ModelConfig]: Dictionary mapping model keys to configurations + """ configs = self.DEFAULT_CONFIGS.copy() if config_path: try: @@ -97,7 +131,17 @@ def _load_configurations( return {key: ModelConfig.from_dict(config) for key, config in configs.items()} def get_llm(self, model_key: str) -> ChatNVIDIA: - """Get or create a ChatNVIDIA model for the specified model key""" + """Get or create a ChatNVIDIA model for the specified model key. + + Args: + model_key (str): Key identifying which model configuration to use + + Returns: + ChatNVIDIA: Initialized ChatNVIDIA instance + + Raises: + ValueError: If model_key is not found in configurations + """ if model_key not in self.model_configs: raise ValueError(f"Unknown model key: {model_key}") if model_key not in self._llm_cache: @@ -118,7 +162,21 @@ def query_sync( json_schema: Optional[Dict] = None, retries: int = 5, ) -> Union[AIMessage, Dict[str, Any]]: - """Send a synchronous query to the specified model""" + """Send a synchronous query to the specified model. + + Args: + model_key (str): Key identifying which model to use + messages (List[Dict[str, str]]): List of message dictionaries + query_name (str): Name of query for telemetry + json_schema (Optional[Dict]): Schema for structured output + retries (int): Number of retry attempts + + Returns: + Union[AIMessage, Dict[str, Any]]: Model response + + Raises: + Exception: If query fails after retries + """ with self.telemetry.tracer.start_as_current_span( f"agent.query.{query_name}" ) as span: @@ -151,7 +209,21 @@ async def query_async( json_schema: Optional[Dict] = None, retries: int = 5, ) -> Union[AIMessage, Dict[str, Any]]: - """Send an asynchronous query to the specified model""" + """Send an asynchronous query to the specified model. + + Args: + model_key (str): Key identifying which model to use + messages (List[Dict[str, str]]): List of message dictionaries + query_name (str): Name of query for telemetry + json_schema (Optional[Dict]): Schema for structured output + retries (int): Number of retry attempts + + Returns: + Union[AIMessage, Dict[str, Any]]: Model response + + Raises: + Exception: If query fails after retries + """ with self.telemetry.tracer.start_as_current_span( f"agent.query.{query_name}" ) as span: @@ -184,7 +256,21 @@ def stream_sync( json_schema: Optional[Dict] = None, retries: int = 5, ) -> Union[str, Dict[str, Any]]: - """Send a synchronous streaming query to the specified model""" + """Send a synchronous streaming query to the specified model. + + Args: + model_key (str): Key identifying which model to use + messages (List[Dict[str, str]]): List of message dictionaries + query_name (str): Name of query for telemetry + json_schema (Optional[Dict]): Schema for structured output + retries (int): Number of retry attempts + + Returns: + Union[str, Dict[str, Any]]: Final chunk from model stream + + Raises: + Exception: If streaming query fails after retries + """ with self.telemetry.tracer.start_as_current_span( f"agent.stream.{query_name}" ) as span: @@ -226,7 +312,21 @@ async def stream_async( json_schema: Optional[Dict] = None, retries: int = 5, ) -> Union[str, Dict[str, Any]]: - """Send an asynchronous streaming query to the specified model""" + """Send an asynchronous streaming query to the specified model. + + Args: + model_key (str): Key identifying which model to use + messages (List[Dict[str, str]]): List of message dictionaries + query_name (str): Name of query for telemetry + json_schema (Optional[Dict]): Schema for structured output + retries (int): Number of retry attempts + + Returns: + Union[str, Dict[str, Any]]: Final chunk from model stream + + Raises: + Exception: If streaming query fails after retries + """ with self.telemetry.tracer.start_as_current_span( f"agent.stream.{query_name}" ) as span: diff --git a/shared/shared/otel.py b/shared/shared/otel.py index 2e7200c..6123068 100644 --- a/shared/shared/otel.py +++ b/shared/shared/otel.py @@ -20,7 +20,16 @@ @dataclass class OpenTelemetryConfig: - """Configuration for OpenTelemetry setup.""" + """Configuration for OpenTelemetry setup. + + Attributes: + service_name (str): Name of the service to be used in traces + otlp_endpoint (str): OTLP endpoint URL for sending traces. Defaults to "http://jaeger:4317" + enable_redis (bool): Whether to enable Redis instrumentation. Defaults to True + enable_requests (bool): Whether to enable requests library instrumentation. Defaults to True + enable_httpx (bool): Whether to enable HTTPX client instrumentation. Defaults to True + enable_urllib3 (bool): Whether to enable urllib3 instrumentation. Defaults to True + """ service_name: str otlp_endpoint: str = "http://jaeger:4317" @@ -32,7 +41,11 @@ class OpenTelemetryConfig: class OpenTelemetryInstrumentation: """ - Lightweight OTEL wrapper + Lightweight OpenTelemetry wrapper for easy instrumentation of FastAPI applications. + + This class provides a simple interface to set up OpenTelemetry tracing with common + instrumentations like Redis, requests, HTTPX, and urllib3. It handles the configuration + of trace providers, processors, and exporters. Example usage: telemetry = OpenTelemetryInstrumentation() @@ -45,12 +58,20 @@ class OpenTelemetryInstrumentation: """ def __init__(self): + """Initialize the OpenTelemetryInstrumentation instance.""" self._tracer: Optional[trace.Tracer] = None self._config: Optional[OpenTelemetryConfig] = None @property def tracer(self) -> trace.Tracer: - """Get the configured tracer instance.""" + """Get the configured tracer instance. + + Returns: + trace.Tracer: The configured OpenTelemetry tracer + + Raises: + RuntimeError: If initialize() hasn't been called yet + """ if not self._tracer: raise RuntimeError( "OpenTelemetry has not been initialized. Call initialize() first." @@ -68,7 +89,7 @@ def initialize( config: OpenTelemetryConfig instance containing configuration options Returns: - self for method chaining + OpenTelemetryInstrumentation: self for method chaining """ self._config = config logger.info(f"Setting up tracing for service: {self._config.service_name}") @@ -78,7 +99,11 @@ def initialize( return self def _setup_tracing(self) -> None: - """Set up the OpenTelemetry tracer provider and processors.""" + """Set up the OpenTelemetry tracer provider and processors. + + Configures the trace provider with the service name resource and sets up + batch processing of spans to the configured OTLP endpoint. + """ resource = Resource.create({"service.name": self._config.service_name}) provider = TracerProvider(resource=resource) @@ -92,7 +117,14 @@ def _setup_tracing(self) -> None: self._tracer = trace.get_tracer(self._config.service_name) def _instrument_app(self, app=None) -> None: - """Instrument the FastAPI application and optional components.""" + """Instrument the FastAPI application and optional components. + + Args: + app: Optional FastAPI application instance to instrument + + Enables instrumentation for FastAPI (if app provided) and optionally for + Redis, requests, HTTPX, and urllib3 based on configuration. + """ # Instrument FastAPI if app: FastAPIInstrumentor.instrument_app(app) diff --git a/shared/shared/pdf_types.py b/shared/shared/pdf_types.py index ce56560..49d424b 100644 --- a/shared/shared/pdf_types.py +++ b/shared/shared/pdf_types.py @@ -5,18 +5,53 @@ class ConversionStatus(str, Enum): + """Enum representing the status of a PDF conversion. + + This enum is used to track the success or failure state of PDF conversion operations. + + Attributes: + SUCCESS: Indicates successful conversion + FAILED: Indicates failed conversion + """ SUCCESS = "success" FAILED = "failed" class PDFConversionResult(BaseModel): + """Model representing the result of a PDF conversion operation. + + This model captures the output and status of converting a PDF file to text. + It includes both successful conversions with extracted content and failed + conversions with error details. + + Attributes: + filename (str): Name of the PDF file + content (str): Extracted text content from the PDF + status (ConversionStatus): Status of the conversion operation + error (Optional[str]): Error message if conversion failed + """ filename: str content: str = "" - status: ConversionStatus + status: ConversionStatus error: Optional[str] = None class PDFMetadata(BaseModel): + """Model representing metadata about a processed PDF document. + + This model stores metadata and processing results for a PDF document, including + both the converted content in markdown format and a generated summary. It tracks + whether the document is a primary target or supplementary context document. + + Attributes: + filename (str): Name of the PDF file + markdown (str): Markdown representation of the PDF content + summary (str): Generated summary of the PDF content + status (ConversionStatus): Status of the PDF processing + type (Union[Literal["target"], Literal["context"]]): Whether this is a target or context document + error (Optional[str]): Error message if processing failed + created_at (datetime): Timestamp when this metadata was created + """ filename: str markdown: str = "" summary: str = "" diff --git a/shared/shared/podcast_types.py b/shared/shared/podcast_types.py index cb6f435..99fe6a0 100644 --- a/shared/shared/podcast_types.py +++ b/shared/shared/podcast_types.py @@ -3,6 +3,15 @@ class SavedPodcast(BaseModel): + """Model representing a saved podcast file. + + Attributes: + job_id (str): Unique identifier for the podcast generation job + filename (str): Name of the saved podcast file + created_at (str): Timestamp when podcast was created + size (int): Size of the podcast file in bytes + transcription_params (Optional[Dict]): Optional parameters used for transcription + """ job_id: str filename: str created_at: str @@ -11,29 +20,65 @@ class SavedPodcast(BaseModel): class SavedPodcastWithAudio(SavedPodcast): + """Model extending SavedPodcast to include audio data. + + Attributes: + audio_data (str): Base64 encoded audio data of the podcast + """ audio_data: str class DialogueEntry(BaseModel): + """Model representing a single dialogue entry in a conversation. + + Attributes: + text (str): The spoken text content + speaker (Literal["speaker-1", "speaker-2"]): Identifier for which speaker is talking + """ text: str speaker: Literal["speaker-1", "speaker-2"] class Conversation(BaseModel): + """Model representing a conversation between two speakers. + + Attributes: + scratchpad (str): Working notes or context for the conversation + dialogue (List[DialogueEntry]): List of dialogue entries making up the conversation + """ scratchpad: str dialogue: List[DialogueEntry] class SegmentPoint(BaseModel): + """Model representing a key point within a podcast segment topic. + + Attributes: + description (str): Description of the point to be covered + """ description: str class SegmentTopic(BaseModel): + """Model representing a topic within a podcast segment. + + Attributes: + title (str): Title of the topic + points (List[SegmentPoint]): List of key points to cover in the topic + """ title: str points: List[SegmentPoint] class PodcastSegment(BaseModel): + """Model representing a segment of a podcast. + + Attributes: + section (str): Name or title of the segment + topics (List[SegmentTopic]): List of topics to cover in the segment + duration (int): Duration of the segment in seconds + references (List[str]): List of reference sources for the segment content + """ section: str topics: List[SegmentTopic] duration: int @@ -41,5 +86,11 @@ class PodcastSegment(BaseModel): class PodcastOutline(BaseModel): + """Model representing the complete outline of a podcast. + + Attributes: + title (str): Title of the podcast + segments (List[PodcastSegment]): List of segments making up the podcast + """ title: str segments: List[PodcastSegment] diff --git a/shared/shared/prompt_tracker.py b/shared/shared/prompt_tracker.py index 5c46aef..d777adf 100644 --- a/shared/shared/prompt_tracker.py +++ b/shared/shared/prompt_tracker.py @@ -9,16 +9,43 @@ class PromptTracker: - """Track prompts and responses and save them to storage""" + """Track prompts and responses and save them to storage. + + This class provides functionality to track and store prompts, responses and processing + steps for a given job. It maintains a history of interactions that can be persisted + to storage. + + Attributes: + job_id (str): Unique identifier for the job being tracked + user_id (str): Identifier for the user who owns this job + steps (Dict[str, ProcessingStep]): Dictionary mapping step names to processing steps + storage_manager (StorageManager): Manager for persisting data to storage + """ def __init__(self, job_id: str, user_id: str, storage_manager: StorageManager): + """Initialize a new PromptTracker instance. + + Args: + job_id (str): Unique identifier for the job + user_id (str): Identifier for the user + storage_manager (StorageManager): Storage manager instance for persistence + """ self.job_id = job_id self.user_id = user_id self.steps: Dict[str, ProcessingStep] = {} self.storage_manager = storage_manager def track(self, step_name: str, prompt: str, model: str, response: str = None): - """Track a processing step""" + """Track a processing step + + Creates a new ProcessingStep entry and optionally saves it if a response is provided. + + Args: + step_name (str): Name identifying this processing step + prompt (str): The prompt text used + model (str): Name/identifier of the model used + response (str, optional): Response received from the model. Defaults to None. + """ self.steps[step_name] = ProcessingStep( step_name=step_name, prompt=prompt, @@ -31,7 +58,15 @@ def track(self, step_name: str, prompt: str, model: str, response: str = None): logger.info(f"Tracked step {step_name} for {self.job_id}") def update_result(self, step_name: str, response: str): - """Update the response for an existing step""" + """Save the current state to storage + + Args: + step_name (str): Name of the step to update + response (str): New response text to store + + Note: + If the step_name doesn't exist, a warning will be logged and no update occurs. + """ if step_name in self.steps: self.steps[step_name].response = response self._save() @@ -40,7 +75,11 @@ def update_result(self, step_name: str, response: str): logger.warning(f"Step {step_name} not found in prompt tracker") def _save(self): - """Save the current state to storage""" + """Save the current state to storage + + Converts the tracked steps to JSON format and stores them using the storage manager. + The file is saved with a name based on the job_id. + """ tracker = PromptTrackerModel(steps=list(self.steps.values())) self.storage_manager.store_file( self.user_id, diff --git a/shared/shared/prompt_types.py b/shared/shared/prompt_types.py index f15b507..f7e1c96 100644 --- a/shared/shared/prompt_types.py +++ b/shared/shared/prompt_types.py @@ -3,6 +3,18 @@ class ProcessingStep(BaseModel): + """Model representing a single processing step in an AI interaction. + + This model captures details about a specific interaction with an AI model, + including the prompt used, response received, and timing information. + + Attributes: + step_name (str): Name identifying this processing step + prompt (str): The prompt text sent to the model + response (str): The response received from the model + model (str): Name/identifier of the AI model used + timestamp (float): Unix timestamp when this step occurred + """ step_name: str prompt: str response: str @@ -11,4 +23,12 @@ class ProcessingStep(BaseModel): class PromptTracker(BaseModel): + """Model for tracking a sequence of AI processing steps. + + This model maintains an ordered list of processing steps that occurred + during a job, providing a complete history of AI interactions. + + Attributes: + steps (List[ProcessingStep]): Ordered list of processing steps that occurred + """ steps: List[ProcessingStep] diff --git a/shared/shared/storage.py b/shared/shared/storage.py index 213f8cf..8c62c9c 100644 --- a/shared/shared/storage.py +++ b/shared/shared/storage.py @@ -27,11 +27,29 @@ # TODO: wrap errors in StorageError # TODO: implement cleanup and delete as well class StorageManager: + """Manages storage operations using MinIO as the backend. + + This class provides an interface for storing and retrieving files using MinIO, + with support for user isolation, job tracking, and metadata management. + + Attributes: + telemetry (OpenTelemetryInstrumentation): Instance for tracing operations + client (Minio): MinIO client instance + bucket_name (str): Name of the MinIO bucket to use + """ + def __init__(self, telemetry: OpenTelemetryInstrumentation): - """ - Initialize MinIO client and ensure bucket exists + """Initialize MinIO client and ensure bucket exists. requires: OpenTelemetryInstrumentation instance for tracing since Minio does not have an auto otel instrumentor + + Requires: + Args: + telemetry (OpenTelemetryInstrumentation): Instance for tracing since MinIO + does not have an auto OpenTelemetry instrumentor + + Raises: + Exception: If MinIO client initialization fails """ try: self.telemetry: OpenTelemetryInstrumentation = telemetry @@ -60,6 +78,11 @@ def __init__(self, telemetry: OpenTelemetryInstrumentation): raise def _ensure_bucket_exists(self): + """Ensure the configured bucket exists, creating it if necessary. + + Raises: + Exception: If bucket creation fails + """ try: if not self.client.bucket_exists(self.bucket_name): self.client.make_bucket(self.bucket_name) @@ -68,7 +91,16 @@ def _ensure_bucket_exists(self): raise def _get_object_path(self, user_id: str, job_id: str, filename: str) -> str: - """Generate the full object path including user isolation""" + """Generate the full object path including user isolation. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + filename (str): Name of the file + + Returns: + str: Full object path in format "user_id/job_id/filename" + """ return f"{user_id}/{job_id}/{filename}" def store_file( @@ -80,7 +112,19 @@ def store_file( content_type: str, metadata: dict = None, ) -> None: - """Store any file type in MinIO with metadata""" + """Store any file type in MinIO with metadata. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + content (bytes): File content to store + filename (str): Name of the file + content_type (str): MIME type of the file + metadata (dict, optional): Additional metadata to store. Defaults to None. + + Raises: + Exception: If file storage fails + """ with self.telemetry.tracer.start_as_current_span("store_file") as span: span.set_attribute("user_id", user_id) span.set_attribute("job_id", job_id) @@ -113,7 +157,18 @@ def store_audio( filename: str, transcription_params: TranscriptionParams, ): - """Store audio file with metadata in MinIO""" + """Store audio file with metadata in MinIO. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + audio_content (bytes): Audio file content + filename (str): Name of the audio file + transcription_params (TranscriptionParams): Parameters used for transcription + + Raises: + S3Error: If MinIO storage operation fails + """ with self.telemetry.tracer.start_as_current_span("store_audio") as span: span.set_attribute("job_id", job_id) span.set_attribute("user_id", user_id) @@ -146,7 +201,18 @@ def store_audio( raise def get_podcast_audio(self, user_id: str, job_id: str) -> Optional[str]: - """Get the audio data for a specific podcast by job_id""" + """Get the audio data for a specific podcast by job_id. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + + Returns: + Optional[str]: Base64 encoded audio data if found, None otherwise + + Raises: + Exception: If retrieval fails + """ with self.telemetry.tracer.start_as_current_span("get_podcast_audio") as span: span.set_attribute("job_id", job_id) span.set_attribute("user_id", user_id) @@ -176,7 +242,19 @@ def get_podcast_audio(self, user_id: str, job_id: str) -> Optional[str]: raise def get_file(self, user_id: str, job_id: str, filename: str) -> Optional[bytes]: - """Get any file from storage by user_id, job_id and filename""" + """Get any file from storage by user_id, job_id and filename. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + filename (str): Name of the file to retrieve + + Returns: + Optional[bytes]: File content if found, None if file doesn't exist + + Raises: + Exception: If retrieval fails for reasons other than missing file + """ with self.telemetry.tracer.start_as_current_span("get_file") as span: span.set_attribute("job_id", job_id) span.set_attribute("user_id", user_id) @@ -202,7 +280,15 @@ def get_file(self, user_id: str, job_id: str, filename: str) -> Optional[bytes]: raise def delete_job_files(self, user_id: str, job_id: str) -> bool: - """Delete all files associated with a user_id and job_id""" + """Delete all files associated with a user_id and job_id. + + Args: + user_id (str): ID of the user + job_id (str): ID of the job + + Returns: + bool: True if deletion successful, False otherwise + """ with self.telemetry.tracer.start_as_current_span("delete_job_files") as span: span.set_attribute("job_id", job_id) span.set_attribute("user_id", user_id) @@ -229,7 +315,17 @@ def delete_job_files(self, user_id: str, job_id: str) -> bool: return False def list_files_metadata(self, user_id: str = None): - """Lists metadata filtered by user_id if provided""" + """Lists metadata filtered by user_id if provided. + + Args: + user_id (str, optional): ID of user to filter results. Defaults to None. + + Returns: + list: List of dictionaries containing file metadata + + Raises: + Exception: If listing fails + """ with self.telemetry.tracer.start_as_current_span("list_files_metadata") as span: try: # If user_id is provided, use it as prefix to filter results diff --git a/tests/test.py b/tests/test.py index 46a0cfe..1bf2ba7 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,3 +1,9 @@ +"""Test module for PDF-to-Podcast API functionality. + +This module provides comprehensive testing capabilities for the PDF-to-Podcast API service, +including WebSocket status monitoring, file processing, and endpoint verification. +""" + import requests import os import json as json @@ -22,7 +28,25 @@ class StatusMonitor: + """Monitor WebSocket status updates for PDF-to-Podcast jobs. + + This class handles WebSocket connections to track the status of PDF processing, + agent processing, and text-to-speech conversion for a specific job. + + Attributes: + base_url (str): Base URL of the API service + job_id (str): Unique identifier for the job being monitored + services (set): Set of services to monitor (pdf, agent, tts) + tts_completed (Event): Event that is set when TTS processing completes + """ + def __init__(self, base_url, job_id): + """Initialize the status monitor. + + Args: + base_url (str): Base URL of the API service + job_id (str): Unique identifier for the job to monitor + """ self.base_url = base_url self.job_id = job_id self.ws_url = self._get_ws_url(base_url) @@ -44,6 +68,11 @@ def _get_ws_url(self, base_url): return urljoin(ws_base, f"/ws/status/{self.job_id}") def get_time(self): + """Get current time formatted as string. + + Returns: + str: Current time in HH:MM:SS format + """ return datetime.now().strftime("%H:%M:%S") def start(self): @@ -63,6 +92,7 @@ def _run_async_loop(self): loop.run_until_complete(self._monitor_status()) async def _monitor_status(self): + """Monitor WebSocket status updates with automatic reconnection""" while not self.stop_event.is_set(): try: async with websockets.connect(self.ws_url) as websocket: @@ -115,7 +145,11 @@ async def _monitor_status(self): ) async def _handle_message(self, message): - """Handle incoming WebSocket messages""" + """Handle incoming WebSocket messages. + + Args: + message (str): JSON message from WebSocket + """ try: data = json.loads(message) service = data.get("service") @@ -143,7 +177,20 @@ async def _handle_message(self, message): def get_output_with_retry(base_url: str, job_id: str, max_retries=5, retry_delay=1): - """Retry getting output with exponential backoff""" + """Retry getting output with exponential backoff. + + Args: + base_url (str): Base URL of the API service + job_id (str): Job ID to get output for + max_retries (int): Maximum number of retry attempts + retry_delay (int): Initial delay between retries in seconds + + Returns: + bytes: Audio file content + + Raises: + TimeoutError: If maximum retries exceeded + """ for attempt in range(max_retries): try: response = requests.get( @@ -170,7 +217,17 @@ def get_output_with_retry(base_url: str, job_id: str, max_retries=5, retry_delay def test_saved_podcasts(base_url: str, job_id: str, max_retries=5, retry_delay=5): - """Test the saved podcasts endpoints with retry logic""" + """Test the saved podcasts endpoints with retry logic. + + Args: + base_url (str): Base URL of the API service + job_id (str): Job ID to test + max_retries (int): Maximum number of retry attempts + retry_delay (int): Initial delay between retries in seconds + + Raises: + AssertionError: If any endpoint tests fail + """ print( f"\n[{datetime.now().strftime('%H:%M:%S')}] Testing saved podcasts endpoints..." ) @@ -267,6 +324,19 @@ def test_api( monologue: bool = False, vdb: bool = False, ): + """Test the PDF-to-Podcast API functionality. + + Args: + base_url (str): Base URL of the API service + target_files (List[str]): List of target PDF files to process + context_files (List[str]): List of context PDF files + monologue (bool): Whether to generate monologue instead of dialogue + vdb (bool): Whether to enable vector database processing + + Raises: + AssertionError: If any API tests fail + Exception: For other errors during testing + """ voice_mapping = { "speaker-1": "iP95p4xoKVk53GoZ742B", } diff --git a/tests/test_db.py b/tests/test_db.py index 358cee7..b11788c 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -10,6 +10,16 @@ # Mock the TranscriptionParams that was in main.py @dataclass class TranscriptionParams: + """Parameters for transcription configuration. + + Attributes: + name (str): Name of the podcast + duration (int): Duration in minutes + speaker_1_name (str): Name of first speaker + speaker_2_name (str): Name of second speaker + model (str): Model to use for transcription + voice_mapping (Dict[str, str]): Mapping of speaker IDs to voice IDs + """ name: str duration: int speaker_1_name: str @@ -19,8 +29,18 @@ class TranscriptionParams: class StorageManager: + """Manages storage operations with MinIO for audio files. + + This class handles initialization of MinIO client, bucket creation, + and operations for storing and retrieving audio files. + """ + def __init__(self): - """Initialize MinIO client and ensure bucket exists""" + """Initialize MinIO client and ensure bucket exists. + + Raises: + Exception: If MinIO client initialization fails + """ try: self.client = Minio( os.getenv("MINIO_ENDPOINT", "localhost:9000"), @@ -38,9 +58,19 @@ def __init__(self): raise def get_time(self): + """Get current time formatted as string. + + Returns: + str: Current time in HH:MM:SS format + """ return datetime.now().strftime("%H:%M:%S") def _ensure_bucket_exists(self): + """Create bucket if it doesn't exist. + + Raises: + Exception: If bucket creation fails + """ try: if not self.client.bucket_exists(self.bucket_name): self.client.make_bucket(self.bucket_name) @@ -56,6 +86,17 @@ def store_audio( filename: str, transcription_params: TranscriptionParams, ): + """Store audio file in MinIO. + + Args: + job_id (str): Unique identifier for the job + audio_content (bytes): Audio file content + filename (str): Name of the audio file + transcription_params (TranscriptionParams): Parameters for transcription + + Returns: + bool: True if storage successful, False otherwise + """ try: object_name = f"{job_id}/{filename}" self.client.put_object( @@ -74,6 +115,15 @@ def store_audio( return False def get_audio(self, job_id: str, filename: str): + """Retrieve audio file from MinIO. + + Args: + job_id (str): Unique identifier for the job + filename (str): Name of the audio file + + Returns: + bytes: Audio file content if successful, None otherwise + """ try: object_name = f"{job_id}/{filename}" result = self.client.get_object(self.bucket_name, object_name).read() @@ -87,6 +137,15 @@ def get_audio(self, job_id: str, filename: str): def test_storage_manager(): + """Run tests for StorageManager functionality. + + Tests include: + 1. Initialization of StorageManager + 2. Storing audio file + 3. Retrieving stored audio file + 4. Handling non-existent file retrieval + 5. Cleanup of test data + """ print("\n=== Starting Storage Manager Tests ===") # Initialize test data diff --git a/tests/test_files.py b/tests/test_files.py index 09ad716..d9e5173 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,3 +1,9 @@ +"""Test module for transcript and prompt history retrieval endpoints. + +This module contains test functions to verify the functionality of the transcript +and prompt history retrieval endpoints of the PDF-to-Podcast API service. +""" + import requests import os from datetime import datetime @@ -8,6 +14,14 @@ def test_transcript(): + """Test the transcript retrieval endpoint. + + Makes a GET request to retrieve the transcript for a specific job ID. + Prints the transcript content if successful, otherwise prints error details. + + Environment Variables: + API_SERVICE_URL: Base URL of the API service (default: http://localhost:8002) + """ base_url = os.getenv("API_SERVICE_URL", "http://localhost:8002") print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Testing transcript endpoint...") print(f"Job ID: {JOB_ID}") @@ -33,6 +47,14 @@ def test_transcript(): def test_prompt_tracker(): + """Test the prompt history retrieval endpoint. + + Makes a GET request to retrieve the prompt generation history for a specific job ID. + Prints the history content if successful, otherwise prints error details. + + Environment Variables: + API_SERVICE_URL: Base URL of the API service (default: http://localhost:8002) + """ base_url = os.getenv("API_SERVICE_URL", "http://localhost:8002") print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Testing history endpoint...") diff --git a/tests/test_invalid_filetype.py b/tests/test_invalid_filetype.py index 7ef6687..95a17cf 100644 --- a/tests/test_invalid_filetype.py +++ b/tests/test_invalid_filetype.py @@ -1,3 +1,9 @@ +"""Test module for invalid file type handling. + +This module contains test functions to verify that the API properly handles +invalid file types and malformed transcription parameters. +""" + import os import requests from requests import Response @@ -6,6 +12,18 @@ def test(base_url: str): + """Test invalid file type and parameter handling. + + Tests two scenarios: + 1. Submitting a .txt file instead of PDF + 2. Submitting invalid transcription parameters + + Args: + base_url (str): Base URL of the API service + + Raises: + AssertionError: If response status codes don't match expected 400 + """ # Define default voice mapping voice_mapping = { "speaker-1": "iP95p4xoKVk53GoZ742B", # Example voice ID for speaker 1 @@ -43,6 +61,21 @@ def test(base_url: str): def test_api( base_url: str, file_name: str, file_type: str, transcription_params: dict[str, any] ) -> Response: + """Test the PDF processing API endpoint with various inputs. + + Args: + base_url (str): Base URL of the API service + file_name (str): Name of the file to upload + file_type (str): MIME type of the file + transcription_params (dict[str, any]): Parameters for transcription + + Returns: + Response: Response object from the API request + + Raises: + FileNotFoundError: If samples directory or test file not found + AssertionError: If test file does not exist + """ # API endpoint process_url = f"{base_url}/process_pdf" diff --git a/tests/test_list.py b/tests/test_list.py index db3b65c..03952a8 100644 --- a/tests/test_list.py +++ b/tests/test_list.py @@ -1,15 +1,42 @@ +"""Test module for listing saved podcasts. + +This module provides functionality to retrieve and display a list of all saved podcasts +from the PDF-to-Podcast API service, including their metadata and transcription parameters. +""" + import requests import ujson as json from datetime import datetime -def format_timestamp(timestamp_str): - """Format the ISO timestamp into a more readable format""" +def format_timestamp(timestamp_str: str) -> str: + """Format an ISO timestamp string into a more readable format. + + Args: + timestamp_str (str): ISO format timestamp string with timezone info + + Returns: + str: Formatted timestamp string in the format "Month DD, YYYY at HH:MM AM/PM" + """ dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) return dt.strftime("%B %d, %Y at %I:%M %p") def list_saved_podcasts(): + """Retrieve and display a list of all saved podcasts from the API. + + Makes a GET request to the /saved_podcasts endpoint and prints details of each podcast + including: + - Job ID + - Filename + - Creation timestamp + - Transcription parameters + + Handles various error cases like connection failures and invalid responses. + + Environment Variables: + API_SERVICE_URL: Base URL of the API service (default: http://localhost:8002) + """ try: print("\nAttempting to connect to API...") response = requests.get("http://localhost:8002/saved_podcasts")