diff --git a/apis/cloudflare/package.json b/apis/cloudflare/package.json index 0bdd168..16b9436 100644 --- a/apis/cloudflare/package.json +++ b/apis/cloudflare/package.json @@ -24,6 +24,7 @@ "@opentelemetry/sdk-metrics": "^1.18.1", "braintrust": "link:../../../sdk/js", "dotenv": "^16.3.1", - "openai": "^4.67.1" + "openai": "^4.67.1", + "zod": "3.22.4" } } diff --git a/apis/cloudflare/src/realtime-logger.ts b/apis/cloudflare/src/realtime-logger.ts index b7e8799..5d223be 100644 --- a/apis/cloudflare/src/realtime-logger.ts +++ b/apis/cloudflare/src/realtime-logger.ts @@ -129,7 +129,7 @@ const messageContentSchema = z.discriminatedUnion("role", [ }), ]); -const outputItemSchema = z.union([ +const inputItemSchema = z.union([ z.object({ type: z.literal("function_call"), call_id: z.string(), @@ -152,6 +152,14 @@ const outputItemSchema = z.union([ }), ]); +const outputItemSchema = inputItemSchema.and( + z.object({ + id: z.string(), + object: z.literal("realtime.item"), + status: responseStatusSchema, + }), +); + const baseResponseSchema = z.object({ object: z.literal("realtime.response"), id: z.string(), @@ -246,7 +254,19 @@ const functionCallDoneMessageSchema = functionCallBaseMessageSchema.extend({ const conversationItemCreateMessageSchema = baseMessageSchema.extend({ type: z.literal("conversation.item.create"), previous_item_id: z.string().nullish(), - item: outputItemSchema, + item: inputItemSchema, +}); + +const speechStartedMessageSchema = baseMessageSchema.extend({ + type: z.literal("input_audio_buffer.speech_started"), + audio_start_ms: z.number(), + item_id: z.string(), +}); + +const speechEndedMessageSchema = baseMessageSchema.extend({ + type: z.literal("input_audio_buffer.speech_stopped"), + audio_end_ms: z.number(), + item_id: z.string(), }); const errorMessageSchema = baseMessageSchema.extend({ @@ -285,10 +305,17 @@ const openAiRealtimeMessageSchema = z.discriminatedUnion("type", [ functionCallDeltaMessageSchema, functionCallDoneMessageSchema, conversationItemCreateMessageSchema, + speechStartedMessageSchema, + speechEndedMessageSchema, errorMessageSchema, unhandledMessageSchema, ]); +// The maximum audio buffer size after pushing. +const maxAudioBufferBytes = 50 * 1024 * 1024; +// When the buffer rolls over, the target size. +const targetAudioBufferBytes = 40 * 1024 * 1024; // 40 MB = about 10 minutes in base64. + /** * Helper class to accumulate and encode a single audio stream, which can then * be logged to Braintrust. @@ -296,21 +323,39 @@ const openAiRealtimeMessageSchema = z.discriminatedUnion("type", [ class AudioBuffer { private inputCodec: PcmAudioFormat; private audioBuffers: string[]; - private totalLength: number; + private totalByteLength: number; constructor({ inputCodec }: { inputCodec: PcmAudioFormat }) { this.inputCodec = inputCodec; this.audioBuffers = []; - this.totalLength = 0; + this.totalByteLength = 0; } - push(audioBufferBase64: string): void { - this.audioBuffers.push(audioBufferBase64); - this.totalLength += audioBufferBase64.length; + push(audioBuffer: string): void { + this.audioBuffers.push(audioBuffer); + this.totalByteLength += audioBuffer.length; + + // May run out of memory on Cloudflare Workers. + if (this.totalByteLength > maxAudioBufferBytes) { + console.warn( + `Audio buffer reached trimming threshold at ${this.totalByteLength} bytes`, + ); + let i = 0; + for ( + ; + i < this.audioBuffers.length && + this.totalByteLength > targetAudioBufferBytes; + i++ + ) { + this.totalByteLength -= this.audioBuffers[i].length; + } + this.audioBuffers = this.audioBuffers.slice(i + 1); + console.warn(`Trimmed audio buffer to ${this.totalByteLength} bytes`); + } } - get length(): number { - return this.totalLength; + get byteLength(): number { + return this.totalByteLength; } encode(compress: boolean): [Blob, string] { @@ -470,6 +515,7 @@ export class OpenAiRealtimeLogger { ) { this.rootSpan.log({ metadata: { + // Consider disabling merging. openai_realtime_session: message.session, }, }); @@ -483,7 +529,17 @@ export class OpenAiRealtimeLogger { message.session.output_audio_format, ); } + } else if (message.type === "response.created") { + const id = message.response.id; + if (this.serverSpans.has(id)) { + throw new Error(`Duplicate response ID ${id}`); + } + this.serverSpans.set(id, this.rootSpan.startSpan({ name: "item" })); } else if (message.type === "response.output_item.added") { + // Create a new span for each item since the response can contain + // multiple. For example, the model could generate audio and a + // function call in a single turn of the conversation. + const id = message.item.id; if ( message.item.type === "message" && message.item.role === "assistant" @@ -492,100 +548,101 @@ export class OpenAiRealtimeLogger { throw new Error("Messages may have been received out of order."); } this.serverAudioBuffer.set( - message.response_id, + id, new AudioBuffer({ inputCodec: this.outputAudioFormat }), ); } - if (!this.serverSpans.has(message.response_id)) { + if (!this.serverSpans.has(id)) { + let parentSpan = this.serverSpans.get(message.response_id); + if (!parentSpan) { + console.warn(`No parent span for response ID ${message.response_id}`); + parentSpan = this.rootSpan; + } this.serverSpans.set( - message.response_id, - this.rootSpan.startSpan({ + id, + parentSpan.startSpan({ name: message.item.type === "message" ? message.item.role : "function", type: message.item.type === "message" ? "llm" : "function", - event: { - id: message.response_id, - }, + event: { id }, }), ); } } else if (message.type === "response.audio.delta") { - const audioBuffer = this.serverAudioBuffer.get(message.response_id); + const id = message.item_id; + const audioBuffer = this.serverAudioBuffer.get(id); if (!audioBuffer) { - throw new Error(`Invalid response ID: ${message.response_id}`); + throw new Error( + `Invalid response ID: ${message.response_id}, item ID: ${id}`, + ); } audioBuffer.push(message.delta); } else if (message.type === "response.audio.done") { - const audioBuffer = this.serverAudioBuffer.get(message.response_id); - const span = this.serverSpans.get(message.response_id); + const id = message.item_id; + const audioBuffer = this.serverAudioBuffer.get(id); + const span = this.serverSpans.get(id); if (!audioBuffer || !span) { - throw new Error(`Invalid response ID: ${message.response_id}`); - } - // May run out of memory on Cloudflare Workers. - if (audioBuffer.length > 5 * 1024 * 1024) { - console.warn( - `Writing a large audio buffer (${audioBuffer.length} bytes in base64)`, + throw new Error( + `Invalid response ID: ${message.response_id}, item ID: ${id}`, ); } this.closeAudio(audioBuffer, span, "output"); - this.serverAudioBuffer.delete(message.response_id); + this.serverAudioBuffer.delete(id); } else if ( message.type === "conversation.item.input_audio_transcription.completed" ) { this.clientSpan?.log({ - input: { - transcript: message.transcript, - }, + input: { transcript: message.transcript }, }); // The transcript can never come before we finish logging audio. this.clientSpan?.close(); this.clientSpan = undefined; } else if (message.type === "response.done") { - const span = this.serverSpans.get(message.response.id); - if (!span) { - throw new Error(`Invalid response ID: ${message.response.id}`); - } - if (message.response.output.length > 1) { - // TODO: - console.warn( - `Response had ${message.response.output.length} items. The first one will be used.`, - ); - } if (message.response.output.length === 0) { - throw new Error(`Response ID ${message.response.id} had no items`); + console.warn(`Response ID ${message.response.id} had no items`); } - const item = message.response.output[0]; + for (const item of message.response.output) { + const id = item.id; + const span = this.serverSpans.get(id); + if (!span) { + throw new Error( + `Invalid response ID: ${message.response.id}, item ID: ${id}`, + ); + } + if (item.type === "message") { + span.log({ + output: { content: item.content }, + }); + span.close(); + } else if (item.type === "function_call") { + let args: unknown = item.arguments; + try { + args = JSON.parse(item.arguments); + } catch {} + span.log({ + input: { + name: item.name, + arguments: args, + }, + }); + // Wait for function call output before closing the span. + this.toolSpans.set(item.call_id, span); + } - if (item.type === "message") { - span.log({ - output: { - content: item.content, - }, - metadata: { - usage: message.response.usage ?? undefined, - }, - }); - span.close(); - } else if (item.type === "function_call") { - let args: unknown = item.arguments; - try { - args = JSON.parse(item.arguments); - } catch {} - span.log({ - input: { - name: item.name, - arguments: args, - }, + this.serverSpans.delete(id); + + const parentSpan = this.serverSpans.get(message.response.id); + if (!parentSpan) { + throw new Error(`Unknown response ID ${id}`); + } + parentSpan.log({ metadata: { usage: message.response.usage ?? undefined, }, }); - // Wait for function call output before closing the span. - this.toolSpans.set(item.call_id, span); + parentSpan.close(); } - - this.serverSpans.delete(message.response.id); } } @@ -612,9 +669,16 @@ export class OpenAiRealtimeLogger { * Close all pending spans. */ public async close() { - // Check if there is a pending audio buffer. + // Check if there is a pending audio buffers. + if (this.serverAudioBuffer.size || this.clientAudioBuffer) { + console.warn( + `Closing with ${this.serverAudioBuffer.size} pending server + ${this.clientAudioBuffer ? 1 : 0} pending client audio buffers`, + ); + } + if (this.clientAudioBuffer && this.clientSpan) { this.closeAudio(this.clientAudioBuffer, this.clientSpan, "input"); + this.clientAudioBuffer = undefined; } for (const [responseId, audioBuffer] of this.serverAudioBuffer) { const span = this.serverSpans.get(responseId); @@ -622,21 +686,26 @@ export class OpenAiRealtimeLogger { continue; } this.closeAudio(audioBuffer, span, "output"); + this.serverAudioBuffer.clear(); } - if (this.serverAudioBuffer.size || this.clientAudioBuffer) { - console.warn( - `Closing with ${this.serverAudioBuffer.size} pending server + ${this.clientAudioBuffer ? 1 : 0} pending client audio buffers`, - ); - } + this.clientSpan?.close(); + this.clientSpan = undefined; + for (const span of this.serverSpans.values()) { span.close(); } + this.serverSpans.clear(); + for (const span of this.toolSpans.values()) { span.close(); } - this.clientSpan?.close(); - this.rootSpan.close(); - await this.rootSpan.flush(); + this.toolSpans.clear(); + + const rootSpan = this.rootSpan; + this.rootSpan = Braintrust.NOOP_SPAN; + + rootSpan.close(); + await rootSpan.flush(); } } diff --git a/apis/cloudflare/src/realtime.ts b/apis/cloudflare/src/realtime.ts index a7ce69e..021088f 100644 --- a/apis/cloudflare/src/realtime.ts +++ b/apis/cloudflare/src/realtime.ts @@ -130,9 +130,7 @@ export async function handleRealtimeProxy({ try { realtimeLogger?.handleMessageServer(event); } catch (e) { - console.warn( - `Error logging server event: ${e} ${JSON.stringify(event, null, 2)}`, - ); + console.warn(`Error logging server event: ${e} ${event.type}`); } server.send(JSON.stringify(event)); }); @@ -155,9 +153,7 @@ export async function handleRealtimeProxy({ try { realtimeLogger?.handleMessageClient(parsedEvent); } catch (e) { - console.warn( - `Error logging client event: ${e} ${JSON.stringify(parsedEvent, null, 2)}`, - ); + console.warn(`Error logging client event: ${e} ${parsedEvent.type}`); } // console.log(`Relaying "${event.type}" to OpenAI`); realtimeClient.realtime.send(parsedEvent.type, parsedEvent); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index baee8c8..1623a92 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -47,6 +47,9 @@ importers: openai: specifier: ^4.67.1 version: 4.68.4 + zod: + specifier: 3.22.4 + version: 3.22.4 devDependencies: '@cloudflare/workers-types': specifier: ^4.20241022.0 @@ -7800,7 +7803,6 @@ packages: /zod@3.22.4: resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} - dev: true /zod@3.23.8: resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}