Skip to content

Commit

Permalink
rename AutoEvalClient to LLMClient and add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ibolmo committed Dec 12, 2024
1 parent 24522e6 commit a28091f
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 50 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ print(f"Factuality score: {result.score}")
print(f"Factuality metadata: {result.metadata['rationale']}")
```

#### Custom Client

If you need to use a custom OpenAI client, you can initialize the library with a custom client.

```python
import openai
from autoevals import init
from autoevals.oai import LLMClient

openai_client = openai.OpenAI(base_url="https://api.openai.com/v1/")

class CustomClient(LLMClient):
openai=openai_client # you can also pass in openai module and we will instantiate it for you
embed = openai.embeddings.create
moderation = openai.moderations.create
RateLimitError = openai.RateLimitError

def complete(self, **kwargs):
# make adjustments as needed
return self.openai.chat.completions.create(**kwargs)

# Autoevals will now use your custom client
client = init(client=CustomClient)
```

If you only need to use a custom client for a specific evaluator, you can pass in the client to the evaluator.

```python
evaluator = Factuality(client=CustomClient)
```

### Node.js

```javascript
Expand Down
17 changes: 9 additions & 8 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from autoevals.partial import ScorerWithPartial

from .oai import AutoEvalClient, arun_cached_request, run_cached_request
from .oai import LLMClient, arun_cached_request, run_cached_request

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
):
self.extra_args = {}
if api_key:
Expand All @@ -95,7 +95,7 @@ def __init__(
temperature=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
):
super().__init__(
api_key=api_key,
Expand All @@ -119,7 +119,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
):
super().__init__(
client=client,
Expand Down Expand Up @@ -240,7 +240,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
**extra_render_args,
):
choice_strings = list(choice_scores.keys())
Expand Down Expand Up @@ -269,11 +269,11 @@ def __init__(
)

@classmethod
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Optional[AutoEvalClient] = None, **kwargs):
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Optional[LLMClient] = None, **kwargs):
return cls(name, spec.prompt, spec.choice_scores, client=client, **kwargs)

@classmethod
def from_spec_file(cls, name: str, path: str, client: Optional[AutoEvalClient] = None, **kwargs):
def from_spec_file(cls, name: str, path: str, client: Optional[LLMClient] = None, **kwargs):
if cls._SPEC_FILE_CONTENTS is None:
with open(path) as f:
cls._SPEC_FILE_CONTENTS = f.read()
Expand All @@ -291,7 +291,7 @@ def __new__(
temperature=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
):
kwargs = {}
if model is not None:
Expand Down Expand Up @@ -386,3 +386,4 @@ class Translation(SpecFileClassifier):
as an expert (`expected`) value.."""

pass
pass
5 changes: 3 additions & 2 deletions py/autoevals/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from autoevals.llm import OpenAIScorer

from .oai import AutoEvalClient, arun_cached_request, run_cached_request
from .oai import LLMClient, arun_cached_request, run_cached_request

REQUEST_TYPE = "moderation"

Expand All @@ -22,7 +22,7 @@ def __init__(
threshold=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
client: Optional[LLMClient] = None,
):
"""
Create a new Moderation scorer.
Expand Down Expand Up @@ -72,3 +72,4 @@ def compute_score(moderation_result, threshold):


__all__ = ["Moderation"]
__all__ = ["Moderation"]
81 changes: 69 additions & 12 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,70 @@


@dataclass
class AutoEvalClient:
# TODO: add docs
# TODO: how to type if we don't depend on openai
class LLMClient:
"""A client wrapper for LLM operations that supports both OpenAI SDK v0 and v1.
This class provides a consistent interface for common LLM operations regardless of the
underlying OpenAI SDK version. It's designed to be extensible for custom implementations.
Attributes:
openai: The OpenAI module or client instance (either v0 or v1 SDK).
complete: Completion function that creates chat completions.
- For v0: openai.ChatCompletion.create or acreate
- For v1: openai.chat.completions.create
embed: Embedding function that creates embeddings.
- For v0: openai.Embedding.create or acreate
- For v1: openai.embeddings.create
moderation: Moderation function that creates content moderations.
- For v0: openai.Moderations.create or acreate
- For v1: openai.moderations.create
RateLimitError: The rate limit exception class for the SDK version.
- For v0: openai.error.RateLimitError
- For v1: openai.RateLimitError
Note:
If using async OpenAI methods you must use the async methods in Autoevals.
Example:
```python
# Using with OpenAI v1
import openai
client = LLMClient(
openai=openai,
complete=openai.chat.completions.create,
embed=openai.embeddings.create,
moderation=openai.moderations.create,
RateLimitError=openai.RateLimitError
)
# Extending for custom implementation
@dataclass
class CustomLLMClient(LLMClient):
def complete(self, **kwargs):
# make adjustments as needed
return openai.chat.completions.create(**kwargs)
```
Note:
This class is typically instantiated via the `prepare_openai()` function, which handles
the SDK version detection and proper function assignment automatically.
"""

openai: Any
complete: Any
embed: Any
moderation: Any
RateLimitError: Exception


_client_var = ContextVar[Optional[AutoEvalClient]]("client")
_client_var = ContextVar[Optional[LLMClient]]("client")


def init(*, client: Optional[AutoEvalClient] = None):
def init(*, client: Optional[LLMClient] = None):
_client_var.set(client)


def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_key=None, base_url=None):
def prepare_openai(client: Optional[LLMClient] = None, is_async=False, api_key=None, base_url=None):
"""Prepares and configures an OpenAI client for use with AutoEval, if client is not provided.
This function handles both v0 and v1 of the OpenAI SDK, configuring the client
Expand All @@ -37,7 +83,7 @@ def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_
We will also attempt to enable Braintrust tracing export, if you've configured tracing.
Args:
client (Optional[AutoEvalClient], optional): Existing AutoEvalClient instance.
client (Optional[LLMClient], optional): Existing LLMClient instance.
If provided, this client will be used instead of creating a new one.
is_async (bool, optional): Whether to create a client with async operations. Defaults to False.
Expand All @@ -54,8 +100,8 @@ def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_
Deprecated: Use the `client` argument and set the `openai`.
Returns:
Tuple[AutoEvalClient, bool]: A tuple containing:
- The configured AutoEvalClient instance, or the client you've provided
Tuple[LLMClient, bool]: A tuple containing:
- The configured LLMClient instance, or the client you've provided
- A boolean indicating whether the client was wrapped with Braintrust tracing
Raises:
Expand Down Expand Up @@ -124,7 +170,7 @@ def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_
complete_fn = None
rate_limit_error = None

Client = AutoEvalClient
Client = LLMClient

if is_v1:
client = Client(
Expand Down Expand Up @@ -170,7 +216,7 @@ def set_span_purpose(kwargs):


def run_cached_request(
*, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
*, client: Optional[LLMClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
):
wrapper, wrapped = prepare_openai(client=client, is_async=False, api_key=api_key, base_url=base_url)
if wrapped:
Expand All @@ -191,7 +237,7 @@ def run_cached_request(


async def arun_cached_request(
*, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
*, client: Optional[LLMClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
):
wrapper, wrapped = prepare_openai(client=client, is_async=True, api_key=api_key, base_url=base_url)
if wrapped:
Expand All @@ -210,3 +256,14 @@ async def arun_cached_request(
retries += 1

return resp
return resp
return resp
return resp
return resp
return resp
return resp
return resp
return resp
return resp
return resp
return resp
Loading

0 comments on commit a28091f

Please sign in to comment.