Skip to content

Commit

Permalink
Merge pull request #123 from marieai/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
gregbugaj authored Oct 2, 2024
2 parents deb20c9 + f125276 commit 088a6fb
Show file tree
Hide file tree
Showing 54 changed files with 2,051 additions and 784 deletions.
8 changes: 4 additions & 4 deletions extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,19 @@ aiofiles: standard,devel
aiohttp: standard,devel
aiostream: standard,devel

pytest: test
pytest<8.0.0: test
pytest-timeout: test
pytest-mock: test
pytest-cov==3.0.0: test
coverage==6.2: test
pytest-repeat: test
pytest-asyncio: test
pytest-asyncio<0.23.0: test
pytest-reraise: test
mock: test
requests-mock: test
pytest-custom_exit_code: test
black==22.3.0: test
kubernetes>=18.20.0: test
black==24.3.0: test
kubernetes>=18.20.0,<31.0.0: test
pytest-kind==22.11.1: test
pytest-lazy-fixture: test
torch: cicd
Expand Down
3 changes: 3 additions & 0 deletions marie/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def Client(
prefetch: Optional[int] = 1000,
protocol: Optional[Union[str, List[str]]] = 'GRPC',
proxy: Optional[bool] = False,
reuse_session: Optional[bool] = False,
suppress_root_logging: Optional[bool] = False,
tls: Optional[bool] = False,
traces_exporter_host: Optional[str] = None,
Expand Down Expand Up @@ -59,6 +60,7 @@ def Client(
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down Expand Up @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down
68 changes: 31 additions & 37 deletions marie/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import inspect
import os
from abc import ABC
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union
from typing import (
TYPE_CHECKING,
AsyncIterator,
Callable,
Iterator,
Optional,
Tuple,
Union,
)

from marie.excepts import BadClientInput
from marie.helper import T, parse_client, send_telemetry_event, typename
Expand Down Expand Up @@ -47,7 +55,6 @@ def __init__(
# affect users os-level envs.
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self._inputs = None
self._setup_instrumentation(
name=(
self.args.name
Expand All @@ -63,6 +70,12 @@ def __init__(
)
send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__)

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

def teardown_instrumentation(self):
"""Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for
exporting metrics data is properly cleaned up.
Expand Down Expand Up @@ -118,62 +131,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
raise BadClientInput from ex

def _get_requests(
self, **kwargs
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
self, inputs, **kwargs
) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]:
"""
Get request in generator.
:param inputs: The inputs argument to get the requests from.
:param kwargs: Keyword arguments.
:return: Iterator of request.
:return: Iterator of request and the length of the inputs.
"""
_kwargs = vars(self.args)
_kwargs['data'] = self.inputs
if hasattr(inputs, '__call__'):
inputs = inputs()

_kwargs['data'] = inputs
# override by the caller-specific kwargs
_kwargs.update(kwargs)

if hasattr(self._inputs, '__len__'):
total_docs = len(self._inputs)
if hasattr(inputs, '__len__'):
total_docs = len(inputs)
elif 'total_docs' in _kwargs:
total_docs = _kwargs['total_docs']
else:
total_docs = None

self._inputs_length = None

if total_docs:
self._inputs_length = max(1, total_docs / _kwargs['request_size'])
inputs_length = max(1, total_docs / _kwargs['request_size'])
else:
inputs_length = None

if inspect.isasyncgen(self.inputs):
if inspect.isasyncgen(inputs):
from marie.clients.request.asyncio import request_generator

return request_generator(**_kwargs)
return request_generator(**_kwargs), inputs_length
else:
from marie.clients.request import request_generator

return request_generator(**_kwargs)

@property
def inputs(self) -> 'InputType':
"""
An iterator of bytes, each element represents a Document's raw content.
``inputs`` defined in the protobuf
:return: inputs
"""
return self._inputs

@inputs.setter
def inputs(self, bytes_gen: 'InputType') -> None:
"""
Set the input data.
:param bytes_gen: input type
"""
if hasattr(bytes_gen, '__call__'):
self._inputs = bytes_gen()
else:
self._inputs = bytes_gen
return request_generator(**_kwargs), inputs_length

@abc.abstractmethod
async def _get_results(
Expand Down
5 changes: 2 additions & 3 deletions marie/clients/base/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ async def _get_results(
else grpc.Compression.NoCompression
)

self.inputs = inputs
req_iter = self._get_requests(**kwargs)
req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs)
continue_on_error = self.continue_on_error
# while loop with retries, check in which state the `iterator` remains after failure
options = client_grpc_options(
Expand Down Expand Up @@ -120,7 +119,7 @@ async def _get_results(
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')

with ProgressBar(
total_length=self._inputs_length, disable=not self.show_progress
total_length=inputs_length, disable=not self.show_progress
) as p_bar:
try:
if stream:
Expand Down
Loading

0 comments on commit 088a6fb

Please sign in to comment.