Skip to content

Commit

Permalink
refactor[trainer]: 修改动态TrainConfig懒加载 & 新增Trainer 命令行查看trainer任务信息 (#679
Browse files Browse the repository at this point in the history
)

* refactor[trainer]: 修改动态TrainConfig懒加载 & 新增Trainer 命令行查看trainer任务信息

* fix: remove log

* fix: add qianfan cache command & fix finetune client train from file

* fix: print

* fix: model batch_infer

* fix: cache credential
  • Loading branch information
danielhjz authored Jul 19, 2024
1 parent d7b2d93 commit 92a98f9
Show file tree
Hide file tree
Showing 14 changed files with 2,225 additions and 958 deletions.
1,939 changes: 1,804 additions & 135 deletions python/poetry.lock

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ aiolimiter = ">=1.1.0"
importlib-metadata = { version = ">=1.4.0", python = "<=3.7" }
bce-python-sdk = ">=0.8.79"
typing-extensions = { version = ">=4.0.0", python = "<=3.10" }
pydantic = "*"
pydantic = ">=1.0"
python-dotenv = [
{ version = "<=0.21.1", python = "<3.8" },
{ version = ">=1.0", python = ">=3.8" }
]
tenacity = "^8.2.3"
multiprocess = "*"
multiprocess = ">=0.70.12"
langchain = { version = ">=0.1.10", python = ">=3.8.1", optional = true }
langchain-community = { version = ">=0.2.0", python = ">=3.8.1", optional = true}
langchain-community = { version = ">=0.2.0", python = ">=3.8.1", optional = true }
numpy = [
{ version = "<1.22.0", python = ">=3.7 <3.8", optional = true },
{ version = ">=1.22.0", python = ">=3.8", optional = true }
Expand All @@ -39,26 +39,27 @@ pyarrow = [
{ version = ">=14.0.1", python = ">=3.8", optional = true },
{ version = "<=12.0.1", python = ">=3.7 <3.8", optional = true }
]
locust = { version = "*", optional = true }
tabulate = { version = "*", optional = true }
locust = { version = ">=2.9.0", optional = true }
tabulate = { version = ">=0.9.0", optional = true }
python-dateutil = { version = "^2.8.2", optional = true }
rich = ">=13.0.0"
typer = ">=0.9.0"
pyyaml = "^6.0.1"
prompt-toolkit = ">=3.0.38"
torch = [
{ version = "<=1.13.1", python = "<3.8", optional = true},
{ version = "<=1.13.1", python = "<3.8", optional = true },
{ version = ">=1.4.0", python = ">=3.8", optional = true }
]
ltp = { version = ">=4.2.0", optional = true }
emoji = { version = ">=2.2.0", optional = true }
sentencepiece = { version = ">=0.1.98", optional = true }
diskcache = "^5.6.3"
diskcache = ">=5.6.3"
cachetools = ">=5.0.0"

ijson = { version = "*", optional = true }
fastapi = { version = "*", optional = true }
uvicorn = { version = "*", optional = true }
filelock = { version = "*", optional = true}
ijson = { version = ">=3.0", optional = true }
fastapi = { version = ">=0.85.0", optional = true }
uvicorn = { version = ">=0.15.0", optional = true }
filelock = { version = ">=3.7.0", optional = true }

[tool.poetry.scripts]
qianfan = "qianfan.common.client.main:main"
Expand Down
26 changes: 24 additions & 2 deletions python/qianfan/common/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional

import click
Expand All @@ -33,7 +32,7 @@
print_error_msg,
print_info_msg,
)
from qianfan.config import encoding
from qianfan.config import encoding, get_config
from qianfan.utils.utils import check_dependency

app = typer.Typer(
Expand All @@ -54,6 +53,29 @@
_enable_traceback = False


@app.command(name="cache")
def clear(
clear: Optional[bool] = typer.Option(
None,
"--clear",
help="clear qianfan cache",
),
) -> None:
"""
clear qianfan cache.
"""
import shutil

# 要删除的目录路径
dir_path = get_config().CACHE_DIR
# 删除目录
try:
shutil.rmtree(dir_path)
print_info_msg(f"目录 {dir_path} 已删除")
except OSError as e:
print_info_msg(f"删除目录 {dir_path} 失败: {e}")


@app.command(name="openai")
@credential_required
def openai(
Expand Down
67 changes: 48 additions & 19 deletions python/qianfan/common/client/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import json
import time
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -41,7 +42,7 @@
from qianfan.model.configs import DeployConfig
from qianfan.model.consts import ServiceType
from qianfan.resources.console.consts import DeployPoolType, FinetuneSupportModelType
from qianfan.trainer import DPO, LLMFinetune, PostPreTrain
from qianfan.trainer import DPO, Finetune, PostPreTrain, Trainer
from qianfan.trainer.actions import (
DeployAction,
EvaluateAction,
Expand Down Expand Up @@ -113,18 +114,18 @@ def handle_train(self, event: Event) -> None:
elif status == "Stopped":
self.progress.log("Task stopped.")
return

if not self.vdl_printed:
self.progress.log(
f"{result['trainMode']} task id: {resp['result']['taskId']}, job"
f" id: {resp['result']['jobId']}, jobName:"
f" {resp['result']['jobName']}"
)
self.progress.log(
"Check this vdl link to view training progress: "
+ resp["result"]["vdlLink"]
)
self.vdl_printed = True
vdl_link = resp["result"].get("vdlLink")
if vdl_link:
self.progress.log(
"Check this vdl link to view training progress: " + vdl_link
)
self.vdl_printed = True

if event.action_state == ActionState.Done:
if self.current_task is not None:
Expand Down Expand Up @@ -220,7 +221,7 @@ def list_train_type(
if cmd == "postpretrain":
model_list = PostPreTrain.train_type_list()
elif cmd in ["finetune", "run"]:
model_list = LLMFinetune.train_type_list()
model_list = Finetune.train_type_list()
elif cmd in ["dpo"]:
model_list = DPO.train_type_list()
else:
Expand Down Expand Up @@ -306,7 +307,7 @@ def show_config_limit(
if cmd == "postpretrain":
model_list = PostPreTrain.train_type_list()
elif cmd in ["finetune", "run"]:
model_list = LLMFinetune.train_type_list()
model_list = Finetune.train_type_list()
elif cmd in ["dpo"]:
model_list = DPO.train_type_list()
else:
Expand Down Expand Up @@ -360,12 +361,15 @@ def finetune(
None, help="Task id of previous trainer output."
),
trainer_pipeline_file: Optional[str] = typer.Option(
None, help="Trainer pipeline file path"
None, "--trainer-pipeline-file", "-f", help="Trainer pipeline file path"
),
daemon: Optional[bool] = daemon_option,
list_train_type: Optional[bool] = list_train_type_option,
show_config_limit: Optional[str] = typer.Option(
None,
"--show-config-limit",
"--show",
"-s",
callback=show_config_limit,
is_eager=True,
help="Show config limit for specified train type.",
Expand Down Expand Up @@ -445,7 +449,7 @@ def finetune(
callback = MyEventHandler(console=console)

if trainer_pipeline_file is not None:
trainer = LLMFinetune.load(file=trainer_pipeline_file)
trainer = Finetune.load(file=trainer_pipeline_file)
trainer.register_event_handler(callback)
else:
ds = None
Expand All @@ -465,7 +469,7 @@ def finetune(
pool_type=DeployPoolType[deploy_pool_type],
service_type=ServiceType[deploy_service_type],
)
trainer = LLMFinetune(
trainer = Finetune(
dataset=ds,
train_type=train_type,
event_handler=callback,
Expand Down Expand Up @@ -515,6 +519,36 @@ def finetune(
time.sleep(0.1)


@trainer_app.command()
@credential_required
def info(
trainer_id: Optional[str] = typer.Option(None, help="trainer id"),
task_id: Optional[str] = typer.Option(None, help="task id"),
) -> None:
"""
get a trainer info from local cache
"""
console = replace_logger_handler()
trainer: Optional[Trainer] = None
if trainer_id:
trainer = Finetune.load(id=trainer_id)
elif task_id:
trainers = Finetune.list()
for t in trainers:
for action in t.actions:
if isinstance(action, TrainAction) and action.task_id == task_id:
trainer = t
break
else:
console.log("Must provide either trainer id or task id.")
if trainer:
json_str = json.dumps(trainer.info(), ensure_ascii=False, indent=2)
print(json_str)

# wait a second for the log to be flushed
time.sleep(0.1)


@trainer_app.command()
@credential_required
def postpretrain(
Expand Down Expand Up @@ -665,11 +699,6 @@ def postpretrain(
time.sleep(0.1)


@trainer_app.command(
"run",
deprecated=True,
help="Run a dpo trainer task.",
)
@trainer_app.command()
@credential_required
def dpo(
Expand Down Expand Up @@ -768,7 +797,7 @@ def dpo(
callback = MyEventHandler(console=console)

if trainer_pipeline_file is not None:
trainer = LLMFinetune.load(file=trainer_pipeline_file)
trainer = Finetune.load(file=trainer_pipeline_file)
trainer.register_event_handler(callback)
else:
ds = None
Expand All @@ -788,7 +817,7 @@ def dpo(
pool_type=DeployPoolType[deploy_pool_type],
service_type=ServiceType[deploy_service_type],
)
trainer = LLMFinetune(
trainer = Finetune(
dataset=ds,
train_type=train_type,
event_handler=callback,
Expand Down
4 changes: 2 additions & 2 deletions python/qianfan/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,10 +1634,10 @@ def _batch_inference_on_model(
log_error(err_msg)
raise ValueError(err_msg)

model_id = Model.detail(model_id)["result"]["modelIdStr"]
model_set_id = Model.detail(model_id)["result"]["modelIdStr"]

result_dataset_id = _start_an_evaluation_task_for_model_batch_inference(
self.inner_data_source_cache, model_id, model_id
self.inner_data_source_cache, model_set_id, model_id
)

result_dataset = Dataset.load(
Expand Down
14 changes: 9 additions & 5 deletions python/qianfan/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def __init__(
self.name = name
if id is None or set_id is None:
self.auto_complete_info()
if id is None and set_id is None:
log_warn("set id or id should be provided")
if (
(id is None and set_id is None)
or self.task_id is None
or self.job_id is None
):
log_warn("set_id/id or job_id/task_id should be provided")

def exec(
self, input: Optional[Dict] = None, **kwargs: Dict
Expand Down Expand Up @@ -386,13 +390,13 @@ def compress(
comp_task_detail_resp["result"]["status"]
== console_const.ModelCompTaskStatus.Succeeded.value
):
new_model_version_id = comp_task_detail_resp["result"].get("modelId")
new_model_id = comp_task_detail_resp["result"].get("modelId")
log_info(
f"compress task {model_comp_task_id} run with status"
f" {comp_task_detail_resp['result']['status']}"
f" new model_version_id: {new_model_version_id}"
f" new model_id: {new_model_id}"
)
new_model = Model(id=new_model_version_id)
new_model = Model(id=new_model_id)
new_model.auto_complete_info()
return new_model
else:
Expand Down
8 changes: 4 additions & 4 deletions python/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_trainer_sft_run():

eh = MyEventHandler()
sft_task = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725",
train_type="ERNIE-Speed-8K",
dataset=ds,
train_config=train_config,
event_handler=eh,
Expand All @@ -157,11 +157,11 @@ def test_trainer_sft_run():
def test_trainer_sft_run_from_bos():
with pytest.raises(InvalidArgumentError):
sft_task = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725",
train_type="ERNIE-Speed-8K",
)
sft_task.run()
sft_task = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725", dataset_bos_path="bos:/sdk-test/"
train_type="ERNIE-Speed-8K", dataset_bos_path="bos:/sdk-test/"
)
sft_task.run()
res = sft_task.result
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_trainer_resume():
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

sft_task = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725",
train_type="ERNIE-Speed-8K",
dataset=ds,
)
ppl = sft_task.ppls[0]
Expand Down
Loading

0 comments on commit 92a98f9

Please sign in to comment.