Skip to content

Commit

Permalink
Thread api key through
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Nov 3, 2023
1 parent c47d88f commit a74b059
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
max_tokens=None,
temperature=None,
engine=None,
api_key=None,
):
self.name = name
self.model = model
Expand All @@ -90,6 +91,8 @@ def __init__(
self.extra_args = {"temperature": temperature or 0}
if max_tokens:
self.extra_args["max_tokens"] = max(max_tokens, 5)
if api_key:
self.extra_args["api_key"] = api_key

self.render_args = {}
if render_args:
Expand Down Expand Up @@ -191,6 +194,7 @@ def __init__(
max_tokens=512,
temperature=0,
engine=None,
api_key=None,
):
choice_strings = list(choice_scores.keys())

Expand All @@ -211,6 +215,7 @@ def __init__(
max_tokens=max_tokens,
temperature=temperature,
engine=engine,
api_key=api_key,
render_args={"__choices": choice_strings},
)

Expand All @@ -226,7 +231,7 @@ def from_spec_file(cls, name: str, path: str, **kwargs):


class SpecFileClassifier(LLMClassifier):
def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None):
def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None):
kwargs = {}
if model is not None:
kwargs["model"] = model
Expand All @@ -238,6 +243,8 @@ def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, tempera
kwargs["max_tokens"] = max_tokens
if temperature is not None:
kwargs["temperature"] = temperature
if api_key is not None:
kwargs["api_key"] = api_key

# convert FooBar to foo_bar
template_name = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__).lower()
Expand Down

0 comments on commit a74b059

Please sign in to comment.