Skip to content

Commit

Permalink
Add provider as response header - BRA-1882 (#126)
Browse files Browse the repository at this point in the history
Adds the provider used to complete an llm call as a response header

Easier to see changes with whitespace hidden
  • Loading branch information
tara-nagar authored Jan 7, 2025
1 parent 0546ef3 commit d3977ad
Showing 1 changed file with 71 additions and 58 deletions.
129 changes: 71 additions & 58 deletions packages/proxy/src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ export const FORMAT_HEADER = "x-bt-stream-fmt";
export const LEGACY_CACHED_HEADER = "x-cached";
export const CACHED_HEADER = "x-bt-cached";

export const USED_ENDPOINT_HEADER = "x-bt-used-endpoint";

const CACHE_MODES = ["auto", "always", "never"] as const;

// Options to control how the cache key is generated.
Expand Down Expand Up @@ -361,61 +363,63 @@ export async function proxyV1({
);
}

const { response: proxyResponse, stream: proxyStream } =
await fetchModelLoop(
meter,
method,
url,
headers,
bodyData,
setOverriddenHeader,
async (model) => {
// First, try to use temp credentials, because then we'll get access
// to the model.
let cachedAuthToken: string | undefined;
if (
useCredentialsCacheMode !== "never" &&
isTempCredential(authToken)
) {
const { credentialCacheValue, jwtPayload } =
await verifyTempCredentials({
jwt: authToken,
cacheGet,
});
// Unwrap the API key here to avoid a duplicate call to
// `verifyTempCredentials` inside `getApiSecrets`. That call will
// use Redis which is not available in Cloudflare.
cachedAuthToken = credentialCacheValue.authToken;
if (jwtPayload.bt.logging) {
console.warn(
`Logging was requested, but not supported on ${method} ${url}`,
);
}
if (jwtPayload.bt.model && jwtPayload.bt.model !== model) {
console.warn(
`Temp credential allows model "${jwtPayload.bt.model}", but "${model}" was requested`,
);
return [];
}
const {
modelResponse: { response: proxyResponse, stream: proxyStream },
secretName,
} = await fetchModelLoop(
meter,
method,
url,
headers,
bodyData,
setOverriddenHeader,
async (model) => {
// First, try to use temp credentials, because then we'll get access
// to the model.
let cachedAuthToken: string | undefined;
if (
useCredentialsCacheMode !== "never" &&
isTempCredential(authToken)
) {
const { credentialCacheValue, jwtPayload } =
await verifyTempCredentials({
jwt: authToken,
cacheGet,
});
// Unwrap the API key here to avoid a duplicate call to
// `verifyTempCredentials` inside `getApiSecrets`. That call will
// use Redis which is not available in Cloudflare.
cachedAuthToken = credentialCacheValue.authToken;
if (jwtPayload.bt.logging) {
console.warn(
`Logging was requested, but not supported on ${method} ${url}`,
);
}

const secrets = await getApiSecrets(
useCredentialsCacheMode !== "never",
cachedAuthToken || authToken,
model,
orgName,
);
if (endpointName) {
return secrets.filter((s) => s.name === endpointName);
} else {
return secrets;
if (jwtPayload.bt.model && jwtPayload.bt.model !== model) {
console.warn(
`Temp credential allows model "${jwtPayload.bt.model}", but "${model}" was requested`,
);
return [];
}
},
spanLogger,
(st) => {
spanType = st;
},
);
}

const secrets = await getApiSecrets(
useCredentialsCacheMode !== "never",
cachedAuthToken || authToken,
model,
orgName,
);
if (endpointName) {
return secrets.filter((s) => s.name === endpointName);
} else {
return secrets;
}
},
spanLogger,
(st) => {
spanType = st;
},
);
stream = proxyStream;

if (!proxyResponse.ok) {
Expand Down Expand Up @@ -449,6 +453,10 @@ export async function proxyV1({
}
proxyResponseHeaders[name] = value;
});
if (secretName) {
setHeader(USED_ENDPOINT_HEADER, secretName);
proxyResponseHeaders[USED_ENDPOINT_HEADER] = secretName;
}

for (const [name, value] of Object.entries(proxyResponseHeaders)) {
setHeader(name, value);
Expand Down Expand Up @@ -718,7 +726,7 @@ async function fetchModelLoop(
getApiSecrets: (model: string | null) => Promise<APISecret[]>,
spanLogger: SpanLogger | undefined,
setSpanType: (spanType: SpanType) => void,
): Promise<ModelResponse> {
): Promise<{ modelResponse: ModelResponse; secretName?: string | null }> {
const requestId = ++loopIndex;

const endpointCalls = meter.createCounter("endpoint_calls");
Expand Down Expand Up @@ -750,7 +758,8 @@ async function fetchModelLoop(
// TODO: Make this smarter. For now, just pick a random one.
const secrets = await getApiSecrets(model);
const initialIdx = getRandomInt(secrets.length);
let proxyResponse = null;
let proxyResponse: ModelResponse | null = null;
let secretName: string | null | undefined = null;
let lastException = null;
let loggableInfo: Record<string, any> = {};

Expand Down Expand Up @@ -821,6 +830,7 @@ async function fetchModelLoop(
bodyData,
setHeader,
);
secretName = secret.name;
if (
proxyResponse.response.ok ||
(proxyResponse.response.status >= 400 &&
Expand Down Expand Up @@ -935,8 +945,11 @@ async function fetchModelLoop(
stream = stream.pipeThrough(timingStream);
}
return {
stream,
response: proxyResponse.response,
modelResponse: {
stream,
response: proxyResponse.response,
},
secretName,
};
}

Expand All @@ -948,7 +961,7 @@ async function fetchModel(
secret: APISecret,
bodyData: null | any,
setHeader: (name: string, value: string) => void,
) {
): Promise<ModelResponse> {
switch (format) {
case "openai":
if (secret.type === "bedrock") {
Expand Down

0 comments on commit d3977ad

Please sign in to comment.