From e9ceb984f14cae814ca62f73eb214c8bdbf5dce2 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Tue, 3 Sep 2024 04:33:10 -0500 Subject: [PATCH 01/10] feat: migrating scheduler to marie package initial workign version where the client monitors it own jobs --- {marie_server => marie}/job/__init__.py | 0 {marie_server => marie}/job/common.py | 2 +- .../job/event_publisher.py | 0 .../job/gateway_job_distributor.py | 17 ++- .../job/job_distributor.py | 6 +- .../job/job_log_storage_client.py | 2 +- {marie_server => marie}/job/job_manager.py | 24 ++-- .../job/job_storage_client_proxy.py | 6 +- {marie_server => marie}/job/job_supervisor.py | 112 +++++++++++----- .../job/placement_group.py | 0 .../job/pydantic_models.py | 0 .../job/scheduling_strategies.py | 2 +- {marie_server => marie}/job/utils.py | 0 marie/serve/executors/__init__.py | 2 +- .../serve/runtimes/worker/request_handling.py | 125 ++++++++++++++---- .../storage => marie/storage/kv}/__init__.py | 0 .../storage => marie/storage/kv}/in_memory.py | 2 +- .../storage => marie/storage/kv}/psql.py | 2 +- .../storage/kv}/storage_client.py | 0 marie_server/scheduler/psql.py | 10 +- poc/custom_gateway/create_jobs.sh | 2 +- poc/custom_gateway/direct-flow.py | 8 +- poc/custom_gateway/server_gateway.py | 12 +- 23 files changed, 229 insertions(+), 105 deletions(-) rename {marie_server => marie}/job/__init__.py (100%) rename {marie_server => marie}/job/common.py (99%) rename {marie_server => marie}/job/event_publisher.py (100%) rename {marie_server => marie}/job/gateway_job_distributor.py (82%) rename {marie_server => marie}/job/job_distributor.py (80%) rename {marie_server => marie}/job/job_log_storage_client.py (96%) rename {marie_server => marie}/job/job_manager.py (96%) rename {marie_server => marie}/job/job_storage_client_proxy.py (88%) rename {marie_server => marie}/job/job_supervisor.py (57%) rename {marie_server => marie}/job/placement_group.py (100%) rename {marie_server => marie}/job/pydantic_models.py (100%) rename {marie_server => marie}/job/scheduling_strategies.py (92%) rename {marie_server => marie}/job/utils.py (100%) rename {marie_server/storage => marie/storage/kv}/__init__.py (100%) rename {marie_server/storage => marie/storage/kv}/in_memory.py (98%) rename {marie_server/storage => marie/storage/kv}/psql.py (98%) rename {marie_server/storage => marie/storage/kv}/storage_client.py (100%) diff --git a/marie_server/job/__init__.py b/marie/job/__init__.py similarity index 100% rename from marie_server/job/__init__.py rename to marie/job/__init__.py diff --git a/marie_server/job/common.py b/marie/job/common.py similarity index 99% rename from marie_server/job/common.py rename to marie/job/common.py index 44f9f429..a0a083fc 100644 --- a/marie_server/job/common.py +++ b/marie/job/common.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union from marie.constants import KV_NAMESPACE_JOB -from marie_server.storage.storage_client import StorageArea +from marie.storage.kv.storage_client import StorageArea JOB_ID_METADATA_KEY = "job_submission_id" JOB_NAME_METADATA_KEY = "job_name" diff --git a/marie_server/job/event_publisher.py b/marie/job/event_publisher.py similarity index 100% rename from marie_server/job/event_publisher.py rename to marie/job/event_publisher.py diff --git a/marie_server/job/gateway_job_distributor.py b/marie/job/gateway_job_distributor.py similarity index 82% rename from marie_server/job/gateway_job_distributor.py rename to marie/job/gateway_job_distributor.py index 26ab831b..1b15ab53 100644 --- a/marie_server/job/gateway_job_distributor.py +++ b/marie/job/gateway_job_distributor.py @@ -1,15 +1,13 @@ from typing import Callable, List, Optional -from django.views.debug import CallableSettingWrapper from docarray import BaseDoc, DocList from docarray.documents import TextDoc -from marie import DocumentArray +from marie.job.common import JobInfo, JobStatus +from marie.job.job_distributor import JobDistributor from marie.logging.logger import MarieLogger from marie.serve.runtimes.gateway.streamer import GatewayStreamer from marie.types.request.data import DataRequest -from marie_server.job.common import JobInfo, JobStatus -from marie_server.job.job_distributor import JobDistributor class GatewayJobDistributor(JobDistributor): @@ -23,6 +21,7 @@ def __init__( async def submit_job( self, + submission_id: str, job_info: JobInfo, send_callback: Callable[[List[DataRequest]], DataRequest] = None, ) -> DataRequest: @@ -32,7 +31,7 @@ async def submit_job( if curr_status != JobStatus.PENDING: raise RuntimeError( - f"Job {job_info._job_id} is not in PENDING state. " + f"Job {submission_id} is not in PENDING state. " f"Current status is {curr_status} with message {curr_message}." ) @@ -41,12 +40,18 @@ async def submit_job( self.logger.warning(f"Gateway streamer is not initialized") raise RuntimeError("Gateway streamer is not initialized") + print("job_info.metadata", job_info.metadata) + parameters = {"job_id": submission_id} # "#job_info.job_id, + if job_info.metadata: + parameters.update(job_info.metadata) + doc = TextDoc(text=f"sample text : {job_info.entrypoint}") + request = DataRequest() request.document_array_cls = DocList[BaseDoc]() request.header.exec_endpoint = "/extract" request.header.target_executor = "executor0" # job_info.entrypoint - request.parameters = {} # job_info.metadata + request.parameters = parameters request.data.docs = DocList([doc]) response = await self.streamer.process_single_data( diff --git a/marie_server/job/job_distributor.py b/marie/job/job_distributor.py similarity index 80% rename from marie_server/job/job_distributor.py rename to marie/job/job_distributor.py index 1befc9e0..c4e9237a 100644 --- a/marie_server/job/job_distributor.py +++ b/marie/job/job_distributor.py @@ -1,9 +1,9 @@ import abc from typing import Callable, Dict, List, Optional +from marie.job.common import JobInfo from marie.types.request import Request from marie.types.request.data import DataRequest -from marie_server.job.common import JobInfo class JobDistributor(abc.ABC): @@ -14,12 +14,14 @@ class JobDistributor(abc.ABC): @abc.abstractmethod async def submit_job( self, + submission_id: str, job_info: JobInfo, send_callback: Optional[Callable[[List[Request], Dict[str, str]], None]] = None, ) -> DataRequest: """ - Publish a job. + Publish a job to the underlying executor. + :param submission_id: The submission id of the job. :param job_info: The job info to publish. :param send_callback: The callback after the job is submitted over the network. :return: diff --git a/marie_server/job/job_log_storage_client.py b/marie/job/job_log_storage_client.py similarity index 96% rename from marie_server/job/job_log_storage_client.py rename to marie/job/job_log_storage_client.py index 6712e350..8140a5a9 100644 --- a/marie_server/job/job_log_storage_client.py +++ b/marie/job/job_log_storage_client.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Iterator, List, Tuple -from marie_server.job.utils import JOB_LOGS_PATH_TEMPLATE, file_tail_iterator +from marie.job.utils import JOB_LOGS_PATH_TEMPLATE, file_tail_iterator class JobLogStorageClient: diff --git a/marie_server/job/job_manager.py b/marie/job/job_manager.py similarity index 96% rename from marie_server/job/job_manager.py rename to marie/job/job_manager.py index 6652c55b..812707ed 100644 --- a/marie_server/job/job_manager.py +++ b/marie/job/job_manager.py @@ -7,23 +7,18 @@ from uuid_extensions import uuid7str from marie._core.utils import run_background_task -from marie.logging.logger import MarieLogger -from marie_server.job.common import ( - ActorHandle, - JobInfo, - JobInfoStorageClient, - JobStatus, -) -from marie_server.job.event_publisher import EventPublisher -from marie_server.job.job_distributor import JobDistributor -from marie_server.job.job_log_storage_client import JobLogStorageClient -from marie_server.job.job_storage_client_proxy import JobInfoStorageClientProxy -from marie_server.job.job_supervisor import JobSupervisor -from marie_server.job.scheduling_strategies import ( +from marie.job.common import ActorHandle, JobInfo, JobInfoStorageClient, JobStatus +from marie.job.event_publisher import EventPublisher +from marie.job.job_distributor import JobDistributor +from marie.job.job_log_storage_client import JobLogStorageClient +from marie.job.job_storage_client_proxy import JobInfoStorageClientProxy +from marie.job.job_supervisor import JobSupervisor +from marie.job.scheduling_strategies import ( NodeAffinitySchedulingStrategy, SchedulingStrategyT, ) -from marie_server.storage.storage_client import StorageArea +from marie.logging.logger import MarieLogger +from marie.storage.kv.storage_client import StorageArea # The max time to wait for the JobSupervisor to start before failing the job. DEFAULT_JOB_START_TIMEOUT_SECONDS = 60 * 15 @@ -357,6 +352,7 @@ async def submit_job( job_id=submission_id, job_info_client=self._job_info_client, job_distributor=self._job_distributor, + event_publisher=self.event_publisher, ) await supervisor.run(_start_signal_actor=_start_signal_actor) diff --git a/marie_server/job/job_storage_client_proxy.py b/marie/job/job_storage_client_proxy.py similarity index 88% rename from marie_server/job/job_storage_client_proxy.py rename to marie/job/job_storage_client_proxy.py index 25796016..3848acc5 100644 --- a/marie_server/job/job_storage_client_proxy.py +++ b/marie/job/job_storage_client_proxy.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Optional -from marie_server.job.common import JobInfo, JobInfoStorageClient, JobStatus -from marie_server.job.event_publisher import EventPublisher -from marie_server.storage.storage_client import StorageArea +from marie.job.common import JobInfoStorageClient, JobStatus +from marie.job.event_publisher import EventPublisher +from marie.storage.kv.storage_client import StorageArea class JobInfoStorageClientProxy(JobInfoStorageClient): diff --git a/marie_server/job/job_supervisor.py b/marie/job/job_supervisor.py similarity index 57% rename from marie_server/job/job_supervisor.py rename to marie/job/job_supervisor.py index 13f5cb66..74375330 100644 --- a/marie_server/job/job_supervisor.py +++ b/marie/job/job_supervisor.py @@ -4,30 +4,40 @@ from docarray import BaseDoc, DocList from docarray.documents import TextDoc +from marie.job.common import ActorHandle, JobInfoStorageClient, JobStatus +from marie.job.event_publisher import EventPublisher +from marie.job.job_distributor import JobDistributor from marie.logging.logger import MarieLogger +from marie.proto import jina_pb2 from marie.serve.networking import _NetworkingHistograms, _NetworkingMetrics +from marie.serve.networking.connection_stub import _ConnectionStubs +from marie.serve.networking.utils import get_grpc_channel from marie.types.request.data import DataRequest -from marie_server.job.common import ActorHandle, JobInfoStorageClient, JobStatus -from marie_server.job.job_distributor import JobDistributor class JobSupervisor: """ Supervise jobs and keep track of their status on the remote executor. + + Executors are responsible for running the job and updating the status of the job, however, the Executor does not update the WorkState. + The JobSupervisor is responsible for updating the WorkState based on the status of the job. """ DEFAULT_JOB_STOP_WAIT_TIME_S = 3 + DEFAULT_JOB_TIMEOUT_S = 60 # 60 seconds, there should be no job that takes more than 60 seconds to process def __init__( self, job_id: str, job_info_client: JobInfoStorageClient, job_distributor: JobDistributor, + event_publisher: EventPublisher, ): self.logger = MarieLogger(self.__class__.__name__) self._job_id = job_id self._job_info_client = job_info_client self._job_distributor = job_distributor + self._event_publisher = event_publisher self.request_info = None async def ping(self): @@ -44,9 +54,6 @@ async def ping(self): f"Sending ping to {address} for request {request_id} on deployment {deployment_name}" ) - from marie.serve.networking.connection_stub import _ConnectionStubs - from marie.serve.networking.utils import get_grpc_channel - channel = get_grpc_channel(address=address, asyncio=True) connection_stub = _ConnectionStubs( address=address, @@ -60,13 +67,14 @@ async def ping(self): histograms=_NetworkingHistograms(), ) - # print("DryRun - Response: ", response) - doc = TextDoc(text=f"Text : _jina_dry_run_") + doc = TextDoc(text=f"ping : _jina_dry_run_") request = DataRequest() request.document_array_cls = DocList[BaseDoc]() request.header.exec_endpoint = "_jina_dry_run_" request.header.target_executor = deployment_name - request.parameters = {} + request.parameters = { + "job_id": self._job_id, + } request.data.docs = DocList([doc]) try: @@ -74,7 +82,7 @@ async def ping(self): requests=[request], metadata={}, compression=False ) self.logger.debug(f"DryRun - Response: {response}") - if response.status.code == response.status.SUCCESS: + if response.status.code == jina_pb2.StatusProto.SUCCESS: return True else: raise RuntimeError( @@ -117,21 +125,37 @@ async def run( # Block in PENDING state until start signal received. await _start_signal_actor.wait.remote() - # this is our gateway address - driver_agent_http_address = "grpc://127.0.0.1" - driver_node_id = "CURRENT_NODE_ID" - - await self._job_info_client.put_status( - self._job_id, - JobStatus.RUNNING, - jobinfo_replace_kwargs={ - "driver_agent_http_address": driver_agent_http_address, - "driver_node_id": driver_node_id, - }, - ) + # moved to request_handling + # # this is our gateway address + # driver_agent_http_address = "grpc://127.0.0.1" + # driver_node_id = "CURRENT_NODE_ID" + # + # await self._job_info_client.put_status( + # self._job_id, + # JobStatus.RUNNING, + # jobinfo_replace_kwargs={ + # "driver_agent_http_address": driver_agent_http_address, + # "driver_node_id": driver_node_id, + # }, + # ) + # Run the job submission in the background - task = asyncio.create_task(self._submit_job_in_background(curr_info)) - print("Task: ", task) + if self.DEFAULT_JOB_TIMEOUT_S > 0: + try: + await asyncio.wait_for( + self._submit_job_in_background(curr_info), + timeout=self.DEFAULT_JOB_TIMEOUT_S, + ) + except asyncio.TimeoutError: + self.logger.error( + f"Job {self._job_id} timed out after {self.DEFAULT_JOB_TIMEOUT_S} seconds." + ) + await self._job_info_client.put_status( + self._job_id, JobStatus.FAILED, message="Job submission timed out." + ) + else: + task = asyncio.create_task(self._submit_job_in_background(curr_info)) + self.logger.debug(f"Job {self._job_id} submitted in the background.") def send_callback( self, requests: Union[List[DataRequest] | DataRequest], request_info: Dict @@ -151,7 +175,9 @@ def send_callback( async def _submit_job_in_background(self, curr_info): try: response = await self._job_distributor.submit_job( - curr_info, self.send_callback + submission_id=self._job_id, + job_info=curr_info, + send_callback=self.send_callback, ) # printing the whole response will trigger a bug in rich.print with stackoverflow # format the response @@ -161,17 +187,37 @@ async def _submit_job_in_background(self, curr_info): print("Response docs: ", response.data.docs) print("Response status: ", response.status) - job_status = await self._job_info_client.get_status(self._job_id) - - if job_status.is_terminal(): - # If the job is already in a terminal state, then we don't need to update it. This can happen if the - # job was cancelled while the job was being submitted. - self.logger.warning( - f"Job {self._job_id} is already in terminal state {job_status}." - ) + if response.status.code == jina_pb2.StatusProto.SUCCESS: + job_status = await self._job_info_client.get_status(self._job_id) + # "STOPPED", "SUCCEEDED", "FAILED" + if job_status.is_terminal(): + # If the job is already in a terminal state, then we don't need to update it. This can happen if the + # job was cancelled while the job was being submitted. + # or while the job was marked from the executor side. + self.logger.warning( + f"Job {self._job_id} is already in terminal state {job_status}." + ) + # triggers the event to update the WorkStatus + await self._event_publisher.publish( + job_status, + { + "job_id": self._job_id, + "status": job_status, + "message": f"Job {self._job_id} is already in terminal state {job_status}.", + "jobinfo_replace_kwargs": False, + }, + ) + else: + await self._job_info_client.put_status( + self._job_id, JobStatus.SUCCEEDED + ) else: + # FIXME : Need to store the exception in the job info + e: jina_pb2.StatusProto.ExceptionProto = response.status.exception + name = str(e.name) + # stack = to_json(e.stacks) await self._job_info_client.put_status( - self._job_id, JobStatus.SUCCEEDED + self._job_id, JobStatus.FAILED, message=f"{name}" ) except Exception as e: await self._job_info_client.put_status( diff --git a/marie_server/job/placement_group.py b/marie/job/placement_group.py similarity index 100% rename from marie_server/job/placement_group.py rename to marie/job/placement_group.py diff --git a/marie_server/job/pydantic_models.py b/marie/job/pydantic_models.py similarity index 100% rename from marie_server/job/pydantic_models.py rename to marie/job/pydantic_models.py diff --git a/marie_server/job/scheduling_strategies.py b/marie/job/scheduling_strategies.py similarity index 92% rename from marie_server/job/scheduling_strategies.py rename to marie/job/scheduling_strategies.py index 84dd75be..80602776 100644 --- a/marie_server/job/scheduling_strategies.py +++ b/marie/job/scheduling_strategies.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional, Union -from marie_server.job.placement_group import PlacementGroup +from marie.job.placement_group import PlacementGroup class PlacementGroupSchedulingStrategy: diff --git a/marie_server/job/utils.py b/marie/job/utils.py similarity index 100% rename from marie_server/job/utils.py rename to marie/job/utils.py diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py index 9b4f8f2e..92b51fcb 100644 --- a/marie/serve/executors/__init__.py +++ b/marie/serve/executors/__init__.py @@ -500,7 +500,7 @@ def _init_monitoring(self): ): with ImportExtensions( required=True, - help_text="You need to install the `prometheus_client` to use the montitoring functionality of marie", + help_text="You need to install the `prometheus_client` to use the monitoring functionality of marie", ): from prometheus_client import Summary diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index acdb094c..85e0f68f 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -25,10 +25,12 @@ from marie.excepts import BadConfigSource, RuntimeTerminated from marie.helper import get_full_version from marie.importer import ImportExtensions +from marie.job.common import JobInfoStorageClient, JobStatus from marie.proto import jina_pb2 from marie.serve.executors import BaseExecutor from marie.serve.instrumentation import MetricsTimer from marie.serve.runtimes.worker.batch_queue import BatchQueue +from marie.storage.kv.psql import PostgreSQLKV from marie.types.request.data import DataRequest, SingleDocumentRequest if docarray_v2: @@ -165,6 +167,7 @@ def __init__( self._hot_reload_task = None if self.args.reload: self._hot_reload_task = asyncio.create_task(self._hot_reload()) + self._init_job_info_client() def _http_fastapi_default_app(self, **kwargs): from marie.serve.runtimes.worker.http_fastapi_app import ( # For Gateway, it works as for head @@ -193,7 +196,7 @@ async def _shutdown(): return extend_rest_interface(app) def _http_fastapi_csp_app(self, **kwargs): - from jina.serve.runtimes.worker.http_csp_app import get_fastapi_app + from marie.serve.runtimes.worker.http_csp_app import get_fastapi_app request_models_map = self._executor._get_endpoint_models_dict() @@ -383,16 +386,16 @@ def _load_executor( uses_requests=self.args.uses_requests, uses_dynamic_batching=self.args.uses_dynamic_batching, runtime_args={ # these are not parsed to the yaml config file but are pass directly during init - 'workspace': self.args.workspace, - 'shard_id': self.args.shard_id, - 'shards': self.args.shards, - 'replicas': self.args.replicas, - 'name': self.args.name, - 'provider': self.args.provider, - 'provider_endpoint': self.args.provider_endpoint, - 'metrics_registry': metrics_registry, - 'tracer_provider': tracer_provider, - 'meter_provider': meter_provider, + "workspace": self.args.workspace, + "shard_id": self.args.shard_id, + "shards": self.args.shards, + "replicas": self.args.replicas, + "name": self.args.name, + "provider": self.args.provider, + "provider_endpoint": self.args.provider_endpoint, + "metrics_registry": metrics_registry, + "tracer_provider": tracer_provider, + "meter_provider": meter_provider, }, py_modules=self.args.py_modules, extra_search_paths=self.args.extra_search_paths, @@ -678,6 +681,14 @@ async def handle( return requests[0] requests, params = self._setup_requests(requests, exec_endpoint) + + print("requests", requests) + print("params", params) + job_id = None + if params is not None: + job_id = params.get("job_id", None) + await self._record_started_job(job_id, exec_endpoint, requests, params) + len_docs = len(requests[0].docs) # TODO we can optimize here and access the if exec_endpoint in self._batchqueue_config: assert len(requests) == 1, "dynamic batching does not support no_reduce" @@ -709,15 +720,19 @@ async def handle( docs_matrix, docs_map = WorkerRequestHandler._get_docs_matrix_from_request( requests ) - return_data = await self._executor.__acall__( - req_endpoint=exec_endpoint, - docs=docs, - parameters=params, - docs_matrix=docs_matrix, - docs_map=docs_map, - tracing_context=tracing_context, - ) - _ = self._set_result(requests, return_data, docs) + try: + return_data = await self._executor.__acall__( + req_endpoint=exec_endpoint, + docs=docs, + parameters=params, + docs_matrix=docs_matrix, + docs_map=docs_map, + tracing_context=tracing_context, + ) + _ = self._set_result(requests, return_data, docs) + except Exception as e: + await self._record_failed_job(job_id, e) + raise e for req in requests: req.add_executor(self.deployment_name) @@ -729,6 +744,7 @@ async def handle( pass self._record_response_size_monitoring(requests) + await self._record_successful_job(job_id) return requests[0] @staticmethod @@ -947,10 +963,10 @@ async def stream_doc( ex = ValueError("endpoint must be generator") self.logger.error( ( - f'{ex!r}' + f"{ex!r}" + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error - else '' + else "" ), exc_info=not self.args.quiet_error, ) @@ -984,10 +1000,10 @@ async def stream_doc( ) self.logger.error( ( - f'{ex!r}' + f"{ex!r}" + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error - else '' + else "" ), exc_info=not self.args.quiet_error, ) @@ -1011,7 +1027,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: """ from google.protobuf import json_format - self.logger.debug('recv an endpoint discovery request') + self.logger.debug("recv an endpoint discovery request") endpoints_proto = jina_pb2.EndpointsProto() endpoints_proto.endpoints.extend(list(self._executor.requests.keys())) endpoints_proto.write_endpoints.extend(list(self._executor.write_endpoints)) @@ -1112,10 +1128,10 @@ async def process_data( except (RuntimeError, Exception) as ex: self.logger.error( ( - f'{ex!r}' + f"{ex!r}" + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error - else '' + else "" ), exc_info=not self.args.quiet_error, ) @@ -1324,3 +1340,58 @@ async def restore_status( id=jina_pb2.RestoreId(value=request.value), status=jina_pb2.RestoreSnapshotStatusProto.Status.NOT_FOUND, ) + + def _init_job_info_client(self): + # storage = self.runtime_args.storage + # FIXME : This should be coming from the runtime_args + kv_storage_config = { + "hostname": "127.0.0.1", + "port": 5432, + "username": "postgres", + "password": "123456", + "database": "postgres", + "default_table": "kv_store_a", + "max_pool_size": 5, + "max_connections": 5, + } + + storage = PostgreSQLKV(config=kv_storage_config, reset=False) + self._job_info_client = JobInfoStorageClient(storage) + + async def _record_failed_job(self, job_id: str, e: Exception): + if job_id is not None and self._job_info_client is not None: + print(f"Monitoring JOB: {job_id} - {e}") + try: + await self._job_info_client.put_status( + job_id, + JobStatus.FAILED, + jobinfo_replace_kwargs={"error_message": str(e)}, + ) + except Exception as e: + self.logger.error(f"Error in recording job status: {e}") + + async def _record_started_job(self, job_id: str, exec_endpoint, requests, params): + if job_id is not None and self._job_info_client is not None: + print(f"Monitoring JOB: {exec_endpoint} - {job_id}") + # this is our gateway address + driver_agent_http_address = "grpc://127.0.0.1" + driver_node_id = "CURRENT_NODE_ID" + try: + await self._job_info_client.put_status( + job_id, + JobStatus.RUNNING, + jobinfo_replace_kwargs={ + "driver_agent_http_address": driver_agent_http_address, + "driver_node_id": driver_node_id, + }, + ) + except Exception as e: + self.logger.error(f"Error in recording job status: {e}") + + async def _record_successful_job(self, job_id): + if job_id is not None and self._job_info_client is not None: + print(f"Monitoring JOB: {job_id}") + try: + await self._job_info_client.put_status(job_id, JobStatus.SUCCEEDED) + except Exception as e: + self.logger.error(f"Error in recording job status: {e}") diff --git a/marie_server/storage/__init__.py b/marie/storage/kv/__init__.py similarity index 100% rename from marie_server/storage/__init__.py rename to marie/storage/kv/__init__.py diff --git a/marie_server/storage/in_memory.py b/marie/storage/kv/in_memory.py similarity index 98% rename from marie_server/storage/in_memory.py rename to marie/storage/kv/in_memory.py index c4469f8f..1d404fd8 100644 --- a/marie_server/storage/in_memory.py +++ b/marie/storage/kv/in_memory.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional from marie.logging.logger import MarieLogger -from marie_server.storage.storage_client import StorageArea +from marie.storage.kv.storage_client import StorageArea class InMemoryKV(StorageArea): diff --git a/marie_server/storage/psql.py b/marie/storage/kv/psql.py similarity index 98% rename from marie_server/storage/psql.py rename to marie/storage/kv/psql.py index 97c4e033..fc85d0da 100644 --- a/marie_server/storage/psql.py +++ b/marie/storage/kv/psql.py @@ -5,7 +5,7 @@ from marie.logging.logger import MarieLogger from marie.storage.database.postgres import PostgresqlMixin -from marie_server.storage.storage_client import StorageArea +from marie.storage.kv.storage_client import StorageArea class PostgreSQLKV(PostgresqlMixin, StorageArea): diff --git a/marie_server/storage/storage_client.py b/marie/storage/kv/storage_client.py similarity index 100% rename from marie_server/storage/storage_client.py rename to marie/storage/kv/storage_client.py diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index 7283515f..1b7be6b9 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -7,11 +7,11 @@ import psycopg2 from marie.helper import get_or_reuse_loop +from marie.job.common import JobStatus +from marie.job.job_manager import JobManager from marie.logging.logger import MarieLogger from marie.logging.predefined import default_logger as logger from marie.storage.database.postgres import PostgresqlMixin -from marie_server.job.common import JobStatus -from marie_server.job.job_manager import JobManager from marie_server.scheduler.fixtures import * from marie_server.scheduler.job_scheduler import JobScheduler from marie_server.scheduler.models import WorkInfo @@ -262,7 +262,6 @@ async def get_work_items( :param stop_event: an event to signal when to stop iterating over the records :return: """ - print("elf._lock.locked():", self._lock.locked()) async with self._lock: with self: try: @@ -277,9 +276,8 @@ async def get_work_items( cursor = self.connection.cursor() cursor.itersize = limit cursor.execute(f"{query}") - records = [] - for record in cursor: - records.append(record) + records = [record for record in cursor] + return records except (Exception, psycopg2.Error) as error: self.logger.error(f"Error fetching next job: {error}") diff --git a/poc/custom_gateway/create_jobs.sh b/poc/custom_gateway/create_jobs.sh index dfa1d1fb..4601491b 100755 --- a/poc/custom_gateway/create_jobs.sh +++ b/poc/custom_gateway/create_jobs.sh @@ -2,7 +2,7 @@ # create number of jobs based on the input from the user -for i in $(seq 1 $1) +for i in $(seq 1 "$1") do echo "Submitting job $i" python ./send_request_to_gateway.py job submit "hello world - $i" diff --git a/poc/custom_gateway/direct-flow.py b/poc/custom_gateway/direct-flow.py index 9867d001..5600a2b1 100644 --- a/poc/custom_gateway/direct-flow.py +++ b/poc/custom_gateway/direct-flow.py @@ -1,5 +1,6 @@ import inspect import os +import random import time from docarray import DocList @@ -44,7 +45,7 @@ def __init__(self, *args, **kwargs): time.sleep(1) @requests(on="/extract") - def funcXX( + def func_extract( self, docs: DocList[TextDoc], parameters=None, @@ -55,6 +56,11 @@ def funcXX( parameters = {} print(f"FirstExec func called : {len(docs)}, {parameters}") + + # randomly throw an error to test the error handling + if random.random() > 0.5: + raise Exception("random error") + for doc in docs: doc.text += " First Exec" print("Sleeping for 5 seconds : ", time.time()) diff --git a/poc/custom_gateway/server_gateway.py b/poc/custom_gateway/server_gateway.py index f050c7df..f83eda63 100644 --- a/poc/custom_gateway/server_gateway.py +++ b/poc/custom_gateway/server_gateway.py @@ -15,6 +15,9 @@ from marie import Gateway as BaseGateway from marie.excepts import RuntimeFailToStart from marie.helper import get_or_reuse_loop +from marie.job.common import JobInfo, JobStatus +from marie.job.gateway_job_distributor import GatewayJobDistributor +from marie.job.job_manager import JobManager from marie.logging.logger import MarieLogger from marie.proto import jina_pb2, jina_pb2_grpc from marie.serve.discovery import JsonAddress @@ -28,17 +31,14 @@ from marie.serve.runtimes.gateway.streamer import GatewayStreamer from marie.serve.runtimes.servers.composite import CompositeServer from marie.serve.runtimes.servers.grpc import GRPCServer +from marie.storage.kv.in_memory import InMemoryKV +from marie.storage.kv.psql import PostgreSQLKV from marie.types.request import Request from marie.types.request.data import DataRequest, Response from marie.types.request.status import StatusMessage -from marie_server.job.common import JobInfo, JobStatus -from marie_server.job.gateway_job_distributor import GatewayJobDistributor -from marie_server.job.job_manager import JobManager from marie_server.scheduler import PostgreSQLJobScheduler from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.state import WorkState -from marie_server.storage.in_memory import InMemoryKV -from marie_server.storage.psql import PostgreSQLKV def create_balancer_interceptor() -> LoadBalancerInterceptor: @@ -387,7 +387,7 @@ async def handle_job_submit_command(self, message: dict) -> Request: response = Response() response.parameters = { "status": "ok", - "msg": "job submitted", + "msg": f"job submitted with id {job_id}", "job_id": job_id, } From 0a61cc403b99d02984761242b41b19a99c626808 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Wed, 4 Sep 2024 05:22:20 -0500 Subject: [PATCH 02/10] wip: working on scheduler --- marie/job/sync_manager.py | 14 +++ marie_server/scheduler/fixtures.py | 127 ++++++++++++++++++++++++--- marie_server/scheduler/models.py | 1 - marie_server/scheduler/plans.py | 94 ++++++++++++++++++-- marie_server/scheduler/psql.py | 49 ++++++----- poc/custom_gateway/server_gateway.py | 6 +- 6 files changed, 245 insertions(+), 46 deletions(-) create mode 100644 marie/job/sync_manager.py diff --git a/marie/job/sync_manager.py b/marie/job/sync_manager.py new file mode 100644 index 00000000..33b6b2aa --- /dev/null +++ b/marie/job/sync_manager.py @@ -0,0 +1,14 @@ +from typing import Any, Dict + + +class SyncManager: + """ + SyncManager is responsible for synchronizing the state of the scheduler with the state of the executor if they are out of sync. + We will also publish events to the event publisher to notify the user of the state of the Job + + Example : If we restart the scheduler, we need to synchronize the state of the executor jobs that have completed while the scheduler was down. + + """ + + def __init__(self, config: Dict[str, Any]): + self.config = config diff --git a/marie_server/scheduler/fixtures.py b/marie_server/scheduler/fixtures.py index 38dd58b5..aaec2824 100644 --- a/marie_server/scheduler/fixtures.py +++ b/marie_server/scheduler/fixtures.py @@ -33,12 +33,13 @@ def create_job_table(schema: str): return f""" CREATE TABLE {schema}.job ( -- id uuid primary key not null default gen_random_uuid(), - id text primary key not null, + id uuid not null default gen_random_uuid(), + --id text primary key not null, name text not null, priority integer not null default(0), data jsonb, state {schema}.job_state not null default('{WorkState.CREATED.value}'), - retry_limit integer not null default(0), + retry_limit integer not null default(2), retry_count integer not null default(0), retry_delay integer not null default(0), retry_backoff boolean not null default false, @@ -50,12 +51,18 @@ def create_job_table(schema: str): created_on timestamp with time zone not null default now(), completed_on timestamp with time zone, keep_until timestamp with time zone NOT NULL default now() + interval '14 days', - on_complete boolean not null default false, - output jsonb - ) + output jsonb, + dead_letter text, + policy text + ) + PARTITION BY LIST (name) """ +def create_primary_key_job(schema: str) -> str: + return f"ALTER TABLE {schema}.job ADD PRIMARY KEY (name, id)" + + def create_job_history_table(schema: str): return f""" CREATE TABLE {schema}.job_history ( @@ -65,7 +72,7 @@ def create_job_history_table(schema: str): priority integer not null default(0), data jsonb, state {schema}.job_state not null, - retry_limit integer not null default(0), + retry_limit integer not null default(2), retry_count integer not null default(0), retry_delay integer not null default(0), retry_backoff boolean not null default false, @@ -74,9 +81,10 @@ def create_job_history_table(schema: str): expire_in interval not null default interval '15 minutes', created_on timestamp with time zone not null default now(), completed_on timestamp with time zone, - keep_until timestamp with time zone not null default now() + interval '14 days', - on_complete boolean not null default false, + keep_until timestamp with time zone not null default now() + interval '14 days', output jsonb, + dead_letter text, + policy text, history_created_on timestamp with time zone not null default now() ) """ @@ -90,13 +98,13 @@ def create_job_update_trigger_function(schema: str): INSERT INTO {schema}.job_history ( id, name, priority, data, state, retry_limit, retry_count, retry_delay, retry_backoff, start_after, started_on, expire_in, created_on, - completed_on, keep_until, on_complete, output, history_created_on + completed_on, keep_until, output, dead_letter, policy, history_created_on ) SELECT NEW.id, NEW.name, NEW.priority, NEW.data, NEW.state, NEW.retry_limit, NEW.retry_count, NEW.retry_delay, NEW.retry_backoff, NEW.start_after, NEW.started_on, NEW.expire_in, NEW.created_on, NEW.completed_on, - NEW.keep_until, NEW.on_complete, NEW.output, now() as history_created_on + NEW.keep_until, NEW.output, NEW.dead_letter, NEW.policy, now() as history_created_on FROM {schema}.job WHERE id = NEW.id; RETURN NEW; @@ -118,6 +126,25 @@ def clone_job_table_for_archive(schema): return f"CREATE TABLE {schema}.archive (LIKE {schema}.job)" +def create_table_queue(schema: str) -> str: + return f""" + CREATE TABLE {schema}.queue ( + name text, + policy text, + retry_limit int, + retry_delay int, + retry_backoff bool, + expire_seconds int, + retention_minutes int, + dead_letter text REFERENCES {schema}.queue (name), + partition_name text, + created_on timestamp with time zone not null default now(), + updated_on timestamp with time zone not null default now(), + PRIMARY KEY (name) + ) + """ + + def create_schedule_table(schema): return f""" CREATE TABLE {schema}.schedule ( @@ -144,6 +171,86 @@ def create_subscription_table(schema): """ +def delete_queue_function(schema: str) -> str: + return f""" + CREATE FUNCTION {schema}.delete_queue(queue_name text) + RETURNS VOID AS + $$ + DECLARE + table_name varchar; + BEGIN + WITH deleted AS ( + DELETE FROM {schema}.queue + WHERE name = queue_name + RETURNING partition_name + ) + SELECT partition_name FROM deleted INTO table_name; + + EXECUTE format('DROP TABLE IF EXISTS {schema}.%I', table_name); + END; + $$ + LANGUAGE plpgsql; + """ + + +def create_queue_function(schema: str) -> str: + return f""" + CREATE FUNCTION {schema}.create_queue(queue_name text, options json) + RETURNS VOID AS + $$ + DECLARE + table_name varchar := 'j' || encode(sha224(queue_name::bytea), 'hex'); + queue_created_on timestamptz; + BEGIN + + WITH q AS ( + INSERT INTO {schema}.queue ( + name, + policy, + retry_limit, + retry_delay, + retry_backoff, + expire_seconds, + retention_minutes, + dead_letter, + partition_name + ) + VALUES ( + queue_name, + options->>'policy', + (options->>'retry_limit')::int, + (options->>'retry_delay')::int, + (options->>'retry_backoff')::bool, + (options->>'expire_in_seconds')::int, + (options->>'retention_minutes')::int, + options->>'dead_letter', + table_name + ) + ON CONFLICT DO NOTHING + RETURNING created_on + ) + SELECT created_on INTO queue_created_on FROM q; + + IF queue_created_on IS NULL THEN + RETURN; + END IF; + + EXECUTE format('CREATE TABLE {schema}.%I (LIKE {schema}.job INCLUDING DEFAULTS)', table_name); + EXECUTE format('{format_partition_command(create_primary_key_job(schema))}', table_name); + EXECUTE format('ALTER TABLE {schema}.%I ADD CONSTRAINT cjc CHECK (name=%L)', table_name, queue_name); + EXECUTE format('ALTER TABLE {schema}.job ATTACH PARTITION {schema}.%I FOR VALUES IN (%L)', table_name, queue_name); + END; + $$ + LANGUAGE plpgsql; + """ + + +def format_partition_command(command: str) -> str: + return ( + command.replace(".job", ".%1$I").replace("job_i", "%1$s_i").replace("'", "''") + ) + + def add_archived_on_to_archive(schema): return f"ALTER TABLE {schema}.archive ADD archived_on timestamptz NOT NULL DEFAULT now()" diff --git a/marie_server/scheduler/models.py b/marie_server/scheduler/models.py index 033ab11e..a27c2365 100644 --- a/marie_server/scheduler/models.py +++ b/marie_server/scheduler/models.py @@ -21,7 +21,6 @@ class WorkInfo(BaseModel): start_after: datetime expire_in_seconds: int keep_until: datetime - on_complete: bool class ExistingWorkPolicy(Enum): diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index f03fef27..a019e880 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -1,5 +1,7 @@ from datetime import datetime, timezone +from typing import Dict +from marie.utils.json import to_json from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.state import WorkState @@ -14,7 +16,7 @@ def to_timestamp_with_tz(dt: datetime): return datetime.utcfromtimestamp(timestamp).isoformat() + "Z" -def insert_job(schema: str, work_info: WorkInfo) -> str: +def insert_job_v1(schema: str, work_info: WorkInfo) -> str: return f""" INSERT INTO {schema}.job ( id, @@ -27,8 +29,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: data, retry_delay, retry_backoff, - keep_until, - on_complete + keep_until ) SELECT id, @@ -41,8 +42,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: data, retry_delay, retry_backoff, - keep_until, - on_complete + keep_until FROM ( SELECT *, CASE @@ -57,7 +57,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: END as start_after FROM ( SELECT - '{work_info.id}'::text as id, + '{work_info.id}'::uuid as id, '{work_info.name}'::text as name, {work_info.priority}::int as priority, '{WorkState.CREATED.value}'::{schema}.job_state as state, @@ -67,8 +67,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: '{work_info.data}'::jsonb as data, {work_info.retry_delay}::int as retry_delay, {work_info.retry_backoff}::bool as retry_backoff, - '{to_timestamp_with_tz(work_info.keep_until)}'::text as keepUntilValue, - {work_info.on_complete}::boolean as on_complete + '{to_timestamp_with_tz(work_info.keep_until)}'::text as keepUntilValue ) j1 ) j2 ) j3 @@ -77,6 +76,83 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: """ +def insert_job(schema: str, work_info: WorkInfo) -> str: + return f""" + INSERT INTO {schema}.job ( + id, + name, + priority, + state, + data, + start_after, + expire_in, + keep_until, + retry_limit, + retry_delay, + retry_backoff, + policy + ) + SELECT + id, + j.name, + priority, + state, + data, + start_after, + CASE + WHEN expire_in IS NOT NULL THEN CAST(expire_in as interval) + WHEN q.expire_seconds IS NOT NULL THEN q.expire_seconds * interval '1s' + WHEN expire_in_default IS NOT NULL THEN CAST(expire_in_default as interval) + ELSE interval '15 minutes' + END as expire_in, + CASE + WHEN right(keep_until, 1) = 'Z' THEN CAST(keep_until as timestamp with time zone) + --ELSE start_after + CAST(COALESCE(keep_until, (q.retention_minutes * 60)::text, keep_until_default, '14 days') as interval) + ELSE start_after + COALESCE(keep_until::interval, (q.retention_minutes * 60) * interval '1 second', keep_until_default, interval '14 days') + END as keep_until, + + COALESCE(j.retry_limit, q.retry_limit, retry_limit_default, 2) as retry_limit, + CASE + WHEN COALESCE(j.retry_backoff, q.retry_backoff, retry_backoff_default, false) + THEN GREATEST(COALESCE(j.retry_delay, q.retry_delay, retry_delay_default), 1) + ELSE COALESCE(j.retry_delay, q.retry_delay, retry_delay_default, 0) + END as retry_delay, + + COALESCE(j.retry_backoff, q.retry_backoff, retry_backoff_default, false) as retry_backoff, + q.policy + FROM + ( SELECT + '{work_info.id}'::uuid as id, + '{work_info.name}'::text as name, + {work_info.priority}::int as priority, + '{WorkState.CREATED.value}'::{schema}.job_state as state, + {work_info.retry_limit}::int as retry_limit, + '{to_timestamp_with_tz(work_info.start_after)}'::text as start_after, + CAST('{work_info.expire_in_seconds}' as interval) as expire_in, + '{work_info.data}'::jsonb as data, + {work_info.retry_delay}::int as retry_delay, + {work_info.retry_backoff}::bool as retry_backoff, + '{to_timestamp_with_tz(work_info.keep_until)}'::text as keep_until, + + 2::int as retry_limit_default, + 2::int as retry_delay_default, + 0::int as retry_backoff_default, + interval '60s'::interval as expire_in_default, + now() + interval '14 days'::interval as keep_until_default + ) j JOIN {schema}.queue q ON j.name = q.name + ON CONFLICT DO NOTHING + RETURNING id + """ + + +def create_queue(schema: str, queue_name: str, options: Dict[str, str]) -> str: + return f"SELECT {schema}.create_queue('{queue_name}', '{to_json(options)}')" + + +def delete_queue(schema: str, queue_name: str) -> str: + return f"SELECT {schema}.delete_queue({queue_name})" + + def fetch_next_job(schema: str): def query( name: str, @@ -102,7 +178,7 @@ def query( retry_count = CASE WHEN started_on IS NOT NULL THEN retry_count + 1 ELSE retry_count END FROM next WHERE name = '{name}' AND j.id = next.id - RETURNING j.{'*' if include_metadata else 'id,name, priority,state,retry_limit,start_after,expire_in,data,retry_delay,retry_backoff,keep_until,on_complete'} + RETURNING j.{'*' if include_metadata else 'id,name, priority,state,retry_limit,start_after,expire_in,data,retry_delay,retry_backoff,keep_until'} """ return query diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index 1b7be6b9..c5da5df0 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -16,6 +16,7 @@ from marie_server.scheduler.job_scheduler import JobScheduler from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.plans import ( + create_queue, fetch_next_job, insert_job, to_timestamp_with_tz, @@ -26,6 +27,7 @@ MAX_POLL_PERIOD = 16.0 # 16s DEFAULT_SCHEMA = "marie_scheduler" +DEFAULT_JOB_TABLE = "job" COMPLETION_JOB_PREFIX = f"__state__{WorkState.COMPLETED.value}__" @@ -116,8 +118,10 @@ def create_tables(self, schema: str): commands = [ create_schema(schema), create_version_table(schema), + create_table_queue(schema), create_job_state_enum(schema), create_job_table(schema), + create_primary_key_job(schema), create_job_history_table(schema), create_job_update_trigger_function(schema), create_job_update_trigger(schema), @@ -131,6 +135,8 @@ def create_tables(self, schema: str): # create_index_singleton_key_on(schema), create_index_job_name(schema), create_index_job_fetch(schema), + create_queue_function(schema), + delete_queue_function(schema), ] query = ";\n".join(commands) @@ -167,6 +173,18 @@ async def wipe(self) -> None: self.logger.error(f"Error clearing tables: {error}") self.connection.rollback() + async def create_queue(self) -> None: + """Setup the queue for the scheduler.""" + + with self: + try: + self._execute_sql_gracefully( + create_queue(DEFAULT_SCHEMA, "extract", {}) + ) + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error setting up queue: {error}") + self.connection.rollback() + async def start(self) -> None: """ Starts the job scheduling agent. @@ -175,6 +193,7 @@ async def start(self) -> None: """ logger.info("Starting job scheduling agent") self.create_tables(DEFAULT_SCHEMA) + await self.create_queue() self.running = True self.task = asyncio.create_task(self._poll()) @@ -267,7 +286,7 @@ async def get_work_items( try: fetch_query_def = fetch_next_job(DEFAULT_SCHEMA) query = fetch_query_def( - name="WorkInfo-001", + name="extract", # TODO this is a placeholder batch_size=limit, include_metadata=False, priority=True, @@ -291,7 +310,7 @@ async def get_job(self, job_id: str) -> Optional[WorkInfo]: :param job_id: """ schema = DEFAULT_SCHEMA - table = "job" + table = DEFAULT_JOB_TABLE with self: try: @@ -309,8 +328,7 @@ async def get_job(self, job_id: str) -> Optional[WorkInfo]: data, retry_delay, retry_backoff, - keep_until, - on_complete + keep_until FROM {schema}.{table} WHERE id = '{job_id}' """ @@ -328,7 +346,7 @@ async def get_job(self, job_id: str) -> Optional[WorkInfo]: async def list_jobs(self, state: Optional[str] = None) -> Dict[str, WorkInfo]: work_items = {} schema = DEFAULT_SCHEMA - table = "job" + table = DEFAULT_JOB_TABLE states = "','".join(WorkState.__members__.keys()) if state is not None: if state.upper() not in WorkState.__members__: @@ -379,6 +397,8 @@ async def submit_job(self, work_info: WorkInfo, overwrite: bool = True) -> str: new_key_added = False submission_id = work_info.id + work_info.retry_limit = 2 + with self: try: cursor = self._execute_sql_gracefully( @@ -456,24 +476,6 @@ async def put_status( finally: self.connection.commit() - def reset_locked_items(self, schema: str): - query = f""" - UPDATE {schema}.job - SET state = '{WorkState.FAILED.value}', - started_on = NULL, - retry_count = retry_count - 1 - WHERE state = '{WorkState.ACTIVE.value}' - AND started_on IS NOT NULL - AND started_on < now() - interval '1 hour' - """ - with self: - try: - self._execute_sql_gracefully(query) - except (Exception, psycopg2.Error) as error: - self.logger.error(f"Error resetting locked items: {error}") - self.connection.rollback() - self.connection.commit() - async def maintenance(self): """ Performs the maintenance process, including expiring, archiving, and purging. @@ -527,5 +529,4 @@ def record_to_work_info(self, record): retry_delay=record[8], retry_backoff=record[9], keep_until=record[10], - on_complete=record[11], ) diff --git a/poc/custom_gateway/server_gateway.py b/poc/custom_gateway/server_gateway.py index f83eda63..4d59145c 100644 --- a/poc/custom_gateway/server_gateway.py +++ b/poc/custom_gateway/server_gateway.py @@ -18,6 +18,7 @@ from marie.job.common import JobInfo, JobStatus from marie.job.gateway_job_distributor import GatewayJobDistributor from marie.job.job_manager import JobManager +from marie.job.sync_manager import SyncManager from marie.logging.logger import MarieLogger from marie.proto import jina_pb2, jina_pb2_grpc from marie.serve.discovery import JsonAddress @@ -88,6 +89,7 @@ def __init__(self, **kwargs): "password": "123456", } + self.syncer = SyncManager(scheduler_config) self.distributor = GatewayJobDistributor( gateway_streamer=None, logger=self.logger ) @@ -121,7 +123,7 @@ async def _shutdown(): async def job_submit(text: str): self.logger.info(f"Received request at {datetime.now}") work_info = WorkInfo( - name="WorkInfo-001", + name="extract", priority=0, data={}, state=WorkState.CREATED, @@ -368,7 +370,7 @@ async def handle_job_submit_command(self, message: dict) -> Request: :return: The response with the submission result. """ work_info = WorkInfo( - name="WorkInfo-001", + name="extract", priority=0, data={}, state=WorkState.CREATED, From 98d5a18209c2f79d1b402ad593c2d5932a03aed3 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Mon, 9 Sep 2024 06:42:31 -0500 Subject: [PATCH 03/10] wip: scheduler --- marie_server/scheduler/fixtures.py | 1 + marie_server/scheduler/plans.py | 80 ++++++------------------------ marie_server/scheduler/psql.py | 32 ++++++++++-- 3 files changed, 44 insertions(+), 69 deletions(-) diff --git a/marie_server/scheduler/fixtures.py b/marie_server/scheduler/fixtures.py index aaec2824..14cadb2b 100644 --- a/marie_server/scheduler/fixtures.py +++ b/marie_server/scheduler/fixtures.py @@ -54,6 +54,7 @@ def create_job_table(schema: str): output jsonb, dead_letter text, policy text + -- CONSTRAINT job_pkey PRIMARY KEY (name, id) -- adde via partition ) PARTITION BY LIST (name) """ diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index a019e880..7f331c72 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -16,66 +16,6 @@ def to_timestamp_with_tz(dt: datetime): return datetime.utcfromtimestamp(timestamp).isoformat() + "Z" -def insert_job_v1(schema: str, work_info: WorkInfo) -> str: - return f""" - INSERT INTO {schema}.job ( - id, - name, - priority, - state, - retry_limit, - start_after, - expire_in, - data, - retry_delay, - retry_backoff, - keep_until - ) - SELECT - id, - name, - priority, - state, - retry_limit, - start_after, - expire_in, - data, - retry_delay, - retry_backoff, - keep_until - FROM - ( SELECT *, - CASE - WHEN right(keepUntilValue, 1) = 'Z' THEN CAST(keepUntilValue as timestamp with time zone) - ELSE (start_after + CAST(COALESCE(keepUntilValue,'0') as interval)) - END as keep_until - FROM - ( SELECT *, - CASE - WHEN right(startAfterValue, 1) = 'Z' THEN CAST(startAfterValue as timestamp with time zone) - ELSE now() + CAST(COALESCE(startAfterValue,'0') as interval) - END as start_after - FROM - ( SELECT - '{work_info.id}'::uuid as id, - '{work_info.name}'::text as name, - {work_info.priority}::int as priority, - '{WorkState.CREATED.value}'::{schema}.job_state as state, - {work_info.retry_limit}::int as retry_limit, - '{to_timestamp_with_tz(work_info.start_after)}'::text as startAfterValue, - CAST('{work_info.expire_in_seconds}' as interval) as expire_in, - '{work_info.data}'::jsonb as data, - {work_info.retry_delay}::int as retry_delay, - {work_info.retry_backoff}::bool as retry_backoff, - '{to_timestamp_with_tz(work_info.keep_until)}'::text as keepUntilValue - ) j1 - ) j2 - ) j3 - ON CONFLICT DO NOTHING - RETURNING id - """ - - def insert_job(schema: str, work_info: WorkInfo) -> str: return f""" INSERT INTO {schema}.job ( @@ -108,7 +48,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: CASE WHEN right(keep_until, 1) = 'Z' THEN CAST(keep_until as timestamp with time zone) --ELSE start_after + CAST(COALESCE(keep_until, (q.retention_minutes * 60)::text, keep_until_default, '14 days') as interval) - ELSE start_after + COALESCE(keep_until::interval, (q.retention_minutes * 60) * interval '1 second', keep_until_default, interval '14 days') + -- ELSE start_after + COALESCE(keep_until::interval, (q.retention_minutes * 60) * interval '1 second', keep_until_default, interval '14 days') END as keep_until, COALESCE(j.retry_limit, q.retry_limit, retry_limit_default, 2) as retry_limit, @@ -127,7 +67,12 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: {work_info.priority}::int as priority, '{WorkState.CREATED.value}'::{schema}.job_state as state, {work_info.retry_limit}::int as retry_limit, - '{to_timestamp_with_tz(work_info.start_after)}'::text as start_after, + --'{to_timestamp_with_tz(work_info.start_after)}'::text as start_after, + CASE + WHEN right('{to_timestamp_with_tz(work_info.start_after)}', 1) = 'Z' THEN CAST('{to_timestamp_with_tz(work_info.start_after)}' as timestamp with time zone) + ELSE now() + CAST(COALESCE('{to_timestamp_with_tz(work_info.start_after)}','0') as interval) + END as start_after, + CAST('{work_info.expire_in_seconds}' as interval) as expire_in, '{work_info.data}'::jsonb as data, {work_info.retry_delay}::int as retry_delay, @@ -136,7 +81,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: 2::int as retry_limit_default, 2::int as retry_delay_default, - 0::int as retry_backoff_default, + False::boolean as retry_backoff_default, interval '60s'::interval as expire_in_default, now() + interval '14 days'::interval as keep_until_default ) j JOIN {schema}.queue q ON j.name = q.name @@ -146,13 +91,20 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: def create_queue(schema: str, queue_name: str, options: Dict[str, str]) -> str: - return f"SELECT {schema}.create_queue('{queue_name}', '{to_json(options)}')" + # return f"SELECT {schema}.create_queue('{queue_name}', {to_json(options)})" + return f""" + SELECT {schema}.create_queue('{queue_name}', '{{"retry_limit":2}}'::json) + """ def delete_queue(schema: str, queue_name: str) -> str: return f"SELECT {schema}.delete_queue({queue_name})" +def version_table_exists(schema: str) -> str: + return f"SELECT to_regclass('{schema}.version') as name" + + def fetch_next_job(schema: str): def query( name: str, diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index c5da5df0..0b3e9f37 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -20,6 +20,7 @@ fetch_next_job, insert_job, to_timestamp_with_tz, + version_table_exists, ) from marie_server.scheduler.state import WorkState @@ -173,13 +174,25 @@ async def wipe(self) -> None: self.logger.error(f"Error clearing tables: {error}") self.connection.rollback() - async def create_queue(self) -> None: + async def is_installed(self) -> bool: + """check if the tables are installed""" + schema = DEFAULT_SCHEMA + with self: + try: + cursor = self._execute_sql_gracefully(version_table_exists(schema)) + return cursor is not None and cursor.rowcount > 0 + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error clearing tables: {error}") + self.connection.rollback() + return False + + async def create_queue(self, queue_name: str) -> None: """Setup the queue for the scheduler.""" with self: try: self._execute_sql_gracefully( - create_queue(DEFAULT_SCHEMA, "extract", {}) + create_queue(DEFAULT_SCHEMA, queue_name, {}) ) except (Exception, psycopg2.Error) as error: self.logger.error(f"Error setting up queue: {error}") @@ -191,10 +204,19 @@ async def start(self) -> None: :return: None """ + logger.info("Starting job scheduling agent") - self.create_tables(DEFAULT_SCHEMA) - await self.create_queue() - self.running = True + installed = await self.is_installed() + logger.info(f"Tables installed: {installed}") + + if not installed: + self.create_tables(DEFAULT_SCHEMA) + + # TODO : This is a placeholder + queue = "extract" + await self.create_queue(queue) + await self.create_queue(f"${queue}_dlq") + self.task = asyncio.create_task(self._poll()) async def _poll(self): From c3b62790d189eb61ab8583095075e5a4fb794ded Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Wed, 11 Sep 2024 17:34:10 -0500 Subject: [PATCH 04/10] wip: add monitoring to get job status --- marie/storage/database/postgres.py | 4 +- marie_server/scheduler/plans.py | 13 +++++++ marie_server/scheduler/psql.py | 62 +++++++++++++++++++++++++----- 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/marie/storage/database/postgres.py b/marie/storage/database/postgres.py index 96284024..d5a01b61 100644 --- a/marie/storage/database/postgres.py +++ b/marie/storage/database/postgres.py @@ -118,8 +118,8 @@ def _table_exists(self) -> bool: def _execute_sql_gracefully( self, - statement, - data=tuple(), + statement: object, + data: object = tuple(), *, named_cursor_name: Optional[str] = None, itersize: Optional[int] = 10000, diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index 7f331c72..be84306d 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -105,6 +105,19 @@ def version_table_exists(schema: str) -> str: return f"SELECT to_regclass('{schema}.version') as name" +def count_states(schema: str): + return f""" + SELECT name, state, count(*) size + FROM {schema}.job + GROUP BY ROLLUP(name), ROLLUP(state) + """ + + +# Example usage: +# schema = 'public' +# print(count_states(schema)) + + def fetch_next_job(schema: str): def query( name: str, diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index 0b3e9f37..fd868c9c 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -16,6 +16,7 @@ from marie_server.scheduler.job_scheduler import JobScheduler from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.plans import ( + count_states, create_queue, fetch_next_job, insert_job, @@ -27,6 +28,8 @@ INIT_POLL_PERIOD = 1.250 # 250ms MAX_POLL_PERIOD = 16.0 # 16s +MONITORING_POLL_PERIOD = 5.0 # 5s + DEFAULT_SCHEMA = "marie_scheduler" DEFAULT_JOB_TABLE = "job" COMPLETION_JOB_PREFIX = f"__state__{WorkState.COMPLETED.value}__" @@ -62,15 +65,18 @@ def __init__(self, config: Dict[str, Any], job_manager: JobManager): self.running = False self.task = None - self.job_manager = job_manager - self._loop = get_or_reuse_loop() - self._setup_storage(config, connection_only=True) - self._setup_event_subscriptions() + self.monitoring_task = None + lock_free = True self._lock = ( asyncio.Lock() if lock_free else asyncio.Lock() ) # Lock to prevent concurrent access to the database + self.job_manager = job_manager + self._loop = get_or_reuse_loop() + self._setup_storage(config, connection_only=True) + self._setup_event_subscriptions() + async def handle_job_event(self, event_type: str, message: Any): """ Handles a job event. @@ -204,11 +210,9 @@ async def start(self) -> None: :return: None """ - logger.info("Starting job scheduling agent") installed = await self.is_installed() logger.info(f"Tables installed: {installed}") - if not installed: self.create_tables(DEFAULT_SCHEMA) @@ -217,13 +221,14 @@ async def start(self) -> None: await self.create_queue(queue) await self.create_queue(f"${queue}_dlq") + self.running = True self.task = asyncio.create_task(self._poll()) + self.monitoring_task = asyncio.create_task(self._monitor()) async def _poll(self): self.logger.info("Starting database scheduler") wait_time = INIT_POLL_PERIOD sleep_chunk = 0.250 - self.running = True while self.running: self.logger.info(f"Polling for new jobs : {wait_time}") @@ -267,8 +272,11 @@ async def _poll(self): async def stop(self) -> None: self.logger.info("Stopping job scheduling agent") self.running = False + if self.task is not None: await self.task + if self.monitoring_task is not None: + await self.monitoring_task def debug_info(self) -> str: print("Debugging info") @@ -494,9 +502,6 @@ async def put_status( self._execute_sql_gracefully(update_query) except (Exception, psycopg2.Error) as error: self.logger.error(f"Error handling job event: {error}") - self.connection.rollback() - finally: - self.connection.commit() async def maintenance(self): """ @@ -533,6 +538,29 @@ def _setup_event_subscriptions(self): self.handle_job_event, ) + async def count_states(self): + state_count_default = {key.lower(): 0 for key in WorkState.__members__.keys()} + + counts = [] + with self: + try: + cursor = self._execute_sql_gracefully(count_states(DEFAULT_SCHEMA)) + counts = cursor.fetchall() + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error handling job event: {error}") + + states = {"queues": {}} + for item in counts: + name, state, size = item + if name: + if name not in states["queues"]: + states["queues"][name] = state_count_default.copy() + queue = states["queues"].get(name, states) + state = state or "all" + queue[state] = int(size) + + return states + def record_to_work_info(self, record): """ Convert a record to a WorkInfo object. @@ -552,3 +580,17 @@ def record_to_work_info(self, record): retry_backoff=record[9], keep_until=record[10], ) + + async def _monitor(self): + wait_time = MONITORING_POLL_PERIOD + while self.running: + self.logger.debug(f"Polling jobs status : {wait_time}") + await asyncio.sleep(wait_time) + try: + states = await self.count_states() + logger.info(f"job state: {states}") + # TODO: emit event + except Exception as e: + logger.error(f"Error monitoring jobs: {e}") + traceback.print_exc() + # TODO: emit error event From a7580b567ba1548800a152f7494259dc3490af7d Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Thu, 12 Sep 2024 04:10:53 -0500 Subject: [PATCH 05/10] wip: work on scheduler --- marie/job/gateway_job_distributor.py | 1 - marie/job/job_storage_client_proxy.py | 1 + marie/job/job_supervisor.py | 25 ++- .../serve/runtimes/worker/request_handling.py | 9 +- marie/storage/database/postgres.py | 2 +- marie_server/scheduler/fixtures.py | 3 +- marie_server/scheduler/plans.py | 148 +++++++++++++++++- marie_server/scheduler/psql.py | 124 +++++++++++++-- 8 files changed, 272 insertions(+), 41 deletions(-) diff --git a/marie/job/gateway_job_distributor.py b/marie/job/gateway_job_distributor.py index 1b15ab53..aeb3bbde 100644 --- a/marie/job/gateway_job_distributor.py +++ b/marie/job/gateway_job_distributor.py @@ -40,7 +40,6 @@ async def submit_job( self.logger.warning(f"Gateway streamer is not initialized") raise RuntimeError("Gateway streamer is not initialized") - print("job_info.metadata", job_info.metadata) parameters = {"job_id": submission_id} # "#job_info.job_id, if job_info.metadata: parameters.update(job_info.metadata) diff --git a/marie/job/job_storage_client_proxy.py b/marie/job/job_storage_client_proxy.py index 3848acc5..535eb6d1 100644 --- a/marie/job/job_storage_client_proxy.py +++ b/marie/job/job_storage_client_proxy.py @@ -34,6 +34,7 @@ async def put_status( message: Optional[str] = None, jobinfo_replace_kwargs: Optional[Dict[str, Any]] = None, ): + print("put_status called : ", job_id, status) await super().put_status(job_id, status, message, jobinfo_replace_kwargs) await self._event_publisher.publish( status, diff --git a/marie/job/job_supervisor.py b/marie/job/job_supervisor.py index 74375330..0a19dee1 100644 --- a/marie/job/job_supervisor.py +++ b/marie/job/job_supervisor.py @@ -125,19 +125,18 @@ async def run( # Block in PENDING state until start signal received. await _start_signal_actor.wait.remote() - # moved to request_handling - # # this is our gateway address - # driver_agent_http_address = "grpc://127.0.0.1" - # driver_node_id = "CURRENT_NODE_ID" - # - # await self._job_info_client.put_status( - # self._job_id, - # JobStatus.RUNNING, - # jobinfo_replace_kwargs={ - # "driver_agent_http_address": driver_agent_http_address, - # "driver_node_id": driver_node_id, - # }, - # ) + # TODO : This should be moved to the request_handling#_record_started_job + driver_agent_http_address = "grpc://127.0.0.1" + driver_node_id = "CURRENT_NODE_ID" + + await self._job_info_client.put_status( + self._job_id, + JobStatus.RUNNING, + jobinfo_replace_kwargs={ + "driver_agent_http_address": driver_agent_http_address, + "driver_node_id": driver_node_id, + }, + ) # Run the job submission in the background if self.DEFAULT_JOB_TIMEOUT_S > 0: diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index 85e0f68f..6021efd7 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -1359,8 +1359,9 @@ def _init_job_info_client(self): self._job_info_client = JobInfoStorageClient(storage) async def _record_failed_job(self, job_id: str, e: Exception): + return if job_id is not None and self._job_info_client is not None: - print(f"Monitoring JOB: {job_id} - {e}") + self.logger.info(f"Monitoring JOB: {job_id} - {e}") try: await self._job_info_client.put_status( job_id, @@ -1371,8 +1372,9 @@ async def _record_failed_job(self, job_id: str, e: Exception): self.logger.error(f"Error in recording job status: {e}") async def _record_started_job(self, job_id: str, exec_endpoint, requests, params): + return if job_id is not None and self._job_info_client is not None: - print(f"Monitoring JOB: {exec_endpoint} - {job_id}") + self.logger.info(f"Monitoring JOB: {exec_endpoint} - {job_id}") # this is our gateway address driver_agent_http_address = "grpc://127.0.0.1" driver_node_id = "CURRENT_NODE_ID" @@ -1389,8 +1391,9 @@ async def _record_started_job(self, job_id: str, exec_endpoint, requests, params self.logger.error(f"Error in recording job status: {e}") async def _record_successful_job(self, job_id): + return if job_id is not None and self._job_info_client is not None: - print(f"Monitoring JOB: {job_id}") + self.logger.info(f"Monitoring JOB: {job_id}") try: await self._job_info_client.put_status(job_id, JobStatus.SUCCEEDED) except Exception as e: diff --git a/marie/storage/database/postgres.py b/marie/storage/database/postgres.py index d5a01b61..7301f4eb 100644 --- a/marie/storage/database/postgres.py +++ b/marie/storage/database/postgres.py @@ -134,7 +134,7 @@ def _execute_sql_gracefully( cursor.execute(statement, data) else: cursor.execute(statement) - except psycopg2.Error as error: + except (Exception, psycopg2.Error) as error: # except psycopg2.errors.UniqueViolation as error: print(statement) self.logger.debug(f"Error while executing {statement}: {error}.") diff --git a/marie_server/scheduler/fixtures.py b/marie_server/scheduler/fixtures.py index 14cadb2b..6c45f541 100644 --- a/marie_server/scheduler/fixtures.py +++ b/marie_server/scheduler/fixtures.py @@ -10,7 +10,8 @@ def create_version_table(schema: str): CREATE TABLE {schema}.version ( version int primary key, maintained_on timestamp with time zone, - cron_on timestamp with time zone + cron_on timestamp with time zone, + monitored_on timestamp with time zone ) """ diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index be84306d..e7efdc55 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -1,7 +1,8 @@ from datetime import datetime, timezone from typing import Dict -from marie.utils.json import to_json +from psycopg2.extras import Json + from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.state import WorkState @@ -16,6 +17,28 @@ def to_timestamp_with_tz(dt: datetime): return datetime.utcfromtimestamp(timestamp).isoformat() + "Z" +def try_set_maintenance_time(schema: str, maintenance_state_interval_seconds: int): + return try_set_timestamp( + schema, "maintained_on", maintenance_state_interval_seconds + ) + + +def try_set_monitor_time(schema: str, monitor_state_interval_seconds: int): + return try_set_timestamp(schema, "monitored_on", monitor_state_interval_seconds) + + +def try_set_cron_time(schema: str, cron_state_interval_seconds: int): + return try_set_timestamp(schema, "cron_on", cron_state_interval_seconds) + + +def try_set_timestamp(schema: str, column: str, interval: int) -> str: + return f""" + UPDATE {schema}.version SET {column} = now() + WHERE EXTRACT(EPOCH FROM (now() - COALESCE({column}, now() - interval '1 week'))) > {interval} + RETURNING true + """ + + def insert_job(schema: str, work_info: WorkInfo) -> str: return f""" INSERT INTO {schema}.job ( @@ -74,7 +97,7 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: END as start_after, CAST('{work_info.expire_in_seconds}' as interval) as expire_in, - '{work_info.data}'::jsonb as data, + {Json(work_info.data)}::jsonb as data, {work_info.retry_delay}::int as retry_delay, {work_info.retry_backoff}::bool as retry_backoff, '{to_timestamp_with_tz(work_info.keep_until)}'::text as keep_until, @@ -105,7 +128,21 @@ def version_table_exists(schema: str) -> str: return f"SELECT to_regclass('{schema}.version') as name" +def insert_version(schema: str, version: str) -> str: + query = f"INSERT INTO {schema}.version(version) VALUES ('{version}')" + return query + + def count_states(schema: str): + """ + Count the number of jobs in each state. + + Example usage: + schema = 'public' + print(count_states(schema)) + :param schema: + :return: + """ return f""" SELECT name, state, count(*) size FROM {schema}.job @@ -113,9 +150,38 @@ def count_states(schema: str): """ -# Example usage: -# schema = 'public' -# print(count_states(schema)) +def cancel_jobs(schema, name: str, ids: list): + ids_string = "ARRAY[" + ",".join(f"'{str(_id)}'" for _id in ids) + "]" + + return f""" + WITH results AS ( + UPDATE {schema}.job + SET completed_on = now(), + state = '{WorkState.CANCELLED.value}' + WHERE name = {name} + AND id IN (SELECT UNNEST({ids_string}::uuid[])) + AND state < '{WorkState.COMPLETED.value}' + RETURNING 1 + ) + SELECT COUNT(*) FROM results + """ + + +def resume_jobs(schema, name: str, ids: list): + ids_string = "ARRAY[" + ",".join(f"'{str(_id)}'" for _id in ids) + "]" + + return f""" + WITH results AS ( + UPDATE {schema}.job + SET completed_on = NULL, + state = '{WorkState.CREATED.value}' + WHERE name = {name} + AND id IN (SELECT UNNEST({ids_string}::uuid[])) + AND state = '{WorkState.CANCELLED.value}' + RETURNING 1 + ) + SELECT COUNT(*) FROM results + """ def fetch_next_job(schema: str): @@ -147,3 +213,75 @@ def query( """ return query + + +def complete_jobs(schema: str, name: str, ids: list, output: dict): + ids_string = "ARRAY[" + ",".join(f"'{str(_id)}'" for _id in ids) + "]" + query = f""" + WITH results AS ( + UPDATE {schema}.job + SET completed_on = now(), + state = '{WorkState.COMPLETED.value}', + output = {Json(output)}::jsonb + WHERE name = '{name}' + AND id IN (SELECT UNNEST({ids_string}::uuid[])) + AND state = '{WorkState.ACTIVE.value}' + RETURNING * + ) + SELECT COUNT(*) FROM results + """ + return query + + +def fail_jobs_by_id(schema: str, name: str, ids: list, output: dict): + ids_string = "ARRAY[" + ",".join(f"'{str(_id)}'" for _id in ids) + "]" + where = f"name = '{name}' AND id IN (SELECT UNNEST({ids_string}::uuid[])) AND state < '{WorkState.COMPLETED.value}'" + return fail_jobs(schema, where, output) + + +def fail_jobs_by_timeout(schema: str): + where = f"state = '{WorkState.ACTIVE.value}' AND (started_on + expire_in) < now()" + return fail_jobs( + schema, where, {"value": {"message": "job failed by timeout in active state"}} + ) + + +def fail_jobs(schema: str, where: str, output: dict): + query = f""" + WITH results AS ( + UPDATE {schema}.job SET + state = CASE + WHEN retry_count < retry_limit THEN '{WorkState.RETRY.value}'::{schema}.job_state + ELSE '{WorkState.FAILED.value}'::{schema}.job_state + END, + completed_on = CASE + WHEN retry_count < retry_limit THEN NULL + ELSE now() + END, + start_after = CASE + WHEN retry_count = retry_limit THEN start_after + WHEN NOT retry_backoff THEN now() + retry_delay * interval '1' + ELSE now() + ( + retry_delay * 2 ^ LEAST(16, retry_count + 1) / 2 + + retry_delay * 2 ^ LEAST(16, retry_count + 1) / 2 * random() + ) * interval '1' + END, + output = {output} + WHERE {where} + RETURNING * + ), dlq_jobs AS ( + INSERT INTO {schema}.job (name, data, output, retry_limit, keep_until) + SELECT + dead_letter, + data, + output, + retry_limit, + keep_until + (keep_until - start_after) + FROM results + WHERE state = '{WorkState.FAILED.value}' + AND dead_letter IS NOT NULL + AND NOT name = dead_letter + ) + SELECT COUNT(*) FROM results + """ + return query diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index fd868c9c..971e6f3c 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -16,11 +16,17 @@ from marie_server.scheduler.job_scheduler import JobScheduler from marie_server.scheduler.models import WorkInfo from marie_server.scheduler.plans import ( + cancel_jobs, + complete_jobs, count_states, create_queue, + fail_jobs_by_id, fetch_next_job, insert_job, + insert_version, + resume_jobs, to_timestamp_with_tz, + try_set_monitor_time, version_table_exists, ) from marie_server.scheduler.state import WorkState @@ -74,8 +80,8 @@ def __init__(self, config: Dict[str, Any], job_manager: JobManager): self.job_manager = job_manager self._loop = get_or_reuse_loop() - self._setup_storage(config, connection_only=True) self._setup_event_subscriptions() + self._setup_storage(config, connection_only=True) async def handle_job_event(self, event_type: str, message: Any): """ @@ -89,30 +95,25 @@ async def handle_job_event(self, event_type: str, message: Any): self.logger.info(f"received message: {event_type} > {message}") job_id = message.get("job_id") status = JobStatus(event_type) - work_item = await self.get_job(job_id) + work_item: WorkInfo = await self.get_job(job_id) if work_item is None: self.logger.error(f"WorkItem not found: {job_id}") return - - completed_on = None - started_on = None - + work_state = convert_job_status_to_work_state(status) if status == JobStatus.PENDING: self.logger.info(f"Job pending : {job_id}") elif status == JobStatus.SUCCEEDED: self.logger.info(f"Job succeeded : {job_id}") - completed_on = datetime.now() + await self.complete(job_id, work_item) elif status == JobStatus.FAILED: self.logger.info(f"Job failed : {job_id}") + await self.fail(job_id, work_item) elif status == JobStatus.RUNNING: self.logger.info(f"Job running : {job_id}") - started_on = datetime.now() + await self.put_status(job_id, work_state, datetime.now(), None) else: self.logger.error(f"Unhandled status : {status}") - work_state = convert_job_status_to_work_state(status) - await self.put_status(job_id, work_state, started_on, completed_on) - if status.is_terminal(): self.logger.info(f"Job {job_id} is in terminal state {status}") self._reset_on_complete = True @@ -122,6 +123,7 @@ def create_tables(self, schema: str): :param schema: The name of the schema where the tables will be created. :return: None """ + version = 1 commands = [ create_schema(schema), create_version_table(schema), @@ -144,6 +146,7 @@ def create_tables(self, schema: str): create_index_job_fetch(schema), create_queue_function(schema), delete_queue_function(schema), + insert_version(schema, version), ] query = ";\n".join(commands) @@ -186,10 +189,12 @@ async def is_installed(self) -> bool: with self: try: cursor = self._execute_sql_gracefully(version_table_exists(schema)) - return cursor is not None and cursor.rowcount > 0 + if cursor and cursor.rowcount > 0: + result = cursor.fetchone() + if result and result[0] is not None: + return True except (Exception, psycopg2.Error) as error: self.logger.error(f"Error clearing tables: {error}") - self.connection.rollback() return False async def create_queue(self, queue_name: str) -> None: @@ -223,7 +228,7 @@ async def start(self) -> None: self.running = True self.task = asyncio.create_task(self._poll()) - self.monitoring_task = asyncio.create_task(self._monitor()) + # self.monitoring_task = asyncio.create_task(self._monitor()) async def _poll(self): self.logger.info("Starting database scheduler") @@ -464,6 +469,36 @@ async def delete_job(self, job_id: str): raise NotImplementedError + async def cancel_job(self, job_id: str) -> None: + """ + Cancel a job by its ID. + :param job_id: + """ + name = "extract" # TODO this is a placeholder + with self: + try: + self.logger.info(f"Cancelling job: {job_id}") + self._execute_sql_gracefully( + cancel_jobs(DEFAULT_SCHEMA, name, [job_id]) + ) + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error handling job event: {error}") + + async def resume_job(self, job_id: str) -> None: + """ + Resume a job by its ID. + :param job_id: + """ + name = "extract" # TODO this is a placeholder + with self: + try: + self.logger.info(f"Resuming job: {job_id}") + self._execute_sql_gracefully( + resume_jobs(DEFAULT_SCHEMA, name, [job_id]) + ) + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error handling job event: {error}") + async def put_status( self, job_id: str, @@ -540,8 +575,8 @@ def _setup_event_subscriptions(self): async def count_states(self): state_count_default = {key.lower(): 0 for key in WorkState.__members__.keys()} - counts = [] + with self: try: cursor = self._execute_sql_gracefully(count_states(DEFAULT_SCHEMA)) @@ -556,8 +591,7 @@ async def count_states(self): if name not in states["queues"]: states["queues"][name] = state_count_default.copy() queue = states["queues"].get(name, states) - state = state or "all" - queue[state] = int(size) + queue[state or "all"] = int(size) return states @@ -587,6 +621,22 @@ async def _monitor(self): self.logger.debug(f"Polling jobs status : {wait_time}") await asyncio.sleep(wait_time) try: + monitored_on = None + try: + cursor = self._execute_sql_gracefully( + try_set_monitor_time( + DEFAULT_SCHEMA, + monitor_state_interval_seconds=MONITORING_POLL_PERIOD, + ) + ) + monitored_on = cursor.fetchone() + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error handling job event: {error}") + + if monitored_on is None: + self.logger.error("Error setting monitor time") + continue + states = await self.count_states() logger.info(f"job state: {states}") # TODO: emit event @@ -594,3 +644,43 @@ async def _monitor(self): logger.error(f"Error monitoring jobs: {e}") traceback.print_exc() # TODO: emit error event + + async def complete(self, job_id: str, work_item: WorkInfo): + self.logger.info(f"Job completed : {job_id}, {work_item}") + with self: + try: + cursor = self._execute_sql_gracefully( + complete_jobs( + DEFAULT_SCHEMA, + work_item.name, + [job_id], + {"on_complete": "done"}, + ) + ) + counts = cursor.fetchone()[0] + if counts > 0: + self.logger.info(f"Completed job: {job_id} : {counts}") + else: + self.logger.error(f"Error completing job: {job_id}") + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error completing job: {error}") + + async def fail(self, job_id: str, work_item: WorkInfo): + self.logger.info(f"Job failed : {job_id}, {work_item}") + with self: + try: + cursor = self._execute_sql_gracefully( + fail_jobs_by_id( + DEFAULT_SCHEMA, + work_item.name, + [job_id], + {"on_complete": "failed"}, + ) + ) + counts = cursor.fetchone()[0] + if counts > 0: + self.logger.info(f"Completed failed job: {job_id}") + else: + self.logger.error(f"Error completing failed job: {job_id}") + except (Exception, psycopg2.Error) as error: + self.logger.error(f"Error completing job: {error}") From a287324817bf8294d4bba008eb029927f08d2df0 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Wed, 25 Sep 2024 06:31:29 -0500 Subject: [PATCH 06/10] wip: scheduling --- .../meta_template_matching.py | 3 +- marie/job/common.py | 2 +- marie/job/job_manager.py | 4 +-- marie/parsers/server.py | 2 +- marie/pipe/components.py | 2 +- marie/storage/kv/psql.py | 13 +++++++- marie_server/rest_extension.py | 4 +-- marie_server/scheduler/plans.py | 2 +- marie_server/scheduler/psql.py | 33 +++++++++++++------ poc/custom_gateway/direct-flow.py | 10 +++--- 10 files changed, 49 insertions(+), 26 deletions(-) diff --git a/marie/components/template_matching/meta_template_matching.py b/marie/components/template_matching/meta_template_matching.py index 1edec5a5..2cf51885 100644 --- a/marie/components/template_matching/meta_template_matching.py +++ b/marie/components/template_matching/meta_template_matching.py @@ -213,7 +213,7 @@ def predict( if candidates: sorted_candidates = sorted( candidates, - key=lambda x: (x['ngram']), + key=lambda x: (x["ngram"]), reverse=False, ) for sc in sorted_candidates: @@ -308,5 +308,4 @@ def score( total_sim = (sim_val + cos_sim_val + embedding_sim) / 3 sout = f"similarity : {sim_val:<10} - {cos_sim_val:<10} > {embedding_sim:<10} ---- {total_sim:<10} --- {ngram_words}" self.logger.info(sout) - print(sout) return total_sim diff --git a/marie/job/common.py b/marie/job/common.py index a0a083fc..f9fa00b7 100644 --- a/marie/job/common.py +++ b/marie/job/common.py @@ -169,7 +169,7 @@ class JobInfoStorageClient: # Please keep this format in sync with JobDataKey() # in src/ray/gcs/gcs_server/gcs_job_manager.h. - JOB_DATA_KEY_PREFIX = f"{INTERNAL_NAMESPACE_PREFIX}job_info_" + JOB_DATA_KEY_PREFIX = f"{INTERNAL_NAMESPACE_PREFIX}/job_info_" JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}" def __init__(self, storage: StorageArea): diff --git a/marie/job/job_manager.py b/marie/job/job_manager.py index 812707ed..343989e2 100644 --- a/marie/job/job_manager.py +++ b/marie/job/job_manager.py @@ -133,11 +133,11 @@ async def _monitor_job_internal( if job_status.is_terminal(): if job_status == JobStatus.SUCCEEDED: is_alive = False - self.logger.info(f"Job {job_id} succeeded.") + self.logger.info(f"Job succeeded : {job_id}") break elif job_status == JobStatus.FAILED: is_alive = False - self.logger.error(f"Job {job_id} failed.") + self.logger.error(f"Job failed : {job_id}") break if job_status == JobStatus.PENDING: diff --git a/marie/parsers/server.py b/marie/parsers/server.py index 3e0b6cd1..62ffef05 100644 --- a/marie/parsers/server.py +++ b/marie/parsers/server.py @@ -15,7 +15,7 @@ def set_server_parser(parser=None): sp = parser.add_subparsers( dest='ctl_cli', - required=True, + required=False, ) watch_parser = sp.add_parser( diff --git a/marie/pipe/components.py b/marie/pipe/components.py index 1bf57ceb..14360709 100644 --- a/marie/pipe/components.py +++ b/marie/pipe/components.py @@ -410,7 +410,7 @@ def setup_template_matching( if key not in pipeline_config: logger.warning(f"Missing {key} in pipeline config, using default config") - return NoopDocumentBoundaryRegistration() + return None, None config = pipeline_config[key] if key in pipeline_config else {} diff --git a/marie/storage/kv/psql.py b/marie/storage/kv/psql.py index fc85d0da..25856eca 100644 --- a/marie/storage/kv/psql.py +++ b/marie/storage/kv/psql.py @@ -111,7 +111,18 @@ async def internal_kv_del( namespace: Optional[bytes], timeout: Optional[float] = None, ) -> int: - raise NotImplementedError + self.logger.debug(f"internal_kv_del: {key!r}, {namespace!r}, {del_by_prefix}") + if namespace is None: + namespace = b"DEFAULT" + + if del_by_prefix: + raise NotImplementedError + else: + query = f"DELETE FROM {self.table} WHERE key = '{key.decode()}' AND namespace = '{namespace.decode()}'" + cursor = self._execute_sql_gracefully(query, data=()) + if cursor is not None: + return 1 + return 0 async def internal_kv_exists( self, key: bytes, namespace: Optional[bytes], timeout: Optional[float] = None diff --git a/marie_server/rest_extension.py b/marie_server/rest_extension.py index 8a6b712d..673dc384 100644 --- a/marie_server/rest_extension.py +++ b/marie_server/rest_extension.py @@ -12,13 +12,13 @@ from marie._core.utils import run_background_task from marie.api import extract_payload_to_uri, value_from_payload_or_args from marie.api.docs import AssetKeyDoc +from marie.job.job_manager import generate_job_id from marie.logging.mdc import MDC from marie.logging.predefined import default_logger as logger from marie.messaging import mark_as_complete, mark_as_failed, mark_as_started from marie.messaging.publisher import mark_as_scheduled from marie.types.request.data import DataRequest from marie.utils.types import strtobool -from marie_server.job.job_manager import generate_job_id if TYPE_CHECKING: # pragma: no cover from fastapi import FastAPI, Request @@ -308,7 +308,7 @@ async def process_document_request( parameters=parameters, request_size=-1, return_responses=True, - prefetch=4 + prefetch=4, # return_type=OutputDoc, ): payload = parse_response_to_payload(resp, expect_return_value=False) diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index e7efdc55..cb63fa09 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -266,7 +266,7 @@ def fail_jobs(schema: str, where: str, output: dict): retry_delay * 2 ^ LEAST(16, retry_count + 1) / 2 * random() ) * interval '1' END, - output = {output} + output = {Json(output)}::jsonb WHERE {where} RETURNING * ), dlq_jobs AS ( diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index 971e6f3c..a751f1d5 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import traceback from contextlib import AsyncExitStack from datetime import datetime @@ -75,7 +76,7 @@ def __init__(self, config: Dict[str, Any], job_manager: JobManager): lock_free = True self._lock = ( - asyncio.Lock() if lock_free else asyncio.Lock() + contextlib.AsyncExitStack() if lock_free else asyncio.Lock() ) # Lock to prevent concurrent access to the database self.job_manager = job_manager @@ -263,9 +264,14 @@ async def _poll(self): for record in records: has_records = True work_item = self.record_to_work_info(record) - has_available_slots, job_id = await self.enqueue(work_item) - self.logger.info(f"Work item scheduled with ID: {job_id}") - if not has_available_slots: + job_id = await self.enqueue(work_item) + if job_id is None: + self.logger.error( + f"Error scheduling work item: {work_item.id}" + ) + else: + self.logger.info(f"Work item scheduled with ID: {job_id}") + if not self.job_manager.has_available_slot(): self.logger.info( f"No more available slots for work, waiting for slots :{wait_time}" ) @@ -286,7 +292,7 @@ async def stop(self) -> None: def debug_info(self) -> str: print("Debugging info") - async def enqueue(self, work_info: WorkInfo) -> tuple[bool, str]: + async def enqueue(self, work_info: WorkInfo) -> str: """ Enqueues a work item for processing on the next available executor. @@ -297,13 +303,20 @@ async def enqueue(self, work_info: WorkInfo) -> tuple[bool, str]: self.logger.info( f"No available slots for work, scheduling : {work_info.id}" ) - return False, None + return None submission_id = work_info.id - returned_id = await self.job_manager.submit_job( - entrypoint="echo hello", submission_id=submission_id - ) - return True, returned_id + # FIXME : This is a hack to allow the job to be re-submitted after a failure + await self.job_manager.job_info_client().delete_info(submission_id) + + try: + returned_id = await self.job_manager.submit_job( + entrypoint="echo hello", submission_id=submission_id + ) + except ValueError as e: + self.logger.error(f"Error submitting job: {e}") + return None + return returned_id async def get_work_items( self, diff --git a/poc/custom_gateway/direct-flow.py b/poc/custom_gateway/direct-flow.py index 5600a2b1..82693b73 100644 --- a/poc/custom_gateway/direct-flow.py +++ b/poc/custom_gateway/direct-flow.py @@ -58,16 +58,16 @@ def func_extract( print(f"FirstExec func called : {len(docs)}, {parameters}") # randomly throw an error to test the error handling - if random.random() > 0.5: + if random.random() > 0: raise Exception("random error") for doc in docs: doc.text += " First Exec" - print("Sleeping for 5 seconds : ", time.time()) - time.sleep(5) - - print("Sleeping for 5 seconds - done: ", time.time()) + sec = 1 + print(f"Sleeping for {sec} seconds : ", time.time()) + time.sleep(1) + print(f"Sleeping for {sec} seconds - done : ", time.time()) return { "parameters": parameters, "data": "Data reply", From 4ab732f59abd2c1fa12973f18b1a6159906ffc2a Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Fri, 27 Sep 2024 00:36:53 -0500 Subject: [PATCH 07/10] feat: externalize eponential backoff for easier work --- marie_server/scheduler/fixtures.py | 21 ++++++++++++++++++--- marie_server/scheduler/plans.py | 5 +---- marie_server/scheduler/psql.py | 1 + 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/marie_server/scheduler/fixtures.py b/marie_server/scheduler/fixtures.py index 6c45f541..2e6fd79d 100644 --- a/marie_server/scheduler/fixtures.py +++ b/marie_server/scheduler/fixtures.py @@ -267,13 +267,15 @@ def add_id_index_to_archive(schema): def create_index_singleton_on(schema): return f""" - CREATE UNIQUE INDEX job_singleton_on ON {schema}.job (name, singleton_on) WHERE state < '{WorkState.EXPIRED.value}' AND singleton_key IS NULL + CREATE UNIQUE INDEX job_singleton_on ON {schema}.job (name, singleton_on) + WHERE state < '{WorkState.EXPIRED.value}' AND singleton_key IS NULL """ def create_index_singleton_key_on(schema): return f""" - CREATE UNIQUE INDEX job_singleton_key_on ON {schema}.job (name, singleton_on, singleton_key) WHERE state < '{WorkState.EXPIRED.value}' + CREATE UNIQUE INDEX job_singleton_key_on ON {schema}.job (name, singleton_on, singleton_key) + WHERE state < '{WorkState.EXPIRED.value}' """ @@ -285,5 +287,18 @@ def create_index_job_name(schema): def create_index_job_fetch(schema): return f""" - CREATE INDEX job_fetch ON {schema}.job (name text_pattern_ops, start_after) WHERE state < '{WorkState.ACTIVE.value}' + CREATE INDEX job_fetch ON {schema}.job (name text_pattern_ops, start_after) + WHERE state < '{WorkState.ACTIVE.value}' + """ + + +def create_exponential_backoff_function(schema): + return f""" + CREATE OR REPLACE FUNCTION exponential_backoff(retry_delay INT, retry_count INT) + RETURNS TIMESTAMP WITH TIME ZONE AS $$ + BEGIN + RETURN now() + (retry_delay * (2 ^ LEAST(16, retry_count + 1) / 2) + + retry_delay * (2 ^ LEAST(16, retry_count + 1) / 2) * random()) * INTERVAL '1 second'; + END; + $$ LANGUAGE plpgsql; """ diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index cb63fa09..240e0d85 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -261,10 +261,7 @@ def fail_jobs(schema: str, where: str, output: dict): start_after = CASE WHEN retry_count = retry_limit THEN start_after WHEN NOT retry_backoff THEN now() + retry_delay * interval '1' - ELSE now() + ( - retry_delay * 2 ^ LEAST(16, retry_count + 1) / 2 + - retry_delay * 2 ^ LEAST(16, retry_count + 1) / 2 * random() - ) * interval '1' + ELSE {schema}.exponential_backoff(retry_delay, retry_count) END, output = {Json(output)}::jsonb WHERE {where} diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index a751f1d5..2ab0872b 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -148,6 +148,7 @@ def create_tables(self, schema: str): create_queue_function(schema), delete_queue_function(schema), insert_version(schema, version), + create_exponential_backoff_function(schema), ] query = ";\n".join(commands) From f7d807c0cc12f914de803eb4fb563c900a341d30 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Fri, 27 Sep 2024 07:08:26 -0500 Subject: [PATCH 08/10] wip: work scheduler --- marie/job/job_supervisor.py | 75 ++++++++--- marie/job/sync_manager.py | 40 +++++- marie/serve/executors/__init__.py | 29 +++-- .../serve/runtimes/worker/request_handling.py | 118 ++++++++++++++---- marie/storage/kv/psql.py | 74 ++++++++--- marie/utils/network.py | 31 +++-- marie_server/scheduler/plans.py | 13 +- marie_server/scheduler/psql.py | 45 ++++--- poc/custom_gateway/direct-flow.py | 8 +- poc/custom_gateway/server_gateway.py | 3 +- tests/core/test_job_manager.py | 36 +++--- 11 files changed, 345 insertions(+), 127 deletions(-) diff --git a/marie/job/job_supervisor.py b/marie/job/job_supervisor.py index 0a19dee1..d83e5376 100644 --- a/marie/job/job_supervisor.py +++ b/marie/job/job_supervisor.py @@ -129,16 +129,27 @@ async def run( driver_agent_http_address = "grpc://127.0.0.1" driver_node_id = "CURRENT_NODE_ID" - await self._job_info_client.put_status( - self._job_id, - JobStatus.RUNNING, - jobinfo_replace_kwargs={ - "driver_agent_http_address": driver_agent_http_address, - "driver_node_id": driver_node_id, - }, - ) + # check if we a calling floating executor if so then we need to update the job status to RUNNING + # as floating executors are not part of the main deployment and they don't update the job status. + # TODO : need to get this from the request_info + + floating_executor = False + + if floating_executor: + self.logger.info( + f"Job {self._job_id} is running on a floating executor. " + f"Updating the job status to RUNNING." + ) + await self._job_info_client.put_status( + self._job_id, + JobStatus.RUNNING, + jobinfo_replace_kwargs={ + "driver_agent_http_address": driver_agent_http_address, + "driver_node_id": driver_node_id, + }, + ) - # Run the job submission in the background + # invoke the job submission in the background if self.DEFAULT_JOB_TIMEOUT_S > 0: try: await asyncio.wait_for( @@ -186,13 +197,24 @@ async def _submit_job_in_background(self, curr_info): print("Response docs: ", response.data.docs) print("Response status: ", response.status) + # This monitoring strategy allows us to have Floating Executors that can be used to run jobs outside of the main + # deployment. This is useful for running jobs that are not part of the main deployment, but are still part of the + # same deployment workflow. + # Example would be calling a custom API that we don't control. + + # If the job is already in a terminal state, then we don't need to update it. This can happen if the + # job was cancelled while the job was being submitted. + # or while the job was marked from the EXECUTOR worker node as "STOPPED", "SUCCEEDED", "FAILED". + + job_status = await self._job_info_client.get_status(self._job_id) + print( + "Job status from _submit_job_in_background: ", + job_status, + job_status.is_terminal(), + ) + if response.status.code == jina_pb2.StatusProto.SUCCESS: - job_status = await self._job_info_client.get_status(self._job_id) - # "STOPPED", "SUCCEEDED", "FAILED" if job_status.is_terminal(): - # If the job is already in a terminal state, then we don't need to update it. This can happen if the - # job was cancelled while the job was being submitted. - # or while the job was marked from the executor side. self.logger.warning( f"Job {self._job_id} is already in terminal state {job_status}." ) @@ -213,11 +235,26 @@ async def _submit_job_in_background(self, curr_info): else: # FIXME : Need to store the exception in the job info e: jina_pb2.StatusProto.ExceptionProto = response.status.exception - name = str(e.name) - # stack = to_json(e.stacks) - await self._job_info_client.put_status( - self._job_id, JobStatus.FAILED, message=f"{name}" - ) + if job_status.is_terminal(): + self.logger.warning( + f"Job {self._job_id} is already in terminal state {job_status}." + ) + # triggers the event to update the WorkStatus + await self._event_publisher.publish( + job_status, + { + "job_id": self._job_id, + "status": job_status, + "message": f"Job {self._job_id} is already in terminal state {job_status}.", + "jobinfo_replace_kwargs": False, + }, + ) + else: + name = str(e.name) + # stack = to_json(e.stacks) + await self._job_info_client.put_status( + self._job_id, JobStatus.FAILED, message=f"{name}" + ) except Exception as e: await self._job_info_client.put_status( self._job_id, JobStatus.FAILED, message=str(e) diff --git a/marie/job/sync_manager.py b/marie/job/sync_manager.py index 33b6b2aa..ae1e50b7 100644 --- a/marie/job/sync_manager.py +++ b/marie/job/sync_manager.py @@ -1,5 +1,9 @@ from typing import Any, Dict +from marie.job.common import JobInfoStorageClient +from marie.logging.logger import MarieLogger +from marie.storage.database.postgres import PostgresqlMixin + class SyncManager: """ @@ -7,8 +11,42 @@ class SyncManager: We will also publish events to the event publisher to notify the user of the state of the Job Example : If we restart the scheduler, we need to synchronize the state of the executor jobs that have completed while the scheduler was down. + Downside to this is that floating executors will not be able to sync their state with the scheduler during the time the scheduler is down. """ - def __init__(self, config: Dict[str, Any]): + def __init__( + self, + config: Dict[str, Any], + job_info_client: JobInfoStorageClient, + psql_mixin: PostgresqlMixin, + ): self.config = config + self.logger = MarieLogger(self.__class__.__name__) + self.job_info_client = job_info_client + self.psql_mixin = psql_mixin + + print("SyncManager init called") + print(job_info_client) + print(psql_mixin) + + self.run_sync() + + async def start(self) -> None: + """ + Starts the job synchronization agent. + + :return: None + """ + + pass + + def run_sync(self): + """ + Run the synchronization process. + + :return: None + """ + print("Running sync") + + # Get all the jobs from the scheduler that are not in the TERMINAL state diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py index 92b51fcb..1cc1205b 100644 --- a/marie/serve/executors/__init__.py +++ b/marie/serve/executors/__init__.py @@ -834,17 +834,26 @@ async def wrapper(*args, **kwargs): async def exec_func( summary, histogram, histogram_metric_labels, tracing_context ): - with MetricsTimer(summary, histogram, histogram_metric_labels): - if iscoroutinefunction(func): - return await func(self, tracing_context=tracing_context, **kwargs) - else: - async with self._lock: - return await get_or_reuse_loop().run_in_executor( - None, - functools.partial( - func, self, tracing_context=tracing_context, **kwargs - ), + try: + with MetricsTimer(summary, histogram, histogram_metric_labels): + if iscoroutinefunction(func): + return await func( + self, tracing_context=tracing_context, **kwargs ) + else: + async with self._lock: + return await get_or_reuse_loop().run_in_executor( + None, + functools.partial( + func, + self, + tracing_context=tracing_context, + **kwargs, + ), + ) + except Exception as e: + self.logger.error(f"Error while executing {req_endpoint} endpoint: {e}") + raise e runtime_name = ( self.runtime_args.name if hasattr(self.runtime_args, "name") else None diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index 6021efd7..6bf05f22 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -3,8 +3,10 @@ import functools import json import os +import sys import tempfile import threading +import traceback import uuid import warnings from typing import ( @@ -32,6 +34,8 @@ from marie.serve.runtimes.worker.batch_queue import BatchQueue from marie.storage.kv.psql import PostgreSQLKV from marie.types.request.data import DataRequest, SingleDocumentRequest +from marie.utils.network import get_ip_address +from marie.utils.types import strtobool if docarray_v2: from docarray import DocList @@ -84,6 +88,7 @@ def __init__( self.args = args self.logger = logger self._is_closed = False + if self.metrics_registry: with ImportExtensions( required=True, @@ -682,12 +687,11 @@ async def handle( requests, params = self._setup_requests(requests, exec_endpoint) - print("requests", requests) - print("params", params) + self.logger.info(f"requests TO MONITOR : {requests}") job_id = None if params is not None: job_id = params.get("job_id", None) - await self._record_started_job(job_id, exec_endpoint, requests, params) + await self._record_started_job(job_id, requests, params) len_docs = len(requests[0].docs) # TODO we can optimize here and access the if exec_endpoint in self._batchqueue_config: @@ -730,9 +734,19 @@ async def handle( tracing_context=tracing_context, ) _ = self._set_result(requests, return_data, docs) - except Exception as e: - await self._record_failed_job(job_id, e) + await self._record_successful_job(job_id, requests) + except asyncio.CancelledError: + print("Task was cancelled due to client disconnect") + raise + except BaseException as e: + self.logger.error(f"Error during __acall__: {e}") + await self._record_failed_job(job_id, requests, e) raise e + finally: + pass + # we do this here to make sure that we record the successful job even if the response fails back to the gateway + # if not failed: + # await self._record_successful_job(job_id, requests) for req in requests: req.add_executor(self.deployment_name) @@ -743,8 +757,6 @@ async def handle( except AttributeError: pass self._record_response_size_monitoring(requests) - - await self._record_successful_job(job_id) return requests[0] @staticmethod @@ -1350,7 +1362,7 @@ def _init_job_info_client(self): "username": "postgres", "password": "123456", "database": "postgres", - "default_table": "kv_store_a", + "default_table": "kv_store_worker", "max_pool_size": 5, "max_connections": 5, } @@ -1358,43 +1370,97 @@ def _init_job_info_client(self): storage = PostgreSQLKV(config=kv_storage_config, reset=False) self._job_info_client = JobInfoStorageClient(storage) - async def _record_failed_job(self, job_id: str, e: Exception): - return + async def _record_failed_job( + self, + job_id: str, + requests: List["DataRequest"], + e: Exception, + ): + print(f"Record job failed: {job_id} - {e}") if job_id is not None and self._job_info_client is not None: - self.logger.info(f"Monitoring JOB: {job_id} - {e}") try: + # Extract the traceback information from the exception + tb = e.__traceback__ + while tb.tb_next: + tb = tb.tb_next + + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + # Clear the frames after extracting the information to avoid memory leaks + traceback.clear_frames(tb) + + detail = "Internal Server Error" + silence_exceptions = strtobool( + os.environ.get("MARIE_SILENCE_EXCEPTIONS", "false") + ) + + if not silence_exceptions: + detail = str(e) + + exc = { + "type": type(e).__name__, + "message": detail, + "filename": filename.split("/")[-1], + "name": name, + "line_no": line_no, + } + await self._job_info_client.put_status( job_id, JobStatus.FAILED, - jobinfo_replace_kwargs={"error_message": str(e)}, + jobinfo_replace_kwargs={ + "metadata": { + "attributes": self._request_attributes(requests), + "error": exc, + } + }, ) except Exception as e: self.logger.error(f"Error in recording job status: {e}") - async def _record_started_job(self, job_id: str, exec_endpoint, requests, params): - return + async def _record_started_job( + self, job_id: str, requests: List["DataRequest"], params + ): + print(f"Record job started: {job_id}") if job_id is not None and self._job_info_client is not None: - self.logger.info(f"Monitoring JOB: {exec_endpoint} - {job_id}") - # this is our gateway address - driver_agent_http_address = "grpc://127.0.0.1" - driver_node_id = "CURRENT_NODE_ID" try: await self._job_info_client.put_status( job_id, JobStatus.RUNNING, jobinfo_replace_kwargs={ - "driver_agent_http_address": driver_agent_http_address, - "driver_node_id": driver_node_id, + "metadata": { + "params": params, + "attributes": self._request_attributes(requests), + } }, ) except Exception as e: - self.logger.error(f"Error in recording job status: {e}") + self.logger.error(f"Error recording job status: {e}") - async def _record_successful_job(self, job_id): - return + async def _record_successful_job(self, job_id: str, requests: List["DataRequest"]): + print(f"Record job success: {job_id}") if job_id is not None and self._job_info_client is not None: - self.logger.info(f"Monitoring JOB: {job_id}") try: - await self._job_info_client.put_status(job_id, JobStatus.SUCCEEDED) + await self._job_info_client.put_status( + job_id, + JobStatus.SUCCEEDED, + jobinfo_replace_kwargs={ + "metadata": {"attributes": self._request_attributes(requests)} + }, + ) except Exception as e: - self.logger.error(f"Error in recording job status: {e}") + self.logger.error(f"Error recording job status: {e}") + + def _request_attributes(self, requests: List["DataRequest"]) -> Dict: + exec_endpoint: str = requests[0].header.exec_endpoint + if exec_endpoint not in self._executor.requests: + if __default_endpoint__ in self._executor.requests: + exec_endpoint = __default_endpoint__ + + return { + "executor_endpoint": exec_endpoint, + "executor": self._executor.__class__.__name__, + "runtime_name": self.args.name, + "host": get_ip_address(flush_cache=False), + } diff --git a/marie/storage/kv/psql.py b/marie/storage/kv/psql.py index 25856eca..029bd34c 100644 --- a/marie/storage/kv/psql.py +++ b/marie/storage/kv/psql.py @@ -26,22 +26,63 @@ def __init__(self, config: Dict[str, Any], reset=True): ) def create_table_callback(self, table_name: str): + """ + :param table_name: Name of the table to be created. + :return: None + """ self.logger.info(f"Creating table : {table_name}") self._execute_sql_gracefully( f""" - CREATE TABLE IF NOT EXISTS {self.table} ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - namespace VARCHAR(1024) NULL, - key VARCHAR(1024) NOT NULL, - value JSONB NULL, - shard int DEFAULT 0, - created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - updated_at timestamp with time zone DEFAULT NULL, - is_deleted BOOL DEFAULT FALSE - ); - CREATE UNIQUE INDEX idx_{self.table}_ns_key ON {self.table} (namespace, key); - """, + CREATE TABLE IF NOT EXISTS {self.table} ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + namespace VARCHAR(1024) NULL, + key VARCHAR(1024) NOT NULL, + value JSONB NULL, + shard int DEFAULT 0, + created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP, + updated_at timestamp with time zone DEFAULT NULL, + is_deleted BOOL DEFAULT FALSE + ); + CREATE UNIQUE INDEX idx_{self.table}_ns_key ON {self.table} (namespace, key); + + CREATE TABLE IF NOT EXISTS {self.table}_history ( + history_id SERIAL PRIMARY KEY, + id UUID, + namespace VARCHAR(1024), + key VARCHAR(1024), + value JSONB, + shard int, + created_at timestamp with time zone, + updated_at timestamp with time zone, + is_deleted BOOL, + change_time timestamp with time zone DEFAULT CURRENT_TIMESTAMP, + operation CHAR(1) CHECK (operation IN ('I', 'U', 'D')) + ); + + CREATE OR REPLACE FUNCTION log_changes_{self.table}() RETURNS TRIGGER AS $$ + BEGIN + IF (TG_OP = 'INSERT') THEN + INSERT INTO {self.table}_history (id, namespace, key, value, shard, created_at, updated_at, is_deleted, operation) + VALUES (NEW.id, NEW.namespace, NEW.key, NEW.value, NEW.shard, NEW.created_at, NEW.updated_at, NEW.is_deleted, 'I'); + RETURN NEW; + ELSIF (TG_OP = 'UPDATE') THEN + INSERT INTO {self.table}_history (id, namespace, key, value, shard, created_at, updated_at, is_deleted, operation) + VALUES (NEW.id, NEW.namespace, NEW.key, NEW.value, NEW.shard, NEW.created_at, NEW.updated_at, NEW.is_deleted, 'U'); + RETURN NEW; + ELSIF (TG_OP = 'DELETE') THEN + INSERT INTO {self.table}_history (id, namespace, key, value, shard, created_at, updated_at, is_deleted, operation) + VALUES (OLD.id, OLD.namespace, OLD.key, OLD.value, OLD.shard, OLD.created_at, OLD.updated_at, OLD.is_deleted, 'D'); + RETURN OLD; + END IF; + RETURN NULL; + END; + $$ LANGUAGE plpgsql; + + CREATE TRIGGER log_changes_{self.table}_trigger + AFTER INSERT OR UPDATE OR DELETE ON {self.table} + FOR EACH ROW EXECUTE FUNCTION log_changes_{self.table}(); + """ ) async def internal_kv_get( @@ -139,7 +180,6 @@ async def internal_kv_keys( try: query = f"SELECT key FROM {self.table} WHERE namespace = '{namespace.decode()}' AND is_deleted = FALSE" for record in self._execute_sql_gracefully(query, data=()): - print(result) result.append(record[0]) except (Exception, psycopg2.Error) as error: self.logger.error(f"Error executing sql statement: {error}") @@ -147,8 +187,12 @@ async def internal_kv_keys( def internal_kv_reset(self) -> None: self.logger.info(f"internal_kv_reset : {self.table}") - query = f"DROP TABLE IF EXISTS {self.table}" - self._execute_sql_gracefully(query) + + self._execute_sql_gracefully(f"DROP TABLE IF EXISTS {self.table}") + self._execute_sql_gracefully(f"DROP TABLE IF EXISTS {self.table}_history") + self._execute_sql_gracefully( + f"DROP FUNCTION IF EXISTS log_changes_{self.table} CASCADE" + ) def debug_info(self) -> str: return "PostgreSQLKV" diff --git a/marie/utils/network.py b/marie/utils/network.py index bf5fd2ea..c0abc4ce 100644 --- a/marie/utils/network.py +++ b/marie/utils/network.py @@ -22,19 +22,24 @@ def is_docker(): ) -def get_ip_address(): +_cached_ip_address = None + + +def get_ip_address(flush_cache=False): """ + Get the IP address of the current machine. Caches the result for future calls. + Set `flush_cache` to True to refresh the cached IP address. + https://stackoverflow.com/questions/24196932/how-can-i-get-the-ip-address-from-nic-in-python """ - # TODO : Add support for IP detection - # if there is an access to external network we can try this - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(("8.8.8.8", 80)) - sockname = s.getsockname() - return s.getsockname()[0] - except Exception as e: - # raise e # For debug - pass - - return "127.0.0.1" + global _cached_ip_address + + if flush_cache or _cached_ip_address is None: + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + _cached_ip_address = s.getsockname()[0] + except Exception as e: + _cached_ip_address = "127.0.0.1" + + return _cached_ip_address diff --git a/marie_server/scheduler/plans.py b/marie_server/scheduler/plans.py index 240e0d85..d29daa53 100644 --- a/marie_server/scheduler/plans.py +++ b/marie_server/scheduler/plans.py @@ -114,7 +114,6 @@ def insert_job(schema: str, work_info: WorkInfo) -> str: def create_queue(schema: str, queue_name: str, options: Dict[str, str]) -> str: - # return f"SELECT {schema}.create_queue('{queue_name}', {to_json(options)})" return f""" SELECT {schema}.create_queue('{queue_name}', '{{"retry_limit":2}}'::json) """ @@ -282,3 +281,15 @@ def fail_jobs(schema: str, where: str, output: dict): SELECT COUNT(*) FROM results """ return query + + +def get_active_jobs(schema: str) -> str: + """ + Get all items in the active state. + :param schema: The schema name. + """ + return f""" + SELECT * + FROM {schema}.job + WHERE state = '{WorkState.ACTIVE.value}' + """ diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index 2ab0872b..c4687649 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -1,9 +1,8 @@ import asyncio import contextlib import traceback -from contextlib import AsyncExitStack from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, Dict, List, Optional import psycopg2 @@ -36,6 +35,7 @@ MAX_POLL_PERIOD = 16.0 # 16s MONITORING_POLL_PERIOD = 5.0 # 5s +SYNC_POLL_PERIOD = 5.0 # 5s DEFAULT_SCHEMA = "marie_scheduler" DEFAULT_JOB_TABLE = "job" @@ -97,6 +97,7 @@ async def handle_job_event(self, event_type: str, message: Any): job_id = message.get("job_id") status = JobStatus(event_type) work_item: WorkInfo = await self.get_job(job_id) + if work_item is None: self.logger.error(f"WorkItem not found: {job_id}") return @@ -229,7 +230,8 @@ async def start(self) -> None: await self.create_queue(f"${queue}_dlq") self.running = True - self.task = asyncio.create_task(self._poll()) + self.sync_task = asyncio.create_task(self._sync()) + # self.task = asyncio.create_task(self._poll()) # self.monitoring_task = asyncio.create_task(self._monitor()) async def _poll(self): @@ -392,7 +394,9 @@ async def get_job(self, job_id: str) -> Optional[WorkInfo]: finally: self.connection.commit() - async def list_jobs(self, state: Optional[str] = None) -> Dict[str, WorkInfo]: + async def list_jobs( + self, state: Optional[str] = None, batch_size: int = 0 + ) -> Dict[str, WorkInfo]: work_items = {} schema = DEFAULT_SCHEMA table = DEFAULT_JOB_TABLE @@ -409,23 +413,11 @@ async def list_jobs(self, state: Optional[str] = None) -> Dict[str, WorkInfo]: cursor.itersize = 10000 cursor.execute( f""" - SELECT - id, - name, - priority, - state, - retry_limit, - start_after, - expire_in, - data, - retry_delay, - retry_backoff, - keep_until, - on_complete + SELECT id,name, priority,state,retry_limit,start_after,expire_in,data,retry_delay,retry_backoff,keep_until FROM {schema}.{table} WHERE state IN ('{states}') + {f"LIMIT {batch_size}" if batch_size > 0 else ""} """ - # + (f" limit = {limit}" if limit > 0 else "") ) for record in cursor: work_items[record[0]] = self.record_to_work_info(record) @@ -698,3 +690,20 @@ async def fail(self, job_id: str, work_item: WorkInfo): self.logger.error(f"Error completing failed job: {job_id}") except (Exception, psycopg2.Error) as error: self.logger.error(f"Error completing job: {error}") + + async def _sync(self): + wait_time = SYNC_POLL_PERIOD + while self.running: + self.logger.info(f"Syncing jobs status : {wait_time}") + await asyncio.sleep(wait_time) + + try: + active_jobs = await self.list_jobs(state=WorkState.ACTIVE.value) + if active_jobs: + self.logger.info(f"Active jobs: {active_jobs}") + for job_id, work_item in active_jobs.items(): + self.logger.info(f"Syncing job: {job_id}, {work_item}") + + except Exception as e: + logger.error(f"Error syncing jobs: {e}") + traceback.print_exc() diff --git a/poc/custom_gateway/direct-flow.py b/poc/custom_gateway/direct-flow.py index 82693b73..6755130f 100644 --- a/poc/custom_gateway/direct-flow.py +++ b/poc/custom_gateway/direct-flow.py @@ -58,15 +58,15 @@ def func_extract( print(f"FirstExec func called : {len(docs)}, {parameters}") # randomly throw an error to test the error handling - if random.random() > 0: - raise Exception("random error") + if random.random() > 0.999: + raise Exception("random error in FirstExec") for doc in docs: doc.text += " First Exec" - sec = 1 + sec = 10 print(f"Sleeping for {sec} seconds : ", time.time()) - time.sleep(1) + time.sleep(sec) print(f"Sleeping for {sec} seconds - done : ", time.time()) return { "parameters": parameters, diff --git a/poc/custom_gateway/server_gateway.py b/poc/custom_gateway/server_gateway.py index 4d59145c..8b8f6b2f 100644 --- a/poc/custom_gateway/server_gateway.py +++ b/poc/custom_gateway/server_gateway.py @@ -75,7 +75,7 @@ def __init__(self, **kwargs): "username": "postgres", "password": "123456", "database": "postgres", - "default_table": "kv_store_a", + "default_table": "kv_store_worker", "max_pool_size": 5, "max_connections": 5, } @@ -89,7 +89,6 @@ def __init__(self, **kwargs): "password": "123456", } - self.syncer = SyncManager(scheduler_config) self.distributor = GatewayJobDistributor( gateway_streamer=None, logger=self.logger ) diff --git a/tests/core/test_job_manager.py b/tests/core/test_job_manager.py index 0620d412..c8a82a1a 100644 --- a/tests/core/test_job_manager.py +++ b/tests/core/test_job_manager.py @@ -12,6 +12,9 @@ from marie import Deployment, Document, DocumentArray, Executor, requests from marie.enums import PollingType +from marie.job.common import JobInfo, JobStatus +from marie.job.job_distributor import JobDistributor +from marie.job.job_manager import JobManager from marie.parsers import set_deployment_parser from marie.proto import jina_pb2 from marie.serve.networking.balancer.load_balancer import LoadBalancerType @@ -19,12 +22,9 @@ from marie.serve.runtimes.gateway.streamer import GatewayStreamer from marie.serve.runtimes.servers import BaseServer from marie.serve.runtimes.worker.request_handling import WorkerRequestHandler +from marie.storage.kv.in_memory import InMemoryKV +from marie.storage.kv.psql import PostgreSQLKV from marie.types.request.data import DataRequest -from marie_server.job.common import JobInfo, JobStatus -from marie_server.job.job_distributor import JobDistributor -from marie_server.job.job_manager import JobManager -from marie_server.storage.in_memory import InMemoryKV -from marie_server.storage.psql import PostgreSQLKV from tests.core.test_utils import async_delay, async_wait_for_condition_async_predicate from tests.helper import _generate_pod_args @@ -243,7 +243,7 @@ def _setup(pod0_port, pod1_port): @pytest.mark.parametrize("results_in_order", [False, True]) @pytest.mark.asyncio async def test_gateway_job_manager( - port_generator, parameters, target_executor, expected_text, results_in_order + port_generator, parameters, target_executor, expected_text, results_in_order ): pod0_port = port_generator() pod1_port = port_generator() @@ -264,11 +264,11 @@ async def test_gateway_job_manager( resp = DocList([]) num_resp = 0 async for r in gateway_streamer.stream_docs( - docs=input_da, - request_size=10, - parameters=parameters, - target_executor=target_executor, - results_in_order=results_in_order, + docs=input_da, + request_size=10, + parameters=parameters, + target_executor=target_executor, + results_in_order=results_in_order, ): num_resp += 1 resp.extend(r) @@ -294,13 +294,13 @@ async def test_gateway_job_manager( def _create_regular_deployment( - port, - name="", - executor=None, - noblock_on_start=True, - polling=PollingType.ANY, - shards=None, - replicas=None, + port, + name="", + executor=None, + noblock_on_start=True, + polling=PollingType.ANY, + shards=None, + replicas=None, ): # return Deployment(uses=executor, include_gateway=False, noblock_on_start=noblock_on_start, replicas=replicas, # shards=shards) From cef84f9e6b60e165ec097b6c00622fc23b3092df Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Wed, 2 Oct 2024 03:43:21 -0500 Subject: [PATCH 09/10] wip: work scheduler --- marie/serve/executors/__init__.py | 66 +++++++++++++--- .../serve/runtimes/worker/request_handling.py | 76 +++++++++++++++---- marie/storage/kv/psql.py | 2 +- marie_server/scheduler/psql.py | 75 ++++++++++++++---- poc/custom_gateway/direct-flow.py | 7 +- poc/custom_gateway/server_gateway.py | 2 +- 6 files changed, 183 insertions(+), 45 deletions(-) diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py index 1cc1205b..f3116483 100644 --- a/marie/serve/executors/__init__.py +++ b/marie/serve/executors/__init__.py @@ -747,7 +747,6 @@ async def __acall__(self, req_endpoint: str, **kwargs): # noqa: DAR102 # noqa: DAR201 """ - if req_endpoint in self.requests: return await self.__acall_endpoint__(req_endpoint, **kwargs) elif __default_endpoint__ in self.requests: @@ -831,26 +830,71 @@ async def wrapper(*args, **kwargs): if is_parameters_pydantic_model: func = parameters_as_pydantic_models_decorator(func, parameters_model) + completion_callback = kwargs.pop("completion_callback") + async def exec_func( summary, histogram, histogram_metric_labels, tracing_context ): try: + # wrap the func to allow for capturing a return value and calling our completion + # callback to indicate that the job has completed + def completion_function_wrapper(fn): + if iscoroutinefunction(fn): + + @functools.wraps(fn) + async def arg_wrapper(*args, **kwargs): + ex: Exception = None + retval: Any = None + try: + retval = await fn(*args, **kwargs) + return retval + except Exception as exc: + ex = exc + finally: + await completion_callback(retval, ex) + + return arg_wrapper + else: + + @functools.wraps(fn) + def arg_wrapper(*args, **kwargs): + ex: Exception = None + retval: Any = None + try: + retval = fn(*args, **kwargs) + return retval + except Exception as exc: + ex = exc + finally: + loop = get_or_reuse_loop() + task = loop.create_task(completion_callback(retval, ex)) + loop.run_until_complete(task) + + return arg_wrapper + with MetricsTimer(summary, histogram, histogram_metric_labels): if iscoroutinefunction(func): - return await func( + wrapped_func = completion_function_wrapper(func) + return await wrapped_func( self, tracing_context=tracing_context, **kwargs ) else: async with self._lock: - return await get_or_reuse_loop().run_in_executor( - None, - functools.partial( - func, - self, - tracing_context=tracing_context, - **kwargs, - ), - ) + try: + return await get_or_reuse_loop().run_in_executor( + None, + functools.partial( + completion_function_wrapper(func), + self, + tracing_context=tracing_context, + **kwargs, + ), + ) + except asyncio.CancelledError as e: + self.logger.error( + f"Task was cancelled due to client request, for {req_endpoint} endpoint: {e}" + ) + raise e except Exception as e: self.logger.error(f"Error while executing {req_endpoint} endpoint: {e}") raise e diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index 6bf05f22..9f636533 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -11,6 +11,7 @@ import warnings from typing import ( TYPE_CHECKING, + Any, AsyncIterator, Dict, Generator, @@ -21,6 +22,8 @@ ) from google.protobuf.struct_pb2 import Struct +from IPython.terminal.shortcuts.auto_suggest import discard +from tvm.relay.backend.interpreter import Executor from marie._docarray import DocumentArray, docarray_v2 from marie.constants import __default_endpoint__ @@ -724,7 +727,42 @@ async def handle( docs_matrix, docs_map = WorkerRequestHandler._get_docs_matrix_from_request( requests ) + + client_disconnected = False + + async def executor_completion_callback( + job_id: str, + requests: List["DataRequest"], + return_data: Any, + raised_exception: Exception, + ): + self.logger.info(f"executor_completion_callback : {job_id}") + # TODO : add support for handling client disconnect rejects + additional_metadata = {"client_disconnected": client_disconnected} + + if raised_exception: + val = "".join( + traceback.format_exception( + raised_exception, limit=None, chain=True + ) + ) + self.logger.error( + f"{raised_exception!r} during executor handling" + + f'\n add "--quiet-error" to suppress the exception details' + + f"\n {val}" + ) + + await self._record_failed_job( + job_id, requests, raised_exception, additional_metadata + ) + else: + await self._record_successful_job( + job_id, requests, additional_metadata + ) + try: + # we adding a callback to track when the executor have finished as the client disconnect will trigger + # `asyncio.CancelledError` however the Task is still running in the background with success or exception return_data = await self._executor.__acall__( req_endpoint=exec_endpoint, docs=docs, @@ -732,21 +770,18 @@ async def handle( docs_matrix=docs_matrix, docs_map=docs_map, tracing_context=tracing_context, + completion_callback=functools.partial( + executor_completion_callback, job_id, requests + ), ) _ = self._set_result(requests, return_data, docs) - await self._record_successful_job(job_id, requests) except asyncio.CancelledError: - print("Task was cancelled due to client disconnect") + self.logger.warning("Task was cancelled due to client disconnect") + client_disconnected = True raise - except BaseException as e: - self.logger.error(f"Error during __acall__: {e}") - await self._record_failed_job(job_id, requests, e) + except Exception as e: + self.logger.error(f"Error during __acall__ {client_disconnected}: {e}") raise e - finally: - pass - # we do this here to make sure that we record the successful job even if the response fails back to the gateway - # if not failed: - # await self._record_successful_job(job_id, requests) for req in requests: req.add_executor(self.deployment_name) @@ -1375,6 +1410,7 @@ async def _record_failed_job( job_id: str, requests: List["DataRequest"], e: Exception, + metadata_attributes: Optional[Dict], ): print(f"Record job failed: {job_id} - {e}") if job_id is not None and self._job_info_client is not None: @@ -1406,12 +1442,16 @@ async def _record_failed_job( "line_no": line_no, } + request_attributes = self._request_attributes(requests) + if metadata_attributes: + request_attributes.update(metadata_attributes) + await self._job_info_client.put_status( job_id, JobStatus.FAILED, jobinfo_replace_kwargs={ "metadata": { - "attributes": self._request_attributes(requests), + "attributes": request_attributes, "error": exc, } }, @@ -1425,6 +1465,7 @@ async def _record_started_job( print(f"Record job started: {job_id}") if job_id is not None and self._job_info_client is not None: try: + await self._job_info_client.put_status( job_id, JobStatus.RUNNING, @@ -1438,15 +1479,24 @@ async def _record_started_job( except Exception as e: self.logger.error(f"Error recording job status: {e}") - async def _record_successful_job(self, job_id: str, requests: List["DataRequest"]): + async def _record_successful_job( + self, + job_id: str, + requests: List["DataRequest"], + metadata_attributes: Optional[Dict], + ): print(f"Record job success: {job_id}") if job_id is not None and self._job_info_client is not None: try: + request_attributes = self._request_attributes(requests) + if metadata_attributes: + request_attributes.update(metadata_attributes) + await self._job_info_client.put_status( job_id, JobStatus.SUCCEEDED, jobinfo_replace_kwargs={ - "metadata": {"attributes": self._request_attributes(requests)} + "metadata": {"attributes": request_attributes} }, ) except Exception as e: diff --git a/marie/storage/kv/psql.py b/marie/storage/kv/psql.py index 029bd34c..7ed333dd 100644 --- a/marie/storage/kv/psql.py +++ b/marie/storage/kv/psql.py @@ -15,7 +15,7 @@ class PostgreSQLKV(PostgresqlMixin, StorageArea): JSONB data type. """ - def __init__(self, config: Dict[str, Any], reset=True): + def __init__(self, config: Dict[str, Any], reset=False): super().__init__() self.logger = MarieLogger(self.__class__.__name__) self.running = False diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py index c4687649..ade2b0fc 100644 --- a/marie_server/scheduler/psql.py +++ b/marie_server/scheduler/psql.py @@ -56,6 +56,8 @@ def convert_job_status_to_work_state(job_status: JobStatus) -> WorkState: return WorkState.COMPLETED elif job_status == JobStatus.FAILED: return WorkState.FAILED + elif job_status == JobStatus.STOPPED: + return WorkState.CANCELLED else: raise ValueError(f"Unknown JobStatus: {job_status}") @@ -395,17 +397,23 @@ async def get_job(self, job_id: str) -> Optional[WorkInfo]: self.connection.commit() async def list_jobs( - self, state: Optional[str] = None, batch_size: int = 0 + self, state: Optional[str | list[str]] = None, batch_size: int = 0 ) -> Dict[str, WorkInfo]: work_items = {} schema = DEFAULT_SCHEMA table = DEFAULT_JOB_TABLE - states = "','".join(WorkState.__members__.keys()) + if state is not None: - if state.upper() not in WorkState.__members__: - raise ValueError(f"Invalid state: {state}") - states = state - states = states.lower() + if isinstance(state, str): + state = [state] + invalid_states = [ + s for s in state if s.upper() not in WorkState.__members__ + ] + if invalid_states: + raise ValueError(f"Invalid state(s): {', '.join(invalid_states)}") + states = "','".join(s.lower() for s in state) + else: + states = "','".join(WorkState.__members__.keys()).lower() with self: try: @@ -438,7 +446,7 @@ async def submit_job(self, work_info: WorkInfo, overwrite: bool = True) -> str: new_key_added = False submission_id = work_info.id - work_info.retry_limit = 2 + work_info.retry_limit = 0 with self: try: @@ -475,17 +483,16 @@ async def delete_job(self, job_id: str): raise NotImplementedError - async def cancel_job(self, job_id: str) -> None: + async def cancel_job(self, job_id: str, work_item: WorkInfo) -> None: """ Cancel a job by its ID. :param job_id: """ - name = "extract" # TODO this is a placeholder with self: try: self.logger.info(f"Cancelling job: {job_id}") self._execute_sql_gracefully( - cancel_jobs(DEFAULT_SCHEMA, name, [job_id]) + cancel_jobs(DEFAULT_SCHEMA, work_item.name, [job_id]) ) except (Exception, psycopg2.Error) as error: self.logger.error(f"Error handling job event: {error}") @@ -651,7 +658,9 @@ async def _monitor(self): traceback.print_exc() # TODO: emit error event - async def complete(self, job_id: str, work_item: WorkInfo): + async def complete( + self, job_id: str, work_item: WorkInfo, output_metadata: dict = None + ): self.logger.info(f"Job completed : {job_id}, {work_item}") with self: try: @@ -660,7 +669,7 @@ async def complete(self, job_id: str, work_item: WorkInfo): DEFAULT_SCHEMA, work_item.name, [job_id], - {"on_complete": "done"}, + {"on_complete": "done", **(output_metadata or {})}, ) ) counts = cursor.fetchone()[0] @@ -671,7 +680,9 @@ async def complete(self, job_id: str, work_item: WorkInfo): except (Exception, psycopg2.Error) as error: self.logger.error(f"Error completing job: {error}") - async def fail(self, job_id: str, work_item: WorkInfo): + async def fail( + self, job_id: str, work_item: WorkInfo, output_metadata: dict = None + ): self.logger.info(f"Job failed : {job_id}, {work_item}") with self: try: @@ -680,7 +691,7 @@ async def fail(self, job_id: str, work_item: WorkInfo): DEFAULT_SCHEMA, work_item.name, [job_id], - {"on_complete": "failed"}, + {"on_complete": "failed", **(output_metadata or {})}, ) ) counts = cursor.fetchone()[0] @@ -696,13 +707,45 @@ async def _sync(self): while self.running: self.logger.info(f"Syncing jobs status : {wait_time}") await asyncio.sleep(wait_time) + job_info_client = self.job_manager.job_info_client() try: - active_jobs = await self.list_jobs(state=WorkState.ACTIVE.value) + active_jobs = await self.list_jobs( + state=[WorkState.ACTIVE.value, WorkState.CREATED.value] + ) if active_jobs: - self.logger.info(f"Active jobs: {active_jobs}") for job_id, work_item in active_jobs.items(): self.logger.info(f"Syncing job: {job_id}, {work_item}") + job_info = await job_info_client.get_info(job_id) + if job_info is None: + self.logger.error(f"Job not found: {job_id}") + continue + + job_info_state = convert_job_status_to_work_state( + job_info.status + ) + if work_item.state != job_info_state: + self.logger.info( + f"State mismatch for job {job_id}: " + f"WorkState={work_item.state}, JobInfoState={job_info_state}. " + f"Updating to JobInfoState." + ) + # check if terminal status + if job_info.status.is_terminal(): + self.logger.info( + f"Job {job_id} is in terminal state, synchronizing." + ) + meta = {"synced": True} + if job_info.status == JobStatus.SUCCEEDED: + await self.complete(job_id, work_item, meta) + elif job_info.status == JobStatus.FAILED: + await self.fail(job_id, work_item, meta) + elif job_info.status == JobStatus.STOPPED: + await self.cancel_job(job_id, work_item) + else: + self.logger.error( + f"Unhandled terminal status: {job_info.status}" + ) except Exception as e: logger.error(f"Error syncing jobs: {e}") diff --git a/poc/custom_gateway/direct-flow.py b/poc/custom_gateway/direct-flow.py index 6755130f..6c00214e 100644 --- a/poc/custom_gateway/direct-flow.py +++ b/poc/custom_gateway/direct-flow.py @@ -42,10 +42,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) print("TestExecutor init called") # emulate the long loading time - time.sleep(1) + # time.sleep(1 @requests(on="/extract") - def func_extract( + async def func_extract( self, docs: DocList[TextDoc], parameters=None, @@ -63,10 +63,11 @@ def func_extract( for doc in docs: doc.text += " First Exec" - sec = 10 print(f"Sleeping for {sec} seconds : ", time.time()) time.sleep(sec) + raise Exception("random error in FirstExec") + print(f"Sleeping for {sec} seconds - done : ", time.time()) return { "parameters": parameters, diff --git a/poc/custom_gateway/server_gateway.py b/poc/custom_gateway/server_gateway.py index 8b8f6b2f..52bf7ba6 100644 --- a/poc/custom_gateway/server_gateway.py +++ b/poc/custom_gateway/server_gateway.py @@ -93,7 +93,7 @@ def __init__(self, **kwargs): gateway_streamer=None, logger=self.logger ) - storage = PostgreSQLKV(config=kv_storage_config, reset=True) + storage = PostgreSQLKV(config=kv_storage_config, reset=False) job_manager = JobManager(storage=storage, job_distributor=self.distributor) self.job_scheduler = PostgreSQLJobScheduler( config=scheduler_config, job_manager=job_manager From f125276567b7a64ef4916d207d797cc384c24df8 Mon Sep 17 00:00:00 2001 From: Grzegorz Bugaj Date: Wed, 2 Oct 2024 04:29:49 -0500 Subject: [PATCH 10/10] feat: Merge JINA 3.27.18 --- extra-requirements.txt | 8 +- marie/clients/__init__.py | 3 + marie/clients/base/__init__.py | 68 ++-- marie/clients/base/grpc.py | 5 +- marie/clients/base/helper.py | 113 ++++--- marie/clients/base/http.py | 139 ++++++--- marie/clients/base/websocket.py | 5 +- marie/clients/http.py | 7 + marie/clients/request/helper.py | 1 + marie/orchestrate/flow/base.py | 13 +- marie/parsers/client.py | 6 + marie/resources/logging.plain.yml | 7 + .../project-template/deployment/client.py | 4 +- .../resources/project-template/flow/client.py | 4 +- marie/serve/executors/__init__.py | 17 +- marie/serve/executors/decorators.py | 20 +- marie/serve/runtimes/worker/batch_queue.py | 291 +++++++++++------- marie/serve/runtimes/worker/http_csp_app.py | 11 +- .../serve/runtimes/worker/http_fastapi_app.py | 14 +- .../serve/runtimes/worker/request_handling.py | 87 ++++-- marie_cli/autocomplete.py | 120 +++++--- 21 files changed, 617 insertions(+), 326 deletions(-) create mode 100644 marie/resources/logging.plain.yml diff --git a/extra-requirements.txt b/extra-requirements.txt index df3eadef..eb9fedfb 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -72,19 +72,19 @@ aiofiles: standard,devel aiohttp: standard,devel aiostream: standard,devel -pytest: test +pytest<8.0.0: test pytest-timeout: test pytest-mock: test pytest-cov==3.0.0: test coverage==6.2: test pytest-repeat: test -pytest-asyncio: test +pytest-asyncio<0.23.0: test pytest-reraise: test mock: test requests-mock: test pytest-custom_exit_code: test -black==22.3.0: test -kubernetes>=18.20.0: test +black==24.3.0: test +kubernetes>=18.20.0,<31.0.0: test pytest-kind==22.11.1: test pytest-lazy-fixture: test torch: cicd diff --git a/marie/clients/__init__.py b/marie/clients/__init__.py index 3ddb97a7..99e9608e 100644 --- a/marie/clients/__init__.py +++ b/marie/clients/__init__.py @@ -30,6 +30,7 @@ def Client( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -59,6 +60,7 @@ def Client( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[ Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. diff --git a/marie/clients/base/__init__.py b/marie/clients/base/__init__.py index 845a1deb..001ef52c 100644 --- a/marie/clients/base/__init__.py +++ b/marie/clients/base/__init__.py @@ -5,7 +5,15 @@ import inspect import os from abc import ABC -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union +from typing import ( + TYPE_CHECKING, + AsyncIterator, + Callable, + Iterator, + Optional, + Tuple, + Union, +) from marie.excepts import BadClientInput from marie.helper import T, parse_client, send_telemetry_event, typename @@ -47,7 +55,6 @@ def __init__( # affect users os-level envs. os.unsetenv('http_proxy') os.unsetenv('https_proxy') - self._inputs = None self._setup_instrumentation( name=( self.args.name @@ -63,6 +70,12 @@ def __init__( ) send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__) + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + return self.teardown_instrumentation() + def teardown_instrumentation(self): """Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for exporting metrics data is properly cleaned up. @@ -118,62 +131,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None: raise BadClientInput from ex def _get_requests( - self, **kwargs - ) -> Union[Iterator['Request'], AsyncIterator['Request']]: + self, inputs, **kwargs + ) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]: """ Get request in generator. + :param inputs: The inputs argument to get the requests from. :param kwargs: Keyword arguments. - :return: Iterator of request. + :return: Iterator of request and the length of the inputs. """ _kwargs = vars(self.args) - _kwargs['data'] = self.inputs + if hasattr(inputs, '__call__'): + inputs = inputs() + + _kwargs['data'] = inputs # override by the caller-specific kwargs _kwargs.update(kwargs) - if hasattr(self._inputs, '__len__'): - total_docs = len(self._inputs) + if hasattr(inputs, '__len__'): + total_docs = len(inputs) elif 'total_docs' in _kwargs: total_docs = _kwargs['total_docs'] else: total_docs = None - self._inputs_length = None - if total_docs: - self._inputs_length = max(1, total_docs / _kwargs['request_size']) + inputs_length = max(1, total_docs / _kwargs['request_size']) + else: + inputs_length = None - if inspect.isasyncgen(self.inputs): + if inspect.isasyncgen(inputs): from marie.clients.request.asyncio import request_generator - return request_generator(**_kwargs) + return request_generator(**_kwargs), inputs_length else: from marie.clients.request import request_generator - return request_generator(**_kwargs) - - @property - def inputs(self) -> 'InputType': - """ - An iterator of bytes, each element represents a Document's raw content. - - ``inputs`` defined in the protobuf - - :return: inputs - """ - return self._inputs - - @inputs.setter - def inputs(self, bytes_gen: 'InputType') -> None: - """ - Set the input data. - - :param bytes_gen: input type - """ - if hasattr(bytes_gen, '__call__'): - self._inputs = bytes_gen() - else: - self._inputs = bytes_gen + return request_generator(**_kwargs), inputs_length @abc.abstractmethod async def _get_results( diff --git a/marie/clients/base/grpc.py b/marie/clients/base/grpc.py index 5a3f1089..0c8c5a8f 100644 --- a/marie/clients/base/grpc.py +++ b/marie/clients/base/grpc.py @@ -90,8 +90,7 @@ async def _get_results( else grpc.Compression.NoCompression ) - self.inputs = inputs - req_iter = self._get_requests(**kwargs) + req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs) continue_on_error = self.continue_on_error # while loop with retries, check in which state the `iterator` remains after failure options = client_grpc_options( @@ -120,7 +119,7 @@ async def _get_results( self.logger.debug(f'connected to {self.args.host}:{self.args.port}') with ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) as p_bar: try: if stream: diff --git a/marie/clients/base/helper.py b/marie/clients/base/helper.py index 1a78ff97..eae40d00 100644 --- a/marie/clients/base/helper.py +++ b/marie/clients/base/helper.py @@ -48,18 +48,16 @@ class AioHttpClientlet(ABC): def __init__( self, - url: str, - logger: 'MarieLogger', + logger: "MarieLogger", max_attempts: int = 1, initial_backoff: float = 0.5, max_backoff: float = 2, backoff_multiplier: float = 1.5, - tracer_provider: Optional['trace.TraceProvider'] = None, + tracer_provider: Optional["trace.TraceProvider"] = None, **kwargs, ) -> None: """HTTP Client to be used with the streamer - :param url: url to send http/websocket request to :param logger: jina logger :param max_attempts: Number of sending attempts, including the original request. :param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff) @@ -68,7 +66,6 @@ def __init__( :param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing. :param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests """ - self.url = url self.logger = logger self.msg_recv = 0 self.msg_sent = 0 @@ -80,15 +77,15 @@ def __init__( self._trace_config = None self.session = None self._session_kwargs = {} - if kwargs.get('headers', None): - self._session_kwargs['headers'] = kwargs.get('headers') - if kwargs.get('auth', None): - self._session_kwargs['auth'] = kwargs.get('auth') - if kwargs.get('cookies', None): - self._session_kwargs['cookies'] = kwargs.get('cookies') - if kwargs.get('timeout', None): - timeout = aiohttp.ClientTimeout(total=kwargs.get('timeout')) - self._session_kwargs['timeout'] = timeout + if kwargs.get("headers", None): + self._session_kwargs["headers"] = kwargs.get("headers") + if kwargs.get("auth", None): + self._session_kwargs["auth"] = kwargs.get("auth") + if kwargs.get("cookies", None): + self._session_kwargs["cookies"] = kwargs.get("cookies") + if kwargs.get("timeout", None): + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout")) + self._session_kwargs["timeout"] = timeout self.max_attempts = max_attempts self.initial_backoff = initial_backoff self.max_backoff = max_backoff @@ -154,33 +151,45 @@ class HTTPClientlet(AioHttpClientlet): UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}" - async def send_message(self, request: 'Request'): + async def send_message(self, url, request: "Request"): """Sends a POST request to the server + :param url: the URL where to send the message :param request: request as dict :return: send post message """ req_dict = request.to_dict() - req_dict['exec_endpoint'] = req_dict['header']['exec_endpoint'] - if 'target_executor' in req_dict['header']: - req_dict['target_executor'] = req_dict['header']['target_executor'] + req_dict["exec_endpoint"] = req_dict["header"]["exec_endpoint"] + if "target_executor" in req_dict["header"]: + req_dict["target_executor"] = req_dict["header"]["target_executor"] for attempt in range(1, self.max_attempts + 1): try: - request_kwargs = {'url': self.url} + request_kwargs = {"url": url} if not docarray_v2: - request_kwargs['json'] = req_dict + request_kwargs["json"] = req_dict else: from docarray.base_doc.io.json import orjson_dumps - request_kwargs['data'] = JinaJsonPayload(value=req_dict) - response = await self.session.post(**request_kwargs).__aenter__() - try: - r_str = await response.json() - except aiohttp.ContentTypeError: - r_str = await response.text() - handle_response_status(response.status, r_str, self.url) - return response - except (ValueError, ConnectionError, BadClient, aiohttp.ClientError) as err: + request_kwargs["data"] = JinaJsonPayload(value=req_dict) + + async with self.session.post(**request_kwargs) as response: + try: + r_str = await response.json() + except aiohttp.ContentTypeError: + r_str = await response.text() + r_status = response.status + handle_response_status(r_status, r_str, url) + return r_status, r_str + except ( + ValueError, + ConnectionError, + BadClient, + aiohttp.ClientError, + aiohttp.ClientConnectionError, + ) as err: + self.logger.debug( + f"Got an error of type {type(err)}: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}" + ) await retry.wait_or_raise_err( attempt=attempt, err=err, @@ -189,37 +198,44 @@ async def send_message(self, request: 'Request'): initial_backoff=self.initial_backoff, max_backoff=self.max_backoff, ) + except Exception as exc: + self.logger.debug( + f"Got a non-retried error of type {type(exc)}: {exc} sending POST to {url}" + ) + raise exc - async def send_streaming_message(self, doc: 'Document', on: str): + async def send_streaming_message(self, url, doc: "Document", on: str): """Sends a GET SSE request to the server + :param url: the URL where to send the message :param doc: Request Document :param on: Request endpoint :yields: responses """ req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict() request_kwargs = { - 'url': self.url, - 'headers': {'Accept': 'text/event-stream'}, - 'json': req_dict, + "url": url, + "headers": {"Accept": "text/event-stream"}, + "json": req_dict, } async with self.session.get(**request_kwargs) as response: async for chunk in response.content.iter_any(): - events = chunk.split(b'event: ')[1:] + events = chunk.split(b"event: ")[1:] for event in events: - if event.startswith(b'update'): + if event.startswith(b"update"): yield event[self.UPDATE_EVENT_PREFIX :].decode() - elif event.startswith(b'end'): + elif event.startswith(b"end"): pass - async def send_dry_run(self, **kwargs): + async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway + :param url: the URL where to send the message :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ return await self.session.get( - url=self.url, timeout=kwargs.get('timeout', None) + url=url, timeout=kwargs.get("timeout", None) ).__aenter__() async def recv_message(self): @@ -261,12 +277,13 @@ async def __anext__(self): class WebsocketClientlet(AioHttpClientlet): """Websocket Client to be used with the streamer""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, url, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.url = url self.websocket = None self.response_iter = None - async def send_message(self, request: 'Request'): + async def send_message(self, request: "Request"): """Send request in bytes to the server. :param request: request object @@ -292,9 +309,9 @@ async def send_dry_run(self, **kwargs): """ try: - return await self.websocket.send_bytes(b'') + return await self.websocket.send_bytes(b"") except ConnectionResetError: - self.logger.critical(f'server connection closed already!') + self.logger.critical(f"server connection closed already!") async def send_eoi(self): """To confirm end of iteration, we send `bytes(True)` to the server. @@ -308,7 +325,7 @@ async def send_eoi(self): # which raises a `ConnectionResetError`, this can be ignored. pass - async def recv_message(self) -> 'DataRequest': + async def recv_message(self) -> "DataRequest": """Receive messages in bytes from server and convert to `DataRequest` ..note:: @@ -376,18 +393,18 @@ def handle_response_status( :param url: request url string """ if http_status == status.HTTP_404_NOT_FOUND: - raise BadClient(f'no such endpoint {url}') + raise BadClient(f"no such endpoint {url}") elif ( http_status == status.HTTP_503_SERVICE_UNAVAILABLE or http_status == status.HTTP_504_GATEWAY_TIMEOUT ): if ( isinstance(response_content, dict) - and 'header' in response_content - and 'status' in response_content['header'] - and 'description' in response_content['header']['status'] + and "header" in response_content + and "status" in response_content["header"] + and "description" in response_content["header"]["status"] ): - raise ConnectionError(response_content['header']['status']['description']) + raise ConnectionError(response_content["header"]["status"]["description"]) else: raise ValueError(response_content) elif ( diff --git a/marie/clients/base/http.py b/marie/clients/base/http.py index b4c2b27f..a4b22c4b 100644 --- a/marie/clients/base/http.py +++ b/marie/clients/base/http.py @@ -23,8 +23,20 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self.reuse_session = False + self._lock = AsyncExitStack() + self.iolet = None - async def _get_endpoints_from_openapi(self): + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + ret = super().close() + if self.iolet is not None: + await self.iolet.__aexit__(None, None, None) + return ret + + async def _get_endpoints_from_openapi(self, **kwargs): def extract_paths_by_method(spec): paths_by_method = {} for path, methods in spec['paths'].items(): @@ -39,10 +51,15 @@ def extract_paths_by_method(spec): import aiohttp + session_kwargs = {} + if 'headers' in kwargs: + session_kwargs = {'headers': kwargs['headers']} + proto = 'https' if self.args.tls else 'http' target_url = f'{proto}://{self.args.host}:{self.args.port}/openapi.json' try: - async with aiohttp.ClientSession() as session: + + async with aiohttp.ClientSession(**session_kwargs) as session: async with session.get(target_url) as response: content = await response.read() openapi_response = json.loads(content.decode()) @@ -64,16 +81,27 @@ async def _is_flow_ready(self, **kwargs) -> bool: try: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - **kwargs, + + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet - response = await iolet.send_dry_run(**kwargs) + response = await iolet.send_dry_run(url=url, **kwargs) r_status = response.status r_str = await response.json() @@ -125,15 +153,14 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) on = kwargs.get('on', '/post') if len(self._endpoints) == 0: - await self._get_endpoints_from_openapi() + await self._get_endpoints_from_openapi(**kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.tls else 'http' @@ -146,19 +173,35 @@ async def _get_results( url = f'{proto}://{self.args.host}:{self.args.port}/default' else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - max_attempts=max_attempts, - initial_backoff=initial_backoff, - max_backoff=max_backoff, - backoff_multiplier=backoff_multiplier, - timeout=timeout, - **kwargs, + + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + max_attempts=max_attempts, + initial_backoff=initial_backoff, + max_backoff=max_backoff, + backoff_multiplier=backoff_multiplier, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + max_attempts=max_attempts, + initial_backoff=initial_backoff, + max_backoff=max_backoff, + backoff_multiplier=backoff_multiplier, + **kwargs, + ) + self.iolet = await self.iolet.__aenter__() + iolet = self.iolet def _request_handler( request: 'Request', **kwargs @@ -170,7 +213,10 @@ def _request_handler( :param kwargs: kwargs :return: asyncio Task for sending message """ - return asyncio.ensure_future(iolet.send_message(request=request)), None + return ( + asyncio.ensure_future(iolet.send_message(url=url, request=request)), + None, + ) def _result_handler(result): return result @@ -187,9 +233,7 @@ def _result_handler(result): async for response in streamer.stream( request_iterator=request_iterator, results_in_order=results_in_order ): - r_status = response.status - - r_str = await response.json() + r_status, r_str = response handle_response_status(r_status, r_str, url) da = None @@ -213,7 +257,7 @@ def _result_handler(result): resp = DataRequest(r_str) if da is not None: - resp.data.docs = da + resp.direct_docs = da callback_exec( response=resp, @@ -244,17 +288,28 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}' else: url = f'{proto}://{self.args.host}:{self.args.port}/default' - - iolet = HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - timeout=timeout, - **kwargs, - ) - - async with iolet: - async for doc in iolet.send_streaming_message(doc=inputs, on=on): + async with AsyncExitStack() as stack: + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet + async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc)) else: diff --git a/marie/clients/base/websocket.py b/marie/clients/base/websocket.py index a37d2c32..8b565a47 100644 --- a/marie/clients/base/websocket.py +++ b/marie/clients/base/websocket.py @@ -108,12 +108,11 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not (self.show_progress) + total_length=inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1) diff --git a/marie/clients/http.py b/marie/clients/http.py index fb1fd92f..fa102bd6 100644 --- a/marie/clients/http.py +++ b/marie/clients/http.py @@ -1,3 +1,5 @@ +import asyncio + from marie.clients.base.http import HTTPBaseClient from marie.clients.mixin import ( AsyncHealthCheckMixin, @@ -81,3 +83,8 @@ async def async_inputs(): print(resp) """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lock = asyncio.Lock() + self.reuse_session = self.args.reuse_session diff --git a/marie/clients/request/helper.py b/marie/clients/request/helper.py index 10de681b..92481f51 100644 --- a/marie/clients/request/helper.py +++ b/marie/clients/request/helper.py @@ -1,4 +1,5 @@ """Module for helper functions for clients.""" + from typing import Optional, Tuple from marie._docarray import Document, DocumentArray, docarray_v2 diff --git a/marie/orchestrate/flow/base.py b/marie/orchestrate/flow/base.py index ba8e2273..b2795ee5 100644 --- a/marie/orchestrate/flow/base.py +++ b/marie/orchestrate/flow/base.py @@ -142,6 +142,7 @@ def __init__( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -164,6 +165,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -426,6 +428,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -475,7 +478,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -1148,7 +1151,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor @@ -1517,7 +1520,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -2903,9 +2906,9 @@ def to_docker_compose_yaml( yaml.dump(docker_compose_dict, fp, sort_keys=False) command = ( - 'docker-compose up' + 'docker compose up' if output_path is None - else f'docker-compose -f {output_path} up' + else f'docker compose -f {output_path} up' ) self.logger.info( diff --git a/marie/parsers/client.py b/marie/parsers/client.py index 7b09eead..9bffb5cf 100644 --- a/marie/parsers/client.py +++ b/marie/parsers/client.py @@ -81,3 +81,9 @@ def mixin_client_features_parser(parser): default='default', help='The config name or the absolute path to the YAML config file of the logger used in this object.', ) + parser.add_argument( + '--reuse-session', + action='store_true', + default=False, + help='True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it', + ) diff --git a/marie/resources/logging.plain.yml b/marie/resources/logging.plain.yml new file mode 100644 index 00000000..747061d1 --- /dev/null +++ b/marie/resources/logging.plain.yml @@ -0,0 +1,7 @@ +handlers: + - StreamHandler +level: INFO +configs: + StreamHandler: + format: '{name:>15}@%(process)2d[%(levelname).1s]:%(message)s' + formatter: PlainFormatter diff --git a/marie/resources/project-template/deployment/client.py b/marie/resources/project-template/deployment/client.py index 897dc31b..ec1a4f2f 100644 --- a/marie/resources/project-template/deployment/client.py +++ b/marie/resources/project-template/deployment/client.py @@ -4,5 +4,7 @@ if __name__ == '__main__': c = Client(host='grpc://0.0.0.0:54321') - da = c.post('/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc]) + da = c.post( + '/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc] + ) print(da.text) diff --git a/marie/resources/project-template/flow/client.py b/marie/resources/project-template/flow/client.py index 897dc31b..ec1a4f2f 100644 --- a/marie/resources/project-template/flow/client.py +++ b/marie/resources/project-template/flow/client.py @@ -4,5 +4,7 @@ if __name__ == '__main__': c = Client(host='grpc://0.0.0.0:54321') - da = c.post('/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc]) + da = c.post( + '/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc] + ) print(da.text) diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py index f3116483..cd4c9df1 100644 --- a/marie/serve/executors/__init__.py +++ b/marie/serve/executors/__init__.py @@ -658,9 +658,22 @@ def _validate_sagemaker(self): return def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): + from collections.abc import Mapping + + def deep_update(source, overrides): + for key, value in overrides.items(): + if isinstance(value, Mapping) and value: + returned = deep_update(source.get(key, {}), value) + source[key] = returned + else: + source[key] = overrides[key] + return source + if _dynamic_batching: - self.dynamic_batching = getattr(self, "dynamic_batching", {}) - self.dynamic_batching.update(_dynamic_batching) + self.dynamic_batching = getattr(self, 'dynamic_batching', {}) + self.dynamic_batching = deep_update( + self.dynamic_batching, _dynamic_batching + ) def _add_metas(self, _metas: Optional[Dict]): from marie.serve.executors.metas import get_default_metas diff --git a/marie/serve/executors/decorators.py b/marie/serve/executors/decorators.py index 64896816..55c1be8a 100644 --- a/marie/serve/executors/decorators.py +++ b/marie/serve/executors/decorators.py @@ -1,11 +1,12 @@ """Decorators and wrappers designed for wrapping :class:`BaseExecutor` functions. """ + import functools import inspect import os from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union -from marie._docarray import Document, DocumentArray, docarray_v2 +from marie._docarray import Document, DocumentArray from marie.constants import __cache_path__ from marie.helper import is_generator, iscoroutinefunction from marie.importer import ImportExtensions @@ -415,6 +416,10 @@ def dynamic_batching( *, preferred_batch_size: Optional[int] = None, timeout: Optional[float] = 10_000, + flush_all: bool = False, + custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None, + use_custom_metric: bool = False, + use_dynamic_batching: bool = True, ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -425,11 +430,16 @@ def dynamic_batching( :param func: the method to decorate :param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached, - or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`. + or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True :param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch. If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor, even if it contains fewer than `preferred_batch_size` Documents. Default is 10_000ms (10 seconds). + :param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not. + If this is true, `preferred_batch_size` is used as a trigger mechanism. + :param custom_metric: Potential lambda function to measure the "weight" of each request. + :param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size. + :param use_dynamic_batching: Determines if we should apply dynamic batching for this method. :return: decorated function """ @@ -475,6 +485,12 @@ def _inject_owner_attrs(self, owner, name): 'preferred_batch_size' ] = preferred_batch_size owner.dynamic_batching[fn_name]['timeout'] = timeout + owner.dynamic_batching[fn_name]['flush_all'] = flush_all + owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric + owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric + owner.dynamic_batching[fn_name][ + 'use_dynamic_batching' + ] = use_dynamic_batching setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/marie/serve/runtimes/worker/batch_queue.py b/marie/serve/runtimes/worker/batch_queue.py index 8999a09c..357af64a 100644 --- a/marie/serve/runtimes/worker/batch_queue.py +++ b/marie/serve/runtimes/worker/batch_queue.py @@ -1,8 +1,9 @@ import asyncio +import copy from asyncio import Event, Task -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from marie._docarray import docarray_v2 +from jina._docarray import docarray_v2 if not docarray_v2: from docarray import DocumentArray @@ -25,10 +26,14 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, + flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, + custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None, + use_custom_metric: bool = False, + **kwargs, ) -> None: - self._data_lock = asyncio.Lock() + # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent self.func = func if params is None: params = dict() @@ -37,7 +42,10 @@ def __init__( self.params = params self._request_docarray_cls = request_docarray_cls self._response_docarray_cls = response_docarray_cls + self._flush_all = flush_all self._preferred_batch_size: int = preferred_batch_size + self._custom_metric = None if not use_custom_metric else custom_metric + self._metric_value = 0 self._timeout: int = timeout self._reset() self._flush_trigger: Event = Event() @@ -56,13 +64,16 @@ def _reset(self) -> None: # a list of every request ID self._request_idxs: List[int] = [] self._request_lens: List[int] = [] + self._docs_metrics: List[int] = [] self._requests_completed: List[asyncio.Queue] = [] if not docarray_v2: self._big_doc: DocumentArray = DocumentArray.empty() else: self._big_doc = self._request_docarray_cls() + self._metric_value = 0 self._flush_task: Optional[Task] = None + self._flush_trigger: Event = Event() def _cancel_timer_if_pending(self): if ( @@ -84,13 +95,14 @@ async def _sleep_then_set(self): self._flush_trigger.set() self._timer_finished = True - async def push(self, request: DataRequest) -> asyncio.Queue: + async def push(self, request: DataRequest, http=False) -> asyncio.Queue: """Append request to the the list of requests to be processed. This method creates an asyncio Queue for that request and keeps track of it. It returns this queue to the caller so that the caller can now when this request has been processed :param request: The request to append to the queue. + :param http: Flag to determine if the request is served via HTTP for some optims :return: The queue that will receive when the request is processed. """ @@ -101,30 +113,37 @@ async def push(self, request: DataRequest) -> asyncio.Queue: # this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc` # before the `flush` task processes it. self._start_timer() - async with self._data_lock: - if not self._flush_task: - self._flush_task = asyncio.create_task(self._await_then_flush()) - - self._big_doc.extend(docs) - next_req_idx = len(self._requests) - num_docs = len(docs) - self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(len(docs)) - self._requests.append(request) - queue = asyncio.Queue() - self._requests_completed.append(queue) - if len(self._big_doc) >= self._preferred_batch_size: - self._flush_trigger.set() + if not self._flush_task: + self._flush_task = asyncio.create_task(self._await_then_flush(http)) + self._big_doc.extend(docs) + next_req_idx = len(self._requests) + num_docs = len(docs) + metric_value = num_docs + if self._custom_metric is not None: + metrics = [self._custom_metric(doc) for doc in docs] + metric_value += sum(metrics) + self._docs_metrics.extend(metrics) + self._metric_value += metric_value + self._request_idxs.extend([next_req_idx] * num_docs) + self._request_lens.append(num_docs) + self._requests.append(request) + queue = asyncio.Queue() + self._requests_completed.append(queue) + if self._metric_value >= self._preferred_batch_size: + self._flush_trigger.set() return queue - async def _await_then_flush(self) -> None: - """Process all requests in the queue once flush_trigger event is set.""" + async def _await_then_flush(self, http=False) -> None: + """Process all requests in the queue once flush_trigger event is set. + :param http: Flag to determine if the request is served via HTTP for some optims + """ def _get_docs_groups_completed_request_indexes( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ): """ This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to. @@ -133,6 +152,7 @@ def _get_docs_groups_completed_request_indexes( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. :return: list of document groups and a list of request Idx to which each of these groups belong """ @@ -161,7 +181,7 @@ def _get_docs_groups_completed_request_indexes( if ( req_idx not in completed_req_idx and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx - == self._request_lens[req_idx] + == requests_lens_in_batch[req_idx] ): completed_req_idx.append(req_idx) request_bucket = non_assigned_docs[ @@ -175,6 +195,9 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, ): """ This method aims to assign to the corresponding request objects the resulting documents from the mini batches. @@ -184,6 +207,9 @@ async def _assign_results( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. + :param requests_in_batch: List requests in batch + :param requests_completed_in_batch: List of queues for requests to be completed :return: amount of assigned documents so that some documents can come back in the next iteration """ @@ -194,113 +220,154 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ) num_assigned_docs = sum(len(group) for group in docs_grouped) for docs_group, request_idx in zip(docs_grouped, completed_req_idxs): - request = self._requests[request_idx] - request_completed = self._requests_completed[request_idx] - request.data.set_docs_convert_arrays( - docs_group, ndarray_type=self._output_array_type - ) + request = requests_in_batch[request_idx] + request_completed = requests_completed_in_batch[request_idx] + if http is False or self._output_array_type is not None: + request.direct_docs = None # batch queue will work in place, therefore result will need to read from data. + request.data.set_docs_convert_arrays( + docs_group, ndarray_type=self._output_array_type + ) + else: + request.direct_docs = docs_group await request_completed.put(None) return num_assigned_docs - def batch(iterable_1, iterable_2, n=1): - items = len(iterable_1) - for ndx in range(0, items, n): - yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ - ndx : min(ndx + n, items) - ] + def batch( + iterable_1, + iterable_2, + n: Optional[int] = 1, + iterable_metrics: Optional = None, + ): + if n is None: + yield iterable_1, iterable_2 + return + elif iterable_metrics is None: + items = len(iterable_1) + for ndx in range(0, items, n): + yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ + ndx : min(ndx + n, items) + ] + else: + batch_idx = 0 + batch_weight = 0 - await self._flush_trigger.wait() + for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)): + batch_weight += weight + + if batch_weight >= n: + yield iterable_1[batch_idx : i + 1], iterable_2[ + batch_idx : i + 1 + ] + batch_idx = i + 1 + batch_weight = 0 + + # Yield any remaining items + if batch_weight > 0: + yield iterable_1[batch_idx : len(iterable_1)], iterable_2[ + batch_idx : len(iterable_1) + ] + await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive - async with self._data_lock: - # At this moment, we have documents concatenated in self._big_doc corresponding to requests in - # self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to - # communicate that the request has been processed properly. At this stage the data_lock is ours and - # therefore noone can add requests to this list. - self._flush_trigger: Event = Event() + big_doc_in_batch = copy.copy(self._big_doc) + requests_idxs_in_batch = copy.copy(self._request_idxs) + requests_lens_in_batch = copy.copy(self._request_lens) + docs_metrics_in_batch = copy.copy(self._docs_metrics) + requests_in_batch = copy.copy(self._requests) + requests_completed_in_batch = copy.copy(self._requests_completed) + + self._reset() + + # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in + # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to + # communicate that the request has been processed properly. + + if not docarray_v2: + non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() + else: + non_assigned_to_response_docs = self._response_docarray_cls() + + non_assigned_to_response_request_idxs = [] + sum_from_previous_first_req_idx = 0 + for docs_inner_batch, req_idxs in batch( + big_doc_in_batch, + requests_idxs_in_batch, + self._preferred_batch_size if not self._flush_all else None, + docs_metrics_in_batch if self._custom_metric is not None else None, + ): + involved_requests_min_indx = req_idxs[0] + involved_requests_max_indx = req_idxs[-1] + input_len_before_call: int = len(docs_inner_batch) + batch_res_docs = None try: - if not docarray_v2: - non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() - else: - non_assigned_to_response_docs = self._response_docarray_cls() - non_assigned_to_response_request_idxs = [] - sum_from_previous_first_req_idx = 0 - for docs_inner_batch, req_idxs in batch( - self._big_doc, self._request_idxs, self._preferred_batch_size + batch_res_docs = await self.func( + docs=docs_inner_batch, + parameters=self.params, + docs_matrix=None, # joining manually with batch queue is not supported right now + tracing_context=None, + ) + # Output validation + if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( + not docarray_v2 and isinstance(batch_res_docs, DocumentArray) ): - involved_requests_min_indx = req_idxs[0] - involved_requests_max_indx = req_idxs[-1] - input_len_before_call: int = len(docs_inner_batch) - batch_res_docs = None - try: - batch_res_docs = await self.func( - docs=docs_inner_batch, - parameters=self.params, - docs_matrix=None, # joining manually with batch queue is not supported right now - tracing_context=None, - ) - # Output validation - if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) - ): - if not len(batch_res_docs) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' - ) - elif batch_res_docs is None: - if not len(docs_inner_batch) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' - ) - else: - array_name = ( - 'DocumentArray' if not docarray_v2 else 'DocList' - ) - raise TypeError( - f'The return type must be {array_name} / `None` when using dynamic batching, ' - f'but getting {batch_res_docs!r}' - ) - except Exception as exc: - # All the requests containing docs in this Exception should be raising it - for request_full in self._requests_completed[ - involved_requests_min_indx : involved_requests_max_indx + 1 - ]: - await request_full.put(exc) - else: - # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend( - batch_res_docs or docs_inner_batch - ) - non_assigned_to_response_request_idxs.extend(req_idxs) - num_assigned_docs = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, - ) - - sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs + if not len(batch_res_docs) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' ) - non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] - non_assigned_to_response_request_idxs = ( - non_assigned_to_response_request_idxs[num_assigned_docs:] + elif batch_res_docs is None: + if not len(docs_inner_batch) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' ) - if len(non_assigned_to_response_request_idxs) > 0: - _ = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, + else: + array_name = 'DocumentArray' if not docarray_v2 else 'DocList' + raise TypeError( + f'The return type must be {array_name} / `None` when using dynamic batching, ' + f'but getting {batch_res_docs!r}' ) - finally: - self._reset() + except Exception as exc: + # All the requests containing docs in this Exception should be raising it + for request_full in requests_completed_in_batch[ + involved_requests_min_indx : involved_requests_max_indx + 1 + ]: + await request_full.put(exc) + else: + # We need to attribute the docs to their requests + non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch) + non_assigned_to_response_request_idxs.extend(req_idxs) + num_assigned_docs = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) + + sum_from_previous_first_req_idx = ( + len(non_assigned_to_response_docs) - num_assigned_docs + ) + non_assigned_to_response_docs = non_assigned_to_response_docs[ + num_assigned_docs: + ] + non_assigned_to_response_request_idxs = ( + non_assigned_to_response_request_idxs[num_assigned_docs:] + ) + if len(non_assigned_to_response_request_idxs) > 0: + _ = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) async def close(self): """Closes the batch queue by flushing pending requests.""" diff --git a/marie/serve/runtimes/worker/http_csp_app.py b/marie/serve/runtimes/worker/http_csp_app.py index 6fc3743b..0746f415 100644 --- a/marie/serve/runtimes/worker/http_csp_app.py +++ b/marie/serve/runtimes/worker/http_csp_app.py @@ -191,7 +191,16 @@ def construct_model_from_line( parsed_fields[field_name] = parsed_list # Handle direct assignment for basic types else: - parsed_fields[field_name] = field_info.type_(field_str) + if field_str: + try: + parsed_fields[field_name] = field_info.type_( + field_str + ) + except (ValueError, TypeError): + # Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes + parsed_fields[field_name] = parse_obj_as( + field_info.type_, field_str + ) return model(**parsed_fields) diff --git a/marie/serve/runtimes/worker/http_fastapi_app.py b/marie/serve/runtimes/worker/http_fastapi_app.py index 5ff41392..1561f213 100644 --- a/marie/serve/runtimes/worker/http_fastapi_app.py +++ b/marie/serve/runtimes/worker/http_fastapi_app.py @@ -99,16 +99,18 @@ async def post(body: input_model, response: Response): data = body.data if isinstance(data, list): if not docarray_v2: - req.data.docs = DocumentArray.from_pydantic_model(data) + req.direct_docs = DocumentArray.from_pydantic_model(data) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model](data) + req.direct_docs = DocList[input_doc_list_model](data) else: if not docarray_v2: - req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) + req.direct_docs = DocumentArray( + [Document.from_pydantic_model(data)] + ) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model]([data]) + req.direct_docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -149,10 +151,10 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): req = DataRequest() req.header.exec_endpoint = endpoint_path if not docarray_v2: - req.data.docs = DocumentArray([body]) + req.direct_docs = DocumentArray([body]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]([body]) + req.direct_docs = DocList[input_doc_model]([body]) event_generator = _gen_dict_documents(await caller(req)) return EventSourceResponse(event_generator) diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index 9f636533..b47f8e22 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -22,8 +22,6 @@ ) from google.protobuf.struct_pb2 import Struct -from IPython.terminal.shortcuts.auto_suggest import discard -from tvm.relay.backend.interpreter import Executor from marie._docarray import DocumentArray, docarray_v2 from marie.constants import __default_endpoint__ @@ -189,7 +187,9 @@ def call_handle(request): "is_generator" ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data( + request, None, http=True, is_generator=is_generator + ) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -213,7 +213,9 @@ def call_handle(request): "is_generator" ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data( + request, None, http=True, is_generator=is_generator + ) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -275,9 +277,24 @@ def _init_batchqueue_dict(self): if getattr(self._executor, "dynamic_batching", None) is not None: # We need to sort the keys into endpoints and functions # Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function + self.logger.debug( + f"Executor Dynamic Batching configs: {self._executor.dynamic_batching}" + ) dbatch_endpoints = [] dbatch_functions = [] + request_models_map = self._executor._get_endpoint_models_dict() + for key, dbatch_config in self._executor.dynamic_batching.items(): + if ( + request_models_map.get(key, {}) + .get("parameters", {}) + .get("model", None) + is not None + ): + error_msg = f"Executor Dynamic Batching cannot be used for endpoint {key} because it depends on parameters." + self.logger.error(error_msg) + raise Exception(error_msg) + if key.startswith("/"): dbatch_endpoints.append((key, dbatch_config)) else: @@ -297,10 +314,21 @@ def _init_batchqueue_dict(self): for endpoint in func_endpoints[func_name]: if endpoint not in self._batchqueue_config: self._batchqueue_config[endpoint] = dbatch_config + else: + # we need to eventually copy the `custom_metric` + if dbatch_config.get("custom_metric", None) is not None: + self._batchqueue_config[endpoint]["custom_metric"] = ( + dbatch_config.get("custom_metric") + ) + + keys_to_remove = [] + for k, batch_config in self._batchqueue_config.items(): + if not batch_config.get("use_dynamic_batching", True): + keys_to_remove.append(k) + + for k in keys_to_remove: + self._batchqueue_config.pop(k) - self.logger.debug( - f"Executor Dynamic Batching configs: {self._executor.dynamic_batching}" - ) self.logger.debug( f"Endpoint Batch Queue Configs: {self._batchqueue_config}" ) @@ -404,6 +432,7 @@ def _load_executor( "metrics_registry": metrics_registry, "tracer_provider": tracer_provider, "meter_provider": meter_provider, + "allow_concurrent": self.args.allow_concurrent, }, py_modules=self.args.py_modules, extra_search_paths=self.args.extra_search_paths, @@ -553,7 +582,7 @@ def _record_response_size_monitoring(self, requests): requests[0].nbytes, attributes=attributes ) - def _set_result(self, requests, return_data, docs): + def _set_result(self, requests, return_data, docs, http=False): # assigning result back to request if return_data is not None: if isinstance(return_data, DocumentArray): @@ -574,10 +603,12 @@ def _set_result(self, requests, return_data, docs): f"The return type must be DocList / Dict / `None`, " f"but getting {return_data!r}" ) - - WorkerRequestHandler.replace_docs( - requests[0], docs, self.args.output_array_type - ) + if not http: + WorkerRequestHandler.replace_docs( + requests[0], docs, self.args.output_array_type + ) + else: + requests[0].direct_docs = docs return docs def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): @@ -665,11 +696,15 @@ async def handle_generator( ) async def handle( - self, requests: List["DataRequest"], tracing_context: Optional["Context"] = None + self, + requests: List["DataRequest"], + http=False, + tracing_context: Optional["Context"] = None, ) -> DataRequest: """Initialize private parameters and execute private loading functions. :param requests: The messages to handle containing a DataRequest + :param http: Flag indicating if it is used by the HTTP server for some optims :param tracing_context: Optional OpenTelemetry tracing context from the originating request. :returns: the processed message """ @@ -716,7 +751,7 @@ async def handle( ) # This is necessary because push might need to await for the queue to be emptied queue = await self._batchqueue_instances[exec_endpoint][param_key].push( - requests[0] + requests[0], http=http ) item = await queue.get() queue.task_done() @@ -774,7 +809,7 @@ async def executor_completion_callback( executor_completion_callback, job_id, requests ), ) - _ = self._set_result(requests, return_data, docs) + _ = self._set_result(requests, return_data, docs, http=http) except asyncio.CancelledError: self.logger.warning("Task was cancelled due to client disconnect") client_disconnected = True @@ -968,18 +1003,25 @@ def reduce_requests(requests: List["DataRequest"]) -> "DataRequest": # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, + request: DataRequest, + context, + http: bool = False, + is_generator: bool = False, ) -> DataRequest: """ Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ self.logger.debug("recv a process_single_data request") - return await self.process_data([request], context, is_generator=is_generator) + return await self.process_data( + [request], context, http=http, is_generator=is_generator + ) async def stream_doc( self, request: SingleDocumentRequest, context: "grpc.aio.ServicerContext" @@ -1124,13 +1166,18 @@ def _extract_tracing_context( return None async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, + requests: List[DataRequest], + context, + http=False, + is_generator: bool = False, ) -> DataRequest: """ Process the received requests and return the result as a new request :param requests: the data requests to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ @@ -1157,7 +1204,7 @@ async def process_data( ) else: result = await self.handle( - requests=requests, tracing_context=tracing_context + requests=requests, http=http, tracing_context=tracing_context ) if self._successful_requests_metrics: @@ -1230,7 +1277,7 @@ async def stream( :param kwargs: keyword arguments :yield: responses to the request """ - self.logger.debug("recv a stream request from client") + self.logger.debug("recv a stream request") async for request in request_iterator: yield await self.process_data([request], context) diff --git a/marie_cli/autocomplete.py b/marie_cli/autocomplete.py index 79c2a33a..f4692741 100644 --- a/marie_cli/autocomplete.py +++ b/marie_cli/autocomplete.py @@ -46,26 +46,34 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--install-requirements", "--port", "--ports", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', "--monitoring", "--port-monitoring", "--retries", @@ -75,14 +83,16 @@ "--metrics", "--metrics-exporter-host", "--metrics-exporter-port", - "--force-update", - "--force", - "--prefer-platform", - "--compression", - "--uses-before-address", - "--uses-after-address", - "--connection-list", - "--timeout-send", + '--stateful', + '--peer-ports', + '--force-update', + '--force', + '--prefer-platform', + '--compression', + '--uses-before-address', + '--uses-after-address', + '--connection-list', + '--timeout-send', ], "flow": [ "--help", @@ -130,15 +140,13 @@ "--title", "--description", "--cors", - "--no-debug-endpoints", - "--no-crud-endpoints", - "--expose-endpoints", "--uvicorn-kwargs", "--ssl-certfile", "--ssl-keyfile", + '--no-debug-endpoints', + '--no-crud-endpoints', + '--expose-endpoints', "--expose-graphql-endpoint", - "--protocol", - "--protocols", "--host", "--host-in", "--proxy", @@ -164,24 +172,32 @@ "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--port", "--ports", "--port-expose", "--port-in", - "--monitoring", - "--port-monitoring", - "--retries", - "--tracing", - "--traces-exporter-host", - "--traces-exporter-port", - "--metrics", - "--metrics-exporter-host", - "--metrics-exporter-port", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', + '--monitoring', + '--port-monitoring', + '--retries', + '--tracing', + '--traces-exporter-host', + '--traces-exporter-port', + '--metrics', + '--metrics-exporter-host', + '--metrics-exporter-port', + '--stateful', + '--peer-ports', ], "auth login": ["--help", "--force"], "auth logout": ["--help"], @@ -265,43 +281,53 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--install-requirements", "--port", "--ports", - "--monitoring", - "--port-monitoring", - "--retries", - "--tracing", - "--traces-exporter-host", - "--traces-exporter-port", - "--metrics", - "--metrics-exporter-host", - "--metrics-exporter-port", - "--force-update", - "--force", - "--prefer-platform", - "--compression", - "--uses-before-address", - "--uses-after-address", - "--connection-list", - "--timeout-send", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', + '--monitoring', + '--port-monitoring', + '--retries', + '--tracing', + '--traces-exporter-host', + '--traces-exporter-port', + '--metrics', + '--metrics-exporter-host', + '--metrics-exporter-port', + '--stateful', + '--peer-ports', + '--force-update', + '--force', + '--prefer-platform', + '--compression', + '--uses-before-address', + '--uses-after-address', + '--connection-list', + '--timeout-send', ], "deployment": [ "--help", @@ -330,18 +356,21 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", @@ -350,6 +379,10 @@ "--install-requirements", "--port", "--ports", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', "--monitoring", "--port-monitoring", "--retries", @@ -359,6 +392,8 @@ "--metrics", "--metrics-exporter-host", "--metrics-exporter-port", + '--stateful', + '--peer-ports', "--force-update", "--force", "--prefer-platform", @@ -374,6 +409,12 @@ "--grpc-metadata", "--deployment-role", "--tls", + '--title', + '--description', + '--cors', + '--uvicorn-kwargs', + '--ssl-certfile', + '--ssl-keyfile', ], "client": [ "--help", @@ -390,6 +431,7 @@ "--metrics-exporter-host", "--metrics-exporter-port", "--log-config", + '--reuse-session', "--protocol", "--grpc-channel-options", "--prefetch",