Skip to content

Commit

Permalink
fix: support log entries created for secured API keys (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jan 15, 2025
1 parent 7526bca commit 68ab942
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 68 deletions.
88 changes: 51 additions & 37 deletions aidial_analytics_realtime/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ async def make_point(
user_hash: str,
user_title: str,
timestamp: datetime,
request: dict,
response: dict,
request: dict | None,
response: dict | None,
request_type: RequestType,
usage: dict | None,
topic_model: TopicModel,
Expand All @@ -95,36 +95,40 @@ async def make_point(
topic = None
response_content = ""
request_content = ""
match request_type:
case RequestType.CHAT_COMPLETION:
response_contents = get_chat_completion_response_contents(
logger, response
)
request_contents = get_chat_completion_request_contents(
logger, request
)

request_content = "\n".join(request_contents)
response_content = "\n".join(response_contents)
if response is not None and request is not None:
match request_type:
case RequestType.CHAT_COMPLETION:
response_contents = get_chat_completion_response_contents(
logger, response
)
request_contents = get_chat_completion_request_contents(
logger, request
)

request_content = "\n".join(request_contents)
response_content = "\n".join(response_contents)

if chat_id:
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
if chat_id:
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
)
)
case RequestType.EMBEDDING:
request_contents = get_embeddings_request_contents(
logger, request
)
case RequestType.EMBEDDING:
request_contents = get_embeddings_request_contents(logger, request)

request_content = "\n".join(request_contents)
if chat_id:
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents)
request_content = "\n".join(request_contents)
if chat_id:
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents)
)
)
)
case _:
assert_never(request_type)
case _:
assert_never(request_type)

price = Decimal(0)
deployment_price = Decimal(0)
Expand Down Expand Up @@ -162,7 +166,7 @@ async def make_point(
"language",
(
"undefined"
if not chat_id
if not chat_id or request is None or response is None
else await detect_lang(logger, request, response, request_type)
),
)
Expand All @@ -174,6 +178,7 @@ async def make_point(
(
response["id"]
if request_type == RequestType.CHAT_COMPLETION
and response is not None
else uuid4()
),
)
Expand All @@ -183,12 +188,16 @@ async def make_point(
.field(
"number_request_messages",
(
len(request["messages"])
if request_type == RequestType.CHAT_COMPLETION
0
if request is None
else (
1
if isinstance(request["input"], str)
else len(request["input"])
len(request["messages"])
if request_type == RequestType.CHAT_COMPLETION
else (
1
if isinstance(request["input"], str)
else len(request["input"])
)
)
),
)
Expand Down Expand Up @@ -239,7 +248,10 @@ def make_rate_point(
return point


async def parse_usage_per_model(response: dict):
async def parse_usage_per_model(response: dict | None):
if response is None:
return []

statistics = response.get("statistics")
if statistics is None:
return []
Expand All @@ -265,8 +277,8 @@ async def on_message(
user_hash: str,
user_title: str,
timestamp: datetime,
request: dict,
response: dict,
request: dict | None,
response: dict | None,
type: RequestType,
topic_model: TopicModel,
rates_calculator: RatesCalculator,
Expand All @@ -275,9 +287,11 @@ async def on_message(
trace: dict | None,
execution_path: list | None,
):
logger.info(f"Chat completion response length {len(response)}")
logger.info(f"Chat completion response length {len(response or [])}")

usage_per_model = await parse_usage_per_model(response)
response_usage = None if response is None else response.get("usage")

if token_usage is not None:
point = await make_point(
logger,
Expand Down Expand Up @@ -314,7 +328,7 @@ async def on_message(
request,
response,
type,
response.get("usage"),
response_usage,
topic_model,
rates_calculator,
parent_deployment,
Expand Down
69 changes: 42 additions & 27 deletions aidial_analytics_realtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,42 +96,47 @@ async def on_chat_completion_message(
if response["status"] != "200":
return

request_body = json.loads(request["body"])
stream = request_body.get("stream", False)
model = request_body.get("model", deployment)

response_body = None
if stream:
body = response["body"]
chunks = body.split("\n\ndata: ")
request_body = None
model: str | None = None

chunks = [chunk.strip() for chunk in chunks]
if (request_body_str := request.get("body")) is not None:

chunks[0] = chunks[0][chunks[0].find("data: ") + 6 :]
if chunks[-1] == "[DONE]":
chunks.pop(len(chunks) - 1)
request_body = json.loads(request_body_str)
stream = request_body.get("stream", False)
model = request_body.get("model", deployment)

response_body = json.loads(chunks[-1])
for chunk in chunks[0 : len(chunks) - 1]:
chunk = json.loads(chunk)
if stream:
body = response["body"]
chunks = body.split("\n\ndata: ")

response_body["choices"] = merge(
response_body["choices"], chunk["choices"]
)
chunks = [chunk.strip() for chunk in chunks]

for i in range(len(response_body["choices"])):
response_body["choices"][i]["message"] = response_body["choices"][
i
]["delta"]
del response_body["choices"][i]["delta"]
else:
response_body = json.loads(response["body"])
chunks[0] = chunks[0][chunks[0].find("data: ") + 6 :]
if chunks[-1] == "[DONE]":
chunks.pop(len(chunks) - 1)

response_body = json.loads(chunks[-1])
for chunk in chunks[0 : len(chunks) - 1]:
chunk = json.loads(chunk)

response_body["choices"] = merge(
response_body["choices"], chunk["choices"]
)

for i in range(len(response_body["choices"])):
response_body["choices"][i]["message"] = response_body[
"choices"
][i]["delta"]
del response_body["choices"][i]["delta"]
else:
response_body = json.loads(response["body"])

await on_message(
logger,
influx_writer,
deployment,
model,
model or deployment,
project_id,
chat_id,
upstream_url,
Expand Down Expand Up @@ -171,6 +176,16 @@ async def on_embedding_message(
if response["status"] != "200":
return

request_body_str = request.get("body")
response_body_str = response.get("body")

request_body = (
None if request_body_str is None else json.loads(request_body_str)
)
response_body = (
None if response_body_str is None else json.loads(response_body_str)
)

await on_message(
logger,
influx_writer,
Expand All @@ -182,8 +197,8 @@ async def on_embedding_message(
user_hash,
user_title,
timestamp,
json.loads(request["body"]),
json.loads(response["body"]),
request_body,
response_body,
RequestType.EMBEDDING,
topic_model,
rates_calculator,
Expand Down
Loading

0 comments on commit 68ab942

Please sign in to comment.