-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
code for the new article: https://medium.com/@enoch3712/optimizing-ll…
- Loading branch information
Showing
13 changed files
with
235 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
class Cascade: | ||
def __init__(self, models, evaluator): | ||
# List of models (ModelDecorator instances) in the order they should be tried | ||
self.models = models | ||
# The evaluator instance (Eval class) to use for evaluating model output | ||
self.evaluator = evaluator | ||
|
||
def process(self, input_data): | ||
for model in self.models: | ||
# Assume model.generate(input_data) returns {'content': [{'logprob': value}]} | ||
result = model.generate(input_data) | ||
if self.evaluator.evaluate(result): | ||
# If the evaluator passes the output, return the response | ||
return result.choices[0].message.content | ||
# If none of the models produce a satisfactory response, return an indication of failure | ||
return "No satisfactory response found." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from typing import List, Dict | ||
from pydantic import BaseModel | ||
|
||
class Classification(BaseModel): | ||
name: str | ||
description: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class Eval(ABC): | ||
def __init__(self, threshold: float): | ||
self.threshold = threshold | ||
|
||
@abstractmethod | ||
def evaluate(self, result): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import math | ||
from Classification.Eval import Eval | ||
import numpy as np | ||
|
||
class LogProbEval(Eval): | ||
def __init__(self, threshold): | ||
self.threshold = threshold | ||
|
||
def evaluate(self, result): | ||
# Ensure there are log probabilities to process | ||
if not result.choices[0].logprobs.content: | ||
return 0.0 | ||
|
||
# Calculate the total log probability | ||
total_logprob = sum([math.exp(c.logprob) for c in result.choices[0].logprobs.content]) | ||
|
||
# Calculate the average log probability | ||
avg_logprob = total_logprob / len(result.choices[0].logprobs.content) | ||
|
||
return avg_logprob > self.threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from Classification.ModelDecorator import ModelDecorator | ||
from openai import OpenAI | ||
from config import API_KEY_OPENAI | ||
|
||
client = OpenAI(api_key=API_KEY_OPENAI) | ||
|
||
class Model35(ModelDecorator): | ||
def __init__(self): | ||
super().__init__(self._model_function) | ||
|
||
def _model_function(self, input_data): | ||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are an API classification tools. You receive some text, and return a JSON with the result: \n\ninput:\n##Content\n.....\n##Classifications\nOne: this is a description\nTwo: this is another description\n\noutput:\n{\"result\": \"One\"}" | ||
}, | ||
{ | ||
"role": "user", | ||
"content": input_data | ||
} | ||
], | ||
temperature=1, | ||
max_tokens=256, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
logprobs=True, | ||
top_logprobs=1 | ||
) | ||
|
||
return response | ||
|
||
def generate(self, input_data): | ||
return self.model_function(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from openai import OpenAI | ||
from Classification.ModelDecorator import ModelDecorator | ||
from config import API_KEY_OPENAI | ||
|
||
client = OpenAI(api_key=API_KEY_OPENAI) | ||
|
||
class Model4(ModelDecorator): | ||
def __init__(self): | ||
super().__init__(self._model_function) | ||
|
||
def _model_function(self, input_data): | ||
response = client.chat.completions.create( | ||
model="gpt-4-turbo-preview", | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are an API classification tools. You receive some text, and return a JSON with the result: \n\ninput:\n##Content\n.....\n##Classifications\nOne: this is a description\nTwo: this is another description\n\noutput:\n{\"result\": \"One\"}" | ||
}, | ||
{ | ||
"role": "user", | ||
"content": input_data | ||
} | ||
], | ||
temperature=1, | ||
max_tokens=256, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
logprobs=True, | ||
top_logprobs=1 | ||
) | ||
|
||
return response | ||
|
||
def generate(self, input_data): | ||
return self.model_function(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class ModelDecorator(ABC): | ||
def __init__(self, model_function): | ||
# model_function is a lambda or function that simulates interacting with a specific LLM's API. | ||
self.model_function = model_function | ||
|
||
@abstractmethod | ||
def generate(self, input_data): | ||
# This method simulates generating a response from a language model | ||
# In a real scenario, this would interact with the model's API and return both the response and logprobs | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
API_KEY = 'XXXX' | ||
API_KEY_ANTROPIC = 'XXXX' | ||
API_KEY_ANTROPIC = 'XXXX' | ||
API_KEY_OPENAI = 'XXXX' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import json | ||
|
||
def remove_json_format(json_string: str) -> str: | ||
replace = json_string.replace("```json", "") | ||
replace = replace.replace("```", "") | ||
return replace.strip() | ||
|
||
def remove_last_element(json_string: str) -> str: | ||
try: | ||
json.loads(json_string) | ||
return json_string | ||
except json.JSONDecodeError: | ||
pass | ||
|
||
last_index = json_string.rfind("},") | ||
|
||
if last_index == -1: | ||
return json_string | ||
|
||
trimmed_string = json_string[:last_index + 1] | ||
trimmed_string += "," | ||
return trimmed_string |