dahongj commited on
Commit
f9cf78c
1 Parent(s): 5af7428
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import numpy as np
2
  import streamlit as st
3
  from transformers import pipeline
 
4
 
5
  import torch
6
 
@@ -36,19 +37,23 @@ def siebert(data):
36
  return label, score
37
 
38
  def finetuned(data):
39
- specific_model = pipeline(model='dahongj/finetuned_toxictweets')
40
- result = specific_model(data)
41
- maxres = result[0]['label']
42
- maxscore = result[0]['score']
43
- sec = result[1]['label']
44
- secscore = result[1]['score']
 
 
 
 
45
 
46
- for i in result:
47
- if i['score'] > secscore:
48
- sec = i['label']
49
- secscore = i['score']
50
 
51
- return maxres, maxscore, sec, secscore
52
 
53
  def getSent(data, model):
54
  if(model == 'Bertweet'):
 
1
  import numpy as np
2
  import streamlit as st
3
  from transformers import pipeline
4
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
5
 
6
  import torch
7
 
 
37
  return label, score
38
 
39
  def finetuned(data):
40
+ model_name = "dahongj/finetuned_toxictweets"
41
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
42
+ model = DistilBertForSequenceClassification.from_pretrained(model_name)
43
+ tokenized_text = tokenizer(data, return_tensors="pt")
44
+ res = model(**tokenized_text)
45
+ mes = torch.sigmoid(res.logits)
46
+
47
+ Dict = {0: "toxic", 1: "severe_toxic", 2: "obscene", 3: "threat", 4: "insult", 5: "identity_hate"}
48
+
49
+ maxres, maxscore, sec, secscore = Dict[0], mes[0][0].item(), 0, 0
50
 
51
+ for i in range(1,6):
52
+ if mes[0][i].item() > secscore:
53
+ sec = i
54
+ secscore = mes[0][i].item()
55
 
56
+ return maxres, maxscore, Dict[sec], secscore
57
 
58
  def getSent(data, model):
59
  if(model == 'Bertweet'):