Skip to content

Commit

Permalink
Add scorer purpose (#87)
Browse files Browse the repository at this point in the history
This allows us to exclude LLM-as-a-judge calls from certain metrics
(e.g. token counts)
  • Loading branch information
ankrgyl authored Aug 7, 2024
1 parent 4ceb49c commit 2c47406
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
19 changes: 18 additions & 1 deletion js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ export interface CachedLLMParams {
tool_choice?: ChatCompletionToolChoiceOption;
temperature?: number;
max_tokens?: number;
span_info?: {
spanAttributes?: Record<string, string>;
};
}

export interface ChatCache {
Expand Down Expand Up @@ -69,6 +72,7 @@ export function buildOpenAIClient(options: OpenAIAuth): OpenAI {
}

declare global {
/* eslint-disable no-var */
var __inherited_braintrust_wrap_openai: ((openai: any) => any) | undefined;
}

Expand All @@ -77,5 +81,18 @@ export async function cachedChatCompletion(
options: { cache?: ChatCache } & OpenAIAuth,
): Promise<ChatCompletion> {
const openai = buildOpenAIClient(options);
return await openai.chat.completions.create(params);

const fullParams = globalThis.__inherited_braintrust_wrap_openai
? {
...params,
span_info: {
spanAttributes: {
...params.span_info?.spanAttributes,
purpose: "scorer",
},
},
}
: params;

return await openai.chat.completions.create(fullParams);
}
16 changes: 13 additions & 3 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def prepare_openai(is_async=False, api_key=None, base_url=None):
openai.api_key = api_key
openai.api_base = base_url

wrapped = False
try:
from braintrust.oai import wrap_openai

openai_obj = wrap_openai(openai_obj)
wrapped = True
except ImportError:
pass

Expand Down Expand Up @@ -89,7 +91,7 @@ def prepare_openai(is_async=False, api_key=None, base_url=None):
RateLimitError=rate_limit_error,
)

return wrapper
return wrapper, wrapped


def post_process_response(resp):
Expand All @@ -102,8 +104,14 @@ def post_process_response(resp):
return resp.dict()


def set_span_purpose(kwargs):
kwargs.setdefault("span_info", {}).setdefault("span_attributes", {})["purpose"] = "scorer"


def run_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper = prepare_openai(is_async=False, api_key=api_key, base_url=base_url)
wrapper, wrapped = prepare_openai(is_async=False, api_key=api_key, base_url=base_url)
if wrapped:
set_span_purpose(kwargs)

retries = 0
sleep_time = 0.1
Expand All @@ -120,7 +128,9 @@ def run_cached_request(request_type="complete", api_key=None, base_url=None, **k


async def arun_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper = prepare_openai(is_async=True, api_key=api_key, base_url=base_url)
wrapper, wrapped = prepare_openai(is_async=True, api_key=api_key, base_url=base_url)
if wrapped:
set_span_purpose(kwargs)

retries = 0
sleep_time = 0.1
Expand Down

0 comments on commit 2c47406

Please sign in to comment.