Skip to content

Commit

Permalink
fix retry
Browse files Browse the repository at this point in the history
  • Loading branch information
ZingLix committed Feb 20, 2024
1 parent 6051bff commit a1cff48
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 93 deletions.
7 changes: 4 additions & 3 deletions go/qianfan/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ func TestAccessTokenExpired(t *testing.T) {
},
)
assert.NoError(t, err)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))

for {
r, err := stream.Recv()
assert.NoError(t, err)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))
assert.Contains(t, r.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk))
if r.IsEnd {
break
Expand Down
55 changes: 54 additions & 1 deletion go/qianfan/base_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package qianfan

import (
"encoding/json"
"fmt"
"io"
"math"
"strconv"
"time"
Expand Down Expand Up @@ -97,9 +99,59 @@ type ModelResponseStream struct {
*streamInternal
}

func newModelResponseStream(si *streamInternal) *ModelResponseStream {
return &ModelResponseStream{streamInternal: si}
}

func (s *ModelResponseStream) checkResponseError() error {
tokenRefreshed := false
var apiError *APIError
for retryCount := 0; retryCount < s.Options.LLMRetryCount || s.Options.LLMRetryCount == 0; retryCount++ {
contentType := s.httpResponse.Header.Get("Content-Type")
if contentType == "application/json" {
// 遇到错误
var resp ModelResponse
content, err := io.ReadAll(s.httpResponse.Body)
if err != nil {
return err
}

err = json.Unmarshal(content, &resp)
if err != nil {
return err
}
apiError = &APIError{Code: resp.ErrorCode, Msg: resp.ErrorMsg}
if !tokenRefreshed && (resp.ErrorCode == APITokenInvalidErrCode || resp.ErrorCode == APITokenExpiredErrCode) {
tokenRefreshed = true
_, err := GetAuthManager().GetAccessTokenWithRefresh(GetConfig().AK, GetConfig().SK)
if err != nil {
return err
}
retryCount--
} else if resp.ErrorCode != QPSLimitReachedErrCode && resp.ErrorCode != ServerHighLoadErrCode {
return apiError
}
s.reset()
} else {
return nil
}
}

if apiError == nil {
return &InternalError{Msg: "there must be an api error here"}
}
return apiError
}

// 获取ModelResponse流式结果
func (s *ModelResponseStream) Recv() (*ModelResponse, error) {
var resp ModelResponse
if s.firstResponse {
err := s.checkResponseError()
if err != nil {
return nil, err
}
}
err := s.streamInternal.Recv(&resp)
if err != nil {
return nil, err
Expand Down Expand Up @@ -159,7 +211,8 @@ func (m *BaseModel) requestResource(request *QfRequest, response any) error {
}
if err = checkResponseError(modelApiResponse); err != nil {
errCode, _ := modelApiResponse.GetError()
if !tokenRefreshed && (errCode == 110 || errCode == 111) {
if !tokenRefreshed && (errCode == APITokenInvalidErrCode ||
errCode == APITokenExpiredErrCode) {
tokenRefreshed = true
_, err := GetAuthManager().GetAccessTokenWithRefresh(GetConfig().AK, GetConfig().SK)
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ func (c *ChatCompletion) Stream(ctx context.Context, request *ChatCompletionRequ
if err != nil {
return nil, err
}
return &ModelResponseStream{
streamInternal: stream,
}, nil
return newModelResponseStream(stream), nil
}

// chat 支持的模型列表
Expand Down
97 changes: 97 additions & 0 deletions go/qianfan/chat_completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package qianfan
import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
Expand Down Expand Up @@ -185,6 +186,102 @@ func TestChatCompletionModelList(t *testing.T) {
assert.Greater(t, len(list), 0)
}

func TestChatCompletionRetry(t *testing.T) {
defer resetTestEnv()
chat := NewChatCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
WithLLMRetryCount(5),
)
resp, err := chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{
ChatCompletionUserMessage("你好"),
},
},
)
assert.NoError(t, err)
assert.Equal(t, resp.Object, "chat.completion")
_, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{},
},
)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Code, InvalidParamErrCode)

chat = NewChatCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
_, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{
ChatCompletionUserMessage("你好"),
},
},
)
assert.Error(t, err)
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Code, ServerHighLoadErrCode)
}

func TestChatCompletionStreamRetry(t *testing.T) {
GetConfig().LLMRetryCount = 5
defer resetTestEnv()
chat := NewChatCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
resp, err := chat.Stream(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{
ChatCompletionUserMessage("你好"),
},
},
)
assert.NoError(t, err)
turn_count := 0
for {
r, err := resp.Recv()
assert.NoError(t, err)
if resp.IsEnd {
break
}
turn_count++
assert.Equal(t, r.RawResponse.StatusCode, 200)
assert.NotEqual(t, r.Id, nil)
assert.Equal(t, r.Object, "chat.completion")
assert.Contains(t, r.RawResponse.Request.URL.Path, "test_retry")
assert.Contains(t, r.Result, "你好")
req, err := getRequestBody[ChatCompletionRequest](r.RawResponse)
assert.NoError(t, err)
assert.Equal(t, req.Messages[0].Content, "你好")
}
assert.True(t, turn_count > 1)

chat = NewChatCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
WithLLMRetryCount(1),
)
resp, err = chat.Stream(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{
ChatCompletionUserMessage("你好"),
},
},
)
assert.NoError(t, err)
_, err = resp.Recv()
assert.Error(t, err)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Code, ServerHighLoadErrCode)
}

func resetTestEnv() {
rand.Seed(time.Now().UnixNano())
logger.SetLevel(logrus.DebugLevel)
Expand Down
48 changes: 47 additions & 1 deletion go/qianfan/completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ func TestStreamCompletionAPIError(t *testing.T) {
comp := NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
_, err := comp.Stream(
s, err := comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: "",
},
)
assert.NoError(t, err)
_, err = s.Recv()
assert.Error(t, err)
}

Expand All @@ -207,3 +209,47 @@ func TestCompletionRetry(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, resp.Object, "completion")
}

func TestCompletionStreamRetry(t *testing.T) {
GetConfig().LLMRetryCount = 5
defer resetTestEnv()
prompt := "promptprompt"
comp := NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
)
stream, err := comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: prompt,
},
)
assert.NoError(t, err)
turnCount := 0
for {
resp, err := stream.Recv()
assert.NoError(t, err)
if resp.IsEnd {
break
}
turnCount++
assert.Contains(t, resp.Result, prompt)
}
assert.Greater(t, turnCount, 1)

comp = NewCompletion(
WithEndpoint(fmt.Sprintf("test_retry_%d", rand.Intn(100000))),
WithLLMRetryCount(1),
)
stream, err = comp.Stream(
context.Background(),
&CompletionRequest{
Prompt: prompt,
},
)
assert.NoError(t, err)
_, err = stream.Recv()
assert.Error(t, err)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Equal(t, target.Code, ServerHighLoadErrCode)
}
30 changes: 30 additions & 0 deletions go/qianfan/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,33 @@ const (
DefaultCompletionModel = "ERNIE-Bot-turbo"
DefaultEmbeddingModel = "Embedding-V1"
)

// API 错误码
const (
NoErrorErrCode = 0
UnknownErrorErrCode = 1
ServiceUnavailableErrCode = 2
UnsupportedMethodErrCode = 3
RequestLimitReachedErrCode = 4
NoPermissionToAccessDataErrCode = 6
GetServiceTokenFailedErrCode = 13
AppNotExistErrCode = 15
DailyLimitReachedErrCode = 17
QPSLimitReachedErrCode = 18
TotalRequestLimitReachedErrCode = 19
InvalidRequestErrCode = 100
APITokenInvalidErrCode = 110
APITokenExpiredErrCode = 111
InternalErrorErrCode = 336000
InvalidArgumentErrCode = 336001
InvalidJSONErrCode = 336002
InvalidParamErrCode = 336003
PermissionErrorErrCode = 336004
APINameNotExistErrCode = 336005
ServerHighLoadErrCode = 336100
InvalidHTTPMethodErrCode = 336101
InvalidArgumentSystemErrCode = 336104
InvalidArgumentUserSettingErrCode = 336105

ConsoleInternalErrorErrCode = 500000
)
9 changes: 8 additions & 1 deletion go/qianfan/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
package qianfan

import (
"os"

"github.com/sirupsen/logrus"
)

var logger = logrus.New()
var logger = &logrus.Logger{
Out: os.Stderr,
Formatter: new(logrus.TextFormatter),
Hooks: make(logrus.LevelHooks),
Level: logrus.WarnLevel,
}
Loading

0 comments on commit a1cff48

Please sign in to comment.