Skip to content

Commit

Permalink
Embedding distance (#39)
Browse files Browse the repository at this point in the history
Add the building block embedding distance, that computes cosine
similarity and reports it as a score.
  • Loading branch information
ankrgyl authored Dec 15, 2023
1 parent 868eea6 commit 8a00ce6
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ npx braintrust run example.eval.js

### Embeddings

- Embedding distance
- [ ] BERTScore
- [ ] Ada Embedding distance

### Heuristic

Expand Down
53 changes: 53 additions & 0 deletions js/embeddings.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { EmbeddingDistance } from "./string.js";

const SYNONYMS = [
{
word: "water",
synonyms: ["water", "H2O", "agua"],
},
{
word: "fire",
synonyms: ["fire", "flame"],
},
{
word: "earth",
synonyms: ["earth", "Planet Earth"],
},
];

const UNRELATED = [
"water",
"The quick brown fox jumps over the lazy dog",
"I like to eat apples",
];

test("Embeddings Test", async () => {
const prefix = "resource type: ";
for (const { word, synonyms } of SYNONYMS) {
for (const synonym of synonyms) {
const result = await EmbeddingDistance({
prefix,
output: word,
expected: synonym,
});
expect(result.score).toBeGreaterThan(0.6);
}
}

for (let i = 0; i < UNRELATED.length; i++) {
for (let j = 0; j < UNRELATED.length; j++) {
if (i == j) {
continue;
}

const word1 = UNRELATED[i];
const word2 = UNRELATED[j];
const result = await EmbeddingDistance({
prefix,
output: word1,
expected: word2,
});
expect(result.score).toBeLessThan(0.5);
}
}
});
9 changes: 3 additions & 6 deletions js/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as yaml from "js-yaml";
import mustache from "mustache";

import { Score, Scorer, ScorerArgs } from "@braintrust/core";
import { ChatCache, cachedChatCompletion } from "./oai.js";
import { ChatCache, OpenAIAuth, cachedChatCompletion } from "./oai.js";
import { templates } from "./templates.js";
import {
ChatCompletionCreateParams,
Expand All @@ -18,13 +18,10 @@ const COT_SUFFIX =

const SUPPORTED_MODELS = ["gpt-3.5-turbo", "gpt-4"];

interface LLMArgs {
type LLMArgs = {
maxTokens?: number;
temperature?: number;
openAiApiKey?: string;
openAiOrganizationId?: string;
openAiBaseUrl?: string;
}
} & OpenAIAuth;

const PLAIN_RESPONSE_SCHEMA = {
properties: {
Expand Down
23 changes: 12 additions & 11 deletions js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ export interface OpenAIAuth {

const PROXY_URL = "https://braintrustproxy.com/v1";

export function buildOpenAIClient(options: OpenAIAuth): OpenAI {
const { openAiApiKey, openAiOrganizationId, openAiBaseUrl } = options;

return new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY,
organization: openAiOrganizationId,
baseURL: openAiBaseUrl || PROXY_URL,
});
}

export async function cachedChatCompletion(
params: CachedLLMParams,
options: { cache?: ChatCache } & OpenAIAuth
): Promise<ChatCompletion> {
const { cache, openAiApiKey, openAiOrganizationId, openAiBaseUrl } = options;
const { cache } = options;

return await currentSpanTraced(
"OpenAI Completion",
Expand All @@ -44,16 +54,7 @@ export async function cachedChatCompletion(
if (ret) {
cached = true;
} else {
const openai = new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY,
organization: openAiOrganizationId,
baseURL: openAiBaseUrl || PROXY_URL,
});

if (openai === null) {
throw new Error("OPENAI_API_KEY not set");
}

const openai = buildOpenAIClient(options);
const completion = await openai.chat.completions.create(params);

await cache?.set(params, completion);
Expand Down
82 changes: 81 additions & 1 deletion js/string.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Scorer } from "@braintrust/core";
import levenshtein from "js-levenshtein";
import { OpenAIAuth, buildOpenAIClient } from "./oai.js";
import { CreateEmbeddingResponse } from "openai/resources/embeddings.mjs";
import { SpanLogFn, currentSpanTraced } from "./util.js";
import { OpenAI } from "openai";
import cossim from "compute-cosine-similarity";

/**
* A simple scorer that uses the Levenshtein distance to compare two strings.
Expand All @@ -18,7 +23,82 @@ export const LevenshteinScorer: Scorer<string, {}> = (args) => {
}

return {
name: "levenshtein",
name: "Levenshtein",
score,
};
};

export const EmbeddingDistance: Scorer<
string,
{
prefix?: string;
expectedMin?: number;
model?: string;
} & OpenAIAuth
> = async (args) => {
if (args.expected === undefined) {
throw new Error("EmbeddingDistance requires an expected value");
}

const prefix = args.prefix ?? "";
const expectedMin = args.expectedMin ?? 0.7;

const [output, expected] = [
`${prefix}${args.output}`,
`${prefix}${args.expected}`,
];

const openai = buildOpenAIClient(args);

const [outputResult, expectedResult] = await Promise.all(
[output, expected].map((input) =>
embed(openai, {
input,
model: args.model ?? "text-embedding-ada-002",
})
)
);

const score = cossim(
outputResult.data[0].embedding,
expectedResult.data[0].embedding
);

return {
name: "EmbeddingDistance",
score: scaleScore(score ?? 0, expectedMin),
error: score === null ? "EmbeddingDistance failed" : undefined,
};
};

function scaleScore(score: number, expectedMin: number): number {
return Math.max((score - expectedMin) / (1 - expectedMin), 0);
}

async function embed(
openai: OpenAI,
params: OpenAI.Embeddings.EmbeddingCreateParams
): Promise<CreateEmbeddingResponse> {
return await currentSpanTraced(
"OpenAI Embedding",
async (spanLog: SpanLogFn) => {
const result = await openai.embeddings.create(params);
const output = result.data[0].embedding;

const { input, ...rest } = params;
spanLog({
input,
output,
metadata: {
...rest,
},
metrics: {
tokens: result.usage?.total_tokens,
prompt_tokens: result.usage?.prompt_tokens,
},
});

return result;
}
);
}
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"dependencies": {
"@braintrust/core": "^0.0.6",
"@types/node": "^20.4.4",
"compute-cosine-similarity": "^1.1.0",
"esbuild": "^0.19.1",
"js-levenshtein": "^1.1.6",
"js-yaml": "^4.1.0",
Expand Down
45 changes: 32 additions & 13 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
import sys
import textwrap
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

PROXY_URL = "https://braintrustproxy.com/v1"


def prepare_openai_complete(is_async=False, api_key=None, base_url=None):
@dataclass
class OpenAIWrapper:
complete: Any
embed: Any
RateLimitError: Exception


def prepare_openai(is_async=False, api_key=None, base_url=None):
if base_url is None:
base_url = PROXY_URL

Expand All @@ -30,6 +39,7 @@ def prepare_openai_complete(is_async=False, api_key=None, base_url=None):

openai_obj = openai
is_v1 = False

if hasattr(openai, "OpenAI"):
# This is the new v1 API
is_v1 = True
Expand All @@ -52,16 +62,26 @@ def prepare_openai_complete(is_async=False, api_key=None, base_url=None):
complete_fn = None
rate_limit_error = None
if is_v1:
rate_limit_error = openai.RateLimitError
complete_fn = openai_obj.chat.completions.create
wrapper = OpenAIWrapper(
complete=openai_obj.chat.completions.create,
embed=openai_obj.embeddings.create,
RateLimitError=openai.RateLimitError,
)
else:
rate_limit_error = openai.error.RateLimitError
if is_async:
complete_fn = openai_obj.ChatCompletion.acreate
embedding_fn = openai_obj.Embedding.acreate
else:
complete_fn = openai_obj.ChatCompletion.create
embedding_fn = openai_obj.Embedding.create
wrapper = OpenAIWrapper(
complete=complete_fn,
embed=embedding_fn,
RateLimitError=rate_limit_error,
)

return complete_fn, rate_limit_error
return wrapper


def post_process_response(resp):
Expand All @@ -74,34 +94,33 @@ def post_process_response(resp):
return resp.dict()


def run_cached_request(api_key=None, base_url=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key, base_url=base_url)
def run_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper = prepare_openai(is_async=False, api_key=api_key, base_url=base_url)

retries = 0
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(complete(**kwargs))
resp = post_process_response(getattr(wrapper, request_type)(**kwargs))
break
except RateLimitError:
except wrapper.RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1

return resp


async def arun_cached_request(api_key=None, base_url=None, **kwargs):
complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key, base_url=base_url)
async def arun_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper = prepare_openai(is_async=True, api_key=api_key, base_url=base_url)

retries = 0
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(await complete(**kwargs))
resp = post_process_response(await getattr(wrapper, request_type)(**kwargs))
break
except RateLimitError:
except wrapper.RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
Expand Down
Loading

0 comments on commit 8a00ce6

Please sign in to comment.