Skip to content

Commit

Permalink
add retry test
Browse files Browse the repository at this point in the history
  • Loading branch information
ZingLix committed Feb 20, 2024
1 parent d713367 commit 6051bff
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 95 deletions.
72 changes: 71 additions & 1 deletion go/qianfan/base_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package qianfan

import (
"fmt"
"math"
"strconv"
"time"
)

// 模型相关的结构体基类
Expand All @@ -34,6 +37,7 @@ type ModelUsage struct {

type ModelAPIResponse interface {
GetError() (int, string)
ClearError()
}

// API 错误信息
Expand All @@ -52,6 +56,12 @@ func (e *ModelAPIError) GetErrorCode() string {
return strconv.Itoa(e.ErrorCode)
}

// 清除错误码
func (e *ModelAPIError) ClearError() {
e.ErrorCode = 0
e.ErrorMsg = ""
}

// 搜索结果
type SearchResult struct {
Index int `json:"index"` // 序号
Expand All @@ -72,7 +82,7 @@ type ModelResponse struct {
SentenceId int `json:"sentence_id"` // 表示当前子句的序号。只有在流式接口模式下会返回该字段
IsEnd bool `json:"is_end"` // 表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段
IsTruncated bool `json:"is_truncated"` // 当前生成的结果是否被截断
Result string `json:"result"` // 对话返回结果
Result string `json:"result"` // 对话返回结果
NeedClearHistory bool `json:"need_clear_history"` // 表示用户输入是否存在安全风险,是否关闭当前会话,清理历史会话信息
Usage ModelUsage `json:"usage"` // token统计信息
FunctionCall *FunctionCall `json:"function_call"` // 由模型生成的函数调用,包含函数名称,和调用参数
Expand Down Expand Up @@ -107,3 +117,63 @@ func checkResponseError(resp ModelAPIResponse) error {
}
return nil
}

func (m *BaseModel) withRetry(fn func() error) error {
for retryCount := 0; retryCount < m.Options.LLMRetryCount; retryCount++ {
err := fn()
if err == nil {
return nil
}
if _, ok := err.(*tryAgainError); ok {
retryCount -= 1
continue
}
time.Sleep(
time.Duration(
math.Pow(
2,
float64(retryCount))*float64(m.Options.LLMRetryBackoffFactor),
) * time.Second,
)
}
return fmt.Errorf("g")
}

func (m *BaseModel) requestResource(request *QfRequest, response any) error {
qfResponse, ok := response.(QfResponse)
if !ok {
return &InternalError{Msg: "response is not QfResponse"}
}
modelApiResponse, ok := response.(ModelAPIResponse)
if !ok {
return &InternalError{Msg: "response is not ModelResponse"}
}
var err error
tokenRefreshed := false
requestFunc := func() error {

modelApiResponse.ClearError()
err = m.Requestor.request(request, qfResponse)
if err != nil {
return err
}
if err = checkResponseError(modelApiResponse); err != nil {
errCode, _ := modelApiResponse.GetError()
if !tokenRefreshed && (errCode == 110 || errCode == 111) {
tokenRefreshed = true
_, err := GetAuthManager().GetAccessTokenWithRefresh(GetConfig().AK, GetConfig().SK)
if err != nil {
return err
}
return &tryAgainError{}
}
return err
}
return nil
}
retryErr := m.withRetry(requestFunc)
if retryErr != nil {
return err
}
return nil
}
7 changes: 3 additions & 4 deletions go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,12 @@ func (c *ChatCompletion) Do(ctx context.Context, request *ChatCompletionRequest)
return nil, err
}
var resp ModelResponse
err = c.Requestor.request(req, &resp)

err = c.requestResource(req, &resp)
if err != nil {
return nil, err
}
if err = checkResponseError(&resp); err != nil {
return &resp, err
}

return &resp, nil
}

Expand Down
6 changes: 2 additions & 4 deletions go/qianfan/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,11 @@ func (c *Completion) Do(ctx context.Context, request *CompletionRequest) (*Model
return nil, err
}
var resp ModelResponse
err = c.Requestor.request(req, &resp)
err = c.requestResource(req, &resp)
if err != nil {
return nil, err
}
if err = checkResponseError(&resp); err != nil {
return &resp, err
}

return &resp, nil
}

Expand Down
46 changes: 37 additions & 9 deletions go/qianfan/completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ func TestCompletionStream(t *testing.T) {
modelList := []string{"ERNIE-Bot-turbo", "SQLCoder-7B"}
prompt := "hello"
for _, m := range modelList {
chat := NewCompletion(
comp := NewCompletion(
WithModel(m),
)
resp, err := chat.Stream(
resp, err := comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: prompt,
Expand All @@ -111,10 +111,10 @@ func TestCompletionStream(t *testing.T) {
assert.Greater(t, turnCount, 1)
}
for _, endpoint := range testEndpointList {
chat := NewCompletion(
comp := NewCompletion(
WithEndpoint(endpoint),
)
resp, err := chat.Stream(
resp, err := comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: prompt,
Expand Down Expand Up @@ -150,8 +150,8 @@ func TestCompletionModelList(t *testing.T) {
}

func TestCompletionUnsupportedModel(t *testing.T) {
chat := NewCompletion(WithModel("unsupported_model"))
_, err := chat.Do(
comp := NewCompletion(WithModel("unsupported_model"))
_, err := comp.Do(
context.Background(),
&CompletionRequest{
Prompt: "hello",
Expand All @@ -164,18 +164,46 @@ func TestCompletionUnsupportedModel(t *testing.T) {
}

func TestCompletionAPIError(t *testing.T) {
chat := NewCompletion(
comp := NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
resp, err := chat.Do(
_, err := comp.Do(
context.Background(),
&CompletionRequest{
Prompt: "",
},
)
fmt.Printf("%s", resp.Object)
assert.Error(t, err)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Code, 336100)
}

func TestStreamCompletionAPIError(t *testing.T) {
comp := NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
_, err := comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: "",
},
)
assert.Error(t, err)
}

func TestCompletionRetry(t *testing.T) {
defer resetTestEnv()
comp := NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
WithLLMRetryCount(5),
)
resp, err := comp.Do(
context.Background(),
&CompletionRequest{
Prompt: "",
},
)
assert.NoError(t, err)
assert.Equal(t, resp.Object, "completion")
}
11 changes: 4 additions & 7 deletions go/qianfan/embdding.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package qianfan

import (
"context"
"fmt"
)

// 用于 Embedding 相关操作的结构体
Expand Down Expand Up @@ -84,10 +83,10 @@ func newEmbedding(options *Options) *Embedding {
// endpoint 转成完整 url
func (c *Embedding) realEndpoint() (string, error) {
url := modelAPIPrefix
if c.Model != "" {
if c.Endpoint == "" {
endpoint, ok := EmbeddingEndpoint[c.Model]
if !ok {
return "", fmt.Errorf("model %s is not supported", c.Model)
return "", &ModelNotSupportedError{Model: c.Model}
}
url += endpoint
} else {
Expand All @@ -108,13 +107,11 @@ func (c *Embedding) Do(ctx context.Context, request *EmbeddingRequest) (*Embeddi
}
resp := &EmbeddingResponse{}

err = c.Requestor.request(req, resp)
err = c.requestResource(req, resp)
if err != nil {
return nil, err
}
if err = checkResponseError(resp); err != nil {
return resp, err
}

return resp, nil
}

Expand Down
47 changes: 46 additions & 1 deletion go/qianfan/embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,50 @@ func TestEmbedding(t *testing.T) {
req, err := getRequestBody[EmbeddingRequest](resp.RawResponse)
assert.NoError(t, err)
assert.Equal(t, req.Input[0], "hello1")
assert.Equal(t, req.Input[1], "hello2")
assert.Equal(t, len(req.Input), 2)

embed = NewEmbedding(WithModel("bge-large-zh"))
resp, err = embed.Do(context.Background(), &EmbeddingRequest{
Input: []string{"hello3"},
})
assert.NoError(t, err)
assert.Equal(t, resp.RawResponse.StatusCode, 200)
assert.Equal(t, len(resp.Data), 1)
assert.NotEqual(t, len(resp.Data), 0)
assert.Contains(t, resp.RawResponse.Request.URL.Path, EmbeddingEndpoint["bge-large-zh"])
req, err = getRequestBody[EmbeddingRequest](resp.RawResponse)
assert.NoError(t, err)
assert.Equal(t, req.Input[0], "hello3")
assert.Equal(t, len(req.Input), 1)

embed = NewEmbedding(WithEndpoint("custom_endpoint"))
resp, err = embed.Do(context.Background(), &EmbeddingRequest{
Input: []string{"hello4"},
})
assert.NoError(t, err)
assert.Equal(t, resp.RawResponse.StatusCode, 200)
assert.Equal(t, len(resp.Data), 1)
assert.NotEqual(t, len(resp.Data), 0)
assert.Contains(t, resp.RawResponse.Request.URL.Path, "custom_endpoint")
req, err = getRequestBody[EmbeddingRequest](resp.RawResponse)
assert.NoError(t, err)
assert.Equal(t, req.Input[0], "hello4")
assert.Equal(t, len(req.Input), 1)
}

func TestEmbeddingModelList(t *testing.T) {
embed := NewEmbedding()
list := embed.ModelList()
assert.Greater(t, len(list), 0)
}

func TestEmbeddingUnexistedModel(t *testing.T) {
embed := NewEmbedding(WithModel("unexisted_model"))
_, err := embed.Do(context.Background(), &EmbeddingRequest{
Input: []string{"hello3"},
})
assert.Error(t, err)
var target *ModelNotSupportedError
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Model, "unexisted_model")
}
19 changes: 19 additions & 0 deletions go/qianfan/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func (e *ModelNotSupportedError) Error() string {
return fmt.Sprintf("model `%s` is not supported, use `ModelList()` to acquire supported model list", e.Model)
}

// API 返回错误
type APIError struct {
Code int
Msg string
Expand All @@ -20,9 +21,27 @@ func (e *APIError) Error() string {
return fmt.Sprintf("api error, code: %d, msg: %s", e.Code, e.Msg)
}

// 鉴权所需信息不足,需确保 (AccessKey, SecretKey) 或 (AK, SK) 存在
type CredentialNotFoundError struct {
}

func (e *CredentialNotFoundError) Error() string {
return "no enough credentails found. Please set AK and SK or AccessKey and SecretKey"
}

// SDK 内部错误,若遇到请联系我们
type InternalError struct {
Msg string
}

func (e *InternalError) Error() string {
return fmt.Sprintf("internal error: %s. there might be a bug in sdk. please contact us", e.Msg)
}

// 内部使用,表示重试
type tryAgainError struct {
}

func (e *tryAgainError) Error() string {
return "try again"
}
Loading

0 comments on commit 6051bff

Please sign in to comment.