-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
790 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package models | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
// CombineMessages combines the content of multiple messages into a single string | ||
func CombineMessages(messages []Message) string { | ||
var builder strings.Builder | ||
for _, message := range messages { | ||
builder.WriteString(message.Content) | ||
builder.WriteString("\n") | ||
} | ||
return builder.String() | ||
} | ||
|
||
// ExtractMessageRequest extracts a MessageRequest from a request body | ||
func ExtractMessageRequest(body []byte) (*MessageRequest, error) { | ||
if len(body) == 0 { | ||
return nil, fmt.Errorf("request body is empty") | ||
} | ||
var request MessageRequest | ||
if err := json.Unmarshal(body, &request); err != nil { | ||
return nil, fmt.Errorf("failed to unmarshal request body: %w", err) | ||
} | ||
return &request, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
package models | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/pkoukk/tiktoken-go" | ||
"github.com/rs/zerolog" | ||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
const ( | ||
ChatGPT string = "chatgpt" | ||
Claude string = "claude" | ||
Gemini string = "gemini" | ||
|
||
// all Claude models (including Claude 3 family - Haiku, Sonnet, and Opus) | ||
// use the same "cl100k_base" encoding for tokenization | ||
ClaudeDefaultEncoding = "cl100k_base" | ||
GeminiDefaultEncoding = "cl100k_base" | ||
ChatGPTDefaultEncoding = "cl100k_base" | ||
) | ||
|
||
type ModelI interface { | ||
GetID() string | ||
CountTokensOfText(string) (int, error) | ||
CountTokensOfLLMMessage([]byte) (int, error) | ||
} | ||
|
||
// Message represents a single message in the conversation | ||
type Message struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
// MessageRequest represents the request body for creating a message | ||
type MessageRequest struct { | ||
Model string `json:"model"` | ||
Messages []Message `json:"messages"` | ||
} | ||
|
||
type Model struct { | ||
modelName string // model name (can consist wildcard to specify range): gpt-4o-*, gpt-3.5-turbo | ||
modelType string | ||
encoding string | ||
encoder *tiktoken.Tiktoken // Cache the encoder for performance | ||
logger zerolog.Logger | ||
} | ||
|
||
// NewModel creates a new Model | ||
func NewModel() *Model { | ||
return &Model{ | ||
logger: log.With().Str("component", "ai-model").Logger(), | ||
} | ||
} | ||
|
||
func (m *Model) Init() error { | ||
var err error | ||
if m.encoding != "" { | ||
m.encoder, err = tiktoken.GetEncoding(m.encoding) | ||
} else if m.modelName != "" { | ||
m.encoder, err = tiktoken.EncodingForModel(m.modelName) | ||
if err != nil { | ||
log.Warn().Err(err).Msgf("Failed to get encoder for model %s, using model type", m.modelName) | ||
m.modelType = m.modelName | ||
m.encoder, err = tiktoken.GetEncoding(m.modelTypeToEncoding()) | ||
} | ||
} else if m.modelType != "" { | ||
m.encoder, err = tiktoken.GetEncoding(m.modelTypeToEncoding()) | ||
} else { | ||
err = fmt.Errorf("no model name or type specified") | ||
} | ||
|
||
return err | ||
} | ||
|
||
func (m *Model) WithName(name string) *Model { | ||
m.modelName = name | ||
return m | ||
} | ||
|
||
func (m *Model) WithType(modelType string) *Model { | ||
m.modelType = modelType | ||
return m | ||
} | ||
|
||
func (m *Model) WithEncoding(encoding string) *Model { | ||
m.encoding = encoding | ||
return m | ||
} | ||
|
||
func (m *Model) GetID() string { | ||
if m.modelName != "" { | ||
return m.modelName | ||
} | ||
if m.modelType != "" { | ||
return string(m.modelType) | ||
} | ||
return m.encoding | ||
} | ||
|
||
func (m *Model) CountTokensOfLLMMessage(body []byte) (int, error) { | ||
request, err := ExtractMessageRequest(body) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
// Combine all message contents to form the full prompt | ||
prompt := CombineMessages(request.Messages) | ||
|
||
return m.CountTokensOfText(prompt) | ||
} | ||
|
||
func (m *Model) CountTokensOfText(text string) (int, error) { | ||
if m.encoder == nil { | ||
return 0, fmt.Errorf("encoder not initialized") | ||
} | ||
|
||
tokens := m.encoder.Encode(text, nil, nil) | ||
tokenCount := len(tokens) | ||
return tokenCount, nil | ||
} | ||
|
||
func (m *Model) modelTypeToEncoding() string { | ||
if m.modelType == "" { | ||
m.logger.Warn().Msg("Model type not set, using default encoding") | ||
return ChatGPTDefaultEncoding | ||
} | ||
switch strings.ToLower(m.modelType) { | ||
case ChatGPT: | ||
return ChatGPTDefaultEncoding | ||
case Claude: | ||
return ClaudeDefaultEncoding | ||
case Gemini: | ||
return GeminiDefaultEncoding | ||
default: | ||
m.logger.Error().Msgf("Model type %v not supported, using default encoding", m.modelType) | ||
return ChatGPTDefaultEncoding | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package ai | ||
|
||
import ( | ||
"lunar/toolkit-core/ai/models" | ||
|
||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
type Tokenizer struct { | ||
model models.ModelI | ||
} | ||
|
||
func NewTokenizer(modelName, modelType, encoding string) (*Tokenizer, error) { | ||
log.Trace(). | ||
Msgf("Creating tokenizer: model %v, type %v, encoding %v", modelName, modelType, encoding) | ||
model := models.NewModel().WithName(modelName).WithType(modelType).WithEncoding(encoding) | ||
err := model.Init() | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error initializing model %v", modelName) | ||
return nil, err | ||
} | ||
return &Tokenizer{model: model}, nil | ||
} | ||
|
||
func NewTokenizerFromModel(modelName string) (*Tokenizer, error) { | ||
log.Trace().Msgf("Creating tokenizer for model %v", modelName) | ||
model := models.NewModel().WithName(modelName) | ||
err := model.Init() | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error initializing model %v", modelName) | ||
return nil, err | ||
} | ||
return &Tokenizer{model: model}, nil | ||
} | ||
|
||
func NewTokenizerFromModelType(modelType string) (*Tokenizer, error) { | ||
log.Trace().Msgf("Creating tokenizer for model %v", modelType) | ||
model := models.NewModel().WithType(modelType) | ||
err := model.Init() | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error initializing model %v", modelType) | ||
return nil, err | ||
} | ||
return &Tokenizer{model: model}, nil | ||
} | ||
|
||
func NewTokenizerFromEncoding(encoding string) (*Tokenizer, error) { | ||
log.Trace().Msgf("Creating tokenizer for encoding %v", encoding) | ||
model := models.NewModel().WithEncoding(encoding) | ||
err := model.Init() | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error initializing model %v", encoding) | ||
return nil, err | ||
} | ||
return &Tokenizer{model: model}, nil | ||
} | ||
|
||
// CountTokensOfLLMMessage counts the number of tokens in the given body | ||
// LLM message is structured as a JSON object like this: | ||
// | ||
// { | ||
// "messages": [ | ||
// { | ||
// "role": "system", | ||
// "content": "You are a helpful assistant." | ||
// }, | ||
// { | ||
// "role": "user", | ||
// "content": "Explain how airplanes fly." | ||
// } | ||
// ] | ||
// } | ||
// | ||
// The content of each message is tokenized and counted. | ||
// role field is optional and can be omitted. | ||
func (t *Tokenizer) CountTokensOfLLMMessage(body []byte) (int, error) { | ||
tokens, err := t.model.CountTokensOfLLMMessage(body) | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error counting LLM tokens for model %v", t.model.GetID()) | ||
return 0, err | ||
} | ||
return tokens, err | ||
} | ||
|
||
// CountTokensOfText counts the number of tokens in the given text | ||
func (t *Tokenizer) CountTokensOfText(text string) (int, error) { | ||
tokens, err := t.model.CountTokensOfText(text) | ||
if err != nil { | ||
log.Error().Err(err).Msgf("Error counting LLM tokens of message for model %v", t.model.GetID()) | ||
return 0, err | ||
} | ||
return tokens, err | ||
} |
Oops, something went wrong.