diff --git a/go/README.md b/go/README.md index 0bbd7f53..888c3ada 100644 --- a/go/README.md +++ b/go/README.md @@ -16,6 +16,8 @@ import ( ) ``` +> 我们提供了一些 [示例](./examples),可以帮助快速了解 SDK 的使用方法并完成常见功能。 + ### 鉴权 在使用千帆 SDK 之前,用户需要 [百度智能云控制台 - 安全认证](https://console.bce.baidu.com/iam/#/iam/accesslist) 页面获取 Access Key 与 Secret Key,并在 [千帆控制台](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application) 中创建应用,选择需要启用的服务,具体流程参见平台 [说明文档](https://cloud.baidu.com/doc/Reference/s/9jwvz2egb)。 diff --git a/go/examples/README.md b/go/examples/README.md new file mode 100644 index 00000000..59940ca3 --- /dev/null +++ b/go/examples/README.md @@ -0,0 +1,7 @@ +# 千帆 Go SDK 示例 + +本文件夹中包含一些示例程序,展示了如何使用千帆 Go SDK。 + +- [`stream_chat`](./stream_chat/main.go):实现了简易的能够在命令行与 LLM 聊天的程序,并使用流式输出,加快响应速度。 +- [`embedding_distance`](./embedding_distance/main.go):展示了如何使用千帆提供的 Embedding 模型,并计算两个文本的余弦距离。 +- [`list_models`](./list_model/main.go):展示了如何获取所有可用的模型。 \ No newline at end of file diff --git a/go/examples/embedding_distance/main.go b/go/examples/embedding_distance/main.go new file mode 100644 index 00000000..47b50a08 --- /dev/null +++ b/go/examples/embedding_distance/main.go @@ -0,0 +1,66 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "math" + + "github.com/baidubce/bce-qianfan-sdk/go/qianfan" +) + +// 展示了如何使用千帆提供的 Embedding 模型,并计算两个文本的余弦距离 +func cosDistance(embed1, embed2 []float64) (float64, error) { + length := len(embed1) + if length != len(embed2) { + return -1, fmt.Errorf("length of embed1 and embed2 must be the same") + } + s1 := 0.0 + s2 := 0.0 + sum := 0.0 + for i := 0; i < length; i++ { + s1 += math.Pow(embed1[i], 2) + s2 += math.Pow(embed2[i], 2) + sum += embed1[i] * embed2[i] + } + return sum / (math.Sqrt(s1) * math.Sqrt(s2)), nil + +} + +func main() { + // 使用前请先设置 AccessKey 和 SecretKey,通过环境变量设置可省略如下两行 + // qianfan.GetConfig().AccessKey = "your_access_key" + // qianfan.GetConfig().SecretKey = "your_secret_key" + + sentence1 := "你好" + sentence2 := "hello" + + embed := qianfan.NewEmbedding() + resp, err := embed.Do(context.TODO(), &qianfan.EmbeddingRequest{ + Input: []string{sentence1, sentence2}, + }) + if err != nil { + panic(err) + } + embed1 := resp.Data[0].Embedding + embed2 := resp.Data[1].Embedding + + distance, err := cosDistance(embed1, embed2) + if err != nil { + panic(err) + } + fmt.Println(distance) +} diff --git a/go/examples/go.mod b/go/examples/go.mod new file mode 100644 index 00000000..fdfc2e4d --- /dev/null +++ b/go/examples/go.mod @@ -0,0 +1,29 @@ +module github.com/baidubce/bce-qianfan-sdk/go/examples + +go 1.19 + +require ( + github.com/baidubce/bce-qianfan-sdk/go/qianfan v0.0.1 // indirect + github.com/baidubce/bce-sdk-go v0.9.164 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/viper v1.18.2 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go/examples/go.sum b/go/examples/go.sum new file mode 100644 index 00000000..1b877775 --- /dev/null +++ b/go/examples/go.sum @@ -0,0 +1,60 @@ +github.com/baidubce/bce-qianfan-sdk/go/qianfan v0.0.1 h1:Nfjklb07jSDD7qovDy0oz+pawEr/7vLP+BXPrQsKZos= +github.com/baidubce/bce-qianfan-sdk/go/qianfan v0.0.1/go.mod h1:f/kIWWvAHAcU7bzgkfN30SkpN0I4lLvsJkljVK6v5YY= +github.com/baidubce/bce-sdk-go v0.9.164 h1:7gswLMsdQyarovMKuv3i6wxFQ3BQgvc5CmyGXb/D/xA= +github.com/baidubce/bce-sdk-go v0.9.164/go.mod h1:zbYJMQwE4IZuyrJiFO8tO8NbtYiKTFTbwh4eIsqjVdg= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/examples/list_model/main.go b/go/examples/list_model/main.go new file mode 100644 index 00000000..e2d3e63f --- /dev/null +++ b/go/examples/list_model/main.go @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "github.com/baidubce/bce-qianfan-sdk/go/qianfan" + +// 展示了如何获取所有可用的模型 +func main() { + // 根据用途选择 + u := qianfan.NewChatCompletion() + // u := qianfan.NewCompletion() + // u := qianfan.NewEmbedding() + + for _, m := range u.ModelList() { + println(m) + } +} diff --git a/go/examples/stream_chat/main.go b/go/examples/stream_chat/main.go new file mode 100644 index 00000000..4b1e7ead --- /dev/null +++ b/go/examples/stream_chat/main.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + + "github.com/baidubce/bce-qianfan-sdk/go/qianfan" +) + +// 实现了简易的能够在命令行与 LLM 聊天的程序,并使用流式输出,加快响应速度 +func main() { + // 使用前请先设置 AccessKey 和 SecretKey,通过环境变量设置可省略如下两行 + // qianfan.GetConfig().AccessKey = "your_access_key" + // qianfan.GetConfig().SecretKey = "your_secret_key" + + chat := qianfan.NewChatCompletion( + qianfan.WithModel("ERNIE-Bot-4"), + ) + chatHistory := []qianfan.ChatCompletionMessage{} + + for { + var userMsg string + fmt.Println("User Input:") + fmt.Scan(&userMsg) + fmt.Println() + + chatHistory = append(chatHistory, qianfan.ChatCompletionUserMessage(userMsg)) + + stream, err := chat.Stream(context.TODO(), &qianfan.ChatCompletionRequest{ + Messages: chatHistory, + }) + + if err != nil { + panic(err) + } + + fmt.Println("Assistant Output:") + var outputMsg string + for { + r, err := stream.Recv() + if err != nil { + panic(err) + } + if r.IsEnd { + break + } + fmt.Print(r.Result) + outputMsg = outputMsg + r.Result + } + chatHistory = append(chatHistory, qianfan.ChatCompletionAssistantMessage(outputMsg)) + fmt.Print("\n\n") + } +} diff --git a/go/qianfan/auth.go b/go/qianfan/auth.go new file mode 100644 index 00000000..b41512ef --- /dev/null +++ b/go/qianfan/auth.go @@ -0,0 +1,148 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "fmt" + "sync" + "time" + + "github.com/mitchellh/mapstructure" +) + +type AccessTokenRequest struct { + GrantType string `mapstructure:"grant_type"` + ClientId string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` +} + +func newAccessTokenRequest(ak, sk string) *AccessTokenRequest { + return &AccessTokenRequest{ + GrantType: "client_credentials", + ClientId: ak, + ClientSecret: sk, + } +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + SessionKey string `json:"session_key"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + SessionSecret string `json:"session_secret"` + baseResponse +} + +func (r *AccessTokenResponse) GetErrorCode() string { + return r.Error +} + +type credential struct { + AK string + SK string +} + +type accessToken struct { + token string + lastUpateTime time.Time +} + +type AuthManager struct { + tokenMap map[credential]*accessToken + lock sync.Mutex + *Requestor +} + +func maskAk(ak string) string { + unmaskLen := 6 + if len(ak) < unmaskLen { + return ak + } + return fmt.Sprintf("%s******", ak[:unmaskLen]) +} + +var _authManager *AuthManager + +func GetAuthManager() *AuthManager { + if _authManager == nil { + _authManager = &AuthManager{ + tokenMap: make(map[credential]*accessToken), + lock: sync.Mutex{}, + Requestor: newRequestor(makeOptions()), + } + } + return _authManager +} + +func (m *AuthManager) GetAccessToken(ak, sk string) (string, error) { + token, ok := func() (*accessToken, bool) { + m.lock.Lock() + defer m.lock.Unlock() + token, ok := m.tokenMap[credential{ak, sk}] + return token, ok + }() + if ok { + return token.token, nil + } + logger.Infof("Access token of ak `%s` not found, tring to refresh it...", maskAk(ak)) + return m.GetAccessTokenWithRefresh(ak, sk) +} + +func (m *AuthManager) GetAccessTokenWithRefresh(ak, sk string) (string, error) { + m.lock.Lock() + defer m.lock.Unlock() + + token, ok := m.tokenMap[credential{ak, sk}] + if ok { + lastUpdate := token.lastUpateTime + current := time.Now() + // 最近更新时间小于最小刷新间隔,则直接返回 + // 避免多个请求同时刷新,导致token被刷新多次 + if current.Sub(lastUpdate) < time.Duration(GetConfig().AccessTokenRefreshMinInterval)*time.Second { + logger.Debugf("Access token of ak `%s` was freshed %s ago, skip refreshing", maskAk(ak), current.Sub(lastUpdate)) + return token.token, nil + } + } + + resp := AccessTokenResponse{} + req, err := newAuthRequest("POST", authAPIPrefix, nil) + if err != nil { + return "", err + } + params := newAccessTokenRequest(ak, sk) + paramsMap := make(map[string]string) + err = mapstructure.Decode(params, ¶msMap) + if err != nil { + return "", err + } + req.Params = paramsMap + err = m.Requestor.request(req, &resp) + if err != nil { + return "", err + } + if resp.Error != "" { + logger.Errorf("refresh access token of ak `%s` failed with error: %s", maskAk(ak), resp.ErrorDescription) + return "", &APIError{Msg: resp.ErrorDescription} + } + logger.Infof("Access token of ak `%s` was refreshed", maskAk(ak)) + m.tokenMap[credential{ak, sk}] = &accessToken{ + token: resp.AccessToken, + lastUpateTime: time.Now(), + } + return resp.AccessToken, nil +} diff --git a/go/qianfan/auth_test.go b/go/qianfan/auth_test.go new file mode 100644 index 00000000..62f1c9b2 --- /dev/null +++ b/go/qianfan/auth_test.go @@ -0,0 +1,195 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func fakeAccessToken(ak, sk string) string { + return fmt.Sprintf("%s.%s", ak, sk) +} + +func resetAuthManager() { + _authManager = nil +} + +func setAccessTokenExpired(ak, sk string) { + GetAuthManager().tokenMap[credential{ak, sk}] = &accessToken{ + token: "expired_token", + lastUpateTime: time.Now().Add(-100 * time.Hour), // 100s 过期 + } +} + +func TestAuth(t *testing.T) { + resetAuthManager() + ak, sk := "ak_33", "sk_4235" + // 第一次获取前,缓存里应当没有 + _, ok := GetAuthManager().tokenMap[credential{ak, sk}] + assert.False(t, ok) + + accessTok, err := GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) + updateTime := GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime + // 再测试一次,应当从缓存里获取,更新时间不变 + accessTok, err = GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) + assert.Equal( + t, + updateTime, + GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime, + ) + // 模拟过期 + ak, sk = "ak_95411", "sk_87135" + setAccessTokenExpired(ak, sk) + // 设置一个附近的更新时间,用来测试是否会忽略刚更新过的 token + GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime = time.Now() + + accessTok, err = GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Equal(t, accessTok, "expired_token") // 直接获取还是从缓存获取 + + accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk) + assert.NoError(t, err) + assert.Equal(t, accessTok, "expired_token") // 刷新后,由于 lastUpdateTime 太接近,依旧使用缓存 + setAccessTokenExpired(ak, sk) + + accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk) + assert.NoError(t, err) + assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) // 应当刷新 + elaplsed := time.Since(GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime) + assert.Less(t, elaplsed, 10*time.Second) // 刷新后,lastUpdateTime 应当更新 +} + +func TestAuthFailed(t *testing.T) { + ak, sk := "bad_ak", "bad_sk" + _, err := GetAuthManager().GetAccessToken(ak, sk) + assert.Error(t, err) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Contains(t, err.Error(), target.Msg) + assert.Equal(t, target.Msg, "Client authentication failed") +} + +func TestAuthWhenUsing(t *testing.T) { + defer resetTestEnv() + _authManager = nil + GetConfig().AccessKey = "access_key_484913" + GetConfig().SecretKey = "secret_key_48135" + GetConfig().AK = "" + GetConfig().SK = "" + // 未设置 AKSK,所以用 IAM 鉴权 + chat := NewChatCompletion() + resp, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, + }, + ) + assert.NoError(t, err) + signedKey, ok := resp.RawResponse.Request.Header["Authorization"] + assert.True(t, ok) + assert.Contains(t, signedKey[0], GetConfig().AccessKey) + assert.NotContains(t, resp.RawResponse.Request.URL.RawQuery, "access_token") + // 设置了 AKSK,所以用 AKSK 鉴权 + GetConfig().AK = "ak_48915684" + GetConfig().SK = "sk_78941813" + resp, err = chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, + }, + ) + assert.NoError(t, err) + _, ok = resp.RawResponse.Request.Header["Authorization"] + assert.False(t, ok) + assert.Contains(t, resp.RawResponse.Request.URL.RawQuery, "access_token") + assert.Equal( + t, + resp.RawResponse.Request.URL.Query().Get("access_token"), + fakeAccessToken(GetConfig().AK, GetConfig().SK), + ) + // 如果只设置了部分鉴权信息,则报错 + GetConfig().AK = "" + GetConfig().AccessKey = "" + _, err = chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, + }, + ) + assert.Error(t, err) + var target *CredentialNotFoundError + assert.ErrorAs(t, err, &target) +} + +func TestAccessTokenExpired(t *testing.T) { + defer resetTestEnv() + _authManager = nil + ak, sk := "ak_48915684", "sk_78941813" + GetConfig().AK = ak + GetConfig().SK = sk + setAccessTokenExpired(ak, sk) + token, err := GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Contains(t, token, "expired") + prompt := "你好" + chat := NewChatCompletion() + resp, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)}, + }, + ) + assert.NoError(t, err) + assert.Contains(t, resp.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk)) + assert.Contains(t, resp.Result, prompt) + token, err = GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Equal(t, token, fakeAccessToken(ak, sk)) + + // 测试流式请求的刷新 token + setAccessTokenExpired(ak, sk) + token, err = GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Contains(t, token, "expired") + stream, err := chat.Stream( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)}, + }, + ) + assert.NoError(t, err) + + for { + r, err := stream.Recv() + assert.NoError(t, err) + token, err = GetAuthManager().GetAccessToken(ak, sk) + assert.NoError(t, err) + assert.Equal(t, token, fakeAccessToken(ak, sk)) + assert.Contains(t, r.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk)) + if r.IsEnd { + break + } + } + +} diff --git a/go/qianfan/base_model.go b/go/qianfan/base_model.go index 203a364f..70ec2f4a 100644 --- a/go/qianfan/base_model.go +++ b/go/qianfan/base_model.go @@ -14,7 +14,14 @@ package qianfan -import "fmt" +import ( + "encoding/json" + "errors" + "io" + "math" + "strconv" + "time" +) // 模型相关的结构体基类 type BaseModel struct { @@ -32,6 +39,7 @@ type ModelUsage struct { type ModelAPIResponse interface { GetError() (int, string) + ClearError() } // API 错误信息 @@ -40,10 +48,22 @@ type ModelAPIError struct { ErrorMsg string `json:"error_msg"` // 错误消息 } +// 获取错误码和错误信息 func (e *ModelAPIError) GetError() (int, string) { return e.ErrorCode, e.ErrorMsg } +// 获取错误码 +func (e *ModelAPIError) GetErrorCode() string { + return strconv.Itoa(e.ErrorCode) +} + +// 清除错误码 +func (e *ModelAPIError) ClearError() { + e.ErrorCode = 0 + e.ErrorMsg = "" +} + // 搜索结果 type SearchResult struct { Index int `json:"index"` // 序号 @@ -64,7 +84,7 @@ type ModelResponse struct { SentenceId int `json:"sentence_id"` // 表示当前子句的序号。只有在流式接口模式下会返回该字段 IsEnd bool `json:"is_end"` // 表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段 IsTruncated bool `json:"is_truncated"` // 当前生成的结果是否被截断 - Result string `json:"result"` // 对话返回结果 + Result string `json:"result"` // 对话返回结果 NeedClearHistory bool `json:"need_clear_history"` // 表示用户输入是否存在安全风险,是否关闭当前会话,清理历史会话信息 Usage ModelUsage `json:"usage"` // token统计信息 FunctionCall *FunctionCall `json:"function_call"` // 由模型生成的函数调用,包含函数名称,和调用参数 @@ -79,9 +99,71 @@ type ModelResponseStream struct { *streamInternal } +func newModelResponseStream(si *streamInternal) *ModelResponseStream { + return &ModelResponseStream{streamInternal: si} +} + +func (s *ModelResponseStream) checkResponseError() error { + tokenRefreshed := false + var apiError *APIError + // LLMRetryCount 为 0 时表示不限制重试次数 + for retryCount := 0; retryCount < s.Options.LLMRetryCount || s.Options.LLMRetryCount == 0; retryCount++ { + contentType := s.httpResponse.Header.Get("Content-Type") + if contentType == "application/json" { + // 遇到错误 + var resp ModelResponse + content, err := io.ReadAll(s.httpResponse.Body) + if err != nil { + return err + } + + err = json.Unmarshal(content, &resp) + if err != nil { + return err + } + apiError = &APIError{Code: resp.ErrorCode, Msg: resp.ErrorMsg} + if !tokenRefreshed && (resp.ErrorCode == APITokenInvalidErrCode || resp.ErrorCode == APITokenExpiredErrCode) { + tokenRefreshed = true + _, err := GetAuthManager().GetAccessTokenWithRefresh(GetConfig().AK, GetConfig().SK) + if err != nil { + return err + } + retryCount-- + } else if resp.ErrorCode != QPSLimitReachedErrCode && resp.ErrorCode != ServerHighLoadErrCode { + return apiError + } + err = s.reset() + if err != nil { + return err + } + logger.Warnf("stream request got error: %s, retrying request... retry count: %d", apiError, retryCount) + } else { + return nil + } + time.Sleep( + time.Duration( + math.Pow( + 2, + float64(retryCount))*float64(s.Options.LLMRetryBackoffFactor), + ) * time.Second, + ) + } + + if apiError == nil { + return &InternalError{Msg: "there must be an api error here"} + } + return apiError +} + // 获取ModelResponse流式结果 func (s *ModelResponseStream) Recv() (*ModelResponse, error) { var resp ModelResponse + if s.firstResponse { + err := s.checkResponseError() + if err != nil { + return nil, err + } + } err := s.streamInternal.Recv(&resp) if err != nil { return nil, err @@ -95,7 +177,79 @@ func (s *ModelResponseStream) Recv() (*ModelResponse, error) { func checkResponseError(resp ModelAPIResponse) error { errCode, errMsg := resp.GetError() if errCode != 0 { - return fmt.Errorf("API return error. code: %d, msg: %s", errCode, errMsg) + return &APIError{Code: errCode, Msg: errMsg} + } + return nil +} + +func (m *BaseModel) withRetry(fn func() error) error { + var err error + // 当 LLMRetryCount 为 0 表示不限制重试次数 + for retryCount := 0; retryCount < m.Options.LLMRetryCount || m.Options.LLMRetryCount == 0; retryCount++ { + err = fn() + if err == nil { + return nil + } + if _, ok := err.(*tryAgainError); ok { + retryCount -= 1 + continue + } + var apiErr *APIError + ok := errors.As(err, &apiErr) + if ok { + if apiErr.Code != QPSLimitReachedErrCode && apiErr.Code != ServerHighLoadErrCode { + return err + } + } + logger.Warnf("request got error: %s, retrying request... retry count: %d", err, retryCount) + time.Sleep( + time.Duration( + math.Pow( + 2, + float64(retryCount))*float64(m.Options.LLMRetryBackoffFactor), + ) * time.Second, + ) + } + return err +} + +func (m *BaseModel) requestResource(request *QfRequest, response any) error { + qfResponse, ok := response.(QfResponse) + if !ok { + return &InternalError{Msg: "response is not QfResponse"} + } + modelApiResponse, ok := response.(ModelAPIResponse) + if !ok { + return &InternalError{Msg: "response is not ModelResponse"} + } + var err error + tokenRefreshed := false + requestFunc := func() error { + modelApiResponse.ClearError() + err = m.Requestor.request(request, qfResponse) + if err != nil { + return err + } + err = checkResponseError(modelApiResponse) + if err != nil { + errCode, _ := modelApiResponse.GetError() + if !tokenRefreshed && (errCode == APITokenInvalidErrCode || errCode == APITokenExpiredErrCode) { + // access token 过期,重新获取 access token 并重试,且不占用重试次数 + tokenRefreshed = true + _, err := GetAuthManager().GetAccessTokenWithRefresh(GetConfig().AK, GetConfig().SK) + if err != nil { + return err + } + return &tryAgainError{} + } + // 其他错误直接返回 + return err + } + return nil + } + retryErr := m.withRetry(requestFunc) + if retryErr != nil { + return err } return nil } diff --git a/go/qianfan/chat_completion.go b/go/qianfan/chat_completion.go index 3ada973c..a3f11be6 100644 --- a/go/qianfan/chat_completion.go +++ b/go/qianfan/chat_completion.go @@ -16,7 +16,6 @@ package qianfan import ( "context" - "fmt" ) // 表示对话内容的结构体 @@ -141,15 +140,16 @@ func newChatCompletion(options *Options) *ChatCompletion { // 将 endpoint 转换成完整的 url func (c *ChatCompletion) realEndpoint() (string, error) { url := modelAPIPrefix - if c.Model != "" { + if c.Endpoint == "" { endpoint, ok := ChatModelEndpoint[c.Model] if !ok { - return "", fmt.Errorf("model %s is not supported", c.Model) + return "", &ModelNotSupportedError{Model: c.Model} } url += endpoint } else { url += "/chat/" + c.Endpoint } + logger.Debugf("requesting endpoint: %s", url) return url, nil } @@ -164,13 +164,12 @@ func (c *ChatCompletion) Do(ctx context.Context, request *ChatCompletionRequest) return nil, err } var resp ModelResponse - err = c.Requestor.request(req, &resp) + + err = c.requestResource(req, &resp) if err != nil { return nil, err } - if err = checkResponseError(&resp); err != nil { - return &resp, err - } + return &resp, nil } @@ -189,9 +188,7 @@ func (c *ChatCompletion) Stream(ctx context.Context, request *ChatCompletionRequ if err != nil { return nil, err } - return &ModelResponseStream{ - streamInternal: stream, - }, nil + return newModelResponseStream(stream), nil } // chat 支持的模型列表 diff --git a/go/qianfan/chat_completion_test.go b/go/qianfan/chat_completion_test.go index c95c3823..f51186b1 100644 --- a/go/qianfan/chat_completion_test.go +++ b/go/qianfan/chat_completion_test.go @@ -17,19 +17,55 @@ package qianfan import ( "context" "encoding/json" + "fmt" "io" + "math/rand" "net/http" "os" "strings" "testing" + "time" "github.com/mitchellh/mapstructure" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) +var testEndpointList = []string{ + "endpoint1", + "sidaofjnon", + "98349823", + "fjid_432", +} + func TestChatCompletion(t *testing.T) { for model, endpoint := range ChatModelEndpoint { chat := NewChatCompletion(WithModel(model)) + resp, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + ChatCompletionAssistantMessage("回复"), + ChatCompletionUserMessage("哈哈"), + }, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.NotEqual(t, resp.Id, nil) + assert.Equal(t, resp.Object, "chat.completion") + assert.Contains(t, resp.RawResponse.Request.URL.Path, endpoint) + assert.Contains(t, resp.Result, "你好") + assert.Contains(t, resp.Result, "回复") + assert.Contains(t, resp.Result, "哈哈") + + req, err := getRequestBody[ChatCompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Messages[0].Content, "你好") + } + for _, endpoint := range testEndpointList { + chat := NewChatCompletion(WithEndpoint(endpoint)) resp, err := chat.Do( context.Background(), &ChatCompletionRequest{ @@ -82,12 +118,182 @@ func TestChatCompletionStream(t *testing.T) { } assert.True(t, turn_count > 1) } + for _, endpoint := range testEndpointList { + chat := NewChatCompletion(WithEndpoint(endpoint)) + resp, err := chat.Stream( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + turn_count := 0 + for { + r, err := resp.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turn_count++ + assert.Equal(t, r.RawResponse.StatusCode, 200) + assert.NotEqual(t, r.Id, nil) + assert.Equal(t, r.Object, "chat.completion") + assert.Contains(t, r.RawResponse.Request.URL.Path, endpoint) + assert.Contains(t, r.Result, "你好") + req, err := getRequestBody[ChatCompletionRequest](r.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Messages[0].Content, "你好") + } + assert.True(t, turn_count > 1) + } } -func TestMain(m *testing.M) { +func TestChatCompletionUnsupportedModel(t *testing.T) { + chat := NewChatCompletion(WithModel("unsupported_model")) + _, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported_model") + var target *ModelNotSupportedError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Model, "unsupported_model") +} + +func TestChatCompletionAPIError(t *testing.T) { + chat := NewChatCompletion() + _, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{}, + }, + ) + assert.Error(t, err) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, 336003) +} + +func TestChatCompletionModelList(t *testing.T) { + list := NewChatCompletion().ModelList() + assert.Greater(t, len(list), 0) +} + +func TestChatCompletionRetry(t *testing.T) { + defer resetTestEnv() + chat := NewChatCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + WithLLMRetryCount(5), + ) + resp, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.Object, "chat.completion") + _, err = chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{}, + }, + ) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, InvalidParamErrCode) + + chat = NewChatCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + ) + _, err = chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.Error(t, err) + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, ServerHighLoadErrCode) +} + +func TestChatCompletionStreamRetry(t *testing.T) { + GetConfig().LLMRetryCount = 5 + defer resetTestEnv() + chat := NewChatCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + ) + resp, err := chat.Stream( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + turn_count := 0 + for { + r, err := resp.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turn_count++ + assert.Equal(t, r.RawResponse.StatusCode, 200) + assert.NotEqual(t, r.Id, nil) + assert.Equal(t, r.Object, "chat.completion") + assert.Contains(t, r.RawResponse.Request.URL.Path, "test_retry") + assert.Contains(t, r.Result, "你好") + req, err := getRequestBody[ChatCompletionRequest](r.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Messages[0].Content, "你好") + } + assert.True(t, turn_count > 1) + + chat = NewChatCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + WithLLMRetryCount(1), + ) + resp, err = chat.Stream( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + _, err = resp.Recv() + assert.Error(t, err) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, ServerHighLoadErrCode) +} + +func resetTestEnv() { + rand.Seed(time.Now().UnixNano()) + logger.SetLevel(logrus.DebugLevel) os.Setenv("QIANFAN_BASE_URL", "http://127.0.0.1:8866") os.Setenv("QIANFAN_ACCESS_KEY", "test_access_key") os.Setenv("QIANFAN_SECRET_KEY", "test_secret_key") + _authManager = nil + _config = nil +} + +func TestMain(m *testing.M) { + resetTestEnv() os.Exit(m.Run()) } diff --git a/go/qianfan/completion.go b/go/qianfan/completion.go index 7a654b25..55adb5ed 100644 --- a/go/qianfan/completion.go +++ b/go/qianfan/completion.go @@ -16,7 +16,6 @@ package qianfan import ( "context" - "fmt" ) // Completion 模型请求的参数结构体,但并非每个模型都完整支持如下参数,具体是否支持以 API 文档为准 @@ -83,15 +82,16 @@ func newCompletion(options *Options) *Completion { // 将 endpoint 转换成完整的 endpoint func (c *Completion) realEndpoint() (string, error) { url := modelAPIPrefix - if c.Model != "" { + if c.Endpoint == "" { endpoint, ok := CompletionModelEndpoint[c.Model] if !ok { - return "", fmt.Errorf("model %s is not supported", c.Model) + return "", &ModelNotSupportedError{Model: c.Model} } url += endpoint } else { url += "/completions/" + c.Endpoint } + logger.Debugf("requesting endpoint: %s", url) return url, nil } @@ -130,13 +130,11 @@ func (c *Completion) Do(ctx context.Context, request *CompletionRequest) (*Model return nil, err } var resp ModelResponse - err = c.Requestor.request(req, &resp) + err = c.requestResource(req, &resp) if err != nil { return nil, err } - if err = checkResponseError(&resp); err != nil { - return &resp, err - } + return &resp, nil } diff --git a/go/qianfan/completion_test.go b/go/qianfan/completion_test.go index d96544e0..79eef25f 100644 --- a/go/qianfan/completion_test.go +++ b/go/qianfan/completion_test.go @@ -16,6 +16,8 @@ package qianfan import ( "context" + "fmt" + "math/rand" "testing" "github.com/stretchr/testify/assert" @@ -60,19 +62,62 @@ func TestCompletion(t *testing.T) { assert.NoError(t, err) assert.Equal(t, reqComp.Prompt, prompt) assert.Equal(t, reqComp.Temperature, 0.5) + + completion = NewCompletion(WithEndpoint("endpoint111")) + resp, err = completion.Do( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + Temperature: 0.5, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.Object, "completion") + assert.Contains(t, resp.RawResponse.Request.URL.Path, "endpoint111") + assert.Contains(t, resp.Result, prompt) + reqComp, err = getRequestBody[CompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, reqComp.Prompt, prompt) + assert.Equal(t, reqComp.Temperature, 0.5) } func TestCompletionStream(t *testing.T) { modelList := []string{"ERNIE-Bot-turbo", "SQLCoder-7B"} + prompt := "hello" for _, m := range modelList { - chat := NewCompletion( + comp := NewCompletion( WithModel(m), ) - resp, err := chat.Stream( + resp, err := comp.Stream( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + Temperature: 0.5, + }, + ) + assert.NoError(t, err) + defer resp.Close() + turnCount := 0 + for { + resp, err := resp.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turnCount++ + assert.Contains(t, resp.Result, prompt) + } + assert.Greater(t, turnCount, 1) + } + for _, endpoint := range testEndpointList { + comp := NewCompletion( + WithEndpoint(endpoint), + ) + resp, err := comp.Stream( context.Background(), &CompletionRequest{ - Prompt: "hello", + Prompt: prompt, Temperature: 0.5, }, ) @@ -86,8 +131,125 @@ func TestCompletionStream(t *testing.T) { break } turnCount++ - assert.Contains(t, resp.Result, "hello") + assert.Contains(t, resp.Result, prompt) + assert.Equal(t, resp.Object, "completion") + assert.Contains(t, resp.RawResponse.Request.URL.Path, endpoint) + assert.Contains(t, resp.Result, prompt) + reqComp, err := getRequestBody[CompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, reqComp.Prompt, prompt) + assert.Equal(t, reqComp.Temperature, 0.5) } assert.Greater(t, turnCount, 1) } } + +func TestCompletionModelList(t *testing.T) { + list := NewCompletion().ModelList() + assert.Greater(t, len(list), 0) +} + +func TestCompletionUnsupportedModel(t *testing.T) { + comp := NewCompletion(WithModel("unsupported_model")) + _, err := comp.Do( + context.Background(), + &CompletionRequest{ + Prompt: "hello", + }, + ) + assert.Error(t, err) + var target *ModelNotSupportedError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Model, "unsupported_model") +} + +func TestCompletionAPIError(t *testing.T) { + comp := NewCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + ) + _, err := comp.Do( + context.Background(), + &CompletionRequest{ + Prompt: "", + }, + ) + assert.Error(t, err) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, 336100) +} + +func TestStreamCompletionAPIError(t *testing.T) { + comp := NewCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + ) + s, err := comp.Stream( + context.Background(), + &CompletionRequest{ + Prompt: "", + }, + ) + assert.NoError(t, err) + _, err = s.Recv() + assert.Error(t, err) +} + +func TestCompletionRetry(t *testing.T) { + defer resetTestEnv() + comp := NewCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + WithLLMRetryCount(5), + ) + resp, err := comp.Do( + context.Background(), + &CompletionRequest{ + Prompt: "", + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.Object, "completion") +} + +func TestCompletionStreamRetry(t *testing.T) { + GetConfig().LLMRetryCount = 5 + defer resetTestEnv() + prompt := "promptprompt" + comp := NewCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + ) + stream, err := comp.Stream( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + }, + ) + assert.NoError(t, err) + turnCount := 0 + for { + resp, err := stream.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turnCount++ + assert.Contains(t, resp.Result, prompt) + } + assert.Greater(t, turnCount, 1) + + comp = NewCompletion( + WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))), + WithLLMRetryCount(1), + ) + stream, err = comp.Stream( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + }, + ) + assert.NoError(t, err) + _, err = stream.Recv() + assert.Error(t, err) + var target *APIError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Code, ServerHighLoadErrCode) +} diff --git a/go/qianfan/config.go b/go/qianfan/config.go index abbca1e5..ed7b6b8f 100644 --- a/go/qianfan/config.go +++ b/go/qianfan/config.go @@ -20,20 +20,32 @@ import ( // 默认配置 var defaultConfig = map[string]string{ - "QIANFAN_ACCESS_KEY": "", - "QIANFAN_SECRET_KEY": "", - "QIANFAN_BASE_URL": "https://aip.baidubce.com", - "QIANFAN_IAM_SIGN_EXPIRATION_SEC": "300", - "QIANFAN_CONSOLE_BASE_URL": "https://qianfan.baidubce.com", + "QIANFAN_AK": "", + "QIANFAN_SK": "", + "QIANFAN_ACCESS_KEY": "", + "QIANFAN_SECRET_KEY": "", + "QIANFAN_BASE_URL": "https://aip.baidubce.com", + "QIANFAN_IAM_SIGN_EXPIRATION_SEC": "300", + "QIANFAN_CONSOLE_BASE_URL": "https://qianfan.baidubce.com", + "QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL": "3600", + "QIANFAN_LLM_API_RETRY_COUNT": "1", + "QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR": "0", + "QIANFAN_LLM_API_RETRY_TIMEOUT": "0", } // SDK 使用的全局配置,可以用 GetConfig() 获取 type Config struct { - AccessKey string `mapstructure:"QIANFAN_ACCESS_KEY"` - SecretKey string `mapstructure:"QIANFAN_SECRET_KEY"` - BaseURL string `mapstructure:"QIANFAN_BASE_URL"` - IAMSignExpirationSeconds int `mapstructure:"QIANFAN_IAM_SIGN_EXPIRATION_SEC"` - ConsoleBaseURL string `mapstructure:"QIANFAN_CONSOLE_BASE_URL"` + AK string `mapstructure:"QIANFAN_AK"` + SK string `mapstructure:"QIANFAN_SK"` + AccessKey string `mapstructure:"QIANFAN_ACCESS_KEY"` + SecretKey string `mapstructure:"QIANFAN_SECRET_KEY"` + BaseURL string `mapstructure:"QIANFAN_BASE_URL"` + IAMSignExpirationSeconds int `mapstructure:"QIANFAN_IAM_SIGN_EXPIRATION_SEC"` + ConsoleBaseURL string `mapstructure:"QIANFAN_CONSOLE_BASE_URL"` + AccessTokenRefreshMinInterval int `mapstructure:"QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL"` + LLMRetryCount int `mapstructure:"QIANFAN_LLM_API_RETRY_COUNT"` + LLMRetryTimeout float32 `mapstructure:"QIANFAN_LLM_API_RETRY_TIMEOUT"` + LLMRetryBackoffFactor float32 `mapstructure:"QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR"` } func setConfigDeafultValue(vConfig *viper.Viper) { diff --git a/go/qianfan/consts.go b/go/qianfan/consts.go index bb202aab..00b8bd78 100644 --- a/go/qianfan/consts.go +++ b/go/qianfan/consts.go @@ -15,7 +15,10 @@ package qianfan // 模型请求的前缀 -const modelAPIPrefix = "/rpc/2.0/ai_custom/v1/wenxinworkshop" +const ( + modelAPIPrefix = "/rpc/2.0/ai_custom/v1/wenxinworkshop" + authAPIPrefix = "/oauth/2.0/token" +) // 默认使用的模型 const ( @@ -23,3 +26,33 @@ const ( DefaultCompletionModel = "ERNIE-Bot-turbo" DefaultEmbeddingModel = "Embedding-V1" ) + +// API 错误码 +const ( + NoErrorErrCode = 0 + UnknownErrorErrCode = 1 + ServiceUnavailableErrCode = 2 + UnsupportedMethodErrCode = 3 + RequestLimitReachedErrCode = 4 + NoPermissionToAccessDataErrCode = 6 + GetServiceTokenFailedErrCode = 13 + AppNotExistErrCode = 15 + DailyLimitReachedErrCode = 17 + QPSLimitReachedErrCode = 18 + TotalRequestLimitReachedErrCode = 19 + InvalidRequestErrCode = 100 + APITokenInvalidErrCode = 110 + APITokenExpiredErrCode = 111 + InternalErrorErrCode = 336000 + InvalidArgumentErrCode = 336001 + InvalidJSONErrCode = 336002 + InvalidParamErrCode = 336003 + PermissionErrorErrCode = 336004 + APINameNotExistErrCode = 336005 + ServerHighLoadErrCode = 336100 + InvalidHTTPMethodErrCode = 336101 + InvalidArgumentSystemErrCode = 336104 + InvalidArgumentUserSettingErrCode = 336105 + + ConsoleInternalErrorErrCode = 500000 +) diff --git a/go/qianfan/embdding.go b/go/qianfan/embdding.go index 4a4bcda7..b4ecdb43 100644 --- a/go/qianfan/embdding.go +++ b/go/qianfan/embdding.go @@ -16,7 +16,6 @@ package qianfan import ( "context" - "fmt" ) // 用于 Embedding 相关操作的结构体 @@ -84,15 +83,16 @@ func newEmbedding(options *Options) *Embedding { // endpoint 转成完整 url func (c *Embedding) realEndpoint() (string, error) { url := modelAPIPrefix - if c.Model != "" { + if c.Endpoint == "" { endpoint, ok := EmbeddingEndpoint[c.Model] if !ok { - return "", fmt.Errorf("model %s is not supported", c.Model) + return "", &ModelNotSupportedError{Model: c.Model} } url += endpoint } else { url += "/embeddings/" + c.Endpoint } + logger.Debugf("requesting endpoint: %s", url) return url, nil } @@ -108,13 +108,11 @@ func (c *Embedding) Do(ctx context.Context, request *EmbeddingRequest) (*Embeddi } resp := &EmbeddingResponse{} - err = c.Requestor.request(req, resp) + err = c.requestResource(req, resp) if err != nil { return nil, err } - if err = checkResponseError(resp); err != nil { - return resp, err - } + return resp, nil } diff --git a/go/qianfan/embedding_test.go b/go/qianfan/embedding_test.go index 8ed78760..3cbb4a8d 100644 --- a/go/qianfan/embedding_test.go +++ b/go/qianfan/embedding_test.go @@ -34,5 +34,50 @@ func TestEmbedding(t *testing.T) { req, err := getRequestBody[EmbeddingRequest](resp.RawResponse) assert.NoError(t, err) assert.Equal(t, req.Input[0], "hello1") - assert.Equal(t, req.Input[1], "hello2") + assert.Equal(t, len(req.Input), 2) + + embed = NewEmbedding(WithModel("bge-large-zh")) + resp, err = embed.Do(context.Background(), &EmbeddingRequest{ + Input: []string{"hello3"}, + }) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.Equal(t, len(resp.Data), 1) + assert.NotEqual(t, len(resp.Data), 0) + assert.Contains(t, resp.RawResponse.Request.URL.Path, EmbeddingEndpoint["bge-large-zh"]) + req, err = getRequestBody[EmbeddingRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Input[0], "hello3") + assert.Equal(t, len(req.Input), 1) + + embed = NewEmbedding(WithEndpoint("custom_endpoint")) + resp, err = embed.Do(context.Background(), &EmbeddingRequest{ + Input: []string{"hello4"}, + }) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.Equal(t, len(resp.Data), 1) + assert.NotEqual(t, len(resp.Data), 0) + assert.Contains(t, resp.RawResponse.Request.URL.Path, "custom_endpoint") + req, err = getRequestBody[EmbeddingRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Input[0], "hello4") + assert.Equal(t, len(req.Input), 1) +} + +func TestEmbeddingModelList(t *testing.T) { + embed := NewEmbedding() + list := embed.ModelList() + assert.Greater(t, len(list), 0) +} + +func TestEmbeddingUnexistedModel(t *testing.T) { + embed := NewEmbedding(WithModel("unexisted_model")) + _, err := embed.Do(context.Background(), &EmbeddingRequest{ + Input: []string{"hello3"}, + }) + assert.Error(t, err) + var target *ModelNotSupportedError + assert.ErrorAs(t, err, &target) + assert.Equal(t, target.Model, "unexisted_model") } diff --git a/go/qianfan/error.go b/go/qianfan/error.go new file mode 100644 index 00000000..58a933f3 --- /dev/null +++ b/go/qianfan/error.go @@ -0,0 +1,70 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import "fmt" + +// 模型不被支持,请使用 `ModelList()` 获取支持的模型列表 +type ModelNotSupportedError struct { + Model string +} + +func (e *ModelNotSupportedError) Error() string { + return fmt.Sprintf("model `%s` is not supported, use `ModelList()` to acquire supported model list", e.Model) +} + +// API 返回错误 +type APIError struct { + Code int + Msg string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("api error, code: %d, msg: %s", e.Code, e.Msg) +} + +// 鉴权所需信息不足,需确保 (AccessKey, SecretKey) 或 (AK, SK) 存在 +type CredentialNotFoundError struct { +} + +func (e *CredentialNotFoundError) Error() string { + return "no enough credentails found. Please set AK and SK or AccessKey and SecretKey" +} + +// SDK 内部错误,若遇到请联系我们 +type InternalError struct { + Msg string +} + +func (e *InternalError) Error() string { + return fmt.Sprintf("internal error: %s. there might be a bug in sdk. please contact us", e.Msg) +} + +// 参数非法 +type InvalidParamError struct { + Msg string +} + +func (e *InvalidParamError) Error() string { + return fmt.Sprint("invalid param ", e.Msg) +} + +// 内部使用,表示重试 +type tryAgainError struct { +} + +func (e *tryAgainError) Error() string { + return "try again" +} diff --git a/go/qianfan/logger.go b/go/qianfan/logger.go index 21ad0b3d..ea215551 100644 --- a/go/qianfan/logger.go +++ b/go/qianfan/logger.go @@ -15,7 +15,14 @@ package qianfan import ( + "os" + "github.com/sirupsen/logrus" ) -var logger = logrus.New() +var logger = &logrus.Logger{ + Out: os.Stderr, + Formatter: new(logrus.TextFormatter), + Hooks: make(logrus.LevelHooks), + Level: logrus.InfoLevel, +} diff --git a/go/qianfan/options.go b/go/qianfan/options.go index 0db46f6a..343d0304 100644 --- a/go/qianfan/options.go +++ b/go/qianfan/options.go @@ -16,8 +16,11 @@ package qianfan type Option func(*Options) type Options struct { - Model *string - Endpoint *string + Model *string + Endpoint *string + LLMRetryCount int + LLMRetryTimeout float32 + LLMRetryBackoffFactor float32 } // 用于模型类对象设置使用的模型 @@ -34,9 +37,34 @@ func WithEndpoint(endpoint string) Option { } } +// 设置重试次数 +func WithLLMRetryCount(count int) Option { + return func(options *Options) { + options.LLMRetryCount = count + } +} + +// 设置重试超时时间 +func WithLLMRetryTimeout(timeout float32) Option { + return func(options *Options) { + options.LLMRetryTimeout = timeout + } +} + +// 设置重试退避因子 +func WithLLMRetryBackoffFactor(factor float32) Option { + return func(options *Options) { + options.LLMRetryBackoffFactor = factor + } +} + // 将多个 Option 转换成最终的 Options 对象 func makeOptions(options ...Option) *Options { - option := Options{} + option := Options{ + LLMRetryCount: GetConfig().LLMRetryCount, + LLMRetryTimeout: GetConfig().LLMRetryTimeout, + LLMRetryBackoffFactor: GetConfig().LLMRetryBackoffFactor, + } for _, opt := range options { opt(&option) } diff --git a/go/qianfan/requestor.go b/go/qianfan/requestor.go index b4cefaf4..50a505b9 100644 --- a/go/qianfan/requestor.go +++ b/go/qianfan/requestor.go @@ -24,6 +24,7 @@ import ( "net/url" "strconv" "strings" + "time" "github.com/baidubce/bce-sdk-go/auth" bceHTTP "github.com/baidubce/bce-sdk-go/http" @@ -79,6 +80,7 @@ func convertToMap(body RequestBody) (map[string]interface{}, error) { // 请求类型,用于区分是模型的请求还是管控类请求 // 在 QfRequest.Type 处被使用 const ( + authRequest = "auth" // AccessToken 鉴权请求 modelRequest = "model" consoleRequest = "console" ) @@ -98,6 +100,11 @@ func newModelRequest(method string, url string, body RequestBody) (*QfRequest, e return newRequest(modelRequest, method, url, body) } +// 创建一个用于鉴权类请求的 Request +func newAuthRequest(method string, url string, body RequestBody) (*QfRequest, error) { + return newRequest(authRequest, method, url, body) +} + // 创建一个用于管控类请求的 Request // 暂时注释避免 lint 报错 // func newConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) { @@ -106,9 +113,13 @@ func newModelRequest(method string, url string, body RequestBody) (*QfRequest, e // 创建一个 Request,body 可以是任意实现了 RequestBody 接口的类型 func newRequest(requestType string, method string, url string, body RequestBody) (*QfRequest, error) { - b, err := convertToMap(body) - if err != nil { - return nil, err + var b map[string]interface{} = nil + if body != nil { + bodyMap, err := convertToMap(body) + if err != nil { + return nil, err + } + b = bodyMap } return newRequestFromMap(requestType, method, url, b) } @@ -134,6 +145,7 @@ type baseResponse struct { // 所有回复类型需实现的接口 type QfResponse interface { SetResponse(Body []byte, RawResponse *http.Response) + GetErrorCode() string } // 设置回复中通用参数的字段 @@ -150,10 +162,37 @@ type Requestor struct { // 创建一个 Requestor func newRequestor(options *Options) *Requestor { - return &Requestor{ + r := &Requestor{ client: &http.Client{}, Options: options, } + if r.Options.LLMRetryTimeout != 0 { + r.client.Timeout = time.Duration(r.Options.LLMRetryTimeout) * time.Second + } + return r +} + +func (r *Requestor) addAuthInfo(request *QfRequest) error { + if request.Type == authRequest { + return nil + } + if GetConfig().AK != "" && GetConfig().SK != "" { + return r.addAccessToken(request) + } else if GetConfig().AccessKey != "" && GetConfig().SecretKey != "" { + return r.sign(request) + } + logger.Error("no enough credential found. Please check whether (ak, sk) or (access_key, secret_key) is set in config") + return &CredentialNotFoundError{} +} + +// 增加 accesstoken 鉴权信息 +func (r *Requestor) addAccessToken(request *QfRequest) error { + token, err := GetAuthManager().GetAccessToken(GetConfig().AK, GetConfig().SK) + if err != nil { + return err + } + request.Params["access_token"] = token + return nil } // IAM 签名 @@ -176,7 +215,11 @@ func (r *Requestor) sign(request *QfRequest) error { } else if u.Scheme == "https" { port = "443" } else { - return fmt.Errorf("unrecognized scheme: %s", u.Scheme) + logger.Errorf("Got unexpected protocol `%s` in requested API url `%s`.", u.Scheme, request.URL) + return &InvalidParamError{ + Msg: fmt.Sprintf("unrecognized protocol `%s` is set in API base url."+ + "Only http and https are supported.", u.Scheme), + } } } porti, err := strconv.Atoi(port) @@ -210,16 +253,21 @@ func (r *Requestor) sign(request *QfRequest) error { } // 对请求进行统一处理,并转换成 http.Request -func (r *Requestor) prepareRequest(request *QfRequest) (*http.Request, error) { +func (r *Requestor) prepareRequest(request QfRequest) (*http.Request, error) { // 设置溯源标识 if request.Type == modelRequest { request.URL = GetConfig().BaseURL + request.URL - request.Body["extra_parameters"] = map[string]string{ - "request_source": versionIndicator, + if _, ok := request.Body["extra_parameters"]; !ok { + request.Body["extra_parameters"] = map[string]interface{}{} } + request.Body["extra_parameters"].(map[string]interface{})["request_source"] = versionIndicator } else if request.Type == consoleRequest { request.URL = GetConfig().ConsoleBaseURL + request.URL request.Headers["request-source"] = versionIndicator + } else if request.Type == authRequest { + request.URL = GetConfig().BaseURL + request.URL + } else { + return nil, &InternalError{"unexpected request type: " + request.Type} } bodyBytes, err := json.Marshal(request.Body) if err != nil { @@ -230,8 +278,8 @@ func (r *Requestor) prepareRequest(request *QfRequest) (*http.Request, error) { return nil, err } request.Headers["Content-Type"] = "application/json" - // IAM 签名 - err = r.sign(request) + // 增加鉴权信息 + err = r.addAuthInfo(&request) if err != nil { return nil, err } @@ -250,11 +298,19 @@ func (r *Requestor) prepareRequest(request *QfRequest) (*http.Request, error) { // 进行请求,返回原始的 baseResponse,并将结果解析至 resp func (r *Requestor) request(request *QfRequest, response QfResponse) error { - req, err := r.prepareRequest(request) + req, err := r.prepareRequest(*request) + if err != nil { + return err + } + err = r.sendRequestAndParse(req, response) if err != nil { return err } - resp, err := r.client.Do(req) + return nil +} + +func (r *Requestor) sendRequestAndParse(request *http.Request, response QfResponse) error { + resp, err := r.client.Do(request) if err != nil { return err } @@ -269,23 +325,49 @@ func (r *Requestor) request(request *QfRequest, response QfResponse) error { if err != nil { return err } + return nil } // 流的内部实现,用于接收流中的响应 type streamInternal struct { - httpResponse *http.Response // 原始的 http.Response - scanner *bufio.Scanner // 读取流的 scanner - IsEnd bool // 流是否已经结束 + *Requestor // 请求器 + requestFunc func() (*http.Response, error) // 请求流中的响应的函数 + httpResponse *http.Response // 原始的 http.Response + scanner *bufio.Scanner // 读取流的 scanner + IsEnd bool // 流是否已经结束 + firstResponse bool // 是否已经读取过第一个响应 } // 创建一个流 -func newStreamInternal(httpResponse *http.Response) (*streamInternal, error) { - return &streamInternal{ - httpResponse: httpResponse, - scanner: bufio.NewScanner(httpResponse.Body), - IsEnd: false, - }, nil +func newStreamInternal(requestor *Requestor, requestFunc func() (*http.Response, error)) (*streamInternal, error) { + si := &streamInternal{ + Requestor: requestor, + requestFunc: requestFunc, + httpResponse: nil, + scanner: nil, + IsEnd: false, + firstResponse: false, + } + // 初始化请求 + err := si.reset() + if err != nil { + return nil, err + } + return si, nil +} + +func (si *streamInternal) reset() error { + response, err := si.requestFunc() + si.IsEnd = false + si.firstResponse = true + if err != nil { + si.IsEnd = true + return err + } + si.httpResponse = response + + return nil } // 关闭流 @@ -295,7 +377,11 @@ func (si *streamInternal) Close() { // 接受流中的响应,并将结果解析至 resp func (si *streamInternal) Recv(resp QfResponse) error { + si.firstResponse = false var eventData []byte + if si.scanner == nil { + si.scanner = bufio.NewScanner(si.httpResponse.Body) + } for len(eventData) == 0 { for { if !si.scanner.Scan() { @@ -338,17 +424,17 @@ func (si *streamInternal) Recv(resp QfResponse) error { // 发送请求,返回流对象 func (r *Requestor) requestStream(request *QfRequest) (*streamInternal, error) { - req, err := r.prepareRequest(request) - if err != nil { - return nil, err - } - resp, err := r.client.Do(req) - if err != nil { - return nil, err - } - stream, err := newStreamInternal(resp) - if err != nil { - return nil, err + sendRequest := func() (*http.Response, error) { + req, err := r.prepareRequest(*request) + if err != nil { + return nil, err + } + resp, err := r.client.Do(req) + if err != nil { + return nil, err + } + return resp, nil } - return stream, nil + + return newStreamInternal(r, sendRequest) } diff --git a/go/qianfan/version.go b/go/qianfan/version.go index 2480beda..6f5d99c3 100644 --- a/go/qianfan/version.go +++ b/go/qianfan/version.go @@ -21,5 +21,5 @@ package qianfan // SDK 版本 -const Version = "v0.0.1" +const Version = "v0.0.2" const versionIndicator = "qianfan_go_sdk_" + Version diff --git a/python/qianfan/tests/utils/mock_server.py b/python/qianfan/tests/utils/mock_server.py index 096e2c55..62d245f1 100644 --- a/python/qianfan/tests/utils/mock_server.py +++ b/python/qianfan/tests/utils/mock_server.py @@ -260,6 +260,20 @@ def chat(model_name): r = request.json request_header = request.headers request_id = request_header.get(Consts.XRequestID) + if model_name.startswith("test_retry"): + global retry_cnt + print("mock retry cnt", retry_cnt) + if model_name not in retry_cnt: + retry_cnt[model_name] = 1 + if retry_cnt[model_name] % 3 != 0: + # need retry + retry_cnt[model_name] = (retry_cnt[model_name] + 1) % 3 + return json_response( + { + "error_code": 336100, + "error_msg": "high load", + } + ) if request_id == "custom_req": return json_response( { @@ -656,6 +670,13 @@ def auth(): ) ak = request.args.get("client_id") sk = request.args.get("client_secret") + if "bad" in sk: + return json_response( + { + "error_description": "Client authentication failed", + "error": "invalid_client", + } + ) # check messages return json_response( {