Skip to content

Commit

Permalink
Merge branch 'feat-update-apis' into feat-trainer-ds-cors-split
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhjz committed Jun 13, 2024
2 parents 8d87cfa + 230d8a6 commit 2ea7f47
Show file tree
Hide file tree
Showing 17 changed files with 1,207 additions and 15 deletions.
169 changes: 160 additions & 9 deletions go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package qianfan

import (
"context"
"unicode/utf8"
)

// 表示对话内容的结构体
Expand Down Expand Up @@ -84,6 +85,7 @@ type ChatCompletionRequest struct {
var ChatModelEndpoint = map[string]string{
"ERNIE-Bot-turbo": "/chat/eb-instant",
"ERNIE-Lite-8K-0922": "/chat/eb-instant",
"ERNIE-Lite-8K": "/chat/ernie-lite-8k",
"ERNIE-Lite-8K-0308": "/chat/ernie-lite-8k",
"ERNIE-3.5-8K": "/chat/completions",
"ERNIE-Bot": "/chat/completions",
Expand Down Expand Up @@ -129,6 +131,101 @@ var ChatModelEndpoint = map[string]string{
"Gemma-7B-it": "/chat/gemma_7b_it",
}

// inputLimitInfo 结构体包含 maxInputChars 和 maxInputTokens
type inputLimitInfo struct {
MaxInputChars int
MaxInputTokens int
}

// 定义包含所需信息的 map
var limitMapInModelName = map[string]inputLimitInfo{
"ERNIE-Lite-8K-0922": {MaxInputChars: 11200, MaxInputTokens: 7168},
"ERNIE-Lite-8K": {MaxInputChars: 11200, MaxInputTokens: 7168},
"ERNIE-Lite-8K-0308": {MaxInputChars: 11200, MaxInputTokens: 7168},
"ERNIE-3.5-8K": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-8K": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-8K-0329": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-8K-0104": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-preemptible": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-8K-Preview-0518": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-4.0-8K-preview": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-3.5-8K-preemptible": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-3.5-128K": {MaxInputChars: 516096, MaxInputTokens: 126976},
"ERNIE-3.5-8K-preview": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-Bot-8K": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-3.5-4K-0205": {MaxInputChars: 8000, MaxInputTokens: 2048},
"ERNIE-3.5-8K-0205": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-3.5-8K-1222": {MaxInputChars: 20000, MaxInputTokens: 5120},
"ERNIE-3.5-8K-0329": {MaxInputChars: 8000, MaxInputTokens: 2048},
"ERNIE-Speed-8K": {MaxInputChars: 11200, MaxInputTokens: 7168},
"ERNIE-Speed-128K": {MaxInputChars: 507904, MaxInputTokens: 126976},
"ERNIE Speed-AppBuilder": {MaxInputChars: 11200, MaxInputTokens: 7168},
"ERNIE-Tiny-8K": {MaxInputChars: 24000, MaxInputTokens: 6144},
"ERNIE-Function-8K": {MaxInputChars: 24000, MaxInputTokens: 6144},
"ERNIE-Character-8K": {MaxInputChars: 24000, MaxInputTokens: 6144},
"BLOOMZ-7B": {MaxInputChars: 4800, MaxInputTokens: 0},
"Llama-2-7B-Chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"Llama-2-13B-Chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"Llama-2-70B-Chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"Meta-Llama-3-8B": {MaxInputChars: 4800, MaxInputTokens: 0},
"Meta-Llama-3-70B": {MaxInputChars: 4800, MaxInputTokens: 0},
"Qianfan-BLOOMZ-7B-compressed": {MaxInputChars: 4800, MaxInputTokens: 0},
"Qianfan-Chinese-Llama-2-7B": {MaxInputChars: 4800, MaxInputTokens: 0},
"ChatGLM2-6B-32K": {MaxInputChars: 4800, MaxInputTokens: 0},
"AquilaChat-7B": {MaxInputChars: 4800, MaxInputTokens: 0},
"XuanYuan-70B-Chat-4bit": {MaxInputChars: 4800, MaxInputTokens: 0},
"Qianfan-Chinese-Llama-2-13B": {MaxInputChars: 4800, MaxInputTokens: 0},
"Qianfan-Chinese-Llama-2-70B": {MaxInputChars: 4800, MaxInputTokens: 0},
"ChatLaw": {MaxInputChars: 4800, MaxInputTokens: 0},
"Yi-34B-Chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"Mixtral-8x7B-Instruct": {MaxInputChars: 4800, MaxInputTokens: 0},
"Gemma-7B-it": {MaxInputChars: 4800, MaxInputTokens: 0},
"UNSPECIFIED_MODEL": {MaxInputChars: 0, MaxInputTokens: 0},
}

var limitMapInEndpoint = map[string]inputLimitInfo{
"/chat/eb-instant": {MaxInputChars: 11200, MaxInputTokens: 7168},
"/chat/ernie-lite-8k": {MaxInputChars: 11200, MaxInputTokens: 7168},
"/chat/completions": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/completions_pro": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-4.0-8k-0329": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-4.0-8k-0104": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/completions_pro_preemptible": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/completions_adv_pro": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-4.0-8k-preview": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/completions_preemptible": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-3.5-128k": {MaxInputChars: 516096, MaxInputTokens: 126976},
"/chat/ernie-3.5-8k-preview": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie_bot_8k": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-3.5-4k-0205": {MaxInputChars: 8000, MaxInputTokens: 2048},
"/chat/ernie-3.5-8k-0205": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-3.5-8k-1222": {MaxInputChars: 20000, MaxInputTokens: 5120},
"/chat/ernie-3.5-8k-0329": {MaxInputChars: 8000, MaxInputTokens: 2048},
"/chat/ernie_speed": {MaxInputChars: 11200, MaxInputTokens: 7168},
"/chat/ernie-speed-128k": {MaxInputChars: 507904, MaxInputTokens: 126976},
"/chat/ai_apaas": {MaxInputChars: 11200, MaxInputTokens: 7168},
"/chat/ernie-tiny-8k": {MaxInputChars: 24000, MaxInputTokens: 6144},
"/chat/ernie-func-8k": {MaxInputChars: 24000, MaxInputTokens: 6144},
"/chat/ernie-char-8k": {MaxInputChars: 24000, MaxInputTokens: 6144},
"/chat/bloomz_7b1": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/llama_2_7b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/llama_2_13b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/llama_2_70b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/llama_3_8b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/llama_3_70b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/qianfan_bloomz_7b_compressed": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/qianfan_chinese_llama_2_7b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/chatglm2_6b_32k": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/aquilachat_7b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/xuanyuan_70b_chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/qianfan_chinese_llama_2_13b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/qianfan_chinese_llama_2_70b": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/chatlaw": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/yi_34b_chat": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/mixtral_8x7b_instruct": {MaxInputChars: 4800, MaxInputTokens: 0},
"/chat/gemma_7b_it": {MaxInputChars: 4800, MaxInputTokens: 0},
}

// 创建一个 User 的消息
func ChatCompletionUserMessage(message string) ChatCompletionMessage {
return ChatCompletionMessage{
Expand Down Expand Up @@ -201,6 +298,9 @@ func (c *ChatCompletion) do(ctx context.Context, request *ChatCompletionRequest)
if err != nil {
return nil, err
}

c.processWithInputLimit(ctx, request, url)

req, err := newModelRequest("POST", url, request)
if err != nil {
return nil, err
Expand Down Expand Up @@ -249,6 +349,9 @@ func (c *ChatCompletion) stream(ctx context.Context, request *ChatCompletionRequ
if err != nil {
return nil, err
}

c.processWithInputLimit(ctx, request, url)

request.SetStream()
req, err := newModelRequest("POST", url, request)
if err != nil {
Expand All @@ -275,6 +378,50 @@ func (c *ChatCompletion) stream(ctx context.Context, request *ChatCompletionRequ
return resp, err
}

func (c *ChatCompletion) processWithInputLimit(ctx context.Context, request *ChatCompletionRequest, url string) {
if len(request.Messages) == 0 {
return
}

url = url[len(modelAPIPrefix):]
limit, ok := limitMapInEndpoint[url]
if !ok {
limit, ok = limitMapInModelName[c.Model]
if !ok {
limit = limitMapInModelName["UNSPECIFIED_MODEL"]
}
}

if limit.MaxInputTokens == 0 && limit.MaxInputChars == 0 {
return
}

messages := request.Messages
totalMessageChars := 0
totalMessageTokens := 0

tokenizer := NewTokenizer()
additionalArguments := make(map[string]interface{})

truncatedIndex := len(messages) - 1

for truncatedIndex > 0 {
tokens, _ := tokenizer.CountTokens(messages[truncatedIndex].Content, TokenizerModeLocal, "", additionalArguments)

totalMessageChars += utf8.RuneCountInString(messages[truncatedIndex].Content)
totalMessageTokens += tokens

if (limit.MaxInputTokens > 0 && totalMessageTokens > limit.MaxInputTokens) ||
(limit.MaxInputChars > 0 && totalMessageChars > limit.MaxInputChars) {
break
}

truncatedIndex--
}

request.Messages = request.Messages[truncatedIndex:]
}

// chat 支持的模型列表
func (c *ChatCompletion) ModelList() []string {
i := 0
Expand All @@ -289,16 +436,20 @@ func (c *ChatCompletion) ModelList() []string {

// 创建一个 ChatCompletion 对象
//
// chat := qianfan.NewChatCompletion() // 默认使用 ERNIE-Bot-turbo 模型
// chat := qianfan.NewChatCompletion() // 默认使用 ERNIE-Bot-turbo 模型
//
// 可以通过 WithModel 指定模型
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithModel("ERNIE-Bot-4"), // 支持的模型可以通过 chat.ModelList() 获取
//
// )
// 或者通过 WithEndpoint 指定 endpoint
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithEndpoint("your_custom_endpoint"),
//
// // 可以通过 WithModel 指定模型
// chat := qianfan.NewChatCompletion(
// qianfan.WithModel("ERNIE-Bot-4"), // 支持的模型可以通过 chat.ModelList() 获取
// )
// // 或者通过 WithEndpoint 指定 endpoint
// chat := qianfan.NewChatCompletion(
// qianfan.WithEndpoint("your_custom_endpoint"),
// )
// )
func NewChatCompletion(optionList ...Option) *ChatCompletion {
options := makeOptions(optionList...)
return newChatCompletion(options)
Expand Down
4 changes: 2 additions & 2 deletions go/qianfan/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Config struct {
InferResourceRefreshInterval int `mapstructure:"QIANFAN_INFER_RESOURCE_REFRESH_MIN_INTERVAL"`
}

func setConfigDeafultValue(vConfig *viper.Viper) {
func setConfigDefaultValue(vConfig *viper.Viper) {
// 因为 viper 自动绑定无法在 unmarshal 时使用,所以这里要手动设置默认值
for k, v := range defaultConfig {
vConfig.SetDefault(k, v)
Expand All @@ -65,7 +65,7 @@ func loadConfigFromEnv() *Config {
vConfig.SetConfigFile(".env")
vConfig.SetConfigType("dotenv")
vConfig.AutomaticEnv()
setConfigDeafultValue(vConfig)
setConfigDefaultValue(vConfig)

// ignore error if config file not found
_ = vConfig.ReadInConfig()
Expand Down
117 changes: 117 additions & 0 deletions go/qianfan/tokenizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package qianfan

import (
"context"
"errors"
"strings"
"unicode"
)

type ebTokenizerRequest struct {
BaseRequestBody `mapstructure:"-"`
Prompt string
Model string
}

// ErrInternal 定义错误类型
var (
ErrInternal = errors.New("internal error")
)

// Tokenizer 结构体
type Tokenizer struct {
BaseModel
}

type TokenizerMode string

// TokenizerMode 枚举
const (
TokenizerModeLocal = TokenizerMode("local")
TokenizerModeRemote = TokenizerMode("remote")
)

// NewTokenizer 创建 Tokenizer 实例
func NewTokenizer() *Tokenizer {
return &Tokenizer{}
}

// CountTokens 计算给定文本中的 token 数量
func (t *Tokenizer) CountTokens(text string, mode TokenizerMode, model string, additionalArguments map[string]interface{}) (int, error) {
if mode == TokenizerModeLocal {
return t.localCountTokens(text, additionalArguments)
}

if mode == TokenizerModeRemote {
return t.remoteCountTokensEB(text, model)
}

return 0, ErrInternal
}

// localCountTokens 本地计算 token 数量
func (t *Tokenizer) localCountTokens(text string, additionalArguments map[string]interface{}) (int, error) {
hanTokens := 0.625
wordTokens := 1.0

// 从 additionalArguments 中获取 hanTokens 和 wordTokens 的值
if val, ok := additionalArguments["han_tokens"].(float64); ok {
hanTokens = val
}
if val, ok := additionalArguments["word_tokens"].(float64); ok {
wordTokens = val
}

hanCount := 0
textOnlyWord := ""

for _, ch := range text {
if isCJKCharacter(ch) {
hanCount++
textOnlyWord += " "
} else if isPunctuation(ch) || isSpace(ch) {
textOnlyWord += " "
} else {
textOnlyWord += string(ch)
}
}

wordCount := len(strings.Fields(textOnlyWord))
return int(float64(hanCount)*hanTokens + float64(wordCount)*wordTokens), nil
}

// remoteCountTokensEB 调用 API 获取 token 数量
func (t *Tokenizer) remoteCountTokensEB(text string, model string) (int, error) {
request := &ebTokenizerRequest{
Prompt: text,
Model: model,
}

req, err := newModelRequest("POST", modelAPIPrefix+"/tokenizer/erniebot", request)
if err != nil {
return -1, err
}

var resp ModelResponse

err = t.requestResource(context.Background(), req, &resp)
if err != nil {
return -1, err
}
return resp.Usage.TotalTokens, nil
}

// isCJKCharacter 检查字符是否是 CJK 字符
func isCJKCharacter(ch rune) bool {
return unicode.Is(unicode.Han, ch)
}

// isSpace 检查字符是否是空格
func isSpace(ch rune) bool {
return unicode.IsSpace(ch)
}

// isPunctuation 检查字符是否是标点符号
func isPunctuation(ch rune) bool {
return unicode.IsPunct(ch)
}
15 changes: 15 additions & 0 deletions java/src/main/java/com/baidubce/qianfan/Qianfan.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import com.baidubce.qianfan.model.image.Image2TextResponse;
import com.baidubce.qianfan.model.image.Text2ImageRequest;
import com.baidubce.qianfan.model.image.Text2ImageResponse;
import com.baidubce.qianfan.model.plugin.PluginRequest;
import com.baidubce.qianfan.model.plugin.PluginResponse;
import com.baidubce.qianfan.model.rerank.RerankRequest;
import com.baidubce.qianfan.model.rerank.RerankResponse;

Expand Down Expand Up @@ -124,6 +126,19 @@ public RerankResponse rerank(RerankRequest request) {
return request(request, RerankResponse.class);
}

public PluginBuilder plugin() {
return new PluginBuilder(this);
}

public PluginResponse plugin(PluginRequest request) {
return request(request, PluginResponse.class);
}

public Iterator<PluginResponse> pluginStream(PluginRequest request) {
request.setStream(true);
return requestStream(request, PluginResponse.class);
}

public <T extends BaseResponse<T>, U extends BaseRequest<U>> T request(BaseRequest<U> request, Class<T> responseClass) {
return client.request(request, responseClass);
}
Expand Down
Loading

0 comments on commit 2ea7f47

Please sign in to comment.