Skip to content

Commit

Permalink
Propagate args to embedding metrics (#83)
Browse files Browse the repository at this point in the history
Fixes #81
  • Loading branch information
ankrgyl authored Jul 23, 2024
1 parent 29060ed commit aa42b64
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
14 changes: 13 additions & 1 deletion js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ export interface OpenAIAuth {
openAiDangerouslyAllowBrowser?: boolean;
}

export function extractOpenAIArgs<T extends Record<string, unknown>>(
args: OpenAIAuth & T,
): OpenAIAuth {
return {
openAiApiKey: args.openAiApiKey,
openAiOrganizationId: args.openAiOrganizationId,
openAiBaseUrl: args.openAiBaseUrl,
openAiDefaultHeaders: args.openAiDefaultHeaders,
openAiDangerouslyAllowBrowser: args.openAiDangerouslyAllowBrowser,
};
}

const PROXY_URL = "https://braintrustproxy.com/v1";

export function buildOpenAIClient(options: OpenAIAuth): OpenAI {
Expand Down Expand Up @@ -64,6 +76,6 @@ export async function cachedChatCompletion(
params: CachedLLMParams,
options: { cache?: ChatCache } & OpenAIAuth,
): Promise<ChatCompletion> {
let openai = buildOpenAIClient(options);
const openai = buildOpenAIClient(options);
return await openai.chat.completions.create(params);
}
19 changes: 6 additions & 13 deletions js/ragas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import mustache from "mustache";

import { Scorer, ScorerArgs } from "@braintrust/core";
import { DEFAULT_MODEL, LLMArgs } from "./llm";
import { buildOpenAIClient } from "./oai";
import { buildOpenAIClient, extractOpenAIArgs } from "./oai";
import OpenAI from "openai";
import { ListContains } from "./list";
import { EmbeddingSimilarity } from "./string";
Expand Down Expand Up @@ -99,6 +99,7 @@ export const ContextEntityRecall: ScorerWithPartial<
const [expectedEntities, contextEntities] = responses.map(mustParseArgs);

const score = await ListContains({
...extractOpenAIArgs(args),
pairwiseScorer: args.pairwiseScorer ?? EmbeddingSimilarity,
allowExtraEntities: true,
output: entitySchema.parse(contextEntities).entities,
Expand Down Expand Up @@ -652,6 +653,7 @@ export const AnswerRelevancy: ScorerWithPartial<
const similarity = await Promise.all(
questions.map(async ({ question }) => {
const { score } = await EmbeddingSimilarity({
...extractOpenAIArgs(args),
output: question,
expected: input,
});
Expand Down Expand Up @@ -679,18 +681,18 @@ export const AnswerRelevancy: ScorerWithPartial<
*/
export const AnswerSimilarity: ScorerWithPartial<string, RagasArgs> =
makePartial(async (args) => {
const { chatArgs, client, ...inputs } = parseArgs(args);
const { ...inputs } = parseArgs(args);

const { output, expected } = checkRequired(
{ output: inputs.output, expected: inputs.expected },
"AnswerSimilarity",
);

const { score, error } = await EmbeddingSimilarity({
...extractOpenAIArgs(args),
output,
expected,
expectedMin: 0,
model: args.model,
});

return {
Expand Down Expand Up @@ -854,16 +856,7 @@ function parseArgs(args: ScorerArgs<string, RagasArgs>): {
>;
client: OpenAI;
} {
const {
input,
output,
expected,
context,
model,
temperature,
maxTokens,
...clientArgs
} = args;
const { input, output, expected, context, ...clientArgs } = args;
const chatArgs: Omit<
OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
"messages"
Expand Down

0 comments on commit aa42b64

Please sign in to comment.