Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: go 支持 v2 api #858

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions go/qianfan/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,32 @@ func (m *AuthManager) GetAccessTokenWithRefresh(ctx context.Context, ak, sk stri
GetConfig().AccessToken = resp.AccessToken
return resp.AccessToken, nil
}

type IAMBearerTokenResponse struct {
UserID string `json:"userId"`
Token string `json:"token"`
Status string `json:"status"`
CreateTime string `json:"createTime"`
ExpireTime string `json:"expireTime"`
baseResponse
}

func (r *IAMBearerTokenResponse) GetErrorCode() string {
return "Get IAM Bearer Token Error"
}

func GetBearerToken() (string, error) {
resp := IAMBearerTokenResponse{}
req, err := NewIAMBearerTokenRequest("GET", "/v1/BCE-BEARER/token", nil)
if err != nil {
return "", err
}

err = newRequestor(makeOptions()).request(context.TODO(), req, &resp)
if err != nil {
return "", err
}
logger.Info("Get IAM Bearer Token Success")
GetConfig().BearerToken = resp.Token
return resp.Token, nil
}
189 changes: 189 additions & 0 deletions go/qianfan/chat_completion_v2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package qianfan

import (
"context"
)

// 用于 chat v2 类型模型的结构体
type ChatCompletionV2 struct {
Model string `mapstructure:"model"` // 模型ID
*Requestor // Requstor 作为基类
}

// chat 模型的请求结构体
type ChatCompletionV2Request struct {
BaseRequestBody `mapstructure:"-"`
Model string `mapstructure:"model"` // 模型ID
Messages []ChatCompletionMessage `mapstructure:"messages"` // 聊天上下文信息
StreamOption *StreamOption `mapstructure:"stream_option,omitempty"` // 流式选项
Temperature float64 `mapstructure:"temperature,omitempty"` // 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定,范围 (0, 1.0],不能为0
TopP float64 `mapstructure:"top_p,omitempty"` // 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围 [0, 1.0]
PenaltyScore float64 `mapstructure:"penalty_score,omitempty"` // 通过对已生成的token增加惩罚,减少重复生成的现象。说明:值越大表示惩罚越大,取值范围:[1.0, 2.0]
MaxCompletionTokens int `mapstructure:"max_completion_tokens,omitempty"` // 指定模型最大输出token数
ResponseFormat string `mapstructure:"response_format,omitempty"` // 指定响应内容的格式
Seed int `mapstructure:"seed,omitempty"` // 随机种子
Stop []string `mapstructure:"stop,omitempty"` // 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成
User string `mapstructure:"user,omitempty"` // 表示最终用户的唯一标识符
FrequencyPenalty float64 `mapstructure:"frequency_penalty,omitempty"` // 指定频率惩罚,用于控制生成文本的重复程度。取值范围 [0.0,
PresencePenalty float64 `mapstructure:"presence_penalty,omitempty"` // 指定存在惩罚,用于控制生成文本的重复程度。取值范围 [0.0 int `mapstructure:"num_samples,omitempty"` // 指定采样次数,取值范围 [1, 20]
}

type ChatCompletionV2Response struct {
baseResponse
ID string `mapstructure:"id"` // 请求ID
Object string `mapstructure:"object"` // 对象类型
Created int64 `mapstructure:"created"` // 创建时间
Model string `mapstructure:"model"` // 模型ID
Choices []ChatCompletionV2Choice `mapstructure:"choices"` // 生成结果
Usage *ModelUsage `mapstructure:"usage"` // 请求信息
Error *ChatCompletionV2Error `mapstructure:"error"` // 错误信息
}
type ChatCompletionV2Choice struct {
Index int `mapstructure:"index"` // 生成结果索引
Message ChatCompletionMessage `mapstructure:"message"` // 生成结果
Delta ChatCompletionV2Delta `mapstructure:"delta"` // 生成结果
FinishReason string `mapstructure:"finish_reason"` // 生成结果的分数
Flag int `mapstructure:"flag"` // 生成结果的标志
BanRound int `mapstructure:"ban_round"` // 生成结果
}

type ChatCompletionV2Delta struct {
Content string `mapstructure:"content"` // 生成结果
}

type StreamOption struct {
IncludeUsage bool `mapstructure:"include_usage"` //流式响应是否输出usage
}

type ChatCompletionV2Error struct {
Code string `mapstructure:"code"`
Msg string `mapstructure:"msg"`
Type string `mapstructure:"type"`
}

func (c *ChatCompletionV2Response) GetErrorCode() string {
return c.Error.Msg
}

type ChatCompletionV2ResponseStream struct {
*streamInternal
}

// 内部根据 options 创建一个 ChatCompletion 对象
func newChatCompletionV2(options *Options) *ChatCompletionV2 {
chat := &ChatCompletionV2{
Requestor: newRequestor(options),
}
return chat
}

// 发送 chat 请求
func (c *ChatCompletionV2) Do(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2Response, error) {
var resp *ChatCompletionV2Response
var err error
runErr := runWithContext(ctx, func() {
resp, err = c.do(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}

func (c *ChatCompletionV2) do(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2Response, error) {
do := func() (*ChatCompletionV2Response, error) {

url := "/v2/chat/completions"

req, err := NewBearerTokenRequest("POST", url, request)
if err != nil {
return nil, err
}
var resp ChatCompletionV2Response

err = c.Requestor.request(ctx, req, &resp)

if err != nil {
return nil, err
}

return &resp, nil
}
resp, err := do()

if err != nil {

return resp, err
}
return resp, err
}

// 发送流式请求
func (c *ChatCompletionV2) Stream(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2ResponseStream, error) {
var resp *ChatCompletionV2ResponseStream
var err error
runErr := runWithContext(ctx, func() {
resp, err = c.stream(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}
func newChatCompletionV2ResponseStream(si *streamInternal) (*ChatCompletionV2ResponseStream, error) {
s := &ChatCompletionV2ResponseStream{streamInternal: si}
return s, nil
}
func (c *ChatCompletionV2) stream(ctx context.Context, request *ChatCompletionV2Request) (*ChatCompletionV2ResponseStream, error) {
do := func() (*ChatCompletionV2ResponseStream, error) {
url := "/v2/chat/completions"

request.SetStream()
req, err := NewBearerTokenRequest("POST", url, request)
if err != nil {
return nil, err
}
stream, err := c.Requestor.requestStream(ctx, req)
if err != nil {
return nil, err
}
return newChatCompletionV2ResponseStream(stream)
}
resp, err := do()
return resp, err
}

// 创建一个 ChatCompletion 对象
//
// chat := qianfan.NewChatCompletion() // 使用默认模型
//
// 可以通过 WithModel 指定模型
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithModel("ERNIE-4.0-8K"), // 支持的模型可以通过 chat.ModelList() 获取
//
// )
// 或者通过 WithEndpoint 指定 endpoint
// chat := qianfan.NewChatCompletion(
//
// qianfan.WithEndpoint("your_custom_endpoint"),
//
// )
func NewChatCompletionV2(optionList ...Option) *ChatCompletionV2 {
options := makeOptions(optionList...)
return newChatCompletionV2(options)
}
4 changes: 4 additions & 0 deletions go/qianfan/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ var defaultConfig = map[string]string{
"QIANFAN_SK": "",
"QIANFAN_ACCESS_KEY": "",
"QIANFAN_SECRET_KEY": "",
"QIANFAN_BEARER_TOKEN": "",
"QIANFAN_BASE_URL": "https://aip.baidubce.com",
"QIANFAN_IAM_SIGN_EXPIRATION_SEC": "300",
"QIANFAN_CONSOLE_BASE_URL": "https://qianfan.baidubce.com",
"QIANFAN_IAM_BASE_URL": "http://iam.bj.baidubce.com",
"QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL": "3600",
"QIANFAN_LLM_API_RETRY_COUNT": "1",
"QIANFAN_LLM_API_RETRY_BACKOFF_FACTOR": "0",
Expand All @@ -43,9 +45,11 @@ type Config struct {
AccessKey string `mapstructure:"QIANFAN_ACCESS_KEY"`
SecretKey string `mapstructure:"QIANFAN_SECRET_KEY"`
AccessToken string `mapstructure:"QIANFAN_ACCESS_TOKEN"`
BearerToken string `mapstructure:"QIANFAN_BEARER_TOKEN"`
BaseURL string `mapstructure:"QIANFAN_BASE_URL"`
IAMSignExpirationSeconds int `mapstructure:"QIANFAN_IAM_SIGN_EXPIRATION_SEC"`
ConsoleBaseURL string `mapstructure:"QIANFAN_CONSOLE_BASE_URL"`
IAMBaseURL string `mapstructure:"QIANFAN_IAM_BASE_URL"`
AccessTokenRefreshMinInterval int `mapstructure:"QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL"`
LLMRetryCount int `mapstructure:"QIANFAN_LLM_API_RETRY_COUNT"`
LLMRetryTimeout float32 `mapstructure:"QIANFAN_LLM_API_RETRY_TIMEOUT"`
Expand Down
39 changes: 36 additions & 3 deletions go/qianfan/requestor.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ func convertToMap(body RequestBody) (map[string]interface{}, error) {
// 请求类型,用于区分是模型的请求还是管控类请求
// 在 QfRequest.Type 处被使用
const (
authRequest = "auth" // AccessToken 鉴权请求
modelRequest = "model"
consoleRequest = "console"
authRequest = "auth" // AccessToken 鉴权请求
modelRequest = "model"
consoleRequest = "console"
bearerTokenRequest = "bearer"
fetchBearerTokenRequest = "fetchBearerToken"
)

// SDK 内部表示请求的类
Expand Down Expand Up @@ -111,6 +113,16 @@ func NewConsoleRequest(method string, url string, body RequestBody) (*QfRequest,
return newRequest(consoleRequest, method, url, body)
}

// 创建一个使用Bearer Token鉴权的 Request
func NewBearerTokenRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(bearerTokenRequest, method, url, body)
}

// 创建一个使用Bearer Token鉴权的 Request
func NewIAMBearerTokenRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(fetchBearerTokenRequest, method, url, body)
}

// 创建一个 Request,body 可以是任意实现了 RequestBody 接口的类型
func newRequest(requestType string, method string, url string, body RequestBody) (*QfRequest, error) {
var b map[string]interface{} = nil
Expand Down Expand Up @@ -195,6 +207,16 @@ func (r *Requestor) addAuthInfo(ctx context.Context, request *QfRequest) error {
if request.Type == authRequest {
return nil
}
if request.Type == bearerTokenRequest {
if GetConfig().BearerToken == "" {
_, err := GetBearerToken()
if err != nil {
return err
}
}
request.Headers["Authorization"] = fmt.Sprintf("Bearer %s", GetConfig().BearerToken)
return nil
}
if GetConfig().AK != "" && GetConfig().SK != "" {
return r.addAccessToken(ctx, request)
} else if GetConfig().AccessKey != "" && GetConfig().SecretKey != "" {
Expand Down Expand Up @@ -288,6 +310,11 @@ func (r *Requestor) prepareRequest(ctx context.Context, request QfRequest) (*htt
request.Headers["request-source"] = versionIndicator
} else if request.Type == authRequest {
request.URL = GetConfig().BaseURL + request.URL
} else if request.Type == bearerTokenRequest {
request.URL = GetConfig().ConsoleBaseURL + request.URL
request.Headers["request-source"] = versionIndicator
} else if request.Type == fetchBearerTokenRequest {
request.URL = GetConfig().IAMBaseURL + request.URL
} else {
return nil, &InternalError{"unexpected request type: " + request.Type}
}
Expand Down Expand Up @@ -448,6 +475,12 @@ func (si *streamInternal) recv(resp QfResponse) error {
RawResponse: si.httpResponse,
}

if string(eventData) == "[DONE]" {
si.IsEnd = true
si.Close()
return nil
}

resp.SetResponse(response.Body, response.RawResponse)
err := json.Unmarshal(response.Body, resp)
if err != nil {
Expand Down
Loading