Skip to content

Commit

Permalink
feat: support start & wait; fix models (#282)
Browse files Browse the repository at this point in the history
* feat: support start & wait; fix models

* fix: compaticble and add log redirect

* fix: lint

* fix: lint

* fix: lint about platform

* fix: lint ingore windows

* fix: deps & add comments
  • Loading branch information
danielhjz authored Feb 22, 2024
1 parent bb72e8d commit 4bbf810
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 29 deletions.
4 changes: 3 additions & 1 deletion go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ var ChatModelEndpoint = map[string]string{
"ERNIE-Bot": "/chat/completions",
"ERNIE-Bot-4": "/chat/completions_pro",
"ERNIE-Bot-8k": "/chat/ernie_bot_8k",
"ERNIE-Speed": "/chat/eb_speed",
"ERNIE-3.5-4K-0205": "/chat/ernie-3.5-4k-0205",
"ERNIE-3.5-8K-0205": "/chat/ernie-3.5-8k-0205",
"ERNIE-Speed": "/chat/ernie_speed",
"ERNIE-Bot-turbo-AI": "/chat/ai_apaas",
"EB-turbo-AppBuilder": "/chat/ai_apaas",
"BLOOMZ-7B": "/chat/bloomz_7b1",
Expand Down
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.3.1"
version = "0.3.2"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down Expand Up @@ -49,6 +49,7 @@ torch = [
ltp = { version = "*", optional = true}
emoji = { version = "*", optional = true}
sentencepiece = { version = "*", optional = true}
multiprocess = "*"

[tool.poetry.scripts]
qianfan = "qianfan.common.client.main:main"
Expand Down
32 changes: 32 additions & 0 deletions python/qianfan/resources/llm/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
"user_id",
},
),
"ERNIE-3.5-4K-0205": QfLLMInfo(
endpoint="/chat/ernie-3.5-4k-0205",
required_keys={"messages"},
optional_keys={
"functions",
"temperature",
"top_p",
"penalty_score",
"stream",
"system",
"stop",
"disable_search",
"enable_citation",
"user_id",
},
),
"ERNIE-3.5-8K-0205": QfLLMInfo(
endpoint="/chat/ernie-3.5-8k-0205",
required_keys={"messages"},
optional_keys={
"functions",
"temperature",
"top_p",
"penalty_score",
"stream",
"system",
"stop",
"disable_search",
"enable_citation",
"user_id",
},
),
"ERNIE-Speed": QfLLMInfo(
endpoint="/chat/ernie_speed",
required_keys={"messages"},
Expand Down
32 changes: 32 additions & 0 deletions python/qianfan/resources/llm/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,38 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
"user_id",
},
),
"ERNIE-3.5-4K-0205": QfLLMInfo(
endpoint="/chat/ernie-3.5-4k-0205",
required_keys={"messages"},
optional_keys={
"functions",
"temperature",
"top_p",
"penalty_score",
"stream",
"system",
"stop",
"disable_search",
"enable_citation",
"user_id",
},
),
"ERNIE-3.5-8K-0205": QfLLMInfo(
endpoint="/chat/ernie-3.5-8k-0205",
required_keys={"messages"},
optional_keys={
"functions",
"temperature",
"top_p",
"penalty_score",
"stream",
"system",
"stop",
"disable_search",
"enable_citation",
"user_id",
},
),
"ERNIE-Speed": QfLLMInfo(
endpoint="/chat/ernie_speed",
required_keys={"messages"},
Expand Down
2 changes: 1 addition & 1 deletion python/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_trainer_resume():
train_action_key = k
ac.task_id = 112
ac.job_id = 123
ppl._state = train_action_key
ppl.current_action = train_action_key
sft_task.resume()
res = sft_task.result
assert res is not None
Expand Down
7 changes: 5 additions & 2 deletions python/qianfan/trainer/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,11 @@ def stop(self, **kwargs: Dict) -> None:
if self.task_id is None or self.job_id is None:
log_warn("[train_action] task_id or job_id not set, training not started")
return
api.FineTune.V2.stop_task(self.task_id)
log_debug(f"train job {self.task_id}/{self.job_id} stopped")
resp = api.FineTune.V2.stop_task(self.task_id)
if resp.get("result"):
log_debug(f"train task {self.task_id}/{self.job_id} stopped successfully")
else:
log_debug(f"train task {self.task_id}/{self.job_id} stopped failed")

def get_default_train_config(
self, model_type: str, train_mode: console_consts.TrainMode, peft_type: PeftType
Expand Down
107 changes: 95 additions & 12 deletions python/qianfan/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import datetime
import os
import pickle
import platform
import sys
import threading
from abc import ABC, abstractmethod
from threading import Lock
from typing import (
Expand All @@ -26,11 +31,14 @@
Union,
)

import multiprocess as multiprocessing

from qianfan.common.runnable.base import ExecuteSerializable
from qianfan.config import encoding
from qianfan.errors import InternalError, InvalidArgumentError
from qianfan.trainer.consts import ActionState
from qianfan.trainer.consts import ActionState, QianfanTrainerLocalCacheDir, StopMessage
from qianfan.trainer.event import Event, EventHandler, dispatch_event
from qianfan.utils import log_debug, log_error, utils
from qianfan.utils import log_debug, log_error, log_info, utils

Input = TypeVar("Input")
Output = TypeVar("Output")
Expand Down Expand Up @@ -236,7 +244,7 @@ def __init__(
self.actions[action.id] = action
self.seq.append(action.id)
self.post_actions = post_actions
self._state: str = ""
self.current_action: str = ""
self._sync_lock = Lock()
self._stop: bool = False
self._last_output: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -278,7 +286,7 @@ def exec_from(
self.action_event(
ActionState.Running, "pipeline running", {"action": k}
)
self._state = k
self.current_action = k
output = self.actions[k].exec(input=output, **kwargs)
err = output.get("error")
if err is not None:
Expand Down Expand Up @@ -311,11 +319,11 @@ def resume(self, **kwargs: Dict) -> Dict[str, Any]:
resume pipeline running from last stopped or failed action.
"""
self._stop = False
last_output = self.actions[self._state].resume(**kwargs)
if self.seq[-1] == self._state:
last_output = self.actions[self.current_action].resume(**kwargs)
if self.seq[-1] == self.current_action:
# last node return directly
return last_output
idx = self.seq.index(self._state) + 1
idx = self.seq.index(self.current_action) + 1
return self.exec_from(last_output, idx, **kwargs)

def stop(self, **kwargs: Dict) -> None:
Expand All @@ -325,10 +333,8 @@ def stop(self, **kwargs: Dict) -> None:
with self._sync_lock:
self._stop = True

action = self.actions.get(self._state)
if action is None:
raise InternalError("unknown action to stop")
else:
action = self.actions.get(self.current_action)
if action is not None:
action.stop()
return super().stop()

Expand Down Expand Up @@ -359,13 +365,17 @@ class Trainer(ABC):
- stop() stop the training process
"""

name: Optional[str] = ""
"""trainer name"""

ppls: List[Pipeline] = []
"""
Pipelines for training, there may be multiple pipelines in
the training process.
"""
result: List[Any] = []
"""pipeline running results, which may be an error or an object"""
process: Optional[multiprocessing.Process] = None

@abstractmethod
def run(self, **kwargs: Dict) -> "Trainer":
Expand All @@ -377,14 +387,87 @@ def run(self, **kwargs: Dict) -> "Trainer":
"""
...

@abstractmethod
def _get_specific_cache_path(self) -> str:
cache_path = os.path.join(
QianfanTrainerLocalCacheDir,
self.name or utils.generate_letter_num_random_id(8),
)
if not os.path.exists(cache_path):
os.makedirs(cache_path)

return cache_path

def _get_log_path(self) -> str:
current_date = datetime.datetime.now()
date_str = current_date.strftime("%Y-%m-%d")

return os.path.join(self._get_specific_cache_path(), f"{date_str}.log")

def start(self, join_on_exited: bool = False, **kwargs: Dict) -> "Trainer":
"""
Trainer start method to start a training process in background.
use `wait()` to block waiting for the training process to be
finished.
Returns:
Trainer: Trainer instance
"""

def run_subprocess(pipe: multiprocessing.Pipe) -> None:
if platform.system() != "Windows":
os.setsid() # type: ignore[attr-defined]
# redirect output
log_path = self._get_log_path()
with open(log_path, "a", encoding=encoding()) as f:
log_info(f"check trainer running log in {log_path}")
sys.stdout = f
from qianfan.utils.logging import redirect_log_to_file

redirect_log_to_file(log_path)

# start a thread for run
main_t = threading.Thread(target=self.run)
main_t.start()

import time

while True:
time.sleep(1)
msg = pipe.recv() # 接收消息
if msg == StopMessage:
log_debug("Child process received STOP signal, exiting...")
self.stop()
break
log_info("trainer subprocess exited")

parent_pipe, child_pipe = multiprocessing.Pipe()
p = multiprocessing.Process(target=run_subprocess, args=(child_pipe,))
p.start()
if not join_on_exited:
self.join = p.join
# multiprocess 在atexit注册自动join
p.join = lambda: None
self.parent_pipe = parent_pipe
self.process = p
return self

def wait(self, **kwargs: Dict) -> "Trainer":
"""
Trainer wait method. Wait for the training process to finish.
"""
if self.process and self.join:
self.join()
return self

def stop(self, **kwargs: Dict) -> "Trainer":
"""
Trainer abstract method. Subclasses implement it to support an
more controllable usage in the concrete situations.
Returns:
Trainer: Trainer instance
"""
if self.process:
self.parent_pipe.send(StopMessage)
return self

@abstractmethod
Expand Down
5 changes: 5 additions & 0 deletions python/qianfan/trainer/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ class ServiceType(str, Enum):
"""Corresponding to the `Embedding`"""
Text2Image = "Text2Image"
"""Corresponding to the `Text2Image"""


StopMessage = "STOP"
# trainer 本地缓存
QianfanTrainerLocalCacheDir = ".qianfan_trainer_cache"
13 changes: 9 additions & 4 deletions python/qianfan/trainer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class LLMFinetune(Trainer):
Class implements the SFT training pipeline with several actions.
Use `run()` to synchronously run the training pipeline until the
model training is finished.
or use `start()`, `wait()`, `stop()` to run the training asynchronously.
"""

def __init__(
Expand Down Expand Up @@ -236,7 +237,7 @@ def status(self) -> str:
"""
if len(self.ppls) != 1:
raise InvalidArgumentError("invalid pipeline to get status")
action = self.ppls[0][str(self.ppls[0]._state)]
action = self.ppls[0][str(self.ppls[0].current_action)]
if action is None:
return TrainStatus.Unknown
action_name = action.__class__.__name__
Expand All @@ -254,9 +255,13 @@ def stop(self, **kwargs: Dict) -> Trainer:
Trainer:
self, for chain invocation.
"""
for ppl in self.ppls:
ppl.stop()
return self
# 后台运行的任务
if self.process:
return super().stop(**kwargs)
else:
for ppl in self.ppls:
ppl.stop()
return self

def resume(self, **kwargs: Dict) -> "LLMFinetune":
"""
Expand Down
12 changes: 8 additions & 4 deletions python/qianfan/trainer/post_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def status(self) -> str:
"""
if len(self.ppls) != 1:
raise InvalidArgumentError("invalid pipeline to get status")
action = self.ppls[0][str(self.ppls[0]._state)]
action = self.ppls[0][str(self.ppls[0].current_action)]
if action is None:
return TrainStatus.Unknown
action_name = action.__class__.__name__
Expand All @@ -179,9 +179,13 @@ def stop(self, **kwargs: Dict) -> Trainer:
Trainer:
self, for chain invocation.
"""
for ppl in self.ppls:
ppl.stop()
return self
# 后台运行的任务
if self.process:
return super().stop(**kwargs)
else:
for ppl in self.ppls:
ppl.stop()
return self

def resume(self, **kwargs: Dict) -> "PostPreTrain":
"""
Expand Down
Loading

0 comments on commit 4bbf810

Please sign in to comment.