diff --git a/app.py b/app.py index f311e01..6ea1003 100644 --- a/app.py +++ b/app.py @@ -6,8 +6,10 @@ from base import ROOT_DIR from PIL import Image import os -from utils import main +from utils import main +from base import reader +from routes import predict_v2 async def predict(request): form = await request.form() @@ -29,9 +31,33 @@ async def predict(request): }, ) +async def get_ocr_data(request): + form = await request.form() + filename = form["image"].filename + contents = await form["image"].read() + contents = io.BytesIO(contents) + save_path = str(ROOT_DIR) + "/image/" + filename + Image.open(contents).save(save_path) + + data = reader.readtext(save_path) + data = [{"bounding_box": d[0], "text": d[1], "confidence": d[2]} for d in data] + for d in data: + d["bounding_box"] = [[int(x), int(y)] for x, y in d["bounding_box"]] + os.remove(save_path) + + return JSONResponse( + { + "status": "success", + "data": data, + }, + ) + + routes = [ - Route('/api/predict', predict, methods=['POST']) + Route('/api/predict', predict, methods=['POST']), + Route('/api/get-ocr-data', get_ocr_data, methods=['POST']), + Route('/api/predict-v2', predict_v2, methods=['POST']), ] app = Starlette(debug=True, routes=routes) diff --git a/base.py b/base.py index 39ef5e5..d24d034 100644 --- a/base.py +++ b/base.py @@ -1,5 +1,10 @@ from pathlib import Path import easyocr +from dotenv import dotenv_values +import os + +config = dotenv_values(".env") + # print("-"*20) # print("Initializeds") @@ -11,4 +16,6 @@ YOLO_PATH = str(ROOT_DIR / "yolov5") -reader = easyocr.Reader(['en']) \ No newline at end of file +reader = easyocr.Reader(['en']) + +OPENAI_API_KEY = config.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY")) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a919c20..a70e730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ torchvision>=0.8.1 tqdm>=4.64.0 protobuf<4.21.3 easyocr>=1.6.2 +python-dotenv==1.0.0 # Logging ------------------------------------- # tensorboard>=2.4.1 # wandb @@ -31,3 +32,8 @@ thop>=0.1.1 # FLOPs computation uvicorn==0.20.0 python-multipart==0.0.5 starlette==0.22.0 + +# Ocr V2 -------------------------------------- +langchain==0.0.292 +openai==0.28.1 +Jinja2==3.1.2 diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000..f3e50a3 --- /dev/null +++ b/routes/__init__.py @@ -0,0 +1 @@ +from routes.predict_v2.index import predict_v2 # noqa \ No newline at end of file diff --git a/routes/predict_v2/ai_helpers/chain.py b/routes/predict_v2/ai_helpers/chain.py new file mode 100644 index 0000000..79c5444 --- /dev/null +++ b/routes/predict_v2/ai_helpers/chain.py @@ -0,0 +1,45 @@ +import json + +from langchain.chat_models import ChatOpenAI +from langchain.prompts import ( + SystemMessagePromptTemplate, + ChatPromptTemplate, +) +from langchain import PromptTemplate, LLMChain + +from base import OPENAI_API_KEY + + +class ChatChain: + def __init__(self, ocr_data): + llm = ChatOpenAI(temperature=0.2, openai_api_key=OPENAI_API_KEY, model="gpt-4") + + template = f"""You are an OCR to JSON converter for 5ParaMonitor. You are given ocr json data for a 5 Para monitor, analyze the brand and predict patients reading, you will output the readings in JSON format. +5ParaMonitor OCR data: {json.dumps(ocr_data)} +Tips to analyze the ocr data: monitor can be zoomed in or zoomed out, ocr data is read from left to right of an image from top to bottom(with every row you go down), most of the times readings that we want are at extreme right of the monitor screen, there can be params like spo2, temp, bp etc present in ocr data, use them to your benefits and identify the correct value, temperature is always a decimal value, don't repeat a single value for multiple params, if you are not sure about a value, you can answer it as null, use common sense to get the correct field of a value. +Example output: +{{"spo2": "value/null", "resp": "value/null", "temperature": "value/null", "pulse":"value/null", "bp":"value/null"}} +""" + + system_prompt = PromptTemplate( + template=template, input_variables=[], template_format="jinja2" + ) + system_message_prompt = SystemMessagePromptTemplate(prompt=system_prompt) + + chat_prompt = ChatPromptTemplate.from_messages( + [ + system_message_prompt, + ] + ) + + self.chain = LLMChain( + llm=llm, + prompt=chat_prompt, + verbose=True, + ) + + async def async_predict(self): + prediction = await self.chain.apredict() + + parsed_prediction = json.loads(prediction) + return parsed_prediction diff --git a/routes/predict_v2/index.py b/routes/predict_v2/index.py new file mode 100644 index 0000000..64df97d --- /dev/null +++ b/routes/predict_v2/index.py @@ -0,0 +1,25 @@ +import traceback + +from starlette.responses import JSONResponse + +from routes.predict_v2.ai_helpers.chain import ChatChain + + +async def predict_v2(request): + + try: + data = await request.json() + + if "ocr_data" not in data: + return JSONResponse(status_code= 400, content={"error": "ocr_data not found"}) + + ocr_data = data["ocr_data"] + + chat_chain = ChatChain(ocr_data) + response = await chat_chain.async_predict() + + return JSONResponse({"data": response}) + + except Exception as e: + traceback.print_exc() + return JSONResponse(status_code=500, content={"error": "Something went wrong"})