From 24522e62675ae88cac66bdbc7fe3959cf8f21726 Mon Sep 17 00:00:00 2001
From: Olmo Maldonado <olmo.maldonado@gmail.com>
Date: Thu, 12 Dec 2024 11:28:26 -0800
Subject: [PATCH] fix gap in our client= coverage

---
 py/autoevals/llm.py      | 20 +++++++++----
 py/autoevals/test_llm.py | 61 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 75 insertions(+), 6 deletions(-)

diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py
index bca51c4..2d26aae 100644
--- a/py/autoevals/llm.py
+++ b/py/autoevals/llm.py
@@ -269,21 +269,29 @@ def __init__(
         )
 
     @classmethod
-    def from_spec(cls, name: str, spec: ModelGradedSpec, **kwargs):
-        return cls(name, spec.prompt, spec.choice_scores, **kwargs)
+    def from_spec(cls, name: str, spec: ModelGradedSpec, client: Optional[AutoEvalClient] = None, **kwargs):
+        return cls(name, spec.prompt, spec.choice_scores, client=client, **kwargs)
 
     @classmethod
-    def from_spec_file(cls, name: str, path: str, **kwargs):
+    def from_spec_file(cls, name: str, path: str, client: Optional[AutoEvalClient] = None, **kwargs):
         if cls._SPEC_FILE_CONTENTS is None:
             with open(path) as f:
                 cls._SPEC_FILE_CONTENTS = f.read()
         spec = yaml.safe_load(cls._SPEC_FILE_CONTENTS)
-        return cls.from_spec(name, ModelGradedSpec(**spec), **kwargs)
+        return cls.from_spec(name, ModelGradedSpec(**spec), client=client, **kwargs)
 
 
 class SpecFileClassifier(LLMClassifier):
     def __new__(
-        cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None, base_url=None
+        cls,
+        model=None,
+        engine=None,
+        use_cot=None,
+        max_tokens=None,
+        temperature=None,
+        api_key=None,
+        base_url=None,
+        client: Optional[AutoEvalClient] = None,
     ):
         kwargs = {}
         if model is not None:
@@ -311,7 +319,7 @@ def __new__(
 
         extra_render_args = cls._partial_args() if hasattr(cls, "_partial_args") else {}
 
-        return LLMClassifier.from_spec_file(cls_name, template_path, **kwargs, **extra_render_args)
+        return LLMClassifier.from_spec_file(cls_name, template_path, client=client, **kwargs, **extra_render_args)
 
 
 class Battle(SpecFileClassifier):
diff --git a/py/autoevals/test_llm.py b/py/autoevals/test_llm.py
index 4b3c89a..e55dfe4 100644
--- a/py/autoevals/test_llm.py
+++ b/py/autoevals/test_llm.py
@@ -168,6 +168,67 @@ def test_factuality():
     assert result.score == 1
 
 
+def test_factuality_client():
+    client = Mock()
+    client.RateLimitError = Exception
+
+    completion = Mock()
+    completion.to_dict.return_value = {
+        "id": "chatcmpl-AdiS4bHWjqSclA5rx7OkuZ6EA9QIp",
+        "choices": [
+            {
+                "finish_reason": "stop",
+                "index": 0,
+                "logprobs": None,
+                "message": {
+                    "content": None,
+                    "refusal": None,
+                    "role": "assistant",
+                    "tool_calls": [
+                        {
+                            "id": "call_JKoeGAX2zGPJAmF2muDgjpHp",
+                            "function": {
+                                "arguments": '{"reasons":"1. The question asks to add the numbers 1, 2, and 3.\\n2. The expert answer provides the sum of these numbers as 6.\\n3. The submitted answer also provides the sum as 6.\\n4. Both the expert and submitted answers provide the same numerical result, which is 6.\\n5. Since both answers provide the same factual content, the submitted answer contains all the same details as the expert answer.\\n6. There is no additional information or discrepancy between the two answers.\\n7. Therefore, the submitted answer is neither a subset nor a superset; it is exactly the same as the expert answer in terms of factual content.","choice":"C"}',
+                                "name": "select_choice",
+                            },
+                            "type": "function",
+                        }
+                    ],
+                },
+            }
+        ],
+        "created": 1734029028,
+        "model": "gpt-4o-2024-08-06",
+        "object": "chat.completion",
+        "system_fingerprint": "fp_cc5cf1c6e3",
+        "usage": {
+            "completion_tokens": 149,
+            "prompt_tokens": 404,
+            "total_tokens": 553,
+            "completion_tokens_details": {
+                "accepted_prediction_tokens": 0,
+                "audio_tokens": 0,
+                "reasoning_tokens": 0,
+                "rejected_prediction_tokens": 0,
+            },
+            "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
+        },
+    }
+
+    client.complete.return_value = completion
+
+    llm = Factuality(client=cast(AutoEvalClient, client))
+    result = llm.eval(
+        output="6",
+        expected="6",
+        input="Add the following numbers: 1, 2, 3",
+    )
+
+    assert client.complete.call_count == 1
+
+    assert result.score == 1
+
+
 # make sure we deny any leaked calls to OpenAI
 @respx.mock(base_url="https://api.openai.com/v1/")
 def test_init_client():