Skip to content

Commit

Permalink
新增 Raw 请求
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiichi-Origami committed Sep 10, 2024
1 parent 0038584 commit 04d10be
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 11 deletions.
68 changes: 68 additions & 0 deletions go/qianfan/base_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ type BaseModel struct {
*Requestor // Requstor 作为基类
}

func NewBaseModel(options ...Option) *BaseModel {
return &BaseModel{Requestor: newRequestor(makeOptions(options...))}
}

// 使用量信息
type ModelUsage struct {
PromptTokens int `json:"prompt_tokens"` // 问题tokens数
Expand All @@ -44,6 +48,26 @@ type ModelAPIResponse interface {
ClearError()
}

type RawRequest map[string]any

func (r RawRequest) SetExtra(extra map[string]any) {
r["extra"] = extra
}

func (r RawRequest) GetExtra() map[string]any {
extra, ok := r["extra"]
if !ok {
return make(map[string]any)
} else {
return extra.(map[string]any)
}
}

type RawResponse struct {
baseResponse `json:",omitempty"`
ModelAPIError `json:",omitempty"`
}

// API 错误信息
type ModelAPIError struct {
ErrorCode int `json:"error_code"` // 错误码
Expand Down Expand Up @@ -179,6 +203,33 @@ func (s *ModelResponseStream) Recv() (*ModelResponse, error) {
return &resp, nil
}

type RawModelResponseStream struct {
*ModelResponseStream
}

func newRawModelResponseStream(si *streamInternal) (*RawModelResponseStream, error) {
s := &RawModelResponseStream{}
mrs, err := newModelResponseStream(si)
if err != nil {
return s, err
}

s.ModelResponseStream = mrs
return s, nil
}

func (s *RawModelResponseStream) Recv() (*RawResponse, error) {
var resp RawResponse
err := s.streamInternal.Recv(&resp)
if err != nil {
return nil, err
}
if err = checkResponseError(&resp); err != nil {
return &resp, err
}
return &resp, nil
}

func checkResponseError(resp ModelAPIResponse) error {
errCode, errMsg := resp.GetError()
if errCode != 0 {
Expand Down Expand Up @@ -272,3 +323,20 @@ func isUnsupportedModelError(err error) bool {
}
return false
}

func (m *BaseModel) Do(ctx context.Context, request *QfRequest) (*RawResponse, error) {
rawResponse := &RawResponse{}
if err := m.requestResource(ctx, request, rawResponse); err != nil {
return nil, err
}
return rawResponse, nil
}

func (m *BaseModel) Stream(ctx context.Context, request *QfRequest) (*RawModelResponseStream, error) {
request.Body["stream"] = true
si, err := m.requestStream(ctx, request)
if err != nil {
return nil, err
}
return newRawModelResponseStream(si)
}
4 changes: 2 additions & 2 deletions go/qianfan/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func (c *ChatCompletion) do(ctx context.Context, request *ChatCompletionRequest)

c.processWithInputLimit(ctx, request, url)

req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -368,7 +368,7 @@ func (c *ChatCompletion) stream(ctx context.Context, request *ChatCompletionRequ
c.processWithInputLimit(ctx, request, url)

request.SetStream()
req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/qianfan/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (c *Completion) do(ctx context.Context, request *CompletionRequest) (*Model
if err != nil {
return nil, err
}
req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -195,7 +195,7 @@ func (c *Completion) stream(ctx context.Context, request *CompletionRequest) (*M
return nil, err
}
request.SetStream()
req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/qianfan/console_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (ca *ConsoleAction) Call(ctx context.Context, route string, action string,
reqBody := BaseRequestBody{
Extra: params,
}
req, err := newConsoleRequest("POST", ca.baseActionUrl(route, action), &reqBody)
req, err := NewConsoleRequest("POST", ca.baseActionUrl(route, action), &reqBody)
if err != nil {
logger.Error("new console req error", err)
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion go/qianfan/embdding.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (c *Embedding) do(ctx context.Context, request *EmbeddingRequest) (*Embeddi
if err != nil {
return nil, err
}
req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/qianfan/requestor.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type QfRequest struct {
}

// 创建一个用于模型类请求的 Request
func newModelRequest(method string, url string, body RequestBody) (*QfRequest, error) {
func NewModelRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(modelRequest, method, url, body)
}

Expand All @@ -107,7 +107,7 @@ func newAuthRequest(method string, url string, body RequestBody) (*QfRequest, er
}

// 创建一个用于管控类请求的 Request
func newConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) {
func NewConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) {
return newRequest(consoleRequest, method, url, body)
}

Expand Down
2 changes: 1 addition & 1 deletion go/qianfan/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type ServiceListItemVersion struct {

func (service *Service) List(ctx context.Context, request *ServiceListRequest) (*ServiceListResponse, error) {
var s ServiceListResponse
req, err := newConsoleRequest("POST", serviceListURL, request)
req, err := NewConsoleRequest("POST", serviceListURL, request)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/qianfan/text2img.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (c *Text2Image) do(ctx context.Context, request *Text2ImageRequest) (*Text2
if err != nil {
return nil, err
}
req, err := newModelRequest("POST", url, request)
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/qianfan/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (t *Tokenizer) remoteCountTokensEB(text string, model string) (int, error)
Model: model,
}

req, err := newModelRequest("POST", modelAPIPrefix+"/tokenizer/erniebot", request)
req, err := NewModelRequest("POST", modelAPIPrefix+"/tokenizer/erniebot", request)
if err != nil {
return -1, err
}
Expand Down

0 comments on commit 04d10be

Please sign in to comment.