From 7fa53af2aff2fd50756f369b0aa66422b346a30c Mon Sep 17 00:00:00 2001 From: George Date: Mon, 14 Oct 2024 16:38:11 +0800 Subject: [PATCH 1/2] feat: support spark new model --- adapter/sparkdesk/chat.go | 36 +++++++++++++++++++++++++------- adapter/sparkdesk/struct.go | 24 ++++++++++++--------- app/src/admin/channel.ts | 10 +++++---- app/src/admin/datasets/charge.ts | 20 ++++++++++++++---- globals/variables.go | 10 +++++---- 5 files changed, 70 insertions(+), 30 deletions(-) diff --git a/adapter/sparkdesk/chat.go b/adapter/sparkdesk/chat.go index 474d2de7..b19d4da3 100644 --- a/adapter/sparkdesk/chat.go +++ b/adapter/sparkdesk/chat.go @@ -9,8 +9,8 @@ import ( ) var FunctionCallingModels = []string{ - globals.SparkDeskV3, - globals.SparkDeskV35, + globals.SparkDeskMax, + globals.SparkDeskV4Ultra, } func GetToken(props *adaptercommon.ChatProps) *int { @@ -19,11 +19,11 @@ func GetToken(props *adaptercommon.ChatProps) *int { } switch props.Model { - case globals.SparkDeskV2, globals.SparkDeskV3, globals.SparkDeskV35: + case globals.SparkDeskLite, globals.SparkDeskPro128K: if *props.MaxTokens > 8192 { return utils.ToPtr(8192) } - case globals.SparkDesk: + case globals.SparkDeskPro, globals.SparkDeskMax, globals.SparkDeskMax32K, globals.SparkDeskV4Ultra: if *props.MaxTokens > 4096 { return utils.ToPtr(4096) } @@ -32,6 +32,18 @@ func GetToken(props *adaptercommon.ChatProps) *int { return props.MaxTokens } +func GetTopK(props *adaptercommon.ChatProps) *int { + if props.TopK == nil { + return nil + } + // topk max value is 6 + if *props.TopK > 6 { + return utils.ToPtr(6) + } + + return props.TopK +} + func (c *ChatInstance) GetMessages(props *adaptercommon.ChatProps) []Message { var messages []Message for _, message := range props.Message { @@ -103,8 +115,13 @@ func getChoice(form *ChatResponse) *globals.Chunk { } func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { - endpoint := fmt.Sprintf("%s/%s/chat", c.Endpoint, TransformAddr(props.Model)) - + var endpoint string + switch props.Model { + case globals.SparkDeskPro128K, globals.SparkDeskMax32K: + endpoint = fmt.Sprintf("%s/chat/%s", c.Endpoint, TransformModel(props.Model)) + default: + endpoint = fmt.Sprintf("%s/%s/chat", c.Endpoint, TransformAddr(props.Model)) + } var conn *utils.WebSocket if conn = utils.NewWebsocketClient(c.GenerateUrl(endpoint)); conn == nil { return fmt.Errorf("sparkdesk error: websocket connection failed") @@ -121,10 +138,13 @@ func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, h }, Functions: c.GetFunctionCalling(props), }, + Parameter: RequestParameter{ Chat: ChatParameter{ - Domain: TransformModel(props.Model), - MaxToken: GetToken(props), + Domain: TransformModel(props.Model), + MaxToken: GetToken(props), + Temperature: props.Temperature, + TopK: GetTopK(props), }, }, }); err != nil { diff --git a/adapter/sparkdesk/struct.go b/adapter/sparkdesk/struct.go index c49b51cd..bd5aaaab 100644 --- a/adapter/sparkdesk/struct.go +++ b/adapter/sparkdesk/struct.go @@ -21,14 +21,14 @@ type ChatInstance struct { func TransformAddr(model string) string { switch model { - case globals.SparkDesk: + case globals.SparkDeskLite: return "v1.1" - case globals.SparkDeskV2: - return "v2.1" - case globals.SparkDeskV3: + case globals.SparkDeskPro: return "v3.1" - case globals.SparkDeskV35: + case globals.SparkDeskMax: return "v3.5" + case globals.SparkDeskV4Ultra: + return "v4.0" default: return "v1.1" } @@ -36,14 +36,18 @@ func TransformAddr(model string) string { func TransformModel(model string) string { switch model { - case globals.SparkDesk: + case globals.SparkDeskLite: return "general" - case globals.SparkDeskV2: - return "generalv2" - case globals.SparkDeskV3: + case globals.SparkDeskPro: return "generalv3" - case globals.SparkDeskV35: + case globals.SparkDeskPro128K: + return "pro-128k" + case globals.SparkDeskMax: return "generalv3.5" + case globals.SparkDeskMax32K: + return "max-32k" + case globals.SparkDeskV4Ultra: + return "4.0Ultra" default: return "general" } diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index c7bd5432..4522da41 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -188,10 +188,12 @@ export const ChannelInfos: Record = { endpoint: "wss://spark-api.xf-yun.com", format: "||", models: [ - "spark-desk-v1.5", - "spark-desk-v2", - "spark-desk-v3", - "spark-desk-v3.5", + "spark-desk-lite", + "spark-desk-pro", + "spark-desk-pro-128k", + "spark-desk-max", + "spark-desk-max-32k", + "spark-desk-4.0-ultra", ], }, chatglm: { diff --git a/app/src/admin/datasets/charge.ts b/app/src/admin/datasets/charge.ts index ec2b61c4..4e0092c1 100644 --- a/app/src/admin/datasets/charge.ts +++ b/app/src/admin/datasets/charge.ts @@ -155,17 +155,29 @@ export const pricing: PricingDataset = [ billing_type: timesBilling, }, { - models: ["spark-desk-v1.5"], - input: 0.015, - output: 0.015, + models: ["spark-desk-lite"], // free + input: 0.001, + output: 0.001, currency: Currency.CNY, }, { - models: ["spark-desk-v2", "spark-desk-v3", "spark-desk-v3.5"], + models: ["spark-desk-pro", "spark-desk-pro-128k","spark-desk-max"], input: 0.03, output: 0.03, currency: Currency.CNY, }, + { + models: ["spark-desk-max-32k"], + input: 0.032, + output: 0.032, + currency: Currency.CNY, + }, + { + models: ["spark-desk-4.0-ultra"], + input: 0.1, + output: 0.1, + currency: Currency.CNY, + }, { models: ["moonshot-v1-8k"], input: 0.012, diff --git a/globals/variables.go b/globals/variables.go index 8032b943..5a0b1147 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -98,10 +98,12 @@ const ( Claude2200k = "claude-2.1" Claude3 = "claude-3" ClaudeSlack = "claude-slack" - SparkDesk = "spark-desk-v1.5" - SparkDeskV2 = "spark-desk-v2" - SparkDeskV3 = "spark-desk-v3" - SparkDeskV35 = "spark-desk-v3.5" + SparkDeskLite = "spark-desk-lite" + SparkDeskPro = "spark-desk-pro" + SparkDeskPro128K = "spark-desk-pro-128k" + SparkDeskMax = "spark-desk-max" + SparkDeskMax32K = "spark-desk-max-32k" + SparkDeskV4Ultra = "spark-desk-4.0-ultra" ChatBison001 = "chat-bison-001" GeminiPro = "gemini-pro" GeminiProVision = "gemini-pro-vision" From 1ba6faaa79ecedc1f897a65b14fb38d192ec9bd6 Mon Sep 17 00:00:00 2001 From: George Date: Mon, 14 Oct 2024 17:05:02 +0800 Subject: [PATCH 2/2] fix: fix max_tokens --- adapter/sparkdesk/chat.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adapter/sparkdesk/chat.go b/adapter/sparkdesk/chat.go index b19d4da3..5316ea54 100644 --- a/adapter/sparkdesk/chat.go +++ b/adapter/sparkdesk/chat.go @@ -20,13 +20,13 @@ func GetToken(props *adaptercommon.ChatProps) *int { switch props.Model { case globals.SparkDeskLite, globals.SparkDeskPro128K: - if *props.MaxTokens > 8192 { - return utils.ToPtr(8192) - } - case globals.SparkDeskPro, globals.SparkDeskMax, globals.SparkDeskMax32K, globals.SparkDeskV4Ultra: if *props.MaxTokens > 4096 { return utils.ToPtr(4096) } + case globals.SparkDeskPro, globals.SparkDeskMax, globals.SparkDeskMax32K, globals.SparkDeskV4Ultra: + if *props.MaxTokens > 8192 { + return utils.ToPtr(8192) + } } return props.MaxTokens