diff --git a/extra-requirements.txt b/extra-requirements.txt index df3eadef..eb9fedfb 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -72,19 +72,19 @@ aiofiles: standard,devel aiohttp: standard,devel aiostream: standard,devel -pytest: test +pytest<8.0.0: test pytest-timeout: test pytest-mock: test pytest-cov==3.0.0: test coverage==6.2: test pytest-repeat: test -pytest-asyncio: test +pytest-asyncio<0.23.0: test pytest-reraise: test mock: test requests-mock: test pytest-custom_exit_code: test -black==22.3.0: test -kubernetes>=18.20.0: test +black==24.3.0: test +kubernetes>=18.20.0,<31.0.0: test pytest-kind==22.11.1: test pytest-lazy-fixture: test torch: cicd diff --git a/marie/clients/__init__.py b/marie/clients/__init__.py index 3ddb97a7..99e9608e 100644 --- a/marie/clients/__init__.py +++ b/marie/clients/__init__.py @@ -30,6 +30,7 @@ def Client( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -59,6 +60,7 @@ def Client( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[ Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. diff --git a/marie/clients/base/__init__.py b/marie/clients/base/__init__.py index 845a1deb..001ef52c 100644 --- a/marie/clients/base/__init__.py +++ b/marie/clients/base/__init__.py @@ -5,7 +5,15 @@ import inspect import os from abc import ABC -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union +from typing import ( + TYPE_CHECKING, + AsyncIterator, + Callable, + Iterator, + Optional, + Tuple, + Union, +) from marie.excepts import BadClientInput from marie.helper import T, parse_client, send_telemetry_event, typename @@ -47,7 +55,6 @@ def __init__( # affect users os-level envs. os.unsetenv('http_proxy') os.unsetenv('https_proxy') - self._inputs = None self._setup_instrumentation( name=( self.args.name @@ -63,6 +70,12 @@ def __init__( ) send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__) + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + return self.teardown_instrumentation() + def teardown_instrumentation(self): """Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for exporting metrics data is properly cleaned up. @@ -118,62 +131,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None: raise BadClientInput from ex def _get_requests( - self, **kwargs - ) -> Union[Iterator['Request'], AsyncIterator['Request']]: + self, inputs, **kwargs + ) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]: """ Get request in generator. + :param inputs: The inputs argument to get the requests from. :param kwargs: Keyword arguments. - :return: Iterator of request. + :return: Iterator of request and the length of the inputs. """ _kwargs = vars(self.args) - _kwargs['data'] = self.inputs + if hasattr(inputs, '__call__'): + inputs = inputs() + + _kwargs['data'] = inputs # override by the caller-specific kwargs _kwargs.update(kwargs) - if hasattr(self._inputs, '__len__'): - total_docs = len(self._inputs) + if hasattr(inputs, '__len__'): + total_docs = len(inputs) elif 'total_docs' in _kwargs: total_docs = _kwargs['total_docs'] else: total_docs = None - self._inputs_length = None - if total_docs: - self._inputs_length = max(1, total_docs / _kwargs['request_size']) + inputs_length = max(1, total_docs / _kwargs['request_size']) + else: + inputs_length = None - if inspect.isasyncgen(self.inputs): + if inspect.isasyncgen(inputs): from marie.clients.request.asyncio import request_generator - return request_generator(**_kwargs) + return request_generator(**_kwargs), inputs_length else: from marie.clients.request import request_generator - return request_generator(**_kwargs) - - @property - def inputs(self) -> 'InputType': - """ - An iterator of bytes, each element represents a Document's raw content. - - ``inputs`` defined in the protobuf - - :return: inputs - """ - return self._inputs - - @inputs.setter - def inputs(self, bytes_gen: 'InputType') -> None: - """ - Set the input data. - - :param bytes_gen: input type - """ - if hasattr(bytes_gen, '__call__'): - self._inputs = bytes_gen() - else: - self._inputs = bytes_gen + return request_generator(**_kwargs), inputs_length @abc.abstractmethod async def _get_results( diff --git a/marie/clients/base/grpc.py b/marie/clients/base/grpc.py index 5a3f1089..0c8c5a8f 100644 --- a/marie/clients/base/grpc.py +++ b/marie/clients/base/grpc.py @@ -90,8 +90,7 @@ async def _get_results( else grpc.Compression.NoCompression ) - self.inputs = inputs - req_iter = self._get_requests(**kwargs) + req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs) continue_on_error = self.continue_on_error # while loop with retries, check in which state the `iterator` remains after failure options = client_grpc_options( @@ -120,7 +119,7 @@ async def _get_results( self.logger.debug(f'connected to {self.args.host}:{self.args.port}') with ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) as p_bar: try: if stream: diff --git a/marie/clients/base/helper.py b/marie/clients/base/helper.py index 1a78ff97..eae40d00 100644 --- a/marie/clients/base/helper.py +++ b/marie/clients/base/helper.py @@ -48,18 +48,16 @@ class AioHttpClientlet(ABC): def __init__( self, - url: str, - logger: 'MarieLogger', + logger: "MarieLogger", max_attempts: int = 1, initial_backoff: float = 0.5, max_backoff: float = 2, backoff_multiplier: float = 1.5, - tracer_provider: Optional['trace.TraceProvider'] = None, + tracer_provider: Optional["trace.TraceProvider"] = None, **kwargs, ) -> None: """HTTP Client to be used with the streamer - :param url: url to send http/websocket request to :param logger: jina logger :param max_attempts: Number of sending attempts, including the original request. :param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff) @@ -68,7 +66,6 @@ def __init__( :param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing. :param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests """ - self.url = url self.logger = logger self.msg_recv = 0 self.msg_sent = 0 @@ -80,15 +77,15 @@ def __init__( self._trace_config = None self.session = None self._session_kwargs = {} - if kwargs.get('headers', None): - self._session_kwargs['headers'] = kwargs.get('headers') - if kwargs.get('auth', None): - self._session_kwargs['auth'] = kwargs.get('auth') - if kwargs.get('cookies', None): - self._session_kwargs['cookies'] = kwargs.get('cookies') - if kwargs.get('timeout', None): - timeout = aiohttp.ClientTimeout(total=kwargs.get('timeout')) - self._session_kwargs['timeout'] = timeout + if kwargs.get("headers", None): + self._session_kwargs["headers"] = kwargs.get("headers") + if kwargs.get("auth", None): + self._session_kwargs["auth"] = kwargs.get("auth") + if kwargs.get("cookies", None): + self._session_kwargs["cookies"] = kwargs.get("cookies") + if kwargs.get("timeout", None): + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout")) + self._session_kwargs["timeout"] = timeout self.max_attempts = max_attempts self.initial_backoff = initial_backoff self.max_backoff = max_backoff @@ -154,33 +151,45 @@ class HTTPClientlet(AioHttpClientlet): UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}" - async def send_message(self, request: 'Request'): + async def send_message(self, url, request: "Request"): """Sends a POST request to the server + :param url: the URL where to send the message :param request: request as dict :return: send post message """ req_dict = request.to_dict() - req_dict['exec_endpoint'] = req_dict['header']['exec_endpoint'] - if 'target_executor' in req_dict['header']: - req_dict['target_executor'] = req_dict['header']['target_executor'] + req_dict["exec_endpoint"] = req_dict["header"]["exec_endpoint"] + if "target_executor" in req_dict["header"]: + req_dict["target_executor"] = req_dict["header"]["target_executor"] for attempt in range(1, self.max_attempts + 1): try: - request_kwargs = {'url': self.url} + request_kwargs = {"url": url} if not docarray_v2: - request_kwargs['json'] = req_dict + request_kwargs["json"] = req_dict else: from docarray.base_doc.io.json import orjson_dumps - request_kwargs['data'] = JinaJsonPayload(value=req_dict) - response = await self.session.post(**request_kwargs).__aenter__() - try: - r_str = await response.json() - except aiohttp.ContentTypeError: - r_str = await response.text() - handle_response_status(response.status, r_str, self.url) - return response - except (ValueError, ConnectionError, BadClient, aiohttp.ClientError) as err: + request_kwargs["data"] = JinaJsonPayload(value=req_dict) + + async with self.session.post(**request_kwargs) as response: + try: + r_str = await response.json() + except aiohttp.ContentTypeError: + r_str = await response.text() + r_status = response.status + handle_response_status(r_status, r_str, url) + return r_status, r_str + except ( + ValueError, + ConnectionError, + BadClient, + aiohttp.ClientError, + aiohttp.ClientConnectionError, + ) as err: + self.logger.debug( + f"Got an error of type {type(err)}: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}" + ) await retry.wait_or_raise_err( attempt=attempt, err=err, @@ -189,37 +198,44 @@ async def send_message(self, request: 'Request'): initial_backoff=self.initial_backoff, max_backoff=self.max_backoff, ) + except Exception as exc: + self.logger.debug( + f"Got a non-retried error of type {type(exc)}: {exc} sending POST to {url}" + ) + raise exc - async def send_streaming_message(self, doc: 'Document', on: str): + async def send_streaming_message(self, url, doc: "Document", on: str): """Sends a GET SSE request to the server + :param url: the URL where to send the message :param doc: Request Document :param on: Request endpoint :yields: responses """ req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict() request_kwargs = { - 'url': self.url, - 'headers': {'Accept': 'text/event-stream'}, - 'json': req_dict, + "url": url, + "headers": {"Accept": "text/event-stream"}, + "json": req_dict, } async with self.session.get(**request_kwargs) as response: async for chunk in response.content.iter_any(): - events = chunk.split(b'event: ')[1:] + events = chunk.split(b"event: ")[1:] for event in events: - if event.startswith(b'update'): + if event.startswith(b"update"): yield event[self.UPDATE_EVENT_PREFIX :].decode() - elif event.startswith(b'end'): + elif event.startswith(b"end"): pass - async def send_dry_run(self, **kwargs): + async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway + :param url: the URL where to send the message :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ return await self.session.get( - url=self.url, timeout=kwargs.get('timeout', None) + url=url, timeout=kwargs.get("timeout", None) ).__aenter__() async def recv_message(self): @@ -261,12 +277,13 @@ async def __anext__(self): class WebsocketClientlet(AioHttpClientlet): """Websocket Client to be used with the streamer""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, url, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.url = url self.websocket = None self.response_iter = None - async def send_message(self, request: 'Request'): + async def send_message(self, request: "Request"): """Send request in bytes to the server. :param request: request object @@ -292,9 +309,9 @@ async def send_dry_run(self, **kwargs): """ try: - return await self.websocket.send_bytes(b'') + return await self.websocket.send_bytes(b"") except ConnectionResetError: - self.logger.critical(f'server connection closed already!') + self.logger.critical(f"server connection closed already!") async def send_eoi(self): """To confirm end of iteration, we send `bytes(True)` to the server. @@ -308,7 +325,7 @@ async def send_eoi(self): # which raises a `ConnectionResetError`, this can be ignored. pass - async def recv_message(self) -> 'DataRequest': + async def recv_message(self) -> "DataRequest": """Receive messages in bytes from server and convert to `DataRequest` ..note:: @@ -376,18 +393,18 @@ def handle_response_status( :param url: request url string """ if http_status == status.HTTP_404_NOT_FOUND: - raise BadClient(f'no such endpoint {url}') + raise BadClient(f"no such endpoint {url}") elif ( http_status == status.HTTP_503_SERVICE_UNAVAILABLE or http_status == status.HTTP_504_GATEWAY_TIMEOUT ): if ( isinstance(response_content, dict) - and 'header' in response_content - and 'status' in response_content['header'] - and 'description' in response_content['header']['status'] + and "header" in response_content + and "status" in response_content["header"] + and "description" in response_content["header"]["status"] ): - raise ConnectionError(response_content['header']['status']['description']) + raise ConnectionError(response_content["header"]["status"]["description"]) else: raise ValueError(response_content) elif ( diff --git a/marie/clients/base/http.py b/marie/clients/base/http.py index b4c2b27f..a4b22c4b 100644 --- a/marie/clients/base/http.py +++ b/marie/clients/base/http.py @@ -23,8 +23,20 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self.reuse_session = False + self._lock = AsyncExitStack() + self.iolet = None - async def _get_endpoints_from_openapi(self): + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + ret = super().close() + if self.iolet is not None: + await self.iolet.__aexit__(None, None, None) + return ret + + async def _get_endpoints_from_openapi(self, **kwargs): def extract_paths_by_method(spec): paths_by_method = {} for path, methods in spec['paths'].items(): @@ -39,10 +51,15 @@ def extract_paths_by_method(spec): import aiohttp + session_kwargs = {} + if 'headers' in kwargs: + session_kwargs = {'headers': kwargs['headers']} + proto = 'https' if self.args.tls else 'http' target_url = f'{proto}://{self.args.host}:{self.args.port}/openapi.json' try: - async with aiohttp.ClientSession() as session: + + async with aiohttp.ClientSession(**session_kwargs) as session: async with session.get(target_url) as response: content = await response.read() openapi_response = json.loads(content.decode()) @@ -64,16 +81,27 @@ async def _is_flow_ready(self, **kwargs) -> bool: try: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - **kwargs, + + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet - response = await iolet.send_dry_run(**kwargs) + response = await iolet.send_dry_run(url=url, **kwargs) r_status = response.status r_str = await response.json() @@ -125,15 +153,14 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) on = kwargs.get('on', '/post') if len(self._endpoints) == 0: - await self._get_endpoints_from_openapi() + await self._get_endpoints_from_openapi(**kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.tls else 'http' @@ -146,19 +173,35 @@ async def _get_results( url = f'{proto}://{self.args.host}:{self.args.port}/default' else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - max_attempts=max_attempts, - initial_backoff=initial_backoff, - max_backoff=max_backoff, - backoff_multiplier=backoff_multiplier, - timeout=timeout, - **kwargs, + + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + max_attempts=max_attempts, + initial_backoff=initial_backoff, + max_backoff=max_backoff, + backoff_multiplier=backoff_multiplier, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + max_attempts=max_attempts, + initial_backoff=initial_backoff, + max_backoff=max_backoff, + backoff_multiplier=backoff_multiplier, + **kwargs, + ) + self.iolet = await self.iolet.__aenter__() + iolet = self.iolet def _request_handler( request: 'Request', **kwargs @@ -170,7 +213,10 @@ def _request_handler( :param kwargs: kwargs :return: asyncio Task for sending message """ - return asyncio.ensure_future(iolet.send_message(request=request)), None + return ( + asyncio.ensure_future(iolet.send_message(url=url, request=request)), + None, + ) def _result_handler(result): return result @@ -187,9 +233,7 @@ def _result_handler(result): async for response in streamer.stream( request_iterator=request_iterator, results_in_order=results_in_order ): - r_status = response.status - - r_str = await response.json() + r_status, r_str = response handle_response_status(r_status, r_str, url) da = None @@ -213,7 +257,7 @@ def _result_handler(result): resp = DataRequest(r_str) if da is not None: - resp.data.docs = da + resp.direct_docs = da callback_exec( response=resp, @@ -244,17 +288,28 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}' else: url = f'{proto}://{self.args.host}:{self.args.port}/default' - - iolet = HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - timeout=timeout, - **kwargs, - ) - - async with iolet: - async for doc in iolet.send_streaming_message(doc=inputs, on=on): + async with AsyncExitStack() as stack: + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet + async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc)) else: diff --git a/marie/clients/base/websocket.py b/marie/clients/base/websocket.py index a37d2c32..8b565a47 100644 --- a/marie/clients/base/websocket.py +++ b/marie/clients/base/websocket.py @@ -108,12 +108,11 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not (self.show_progress) + total_length=inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1) diff --git a/marie/clients/http.py b/marie/clients/http.py index fb1fd92f..fa102bd6 100644 --- a/marie/clients/http.py +++ b/marie/clients/http.py @@ -1,3 +1,5 @@ +import asyncio + from marie.clients.base.http import HTTPBaseClient from marie.clients.mixin import ( AsyncHealthCheckMixin, @@ -81,3 +83,8 @@ async def async_inputs(): print(resp) """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lock = asyncio.Lock() + self.reuse_session = self.args.reuse_session diff --git a/marie/clients/request/helper.py b/marie/clients/request/helper.py index 10de681b..92481f51 100644 --- a/marie/clients/request/helper.py +++ b/marie/clients/request/helper.py @@ -1,4 +1,5 @@ """Module for helper functions for clients.""" + from typing import Optional, Tuple from marie._docarray import Document, DocumentArray, docarray_v2 diff --git a/marie/orchestrate/flow/base.py b/marie/orchestrate/flow/base.py index ba8e2273..b2795ee5 100644 --- a/marie/orchestrate/flow/base.py +++ b/marie/orchestrate/flow/base.py @@ -142,6 +142,7 @@ def __init__( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -164,6 +165,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -426,6 +428,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -475,7 +478,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -1148,7 +1151,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor @@ -1517,7 +1520,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE']. :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -2903,9 +2906,9 @@ def to_docker_compose_yaml( yaml.dump(docker_compose_dict, fp, sort_keys=False) command = ( - 'docker-compose up' + 'docker compose up' if output_path is None - else f'docker-compose -f {output_path} up' + else f'docker compose -f {output_path} up' ) self.logger.info( diff --git a/marie/parsers/client.py b/marie/parsers/client.py index 7b09eead..9bffb5cf 100644 --- a/marie/parsers/client.py +++ b/marie/parsers/client.py @@ -81,3 +81,9 @@ def mixin_client_features_parser(parser): default='default', help='The config name or the absolute path to the YAML config file of the logger used in this object.', ) + parser.add_argument( + '--reuse-session', + action='store_true', + default=False, + help='True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it', + ) diff --git a/marie/resources/logging.plain.yml b/marie/resources/logging.plain.yml new file mode 100644 index 00000000..747061d1 --- /dev/null +++ b/marie/resources/logging.plain.yml @@ -0,0 +1,7 @@ +handlers: + - StreamHandler +level: INFO +configs: + StreamHandler: + format: '{name:>15}@%(process)2d[%(levelname).1s]:%(message)s' + formatter: PlainFormatter diff --git a/marie/resources/project-template/deployment/client.py b/marie/resources/project-template/deployment/client.py index 897dc31b..ec1a4f2f 100644 --- a/marie/resources/project-template/deployment/client.py +++ b/marie/resources/project-template/deployment/client.py @@ -4,5 +4,7 @@ if __name__ == '__main__': c = Client(host='grpc://0.0.0.0:54321') - da = c.post('/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc]) + da = c.post( + '/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc] + ) print(da.text) diff --git a/marie/resources/project-template/flow/client.py b/marie/resources/project-template/flow/client.py index 897dc31b..ec1a4f2f 100644 --- a/marie/resources/project-template/flow/client.py +++ b/marie/resources/project-template/flow/client.py @@ -4,5 +4,7 @@ if __name__ == '__main__': c = Client(host='grpc://0.0.0.0:54321') - da = c.post('/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc]) + da = c.post( + '/', DocList[TextDoc]([TextDoc(), TextDoc()]), return_type=DocList[TextDoc] + ) print(da.text) diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py index f3116483..cd4c9df1 100644 --- a/marie/serve/executors/__init__.py +++ b/marie/serve/executors/__init__.py @@ -658,9 +658,22 @@ def _validate_sagemaker(self): return def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): + from collections.abc import Mapping + + def deep_update(source, overrides): + for key, value in overrides.items(): + if isinstance(value, Mapping) and value: + returned = deep_update(source.get(key, {}), value) + source[key] = returned + else: + source[key] = overrides[key] + return source + if _dynamic_batching: - self.dynamic_batching = getattr(self, "dynamic_batching", {}) - self.dynamic_batching.update(_dynamic_batching) + self.dynamic_batching = getattr(self, 'dynamic_batching', {}) + self.dynamic_batching = deep_update( + self.dynamic_batching, _dynamic_batching + ) def _add_metas(self, _metas: Optional[Dict]): from marie.serve.executors.metas import get_default_metas diff --git a/marie/serve/executors/decorators.py b/marie/serve/executors/decorators.py index 64896816..55c1be8a 100644 --- a/marie/serve/executors/decorators.py +++ b/marie/serve/executors/decorators.py @@ -1,11 +1,12 @@ """Decorators and wrappers designed for wrapping :class:`BaseExecutor` functions. """ + import functools import inspect import os from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union -from marie._docarray import Document, DocumentArray, docarray_v2 +from marie._docarray import Document, DocumentArray from marie.constants import __cache_path__ from marie.helper import is_generator, iscoroutinefunction from marie.importer import ImportExtensions @@ -415,6 +416,10 @@ def dynamic_batching( *, preferred_batch_size: Optional[int] = None, timeout: Optional[float] = 10_000, + flush_all: bool = False, + custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None, + use_custom_metric: bool = False, + use_dynamic_batching: bool = True, ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -425,11 +430,16 @@ def dynamic_batching( :param func: the method to decorate :param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached, - or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`. + or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True :param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch. If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor, even if it contains fewer than `preferred_batch_size` Documents. Default is 10_000ms (10 seconds). + :param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not. + If this is true, `preferred_batch_size` is used as a trigger mechanism. + :param custom_metric: Potential lambda function to measure the "weight" of each request. + :param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size. + :param use_dynamic_batching: Determines if we should apply dynamic batching for this method. :return: decorated function """ @@ -475,6 +485,12 @@ def _inject_owner_attrs(self, owner, name): 'preferred_batch_size' ] = preferred_batch_size owner.dynamic_batching[fn_name]['timeout'] = timeout + owner.dynamic_batching[fn_name]['flush_all'] = flush_all + owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric + owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric + owner.dynamic_batching[fn_name][ + 'use_dynamic_batching' + ] = use_dynamic_batching setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/marie/serve/runtimes/worker/batch_queue.py b/marie/serve/runtimes/worker/batch_queue.py index 8999a09c..357af64a 100644 --- a/marie/serve/runtimes/worker/batch_queue.py +++ b/marie/serve/runtimes/worker/batch_queue.py @@ -1,8 +1,9 @@ import asyncio +import copy from asyncio import Event, Task -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from marie._docarray import docarray_v2 +from jina._docarray import docarray_v2 if not docarray_v2: from docarray import DocumentArray @@ -25,10 +26,14 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, + flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, + custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None, + use_custom_metric: bool = False, + **kwargs, ) -> None: - self._data_lock = asyncio.Lock() + # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent self.func = func if params is None: params = dict() @@ -37,7 +42,10 @@ def __init__( self.params = params self._request_docarray_cls = request_docarray_cls self._response_docarray_cls = response_docarray_cls + self._flush_all = flush_all self._preferred_batch_size: int = preferred_batch_size + self._custom_metric = None if not use_custom_metric else custom_metric + self._metric_value = 0 self._timeout: int = timeout self._reset() self._flush_trigger: Event = Event() @@ -56,13 +64,16 @@ def _reset(self) -> None: # a list of every request ID self._request_idxs: List[int] = [] self._request_lens: List[int] = [] + self._docs_metrics: List[int] = [] self._requests_completed: List[asyncio.Queue] = [] if not docarray_v2: self._big_doc: DocumentArray = DocumentArray.empty() else: self._big_doc = self._request_docarray_cls() + self._metric_value = 0 self._flush_task: Optional[Task] = None + self._flush_trigger: Event = Event() def _cancel_timer_if_pending(self): if ( @@ -84,13 +95,14 @@ async def _sleep_then_set(self): self._flush_trigger.set() self._timer_finished = True - async def push(self, request: DataRequest) -> asyncio.Queue: + async def push(self, request: DataRequest, http=False) -> asyncio.Queue: """Append request to the the list of requests to be processed. This method creates an asyncio Queue for that request and keeps track of it. It returns this queue to the caller so that the caller can now when this request has been processed :param request: The request to append to the queue. + :param http: Flag to determine if the request is served via HTTP for some optims :return: The queue that will receive when the request is processed. """ @@ -101,30 +113,37 @@ async def push(self, request: DataRequest) -> asyncio.Queue: # this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc` # before the `flush` task processes it. self._start_timer() - async with self._data_lock: - if not self._flush_task: - self._flush_task = asyncio.create_task(self._await_then_flush()) - - self._big_doc.extend(docs) - next_req_idx = len(self._requests) - num_docs = len(docs) - self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(len(docs)) - self._requests.append(request) - queue = asyncio.Queue() - self._requests_completed.append(queue) - if len(self._big_doc) >= self._preferred_batch_size: - self._flush_trigger.set() + if not self._flush_task: + self._flush_task = asyncio.create_task(self._await_then_flush(http)) + self._big_doc.extend(docs) + next_req_idx = len(self._requests) + num_docs = len(docs) + metric_value = num_docs + if self._custom_metric is not None: + metrics = [self._custom_metric(doc) for doc in docs] + metric_value += sum(metrics) + self._docs_metrics.extend(metrics) + self._metric_value += metric_value + self._request_idxs.extend([next_req_idx] * num_docs) + self._request_lens.append(num_docs) + self._requests.append(request) + queue = asyncio.Queue() + self._requests_completed.append(queue) + if self._metric_value >= self._preferred_batch_size: + self._flush_trigger.set() return queue - async def _await_then_flush(self) -> None: - """Process all requests in the queue once flush_trigger event is set.""" + async def _await_then_flush(self, http=False) -> None: + """Process all requests in the queue once flush_trigger event is set. + :param http: Flag to determine if the request is served via HTTP for some optims + """ def _get_docs_groups_completed_request_indexes( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ): """ This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to. @@ -133,6 +152,7 @@ def _get_docs_groups_completed_request_indexes( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. :return: list of document groups and a list of request Idx to which each of these groups belong """ @@ -161,7 +181,7 @@ def _get_docs_groups_completed_request_indexes( if ( req_idx not in completed_req_idx and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx - == self._request_lens[req_idx] + == requests_lens_in_batch[req_idx] ): completed_req_idx.append(req_idx) request_bucket = non_assigned_docs[ @@ -175,6 +195,9 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, ): """ This method aims to assign to the corresponding request objects the resulting documents from the mini batches. @@ -184,6 +207,9 @@ async def _assign_results( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. + :param requests_in_batch: List requests in batch + :param requests_completed_in_batch: List of queues for requests to be completed :return: amount of assigned documents so that some documents can come back in the next iteration """ @@ -194,113 +220,154 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ) num_assigned_docs = sum(len(group) for group in docs_grouped) for docs_group, request_idx in zip(docs_grouped, completed_req_idxs): - request = self._requests[request_idx] - request_completed = self._requests_completed[request_idx] - request.data.set_docs_convert_arrays( - docs_group, ndarray_type=self._output_array_type - ) + request = requests_in_batch[request_idx] + request_completed = requests_completed_in_batch[request_idx] + if http is False or self._output_array_type is not None: + request.direct_docs = None # batch queue will work in place, therefore result will need to read from data. + request.data.set_docs_convert_arrays( + docs_group, ndarray_type=self._output_array_type + ) + else: + request.direct_docs = docs_group await request_completed.put(None) return num_assigned_docs - def batch(iterable_1, iterable_2, n=1): - items = len(iterable_1) - for ndx in range(0, items, n): - yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ - ndx : min(ndx + n, items) - ] + def batch( + iterable_1, + iterable_2, + n: Optional[int] = 1, + iterable_metrics: Optional = None, + ): + if n is None: + yield iterable_1, iterable_2 + return + elif iterable_metrics is None: + items = len(iterable_1) + for ndx in range(0, items, n): + yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ + ndx : min(ndx + n, items) + ] + else: + batch_idx = 0 + batch_weight = 0 - await self._flush_trigger.wait() + for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)): + batch_weight += weight + + if batch_weight >= n: + yield iterable_1[batch_idx : i + 1], iterable_2[ + batch_idx : i + 1 + ] + batch_idx = i + 1 + batch_weight = 0 + + # Yield any remaining items + if batch_weight > 0: + yield iterable_1[batch_idx : len(iterable_1)], iterable_2[ + batch_idx : len(iterable_1) + ] + await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive - async with self._data_lock: - # At this moment, we have documents concatenated in self._big_doc corresponding to requests in - # self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to - # communicate that the request has been processed properly. At this stage the data_lock is ours and - # therefore noone can add requests to this list. - self._flush_trigger: Event = Event() + big_doc_in_batch = copy.copy(self._big_doc) + requests_idxs_in_batch = copy.copy(self._request_idxs) + requests_lens_in_batch = copy.copy(self._request_lens) + docs_metrics_in_batch = copy.copy(self._docs_metrics) + requests_in_batch = copy.copy(self._requests) + requests_completed_in_batch = copy.copy(self._requests_completed) + + self._reset() + + # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in + # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to + # communicate that the request has been processed properly. + + if not docarray_v2: + non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() + else: + non_assigned_to_response_docs = self._response_docarray_cls() + + non_assigned_to_response_request_idxs = [] + sum_from_previous_first_req_idx = 0 + for docs_inner_batch, req_idxs in batch( + big_doc_in_batch, + requests_idxs_in_batch, + self._preferred_batch_size if not self._flush_all else None, + docs_metrics_in_batch if self._custom_metric is not None else None, + ): + involved_requests_min_indx = req_idxs[0] + involved_requests_max_indx = req_idxs[-1] + input_len_before_call: int = len(docs_inner_batch) + batch_res_docs = None try: - if not docarray_v2: - non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() - else: - non_assigned_to_response_docs = self._response_docarray_cls() - non_assigned_to_response_request_idxs = [] - sum_from_previous_first_req_idx = 0 - for docs_inner_batch, req_idxs in batch( - self._big_doc, self._request_idxs, self._preferred_batch_size + batch_res_docs = await self.func( + docs=docs_inner_batch, + parameters=self.params, + docs_matrix=None, # joining manually with batch queue is not supported right now + tracing_context=None, + ) + # Output validation + if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( + not docarray_v2 and isinstance(batch_res_docs, DocumentArray) ): - involved_requests_min_indx = req_idxs[0] - involved_requests_max_indx = req_idxs[-1] - input_len_before_call: int = len(docs_inner_batch) - batch_res_docs = None - try: - batch_res_docs = await self.func( - docs=docs_inner_batch, - parameters=self.params, - docs_matrix=None, # joining manually with batch queue is not supported right now - tracing_context=None, - ) - # Output validation - if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) - ): - if not len(batch_res_docs) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' - ) - elif batch_res_docs is None: - if not len(docs_inner_batch) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' - ) - else: - array_name = ( - 'DocumentArray' if not docarray_v2 else 'DocList' - ) - raise TypeError( - f'The return type must be {array_name} / `None` when using dynamic batching, ' - f'but getting {batch_res_docs!r}' - ) - except Exception as exc: - # All the requests containing docs in this Exception should be raising it - for request_full in self._requests_completed[ - involved_requests_min_indx : involved_requests_max_indx + 1 - ]: - await request_full.put(exc) - else: - # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend( - batch_res_docs or docs_inner_batch - ) - non_assigned_to_response_request_idxs.extend(req_idxs) - num_assigned_docs = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, - ) - - sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs + if not len(batch_res_docs) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' ) - non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] - non_assigned_to_response_request_idxs = ( - non_assigned_to_response_request_idxs[num_assigned_docs:] + elif batch_res_docs is None: + if not len(docs_inner_batch) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' ) - if len(non_assigned_to_response_request_idxs) > 0: - _ = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, + else: + array_name = 'DocumentArray' if not docarray_v2 else 'DocList' + raise TypeError( + f'The return type must be {array_name} / `None` when using dynamic batching, ' + f'but getting {batch_res_docs!r}' ) - finally: - self._reset() + except Exception as exc: + # All the requests containing docs in this Exception should be raising it + for request_full in requests_completed_in_batch[ + involved_requests_min_indx : involved_requests_max_indx + 1 + ]: + await request_full.put(exc) + else: + # We need to attribute the docs to their requests + non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch) + non_assigned_to_response_request_idxs.extend(req_idxs) + num_assigned_docs = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) + + sum_from_previous_first_req_idx = ( + len(non_assigned_to_response_docs) - num_assigned_docs + ) + non_assigned_to_response_docs = non_assigned_to_response_docs[ + num_assigned_docs: + ] + non_assigned_to_response_request_idxs = ( + non_assigned_to_response_request_idxs[num_assigned_docs:] + ) + if len(non_assigned_to_response_request_idxs) > 0: + _ = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) async def close(self): """Closes the batch queue by flushing pending requests.""" diff --git a/marie/serve/runtimes/worker/http_csp_app.py b/marie/serve/runtimes/worker/http_csp_app.py index 6fc3743b..0746f415 100644 --- a/marie/serve/runtimes/worker/http_csp_app.py +++ b/marie/serve/runtimes/worker/http_csp_app.py @@ -191,7 +191,16 @@ def construct_model_from_line( parsed_fields[field_name] = parsed_list # Handle direct assignment for basic types else: - parsed_fields[field_name] = field_info.type_(field_str) + if field_str: + try: + parsed_fields[field_name] = field_info.type_( + field_str + ) + except (ValueError, TypeError): + # Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes + parsed_fields[field_name] = parse_obj_as( + field_info.type_, field_str + ) return model(**parsed_fields) diff --git a/marie/serve/runtimes/worker/http_fastapi_app.py b/marie/serve/runtimes/worker/http_fastapi_app.py index 5ff41392..1561f213 100644 --- a/marie/serve/runtimes/worker/http_fastapi_app.py +++ b/marie/serve/runtimes/worker/http_fastapi_app.py @@ -99,16 +99,18 @@ async def post(body: input_model, response: Response): data = body.data if isinstance(data, list): if not docarray_v2: - req.data.docs = DocumentArray.from_pydantic_model(data) + req.direct_docs = DocumentArray.from_pydantic_model(data) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model](data) + req.direct_docs = DocList[input_doc_list_model](data) else: if not docarray_v2: - req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) + req.direct_docs = DocumentArray( + [Document.from_pydantic_model(data)] + ) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model]([data]) + req.direct_docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -149,10 +151,10 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): req = DataRequest() req.header.exec_endpoint = endpoint_path if not docarray_v2: - req.data.docs = DocumentArray([body]) + req.direct_docs = DocumentArray([body]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]([body]) + req.direct_docs = DocList[input_doc_model]([body]) event_generator = _gen_dict_documents(await caller(req)) return EventSourceResponse(event_generator) diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py index 9f636533..b47f8e22 100644 --- a/marie/serve/runtimes/worker/request_handling.py +++ b/marie/serve/runtimes/worker/request_handling.py @@ -22,8 +22,6 @@ ) from google.protobuf.struct_pb2 import Struct -from IPython.terminal.shortcuts.auto_suggest import discard -from tvm.relay.backend.interpreter import Executor from marie._docarray import DocumentArray, docarray_v2 from marie.constants import __default_endpoint__ @@ -189,7 +187,9 @@ def call_handle(request): "is_generator" ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data( + request, None, http=True, is_generator=is_generator + ) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -213,7 +213,9 @@ def call_handle(request): "is_generator" ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data( + request, None, http=True, is_generator=is_generator + ) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -275,9 +277,24 @@ def _init_batchqueue_dict(self): if getattr(self._executor, "dynamic_batching", None) is not None: # We need to sort the keys into endpoints and functions # Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function + self.logger.debug( + f"Executor Dynamic Batching configs: {self._executor.dynamic_batching}" + ) dbatch_endpoints = [] dbatch_functions = [] + request_models_map = self._executor._get_endpoint_models_dict() + for key, dbatch_config in self._executor.dynamic_batching.items(): + if ( + request_models_map.get(key, {}) + .get("parameters", {}) + .get("model", None) + is not None + ): + error_msg = f"Executor Dynamic Batching cannot be used for endpoint {key} because it depends on parameters." + self.logger.error(error_msg) + raise Exception(error_msg) + if key.startswith("/"): dbatch_endpoints.append((key, dbatch_config)) else: @@ -297,10 +314,21 @@ def _init_batchqueue_dict(self): for endpoint in func_endpoints[func_name]: if endpoint not in self._batchqueue_config: self._batchqueue_config[endpoint] = dbatch_config + else: + # we need to eventually copy the `custom_metric` + if dbatch_config.get("custom_metric", None) is not None: + self._batchqueue_config[endpoint]["custom_metric"] = ( + dbatch_config.get("custom_metric") + ) + + keys_to_remove = [] + for k, batch_config in self._batchqueue_config.items(): + if not batch_config.get("use_dynamic_batching", True): + keys_to_remove.append(k) + + for k in keys_to_remove: + self._batchqueue_config.pop(k) - self.logger.debug( - f"Executor Dynamic Batching configs: {self._executor.dynamic_batching}" - ) self.logger.debug( f"Endpoint Batch Queue Configs: {self._batchqueue_config}" ) @@ -404,6 +432,7 @@ def _load_executor( "metrics_registry": metrics_registry, "tracer_provider": tracer_provider, "meter_provider": meter_provider, + "allow_concurrent": self.args.allow_concurrent, }, py_modules=self.args.py_modules, extra_search_paths=self.args.extra_search_paths, @@ -553,7 +582,7 @@ def _record_response_size_monitoring(self, requests): requests[0].nbytes, attributes=attributes ) - def _set_result(self, requests, return_data, docs): + def _set_result(self, requests, return_data, docs, http=False): # assigning result back to request if return_data is not None: if isinstance(return_data, DocumentArray): @@ -574,10 +603,12 @@ def _set_result(self, requests, return_data, docs): f"The return type must be DocList / Dict / `None`, " f"but getting {return_data!r}" ) - - WorkerRequestHandler.replace_docs( - requests[0], docs, self.args.output_array_type - ) + if not http: + WorkerRequestHandler.replace_docs( + requests[0], docs, self.args.output_array_type + ) + else: + requests[0].direct_docs = docs return docs def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): @@ -665,11 +696,15 @@ async def handle_generator( ) async def handle( - self, requests: List["DataRequest"], tracing_context: Optional["Context"] = None + self, + requests: List["DataRequest"], + http=False, + tracing_context: Optional["Context"] = None, ) -> DataRequest: """Initialize private parameters and execute private loading functions. :param requests: The messages to handle containing a DataRequest + :param http: Flag indicating if it is used by the HTTP server for some optims :param tracing_context: Optional OpenTelemetry tracing context from the originating request. :returns: the processed message """ @@ -716,7 +751,7 @@ async def handle( ) # This is necessary because push might need to await for the queue to be emptied queue = await self._batchqueue_instances[exec_endpoint][param_key].push( - requests[0] + requests[0], http=http ) item = await queue.get() queue.task_done() @@ -774,7 +809,7 @@ async def executor_completion_callback( executor_completion_callback, job_id, requests ), ) - _ = self._set_result(requests, return_data, docs) + _ = self._set_result(requests, return_data, docs, http=http) except asyncio.CancelledError: self.logger.warning("Task was cancelled due to client disconnect") client_disconnected = True @@ -968,18 +1003,25 @@ def reduce_requests(requests: List["DataRequest"]) -> "DataRequest": # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, + request: DataRequest, + context, + http: bool = False, + is_generator: bool = False, ) -> DataRequest: """ Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ self.logger.debug("recv a process_single_data request") - return await self.process_data([request], context, is_generator=is_generator) + return await self.process_data( + [request], context, http=http, is_generator=is_generator + ) async def stream_doc( self, request: SingleDocumentRequest, context: "grpc.aio.ServicerContext" @@ -1124,13 +1166,18 @@ def _extract_tracing_context( return None async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, + requests: List[DataRequest], + context, + http=False, + is_generator: bool = False, ) -> DataRequest: """ Process the received requests and return the result as a new request :param requests: the data requests to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ @@ -1157,7 +1204,7 @@ async def process_data( ) else: result = await self.handle( - requests=requests, tracing_context=tracing_context + requests=requests, http=http, tracing_context=tracing_context ) if self._successful_requests_metrics: @@ -1230,7 +1277,7 @@ async def stream( :param kwargs: keyword arguments :yield: responses to the request """ - self.logger.debug("recv a stream request from client") + self.logger.debug("recv a stream request") async for request in request_iterator: yield await self.process_data([request], context) diff --git a/marie_cli/autocomplete.py b/marie_cli/autocomplete.py index 79c2a33a..f4692741 100644 --- a/marie_cli/autocomplete.py +++ b/marie_cli/autocomplete.py @@ -46,26 +46,34 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--install-requirements", "--port", "--ports", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', "--monitoring", "--port-monitoring", "--retries", @@ -75,14 +83,16 @@ "--metrics", "--metrics-exporter-host", "--metrics-exporter-port", - "--force-update", - "--force", - "--prefer-platform", - "--compression", - "--uses-before-address", - "--uses-after-address", - "--connection-list", - "--timeout-send", + '--stateful', + '--peer-ports', + '--force-update', + '--force', + '--prefer-platform', + '--compression', + '--uses-before-address', + '--uses-after-address', + '--connection-list', + '--timeout-send', ], "flow": [ "--help", @@ -130,15 +140,13 @@ "--title", "--description", "--cors", - "--no-debug-endpoints", - "--no-crud-endpoints", - "--expose-endpoints", "--uvicorn-kwargs", "--ssl-certfile", "--ssl-keyfile", + '--no-debug-endpoints', + '--no-crud-endpoints', + '--expose-endpoints', "--expose-graphql-endpoint", - "--protocol", - "--protocols", "--host", "--host-in", "--proxy", @@ -164,24 +172,32 @@ "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--port", "--ports", "--port-expose", "--port-in", - "--monitoring", - "--port-monitoring", - "--retries", - "--tracing", - "--traces-exporter-host", - "--traces-exporter-port", - "--metrics", - "--metrics-exporter-host", - "--metrics-exporter-port", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', + '--monitoring', + '--port-monitoring', + '--retries', + '--tracing', + '--traces-exporter-host', + '--traces-exporter-port', + '--metrics', + '--metrics-exporter-host', + '--metrics-exporter-port', + '--stateful', + '--peer-ports', ], "auth login": ["--help", "--force"], "auth logout": ["--help"], @@ -265,43 +281,53 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", "--floating", + '--replica-id', "--reload", "--install-requirements", "--port", "--ports", - "--monitoring", - "--port-monitoring", - "--retries", - "--tracing", - "--traces-exporter-host", - "--traces-exporter-port", - "--metrics", - "--metrics-exporter-host", - "--metrics-exporter-port", - "--force-update", - "--force", - "--prefer-platform", - "--compression", - "--uses-before-address", - "--uses-after-address", - "--connection-list", - "--timeout-send", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', + '--monitoring', + '--port-monitoring', + '--retries', + '--tracing', + '--traces-exporter-host', + '--traces-exporter-port', + '--metrics', + '--metrics-exporter-host', + '--metrics-exporter-port', + '--stateful', + '--peer-ports', + '--force-update', + '--force', + '--prefer-platform', + '--compression', + '--uses-before-address', + '--uses-after-address', + '--connection-list', + '--timeout-send', ], "deployment": [ "--help", @@ -330,18 +356,21 @@ "--disable-reduce", "--allow-concurrent", "--grpc-server-options", + '--raft-configuration', "--grpc-channel-options", "--entrypoint", "--docker-kwargs", "--volumes", "--gpus", "--disable-auto-volume", + '--force-network-mode', "--host", "--host-in", "--runtime-cls", "--timeout-ready", "--env", "--env-from-secret", + '--image-pull-secrets', "--shard-id", "--pod-role", "--noblock-on-start", @@ -350,6 +379,10 @@ "--install-requirements", "--port", "--ports", + '--protocol', + '--protocols', + '--provider', + '--provider-endpoint', "--monitoring", "--port-monitoring", "--retries", @@ -359,6 +392,8 @@ "--metrics", "--metrics-exporter-host", "--metrics-exporter-port", + '--stateful', + '--peer-ports', "--force-update", "--force", "--prefer-platform", @@ -374,6 +409,12 @@ "--grpc-metadata", "--deployment-role", "--tls", + '--title', + '--description', + '--cors', + '--uvicorn-kwargs', + '--ssl-certfile', + '--ssl-keyfile', ], "client": [ "--help", @@ -390,6 +431,7 @@ "--metrics-exporter-host", "--metrics-exporter-port", "--log-config", + '--reuse-session', "--protocol", "--grpc-channel-options", "--prefetch",