Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
enoch3712 committed Apr 9, 2024
1 parent 29d2f78 commit 6a9a1ce
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 33 deletions.
3 changes: 0 additions & 3 deletions Antropic/AnthropicsApiRequest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from Antropic.Message import Message


from typing import List


class AnthropicsApiRequest:
def __init__(self, model: str, max_tokens: int, messages: List['Message'], system: str):
self.model = model
Expand Down
33 changes: 5 additions & 28 deletions Antropic/AnthropicsApiService.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from Antropic.AnthropicsApiRequest import AnthropicsApiRequest
from Antropic.Message import Message
import time


from utils import remove_json_format, remove_last_element
class AnthropicsApiService:
def __init__(self, api_key: str):
self.client = anthropic.Anthropic(api_key=api_key)
Expand Down Expand Up @@ -112,37 +112,14 @@ def send_image_message(self, initial_request: AnthropicsApiRequest, base64_image

final_response = message

content = self.remove_json_format(final_response.content[0].text)
content = remove_json_format(final_response.content[0].text)

if final_response.stop_reason != "end_turn":
content = self.removeLastElement(content)
content = remove_last_element(content)

sb.append(content)

if final_response.stop_reason == "end_turn":
break

return "".join(sb)

@staticmethod
def remove_json_format(json_string: str) -> str:
replace = json_string.replace("```json", "")
replace = replace.replace("```", "")
return replace.strip()

@staticmethod
def removeLastElement(json_string: str) -> str:
try:
json.loads(json_string)
return json_string
except json.JSONDecodeError:
pass

last_index = json_string.rfind("},")

if last_index == -1:
return json_string

trimmed_string = json_string[:last_index + 1]
trimmed_string += ","
return trimmed_string
return "".join(sb)
16 changes: 16 additions & 0 deletions Classification/Cascade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class Cascade:
def __init__(self, models, evaluator):
# List of models (ModelDecorator instances) in the order they should be tried
self.models = models
# The evaluator instance (Eval class) to use for evaluating model output
self.evaluator = evaluator

def process(self, input_data):
for model in self.models:
# Assume model.generate(input_data) returns {'content': [{'logprob': value}]}
result = model.generate(input_data)
if self.evaluator.evaluate(result):
# If the evaluator passes the output, return the response
return result.choices[0].message.content
# If none of the models produce a satisfactory response, return an indication of failure
return "No satisfactory response found."
6 changes: 6 additions & 0 deletions Classification/Classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import List, Dict
from pydantic import BaseModel

class Classification(BaseModel):
name: str
description: str
9 changes: 9 additions & 0 deletions Classification/Eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

class Eval(ABC):
def __init__(self, threshold: float):
self.threshold = threshold

@abstractmethod
def evaluate(self, result):
pass
20 changes: 20 additions & 0 deletions Classification/LogProbEval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import math
from Classification.Eval import Eval
import numpy as np

class LogProbEval(Eval):
def __init__(self, threshold):
self.threshold = threshold

def evaluate(self, result):
# Ensure there are log probabilities to process
if not result.choices[0].logprobs.content:
return 0.0

# Calculate the total log probability
total_logprob = sum([math.exp(c.logprob) for c in result.choices[0].logprobs.content])

# Calculate the average log probability
avg_logprob = total_logprob / len(result.choices[0].logprobs.content)

return avg_logprob > self.threshold
36 changes: 36 additions & 0 deletions Classification/Model35.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from Classification.ModelDecorator import ModelDecorator
from openai import OpenAI
from config import API_KEY_OPENAI

client = OpenAI(api_key=API_KEY_OPENAI)

class Model35(ModelDecorator):
def __init__(self):
super().__init__(self._model_function)

def _model_function(self, input_data):
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are an API classification tools. You receive some text, and return a JSON with the result: \n\ninput:\n##Content\n.....\n##Classifications\nOne: this is a description\nTwo: this is another description\n\noutput:\n{\"result\": \"One\"}"
},
{
"role": "user",
"content": input_data
}
],
temperature=1,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
logprobs=True,
top_logprobs=1
)

return response

def generate(self, input_data):
return self.model_function(input_data)
36 changes: 36 additions & 0 deletions Classification/Model4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from openai import OpenAI
from Classification.ModelDecorator import ModelDecorator
from config import API_KEY_OPENAI

client = OpenAI(api_key=API_KEY_OPENAI)

class Model4(ModelDecorator):
def __init__(self):
super().__init__(self._model_function)

def _model_function(self, input_data):
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "system",
"content": "You are an API classification tools. You receive some text, and return a JSON with the result: \n\ninput:\n##Content\n.....\n##Classifications\nOne: this is a description\nTwo: this is another description\n\noutput:\n{\"result\": \"One\"}"
},
{
"role": "user",
"content": input_data
}
],
temperature=1,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
logprobs=True,
top_logprobs=1
)

return response

def generate(self, input_data):
return self.model_function(input_data)
12 changes: 12 additions & 0 deletions Classification/ModelDecorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod

class ModelDecorator(ABC):
def __init__(self, model_function):
# model_function is a lambda or function that simulates interacting with a specific LLM's API.
self.model_function = model_function

@abstractmethod
def generate(self, input_data):
# This method simulates generating a response from a language model
# In a real scenario, this would interact with the model's API and return both the response and logprobs
pass
3 changes: 2 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
API_KEY = 'XXXX'
API_KEY_ANTROPIC = 'XXXX'
API_KEY_ANTROPIC = 'XXXX'
API_KEY_OPENAI = 'XXXX'
72 changes: 71 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shutil
import tempfile
import time
from typing import List
import urllib.request
from http.client import HTTPException
from io import BytesIO
Expand All @@ -28,9 +29,16 @@

from Antropic.AnthropicsApiRequest import AnthropicsApiRequest
from Antropic.AnthropicsApiService import AnthropicsApiService
from Classification.Cascade import Cascade
from Classification.Classification import Classification
from Classification.Model35 import Model35
from Classification.Model4 import Model4
from Classification.LogProbEval import LogProbEval
from Classification.ModelDecorator import ModelDecorator
from CustomException import CustomException
from Payload import Message, Payload
from config import API_KEY, API_KEY_ANTROPIC
from utils import remove_json_format

# local path to tesseract
pytesseract.pytesseract.tesseract_cmd = 'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
Expand Down Expand Up @@ -102,7 +110,6 @@ async def extract_text(file: UploadFile = File(...)):

# Return the base64 images in a JSON response
return JSONResponse(content={"images": base64_images})


@app.post("/extract_text_from_url")
async def extract_text_from_url(url: str):
Expand Down Expand Up @@ -272,6 +279,69 @@ async def process_excel_file(file: UploadFile = File(...), extraction_contract:
# Return the response
return {"Content": response}


@app.post("/classify")
async def classify(file: UploadFile = File(...), classifications: str = Form(...)):
classifications_list = verify_classifications(classifications)

# Process the file to get the input data
extracted_text = process_file(file)

# Wrap the models with ModelDecorator
model1 = Model35()
model2 = Model4()

# Create an evaluator with a threshold of 0.15
evaluator = LogProbEval(0.99)

# Create a cascade with the models and evaluator
cascade = Cascade([model1, model2], evaluator)

# Test the cascade with some input data
input_data = f"##Content\n{extracted_text[0]}\n##Classifications\n" + "\n".join([f"{c['name']}: {c['description']}" for c in classifications_list]) + "\n\n##JSON Output\n"

result = cascade.process(input_data)

content = remove_json_format(result)

return json.loads(content)

def process_file(file):
# Create a temporary file and save the uploaded file to it
temp_file = tempfile.NamedTemporaryFile(delete=False)
shutil.copyfileobj(file.file, temp_file)
file_path = temp_file.name

# Check the file type
file_type = imghdr.what(file_path)
if file_type is None:
# If the file is not an image, assume it's a PDF and extract the text from it
images = convert_pdf_to_images(file_path)
extracted_text = extract_text_with_pytesseract(images)
input_data = "\n new page --- \n".join(extracted_text)
else:
# If the file is an image or text, read it directly
with open(file_path, 'r') as f:
input_data = f.read()

return input_data

def verify_classifications(classifications: str) -> List[dict]:
try:
# Attempt to deserialize the classifications string into a list of dictionaries
classifications_list = json.loads(classifications)
except json.JSONDecodeError:
# If deserialization fails, return a 400 error
raise HTTPException(status_code=400, detail="Invalid classifications format")

# Check that each classification has a 'name' and 'description'
for classification in classifications_list:
if 'name' not in classification or 'description' not in classification:
# If a classification doesn't have a 'name' or 'description', return a 401 error
raise HTTPException(status_code=400, detail="Invalid classification format")

return classifications_list

def send_request_to_mistral(content: str) -> str:
url = "https://api.mistral.ai/v1/chat/completions"
headers = {
Expand Down
Binary file modified requirements.txt
Binary file not shown.
22 changes: 22 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

def remove_json_format(json_string: str) -> str:
replace = json_string.replace("```json", "")
replace = replace.replace("```", "")
return replace.strip()

def remove_last_element(json_string: str) -> str:
try:
json.loads(json_string)
return json_string
except json.JSONDecodeError:
pass

last_index = json_string.rfind("},")

if last_index == -1:
return json_string

trimmed_string = json_string[:last_index + 1]
trimmed_string += ","
return trimmed_string

0 comments on commit 6a9a1ce

Please sign in to comment.