Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pr3 #36

Merged
merged 3 commits into from
Nov 25, 2024
Merged

pr3 #36

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homework_9/pr1/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Run
```bash
pip install -r requirements.txt
streamlit run project/main.py
streamlit run main.py
```

## Tests
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion homework_9/pr2/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Run
```bash
pip install -r requirements.txt
python project/main.py
python main.py
```

## Tests
Expand Down
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions homework_9/pr3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Run

```bash
cd homework_9
uvicorn pr3.app:app --reload
```

## Tests
File renamed without changes.
42 changes: 42 additions & 0 deletions homework_9/pr3/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from fastapi import FastAPI
from pydantic import BaseModel
from .utils import Model

model = Model(model_name='distilbert-base-uncased-finetuned-sst-2-english')

app = FastAPI()


class SentimentRequest(BaseModel):
text: str

class SentimentResponse(BaseModel):
text: str
sentiment: str
probability: float

class ProbabilityResponse(BaseModel):
text: str
probability: float

@app.get("/")
def read_root():
return {"message": "Welcome to the sentiment analysis API"}

@app.post("/predict")
def predict_sentiment(request: SentimentRequest) -> SentimentResponse:
label = model.predict(request.text)
probability = model.predict_proba(request.text)
return SentimentResponse(
text=request.text,
sentiment=label,
probability=float(probability)
)

@app.post("/probability")
def get_probability(request: SentimentRequest) -> ProbabilityResponse:
probability = model.predict_proba(request.text)
return ProbabilityResponse(
text=request.text,
probability=float(probability)
)
26 changes: 26 additions & 0 deletions homework_9/pr3/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

class Model:
def __init__(self, model_name="distilbert-base-uncased-finetuned-sst-2-english"):
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
self.model = DistilBertForSequenceClassification.from_pretrained(model_name)
self.model.eval()

def predict(self, text):
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
return self.model.config.id2label[predicted_class_id]

def predict_proba(self, text):
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
return probabilities.squeeze().max().item()
6 changes: 5 additions & 1 deletion homework_9/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ pandas==2.2.1
torch==2.2.1
streamlit==1.39.0
pytest==8.3.0
gradio==4.44.1
gradio==4.44.1
fastapi[standard]==0.114.0
pydantic==2.8.2
uvicorn==0.30.5
requests==2.32.2
67 changes: 67 additions & 0 deletions homework_9/tests/fastapi_tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from fastapi.testclient import TestClient
from pr3.app import app

client = TestClient(app)

def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to the sentiment analysis API"}

def test_predict_sentiment():
# Test with positive text
response = client.post(
"/predict",
json={"text": "I love this movie!"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "sentiment" in data
assert "probability" in data
assert data["text"] == "I love this movie!"
assert isinstance(data["sentiment"], str)
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

# Test with negative text
response = client.post(
"/predict",
json={"text": "I hate this movie!"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "sentiment" in data
assert "probability" in data
assert data["text"] == "I hate this movie!"
assert isinstance(data["sentiment"], str)
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

def test_get_probability():
response = client.post(
"/probability",
json={"text": "This is a test message"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "probability" in data
assert data["text"] == "This is a test message"
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

def test_invalid_request():
# Test missing text field
response = client.post(
"/predict",
json={}
)
assert response.status_code == 422

response = client.post(
"/probability",
json={}
)
assert response.status_code == 422
7 changes: 5 additions & 2 deletions homework_9/tests/gradio_func/test_gradio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pr2.func_gradio.utils import Model
from pr2.utils import Model

@pytest.fixture
def model():
Expand Down Expand Up @@ -36,4 +36,7 @@ def test_model_consistency(model):
# Test multiple times to ensure consistency
for _ in range(3):
assert model.predict(text) == first_prediction
assert abs(model.predict_proba(text) - first_probability) < 1e-6
assert abs(model.predict_proba(text) - first_probability) < 1e-6



2 changes: 1 addition & 1 deletion homework_9/tests/streamlit_func/test_streamlit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pr1.func_st.utils import Model
from pr1.utils import Model

@pytest.fixture
def model():
Expand Down