Skip to content

Commit

Permalink
create provider plugin system
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-schabel committed Jan 9, 2025
1 parent 32ae9de commit 905f69e
Show file tree
Hide file tree
Showing 18 changed files with 828 additions and 983 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import { ProviderPlugin } from "../provider-plugin";
import { StreamParams } from "../provider-types";

interface AnthropicStreamResponse {
type: string;
message?: {
content?: Array<{ text?: string }>;
stop_reason?: string;
};
delta?: {
text?: string;
};
error?: { message: string };
}

/**
* The new Anthropic plugin
*/
export class AnthropicPlugin implements ProviderPlugin {
private apiKey: string;
private version: string;
private beta?: string;

constructor(apiKey: string, version: string, beta?: string) {
this.apiKey = apiKey;
this.version = version;
this.beta = beta;
}

async prepareRequest(params: StreamParams) {
const { userMessage, chatService, chatId, assistantMessageId, options } = params;

// Optionally save user message or reconstruct chat messages here
await chatService.saveMessage({
chatId,
role: "user",
content: userMessage,
});
await chatService.updateChatTimestamp(chatId);

const body = JSON.stringify({
model: options.model || "claude-2",
messages: [
{
role: "user",
content: userMessage,
},
],
max_tokens: options.max_tokens ?? 1024,
temperature: typeof options.temperature === "number" ? options.temperature : 1.0,
top_p: options.top_p ?? 1,
top_k: options.top_k ?? 0,
stream: true,
});

const headers: Record<string, string> = {
"Content-Type": "application/json",
"anthropic-version": this.version,
"x-api-key": this.apiKey,
};
if (this.beta) {
headers["anthropic-beta"] = this.beta;
}

const response = await fetch("https://api.anthropic.com/v1/messages", {
method: "POST",
headers,
body,
});

if (!response.ok || !response.body) {
const errorText = await response.text();
throw new Error(`Anthropic API error: ${response.status} - ${errorText}`);
}

return response.body.getReader();
}

parseServerSentEvent(line: string): string | null {
// SSE lines are prefixed with "data:" ...
if (!line.startsWith("data:")) return null;
const jsonString = line.replace(/^data:\s*/, "").trim();

// Anthropic uses "[DONE]" to signal the end
if (jsonString === "[DONE]") return "[DONE]";

try {
const parsed = JSON.parse(jsonString) as AnthropicStreamResponse;
if (parsed.error) {
throw new Error(`Anthropic SSE error: ${parsed.error.message}`);
}
if (parsed.type === "content_block_delta" && parsed.delta?.text) {
return parsed.delta.text;
}
return null;
} catch {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import { ProviderPlugin } from "../provider-plugin";
import { StreamParams } from "../provider-types";
import { TextEncoder, TextDecoder } from "util";

export class GeminiPlugin implements ProviderPlugin {
private geminiApiKey: string;
private geminiBaseUrl: string;
private modelId: string;

constructor(
geminiApiKey: string,
geminiBaseUrl: string,
modelId: string
) {
this.geminiApiKey = geminiApiKey;
this.geminiBaseUrl = geminiBaseUrl;
this.modelId = modelId;
}

/**
* Prepare the SSE request and return a ReadableStream of SSE lines:
* "data: <chunk>\n\n" ... "data: [DONE]\n\n"
*/
async prepareRequest(params: StreamParams) {
const { chatId, userMessage, chatService, options } = params;

// 1) Save user message and update chat timestamp
await chatService.saveMessage({
chatId,
role: "user",
content: userMessage,
});
await chatService.updateChatTimestamp(chatId);

// 2) Rebuild chat history in Gemini’s expected format
const msgs = await chatService.getChatMessages(chatId);
const messages = msgs
.filter((m: any) => m.content.trim().length > 0)
.map((m: any) => ({
role: m.role === "assistant" ? "model" : m.role,
parts: [{ text: m.content }],
}));

// 3) Build the Gemini request payload
const payload = {
contents: messages,
generationConfig: {
temperature: typeof options.temperature === "number" ? options.temperature : 0.7,
maxOutputTokens: options.max_tokens ?? 1024,
topP: options.top_p ?? 0.9,
topK: options.top_k ?? 40,
},
};

// 4) Send POST with alt=sse
const endpoint = `${this.geminiBaseUrl}/${this.modelId}:streamGenerateContent?alt=sse&key=${this.geminiApiKey}`;
if (options.debug) {
console.debug("[GeminiPlugin] Sending request:", {
endpoint,
payload,
});
}

const response = await fetch(endpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});

if (!response.ok || !response.body) {
console.error("Gemini API error response:", await response.text());
throw new Error(`Gemini API error: ${response.statusText}`);
}

// 5) Convert fetch response into SSE lines
const reader = response.body.getReader();
const { readable, writable } = new TransformStream();
(async () => {
try {
const decoder = new TextDecoder();
const encoder = new TextEncoder();

let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;

const chunk = decoder.decode(value, { stream: true });
buffer += chunk;

// SSE lines generally split on "\n"
const lines = buffer.split("\n");
buffer = lines.pop() || "";

for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed.startsWith(":")) continue;

if (trimmed.startsWith("data:")) {
const jsonString = trimmed.replace(/^data:\s*/, "");
if (jsonString === "[DONE]") {
// 6) Emit the "done" line so the streaming engine can close
const sseDone = `data: [DONE]\n\n`;
const writer = writable.getWriter();
await writer.write(encoder.encode(sseDone));
writer.releaseLock();
return;
}

try {
const parsed = JSON.parse(jsonString);
// The chunk text is in parsed.candidates[0].content.parts
if (parsed.candidates && parsed.candidates[0]?.content?.parts) {
const chunkText = parsed.candidates[0].content.parts
.map((p: any) => p.text)
.join("");

if (chunkText) {
// Emit SSE line with JSON-encoded chunk to preserve newlines
const payloadJson = JSON.stringify(chunkText);
const sseLine = `data: ${payloadJson}\n\n`;

const writer = writable.getWriter();
await writer.write(encoder.encode(sseLine));
writer.releaseLock();
}
}
} catch (err) {
console.error("[GeminiPlugin] SSE JSON parse error:", err);
}
}
}
}

// If the stream ended without an explicit [DONE], we can still finalize
// or optionally emit "data: [DONE]\n\n"
const writer = writable.getWriter();
await writer.write(encoder.encode("data: [DONE]\n\n"));
writer.releaseLock();

writable.close();
} catch (error) {
console.error("[GeminiPlugin] Error reading SSE from Gemini:", error);
writable.abort(error);
}
})();

return readable;
}

/**
* Parse each SSE line. If it's 'data: [DONE]', we return '[DONE]' so the SSE engine knows to finalize.
* Otherwise, we JSON-parse the chunk text so we preserve newlines/spaces.
*/
parseServerSentEvent(line: string): string | null {
if (!line.startsWith("data: ")) return null;

const payload = line.slice("data: ".length).trim();
if (payload === "[DONE]") {
return "[DONE]";
}

try {
return JSON.parse(payload); // decode the chunk text from JSON
} catch {
return payload;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { ProviderPlugin } from "../provider-plugin";
import { StreamParams } from "../provider-types";

export class GroqPlugin implements ProviderPlugin {
private apiKey: string;
private baseUrl: string;

constructor(apiKey: string, baseUrl: string) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
}

async prepareRequest(params: StreamParams) {
const { chatId, userMessage, chatService, options } = params;

await chatService.saveMessage({
chatId,
role: "user",
content: userMessage,
});
await chatService.updateChatTimestamp(chatId);

const payload = {
model: options.model || "llama-3.1-70b-versatile",
messages: [{ role: "user", content: userMessage }],
stream: true,
max_tokens: options.max_tokens ?? 1024,
temperature: typeof options.temperature === "number" ? options.temperature : 0.7,
top_p: options.top_p ?? 1,
frequency_penalty: options.frequency_penalty ?? 0,
presence_penalty: options.presence_penalty ?? 0,
};

const endpoint = `${this.baseUrl}/chat/completions`;
const response = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(payload),
});

if (!response.ok || !response.body) {
const errorText = await response.text();
throw new Error(`Groq API error: ${response.statusText} - ${errorText}`);
}

return response.body.getReader();
}

parseServerSentEvent(line: string): string | null {
if (!line.startsWith("data:")) return null;
const jsonString = line.replace(/^data:\s*/, "").trim();

if (jsonString === "[DONE]") return "[DONE]";

try {
const parsed = JSON.parse(jsonString);
// e.g. parsed.choices[0].delta.content
const content = parsed.choices?.[0]?.delta?.content;
return content || null;
} catch {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { ProviderPlugin } from "../provider-plugin";
import { StreamParams } from "../provider-types";

export class OllamaPlugin implements ProviderPlugin {
private baseUrl: string;

constructor(baseUrl: string) {
this.baseUrl = baseUrl;
}

async prepareRequest(params: StreamParams) {
const { userMessage, chatService, chatId, assistantMessageId, options } = params;

// Save user message, reconstruct chat if you like, etc.
await chatService.saveMessage({
chatId,
role: "user",
content: userMessage,
});
await chatService.updateChatTimestamp(chatId);

const response = await fetch(`${this.baseUrl}/api/chat`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: options.model || "llama3:latest",
messages: [{ role: "user", content: userMessage }],
stream: true,
...options, // pass along other config
}),
});

if (!response.ok || !response.body) {
throw new Error(`Ollama API error: ${response.statusText}`);
}

return response.body.getReader();
}

parseServerSentEvent(line: string): string | null {
// Each line is JSON
// We can ignore lines that don't parse.
// Return "[DONE]" if there's some condition for done (if needed).
try {
const data = JSON.parse(line);
const chunk = data?.message?.content || "";
return chunk || null;
} catch {
// If partial or invalid JSON
return null;
}
}
}
Loading

0 comments on commit 905f69e

Please sign in to comment.