Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ZingLix committed Feb 19, 2024
1 parent 3249fd9 commit d713367
Show file tree
Hide file tree
Showing 10 changed files with 436 additions and 24 deletions.
14 changes: 13 additions & 1 deletion go/qianfan/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ type AuthManager struct {
*Requestor
}

func maskAk(ak string) string {
unmaskLen := 6
if len(ak) < unmaskLen {
return ak
}
return fmt.Sprintf("%s******", ak[:unmaskLen])
}

var _authManager *AuthManager

func GetAuthManager() *AuthManager {
Expand All @@ -77,6 +85,7 @@ func (m *AuthManager) GetAccessToken(ak, sk string) (string, error) {
if ok {
return token.token, nil
}
logger.Infof("Access token of ak `%s` not found, tring to refresh it...", maskAk(ak))
return m.GetAccessTokenWithRefresh(ak, sk)
}

Expand All @@ -91,6 +100,7 @@ func (m *AuthManager) GetAccessTokenWithRefresh(ak, sk string) (string, error) {
// 最近更新时间小于最小刷新间隔,则直接返回
// 避免多个请求同时刷新,导致token被刷新多次
if current.Sub(lastUpdate) < time.Duration(GetConfig().AccessTokenRefreshMinInterval)*time.Second {
logger.Debugf("Access token of ak `%s` was freshed %s ago, skip refreshing", maskAk(ak), current.Sub(lastUpdate))
return token.token, nil
}
}
Expand All @@ -112,8 +122,10 @@ func (m *AuthManager) GetAccessTokenWithRefresh(ak, sk string) (string, error) {
return "", err
}
if resp.Error != "" {
return "", fmt.Errorf("refresh access token failed: %s", resp.ErrorDescription)
logger.Errorf("refresh access token of ak `%s` failed with error: %s", maskAk(ak), resp.ErrorDescription)
return "", &APIError{Msg: resp.ErrorDescription}
}
logger.Infof("Access token of ak `%s` was refreshed", maskAk(ak))
m.tokenMap[credential{ak, sk}] = &accessToken{
token: resp.AccessToken,
lastUpateTime: time.Now(),
Expand Down
180 changes: 180 additions & 0 deletions go/qianfan/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package qianfan

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func fakeAccessToken(ak, sk string) string {
return fmt.Sprintf("%s.%s", ak, sk)
}

func resetAuthManager() {
_authManager = nil
}

func setAccessTokenExpired(ak, sk string) {
GetAuthManager().tokenMap[credential{ak, sk}] = &accessToken{
token: "expired_token",
lastUpateTime: time.Now().Add(-100 * time.Hour), // 100s 过期
}
}

func TestAuth(t *testing.T) {
resetAuthManager()
ak, sk := "ak_33", "sk_4235"
// 第一次获取前,缓存里应当没有
_, ok := GetAuthManager().tokenMap[credential{ak, sk}]
assert.False(t, ok)

accessTok, err := GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk))
updateTime := GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime
// 再测试一次,应当从缓存里获取,更新时间不变
accessTok, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk))
assert.Equal(
t,
updateTime,
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime,
)
// 模拟过期
ak, sk = "ak_95411", "sk_87135"
setAccessTokenExpired(ak, sk)
// 设置一个附近的更新时间,用来测试是否会忽略刚更新过的 token
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime = time.Now()

accessTok, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, "expired_token") // 直接获取还是从缓存获取

accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, "expired_token") // 刷新后,由于 lastUpdateTime 太接近,依旧使用缓存
setAccessTokenExpired(ak, sk)

accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) // 应当刷新
elaplsed := time.Since(GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime)
assert.Less(t, elaplsed, 10*time.Second) // 刷新后,lastUpdateTime 应当更新
}

func TestAuthFailed(t *testing.T) {
ak, sk := "bad_ak", "bad_sk"
_, err := GetAuthManager().GetAccessToken(ak, sk)
assert.Error(t, err)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Contains(t, err.Error(), target.Msg)
assert.Equal(t, target.Msg, "Client authentication failed")
}

func TestAuthWhenUsing(t *testing.T) {
defer resetTestEnv()
_authManager = nil
GetConfig().AccessKey = "access_key_484913"
GetConfig().SecretKey = "secret_key_48135"
GetConfig().AK = ""
GetConfig().SK = ""
// 未设置 AKSK,所以用 IAM 鉴权
chat := NewChatCompletion()
resp, err := chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.NoError(t, err)
signedKey, ok := resp.RawResponse.Request.Header["Authorization"]
assert.True(t, ok)
assert.Contains(t, signedKey[0], GetConfig().AccessKey)
assert.NotContains(t, resp.RawResponse.Request.URL.RawQuery, "access_token")
// 设置了 AKSK,所以用 AKSK 鉴权
GetConfig().AK = "ak_48915684"
GetConfig().SK = "sk_78941813"
resp, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.NoError(t, err)
_, ok = resp.RawResponse.Request.Header["Authorization"]
assert.False(t, ok)
assert.Contains(t, resp.RawResponse.Request.URL.RawQuery, "access_token")
assert.Equal(
t,
resp.RawResponse.Request.URL.Query().Get("access_token"),
fakeAccessToken(GetConfig().AK, GetConfig().SK),
)
// 如果只设置了部分鉴权信息,则报错
GetConfig().AK = ""
GetConfig().AccessKey = ""
_, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.Error(t, err)
var target *CredentialNotFoundError
assert.ErrorAs(t, err, &target)
}

func TestAccessTokenExpired(t *testing.T) {
defer resetTestEnv()
_authManager = nil
ak, sk := "ak_48915684", "sk_78941813"
GetConfig().AK = ak
GetConfig().SK = sk
setAccessTokenExpired(ak, sk)
token, err := GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Contains(t, token, "expired")
prompt := "你好"
chat := NewChatCompletion()
resp, err := chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)},
},
)
assert.NoError(t, err)
assert.Contains(t, resp.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk))
assert.Contains(t, resp.Result, prompt)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))

// 测试流式请求的刷新 token
setAccessTokenExpired(ak, sk)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Contains(t, token, "expired")
stream, err := chat.Stream(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)},
},
)
assert.NoError(t, err)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))
for {
r, err := stream.Recv()
assert.NoError(t, err)
assert.Contains(t, r.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk))
if r.IsEnd {
break
}
}

}
3 changes: 1 addition & 2 deletions go/qianfan/base_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package qianfan

import (
"fmt"
"strconv"
)

Expand Down Expand Up @@ -104,7 +103,7 @@ func (s *ModelResponseStream) Recv() (*ModelResponse, error) {
func checkResponseError(resp ModelAPIResponse) error {
errCode, errMsg := resp.GetError()
if errCode != 0 {
return fmt.Errorf("API return error. code: %d, msg: %s", errCode, errMsg)
return &APIError{Code: errCode, Msg: errMsg}
}
return nil
}
5 changes: 2 additions & 3 deletions go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package qianfan

import (
"context"
"fmt"
)

// 表示对话内容的结构体
Expand Down Expand Up @@ -141,10 +140,10 @@ func newChatCompletion(options *Options) *ChatCompletion {
// 将 endpoint 转换成完整的 url
func (c *ChatCompletion) realEndpoint() (string, error) {
url := modelAPIPrefix
if c.Model != "" {
if c.Endpoint == "" {
endpoint, ok := ChatModelEndpoint[c.Model]
if !ok {
return "", fmt.Errorf("model %s is not supported", c.Model)
return "", &ModelNotSupportedError{Model: c.Model}
}
url += endpoint
} else {
Expand Down
Loading

0 comments on commit d713367

Please sign in to comment.