-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
102 lines (84 loc) · 2.74 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from http.client import HTTPResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
import uvicorn
from fastapi import FastAPI, Request, Response
from pydantic import BaseModel
import uvicorn, json, datetime
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from starlette.middleware.cors import CORSMiddleware
import logging
import time
logging.basicConfig(level=logging.INFO)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
class Query(BaseModel):
text: str
from src.system import *
import numpy as np
import json
import torch
import random
from typing import List, Optional
from pydantic import BaseModel
from fastapi import status
from fastapi import HTTPException
from pydantic import ValidationError
oneedit = OneEdit('/mnt/xzk/OneEdit/hparams.yaml')
class Message(BaseModel):
role: str
content: str
class InputData(BaseModel):
messages: List[Message]
stream: bool
model: str
temperature: float
presence_penalty: float
frequency_penalty: float
top_p: float
@app.post("/v1/chat/completions")
async def chatQuery(input_data: InputData):
try:
if input_data.model == "Chat Mode":
str = oneedit.generate(input_data.messages[-1].content)
else:
str = oneedit.edit_knowledge(input_data.messages[-1].content)
return Response(
status_code=200,
content=str,
media_type="text/plain"
)
except ValidationError as e:
print(f"Validation error: {e.json()}")
raise HTTPException(status_code=422, detail=json.loads(e.json()))
except Exception as e:
print(f"Other error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
return JSONResponse(status_code=exc.status_code, content={"status": "fail", "detail": exc.detail})
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc: RequestValidationError):
print(f"Request validation error: {exc.errors()}")
return JSONResponse(
status_code=422,
content={"status": "fail", "detail": exc.errors()},
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc: Exception):
print(f"General error: {str(exc)}")
return JSONResponse(
status_code=500,
content={"status": "error", "detail": str(exc)},
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=2001)