Skip to content

Commit

Permalink
go support v2 api
Browse files Browse the repository at this point in the history
  • Loading branch information
ZingLix committed Nov 7, 2024
1 parent adb03fd commit 3da3dc6
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 3 deletions.
29 changes: 29 additions & 0 deletions go/qianfan/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,32 @@ func (m *AuthManager) GetAccessTokenWithRefresh(ctx context.Context, ak, sk stri
GetConfig().AccessToken = resp.AccessToken
return resp.AccessToken, nil
}

type IAMBearerTokenResponse struct {
UserID string `json:"userId"`
Token string `json:"token"`
Status string `json:"status"`
CreateTime string `json:"createTime"`
ExpireTime string `json:"expireTime"`
baseResponse
}

func (r *IAMBearerTokenResponse) GetErrorCode() string {
return "Get IAM Bearer Token Error"
}

func GetBearerToken() (string, error) {
resp := IAMBearerTokenResponse{}
req, err := NewIAMBearerTokenRequest("GET", "/v1/BCE-BEARER/token", nil)
if err != nil {
return "", err
}

err = newRequestor(makeOptions()).request(context.TODO(), req, &resp)
if err != nil {
return "", err
}
logger.Info("Get IAM Bearer Token Success")
GetConfig().BearerToken = resp.Token
return resp.Token, nil
}
189 changes: 189 additions & 0 deletions go/qianfan/chat_completion_v2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package qianfan

import (
"context"
)

// 用于 chat v2 类型模型的结构体
type ChatCompletionV2 struct {
Model string `mapstructure:"model"` // 模型ID
*Requestor // Requstor 作为基类
}

// chat 模型的请求结构体
type ChatCompletionV2Request struct {
BaseRequestBody `mapstructure:"-"`
Model string `mapstructure:"model"` // 模型ID
Messages []ChatCompletionMessage `mapstructure:"messages"` // 聊天上下文信息
StreamOption *StreamOption `mapstructure:"stream_option,omitempty"` // 流式选项
Temperature float64 `mapstructure:"temperature,omitempty"` // 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定,范围 (0, 1.0],不能为0
TopP float64 `mapstructure:"top_p,omitempty"` // 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围 [0, 1.0]
PenaltyScore float64 `mapstructure:"penalty_score,omitempty"` // 通过对已生成的token增加惩罚,减少重复生成的现象。说明:值越大表示惩罚越大,取值范围:[1.0, 2.0]
MaxCompletionTokens int `mapstructure:"max_completion_tokens,omitempty"` // 指定模型最大输出token数
ResponseFormat string `mapstructure:"response_format,omitempty"` // 指定响应内容的格式
Seed int `mapstructure:"seed,omitempty"` // 随机种子
Stop []string `mapstructure:"stop,omitempty"` // 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成
User string `mapstructure:"user,omitempty"` // 表示最终用户的唯一标识符
FrequencyPenalty float64 `mapstructure:"frequency_penalty,omitempty"` // 指定频率惩罚,用于控制生成文本的重复程度。取值范围 [0.0,
PresencePenalty float64 `mapstructure:"presence_penalty,omitempty"` // 指定存在惩罚,用于控制生成文本的重复程度。取值范围 [0.0 int `mapstructure:"num_samples,omitempty"` // 指定采样次数,取值范围 [1, 20]
}

type ChatCompletionV2Response struct {
baseResponse
ID string `mapstructure:"id"` // 请求ID
Object string `mapstructure:"object"` // 对象类型
Created int64 `mapstructure:"created"` // 创建时间
Model string `mapstructure:"model"` // 模型ID
Choices []ChatCompletionV2Choice `mapstructure:"choices"` // 生成结果
Usage *ModelUsage `mapstructure:"usage"` // 请求信息
Error *ChatCompletionV2Error `mapstructure:"error"` // 错误信息
}
type ChatCompletionV2Choice struct {
Index int `mapstructure:"index"` // 生成结果索引
Message ChatCompletionMessage `mapstructure:"message"` // 生成结果
Delta ChatCompletionV2Delta `mapstructure:"delta"` // 生成结果
FinishReason string `mapstructure:"finish_reason"` // 生成结果的分数
Flag int `mapstructure:"flag"` // 生成结果的标志
BanRound int `mapstructure:"ban_round"` // 生成结果
}

type ChatCompletionV2Delta struct {
Content string `mapstructure:"content"` // 生成结果
}

type StreamOption struct {
IncludeUsage bool `mapstructure:"include_usage"` //流式响应是否输出usage
}

type ChatCompletionV2Error struct {
Code string `mapstructure:"code"`
Msg string `mapstructure:"msg"`
Type string `mapstructure:"type"`
}

func (c *ChatCompletionV2Response) GetErrorCode() string {
return c.Error.Msg
}

type ChatCompletionV2ResponseStream struct {
*streamInternal
}

// 内部根据 options 创建一个 ChatCompletion 对象
func newChatCompletionV2(options *Options) *ChatCompletionV2 {
chat := &ChatCompletionV2{
Requestor: newRequestor(options),
}
return chat
}

// 发送 chat 请求
func (c *ChatCompletionV2) Do(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2Response, error) {
var resp *ChatCompletionV2Response
var err error
runErr := runWithContext(ctx, func() {
resp, err = c.do(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}

func (c *ChatCompletionV2) do(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2Response, error) {
do := func() (*ChatCompletionV2Response, error) {

url := "/v2/chat/completions"

req, err := NewBearerTokenRequest("POST", url, request)
if err != nil {
return nil, err
}
var resp ChatCompletionV2Response

err = c.Requestor.request(ctx, req, &resp)

if err != nil {
return nil, err
}

return &resp, nil
}
resp, err := do()

if err != nil {

return resp, err
}
return resp, err
}

// 发送流式请求
func (c *ChatCompletionV2) Stream(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2ResponseStream, error) {
var resp *ChatCompletionV2ResponseStream
var err error
runErr := runWithContext(ctx, func() {
resp, err = c.stream(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}
func newChatCompletionV2ResponseStream(si *streamInternal) (*ChatCompletionV2ResponseStream, error) {
s := &ChatCompletionV2ResponseStream{streamInternal: si}
return s, nil
}
func (c *ChatCompletionV2) stream(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2ResponseStream, error) {
do := func() (*ChatCompletionV2ResponseStream, error) {
url := "/v2/chat/completions"

request.SetStream()
req, err := NewConsoleRequest("POST", url, request)
if err != nil {
return nil, err
}
stream, err := c.Requestor.requestStream(ctx, req)
if err != nil {
return nil, err
}
return newChatCompletionV2ResponseStream(stream)
}
resp, err := do()
return resp, err
}

// 创建一个 ChatCompletion 对象
//
// chat := qianfan.NewChatCompletion() // 使用默认模型
//
// 可以通过 WithModel 指定模型
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithModel("ERNIE-4.0-8K"), // 支持的模型可以通过 chat.ModelList() 获取
//
// )
// 或者通过 WithEndpoint 指定 endpoint
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithEndpoint("your_custom_endpoint"),
//
// )
func NewChatCompletionV2(optionList ...Option) *ChatCompletionV2 {
options := makeOptions(optionList...)
return newChatCompletionV2(options)
}
4 changes: 4 additions & 0 deletions go/qianfan/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ var defaultConfig = map[string]string{
"QIANFAN_SK": "",
"QIANFAN_ACCESS_KEY": "",
"QIANFAN_SECRET_KEY": "",
"QIANFAN_BEARER_TOKEN": "",
"QIANFAN_BASE_URL": "https://aip.baidubce.com",
"QIANFAN_IAM_SIGN_EXPIRATION_SEC": "300",
"QIANFAN_CONSOLE_BASE_URL": "https://qianfan.baidubce.com",
"QIANFAN_IAM_BASE_URL": "http://iam.bj.baidubce.com",
"QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL": "3600",
"QIANFAN_LLM_API_RETRY_COUNT": "1",
"QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR": "0",
Expand All @@ -43,9 +45,11 @@ type Config struct {
AccessKey string `mapstructure:"QIANFAN_ACCESS_KEY"`
SecretKey string `mapstructure:"QIANFAN_SECRET_KEY"`
AccessToken string `mapstructure:"QIANFAN_ACCESS_TOKEN"`
BearerToken string `mapstructure:"QIANFAN_BEARER_TOKEN"`
BaseURL string `mapstructure:"QIANFAN_BASE_URL"`
IAMSignExpirationSeconds int `mapstructure:"QIANFAN_IAM_SIGN_EXPIRATION_SEC"`
ConsoleBaseURL string `mapstructure:"QIANFAN_CONSOLE_BASE_URL"`
IAMBaseURL string `mapstructure:"QIANFAN_IAM_BASE_URL"`
AccessTokenRefreshMinInterval int `mapstructure:"QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL"`
LLMRetryCount int `mapstructure:"QIANFAN_LLM_API_RETRY_COUNT"`
LLMRetryTimeout float32 `mapstructure:"QIANFAN_LLM_API_RETRY_TIMEOUT"`
Expand Down
30 changes: 27 additions & 3 deletions go/qianfan/requestor.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ func convertToMap(body RequestBody) (map[string]interface{}, error) {
// 请求类型,用于区分是模型的请求还是管控类请求
// 在 QfRequest.Type 处被使用
const (
authRequest = "auth" // AccessToken 鉴权请求
modelRequest = "model"
consoleRequest = "console"
authRequest = "auth" // AccessToken 鉴权请求
modelRequest = "model"
consoleRequest = "console"
bearerTokenRequest = "bearer"
iamRequest = "iam"
)

// SDK 内部表示请求的类
Expand Down Expand Up @@ -111,6 +113,16 @@ func NewConsoleRequest(method string, url string, body RequestBody) (*QfRequest,
return newRequest(consoleRequest, method, url, body)
}

// 创建一个使用Bearer Token鉴权的 Request
func NewBearerTokenRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(bearerTokenRequest, method, url, body)
}

// 创建一个使用Bearer Token鉴权的 Request
func NewIAMBearerTokenRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(iamRequest, method, url, body)
}

// 创建一个 Request,body 可以是任意实现了 RequestBody 接口的类型
func newRequest(requestType string, method string, url string, body RequestBody) (*QfRequest, error) {
var b map[string]interface{} = nil
Expand Down Expand Up @@ -195,6 +207,13 @@ func (r *Requestor) addAuthInfo(ctx context.Context, request *QfRequest) error {
if request.Type == authRequest {
return nil
}
if request.Type == bearerTokenRequest {
if GetConfig().BearerToken == "" {
GetBearerToken()

Check failure on line 212 in go/qianfan/requestor.go

View workflow job for this annotation

GitHub Actions / Unit tests (macos-latest, 1.19, 3.11)

Error return value is not checked (errcheck)

Check failure on line 212 in go/qianfan/requestor.go

View workflow job for this annotation

GitHub Actions / Unit tests (macos-latest, 1.20, 3.11)

Error return value is not checked (errcheck)

Check failure on line 212 in go/qianfan/requestor.go

View workflow job for this annotation

GitHub Actions / Unit tests (macos-latest, 1.21.x, 3.11)

Error return value is not checked (errcheck)
}
request.Headers["Authorization"] = fmt.Sprintf("Bearer %s", GetConfig().BearerToken)
return nil
}
if GetConfig().AK != "" && GetConfig().SK != "" {
return r.addAccessToken(ctx, request)
} else if GetConfig().AccessKey != "" && GetConfig().SecretKey != "" {
Expand Down Expand Up @@ -288,6 +307,11 @@ func (r *Requestor) prepareRequest(ctx context.Context, request QfRequest) (*htt
request.Headers["request-source"] = versionIndicator
} else if request.Type == authRequest {
request.URL = GetConfig().BaseURL + request.URL
} else if request.Type == bearerTokenRequest {
request.URL = GetConfig().ConsoleBaseURL + request.URL
request.Headers["request-source"] = versionIndicator
} else if request.Type == iamRequest {
request.URL = GetConfig().IAMBaseURL + request.URL
} else {
return nil, &InternalError{"unexpected request type: " + request.Type}
}
Expand Down

0 comments on commit 3da3dc6

Please sign in to comment.