Skip to content

Commit

Permalink
feat: support translations
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmcguire1 committed May 16, 2024
1 parent 6fa3a61 commit b1e1ab3
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 2 deletions.
11 changes: 11 additions & 0 deletions cmd/sqs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ func (handler *SqsHandler) ProcessChatGPTRequest(ctx context.Context, req *chatm
With("error", err).
Error("failed to persist image resolutions")

errorMsg = err.Error()
break
}
case chatmodels.CHAT_MODEL_TRANSLATIONS:
response, err = handler.ChatModelSvc.Translate(ctx, req.Prompt, req.SourceLanguage, req.TargetLanguage, req.Model)
if err != nil {
handler.Logger.
With("prompt", req.Prompt).
With("error", err).
Error("failed to process translation request")

errorMsg = err.Error()
break
}
Expand Down
16 changes: 16 additions & 0 deletions internal/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ func (h *Handler) DispatchIntents(ctx context.Context, req alexa.Request) (res a
break
}

res, err = h.GetResponse(ctx, h.PollDelay, false)
case alexa.TranslateIntent:
prompt := req.Body.Intent.Slots["prompt"].Value
sourceLanguage := req.Body.Intent.Slots["sourcelang"].Value
targetLanguage := req.Body.Intent.Slots["targetlang"].Value

err = h.RequestsQueue.PushMessage(ctx, &chatmodels.Request{
Prompt: prompt,
TargetLanguage: targetLanguage,
SourceLanguage: sourceLanguage,
Model: chatmodels.CHAT_MODEL_TRANSLATIONS,
})
if err != nil {
break
}

res, err = h.GetResponse(ctx, h.PollDelay, false)
case alexa.AutoCompleteIntent:
prompt := req.Body.Intent.Slots["prompt"].Value
Expand Down
13 changes: 13 additions & 0 deletions internal/api/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ response:
)
h.lastResponse = response
return
case chatmodels.CHAT_MODEL_TRANSLATIONS:
res = alexa.NewResponse(
"Response",
fmt.Sprintf(

Check failure on line 69 in internal/api/response.go

View workflow job for this annotation

GitHub Actions / build-deploy

fmt.Sprintf call needs 2 args but has 3 args
"your translated prompt is %s, this took %s seconds to fetch the answer",
response.Response,
response.Model,
response.TimeDiff,
),
false,
)
h.lastResponse = response
return
default:
res = alexa.NewResponse("Response",
fmt.Sprintf(
Expand Down
6 changes: 6 additions & 0 deletions internal/dom/chatmodels/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type GeminiAPI interface {
type CloudFlareAiWorkerAPI interface {
GenerateText(context.Context, string, string) (string, error)
GenerateImage(ctx context.Context, prompt string, model string) ([]byte, error)
GenerateTranslation(ctx context.Context, req *GenerateTranslationRequest) (string, error)
}

type mockAPI struct {
Expand Down Expand Up @@ -49,3 +50,8 @@ func (api *mockAPI) GenerateImage(ctx context.Context, prompt string, model stri
}
return res, args.Error(1)
}

func (api *mockAPI) GenerateTranslation(ctx context.Context, req *GenerateTranslationRequest) (string, error) {
args := api.Called(ctx, req)
return args.String(0), args.Error(1)
}
58 changes: 58 additions & 0 deletions internal/dom/chatmodels/cloudflare_ai_worker_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
CF_AWQ_MODEL = "@hf/thebloke/llama-2-13b-chat-awq"
CF_OPEN_CHAT_MODEL = "@cf/openchat/openchat-3.5-0106"
CF_STABLE_DIFFUSION = "@cf/stabilityai/stable-diffusion-xl-base-1.0"
CF_META_TRANSLATION_MODEL = "@cf/meta/m2m100-1.2b"
)

var CHAT_MODEL_TO_CF_MODEL = map[ChatModel]string{
Expand All @@ -26,6 +27,7 @@ var CHAT_MODEL_TO_CF_MODEL = map[ChatModel]string{
CHAT_MODEL_META: CF_LLAMA_3_8B_INSTRUCT_MODEL,
CHAT_MODEL_OPEN: CF_OPEN_CHAT_MODEL,
CHAT_MODEL_STABLE_DIFFUSION: CF_STABLE_DIFFUSION,
CHAT_MODEL_TRANSLATIONS: CF_META_TRANSLATION_MODEL,
}

type Response struct {
Expand Down Expand Up @@ -129,3 +131,59 @@ func (api *CloudflareApiClient) GenerateImage(ctx context.Context, prompt string

return data, nil
}

type GenerateTranslationRequest struct {
SourceLanguage string
TargetLanguage string
Prompt string
Model string
}

func (api *CloudflareApiClient) GenerateTranslation(ctx context.Context, req *GenerateTranslationRequest) (string, error) {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/accounts/%s/ai/run/%s", api.AccountID, req.Model)

if req.SourceLanguage == "" {
req.SourceLanguage = "en"
}
payload := map[string]string{
"prompt": req.Prompt,
"source_language": req.SourceLanguage,
"target_language": req.TargetLanguage,
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", err
}

httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", err
}
httpReq.Header.Set("Authorization", "Bearer "+api.APIKey)
httpReq.Header.Set("Content-Type", "application/json")

client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()

data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

type Result struct {
TranslatedText string `json:"translated_text"`
}

var result *Result
err = json.Unmarshal(data, &result)
if err != nil {
return "", err
}

return result.TranslatedText, nil
}
1 change: 1 addition & 0 deletions internal/dom/chatmodels/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const (
CHAT_MODEL_GPT ChatModel = "gpt"
CHAT_MODEL_META ChatModel = "llama"
CHAT_MODEL_AWQ ChatModel = "awq"
CHAT_MODEL_TRANSLATIONS ChatModel = "translate"
CHAT_MODEL_OPEN ChatModel = "open chat"
CHAT_MODEL_SQL ChatModel = "sql"
CHAT_MODEL_STABLE_DIFFUSION ChatModel = "stable diffusion"
Expand Down
24 changes: 24 additions & 0 deletions internal/dom/chatmodels/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ func (client *Client) GenerateImage(ctx context.Context, prompt string, model Ch
}
return nil, fmt.Errorf("unidentified image generation model")
}

func (client *Client) Translate(
ctx context.Context,
prompt string,
sourceLang string,
targetLang string,
model ChatModel,
) (string, error) {
if sourceLang == "" {
sourceLang = "en"
}
if targetLang == "" {
targetLang = "jp"
}
if model == "" {
model = CHAT_MODEL_TRANSLATIONS
}
return client.CloudflareApiClient.GenerateTranslation(ctx, &GenerateTranslationRequest{
SourceLanguage: sourceLang,
TargetLanguage: targetLang,
Prompt: prompt,
Model: CHAT_MODEL_TO_CF_MODEL[model],
})
}
6 changes: 4 additions & 2 deletions internal/dom/chatmodels/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ type LastResponse struct {
}

type Request struct {
Prompt string `json:"prompt"`
Model ChatModel `json:"model"`
Prompt string `json:"prompt"`
TargetLanguage string `json:"target_language,omitempty"`
SourceLanguage string `json:"source_language,omitempty"`
Model ChatModel `json:"model"`
}
7 changes: 7 additions & 0 deletions internal/dom/chatmodels/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ type Resources struct {
type Service interface {
AutoComplete(context.Context, string, ChatModel) (string, error)
GenerateImage(context.Context, string, ChatModel) ([]byte, error)
Translate(
ctx context.Context,
prompt string,
sourceLang string,
targetLang string,
model ChatModel,
) (string, error)
}

type Client struct {
Expand Down
1 change: 1 addition & 0 deletions internal/pkg/alexa/intents.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const (
FallbackIntent = "AMAZON.FallbackIntent"
AutoCompleteIntent = "AutoCompleteIntent"
ImageIntent = "ImageIntent"
TranslateIntent = "TranslateIntent"
RandomFactIntent = "RandomFactIntent"
ModelIntent = "Model"
PurgeIntent = "Purge"
Expand Down

0 comments on commit b1e1ab3

Please sign in to comment.