Skip to content

Commit

Permalink
fix: rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhjz committed Sep 2, 2024
1 parent 19cb840 commit 0ea3bb2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 42 deletions.
31 changes: 22 additions & 9 deletions python/qianfan/resources/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,20 @@ def _update_model_and_endpoint(
model = self._model
endpoint = self._endpoint
if endpoint is None:
# 获取本地模型列表
final_model = self._default_model() if model is None else model
model_info = self.get_model_info(final_model)
model_info_list = {k.lower(): v for k, v in self._local_models().items()}
model_info = model_info_list.get(final_model.lower())
if model_info is None:
raise errors.InvalidArgumentError(
f"The provided model `{model}` is not in the list of supported"
" models. If this is a recently added model, try using the"
" `endpoint` arguments and create an issue to tell us. Supported"
f" models: {self.models()}"
)
# 动态获取
model_info = self.get_model_info(final_model)
if model_info is None:
raise errors.InvalidArgumentError(
f"The provided model `{model}` is not in the list of supported"
" models. If this is a recently added model, try using the"
" `endpoint` arguments and create an issue to tell us."
f" Supported models: {self.models()}"
)
endpoint = model_info.endpoint
else:
# 适配非公有云等不需要添加chat/等前缀的endpoint
Expand Down Expand Up @@ -786,13 +791,18 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:

def _self_supported_models(self) -> Dict[str, QfLLMInfo]:
"""
base implement os _supported_models
preset model services list of current config
Args:
None
Returns:
Dict[str, QfLLMInfo]: _description_
"""
info_list = self._local_models()
# 获取最新的模型列表
return self.get_latest_api_type_models()
info_list = self._merge_local_models_with_latest(info_list)
return info_list

def _merge_local_models_with_latest(
self, info_list: Dict[str, QfLLMInfo]
Expand Down Expand Up @@ -941,6 +951,9 @@ def models(
models.remove(UNSPECIFIED_MODEL)
return models

def _local_models(self) -> Dict[str, QfLLMInfo]:
return {}

def get_model_info(self, model: str) -> QfLLMInfo:
"""
Get the model info of `model`
Expand Down
15 changes: 2 additions & 13 deletions python/qianfan/resources/llm/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,7 @@ class _ChatCompletionV1(BaseResourceV1):
QianFan ChatCompletion is an agent for calling QianFan ChatCompletion API.
"""

def _self_supported_models(self) -> Dict[str, QfLLMInfo]:
"""
preset model services list of ChatCompletion to current config
Args:
None
Returns:
Dict[str, QfLLMInfo]: _description_
"""

def _local_models(self) -> Dict[str, QfLLMInfo]:
info_list = {
"ERNIE-4.0-8K-Latest": QfLLMInfo(
endpoint="/chat/ernie-4.0-8k-latest",
Expand Down Expand Up @@ -1018,8 +1008,7 @@ def _self_supported_models(self) -> Dict[str, QfLLMInfo]:
optional_keys=set(),
),
}
# 获取最新的模型列表
info_list = self._merge_local_models_with_latest(info_list)

# 处理历史模型名称/别名
alias = {
"ERNIE-Speed": "ERNIE-Speed-8K",
Expand Down
34 changes: 30 additions & 4 deletions python/qianfan/resources/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import asyncio
import threading
import time
from queue import Queue
from queue import Empty, Queue
from types import TracebackType
from typing import Any, Optional, Type

Expand Down Expand Up @@ -290,6 +290,14 @@ async def __aexit__(
) -> None:
return

def __del__(self) -> None:
if hasattr(self, "_internal_qp10s_rate_limiter"):
del self._internal_qp10s_rate_limiter
if hasattr(self, "_internal_qps_rate_limiter"):
del self._internal_qps_rate_limiter
if hasattr(self, "_internal_rpm_rate_limiter"):
del self._internal_rpm_rate_limiter


class RateLimiter:
"""
Expand Down Expand Up @@ -346,10 +354,17 @@ def _leak(self) -> None:
)

def _worker(self) -> None:
while True:
task = self._condition_queue.get(True)
self._running = True
while self._running:
task: Optional[RateLimiter._AcquireTask] = None
try:
task = self._condition_queue.get(False)
except Empty:
# time.sleep(0.5)
task = None
continue
amount = task.amount
while True:
while self._running:
with self._sync_lock:
self._leak()
if self._token_count >= amount:
Expand All @@ -360,6 +375,14 @@ def _worker(self) -> None:
with task.condition:
task.condition.notify()

def stop(self, block: bool = True) -> None:
self._running = False
if block:
self._working_thread.join()

def __del__(self) -> None:
self.stop()

def acquire(self, amount: float = 1) -> None:
if amount > self._query_per_period:
raise ValueError("Can't acquire more than the maximum capacity")
Expand Down Expand Up @@ -520,3 +543,6 @@ async def __aexit__(
exit
"""
return

def __del__(self) -> None:
self._sync_limiter.stop()
20 changes: 4 additions & 16 deletions python/qianfan/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""
test config for qianfan pytest
"""
import subprocess
import threading

import pytest
Expand Down Expand Up @@ -50,13 +49,6 @@ def reset_config_automatically():
return


def print_ulimit():
result = subprocess.run("ulimit -s", shell=True, capture_output=True, text=True)
print(f"ulimit -s in pytest: {result.stdout.strip()}")
global res_ulimit
res_ulimit = result.stdout.strip()


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_setup(item):
# 记录测试开始时的活动线程数
Expand All @@ -70,19 +62,15 @@ def pytest_runtest_setup(item):
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_teardown(item):
yield
import time

time.sleep(3)
# 记录测试结束时的活动线程数
final_thread_count = threading.active_count()
initial_thread_count = getattr(item, "_initial_thread_count", final_thread_count)

global res_ulimit
# 计算新增的线程数
threads_created = final_thread_count - initial_thread_count
print(
f"max threads{res_ulimit}, '{item.nodeid}' init:{initial_thread_count}, curr: "
f" {final_thread_count} , 新增线程数: {threads_created}"
f" '{item.nodeid}' thread stat: init:{initial_thread_count}, curr: "
f" {final_thread_count} , new threads count: {threads_created}"
)

# 列出diff threads
Expand All @@ -95,8 +83,8 @@ def pytest_runtest_teardown(item):
if ident not in initial_threads
}
# 打印新增的线程信息
print(f"测试 '{item.nodeid}' 结束后,新增的线程数: {len(new_threads)}")
print(f"ut: '{item.nodeid}' new threads count: {len(new_threads)}")
for ident, name in new_threads.items():
print(
f"线程ID: {ident}, 线程名: {name}",
f"thread ID: {ident}, thread name: {name}",
)

0 comments on commit 0ea3bb2

Please sign in to comment.