diff --git a/cookbook/finetune/trainer_finetune.ipynb b/cookbook/finetune/trainer_finetune.ipynb index 30d7ddb0..d01065a0 100644 --- a/cookbook/finetune/trainer_finetune.ipynb +++ b/cookbook/finetune/trainer_finetune.ipynb @@ -18,7 +18,7 @@ "metadata": {}, "outputs": [], "source": [ - "! pip install \"qianfan>=0.2.8\" -U" + "! pip install \"qianfan>=0.3.0\" -U" ] }, { @@ -386,26 +386,6 @@ "trainer.output" ] }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'10268'" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "o[\"model_id\"]" - ] - }, { "attachments": {}, "cell_type": "markdown", diff --git a/docs/trainer.md b/docs/trainer.md index 90b44e2e..7ec7c387 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -6,6 +6,7 @@ ![trainer](./imgs/trainer.png) ## 快速开始 +### Finetune 以下以LLMFinetune(对应千帆平台 SFT语言大模型)为例,介绍如何使用`Trainer`进行训练。 ```python @@ -23,14 +24,56 @@ ds: Dataset = Dataset.load(qianfan_dataset_id=111, is_download_to_local=False) # 新建trainer LLMFinetune,最少传入train_type和dataset # 注意fine-tune任务需要指定的数据集类型要求为有标注的非排序对话数据集。 trainer = LLMFinetune( - train_type="ERNIE-Bot-turbo-0725", + train_type="ERNIE-Speed", dataset=ds, ) trainer.run() ``` -## 自定义训练参数 +### PostPretrain +除了使用`LLMFinetune`进行模型微调外,我们还可以使用`PostPretrain`: + +```python +from qianfan.trainer import PostPreTrain, LLMFinetune +from qianfan.trainer.configs import TrainConfig +from qianfan.trainer.consts import PeftType +from qianfan.dataset import Dataset + +# 泛文本 数据集 +ds = Dataset.load(qianfan_dataset_id="ds-ag138", is_download_to_local=False) + +# postpretrain +trainer = PostPreTrain( + train_type="ERNIE-Speed", + dataset=ds, +) +trainer.run() +# 这一步可以拿到训练完成的PostPretrain任务信息: +print(trainer.output) + + +# sft数据集 +sft_ds = Dataset.load(qianfan_dataset_id="ds-47j7ztjxfz60wb8x", is_download_to_local=False) +ppt_sft_trainer = LLMFinetune( + train_type="ERNIE-Speed", + dataset=sft_ds, + train_config=TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + peft_type=PeftType.ALL, + ), + name="qianfantrainer01" + previous_trainer=trainer, +) + +ppt_sft_trainer.run() +# 拿到最终的可用于推理部署的模型: +print(ppt_sft_trainer.output) +``` + +### 自定义训练参数 如果需要自定义训练参数,可以根据不同的模型传入不同的TrainConfig 以指定训练过程中的参数,需要注意的是不同模型支持的参数不同,具体以API文档为准。 ```python import os @@ -43,7 +86,7 @@ from qianfan.trainer import LLMFinetune from qianfan.trainer.configs import TrainConfig trainer = LLMFinetune( - train_type="ERNIE-Bot-turbo-0516", + train_type="ERNIE-Speed", dataset=ds, train_config=TrainConfig( epochs=1, # 迭代轮次(Epoch),控制训练过程中的迭代轮数。 @@ -54,7 +97,7 @@ trainer = LLMFinetune( trainer.run() ``` -## 事件回调 +### 事件回调 如果需要在训练过程中监控每个阶段的各个节点的状态,可以通过事件回调函数来实现 @@ -80,7 +123,7 @@ class MyEventHandler(EventHandler): eh = MyEventHandler() trainer = LLMFinetune( - train_type="Llama-2-13b", + train_type="ERNIE-Speed", dataset=ds, train_config=TrainConfig( epochs=1, diff --git a/python/pyproject.toml b/python/pyproject.toml index 1ce7bf05..ce88492a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "qianfan" -version = "0.2.9" +version = "0.3.0" description = "文心千帆大模型平台 Python SDK" authors = [] license = "Apache-2.0" diff --git a/python/qianfan/common/client/trainer.py b/python/qianfan/common/client/trainer.py index 3e20e1eb..4599c685 100644 --- a/python/qianfan/common/client/trainer.py +++ b/python/qianfan/common/client/trainer.py @@ -53,7 +53,6 @@ from qianfan.trainer.configs import ModelInfo, TrainLimit from qianfan.trainer.consts import ActionState, PeftType from qianfan.trainer.event import Event, EventHandler -from qianfan.utils.utils import remove_suffix_list trainer_app = typer.Typer( no_args_is_help=True, @@ -214,17 +213,14 @@ def print_trainer_config(config: ModelInfo) -> None: table.add_column("") for p in config.support_peft_types: table.add_column(Pretty(p.value, overflow="fold")) - example = TrainLimit() - limit_fields = [ - attr - for attr in dir(example) - if not attr.startswith("_") and not callable(getattr(example, attr)) - ] + from qianfan.trainer.configs import TrainConfig + + limit_fields = ( + TrainConfig().dict(exclude={"peft_type", "trainset_rate", "extras"}).keys() + ) for k in limit_fields: - if k in ["supported_hyper_params"]: - continue row_objs = [] - row_objs.append(remove_suffix_list(k, ["_options", "_limit"])) + row_objs.append(k) has_not_none_limit = False for peft in config.support_peft_types: peft_limit: Optional[TrainLimit] = config.common_params_limit @@ -232,8 +228,8 @@ def print_trainer_config(config: ModelInfo) -> None: specific_train_limit = config.specific_peft_types_params_limit.get(peft) if specific_train_limit is not None: peft_limit = specific_train_limit | config.common_params_limit - if peft_limit.__getattribute__(k): - row_objs.append(peft_limit.__getattribute__(k)) + if peft_limit and peft_limit.get(k): + row_objs.append(f"{peft_limit.get(k)}") has_not_none_limit = True else: row_objs.append("---") diff --git a/python/qianfan/consts.py b/python/qianfan/consts.py index 56d3f461..fa9fe6bf 100644 --- a/python/qianfan/consts.py +++ b/python/qianfan/consts.py @@ -142,7 +142,7 @@ class DefaultValue: ModelPublishStatusPollingInterval: float = 30 BatchRunStatusPollingInterval: float = 30 DeployStatusPollingInterval: float = 30 - DefaultFinetuneTrainType: str = "ERNIE-Bot-turbo-0725" + DefaultFinetuneTrainType: str = "ERNIE-Speed" # 目前可直接下载到本地的千帆数据集解压后的大小上限 # 后期研究更换为用户机内存大小的上限 @@ -172,6 +172,13 @@ class Consts: FineTuneCreateTaskAPI: str = "/wenxinworkshop/finetune/createTask" FineTuneCreateJobAPI: str = "/wenxinworkshop/finetune/createJob" FineTuneStopJobAPI: str = "/wenxinworkshop/finetune/stopJob" + ConsoleAPIQueryAction: str = "Action" + FineTuneV2BaseRouteAPI: str = "/v2/finetuning" + FineTuneCreateJobAction: str = "CreateFineTuningJob" + FineTuneCreateTaskAction: str = "CreateFineTuningTask" + FineTuneJobListAction: str = "DescribeFineTuningJobs" + FineTuneTaskListAction: str = "DescribeFineTuningTasks" + FineTuneTaskDetailAction: str = "DescribeFineTuningTask" ModelDetailAPI: str = "/wenxinworkshop/modelrepo/modelDetail" ModelVersionDetailAPI: str = "/wenxinworkshop/modelrepo/modelVersionDetail" ModelPublishAPI: str = "/wenxinworkshop/modelrepo/publishTrainModel" diff --git a/python/qianfan/model/model.py b/python/qianfan/model/model.py index 142e918f..641ed296 100644 --- a/python/qianfan/model/model.py +++ b/python/qianfan/model/model.py @@ -50,17 +50,17 @@ class Model( """model name""" service: Optional["Service"] = None """model service""" - task_id: Optional[int] + task_id: Optional[str] """train tkas id""" - job_id: Optional[int] + job_id: Optional[str] """train job id""" def __init__( self, id: Optional[str] = None, version_id: Optional[str] = None, - task_id: Optional[int] = None, - job_id: Optional[int] = None, + task_id: Optional[str] = None, + job_id: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, ): @@ -215,11 +215,11 @@ def publish(self, name: str = "", **kwargs: Any) -> "Model": self._wait_for_publish(**kwargs) # 发布模型 - self.model_name = name if name != "" else f"m_{self.task_id}_{self.job_id}" + self.model_name = name if name != "" else f"m_{self.job_id}_{self.task_id}" model_publish_resp = ResourceModel.publish( is_new=True, model_name=self.model_name, - version_meta={"taskId": self.task_id, "iterationId": self.job_id}, + version_meta={"taskId": self.job_id, "iterationId": self.task_id}, **kwargs, ) log_info( @@ -232,12 +232,11 @@ def publish(self, name: str = "", **kwargs: Any) -> "Model": raise InvalidArgumentError("task id or job id not found") # 判断训练任务已经训练完成 while True: - job_status_resp = api.FineTune.get_job( + job_status_resp = api.FineTune.V2.task_detail( task_id=self.task_id, - job_id=self.job_id, **kwargs, ) - job_status = job_status_resp["result"]["trainStatus"] + job_status = job_status_resp["result"]["runStatus"] log_info(f"model publishing keep polling, current status {job_status}") if job_status == console_const.TrainStatus.Running: time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL) diff --git a/python/qianfan/resources/console/consts.py b/python/qianfan/resources/console/consts.py index 6be3eee2..53e9c327 100644 --- a/python/qianfan/resources/console/consts.py +++ b/python/qianfan/resources/console/consts.py @@ -132,13 +132,13 @@ class ServiceStatus(str, Enum): class TrainStatus(str, Enum): - Finish = "FINISH" + Finish = "Done" """训练完成""" - Running = "RUNNING" + Running = "Running" """训练进行中""" - Fail = "FAIL" + Fail = "Fail" """训练失败""" - Stop = "STOP" + Stop = "Stopped" """训练停止""" @@ -158,9 +158,22 @@ class TrainDatasetType(int, Enum): """私有Bos数据集""" +class TrainDatasetSourceType(str, Enum): + Platform = "Platform" + PrivateBos = "Bos" + + class TrainMode(str, Enum): SFT = "SFT" """对应 LLMFinetune""" + PostPretrain = "PostPretrain" + """PostPretrain """ + + +class TrainParameterScale(str, Enum): + FullFineTuning = "FullFineTuning" + PromptTuning = "PromptTuning" + LoRA = "LoRA" class DeployPoolType(int, Enum): diff --git a/python/qianfan/resources/console/finetune.py b/python/qianfan/resources/console/finetune.py index bcf74a3b..c25bc12a 100644 --- a/python/qianfan/resources/console/finetune.py +++ b/python/qianfan/resources/console/finetune.py @@ -16,10 +16,11 @@ FineTune API """ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from qianfan.consts import Consts -from qianfan.resources.console.utils import console_api_request +from qianfan.resources.console import consts as console_consts +from qianfan.resources.console.utils import _get_console_v2_query, console_api_request from qianfan.resources.typing import QfRequest @@ -63,7 +64,7 @@ def create_task( base_train_type: str, train_type: str, description: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> QfRequest: """ Create a model fine-tuning task. @@ -74,6 +75,10 @@ def create_task( Parameters: name (str): The name of the fine-tuning task. + base_train_type (str): + The base training type of the fine-tuning task. e.g. "ERNIE-Bot-turbo" + train_type (str): + The training type of the fine-tuning task. e.g. "ERNIE-Bot-turbo-0922 description (Optional[str]): An optional description for the fine-tuning task. kwargs (Any): @@ -123,7 +128,7 @@ def create_job(cls, job: Dict[str, Any], **kwargs: Any) -> QfRequest: @classmethod @console_api_request - def stop_job(cls, task_id: int, job_id: int, **kwargs: Any) -> QfRequest: + def stop_job(cls, task_id: str, job_id: str, **kwargs: Any) -> QfRequest: """ Stop a fine-tuning job. @@ -131,9 +136,9 @@ def stop_job(cls, task_id: int, job_id: int, **kwargs: Any) -> QfRequest: specific task. Parameters: - task_id (int): + task_id (str): The identifier of the task associated with the fine-tuning job. - job_id (int): + job_id (str): The identifier of the fine-tuning job to be stopped. kwargs: Additional keyword arguments that can be passed to customize the request. @@ -147,3 +152,207 @@ def stop_job(cls, task_id: int, job_id: int, **kwargs: Any) -> QfRequest: req = QfRequest(method="POST", url=Consts.FineTuneStopJobAPI) req.json_body = {"taskId": task_id, "jobId": job_id, **kwargs} return req + + class V2: + """ + this class provides methods to interact with the fine-tuning V2 API. + """ + + @classmethod + def base_api_route(cls) -> str: + """ + base api url route for fine-tuning V2. + + Returns: + str: base api url route + """ + return Consts.FineTuneV2BaseRouteAPI + + @classmethod + @console_api_request + def create_job( + cls, + name: str, + model: str, + train_mode: Union[str, console_consts.TrainMode], + description: Optional[str] = None, + **kwargs: Any, + ) -> QfRequest: + """ + create a fine-tuning job. + + This function create a fine-tuning job. job may be associated with + many tasks. + + Parameters: + name (str): + The name of job. + model (str): + The identifier of the fine-tuning job to be stopped. + e.g. "ERNIE-Speed" + train_mode (Union[str, console_consts.TrainMode]): + The train mode of the fine-tuning job, including "SFT" and + "PostPreTrain" and so on. + description (Optional[str]): + The description of the fine-tuning job. + kwargs: + Additional keyword arguments that can be passed to customize the + request. + + Note: + The `@console_api_request` decorator is applied to this method, enabling + it to send the generated QfRequest and return a QfResponse to the user. + + """ + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.FineTuneCreateJobAction), + ) + req.json_body = {**kwargs, "name": name, "model": model} + if isinstance(train_mode, console_consts.TrainMode): + req.json_body["trainMode"] = train_mode.value + elif isinstance(train_mode, str): + req.json_body["trainMode"] = train_mode + else: + raise TypeError( + "train_mode must be a string or TrainMode, but got" + f" {type(train_mode)}" + ) + if description is not None: + req.json_body["description"] = description + return req + + @classmethod + @console_api_request + def create_task( + cls, + job_id: str, + params_scale: Union[str, console_consts.TrainParameterScale], + hyper_params: Dict[str, Any], + dataset_config: Dict[str, Any], + incrementTaskId: Optional[str] = None, + **kwargs: Any, + ) -> QfRequest: + """ + create a fine-tuning task. + + This function create a fine-tuning task associated with a + specific job. + + Parameters: + name (str): + The name of job. + model (str): + The identifier of the fine-tuning job to be stopped. + e.g. "ERNIE-Speed" + train_mode (Union[str, console_consts.TrainMode]): + The train mode of the fine-tuning job, including "SFT", + "PostPreTrain" and so on. + description (Optional[str]): + The description of the fine-tuning job. + kwargs: + Additional keyword arguments that can be passed to customize + the request. + + Note: + The `@console_api_request` decorator is applied to this method, enabling + it to send the generated QfRequest and return a QfResponse to the user. + """ + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.FineTuneCreateTaskAction), + ) + req.json_body = { + **kwargs, + "jobId": job_id, + "parameterScale": ( + params_scale.value + if isinstance(params_scale, console_consts.TrainParameterScale) + else params_scale + ), + "hyperParameterConfig": hyper_params, + "datasetConfig": dataset_config, + } + if incrementTaskId is not None: + req.json_body["incrementTaskId"] = incrementTaskId + return req + + @classmethod + @console_api_request + def job_list( + cls, + train_model: Optional[Union[str, console_consts.TrainMode]] = None, + marker: Optional[str] = None, + max_keys: Optional[int] = None, + page_reverse: Optional[bool] = None, + **kwargs: Any, + ) -> QfRequest: + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.FineTuneJobListAction), + ) + req.json_body = { + k: v + for k, v in { + **kwargs, + "trainModel": ( + train_model.value + if isinstance(train_model, console_consts.TrainMode) + else train_model + ), + "maker": marker, + "maxKeys": max_keys, + "pageReverse": page_reverse, + }.items() + if v is not None + } + return req + + @classmethod + @console_api_request + def task_list( + cls, + job_id: str, + marker: Optional[str] = None, + max_keys: Optional[int] = None, + page_reverse: Optional[bool] = None, + **kwargs: Any, + ) -> QfRequest: + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.FineTuneTaskListAction), + ) + req.json_body = { + k: v + for k, v in { + **kwargs, + "jobId": job_id, + "maker": marker, + "maxKeys": max_keys, + "pageReverse": page_reverse, + }.items() + if v is not None + } + return req + + @classmethod + @console_api_request + def task_detail( + cls, + task_id: str, + **kwargs: Any, + ) -> QfRequest: + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.FineTuneTaskDetailAction), + ) + req.json_body = { + **kwargs, + "taskId": task_id, + } + return req diff --git a/python/qianfan/resources/console/utils.py b/python/qianfan/resources/console/utils.py index 732f1309..d2ae1f80 100644 --- a/python/qianfan/resources/console/utils.py +++ b/python/qianfan/resources/console/utils.py @@ -15,10 +15,12 @@ """ Utils for console api """ +import copy import functools -from typing import Any, Awaitable, Callable, Tuple +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple from qianfan import get_config +from qianfan.consts import Consts from qianfan.errors import InvalidArgumentError from qianfan.resources.requestor.console_requestor import ConsoleAPIRequestor from qianfan.resources.typing import ParamSpec, QfRequest, QfResponse, RetryConfig @@ -46,16 +48,16 @@ def inner(*args: Any, **kwargs: Any) -> QfResponse: ak, sk = _get_console_ak_sk(**kwargs) config = get_config() retry_config = RetryConfig( - retry_count=kwargs.get("retry_count", config.CONSOLE_API_RETRY_COUNT), - timeout=kwargs.get("request_timeout", config.CONSOLE_API_RETRY_TIMEOUT), - backoff_factor=kwargs.get( + retry_count=kwargs.pop("retry_count", config.CONSOLE_API_RETRY_COUNT), + timeout=kwargs.pop("request_timeout", config.CONSOLE_API_RETRY_TIMEOUT), + backoff_factor=kwargs.pop( "backoff_factor", config.CONSOLE_API_RETRY_BACKOFF_FACTOR ), - jitter=kwargs.get("retry_jitter", config.CONSOLE_API_RETRY_JITTER), - retry_err_codes=kwargs.get( + jitter=kwargs.pop("retry_jitter", config.CONSOLE_API_RETRY_JITTER), + retry_err_codes=kwargs.pop( "retry_err_codes", config.CONSOLE_API_RETRY_ERR_CODES ), - max_wait_interval=kwargs.get( + max_wait_interval=kwargs.pop( "max_wait_interval", config.CONSOLE_API_RETRY_MAX_WAIT_INTERVAL ), ) @@ -108,6 +110,15 @@ async def inner(*args: Any, **kwargs: Any) -> QfResponse: return inner +def _get_console_v2_query( + action: Optional[str] = None, query: Dict[str, Any] = {} +) -> Dict[str, Any]: + res = copy.deepcopy(query) + if action is not None: + res[Consts.ConsoleAPIQueryAction] = action + return res + + def _get_console_ak_sk(pop: bool = True, **kwargs: Any) -> Tuple[str, str]: """ extract ak and sk from kwargs diff --git a/python/qianfan/resources/llm/completion.py b/python/qianfan/resources/llm/completion.py index 298bafff..de7c2ed6 100644 --- a/python/qianfan/resources/llm/completion.py +++ b/python/qianfan/resources/llm/completion.py @@ -333,8 +333,8 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]: "tool_choice", }, ), - "Yi-34B-Chat": QfLLMInfo( - endpoint="/chat/yi_34b_chat", + "Mixtral-8x7B-Instruct": QfLLMInfo( + endpoint="/chat/mixtral_8x7b_instruct", required_keys={"messages"}, optional_keys={ "stream", @@ -348,8 +348,8 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]: "tool_choice", }, ), - "Mixtral-8x7B-Instruct": QfLLMInfo( - endpoint="/chat/mixtral_8x7b_instruct", + "Yi-34B-Chat": QfLLMInfo( + endpoint="/chat/yi_34b_chat", required_keys={"messages"}, optional_keys={ "stream", diff --git a/python/qianfan/resources/requestor/base.py b/python/qianfan/resources/requestor/base.py index a7e8d5dc..e1078dd0 100644 --- a/python/qianfan/resources/requestor/base.py +++ b/python/qianfan/resources/requestor/base.py @@ -45,7 +45,7 @@ from qianfan.resources.http_client import HTTPClient from qianfan.resources.rate_limiter import RateLimiter from qianfan.resources.typing import QfRequest, QfResponse, RetryConfig -from qianfan.utils.logging import log_error, log_warn +from qianfan.utils.logging import log_error, log_trace, log_warn _T = TypeVar("_T") @@ -241,6 +241,7 @@ def _request( simple sync request """ with self._rate_limiter: + log_trace(f"raw request: {request}") response = self._client.request(request) _check_if_status_code_is_200(response) try: diff --git a/python/qianfan/resources/tools/tokenizer.py b/python/qianfan/resources/tools/tokenizer.py index b970bca8..16eb7b8f 100644 --- a/python/qianfan/resources/tools/tokenizer.py +++ b/python/qianfan/resources/tools/tokenizer.py @@ -66,11 +66,6 @@ def count_tokens( if mode == "local": return cls._local_count_tokens(text) if mode == "remote": - if model not in ["ERNIE-Bot", "ERNIE-Bot-turbo", "ERNIE-Bot-4"]: - raise InvalidArgumentError( - f"Model `{model} is not supported to calculate token count from" - " server.`" - ) return cls._remote_count_tokens_eb(text, model, **kwargs) # unreachable diff --git a/python/qianfan/tests/finetune_test.py b/python/qianfan/tests/finetune_test.py index 2dca7322..7d7954a7 100644 --- a/python/qianfan/tests/finetune_test.py +++ b/python/qianfan/tests/finetune_test.py @@ -18,6 +18,7 @@ from qianfan.resources import FineTune +from qianfan.resources.console import consts as console_consts def test_create_finetune_task(): @@ -62,7 +63,7 @@ def test_create_finetune_job(): "baseTrainType": "ERNIE-Bot-turbo", "trainType": "ERNIE-Bot-turbo-0725", "trainMode": "SFT", - "peftType": "ALL", + "peftType": console_consts.TrainParameterScale.FullFineTuning, "trainConfig": {"epoch": 1, "learningRate": 0.00003, "maxSeqLen": 4096}, "trainset": [{"type": 1, "id": 12563}], "trainsetRate": 20, @@ -91,3 +92,50 @@ def test_stop_finetune_job(): resp = FineTune.stop_job(task_id=147, job_id=258) assert resp["_request"] == {"taskId": 147, "jobId": 258} assert "result" in resp + + +def test_finetune_v2_create_job(): + resp = FineTune.V2.create_job(name="hiii", model="ERNIE-Speed", train_mode="SFT") + print("resp", resp) + + +def test_finetune_v2_create_task(): + from qianfan.resources.console.consts import TrainParameterScale + + job_id = "job-xx1234" + resp = FineTune.V2.create_task( + job_id=job_id, + params_scale=TrainParameterScale.FullFineTuning, + hyper_params={ + "learning_rate": 0.0001, + "epoch": 1, + }, + dataset_config={ + "sourceType": "Platform", + "corpusProportion": "1:5", + "datasets": [{"datasetId": "ds-p1t2wiv12f1vwsch"}], + "splitRatio": 20, + }, + ) + assert resp["result"]["jobId"] == job_id + assert resp["result"]["taskId"] != "" + + +def test_finetune_v2_task_detail(): + task_id = "task-xx1234" + resp = FineTune.V2.task_detail(task_id=task_id) + assert resp["result"]["taskId"] == task_id + + +def test_finetune_v2_job_list(): + resp = FineTune.V2.job_list() + assert "pageInfo" in resp["result"] + assert len(resp["result"]["jobList"]) > 0 + + +def test_finetune_v2_task_list(): + job_id = "job-xx1234" + resp = FineTune.V2.task_list(job_id=job_id) + assert "pageInfo" in resp["result"] + assert len(resp["result"]["taskList"]) > 0 + assert resp["result"]["taskList"][0]["jobId"] == job_id diff --git a/python/qianfan/tests/trainer_test.py b/python/qianfan/tests/trainer_test.py index 44d28f7a..2a94e346 100644 --- a/python/qianfan/tests/trainer_test.py +++ b/python/qianfan/tests/trainer_test.py @@ -36,6 +36,7 @@ from qianfan.trainer.consts import PeftType from qianfan.trainer.event import Event, EventHandler from qianfan.trainer.finetune import LLMFinetune +from qianfan.trainer.post_pretrain import PostPreTrain class MyEventHandler(EventHandler): @@ -60,24 +61,29 @@ def test_load_data_action(): qianfan_dataset_id="ds-9cetiuhvnbn4mqs3", is_download_to_local=False ) - res = LoadDataSetAction(preset).exec() + res = LoadDataSetAction( + preset, dataset_template=console_consts.DataTemplateType.NonSortedConversation + ).exec() assert isinstance(res, dict) assert "datasets" in res def test_train_action(): - ds_id = 111 - ta = TrainAction("ERNIE-Bot-turbo-0725") + ta = TrainAction( + train_type="ERNIE-Speed", train_mode=console_consts.TrainMode.PostPretrain + ) output = ta.exec( input={ - "datasets": [ - {"type": console_consts.TrainDatasetType.Platform.value, "id": ds_id} - ] + "datasets": { + "sourceType": console_consts.TrainDatasetSourceType.PrivateBos.value, + "versions": [{"versionBosUri": "bos:/aaa/"}], + } } ) assert isinstance(output, dict) assert "task_id" in output and "job_id" in output + assert isinstance(output["task_id"], str) and output["task_id"] != "" def test_model_publish_action(): @@ -102,7 +108,6 @@ def test_service_deploy_action(): def test_trainer_sft_run(): train_config = TrainConfig( epoch=1, - batch_size=4, learning_rate=0.00002, max_seq_len=4096, trainset_rate=20, @@ -137,7 +142,7 @@ def test_trainer_sft_run_from_bos(): ) sft_task.run() sft_task = LLMFinetune( - train_type="ERNIE-Bot-turbo-0725", dataset_bos_path="bos:/sdk-test/ds.jsonl" + train_type="ERNIE-Bot-turbo-0725", dataset_bos_path="bos:/sdk-test/" ) sft_task.run() res = sft_task.result @@ -150,7 +155,11 @@ def test_trainer_sft_run_from_bos(): def test_trainer_sft_with_deploy(): train_config = TrainConfig( - epoch=1, batch_size=4, learning_rate=0.00002, max_seq_len=4096 + epoch=1, + batch_size=4, + learning_rate=0.00002, + max_seq_len=4096, + peft_type=PeftType.ALL, ) deploy_config = DeployConfig(replicas=1, pool_type=1, service_type=ServiceType.Chat) qianfan_data_source = QianfanDataSource.create_bare_dataset( @@ -297,7 +306,11 @@ def test_eval_action_resume(): def test_trainer_sft_with_eval(): train_config = TrainConfig( - epoch=1, batch_size=4, learning_rate=0.00002, max_seq_len=4096 + epoch=1, + batch_size=4, + learning_rate=0.00002, + max_seq_len=4096, + peft_type=PeftType.LoRA, ) qianfan_data_source = QianfanDataSource.create_bare_dataset( "train", console_consts.DataTemplateType.NonSortedConversation @@ -374,25 +387,75 @@ def test_train_limit__or__(): def test_train_config_validate(): conf = TrainConfig(epoch=4, batch_size=4, max_seq_len=4096, learning_rate=0.0002) - res = conf.validate_config(TrainLimit(epoch_limit=(1, 2))) + # 不存在的字段 + res = conf.validate_config(TrainLimit(epoch=(1, 2))) assert not res - res = conf.validate_config(TrainLimit(epoch_limit=(1, 10))) - assert res - res = conf.validate_config(TrainLimit(max_seq_len_options=(1, 4096))) + res = conf.validate_config(TrainLimit(epoch=(1, 2), batch_size=(1, 20))) + assert not res + res = conf.validate_config( + TrainLimit( + epoch=(1, 8), + batch_size=(1, 10), + max_seq_len=[1024, 2048, 4096], + learning_rate=(0.000001, 0.1), + ) + ) assert res - res = conf.validate_valid_fields( - TrainLimit(supported_hyper_params=["epoch", "batch_size"]) + +def test_ppt(): + ppt_ds = Dataset.load( + qianfan_dataset_id="ds-mock-generic", is_download_to_local=False ) - assert res != "" - res = conf.validate_valid_fields( - TrainLimit( - supported_hyper_params=[ - "epoch", - "batch_size", - "max_seq_len", - "learning_rate", - ] + ppt_trainer = PostPreTrain( + train_type="ERNIE-Speed", + dataset=ppt_ds, + ) + ppt_trainer.run() + res = ppt_trainer.output + assert "task_id" in res and "job_id" in res + + +def test_ppt_with_sft(): + ppt_ds = Dataset.load( + qianfan_dataset_id="ds-mock-generic", is_download_to_local=False + ) + ppt_trainer = PostPreTrain( + train_type="ERNIE-Speed", + dataset=ppt_ds, + ) + ppt_trainer.run() + assert "task_id" in ppt_trainer.output and "job_id" in ppt_trainer.output + + sft_ds = Dataset.load(qianfan_dataset_id="ds-111", is_download_to_local=False) + sft_trainer = LLMFinetune( + dataset=sft_ds, previous_trainer=ppt_trainer, name="ppt_with_sft" + ) + sft_trainer.run() + assert "model_version_id" in sft_trainer.output and "model_id" in sft_trainer.output + + +def test_all_default_config(): + from qianfan.trainer.configs import ( + DefaultPostPretrainTrainConfigMapping, + DefaultTrainConfigMapping, + ) + + sft_ds = Dataset.load(qianfan_dataset_id="ds-111", is_download_to_local=False) + from qianfan.utils import log_info + + for k in DefaultTrainConfigMapping.keys(): + log_info(f"current: {k}") + LLMFinetune( + train_type=k, + dataset=sft_ds, ) + + ppt_ds = Dataset.load( + qianfan_dataset_id="ds-mock-generic", is_download_to_local=False ) - assert res == "" + for k in DefaultPostPretrainTrainConfigMapping.keys(): + PostPreTrain( + train_type=k, + dataset=ppt_ds, + ) diff --git a/python/qianfan/tests/utils/mock_server.py b/python/qianfan/tests/utils/mock_server.py index abacce0c..2f6da9b0 100644 --- a/python/qianfan/tests/utils/mock_server.py +++ b/python/qianfan/tests/utils/mock_server.py @@ -723,6 +723,147 @@ def create_finetune_task(): ) +@app.route(Consts.FineTuneV2BaseRouteAPI, methods=["POST"]) +def finetune_v2(): + action = request.args.get(Consts.ConsoleAPIQueryAction) + json_body = request.json + action_handler = { + Consts.FineTuneCreateJobAction: finetune_v2_create_job, + Consts.FineTuneCreateTaskAction: finetune_v2_create_task, + Consts.FineTuneJobListAction: finetune_v2_job_list, + Consts.FineTuneTaskListAction: finetune_v2_task_list, + Consts.FineTuneTaskDetailAction: finetune_v2_task_detail, + } + return action_handler.get(action)(body=json_body) + + +def finetune_v2_create_job(body): + return json_response( + { + "requestId": "98d2e3d7-a689-4255-91f1-da514a3a5777", + "result": {"jobId": "job-2qm2a9s9rj22"}, + } + ) + + +def finetune_v2_create_task(body): + return json_response( + { + "requestId": "aac33135-aed1-416a-8070-c6ecde325df5", + "result": {"jobId": body["jobId"], "taskId": "task-92zjbyinxruq"}, + } + ) + + +def finetune_v2_job_list(body): + return json_response( + { + "requestId": "f17326a0-91fd-404c-a9bc-db586166893e", + "result": { + "jobList": [ + { + "jobId": "job-b7hmiwmptntt", + "name": "ebspda2", + "description": "", + "model": "ERNIE-Speed", + "trainMode": "PostPretrain", + "createDate": "2024-01-29T16:24:32Z", + }, + { + "jobId": "job-yhddtcbesggz", + "name": "0129_yige", + "description": "", + "model": "WENXIN-YIGE", + "trainMode": "SFT", + "createDate": "2024-01-29T14:39:04Z", + }, + ], + "pageInfo": { + "marker": "", + "maxKeys": 20, + "isTruncated": True, + "nextMarker": "job-afik6nqipgnq", + }, + }, + } + ) + + +def finetune_v2_task_list(body): + return json_response( + { + "requestId": "eb3de810-3b21-4737-957f-ffd971a5610f", + "result": { + "pageInfo": {"marker": "", "maxKeys": 100, "isTruncated": False}, + "taskList": [ + { + "taskId": "task-92zjbyinxruq", + "jobId": body["jobId"], + "jobName": "hj_pptr", + "jobDescription": "", + "model": "ERNIE-Speed", + "trainMode": "PostPretrain", + "parameterScale": "FullFineTuning", + "runStatus": "Running", + "createDate": "2024-01-30T09:41:54Z", + "finishDate": "0000-00-00T00:00:00Z", + } + ], + }, + } + ) + + +def finetune_v2_task_detail(body): + r = request.json + task_id = r["taskId"] + global finetune_task_call_times + call_times = finetune_task_call_times.get(task_id) + if call_times is None: + finetune_task_call_times[task_id] = 0 + return json_response( + { + "requestId": "754dc75c-3515-4ddd-88ff-59caaad4358d", + "result": { + "taskId": task_id, + "jobId": "job-s66h7p9gqqu1", + "jobName": "hj_pptr", + "jobDescription": "", + "model": "ERNIE-Speed", + "trainMode": "PostPretrain", + "parameterScale": "FullFineTuning", + "runStatus": "Running", + "runProgress": "0%", + "vdlLink": "https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi1raXNyYzB4ZWlzcTM4MDgxIn0=", + "createDate": "2024-01-30T09:41:54Z", + "finishDate": "0000-00-00T00:00:00Z", + }, + } + ) + else: + MAX_CALL_TIMES = 10 + finetune_task_call_times[task_id] += 1 + return json_response( + { + "requestId": "754dc75c-3515-4ddd-88ff-59caaaaaaa", + "result": { + "taskId": task_id, + "jobId": "job-s66h7p9gqqu1", + "jobName": "hj_pptr", + "jobDescription": "", + "model": "ERNIE-Speed", + "trainMode": "PostPretrain", + "parameterScale": "FullFineTuning", + "runStatus": "Done" if call_times >= MAX_CALL_TIMES else "Running", + "runProgress": f"{int(100 * call_times / MAX_CALL_TIMES)}%", + "vdlLink": "https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi1raXNyYzB4ZWlzcTM4MDgxIn0=", + "createDate": "2024-01-30T09:41:54Z", + "finishDate": "0000-00-00T00:00:00Z", + }, + } + ) + + @app.route(Consts.FineTuneCreateJobAPI, methods=["POST"]) @iam_auth_checker def create_finetune_job(): @@ -1732,92 +1873,92 @@ def release_dataset(): @iam_auth_checker def get_dataset_info(): args = request.json - return json_response( - { - "log_id": "log_id", - "result": { - "groupPK": "14510", - "name": "ChineseMedicalDialogueData中文医疗问答数据集", + resp = { + "log_id": "log_id", + "result": { + "groupPK": "14510", + "name": "ChineseMedicalDialogueData中文医疗问答数据集", + "dataType": 4, + "versionInfo": { + "id": 123, + "groupId": 14510, + "datasetId": 12444, + "datasetPK": args["datasetId"], + "importRecordCount": 1, + "exportRecordCount": 0, + "bmlDatasetId": "ds-7pkzh1exthpuy10n", + "userId": 0, + "versionId": 1, + "displayName": "", + "importStatus": 2, + "importProgress": 100, + "exportStatus": 2, + "exportProgress": 0, "dataType": 4, - "versionInfo": { - "id": 123, - "groupId": 14510, - "datasetId": 12444, - "datasetPK": args["datasetId"], - "importRecordCount": 1, - "exportRecordCount": 0, - "bmlDatasetId": "ds-7pkzh1exthpuy10n", - "userId": 0, - "versionId": 1, - "displayName": "", - "importStatus": 2, - "importProgress": 100, - "exportStatus": 2, - "exportProgress": 0, - "dataType": 4, - "projectType": 20, - "templateType": 2000, - "errCode": None, - "uniqueType": 0, - "importErrorInfo": None, - "createTime": "2023-09-08 17:10:11", - "modifyTime": "2023-10-25 20:45:23", - "storageType": "sysBos", - "storage": { - "storageId": "easydata", - "storageName": "easydata", - "storagePath": ( - "/easydata/_system_/dataset/ds-7pkzh1exthpuy10n/texts" - ), - "rawStoragePath": "", - "region": "bj", - }, - "releaseStatus": 2, - "releaseErrCode": 0, - "releaseStoragePath": ( - "/easydata/_system_/dataset/ds-7pkzh1exthpuy10n/texts/jsonl" - ), - "releaseProgress": 0, - "remark": "", - "annotatedEntityCount": 792099, - "entityCount": 792099, - "labelCount": 1, - "memorySize": 513.42, - "characterCount": 173338860, - "isEnhancing": False, - "enhanceStatus": -1, - "hasEnhance": False, - "isSelfInstructEnhance": False, - "interAnnoRunning": False, - "hardSampleCount": 0, - "etlStatus": 0, - "hasEtl": False, - "isPipelineEtl": False, - "teamAnnoStatus": -1, - "hasTeamAnno": False, - "promptOptimizeStatus": 0, - "demandStatus": "", - "view": 2446, - "usage": 262, - "description": ( - "中文医疗对话数据集由792099个问答对组成,包括男科、内科、妇产科、肿瘤科、儿科和外科" - ), - "tag": [ - {"name": "文本对话非排序"}, - {"name": "限定式问答"}, - {"name": "调优"}, - ], - "license": "MIT", - "copyright": "toyhom", - "copyrightLink": ( - "https://github.com/Toyhom/Chinese-medical-dialogue-data" + "projectType": 20, + "templateType": 2000, + "errCode": None, + "uniqueType": 0, + "importErrorInfo": None, + "createTime": "2023-09-08 17:10:11", + "modifyTime": "2023-10-25 20:45:23", + "storageType": "sysBos", + "storage": { + "storageId": "easydata", + "storageName": "easydata", + "storagePath": ( + "/easydata/_system_/dataset/ds-7pkzh1exthpuy10n/texts" ), + "rawStoragePath": "", + "region": "bj", }, + "releaseStatus": 2, + "releaseErrCode": 0, + "releaseStoragePath": ( + "/easydata/_system_/dataset/ds-7pkzh1exthpuy10n/texts/jsonl" + ), + "releaseProgress": 0, + "remark": "", + "annotatedEntityCount": 792099, + "entityCount": 792099, + "labelCount": 1, + "memorySize": 513.42, + "characterCount": 173338860, + "isEnhancing": False, + "enhanceStatus": -1, + "hasEnhance": False, + "isSelfInstructEnhance": False, + "interAnnoRunning": False, + "hardSampleCount": 0, + "etlStatus": 0, + "hasEtl": False, + "isPipelineEtl": False, + "teamAnnoStatus": -1, + "hasTeamAnno": False, + "promptOptimizeStatus": 0, + "demandStatus": "", + "view": 2446, + "usage": 262, + "description": "中文医疗对话数据集由792099个问答对组成,包括男科、内科、妇产科、肿瘤科、儿科和外科", + "tag": [ + {"name": "文本对话非排序"}, + {"name": "限定式问答"}, + {"name": "调优"}, + ], + "license": "MIT", + "copyright": "toyhom", + "copyrightLink": ( + "https://github.com/Toyhom/Chinese-medical-dialogue-data" + ), }, - "status": 200, - "success": True, - } - ) + }, + "status": 200, + "success": True, + } + if args["datasetId"] == "ds-mock-generic": + resp["result"]["versionInfo"]["projectType"] = 401 + resp["result"]["versionInfo"]["templateType"] = 40100 + return json_response(resp) @app.route(Consts.DatasetStatusFetchInBatchAPI, methods=["POST"]) diff --git a/python/qianfan/trainer/__init__.py b/python/qianfan/trainer/__init__.py index e87af090..da857b92 100644 --- a/python/qianfan/trainer/__init__.py +++ b/python/qianfan/trainer/__init__.py @@ -20,6 +20,7 @@ ) from qianfan.trainer.event import Event, EventHandler from qianfan.trainer.finetune import LLMFinetune, Trainer +from qianfan.trainer.post_pretrain import PostPreTrain __all__ = [ "LLMFinetune", @@ -31,4 +32,5 @@ "LoadDataSetAction", "DeployAction", "ModelPublishAction", + "PostPreTrain", ] diff --git a/python/qianfan/trainer/actions.py b/python/qianfan/trainer/actions.py index a719cc3c..4c4b28c0 100644 --- a/python/qianfan/trainer/actions.py +++ b/python/qianfan/trainer/actions.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast from qianfan import resources as api from qianfan.config import get_config -from qianfan.dataset.dataset import Dataset +from qianfan.dataset import BosDataSource, Dataset, QianfanDataSource from qianfan.errors import InternalError, InvalidArgumentError from qianfan.evaluation import EvaluationManager from qianfan.evaluation.evaluator import Evaluator, LocalEvaluator, QianfanEvaluator @@ -29,11 +30,14 @@ with_event, ) from qianfan.trainer.configs import ( + DefaultPostPretrainTrainConfigMapping, DefaultTrainConfigMapping, - ModelInfoMapping, + PeftType, TrainConfig, TrainLimit, + get_model_info, ) +from qianfan.trainer.consts import ServiceStatus, TrainStatus from qianfan.utils import ( bos_uploader, log_debug, @@ -42,6 +46,8 @@ log_warn, utils, ) +from qianfan.utils.bos_uploader import is_valid_bos_path +from qianfan.utils.utils import first_lower_case, snake_to_camel class LoadDataSetAction(BaseAction[Dict[str, Any], Dict[str, Any]]): @@ -65,27 +71,68 @@ class LoadDataSetAction(BaseAction[Dict[str, Any], Dict[str, Any]]): from qianfan.dataset.dataset import Dataset dataset: Optional[Dataset] = None + bos_path: Optional[str] = None def __init__( self, - dataset: Optional[Dataset] = None, + dataset: Optional[Union[Dataset, str]] = None, + dataset_template: Optional[console_consts.DataTemplateType] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.dataset = dataset + if dataset is None: + raise InvalidArgumentError("dataset must be set") + if isinstance(dataset, str): + if not is_valid_bos_path(dataset): + raise InvalidArgumentError(f"invalid bos_path {dataset}") + self.bos_path = dataset + elif isinstance(dataset.inner_data_source_cache, QianfanDataSource): + qf_data_src = cast(QianfanDataSource, dataset.inner_data_source_cache) + if ( + dataset_template is not None + and qf_data_src.template_type != dataset_template + ): + raise InvalidArgumentError( + f"dataset must be `{dataset_template}` template." + ) + self.dataset = dataset + elif isinstance(dataset.inner_data_source_cache, BosDataSource): + self.dataset = dataset + else: + raise InvalidArgumentError( + "dataset must be either implemented with QianfanDataSource or" + " BosDataSource or a bos path" + ) @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: return self._exec(input, **kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: - from qianfan.dataset.data_source import BosDataSource, QianfanDataSource - """ Load dataset implementation, may called by exec and resume. """ + if self.bos_path is not None: + if not self.bos_path.endswith("/"): + bos_path = f'{Path(f"/{self.bos_path}").parent}' + log_warn( + f"input bos_path {self.bos_path} is a file, auto_convert to dir:" + f" {bos_path}" + ) + else: + bos_path = self.bos_path + return { + "datasets": { + "sourceType": ( + console_consts.TrainDatasetSourceType.PrivateBos.value + ), + "versions": [{"versionBosUri": bos_path}], + } + } + from qianfan.dataset.data_source import BosDataSource, QianfanDataSource + if self.dataset is None: - raise InvalidArgumentError("dataset must be set") + raise InvalidArgumentError("dataset or bos_path must be set") if self.dataset.inner_data_source_cache is None: raise InvalidArgumentError("invalid dataset") if isinstance(self.dataset.inner_data_source_cache, QianfanDataSource): @@ -98,25 +145,31 @@ def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: log_debug("[load_dataset_action] dataset loaded successfully") self.qf_dataset_id = qf_data_src.id return { - "datasets": [ - { - "id": qf_data_src.old_dataset_id, - "type": console_consts.TrainDatasetType.Platform.value, - } - ] + "datasets": { + "sourceType": console_consts.TrainDatasetSourceType.Platform.value, + "versions": [ + { + "versionId": qf_data_src.id, + } + ], + } } elif isinstance(self.dataset.inner_data_source_cache, BosDataSource): log_debug("[load_dataset_action] prepare train-set in BOS") bos_data_src = cast(BosDataSource, self.dataset.inner_data_source_cache) return { - "datasets": [ - { - "type": console_consts.TrainDatasetType.PrivateBos.value, - "bosPath": bos_uploader.generate_bos_file_parent_path( - bos_data_src.bucket, bos_data_src.bos_file_path - ), - } - ] + "datasets": { + "sourceType": ( + console_consts.TrainDatasetSourceType.PrivateBos.value + ), + "versions": [ + { + "versionBosUri": bos_uploader.generate_bos_file_parent_path( + bos_data_src.bucket, bos_data_src.bos_file_path + ) + } + ], + } } else: raise InvalidArgumentError("dataset must be set") @@ -159,37 +212,38 @@ class TrainAction( Input: ``` - {'datasets':[{'type': 1, 'id': 111}]} + {'datasets': {"sourceType": ( + console_consts.TrainDatasetSourceType.PrivateBos.value + ), + "versions": [ + { + "versionBosUri": bos_uploader.generate_bos_file_parent_path( + bos_data_src.bucket, bos_data_src.bos_file_path + ) + } + ]} ``` Output: ``` - {'task_id': 47923, 'job_id': 33512} + {'task_id': "task-ddd", 'job_id': "job-xxxx"} Sample code: ``` """ - task_id: Optional[int] = None + task_id: Optional[str] = None """train task id""" - job_id: Optional[int] = None + job_id: Optional[str] = None """train job id""" - # 这里的id新API的原因task/job和调换了,现在具体 - # task_id对应job_str_id,job_id对应task_str_id - task_str_id: Optional[str] = None - """train task str id""" - job_str_id: Optional[str] = None - """job task str id""" train_type: Optional[str] = "" """train_type""" - base_model: Optional[str] = None - """base train type like 'ERNIE-Bot-turbo'""" is_incr: bool = False """if it's incremental train or not""" train_config: Optional[TrainConfig] = None """train config""" - train_mode: console_consts.TrainMode = console_consts.TrainMode.SFT + train_mode: console_consts.TrainMode """train mode""" - task_name: str = "" + job_name: str = "" """train task name""" task_description: Optional[str] = None """train task description""" @@ -202,13 +256,13 @@ class TrainAction( def __init__( self, + train_mode: console_consts.TrainMode, train_type: Optional[str] = None, train_config: Optional[TrainConfig] = None, - base_model: Optional[str] = None, - task_id: Optional[int] = None, - job_id: Optional[int] = None, - train_mode: Optional[console_consts.TrainMode] = None, - task_name: Optional[str] = None, + task_id: Optional[str] = None, + job_id: Optional[str] = None, + peft_type: PeftType = PeftType.ALL, + job_name: Optional[str] = None, task_description: Optional[str] = None, job_description: Optional[str] = None, **kwargs: Any, @@ -216,6 +270,8 @@ def __init__( """ Parameters: + train_mode (Optional[console_consts.TrainMode], optional): + train mode, e.g. `SFT`, `PostPretrain`. Defaults to None. train_type (Optional[str], optional): train_type, must be specified when it's not increment training like 'ERNIE-Bot-turbo-0725' @@ -227,9 +283,7 @@ def __init__( used in incr train, model train task_id. Defaults to None. job_id (Optional[int], optional): used in incr train, mod train job_id. Defaults to None. - train_mode (Optional[console_consts.TrainMode], optional): - train mode, e.g. `sft`, `incremental`. Defaults to None. - task_name (Optional[str], optional): + job_name (Optional[str], optional): train task name. Defaults to None. task_description (Optional[str], optional): train task description. Defaults to None. @@ -239,44 +293,43 @@ def __init__( super().__init__(**kwargs) self.task_id = task_id self.job_id = job_id - if self.task_id is not None and self.job_id is not None: + self.train_mode = train_mode + if self.task_id is not None: # if incremental train + pre_task_detail = api.FineTune.V2.task_detail(task_id=self.task_id) + # 获取增量任务的训练model + if pre_task_detail.get("result") is not None: + self.train_type = pre_task_detail["result"]["model"] + self.train_mode = train_mode self.is_incr = True - self.train_config = train_config else: if train_type is None: raise InvalidArgumentError("train_type must be specified") - # train from base model + # 从基础模型开始训练 self.train_type = train_type - if base_model is None: - model_info = ModelInfoMapping.get(self.train_type) - if model_info is None: - raise InvalidArgumentError( - "base_model_type must be specified caused train_type:" - f" {self.train_type} is not found" - ) - self.base_model = model_info.base_model_type - else: - self.base_model = base_model - self.train_config = ( - train_config - if train_config is not None - else self.get_default_train_config(train_type) + model_info = get_model_info(train_mode, self.train_type) + if model_info is None: + log_warn(f"unknown train model type: {self.train_type} is not found") + assert self.train_type is not None + if train_config is None: + train_config = self.get_default_train_config( + self.train_type, self.train_mode, peft_type ) - self.validateTrainConfig() - if train_mode is not None: - self.train_mode = train_mode - self.task_name = self._generate_task_name(task_name, self.train_type) + self.train_config = train_config + self.validateTrainConfig(strict=kwargs.get("validate_strict", True)) + self.job_name = self._generate_job_name(job_name, self.train_type) self.task_description = task_description self.job_description = job_description - def _generate_task_name( - self, task_name: Optional[str], train_type: Optional[str] + def _generate_job_name( + self, job_name: Optional[str], train_type: Optional[str] ) -> str: - if task_name is not None: - return task_name + if job_name is not None: + return job_name model_info = ( - ModelInfoMapping.get(train_type) if train_type is not None else None + get_model_info(self.train_mode, train_type) + if train_type is not None + else None ) return ( f"job_{utils.generate_letter_num_random_id()}" @@ -284,7 +337,7 @@ def _generate_task_name( else f"{model_info.short_name}_{utils.generate_letter_num_random_id(5)}" ) - def validateTrainConfig(self) -> None: + def validateTrainConfig(self, strict: bool = True) -> None: """ validate train_config with ModelInfo Limits @@ -293,13 +346,11 @@ def validateTrainConfig(self) -> None: """ if self.train_config is None: raise InvalidArgumentError("none train_config") - if self.train_type not in ModelInfoMapping: - log_warn( - f"[train_action] train_type {self.train_type} not found, it may be not" - " supported" - ) else: - train_type_model_info = ModelInfoMapping[self.train_type] + assert self.train_type + train_type_model_info = get_model_info(self.train_mode, self.train_type) + if train_type_model_info is None: + return if ( self.train_config.peft_type not in train_type_model_info.support_peft_types @@ -308,24 +359,37 @@ def validateTrainConfig(self) -> None: f"[train_action] train_type {self.train_type}, peft_type" f" {self.train_config.peft_type} not found, it may be not supported" ) + if strict: + raise InvalidArgumentError( + f"[train_action] train_type {self.train_type}, peft_type" + f" {self.train_config.peft_type} not found, it may be not" + " supported" + ) + else: + assert train_type_model_info + res = False if ( train_type_model_info.specific_peft_types_params_limit is not None and self.train_config.peft_type in train_type_model_info.specific_peft_types_params_limit ): - self._validate_train_config( + res = self._validate_train_config( train_type_model_info.specific_peft_types_params_limit[ self.train_config.peft_type ] | train_type_model_info.common_params_limit, ) else: - self._validate_train_config( + res = self._validate_train_config( train_type_model_info.common_params_limit ) + if not res and strict: + raise InvalidArgumentError( + "invalid train_config, please check the config" + ) - def _validate_train_config(self, train_limit: TrainLimit) -> None: + def _validate_train_config(self, train_limit: TrainLimit) -> bool: """ validate train_config with a specific train_limit @@ -337,27 +401,7 @@ def _validate_train_config(self, train_limit: TrainLimit) -> None: """ if self.train_config is None: raise InvalidArgumentError("validate train_config is none") - self.train_config.validate_config(train_limit) - self.train_config.validate_valid_fields(train_limit) - - def _exec_incremental( - self, input: Dict[str, Any], **kwargs: Dict - ) -> Dict[str, Any]: - """ - increment train from task_id, job_id - - Parameters: - input (Dict[str, Any]): - input - - Raises: - NotImplementedError: not implemented yet - - Returns: - Dict[str, Any]: - output - """ - raise NotImplementedError("incr train not implemented") + return self.train_config.validate_config(train_limit) @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: @@ -389,64 +433,58 @@ def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: self._input = input return self._exec(self._input, **kwargs) - def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: + def _exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]: # 校验数据集 - train_sets = input.get("datasets") - if train_sets is None or len(train_sets) == 0: + ds_config = input.get("datasets") + if ds_config is None: raise InvalidArgumentError("train set must be set") + assert isinstance(ds_config, dict) + assert self.train_config + ds_config["splitRatio"] = self.train_config.trainset_rate + + if self.job_id is None: + # request for create model train task + assert self.train_type is not None + resp = api.FineTune.V2.create_job( + name=self.job_name, + description=self.task_description, + model=self.train_type, + train_mode=self.train_mode, + **kwargs, + ) - # 判断是否增量训练 - if self.is_incr: - return self._exec_incremental(input, **kwargs) - - # request for create model train task - assert self.train_type is not None - assert self.base_model is not None - resp = api.FineTune.create_task( - name=self.task_name, - description=self.task_description, - train_type=self.train_type, - base_train_type=self.base_model, - **kwargs, - ) - self.task_id = cast(int, resp["result"]["id"]) - self.job_str_id = resp["result"]["uuid"] - log_debug(f"[train_action] create fine-tune task: {self.task_id}") + self.job_id = str(resp["result"]["jobId"]) + log_debug( + f"[train_action] create {self.train_mode} train job: {self.job_id}" + ) assert self.train_config is not None - req_job = { - "taskId": self.task_id, - "description": self.job_description, - "baseTrainType": self.base_model, - "trainType": self.train_type, - "trainMode": self.train_mode.value, - "peftType": self.train_config.peft_type, - "trainConfig": { - "epoch": self.train_config.epoch, - "learningRate": self.train_config.learning_rate, - "batchSize": self.train_config.batch_size, - "maxSeqLen": self.train_config.max_seq_len, - "loggingSteps": self.train_config.logging_steps, - "warmupRatio": self.train_config.warmup_ratio, - "weightDecay": self.train_config.weight_decay, - "loraRank": self.train_config.lora_rank, - "loraAllLinear": self.train_config.lora_all_linear, - "loraAlpha": self.train_config.lora_alpha, - "loraDropout": self.train_config.lora_dropout, - "schedulerName": self.train_config.scheduler_name, - **self.train_config.extras, - }, - "trainset": train_sets, - "trainsetRate": self.train_config.trainset_rate, + hyper_params_dict = { + **self.train_config.dict(exclude={"peft_type", "trainset_rate", "extras"}), + **self.train_config.extras, } - tc_dict = cast(dict, req_job["trainConfig"]) - req_job["trainConfig"] = { - key: value for key, value in tc_dict.items() if value is not None + hyper_params_dict = { + first_lower_case(snake_to_camel(key)): value + for key, value in hyper_params_dict.items() + if value is not None } - create_job_resp = api.FineTune.create_job(req_job, **kwargs) - self.job_id = cast(int, create_job_resp["result"]["id"]) - self.task_str_id = create_job_resp["result"]["uuid"] - log_debug(f"[train_action] create fine-tune job_id: {self.job_id}") + ds_config = input["datasets"] + log_debug(f"train with ds_config: { ds_config}") + log_debug(f"train with hyper_params: { hyper_params_dict}") + if self.is_incr: + # 增量训练 + kwargs["incrementTaskId"] = self.task_id + log_info(f"train with incrementTaskId: { self.task_id}") + assert self.train_config.peft_type is not None + create_task_resp = api.FineTune.V2.create_task( + job_id=self.job_id, + params_scale=self.train_config.peft_type, + hyper_params=hyper_params_dict, + dataset_config=ds_config, + **kwargs, + ) + self.task_id = str(create_task_resp["result"]["taskId"]) + log_debug(f"[train_action] create {self.train_mode} train task: {self.task_id}") # 获取job状态,是否训练完成 self._wait_model_trained(**kwargs) @@ -455,21 +493,20 @@ def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: return self.result def _wait_model_trained(self, **kwargs: Dict) -> None: - if self.task_id is None or self.job_id is None: - raise InvalidArgumentError("task_id and job_id must not be None") + if self.task_id is None: + raise InvalidArgumentError("task_id must not be None") while True: - job_status_resp = api.FineTune.get_job( + job_status_resp = api.FineTune.V2.task_detail( task_id=self.task_id, - job_id=self.job_id, **kwargs, ) - job_status = job_status_resp["result"]["trainStatus"] - job_progress = job_status_resp["result"]["progress"] + job_status = job_status_resp["result"]["runStatus"] + job_progress = int(job_status_resp["result"]["runProgress"][:-1]) log_info( "[train_action] fine-tune running..." - f" task_name:{self.task_name} current status: {job_status}," + f" job_name:{self.job_name} current status: {job_status}," f" {job_progress}% check train task log in" - f" https://console.bce.baidu.com/qianfan/train/sft/{self.job_str_id}/{self.task_str_id}/detail/traininglog" + f" https://console.bce.baidu.com/qianfan/train/sft/{self.job_id}/{self.task_id}/detail/traininglog" ) if job_progress >= 50: log_info(f" check vdl report in {job_status_resp['result']['vdlLink']}") @@ -482,15 +519,18 @@ def _wait_model_trained(self, **kwargs: Dict) -> None: ]: log_error( "[train_action] fine-tune job" - f" {self.job_str_id}/{self.task_str_id} has ended," + f" {self.job_id}/{self.task_id} has ended," f" {job_status_resp}" ) - break + raise InternalError( + f"fine-tune job {self.job_id}/{self.task_id} has ended with" + f" status: {job_status}" + ) else: time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL) log_info( "[train_action] fine-tune job has ended:" - f" {self.job_str_id}/{self.task_str_id} with status: {job_status}" + f" {self.job_id}/{self.task_id} with status: {job_status}" ) @with_event @@ -533,11 +573,26 @@ def stop(self, **kwargs: Dict) -> None: api.FineTune.stop_job(self.task_id, self.job_id) log_debug(f"train job {self.task_id}/{self.job_id} stopped") - def get_default_train_config(self, model_type: str) -> TrainConfig: - return DefaultTrainConfigMapping.get( - model_type, - DefaultTrainConfigMapping[get_config().DEFAULT_FINE_TUNE_TRAIN_TYPE], - ) + def get_default_train_config( + self, model_type: str, train_mode: console_consts.TrainMode, peft_type: PeftType + ) -> TrainConfig: + if train_mode == console_consts.TrainMode.PostPretrain: + model_info = DefaultPostPretrainTrainConfigMapping.get( + model_type, + # DefaultTrainConfigMapping[get_config().DEFAULT_FINE_TUNE_TRAIN_TYPE], + ) + else: + model_info = DefaultTrainConfigMapping.get( + model_type, + # DefaultTrainConfigMapping[get_config().DEFAULT_FINE_TUNE_TRAIN_TYPE], + ) + if model_info is None: + raise InvalidArgumentError( + f"can not find default config for {model_type} in {peft_type}" + ) + train_config = model_info[peft_type] + train_config.peft_type = peft_type + return train_config class ModelPublishAction(BaseAction[Dict[str, Any], Dict[str, Any]]): @@ -557,9 +612,9 @@ class ModelPublishAction(BaseAction[Dict[str, Any], Dict[str, Any]]): ``` """ - task_id: Optional[int] = None + task_id: Optional[str] = None """task id""" - job_id: Optional[int] = None + job_id: Optional[str] = None """job id""" result: Optional[Dict[str, Any]] = None """result of model publish action""" @@ -570,15 +625,18 @@ class ModelPublishAction(BaseAction[Dict[str, Any], Dict[str, Any]]): def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: if self.task_id == "" or self.job_id == "": raise InvalidArgumentError("task_id or job_id must be set") - self.task_id = int(input.get("task_id", "")) - self.job_id = int(input.get("job_id", "")) + self.task_id = input.get("task_id", "") + self.job_id = input.get("job_id", "") self.model = Model(task_id=self.task_id, job_id=self.job_id) return self._exec(input, **kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: if self.model is None: raise InvalidArgumentError("model must be set when in model publish._exec") - log_debug("[model_publish_action] start model publish") + log_debug( + f"[model_publish_action] start model publish task:, {self.task_id}," + f" {self.job_id}" + ) try: self.action_event( ActionState.Running, @@ -883,3 +941,42 @@ def resume(self, **kwargs: Dict) -> Dict[str, Any]: res = self._exec(llm, **kwargs) self.result = {"eval_res": res, **self._input} return self.result + + +action_mapping: Dict[str, Dict[str, Any]] = { + LoadDataSetAction.__class__.__name__: { + ActionState.Preceding: TrainStatus.DatasetLoading, + ActionState.Running: TrainStatus.DatasetLoading, + ActionState.Done: TrainStatus.DatasetLoaded, + ActionState.Error: TrainStatus.DatasetLoadFailed, + ActionState.Stopped: TrainStatus.DatasetLoadStopped, + }, + TrainAction.__class__.__name__: { + ActionState.Preceding: TrainStatus.TrainCreated, + ActionState.Running: TrainStatus.Training, + ActionState.Done: TrainStatus.TrainFinished, + ActionState.Error: TrainStatus.TrainFailed, + ActionState.Stopped: TrainStatus.TrainStopped, + }, + ModelPublishAction.__class__.__name__: { + ActionState.Preceding: TrainStatus.ModelPublishing, + ActionState.Running: TrainStatus.ModelPublishing, + ActionState.Done: TrainStatus.ModelPublished, + ActionState.Error: TrainStatus.ModelPublishFailed, + ActionState.Stopped: TrainStatus.ModelPublishFailed, + }, + DeployAction.__class__.__name__: { + ActionState.Preceding: ServiceStatus.Created, + ActionState.Running: ServiceStatus.Deploying, + ActionState.Done: ServiceStatus.Deployed, + ActionState.Error: ServiceStatus.DeployFailed, + ActionState.Stopped: ServiceStatus.DeployStopped, + }, + EvaluateAction.__class__.__name__: { + ActionState.Preceding: TrainStatus.EvaluationCreated, + ActionState.Running: TrainStatus.EvaluationRunning, + ActionState.Done: TrainStatus.EvaluationFinished, + ActionState.Error: TrainStatus.EvaluationFailed, + ActionState.Stopped: TrainStatus.EvaluationStopped, + }, +} diff --git a/python/qianfan/trainer/configs.py b/python/qianfan/trainer/configs.py index f98d9f0c..7ac606bb 100644 --- a/python/qianfan/trainer/configs.py +++ b/python/qianfan/trainer/configs.py @@ -12,60 +12,79 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from qianfan.config import encoding from qianfan.errors import InvalidArgumentError +from qianfan.resources.console import consts as console_consts from qianfan.trainer.consts import PeftType from qianfan.utils import log_error, log_warn -from qianfan.utils.pydantic import BaseModel +from qianfan.utils.pydantic import BaseModel, Field T = TypeVar("T") +class LimitType(int, Enum): + SingleChoice = 1 + MultipleChoice = 2 + Range = 3 + + class TrainConfig(BaseModel): - epoch: Optional[int] = None + epoch: Optional[int] = Field(default=None, limit_type=LimitType.Range) """ epoch number: differ from models """ - batch_size: Optional[int] = None + batch_size: Optional[int] = Field(default=None, limit_type=LimitType.Range) """ batch size: differ from models """ - learning_rate: Optional[float] = None + learning_rate: Optional[float] = Field(default=None, limit_type=LimitType.Range) """ learning rate: differ from models """ - max_seq_len: Optional[int] = None + max_seq_len: Optional[int] = Field(default=None, limit_type=LimitType.SingleChoice) """ max_seq_len: differ from models """ - peft_type: Optional[Union[str, PeftType]] = None - """ - parameter efficient FineTuning method, like `LoRA`, `P-tuning`, `ALL` - """ - trainset_rate: int = 20 - """ - rate for dataset to spilt - """ - logging_steps: Optional[int] = None + logging_steps: Optional[int] = Field(default=None, limit_type=LimitType.Range) """log saving interval steps""" - warmup_ratio: Optional[float] = None + warmup_ratio: Optional[float] = Field(default=None, limit_type=LimitType.Range) """warmup ratio""" - weight_decay: Optional[float] = None + weight_decay: Optional[float] = Field(default=None, limit_type=LimitType.Range) """normalization params""" - lora_rank: Optional[int] = None + lora_rank: Optional[int] = Field(default=None, limit_type=LimitType.SingleChoice) """loRA rank""" - lora_all_linear: Optional[str] = None + lora_all_linear: Optional[str] = Field( + default=None, limit_type=LimitType.SingleChoice + ) """loRA all linear layer""" - scheduler_name: Optional[str] = None + scheduler_name: Optional[str] = Field( + default=None, limit_type=LimitType.SingleChoice + ) """for learning rate schedule""" - lora_alpha: Optional[int] = None + lora_alpha: Optional[int] = Field(default=None, limit_type=LimitType.Range) """LoRA scaling params""" - lora_dropout: Optional[float] = None + lora_dropout: Optional[float] = Field(default=None, limit_type=LimitType.Range) """loRA dropout""" + lora_target_modules: Optional[List[str]] = Field( + default=None, limit_type=LimitType.MultipleChoice + ) + """LoRA参数层列表""" + peft_type: Optional[Union[str, PeftType]] = None + """ + parameter efficient FineTuning method, like `LoRA`, `P-tuning`, `ALL` + """ + trainset_rate: int = 20 + """ + rate for dataset to spilt + """ extras: Dict[str, Any] = {} + """ + extra fields for train_config + """ @classmethod def load(cls, path: str) -> "TrainConfig": @@ -89,44 +108,35 @@ def load(cls, path: str) -> "TrainConfig": except Exception as e: raise e - # 后期考虑迁移到pydantic做动态校验,当前需要根据train_type和train_limit做校验 def validate_config(self, train_limit: "TrainLimit") -> bool: + schema = self.schema() res = True - res &= self._validate_range(self.epoch, [train_limit.epoch_limit], "epoch") - res &= self._validate_range( - self.batch_size, [train_limit.batch_size_limit], "batch_size" - ) - res &= self._validate_range( - self.learning_rate, [train_limit.learning_rate_limit], "learning_rate" - ) - res &= self._validate_options( - self.max_seq_len, train_limit.max_seq_len_options, "max_seq_len" - ) - res &= self._validate_range( - self.logging_steps, [train_limit.log_steps_limit], "logging_steps" - ) - res &= self._validate_range( - self.warmup_ratio, [train_limit.warmup_ratio_limit], "warmup_ratio" - ) - res &= self._validate_range( - self.weight_decay, [train_limit.weight_decay_limit], "weight_decay" - ) - res &= self._validate_options( - self.lora_alpha, train_limit.lora_alpha_options, "lora_alpha" - ) - res &= self._validate_options( - self.lora_rank, train_limit.lora_rank_options, "lora_rank" - ) - res &= self._validate_range( - self.lora_dropout, [train_limit.lora_dropout_limit], "lora_dropout" - ) - res &= self._validate_options( - self.scheduler_name, train_limit.scheduler_name_options, "scheduler_name" - ) + for k, v in schema["properties"].items(): + limit_type = v.get("limit_type") + if limit_type is None: + continue + value = getattr(self, k) + if value is None: + continue + if k not in train_limit: + log_warn( + f"train_config hyper params '{k}' is not in supported_params:" + f" {train_limit}" + ) + return False + if limit_type == LimitType.Range: + res &= self._validate_range(value, train_limit[k], k) + elif limit_type == LimitType.SingleChoice: + res &= self._validate_options(value, train_limit[k], k) + elif limit_type == LimitType.MultipleChoice: + for v in value: + res &= self._validate_options(v, train_limit[k], k) + if not res: + break return res def _validate_range( - self, value: Any, limit_ranges: List[Optional[Tuple[T, T]]], field_name: str + self, value: Any, limit_ranges: Optional[Tuple[T, T]], field_name: str ) -> bool: """ return False if value is not in limit_ranges @@ -142,15 +152,14 @@ def _validate_range( """ if value is None or limit_ranges is None: return True - for r in limit_ranges: - if r is None: - continue - if r[0] > value or r[1] < value: - log_warn( - f"train_config current {field_name} is {value}:" - f" but suggested {field_name} is in {r}" - ) - return False + if limit_ranges is None: + return True + if limit_ranges[0] > value or limit_ranges[1] < value: + log_warn( + f"train_config current {field_name} is {value}:" + f" but suggested {field_name} is in {limit_ranges}" + ) + return False return True def _validate_options( @@ -178,58 +187,20 @@ def _validate_options( return False return True - def validate_valid_fields(self, limit: "TrainLimit") -> str: - """ - return invalid field name if value is not in limit.supported_hyper_params - return "" if all fields are valid. - """ - supported_fields = limit.supported_hyper_params - for field in self.dict(exclude_none=True): - if field in ["peft_type", "extras", "trainset_rate"]: - continue - if field not in supported_fields: - log_warn( - f"train_config hyper params '{field}' is not in supported_params:" - f" {supported_fields}" - ) - return field - return "" - - -class TrainLimit(BaseModel): - batch_size_limit: Optional[Tuple[int, int]] = None - """batch size limit""" - max_seq_len_options: Optional[List[int]] = None - """max seq len options""" - epoch_limit: Optional[Tuple[int, int]] = None - """epoch limit""" - learning_rate_limit: Optional[Tuple[float, float]] = None - """learning rate limit""" - log_steps_limit: Optional[Tuple[int, int]] = None - """log steps limit""" - warmup_ratio_limit: Optional[Tuple[float, float]] = None - """warmup_ratio limit""" - weight_decay_limit: Optional[Tuple[float, float]] = None - """weight_decay limit""" - lora_rank_options: Optional[List[int]] = None - """loRA rank options""" - lora_alpha_options: Optional[List[int]] = None - """loRA alpha limit""" - lora_dropout_limit: Optional[Tuple[float, float]] = None - """loRA dropout limit""" - scheduler_name_options: Optional[List[str]] = None - """scheduler name options""" - - supported_hyper_params: List[str] = [] - """supported hyper params""" + +class TrainLimit(dict): + def __init__(self, **kwargs: Any): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__setitem__(k, v) def __or__(self, other: Any) -> "TrainLimit": assert isinstance(other, TrainLimit) # 使用copy模块深拷贝a的数据,避免修改原始数据 - merged_data = copy.deepcopy(self.dict()) + merged_data = copy.deepcopy(self) # 遍历b的字段,如果a中的值为None,则取b的值 - for field, value in other.dict().items(): + for field, value in other.items(): if merged_data.get(field) is None: merged_data[field] = value @@ -258,6 +229,60 @@ class ModelInfo(BaseModel): """special params suggestion of specific peft types""" +def get_model_info( + train_mode: console_consts.TrainMode, model: str +) -> Optional[ModelInfo]: + if train_mode == console_consts.TrainMode.PostPretrain: + return PostPreTrainModelInfoMapping.get(model) + elif train_mode == console_consts.TrainMode.SFT: + return ModelInfoMapping.get(model) + else: + return None + + +PostPreTrainModelInfoMapping: Dict[str, ModelInfo] = { + "ERNIE-Speed": ModelInfo( + short_name="ERNIE_Speed", + base_model_type="ERNIE-Speed", + support_peft_types=[PeftType.ALL], + common_params_limit=TrainLimit(), + specific_peft_types_params_limit={ + PeftType.ALL: TrainLimit( + epoch=(1, 10), + learning_rate=(0.00001, 0.00004), + max_seq_len=[4096, 8192], + ), + }, + ), + "ERNIE-Bot-turbo-0922": ModelInfo( + short_name="turbo_0922", + base_model_type="ERNIE-Bot-turbo", + support_peft_types=[PeftType.ALL], + common_params_limit=TrainLimit(), + specific_peft_types_params_limit={ + PeftType.ALL: TrainLimit( + epoch=(1, 10), + learning_rate=(0.00001, 0.00004), + max_seq_len=[4096, 8192], + ), + }, + ), + "Qianfan-Chinese-Llama-2-13B": ModelInfo( + short_name="Llama2_13b", + base_model_type="Llama-2", + support_peft_types=[PeftType.ALL], + common_params_limit=TrainLimit(), + specific_peft_types_params_limit={ + PeftType.ALL: TrainLimit( + batch_size=(48, 960), + epoch=(1, 1), + learning_rate=(0.0000002, 0.0002), + weight_decay=(0.0001, 0.05), + ), + }, + ), +} + # model train type -> default train config ModelInfoMapping: Dict[str, ModelInfo] = { "ERNIE-Speed": ModelInfo( @@ -265,38 +290,21 @@ class ModelInfo(BaseModel): base_model_type="ERNIE-Speed", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - log_steps_limit=(1, 100), - warmup_ratio_limit=(0.01, 0.5), - weight_decay_limit=(0.0001, 0.1), + batch_size=(1, 4), + max_seq_len=[4096, 8192], + epoch=(1, 50), + logging_steps=(1, 100), + warmup_ratio=(0.01, 0.5), + weight_decay=(0.0001, 0.1), ), specific_peft_types_params_limit={ PeftType.ALL: TrainLimit( - learning_rate_limit=(0.00001, 0.00004), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "log_steps", - "warmup_ratio", - "weight_decay", - ], + learning_rate=(0.00001, 0.00004), ), PeftType.LoRA: TrainLimit( - learning_rate_limit=(0.00003, 0.001), - lora_rank_options=[2, 4, 8], - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "log_steps", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_all_linear", - ], + learning_rate=(0.00003, 0.001), + lora_rank=[2, 4, 8], + lora_all_linear=["True", "False"], ), }, ), @@ -305,38 +313,20 @@ class ModelInfo(BaseModel): base_model_type="ERNIE-Bot-turbo", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - log_steps_limit=(1, 100), - warmup_ratio_limit=(0.01, 0.5), - weight_decay_limit=(0.0001, 0.1), + batch_size=(1, 4), + max_seq_len=[4096, 8192], + epoch=(1, 50), + logging_steps=(1, 100), + warmup_ratio=(0.01, 0.5), + weight_decay=(0.0001, 0.1), ), specific_peft_types_params_limit={ PeftType.ALL: TrainLimit( - learning_rate_limit=(0.00001, 0.00004), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "log_steps", - "warmup_ratio", - "weight_decay", - ], + learning_rate=(0.00001, 0.00004), ), PeftType.LoRA: TrainLimit( - learning_rate_limit=(0.00003, 0.001), - lora_rank_options=[2, 4, 8], - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "log_steps", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_all_linear", - ], + learning_rate=(0.00003, 0.001), + lora_rank=[2, 4, 8], ), }, ), @@ -345,17 +335,15 @@ class ModelInfo(BaseModel): base_model_type="ERNIE-Bot-turbo", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), + max_seq_len=[4096, 8192], + epoch=(1, 50), ), specific_peft_types_params_limit={ PeftType.ALL: TrainLimit( - learning_rate_limit=(0.00001, 0.00004), - supported_hyper_params=["epoch", "learning_rate", "max_seq_len"], + learning_rate=(0.00001, 0.00004), ), PeftType.LoRA: TrainLimit( - learning_rate_limit=(0.00003, 0.001), - supported_hyper_params=["epoch", "learning_rate", "max_seq_len"], + learning_rate=(0.00003, 0.001), ), }, ), @@ -364,116 +352,98 @@ class ModelInfo(BaseModel): base_model_type="ERNIE-Bot-turbo", support_peft_types=[PeftType.ALL, PeftType.LoRA, PeftType.PTuning], common_params_limit=TrainLimit( - epoch_limit=(1, 50), + epoch=(1, 50), ), specific_peft_types_params_limit={ PeftType.PTuning: TrainLimit( - learning_rate_limit=(0.003, 0.1), - supported_hyper_params=["epoch", "learning_rate"], + learning_rate=(0.003, 0.1), ), PeftType.ALL: TrainLimit( - learning_rate_limit=(0.00001, 0.00004), - supported_hyper_params=["epoch", "learning_rate"], + learning_rate=(0.00001, 0.00004), ), PeftType.LoRA: TrainLimit( - learning_rate_limit=(0.00003, 0.001), - supported_hyper_params=["epoch", "learning_rate"], + learning_rate=(0.00003, 0.001), ), }, ), - "Llama-2-7b": ModelInfo( + "Qianfan-Chinese-Llama-2-7B": ModelInfo( short_name="Llama2_7b", base_model_type="Llama-2", support_peft_types=[PeftType.ALL, PeftType.LoRA, PeftType.PTuning], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[1024, 2048, 4096], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - scheduler_name_options=[ + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + epoch=(1, 50), + learning_rate=(0.0000002, 0.0002), + scheduler_name=[ "linear", "cosine", "polynomial", "constant", "constant_with_warmup", ], - weight_decay_limit=(0.001, 1), - warmup_ratio_limit=(0.01, 0.1), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - ], + weight_decay=(0.001, 1), + warmup_ratio=(0.01, 0.1), ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - lora_rank_options=[8, 16, 32, 64], - lora_alpha_options=[8, 16, 32, 64], - lora_dropout_limit=(0.1, 0.5), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ), }, ), - "Llama-2-13b": ModelInfo( + "Qianfan-Chinese-Llama-2-13B": ModelInfo( short_name="Llama2_13b", base_model_type="Llama-2", support_peft_types=[PeftType.ALL, PeftType.LoRA, PeftType.PTuning], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[1024, 2048, 4096], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - scheduler_name_options=[ + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + epoch=(1, 50), + learning_rate=(0.0000002, 0.0002), + scheduler_name=[ "linear", "cosine", "polynomial", "constant", "constant_with_warmup", ], - weight_decay_limit=(0.001, 1), - warmup_ratio_limit=(0.01, 0.1), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", + weight_decay=(0.001, 1), + warmup_ratio=(0.01, 0.1), + ), + specific_peft_types_params_limit={ + PeftType.LoRA: TrainLimit( + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), + ), + }, + ), + "Qianfan-Chinese-Llama-2-7B-32K": ModelInfo( + short_name="Llama2_13b", + base_model_type="Llama-2", + support_peft_types=[PeftType.ALL, PeftType.LoRA, PeftType.PTuning], + common_params_limit=TrainLimit( + batch_size=(1, 1), + max_seq_len=[4096, 8192, 16384, 32768], + epoch=(1, 50), + learning_rate=(0.0000000001, 0.0002), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], + weight_decay=(0.001, 1), + warmup_ratio=(0.01, 0.1), ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - lora_rank_options=[8, 16, 32, 64], - lora_alpha_options=[8, 16, 32, 64], - lora_dropout_limit=(0.1, 0.5), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ), }, ), @@ -482,37 +452,16 @@ class ModelInfo(BaseModel): base_model_type="SQLCoder", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - ], + batch_size=(1, 4), + max_seq_len=[4096, 8192], + epoch=(1, 50), + learning_rate=(0.0000002, 0.0002), ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - lora_rank_options=[8, 16, 32, 64], - lora_alpha_options=[8, 16, 32, 64], - lora_dropout_limit=(0.1, 0.5), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ), }, ), @@ -521,37 +470,25 @@ class ModelInfo(BaseModel): base_model_type="ChatGLM2", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", + epoch=(1, 50), + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], + learning_rate=(0.0000002, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - lora_rank_options=[8, 16, 32, 64], - lora_alpha_options=[8, 16, 32, 64], - lora_dropout_limit=(0.1, 0.5), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ), }, ), @@ -560,16 +497,18 @@ class ModelInfo(BaseModel): base_model_type="ChatGLM2", support_peft_types=[PeftType.ALL], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "scheduler_name", - "warmup_ratio", - "weight_decay", + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + epoch=(1, 50), + learning_rate=(0.0000002, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], ), ), @@ -578,38 +517,25 @@ class ModelInfo(BaseModel): base_model_type="Baichuan2", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + epoch=(1, 50), + learning_rate=(0.0000000001, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ) }, ), @@ -618,37 +544,25 @@ class ModelInfo(BaseModel): base_model_type="Baichuan2", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "scheduler_name", - "warmup_ratio", - "weight_decay", + batch_size=(1, 4), + max_seq_len=[1024, 2048, 4096], + epoch=(1, 50), + learning_rate=(0.0000000001, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ) }, ), @@ -657,41 +571,24 @@ class ModelInfo(BaseModel): base_model_type="BLOOMZ", support_peft_types=[PeftType.ALL, PeftType.LoRA, PeftType.PTuning], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000002, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", + batch_size=(1, 4), + epoch=(1, 50), + learning_rate=(0.0000002, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ) }, ), @@ -700,128 +597,341 @@ class ModelInfo(BaseModel): base_model_type="CodeLlama", support_peft_types=[PeftType.ALL, PeftType.LoRA], common_params_limit=TrainLimit( - batch_size_limit=(1, 4), - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", + batch_size=(1, 4), + epoch=(1, 50), + max_seq_len=[1024, 2048, 4096], + learning_rate=(0.0000000001, 0.0002), + warmup_ratio=(0.01, 0.1), + weight_decay=(0.001, 1), + scheduler_name=[ + "linear", + "cosine", + "polynomial", + "constant", + "constant_with_warmup", ], ), specific_peft_types_params_limit={ PeftType.LoRA: TrainLimit( - batch_size_limit=(1, 4), - max_seq_len_options=[4096, 8192], - epoch_limit=(1, 50), - learning_rate_limit=(0.0000000001, 0.0002), - supported_hyper_params=[ - "epoch", - "learning_rate", - "max_seq_len", - "batch_size", - "scheduler_name", - "warmup_ratio", - "weight_decay", - "lora_rank", - "lora_alpha", - "lora_dropout", - "lora_target_modules", - ], + lora_rank=[8, 16, 32, 64], + lora_alpha=[8, 16, 32, 64], + lora_dropout=(0.1, 0.5), ) }, ), } -# model train type -> default train config -DefaultTrainConfigMapping: Dict[str, TrainConfig] = { - "ERNIE-Speed": TrainConfig( - epoch=1, - learning_rate=0.0003, - max_seq_len=4096, - peft_type=PeftType.LoRA, - logging_steps=1, - warmup_ratio=0.10, - weight_decay=0.0100, - lora_rank=8, - lora_all_linear="True", - ), - "ERNIE-Bot-turbo-0922": TrainConfig( - epoch=1, - learning_rate=0.0003, - max_seq_len=4096, - peft_type=PeftType.LoRA, - logging_steps=1, - warmup_ratio=0.10, - weight_decay=0.0100, - lora_rank=8, - lora_all_linear="True", - ), - "ERNIE-Bot-turbo-0725": TrainConfig( - epoch=1, - learning_rate=0.00003, - max_seq_len=4096, - peft_type=PeftType.LoRA, - ), - "ERNIE-Bot-turbo-0704": TrainConfig( - epoch=1, - learning_rate=0.00003, - peft_type=PeftType.LoRA, - ), - "Llama-2-7b": TrainConfig( - epoch=1, - batch_size=4, - learning_rate=0.00002, - peft_type=PeftType.LoRA, - ), - "Llama-2-13b": TrainConfig( - epoch=1, - batch_size=1, - learning_rate=0.00002, - peft_type=PeftType.LoRA, - ), - "SQLCoder-7B": TrainConfig( - epoch=1, - batch_size=1, - learning_rate=0.00002, - peft_type=PeftType.LoRA, - ), - "ChatGLM2-6B": TrainConfig( - epoch=1, - batch_size=1, - learning_rate=0.00002, - peft_type=PeftType.LoRA, - ), - "ChatGLM2-6B-32K": TrainConfig( - epoch=1, - learning_rate=0.00002, - peft_type=PeftType.ALL, - ), - "Baichuan2-7B": TrainConfig( - epoch=1, - batch_size=1, - learning_rate=0.000001, - peft_type=PeftType.LoRA, - ), - "Baichuan2-13B": TrainConfig( - epoch=1, - learning_rate=0.000001, - peft_type=PeftType.LoRA, - ), - "BLOOMZ-7B": TrainConfig( - epoch=1, - batch_size=1, - learning_rate=0.00002, - peft_type=PeftType.LoRA, - ), - "CodeLlama-7B": TrainConfig( - epoch=1, - learning_rate=0.000001, - batch_size=1, - peft_type=PeftType.LoRA, - ), +DefaultPostPretrainTrainConfigMapping: Dict[str, Dict[PeftType, TrainConfig]] = { + "ERNIE-Speed": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + peft_type=PeftType.ALL, + ) + }, + "ERNIE-Bot-turbo-0922": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + ) + }, + "Qianfan-Chinese-Llama-2-13B": { + PeftType.ALL: TrainConfig( + epoch=1, + batch_size=192, + learning_rate=0.000020, + weight_decay=0.01, + ) + }, +} + +tc = TrainConfig(learning_rate=0.333) + +# finetune model train type -> default finetune train config +DefaultTrainConfigMapping: Dict[str, Dict[PeftType, TrainConfig]] = { + "ERNIE-Speed": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + logging_steps=1, + warmup_ratio=0.1, + weight_decay=0.01, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.0003, + max_seq_len=4096, + logging_steps=1, + warmup_ratio=0.10, + weight_decay=0.0100, + lora_rank=8, + lora_all_linear="True", + ), + }, + "ERNIE-Bot-turbo-0922": { + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.0003, + max_seq_len=4096, + logging_steps=1, + warmup_ratio=0.10, + weight_decay=0.0100, + lora_rank=8, + lora_all_linear="True", + ), + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + logging_steps=1, + warmup_ratio=0.1, + weight_decay=0.01, + ), + }, + "ERNIE-Bot-turbo-0725": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.0003, + max_seq_len=4096, + ), + }, + "ERNIE-Bot-turbo-0704": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.00003, + ), + PeftType.PTuning: TrainConfig( + epoch=1, + learning_rate=0.03, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.00003, + ), + }, + "Qianfan-Chinese-Llama-2-7B": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.PTuning: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + }, + "Qianfan-Chinese-Llama-2-13B": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.PTuning: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + }, + "Qianfan-Chinese-Llama-2-7B-32K": { + PeftType.LoRA: TrainConfig( + epoch=3, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=32768, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + PeftType.ALL: TrainConfig( + epoch=3, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=32768, + ), + }, + "ChatGLM2-6B": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + }, + "ChatGLM2-6B-32K": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + ), + }, + "Baichuan2-7B": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + }, + "Baichuan2-13B": { + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + }, + "BLOOMZ-7B": { + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + ), + PeftType.PTuning: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + ), + }, + "CodeLlama-7B": { + PeftType.LoRA: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + lora_target_modules=["self_attn.q_proj", "self_attn.v_proj"], + lora_rank=32, + lora_alpha=32, + lora_dropout=0.1, + ), + PeftType.ALL: TrainConfig( + epoch=1, + learning_rate=0.000001, + batch_size=1, + scheduler_name="cosine", + warmup_ratio=0.03, + weight_decay=0.01, + max_seq_len=4096, + ), + }, } diff --git a/python/qianfan/trainer/consts.py b/python/qianfan/trainer/consts.py index 33412112..af0efe04 100644 --- a/python/qianfan/trainer/consts.py +++ b/python/qianfan/trainer/consts.py @@ -33,7 +33,7 @@ class ActionState(str, Enum): """`Stopped` stands for the state when stop() is called.""" -class FinetuneStatus(str, Enum): +class TrainStatus(str, Enum): Unknown = "Unknown" """未知状态""" DatasetLoading = "DatasetLoading" @@ -88,9 +88,9 @@ class ServiceStatus(str, Enum): class PeftType(str, Enum): - ALL = "ALL" + ALL = "FullFineTuning" """全量更新""" - PTuning = "P-tuning" + PTuning = "PromptTuning" """p-tuning""" LoRA = "LoRA" """LoRA""" diff --git a/python/qianfan/trainer/finetune.py b/python/qianfan/trainer/finetune.py index 7f9042cd..ae284d14 100644 --- a/python/qianfan/trainer/finetune.py +++ b/python/qianfan/trainer/finetune.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union from qianfan.config import get_config -from qianfan.dataset.data_source import BosDataSource, QianfanDataSource from qianfan.errors import InvalidArgumentError from qianfan.evaluation.evaluator import Evaluator from qianfan.model.configs import DeployConfig @@ -25,6 +24,7 @@ LoadDataSetAction, ModelPublishAction, TrainAction, + action_mapping, ) from qianfan.trainer.base import ( BaseAction, @@ -38,9 +38,7 @@ TrainConfig, ) from qianfan.trainer.consts import ( - ActionState, - FinetuneStatus, - ServiceStatus, + TrainStatus, ) @@ -53,15 +51,17 @@ class LLMFinetune(Trainer): def __init__( self, - train_type: str, + train_type: Optional[str] = None, dataset: Optional[Any] = None, train_config: Optional[Union[TrainConfig, str]] = None, deploy_config: Optional[DeployConfig] = None, event_handler: Optional[EventHandler] = None, - base_model: Optional[str] = None, eval_dataset: Optional[Any] = None, evaluators: Optional[List[Evaluator]] = None, dataset_bos_path: Optional[str] = None, + previous_trainer: Optional[Trainer] = None, + previous_task_id: Optional[str] = None, + name: Optional[str] = None, **kwargs: Any, ) -> None: """ @@ -107,9 +107,8 @@ def __init__( ) ``` """ - # 校验train_type - if train_type is None or train_type == "": - raise InvalidArgumentError("train_type is empty") + # 设置name + self.name = name if isinstance(train_config, str): train_config = TrainConfig.load(train_config) @@ -117,43 +116,53 @@ def __init__( actions: List[BaseAction] = [] # 校验dataset if dataset is not None: - if dataset.inner_data_source_cache is None: - raise InvalidArgumentError("invalid dataset") - if isinstance(dataset.inner_data_source_cache, QianfanDataSource): - qf_data_src = cast(QianfanDataSource, dataset.inner_data_source_cache) - if ( - qf_data_src.template_type - != console_consts.DataTemplateType.NonSortedConversation - ): - raise InvalidArgumentError( - "dataset must be `non-sorted conversation` template in" - " llm-fine-tune" - ) - self.load_data_action = LoadDataSetAction( - dataset=dataset, event_handler=event_handler, **kwargs - ) - elif isinstance(dataset.inner_data_source_cache, BosDataSource): - self.load_data_action = LoadDataSetAction( - dataset=dataset, event_handler=event_handler, **kwargs + self.load_data_action = LoadDataSetAction( + dataset=dataset, + dataset_template=console_consts.DataTemplateType.NonSortedConversation, + event_handler=event_handler, + **kwargs, + ) + elif dataset_bos_path: + self.load_data_action = LoadDataSetAction( + dataset=dataset_bos_path, + event_handler=event_handler, + **kwargs, + ) + else: + raise InvalidArgumentError("either dataset or bos_path is required") + actions.append(self.load_data_action) + if previous_trainer: + # init an increment training + if hasattr(previous_trainer, "train_action"): + self.train_action = TrainAction( + train_config=train_config, + task_id=previous_trainer.train_action.task_id, + train_mode=console_consts.TrainMode.SFT, + job_name=name, + **kwargs, ) else: raise InvalidArgumentError( - "dataset must be either implemented with QianfanDataSource or" - " BosDataSource" + "invalid trainer input without previous train action" ) - actions.append(self.load_data_action) - elif dataset_bos_path: - self.dataset_bos_path = dataset_bos_path + elif previous_task_id: + self.train_action = TrainAction( + train_config=train_config, + task_id=previous_task_id, + train_mode=console_consts.TrainMode.SFT, + job_name=name, + **kwargs, + ) else: - raise InvalidArgumentError("either dataset or bos_path is required") - self.train_action = TrainAction( - train_config=train_config, - base_model=base_model, - train_type=train_type, - train_mode=console_consts.TrainMode.SFT, - event_handler=event_handler, - **kwargs, - ) + # init train action from base model + self.train_action = TrainAction( + train_config=train_config, + train_type=train_type, + train_mode=console_consts.TrainMode.SFT, + event_handler=event_handler, + job_name=name, + **kwargs, + ) actions.append(self.train_action) if not kwargs.get("model_not_publish"): self.model_publish = ModelPublishAction( @@ -205,10 +214,6 @@ def run(self, **kwargs: Any) -> Trainer: kwargs["retry_count"] = kwargs.get( "retry_count", get_config().TRAINER_STATUS_POLLING_RETRY_TIMES ) - if not hasattr(self, "load_data_action") and self.dataset_bos_path is not None: - kwargs["input"] = { - "datasets": [{"id": 2, "bosPath": self.dataset_bos_path}] - } self.result[0] = self.ppls[0].exec(**kwargs) return self @@ -224,10 +229,10 @@ def status(self) -> str: raise InvalidArgumentError("invalid pipeline to get status") action = self.ppls[0][str(self.ppls[0]._state)] if action is None: - return FinetuneStatus.Unknown + return TrainStatus.Unknown action_name = action.__class__.__name__ - return fine_tune_action_mapping.get(action_name, {}).get( - action.state, FinetuneStatus.Unknown + return action_mapping.get(action_name, {}).get( + action.state, TrainStatus.Unknown ) def stop(self, **kwargs: Dict) -> Trainer: @@ -261,43 +266,3 @@ def output(self) -> Any: @classmethod def train_type_list(cls) -> Dict[str, ModelInfo]: return ModelInfoMapping - - -# mapping for action state -> fine-tune status -fine_tune_action_mapping: Dict[str, Dict[str, Any]] = { - LoadDataSetAction.__class__.__name__: { - ActionState.Preceding: FinetuneStatus.DatasetLoading, - ActionState.Running: FinetuneStatus.DatasetLoading, - ActionState.Done: FinetuneStatus.DatasetLoaded, - ActionState.Error: FinetuneStatus.DatasetLoadFailed, - ActionState.Stopped: FinetuneStatus.DatasetLoadStopped, - }, - TrainAction.__class__.__name__: { - ActionState.Preceding: FinetuneStatus.TrainCreated, - ActionState.Running: FinetuneStatus.Training, - ActionState.Done: FinetuneStatus.TrainFinished, - ActionState.Error: FinetuneStatus.TrainFailed, - ActionState.Stopped: FinetuneStatus.TrainStopped, - }, - ModelPublishAction.__class__.__name__: { - ActionState.Preceding: FinetuneStatus.ModelPublishing, - ActionState.Running: FinetuneStatus.ModelPublishing, - ActionState.Done: FinetuneStatus.ModelPublished, - ActionState.Error: FinetuneStatus.ModelPublishFailed, - ActionState.Stopped: FinetuneStatus.ModelPublishFailed, - }, - DeployAction.__class__.__name__: { - ActionState.Preceding: ServiceStatus.Created, - ActionState.Running: ServiceStatus.Deploying, - ActionState.Done: ServiceStatus.Deployed, - ActionState.Error: ServiceStatus.DeployFailed, - ActionState.Stopped: ServiceStatus.DeployStopped, - }, - EvaluateAction.__class__.__name__: { - ActionState.Preceding: FinetuneStatus.EvaluationCreated, - ActionState.Running: FinetuneStatus.EvaluationRunning, - ActionState.Done: FinetuneStatus.EvaluationFinished, - ActionState.Error: FinetuneStatus.EvaluationFailed, - ActionState.Stopped: FinetuneStatus.EvaluationStopped, - }, -} diff --git a/python/qianfan/trainer/post_pretrain.py b/python/qianfan/trainer/post_pretrain.py new file mode 100644 index 00000000..051a48d9 --- /dev/null +++ b/python/qianfan/trainer/post_pretrain.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Union + +from qianfan.config import get_config +from qianfan.dataset import Dataset +from qianfan.errors import InvalidArgumentError +from qianfan.resources.console import consts as console_consts +from qianfan.trainer.actions import ( + LoadDataSetAction, + TrainAction, + action_mapping, +) +from qianfan.trainer.base import ( + BaseAction, + EventHandler, + Pipeline, + Trainer, +) +from qianfan.trainer.configs import ( + ModelInfo, + PostPreTrainModelInfoMapping, + TrainConfig, +) +from qianfan.trainer.consts import ( + TrainStatus, +) + + +class PostPreTrain(Trainer): + """ + Class implements the PostPreTrain training pipeline with several actions. + Use `run()` to synchronously run the training pipeline until the + model training pipeline is finished. + """ + + def __init__( + self, + train_type: str, + dataset: Optional[Union[Dataset, str]] = None, + train_config: Optional[Union[TrainConfig, str]] = None, + event_handler: Optional[EventHandler] = None, + **kwargs: Any, + ) -> None: + """ + Initialization function for LLM fine-tuning. + + Parameters: + train_type: str + A string representing the model version type. + like 'ERNIE-Bot-turbo-0725', 'ChatGLM2-6b' + dataset: Optional[Union[Dataset, str]] = None, + A post_pretrain dataset instance and an bos path. + or an bos path for post pretrain + train_config: TrainConfig + An TrainConfig for post pretrain training parameters. + If not provided, default parameters of diverse + models will be used. + event_handler: EventHandler + An EventHandler instance for receive events during + the training process + **kwargs: Any additional keyword arguments. + + for calling example: + ``` + ds = Dataset.load(qianfan_dataset_id="", ...) + sft_task = PostPreTrain( + train_type="ERNIE-Bot-turbo-0725", + dataset=ds, + train_config=TrainConfig(...), + event_handler=eh, + ) + ``` + """ + # 校验train_type + if train_type is None or train_type == "": + raise InvalidArgumentError("train_type is empty") + + if isinstance(train_config, str): + train_config = TrainConfig.load(train_config) + + actions: List[BaseAction] = [] + # 初始化load action + self.load_data_action = LoadDataSetAction( + dataset, + console_consts.DataTemplateType.GenericText, + event_handler=event_handler, + **kwargs, + ) + actions.append(self.load_data_action) + # 初始化train action + self.train_action = TrainAction( + train_config=train_config, + train_type=train_type, + train_mode=console_consts.TrainMode.PostPretrain, + event_handler=event_handler, + **kwargs, + ) + actions.append(self.train_action) + ppl = Pipeline( + actions=actions, + event_handler=event_handler, + ) + self.ppls = [ppl] + self.result = [None] + + def run(self, **kwargs: Any) -> Trainer: + """_summary_ + run a pipeline to run the fine-tune process. + + Parameters: + **kwargs: + Any additional keyword arguments. + {"input": {}} could be specified if needed + + Raises: + InvalidArgumentError: no pipeline bind + to run. + Returns: + Trainer: + self, for chain invocation. + """ + self.input: Any = kwargs.get("input") + if len(self.ppls) != 1: + raise InvalidArgumentError("invalid pipeline to run") + kwargs["backoff_factor"] = kwargs.get( + "backoff_factor", get_config().TRAINER_STATUS_POLLING_BACKOFF_FACTOR + ) + kwargs["retry_count"] = kwargs.get( + "retry_count", get_config().TRAINER_STATUS_POLLING_RETRY_TIMES + ) + self.result[0] = self.ppls[0].exec(**kwargs) + return self + + @property + def status(self) -> str: + """ + PostPreTrain status getter. + + Returns: + str: status for PostPreTrain, mapping from state of actions in pipeline. + """ + if len(self.ppls) != 1: + raise InvalidArgumentError("invalid pipeline to get status") + action = self.ppls[0][str(self.ppls[0]._state)] + if action is None: + return TrainStatus.Unknown + action_name = action.__class__.__name__ + return action_mapping.get(action_name, {}).get( + action.state, TrainStatus.Unknown + ) + + def stop(self, **kwargs: Dict) -> Trainer: + """ + stop method of PostPreTrain. PostPreTrain will stop + all actions in pipeline. In fact, PostPreTrain only take one + pipeline, so it will be equal to stop first of `ppls`. + + Returns: + Trainer: + self, for chain invocation. + """ + for ppl in self.ppls: + ppl.stop() + return self + + def resume(self, **kwargs: Dict) -> "PostPreTrain": + """ + PostPreTrain resume method. + + Returns: + PostPreTrain: _description_ + """ + self.result[0] = self.ppls[0].resume(**kwargs) + return self + + @property + def output(self) -> Any: + return self.result[0] + + @classmethod + def train_type_list(cls) -> Dict[str, ModelInfo]: + return PostPreTrainModelInfoMapping diff --git a/python/qianfan/utils/bos_uploader.py b/python/qianfan/utils/bos_uploader.py index cfe4f859..af53371e 100644 --- a/python/qianfan/utils/bos_uploader.py +++ b/python/qianfan/utils/bos_uploader.py @@ -15,6 +15,7 @@ utility for uploading content to bos """ +import re from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -125,6 +126,16 @@ def generate_bos_file_parent_path(bucket_name: str, absolute_path: str) -> str: return f"bos:{p.parent}" +def is_valid_bos_path(path: str) -> bool: + pattern = r"^bos:/([a-zA-Z0-9_-]+(\/)?)*$" + match = re.match(pattern, path) + + if match: + return True + else: + return False + + def parse_bos_path(bos_path: str) -> Tuple[str, str]: """解析 bos 路径,返回 bucket 和 path""" if bos_path.startswith("bos://"): diff --git a/python/qianfan/utils/logging.py b/python/qianfan/utils/logging.py index f2642c40..a00ff7e8 100644 --- a/python/qianfan/utils/logging.py +++ b/python/qianfan/utils/logging.py @@ -18,6 +18,10 @@ from functools import partial from typing import Any +TRACE_LEVEL = 5 + +logging.addLevelName(TRACE_LEVEL, "TRACE") + class Logger(object): _DEFAULT_MSG_FORMAT = ( @@ -44,7 +48,7 @@ def __init__( # 创建一个loggger self.__name = name self._logger = logging.getLogger(self.__name) - self._logger.setLevel(logging.WARN) + self._logger.setLevel(logging.INFO) formatter = logging.Formatter(format, datefmt) handler = logging.StreamHandler() handler.setFormatter(formatter) @@ -100,6 +104,16 @@ def warn(self, message: object, *args: object, **params: Any) -> None: """ self._logger.warning(message, *args, **params) + def trace(self, message: object, *args: object, **params: Any) -> None: + """ + TRACE level log + Args: + message (object): message content + Returns: + None + """ + self._logger.log(TRACE_LEVEL, message, *args, **params) + logger = Logger() @@ -109,11 +123,13 @@ def warn(self, message: object, *args: object, **params: Any) -> None: log_debug = logger.debug log_error = logger.error log_warn = logger.warn + log_trace = logger.trace else: log_info = partial(logger.info, stacklevel=2) log_debug = partial(logger.debug, stacklevel=2) log_error = partial(logger.error, stacklevel=2) log_warn = partial(logger.warn, stacklevel=2) + log_trace = partial(logger.trace, stacklevel=2) def enable_log(log_level: int = logging.INFO) -> None: diff --git a/python/qianfan/utils/utils.py b/python/qianfan/utils/utils.py index 28afcc71..c11a081c 100644 --- a/python/qianfan/utils/utils.py +++ b/python/qianfan/utils/utils.py @@ -183,6 +183,10 @@ def snake_to_camel(name: str) -> str: return "".join([x.capitalize() for x in name.split("_")]) +def first_lower_case(name: str) -> str: + return name[:1].lower() + name[1:] + + def remove_suffix(name: str, suffix: str) -> str: if name.endswith(suffix): return name[: -len(suffix)] diff --git a/src/qianfan/tests/semantic_kernel/__init__.py b/src/qianfan/tests/semantic_kernel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qianfan/tests/semantic_kernel/chat_test.py b/src/qianfan/tests/semantic_kernel/chat_test.py new file mode 100644 index 00000000..f52c1e4b --- /dev/null +++ b/src/qianfan/tests/semantic_kernel/chat_test.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +from qianfan.extensions.semantic_kernel.connectors.qianfan_chat_completion import ( + QianfanChatCompletion, +) + + +@pytest.mark.skipif( + sys.version_info < (3, 8, 0), reason="requires Python 3.8.1 or higher" +) +def chat_test(): + chat = QianfanChatCompletion() + res = chat.complete_chat_async(messages=[{"role": "user", "content": "你好"}]) + print(res)