Skip to content

Commit

Permalink
feat: 优化 show_total_latency 为 True 时的表现 (#812)
Browse files Browse the repository at this point in the history
* 优化 show_total_latency 的表现

* 格式化

* 修复压测无法正常工作的问题

* 移除线程池

* 线程池和独立线程读取并行
  • Loading branch information
Dobiichi-Origami authored Sep 29, 2024
1 parent daec965 commit 153936d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 22 deletions.
6 changes: 6 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ async for r in resp:
> ```python
> qianfan.ChatCompletion().do(messages=[{"role":"user", "content":"hi"}], show_total_latency=True)
> ```
> 在开启 show_total_latency 时,SDK 会在每次请求时启动一个单独的后台线程进行输出的预读取。在极端场景下可能会因无法创建线程而异常抛出。
>
> 对于这种情况,用户在使用时,可以在创建请求对象时设置 `sync_reading_thread_count=最大线程数` 来避免此问题:
> ```python
> qianfan.ChatCompletion(sync_reading_thread_count=500).do(messages=[{"role":"user", "content":"hi"}], show_total_latency=True)
> ```
以下是一个获取流式请求包间延迟的例子:
```python
Expand Down
83 changes: 61 additions & 22 deletions python/qianfan/resources/requestor/openapi_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
"""
Qianfan API Requestor
"""
import asyncio
import copy
import json
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from queue import Queue
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -66,6 +71,14 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._token_limiter = TokenLimiter(**kwargs)
self._async_token_limiter = AsyncTokenLimiter(**kwargs)
self._sync_reading_thread_count: Optional[int] = kwargs.get(
"sync_reading_thread_count", None
)

if self._sync_reading_thread_count is not None:
self._sync_reading_thread_pool: ThreadPoolExecutor = ThreadPoolExecutor(
max_workers=self._sync_reading_thread_count
)

def _retry_if_token_expired(self, func: Callable[..., _T]) -> Callable[..., _T]:
"""
Expand Down Expand Up @@ -471,9 +484,34 @@ def _generator_wrapper(

yield res

def _list_generator(data: List) -> Any:
for res in data:
yield res
def _list_generator(generator: Iterator) -> Any:
data: Queue[QfResponse] = Queue()

def _inner_worker() -> None:
for res in generator:
data.put(res)

if self._sync_reading_thread_count is not None:
task = self._sync_reading_thread_pool.submit(_inner_worker)

def is_alive() -> bool:
return not task.done()

else:
t = threading.Thread(target=_inner_worker, daemon=True)
t.start()

def is_alive() -> bool:
return t.is_alive()

while is_alive() or not data.empty():
try:
yield data.get(timeout=0.5)
except queue.Empty:
continue

if self._sync_reading_thread_count is None:
t.join()

if stream:
generator = self._compensate_token_usage_stream(
Expand All @@ -484,11 +522,7 @@ def _list_generator(data: List) -> Any:
if not show_total_latency:
return _generator_wrapper(generator)
else:
result_list: List[QfResponse] = []
for res in generator:
result_list.append(res)

return _list_generator(result_list)
return _list_generator(generator)
else:
return self._compensate_token_usage_non_stream(
self._request(
Expand Down Expand Up @@ -522,20 +556,29 @@ async def async_llm(
m.pop("tool_call_id", None)

class AsyncListIterator:
def __init__(self, data: List[QfResponse]):
self.data = data
self.index = 0
def __init__(self, data: AsyncIterator[QfResponse]):
self.queue: asyncio.Queue = asyncio.Queue()
self.is_closed = False

async def _inner_worker() -> None:
async for res in data:
await self.queue.put(res)

self.is_closed = True

self.task = asyncio.create_task(_inner_worker())

def __aiter__(self) -> "AsyncListIterator":
return self

async def __anext__(self) -> Any:
if self.index < len(self.data):
value = self.data[self.index]
self.index += 1
return value
else:
raise StopAsyncIteration
while not self.is_closed or self.queue.empty():
try:
return self.queue.get_nowait()
except asyncio.QueueEmpty:
await asyncio.sleep(0.5)
continue
raise StopAsyncIteration

@self._async_retry_if_token_expired
async def _helper() -> Union[QfResponse, AsyncIterator[QfResponse]]:
Expand Down Expand Up @@ -571,11 +614,7 @@ async def _async_generator_wrapper(
if not show_total_latency:
return _async_generator_wrapper(generator)
else:
result_list: List[QfResponse] = []
async for res in generator:
result_list.append(res)

return AsyncListIterator(result_list)
return AsyncListIterator(generator)
else:
return await self._async_compensate_token_usage_non_stream(
await self._async_request(req, data_postprocess=data_postprocess),
Expand Down

0 comments on commit 153936d

Please sign in to comment.