cdb24 commited on
Commit
a971e76
1 Parent(s): 9bbae05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -45
app.py CHANGED
@@ -1,49 +1,42 @@
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from streamlit_chat import message
3
  import requests
4
 
5
- st.set_page_config(
6
- page_title="Streamlit Chat - Demo",
7
- page_icon=":robot:"
8
- )
9
-
10
- API_URL = "https://api-inference.huggingface.co/models/facebook/blenderbot-400M-distill"
11
- headers = {"Authorization": st.secrets['api_key']}
12
-
13
- st.header("Streamlit Chat - Demo")
14
- st.markdown("[Github](https://github.com/ai-yash/st-chat)")
15
-
16
- if 'generated' not in st.session_state:
17
- st.session_state['generated'] = []
18
-
19
- if 'past' not in st.session_state:
20
- st.session_state['past'] = []
21
-
22
- def query(payload):
23
- response = requests.post(API_URL, headers=headers, json=payload)
24
- return response.json()
25
-
26
- def get_text():
27
- input_text = st.text_input("You: ","Hello, how are you?", key="input")
28
- return input_text
29
-
30
-
31
- user_input = get_text()
32
-
33
- if user_input:
34
- output = query({
35
- "inputs": {
36
- "past_user_inputs": st.session_state.past,
37
- "generated_responses": st.session_state.generated,
38
- "text": user_input,
39
- },"parameters": {"repetition_penalty": 1.33},
40
- })
41
-
42
- st.session_state.past.append(user_input)
43
- st.session_state.generated.append(output["generated_text"])
44
-
45
- if st.session_state['generated']:
46
-
47
- for i in range(len(st.session_state['generated'])-1, -1, -1):
48
- message(st.session_state["generated"][i], key=str(i))
49
- message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
 
1
+ from pathlib import Path
2
+ from sklearn.model_selection import train_test_split
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
6
+ from transformers import Trainer, TrainingArguments
7
+
8
  import streamlit as st
9
  from streamlit_chat import message
10
  import requests
11
 
12
+ model_one = "distilbert-base-uncased-finetuned-sst-2-english"
13
+ model_two = "Newtral/xlm-r-finetuned-toxic-political-tweets-es"
14
+
15
+ def toxicRating(text, model):
16
+ model = AutoModelForSequenceClassification.from_pretrained(model)
17
+ tokenizer = AutoTokenizer.from_pretrained(model)
18
+
19
+ classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
20
+ results = classifier(text)
21
+ return results
22
+
23
+ def main():
24
+ st.title("TOXIC TWEETS, \n TOXIC OR NOT?")
25
+ prompt = st.header("Select Model")
26
+ selection = st.radio("Models",('Model 1', 'Model 2'))
27
+
28
+
29
+ input = st.text_area("Enter Tweet: ")
30
+ if input:
31
+ if selection == 'Model 1':
32
+ rating = rate_ModelOne(input, model_one)
33
+ st.write(f"Label: {rating[1]} \n Score : {rating[3]}")
34
+ elif selection == 'Model 2':
35
+ rating = rate_ModelTwo(input, model_two)
36
+ rating = rate_ModelOne(input, model_one)
37
+ st.write(f"Label: {rating[1]} \n Score : {rating[3]}")
38
+ else:
39
+ st.warning("Enter Tweet")
40
+
41
+ if __name__ == "__main__":
42
+ main();