From 04d10be7a9140341840b9e8e68e468bc8fda85a0 Mon Sep 17 00:00:00 2001 From: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:44:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20Raw=20=E8=AF=B7=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go/qianfan/base_model.go | 68 +++++++++++++++++++++++++++++++++++ go/qianfan/chat_completion.go | 4 +-- go/qianfan/completion.go | 4 +-- go/qianfan/console_action.go | 2 +- go/qianfan/embdding.go | 2 +- go/qianfan/requestor.go | 4 +-- go/qianfan/service.go | 2 +- go/qianfan/text2img.go | 2 +- go/qianfan/tokenizer.go | 2 +- 9 files changed, 79 insertions(+), 11 deletions(-) diff --git a/go/qianfan/base_model.go b/go/qianfan/base_model.go index 958ca1c2..3369342d 100644 --- a/go/qianfan/base_model.go +++ b/go/qianfan/base_model.go @@ -32,6 +32,10 @@ type BaseModel struct { *Requestor // Requstor 作为基类 } +func NewBaseModel(options ...Option) *BaseModel { + return &BaseModel{Requestor: newRequestor(makeOptions(options...))} +} + // 使用量信息 type ModelUsage struct { PromptTokens int `json:"prompt_tokens"` // 问题tokens数 @@ -44,6 +48,26 @@ type ModelAPIResponse interface { ClearError() } +type RawRequest map[string]any + +func (r RawRequest) SetExtra(extra map[string]any) { + r["extra"] = extra +} + +func (r RawRequest) GetExtra() map[string]any { + extra, ok := r["extra"] + if !ok { + return make(map[string]any) + } else { + return extra.(map[string]any) + } +} + +type RawResponse struct { + baseResponse `json:",omitempty"` + ModelAPIError `json:",omitempty"` +} + // API 错误信息 type ModelAPIError struct { ErrorCode int `json:"error_code"` // 错误码 @@ -179,6 +203,33 @@ func (s *ModelResponseStream) Recv() (*ModelResponse, error) { return &resp, nil } +type RawModelResponseStream struct { + *ModelResponseStream +} + +func newRawModelResponseStream(si *streamInternal) (*RawModelResponseStream, error) { + s := &RawModelResponseStream{} + mrs, err := newModelResponseStream(si) + if err != nil { + return s, err + } + + s.ModelResponseStream = mrs + return s, nil +} + +func (s *RawModelResponseStream) Recv() (*RawResponse, error) { + var resp RawResponse + err := s.streamInternal.Recv(&resp) + if err != nil { + return nil, err + } + if err = checkResponseError(&resp); err != nil { + return &resp, err + } + return &resp, nil +} + func checkResponseError(resp ModelAPIResponse) error { errCode, errMsg := resp.GetError() if errCode != 0 { @@ -272,3 +323,20 @@ func isUnsupportedModelError(err error) bool { } return false } + +func (m *BaseModel) Do(ctx context.Context, request *QfRequest) (*RawResponse, error) { + rawResponse := &RawResponse{} + if err := m.requestResource(ctx, request, rawResponse); err != nil { + return nil, err + } + return rawResponse, nil +} + +func (m *BaseModel) Stream(ctx context.Context, request *QfRequest) (*RawModelResponseStream, error) { + request.Body["stream"] = true + si, err := m.requestStream(ctx, request) + if err != nil { + return nil, err + } + return newRawModelResponseStream(si) +} diff --git a/go/qianfan/chat_completion.go b/go/qianfan/chat_completion.go index 07945ceb..6b34cd89 100644 --- a/go/qianfan/chat_completion.go +++ b/go/qianfan/chat_completion.go @@ -316,7 +316,7 @@ func (c *ChatCompletion) do(ctx context.Context, request *ChatCompletionRequest) c.processWithInputLimit(ctx, request, url) - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } @@ -368,7 +368,7 @@ func (c *ChatCompletion) stream(ctx context.Context, request *ChatCompletionRequ c.processWithInputLimit(ctx, request, url) request.SetStream() - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } diff --git a/go/qianfan/completion.go b/go/qianfan/completion.go index fc8bafaa..9a9200ff 100644 --- a/go/qianfan/completion.go +++ b/go/qianfan/completion.go @@ -145,7 +145,7 @@ func (c *Completion) do(ctx context.Context, request *CompletionRequest) (*Model if err != nil { return nil, err } - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } @@ -195,7 +195,7 @@ func (c *Completion) stream(ctx context.Context, request *CompletionRequest) (*M return nil, err } request.SetStream() - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } diff --git a/go/qianfan/console_action.go b/go/qianfan/console_action.go index 24f67b45..cd39bc87 100644 --- a/go/qianfan/console_action.go +++ b/go/qianfan/console_action.go @@ -30,7 +30,7 @@ func (ca *ConsoleAction) Call(ctx context.Context, route string, action string, reqBody := BaseRequestBody{ Extra: params, } - req, err := newConsoleRequest("POST", ca.baseActionUrl(route, action), &reqBody) + req, err := NewConsoleRequest("POST", ca.baseActionUrl(route, action), &reqBody) if err != nil { logger.Error("new console req error", err) return nil, err diff --git a/go/qianfan/embdding.go b/go/qianfan/embdding.go index 1dee1f9d..4095801b 100644 --- a/go/qianfan/embdding.go +++ b/go/qianfan/embdding.go @@ -118,7 +118,7 @@ func (c *Embedding) do(ctx context.Context, request *EmbeddingRequest) (*Embeddi if err != nil { return nil, err } - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } diff --git a/go/qianfan/requestor.go b/go/qianfan/requestor.go index f07f50fe..da3eb6a7 100644 --- a/go/qianfan/requestor.go +++ b/go/qianfan/requestor.go @@ -97,7 +97,7 @@ type QfRequest struct { } // 创建一个用于模型类请求的 Request -func newModelRequest(method string, url string, body RequestBody) (*QfRequest, error) { +func NewModelRequest(method string, url string, body RequestBody) (*QfRequest, error) { return newRequest(modelRequest, method, url, body) } @@ -107,7 +107,7 @@ func newAuthRequest(method string, url string, body RequestBody) (*QfRequest, er } // 创建一个用于管控类请求的 Request -func newConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) { +func NewConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) { return newRequest(consoleRequest, method, url, body) } diff --git a/go/qianfan/service.go b/go/qianfan/service.go index 8d627bd8..30d77503 100644 --- a/go/qianfan/service.go +++ b/go/qianfan/service.go @@ -42,7 +42,7 @@ type ServiceListItemVersion struct { func (service *Service) List(ctx context.Context, request *ServiceListRequest) (*ServiceListResponse, error) { var s ServiceListResponse - req, err := newConsoleRequest("POST", serviceListURL, request) + req, err := NewConsoleRequest("POST", serviceListURL, request) if err != nil { return nil, err } diff --git a/go/qianfan/text2img.go b/go/qianfan/text2img.go index 41128477..4de85a35 100644 --- a/go/qianfan/text2img.go +++ b/go/qianfan/text2img.go @@ -121,7 +121,7 @@ func (c *Text2Image) do(ctx context.Context, request *Text2ImageRequest) (*Text2 if err != nil { return nil, err } - req, err := newModelRequest("POST", url, request) + req, err := NewModelRequest("POST", url, request) if err != nil { return nil, err } diff --git a/go/qianfan/tokenizer.go b/go/qianfan/tokenizer.go index 4d5fa9ac..1daccd00 100644 --- a/go/qianfan/tokenizer.go +++ b/go/qianfan/tokenizer.go @@ -87,7 +87,7 @@ func (t *Tokenizer) remoteCountTokensEB(text string, model string) (int, error) Model: model, } - req, err := newModelRequest("POST", modelAPIPrefix+"/tokenizer/erniebot", request) + req, err := NewModelRequest("POST", modelAPIPrefix+"/tokenizer/erniebot", request) if err != nil { return -1, err }