Skip to content

Commit

Permalink
feat: 支持ChatCompletion调用ERNIE-Funcitons-8K
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhjz committed Jul 18, 2024
1 parent a71c59d commit f773f8a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/qianfan/resources/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ def __init__(
self, version: Optional[Literal["1", "2", 1, 2]] = None, **kwargs: Any
) -> None:
self._version = str(version) if version else "1"
self._real = self._real_base(self._version)(**kwargs)
self._backup = self._real_base("1")(**kwargs)
self._real = self._real_base(self._version, **kwargs)(**kwargs)
self._backup = self._real_base("1", **kwargs)(**kwargs)

@classmethod
def _real_base(cls, version: str) -> Type[BaseResource]:
def _real_base(cls, version: str, **kwargs: Any) -> Type[BaseResource]:
"""
return the real base class
"""
Expand Down
20 changes: 19 additions & 1 deletion python/qianfan/resources/llm/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BatchRequestFuture,
VersionBase,
)
from qianfan.resources.llm.function import Function
from qianfan.resources.tools.tokenizer import Tokenizer
from qianfan.resources.typing import JsonBody, QfLLMInfo, QfMessages, QfResponse, QfRole
from qianfan.utils.logging import log_error, log_info
Expand Down Expand Up @@ -1537,7 +1538,24 @@ class ChatCompletion(VersionBase):
_real: Union[_ChatCompletionV1, _ChatCompletionV2]

@classmethod
def _real_base(cls, version: str) -> Type:
def _real_base(cls, version: str, **kwargs: Any) -> Type:
# convert to qianfan.Function
if kwargs.get("use_function"):
return Function
else:
model = kwargs.get("model", "")
func_model_info_list = {
k.lower(): v for k, v in Function._supported_models().items()
}
func_model_info = func_model_info_list.get(model.lower())
if model and func_model_info:
if func_model_info and func_model_info.endpoint:
return Function
endpoint = kwargs.get("endpoint", "")
for m in func_model_info_list.values():
if endpoint and m.endpoint == endpoint:
return Function

if version == "1":
return _ChatCompletionV1
elif version == "2":
Expand Down
2 changes: 2 additions & 0 deletions python/qianfan/resources/llm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
"user_id",
"stop",
"max_output_tokens",
"enable_user_memory",
"user_memory_extract_level",
},
max_input_chars=11200,
max_input_tokens=7168,
Expand Down

0 comments on commit f773f8a

Please sign in to comment.