-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
436 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
package qianfan | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func fakeAccessToken(ak, sk string) string { | ||
return fmt.Sprintf("%s.%s", ak, sk) | ||
} | ||
|
||
func resetAuthManager() { | ||
_authManager = nil | ||
} | ||
|
||
func setAccessTokenExpired(ak, sk string) { | ||
GetAuthManager().tokenMap[credential{ak, sk}] = &accessToken{ | ||
token: "expired_token", | ||
lastUpateTime: time.Now().Add(-100 * time.Hour), // 100s 过期 | ||
} | ||
} | ||
|
||
func TestAuth(t *testing.T) { | ||
resetAuthManager() | ||
ak, sk := "ak_33", "sk_4235" | ||
// 第一次获取前,缓存里应当没有 | ||
_, ok := GetAuthManager().tokenMap[credential{ak, sk}] | ||
assert.False(t, ok) | ||
|
||
accessTok, err := GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) | ||
updateTime := GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime | ||
// 再测试一次,应当从缓存里获取,更新时间不变 | ||
accessTok, err = GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) | ||
assert.Equal( | ||
t, | ||
updateTime, | ||
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime, | ||
) | ||
// 模拟过期 | ||
ak, sk = "ak_95411", "sk_87135" | ||
setAccessTokenExpired(ak, sk) | ||
// 设置一个附近的更新时间,用来测试是否会忽略刚更新过的 token | ||
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime = time.Now() | ||
|
||
accessTok, err = GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, accessTok, "expired_token") // 直接获取还是从缓存获取 | ||
|
||
accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, accessTok, "expired_token") // 刷新后,由于 lastUpdateTime 太接近,依旧使用缓存 | ||
setAccessTokenExpired(ak, sk) | ||
|
||
accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) // 应当刷新 | ||
elaplsed := time.Since(GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime) | ||
assert.Less(t, elaplsed, 10*time.Second) // 刷新后,lastUpdateTime 应当更新 | ||
} | ||
|
||
func TestAuthFailed(t *testing.T) { | ||
ak, sk := "bad_ak", "bad_sk" | ||
_, err := GetAuthManager().GetAccessToken(ak, sk) | ||
assert.Error(t, err) | ||
var target *APIError | ||
assert.ErrorAs(t, err, &target) | ||
assert.Contains(t, err.Error(), target.Msg) | ||
assert.Equal(t, target.Msg, "Client authentication failed") | ||
} | ||
|
||
func TestAuthWhenUsing(t *testing.T) { | ||
defer resetTestEnv() | ||
_authManager = nil | ||
GetConfig().AccessKey = "access_key_484913" | ||
GetConfig().SecretKey = "secret_key_48135" | ||
GetConfig().AK = "" | ||
GetConfig().SK = "" | ||
// 未设置 AKSK,所以用 IAM 鉴权 | ||
chat := NewChatCompletion() | ||
resp, err := chat.Do( | ||
context.Background(), | ||
&ChatCompletionRequest{ | ||
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, | ||
}, | ||
) | ||
assert.NoError(t, err) | ||
signedKey, ok := resp.RawResponse.Request.Header["Authorization"] | ||
assert.True(t, ok) | ||
assert.Contains(t, signedKey[0], GetConfig().AccessKey) | ||
assert.NotContains(t, resp.RawResponse.Request.URL.RawQuery, "access_token") | ||
// 设置了 AKSK,所以用 AKSK 鉴权 | ||
GetConfig().AK = "ak_48915684" | ||
GetConfig().SK = "sk_78941813" | ||
resp, err = chat.Do( | ||
context.Background(), | ||
&ChatCompletionRequest{ | ||
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, | ||
}, | ||
) | ||
assert.NoError(t, err) | ||
_, ok = resp.RawResponse.Request.Header["Authorization"] | ||
assert.False(t, ok) | ||
assert.Contains(t, resp.RawResponse.Request.URL.RawQuery, "access_token") | ||
assert.Equal( | ||
t, | ||
resp.RawResponse.Request.URL.Query().Get("access_token"), | ||
fakeAccessToken(GetConfig().AK, GetConfig().SK), | ||
) | ||
// 如果只设置了部分鉴权信息,则报错 | ||
GetConfig().AK = "" | ||
GetConfig().AccessKey = "" | ||
_, err = chat.Do( | ||
context.Background(), | ||
&ChatCompletionRequest{ | ||
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")}, | ||
}, | ||
) | ||
assert.Error(t, err) | ||
var target *CredentialNotFoundError | ||
assert.ErrorAs(t, err, &target) | ||
} | ||
|
||
func TestAccessTokenExpired(t *testing.T) { | ||
defer resetTestEnv() | ||
_authManager = nil | ||
ak, sk := "ak_48915684", "sk_78941813" | ||
GetConfig().AK = ak | ||
GetConfig().SK = sk | ||
setAccessTokenExpired(ak, sk) | ||
token, err := GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Contains(t, token, "expired") | ||
prompt := "你好" | ||
chat := NewChatCompletion() | ||
resp, err := chat.Do( | ||
context.Background(), | ||
&ChatCompletionRequest{ | ||
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)}, | ||
}, | ||
) | ||
assert.NoError(t, err) | ||
assert.Contains(t, resp.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk)) | ||
assert.Contains(t, resp.Result, prompt) | ||
token, err = GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, token, fakeAccessToken(ak, sk)) | ||
|
||
// 测试流式请求的刷新 token | ||
setAccessTokenExpired(ak, sk) | ||
token, err = GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Contains(t, token, "expired") | ||
stream, err := chat.Stream( | ||
context.Background(), | ||
&ChatCompletionRequest{ | ||
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)}, | ||
}, | ||
) | ||
assert.NoError(t, err) | ||
token, err = GetAuthManager().GetAccessToken(ak, sk) | ||
assert.NoError(t, err) | ||
assert.Equal(t, token, fakeAccessToken(ak, sk)) | ||
for { | ||
r, err := stream.Recv() | ||
assert.NoError(t, err) | ||
assert.Contains(t, r.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk)) | ||
if r.IsEnd { | ||
break | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.