Skip to content

Commit

Permalink
Merge pull request #36 from danilyef/pr3_fashapi
Browse files Browse the repository at this point in the history
PR3: Write a FastAPI server for your model, with tests and CI integration.
  • Loading branch information
danilyef authored Nov 25, 2024
2 parents adc5e89 + f0c15db commit 79ad295
Show file tree
Hide file tree
Showing 15 changed files with 156 additions and 6 deletions.
2 changes: 1 addition & 1 deletion homework_9/pr1/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Run
```bash
pip install -r requirements.txt
streamlit run project/main.py
streamlit run main.py
```

## Tests
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion homework_9/pr2/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Run
```bash
pip install -r requirements.txt
python project/main.py
python main.py
```

## Tests
Expand Down
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions homework_9/pr3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Run

```bash
cd homework_9
uvicorn pr3.app:app --reload
```

## Tests
File renamed without changes.
42 changes: 42 additions & 0 deletions homework_9/pr3/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from fastapi import FastAPI
from pydantic import BaseModel
from .utils import Model

model = Model(model_name='distilbert-base-uncased-finetuned-sst-2-english')

app = FastAPI()


class SentimentRequest(BaseModel):
text: str

class SentimentResponse(BaseModel):
text: str
sentiment: str
probability: float

class ProbabilityResponse(BaseModel):
text: str
probability: float

@app.get("/")
def read_root():
return {"message": "Welcome to the sentiment analysis API"}

@app.post("/predict")
def predict_sentiment(request: SentimentRequest) -> SentimentResponse:
label = model.predict(request.text)
probability = model.predict_proba(request.text)
return SentimentResponse(
text=request.text,
sentiment=label,
probability=float(probability)
)

@app.post("/probability")
def get_probability(request: SentimentRequest) -> ProbabilityResponse:
probability = model.predict_proba(request.text)
return ProbabilityResponse(
text=request.text,
probability=float(probability)
)
26 changes: 26 additions & 0 deletions homework_9/pr3/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

class Model:
def __init__(self, model_name="distilbert-base-uncased-finetuned-sst-2-english"):
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
self.model = DistilBertForSequenceClassification.from_pretrained(model_name)
self.model.eval()

def predict(self, text):
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
return self.model.config.id2label[predicted_class_id]

def predict_proba(self, text):
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
return probabilities.squeeze().max().item()
6 changes: 5 additions & 1 deletion homework_9/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ pandas==2.2.1
torch==2.2.1
streamlit==1.39.0
pytest==8.3.0
gradio==4.44.1
gradio==4.44.1
fastapi[standard]==0.114.0
pydantic==2.8.2
uvicorn==0.30.5
requests==2.32.2
File renamed without changes.
67 changes: 67 additions & 0 deletions homework_9/tests/fastapi_tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from fastapi.testclient import TestClient
from pr3.app import app

client = TestClient(app)

def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to the sentiment analysis API"}

def test_predict_sentiment():
# Test with positive text
response = client.post(
"/predict",
json={"text": "I love this movie!"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "sentiment" in data
assert "probability" in data
assert data["text"] == "I love this movie!"
assert isinstance(data["sentiment"], str)
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

# Test with negative text
response = client.post(
"/predict",
json={"text": "I hate this movie!"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "sentiment" in data
assert "probability" in data
assert data["text"] == "I hate this movie!"
assert isinstance(data["sentiment"], str)
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

def test_get_probability():
response = client.post(
"/probability",
json={"text": "This is a test message"}
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "probability" in data
assert data["text"] == "This is a test message"
assert isinstance(data["probability"], float)
assert 0 <= data["probability"] <= 1

def test_invalid_request():
# Test missing text field
response = client.post(
"/predict",
json={}
)
assert response.status_code == 422

response = client.post(
"/probability",
json={}
)
assert response.status_code == 422
7 changes: 5 additions & 2 deletions homework_9/tests/gradio_func/test_gradio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pr2.func_gradio.utils import Model
from pr2.utils import Model

@pytest.fixture
def model():
Expand Down Expand Up @@ -36,4 +36,7 @@ def test_model_consistency(model):
# Test multiple times to ensure consistency
for _ in range(3):
assert model.predict(text) == first_prediction
assert abs(model.predict_proba(text) - first_probability) < 1e-6
assert abs(model.predict_proba(text) - first_probability) < 1e-6



2 changes: 1 addition & 1 deletion homework_9/tests/streamlit_func/test_streamlit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pr1.func_st.utils import Model
from pr1.utils import Model

@pytest.fixture
def model():
Expand Down

0 comments on commit 79ad295

Please sign in to comment.