Skip to content

Commit

Permalink
nest the spans
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin1 committed Oct 31, 2024
1 parent 810c91e commit 09d4b56
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 83 deletions.
3 changes: 2 additions & 1 deletion apis/cloudflare/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"@opentelemetry/sdk-metrics": "^1.18.1",
"braintrust": "link:../../../sdk/js",
"dotenv": "^16.3.1",
"openai": "^4.67.1"
"openai": "^4.67.1",
"zod": "3.22.4"
}
}
219 changes: 144 additions & 75 deletions apis/cloudflare/src/realtime-logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ const messageContentSchema = z.discriminatedUnion("role", [
}),
]);

const outputItemSchema = z.union([
const inputItemSchema = z.union([
z.object({
type: z.literal("function_call"),
call_id: z.string(),
Expand All @@ -152,6 +152,14 @@ const outputItemSchema = z.union([
}),
]);

const outputItemSchema = inputItemSchema.and(
z.object({
id: z.string(),
object: z.literal("realtime.item"),
status: responseStatusSchema,
}),
);

const baseResponseSchema = z.object({
object: z.literal("realtime.response"),
id: z.string(),
Expand Down Expand Up @@ -246,7 +254,19 @@ const functionCallDoneMessageSchema = functionCallBaseMessageSchema.extend({
const conversationItemCreateMessageSchema = baseMessageSchema.extend({
type: z.literal("conversation.item.create"),
previous_item_id: z.string().nullish(),
item: outputItemSchema,
item: inputItemSchema,
});

const speechStartedMessageSchema = baseMessageSchema.extend({
type: z.literal("input_audio_buffer.speech_started"),
audio_start_ms: z.number(),
item_id: z.string(),
});

const speechEndedMessageSchema = baseMessageSchema.extend({
type: z.literal("input_audio_buffer.speech_stopped"),
audio_end_ms: z.number(),
item_id: z.string(),
});

const errorMessageSchema = baseMessageSchema.extend({
Expand Down Expand Up @@ -285,32 +305,57 @@ const openAiRealtimeMessageSchema = z.discriminatedUnion("type", [
functionCallDeltaMessageSchema,
functionCallDoneMessageSchema,
conversationItemCreateMessageSchema,
speechStartedMessageSchema,
speechEndedMessageSchema,
errorMessageSchema,
unhandledMessageSchema,
]);

// The maximum audio buffer size after pushing.
const maxAudioBufferBytes = 50 * 1024 * 1024;
// When the buffer rolls over, the target size.
const targetAudioBufferBytes = 40 * 1024 * 1024; // 40 MB = about 10 minutes in base64.

/**
* Helper class to accumulate and encode a single audio stream, which can then
* be logged to Braintrust.
*/
class AudioBuffer {
private inputCodec: PcmAudioFormat;
private audioBuffers: string[];
private totalLength: number;
private totalByteLength: number;

constructor({ inputCodec }: { inputCodec: PcmAudioFormat }) {
this.inputCodec = inputCodec;
this.audioBuffers = [];
this.totalLength = 0;
this.totalByteLength = 0;
}

push(audioBufferBase64: string): void {
this.audioBuffers.push(audioBufferBase64);
this.totalLength += audioBufferBase64.length;
push(audioBuffer: string): void {
this.audioBuffers.push(audioBuffer);
this.totalByteLength += audioBuffer.length;

// May run out of memory on Cloudflare Workers.
if (this.totalByteLength > maxAudioBufferBytes) {
console.warn(
`Audio buffer reached trimming threshold at ${this.totalByteLength} bytes`,
);
let i = 0;
for (
;
i < this.audioBuffers.length &&
this.totalByteLength > targetAudioBufferBytes;
i++
) {
this.totalByteLength -= this.audioBuffers[i].length;
}
this.audioBuffers = this.audioBuffers.slice(i + 1);
console.warn(`Trimmed audio buffer to ${this.totalByteLength} bytes`);
}
}

get length(): number {
return this.totalLength;
get byteLength(): number {
return this.totalByteLength;
}

encode(compress: boolean): [Blob, string] {
Expand Down Expand Up @@ -470,6 +515,7 @@ export class OpenAiRealtimeLogger {
) {
this.rootSpan.log({
metadata: {
// Consider disabling merging.
openai_realtime_session: message.session,
},
});
Expand All @@ -483,7 +529,17 @@ export class OpenAiRealtimeLogger {
message.session.output_audio_format,
);
}
} else if (message.type === "response.created") {
const id = message.response.id;
if (this.serverSpans.has(id)) {
throw new Error(`Duplicate response ID ${id}`);
}
this.serverSpans.set(id, this.rootSpan.startSpan({ name: "item" }));
} else if (message.type === "response.output_item.added") {
// Create a new span for each item since the response can contain
// multiple. For example, the model could generate audio and a
// function call in a single turn of the conversation.
const id = message.item.id;
if (
message.item.type === "message" &&
message.item.role === "assistant"
Expand All @@ -492,100 +548,101 @@ export class OpenAiRealtimeLogger {
throw new Error("Messages may have been received out of order.");
}
this.serverAudioBuffer.set(
message.response_id,
id,
new AudioBuffer({ inputCodec: this.outputAudioFormat }),
);
}

if (!this.serverSpans.has(message.response_id)) {
if (!this.serverSpans.has(id)) {
let parentSpan = this.serverSpans.get(message.response_id);
if (!parentSpan) {
console.warn(`No parent span for response ID ${message.response_id}`);
parentSpan = this.rootSpan;
}
this.serverSpans.set(
message.response_id,
this.rootSpan.startSpan({
id,
parentSpan.startSpan({
name:
message.item.type === "message" ? message.item.role : "function",
type: message.item.type === "message" ? "llm" : "function",
event: {
id: message.response_id,
},
event: { id },
}),
);
}
} else if (message.type === "response.audio.delta") {
const audioBuffer = this.serverAudioBuffer.get(message.response_id);
const id = message.item_id;
const audioBuffer = this.serverAudioBuffer.get(id);
if (!audioBuffer) {
throw new Error(`Invalid response ID: ${message.response_id}`);
throw new Error(
`Invalid response ID: ${message.response_id}, item ID: ${id}`,
);
}
audioBuffer.push(message.delta);
} else if (message.type === "response.audio.done") {
const audioBuffer = this.serverAudioBuffer.get(message.response_id);
const span = this.serverSpans.get(message.response_id);
const id = message.item_id;
const audioBuffer = this.serverAudioBuffer.get(id);
const span = this.serverSpans.get(id);
if (!audioBuffer || !span) {
throw new Error(`Invalid response ID: ${message.response_id}`);
}
// May run out of memory on Cloudflare Workers.
if (audioBuffer.length > 5 * 1024 * 1024) {
console.warn(
`Writing a large audio buffer (${audioBuffer.length} bytes in base64)`,
throw new Error(
`Invalid response ID: ${message.response_id}, item ID: ${id}`,
);
}
this.closeAudio(audioBuffer, span, "output");
this.serverAudioBuffer.delete(message.response_id);
this.serverAudioBuffer.delete(id);
} else if (
message.type === "conversation.item.input_audio_transcription.completed"
) {
this.clientSpan?.log({
input: {
transcript: message.transcript,
},
input: { transcript: message.transcript },
});
// The transcript can never come before we finish logging audio.
this.clientSpan?.close();
this.clientSpan = undefined;
} else if (message.type === "response.done") {
const span = this.serverSpans.get(message.response.id);
if (!span) {
throw new Error(`Invalid response ID: ${message.response.id}`);
}
if (message.response.output.length > 1) {
// TODO:
console.warn(
`Response had ${message.response.output.length} items. The first one will be used.`,
);
}
if (message.response.output.length === 0) {
throw new Error(`Response ID ${message.response.id} had no items`);
console.warn(`Response ID ${message.response.id} had no items`);
}
const item = message.response.output[0];
for (const item of message.response.output) {
const id = item.id;
const span = this.serverSpans.get(id);
if (!span) {
throw new Error(
`Invalid response ID: ${message.response.id}, item ID: ${id}`,
);
}
if (item.type === "message") {
span.log({
output: { content: item.content },
});
span.close();
} else if (item.type === "function_call") {
let args: unknown = item.arguments;
try {
args = JSON.parse(item.arguments);
} catch {}
span.log({
input: {
name: item.name,
arguments: args,
},
});
// Wait for function call output before closing the span.
this.toolSpans.set(item.call_id, span);
}

if (item.type === "message") {
span.log({
output: {
content: item.content,
},
metadata: {
usage: message.response.usage ?? undefined,
},
});
span.close();
} else if (item.type === "function_call") {
let args: unknown = item.arguments;
try {
args = JSON.parse(item.arguments);
} catch {}
span.log({
input: {
name: item.name,
arguments: args,
},
this.serverSpans.delete(id);

const parentSpan = this.serverSpans.get(message.response.id);
if (!parentSpan) {
throw new Error(`Unknown response ID ${id}`);
}
parentSpan.log({
metadata: {
usage: message.response.usage ?? undefined,
},
});
// Wait for function call output before closing the span.
this.toolSpans.set(item.call_id, span);
parentSpan.close();
}

this.serverSpans.delete(message.response.id);
}
}

Expand All @@ -612,31 +669,43 @@ export class OpenAiRealtimeLogger {
* Close all pending spans.
*/
public async close() {
// Check if there is a pending audio buffer.
// Check if there is a pending audio buffers.
if (this.serverAudioBuffer.size || this.clientAudioBuffer) {
console.warn(
`Closing with ${this.serverAudioBuffer.size} pending server + ${this.clientAudioBuffer ? 1 : 0} pending client audio buffers`,
);
}

if (this.clientAudioBuffer && this.clientSpan) {
this.closeAudio(this.clientAudioBuffer, this.clientSpan, "input");
this.clientAudioBuffer = undefined;
}
for (const [responseId, audioBuffer] of this.serverAudioBuffer) {
const span = this.serverSpans.get(responseId);
if (!span) {
continue;
}
this.closeAudio(audioBuffer, span, "output");
this.serverAudioBuffer.clear();
}

if (this.serverAudioBuffer.size || this.clientAudioBuffer) {
console.warn(
`Closing with ${this.serverAudioBuffer.size} pending server + ${this.clientAudioBuffer ? 1 : 0} pending client audio buffers`,
);
}
this.clientSpan?.close();
this.clientSpan = undefined;

for (const span of this.serverSpans.values()) {
span.close();
}
this.serverSpans.clear();

for (const span of this.toolSpans.values()) {
span.close();
}
this.clientSpan?.close();
this.rootSpan.close();
await this.rootSpan.flush();
this.toolSpans.clear();

const rootSpan = this.rootSpan;
this.rootSpan = Braintrust.NOOP_SPAN;

rootSpan.close();
await rootSpan.flush();
}
}
8 changes: 2 additions & 6 deletions apis/cloudflare/src/realtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ export async function handleRealtimeProxy({
try {
realtimeLogger?.handleMessageServer(event);
} catch (e) {
console.warn(
`Error logging server event: ${e} ${JSON.stringify(event, null, 2)}`,
);
console.warn(`Error logging server event: ${e} ${event.type}`);
}
server.send(JSON.stringify(event));
});
Expand All @@ -155,9 +153,7 @@ export async function handleRealtimeProxy({
try {
realtimeLogger?.handleMessageClient(parsedEvent);
} catch (e) {
console.warn(
`Error logging client event: ${e} ${JSON.stringify(parsedEvent, null, 2)}`,
);
console.warn(`Error logging client event: ${e} ${parsedEvent.type}`);
}
// console.log(`Relaying "${event.type}" to OpenAI`);
realtimeClient.realtime.send(parsedEvent.type, parsedEvent);
Expand Down
4 changes: 3 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 09d4b56

Please sign in to comment.