-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit-demo.py
39 lines (28 loc) · 1.09 KB
/
streamlit-demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
from transformers import pipeline
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
st.title('Sensitive Prompt Detection - Finetuned model')
input = st.text_area('Prompt', 'Enter input prompt')
# Load the model (only executed once!)
# NOTE: Don't set ttl or max_entries in this case
@st.cache_resource
def load_model():
return AutoModelForSequenceClassification.from_pretrained("Harish-wald/sensitive-bert",token="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", num_labels=2)
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained("bert-base-cased")
model = load_model()
model.eval()
tokenizer = load_tokenizer()
def run_prompt():
tokenized_input = tokenizer([input], padding="max_length", truncation=True,return_tensors='pt')
output = model.forward(**tokenized_input)
st.subheader('Output: ')
label = torch.argmax(output['logits']).tolist()
output = 'non-confidential'
if label:
output = 'confidential'
st.write(output)
st.button('Run', on_click=run_prompt)