From 262de9b37e3a0afbe96f6d20de8cb162f929622c Mon Sep 17 00:00:00 2001 From: danilyef Date: Mon, 25 Nov 2024 11:52:25 +0100 Subject: [PATCH 1/3] pr3 --- homework_9/pr1/README.md | 2 +- homework_9/pr1/{func_st => }/main.py | 0 homework_9/pr1/{func_st => }/utils.py | 0 homework_9/pr2/README.md | 2 +- homework_9/pr2/{func_gradio => }/main.py | 0 homework_9/pr2/{func_gradio => }/utils.py | 0 homework_9/pr3/README.md | 8 +++ homework_9/{pr1/func_st => pr3}/__init__.py | 0 homework_9/pr3/app.py | 42 ++++++++++++ homework_9/pr3/utils.py | 26 +++++++ homework_9/requirements.txt | 6 +- .../fastapi_tests}/__init__.py | 0 .../tests/fastapi_tests/test_fastapi.py | 67 +++++++++++++++++++ homework_9/tests/gradio_func/test_gradio.py | 2 +- .../tests/streamlit_func/test_streamlit.py | 2 +- 15 files changed, 152 insertions(+), 5 deletions(-) rename homework_9/pr1/{func_st => }/main.py (100%) rename homework_9/pr1/{func_st => }/utils.py (100%) rename homework_9/pr2/{func_gradio => }/main.py (100%) rename homework_9/pr2/{func_gradio => }/utils.py (100%) create mode 100644 homework_9/pr3/README.md rename homework_9/{pr1/func_st => pr3}/__init__.py (100%) create mode 100644 homework_9/pr3/app.py create mode 100644 homework_9/pr3/utils.py rename homework_9/{pr2/func_gradio => tests/fastapi_tests}/__init__.py (100%) create mode 100644 homework_9/tests/fastapi_tests/test_fastapi.py diff --git a/homework_9/pr1/README.md b/homework_9/pr1/README.md index 3f69e3f..bb03f81 100644 --- a/homework_9/pr1/README.md +++ b/homework_9/pr1/README.md @@ -1,7 +1,7 @@ ## Run ```bash pip install -r requirements.txt -streamlit run project/main.py +streamlit run main.py ``` ## Tests diff --git a/homework_9/pr1/func_st/main.py b/homework_9/pr1/main.py similarity index 100% rename from homework_9/pr1/func_st/main.py rename to homework_9/pr1/main.py diff --git a/homework_9/pr1/func_st/utils.py b/homework_9/pr1/utils.py similarity index 100% rename from homework_9/pr1/func_st/utils.py rename to homework_9/pr1/utils.py diff --git a/homework_9/pr2/README.md b/homework_9/pr2/README.md index bd4d38c..c0be38b 100644 --- a/homework_9/pr2/README.md +++ b/homework_9/pr2/README.md @@ -1,7 +1,7 @@ ## Run ```bash pip install -r requirements.txt -python project/main.py +python main.py ``` ## Tests diff --git a/homework_9/pr2/func_gradio/main.py b/homework_9/pr2/main.py similarity index 100% rename from homework_9/pr2/func_gradio/main.py rename to homework_9/pr2/main.py diff --git a/homework_9/pr2/func_gradio/utils.py b/homework_9/pr2/utils.py similarity index 100% rename from homework_9/pr2/func_gradio/utils.py rename to homework_9/pr2/utils.py diff --git a/homework_9/pr3/README.md b/homework_9/pr3/README.md new file mode 100644 index 0000000..53f310a --- /dev/null +++ b/homework_9/pr3/README.md @@ -0,0 +1,8 @@ +## Run + +```bash +cd homework_9/pr3 +uvicorn app:app --reload +``` + +## Tests diff --git a/homework_9/pr1/func_st/__init__.py b/homework_9/pr3/__init__.py similarity index 100% rename from homework_9/pr1/func_st/__init__.py rename to homework_9/pr3/__init__.py diff --git a/homework_9/pr3/app.py b/homework_9/pr3/app.py new file mode 100644 index 0000000..24c5d62 --- /dev/null +++ b/homework_9/pr3/app.py @@ -0,0 +1,42 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from utils import Model + +model = Model(model_name='distilbert-base-uncased-finetuned-sst-2-english') + +app = FastAPI() + + +class SentimentRequest(BaseModel): + text: str + +class SentimentResponse(BaseModel): + text: str + sentiment: str + probability: float + +class ProbabilityResponse(BaseModel): + text: str + probability: float + +@app.get("/") +def read_root(): + return {"message": "Welcome to the sentiment analysis API"} + +@app.post("/predict") +def predict_sentiment(request: SentimentRequest) -> SentimentResponse: + label = model.predict(request.text) + probability = model.predict_proba(request.text) + return SentimentResponse( + text=request.text, + sentiment=label, + probability=float(probability) + ) + +@app.post("/probability") +def get_probability(request: SentimentRequest) -> ProbabilityResponse: + probability = model.predict_proba(request.text) + return ProbabilityResponse( + text=request.text, + probability=float(probability) + ) \ No newline at end of file diff --git a/homework_9/pr3/utils.py b/homework_9/pr3/utils.py new file mode 100644 index 0000000..7565235 --- /dev/null +++ b/homework_9/pr3/utils.py @@ -0,0 +1,26 @@ +from transformers import DistilBertTokenizer, DistilBertForSequenceClassification +import torch + +class Model: + def __init__(self, model_name="distilbert-base-uncased-finetuned-sst-2-english"): + self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) + self.model = DistilBertForSequenceClassification.from_pretrained(model_name) + self.model.eval() + + def predict(self, text): + inputs = self.tokenizer( + text, return_tensors="pt", truncation=True, padding=True + ) + with torch.no_grad(): + outputs = self.model(**inputs) + predicted_class_id = torch.argmax(outputs.logits, dim=1).item() + return self.model.config.id2label[predicted_class_id] + + def predict_proba(self, text): + inputs = self.tokenizer( + text, return_tensors="pt", truncation=True, padding=True + ) + with torch.no_grad(): + outputs = self.model(**inputs) + probabilities = torch.softmax(outputs.logits, dim=1) + return probabilities.squeeze().max().item() diff --git a/homework_9/requirements.txt b/homework_9/requirements.txt index 1a121ad..a3af5dd 100644 --- a/homework_9/requirements.txt +++ b/homework_9/requirements.txt @@ -4,4 +4,8 @@ pandas==2.2.1 torch==2.2.1 streamlit==1.39.0 pytest==8.3.0 -gradio==4.44.1 \ No newline at end of file +gradio==4.44.1 +fastapi[standard]==0.114.0 +pydantic==2.8.2 +uvicorn==0.30.5 +requests==2.32.2 \ No newline at end of file diff --git a/homework_9/pr2/func_gradio/__init__.py b/homework_9/tests/fastapi_tests/__init__.py similarity index 100% rename from homework_9/pr2/func_gradio/__init__.py rename to homework_9/tests/fastapi_tests/__init__.py diff --git a/homework_9/tests/fastapi_tests/test_fastapi.py b/homework_9/tests/fastapi_tests/test_fastapi.py new file mode 100644 index 0000000..6f8d919 --- /dev/null +++ b/homework_9/tests/fastapi_tests/test_fastapi.py @@ -0,0 +1,67 @@ +from fastapi.testclient import TestClient +from pr3.app import app + +client = TestClient(app) + +def test_read_root(): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Welcome to the sentiment analysis API"} + +def test_predict_sentiment(): + # Test with positive text + response = client.post( + "/predict", + json={"text": "I love this movie!"} + ) + assert response.status_code == 200 + data = response.json() + assert "text" in data + assert "sentiment" in data + assert "probability" in data + assert data["text"] == "I love this movie!" + assert isinstance(data["sentiment"], str) + assert isinstance(data["probability"], float) + assert 0 <= data["probability"] <= 1 + + # Test with negative text + response = client.post( + "/predict", + json={"text": "I hate this movie!"} + ) + assert response.status_code == 200 + data = response.json() + assert "text" in data + assert "sentiment" in data + assert "probability" in data + assert data["text"] == "I hate this movie!" + assert isinstance(data["sentiment"], str) + assert isinstance(data["probability"], float) + assert 0 <= data["probability"] <= 1 + +def test_get_probability(): + response = client.post( + "/probability", + json={"text": "This is a test message"} + ) + assert response.status_code == 200 + data = response.json() + assert "text" in data + assert "probability" in data + assert data["text"] == "This is a test message" + assert isinstance(data["probability"], float) + assert 0 <= data["probability"] <= 1 + +def test_invalid_request(): + # Test missing text field + response = client.post( + "/predict", + json={} + ) + assert response.status_code == 422 + + response = client.post( + "/probability", + json={} + ) + assert response.status_code == 422 \ No newline at end of file diff --git a/homework_9/tests/gradio_func/test_gradio.py b/homework_9/tests/gradio_func/test_gradio.py index 3be5980..5ae10bf 100644 --- a/homework_9/tests/gradio_func/test_gradio.py +++ b/homework_9/tests/gradio_func/test_gradio.py @@ -1,5 +1,5 @@ import pytest -from pr2.func_gradio.utils import Model +from pr2.utils import Model @pytest.fixture def model(): diff --git a/homework_9/tests/streamlit_func/test_streamlit.py b/homework_9/tests/streamlit_func/test_streamlit.py index 914de79..24211b0 100644 --- a/homework_9/tests/streamlit_func/test_streamlit.py +++ b/homework_9/tests/streamlit_func/test_streamlit.py @@ -1,5 +1,5 @@ import pytest -from pr1.func_st.utils import Model +from pr1.utils import Model @pytest.fixture def model(): From a9b48b5f1ac776952921c3e2fdf4a94dc757342b Mon Sep 17 00:00:00 2001 From: danilyef Date: Mon, 25 Nov 2024 13:40:47 +0100 Subject: [PATCH 2/3] fix --- homework_9/pr3/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homework_9/pr3/app.py b/homework_9/pr3/app.py index 24c5d62..331260a 100644 --- a/homework_9/pr3/app.py +++ b/homework_9/pr3/app.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from pydantic import BaseModel -from utils import Model +from .utils import Model model = Model(model_name='distilbert-base-uncased-finetuned-sst-2-english') From f0c15db5739510459b5a67bd7955ff7c6b6b5198 Mon Sep 17 00:00:00 2001 From: danilyef Date: Mon, 25 Nov 2024 14:45:10 +0100 Subject: [PATCH 3/3] pr3 small fix --- homework_9/pr3/README.md | 4 ++-- homework_9/tests/gradio_func/test_gradio.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/homework_9/pr3/README.md b/homework_9/pr3/README.md index 53f310a..5a6b53d 100644 --- a/homework_9/pr3/README.md +++ b/homework_9/pr3/README.md @@ -1,8 +1,8 @@ ## Run ```bash -cd homework_9/pr3 -uvicorn app:app --reload +cd homework_9 +uvicorn pr3.app:app --reload ``` ## Tests diff --git a/homework_9/tests/gradio_func/test_gradio.py b/homework_9/tests/gradio_func/test_gradio.py index 5ae10bf..22fa786 100644 --- a/homework_9/tests/gradio_func/test_gradio.py +++ b/homework_9/tests/gradio_func/test_gradio.py @@ -36,4 +36,7 @@ def test_model_consistency(model): # Test multiple times to ensure consistency for _ in range(3): assert model.predict(text) == first_prediction - assert abs(model.predict_proba(text) - first_probability) < 1e-6 \ No newline at end of file + assert abs(model.predict_proba(text) - first_probability) < 1e-6 + + +