jonathanjordan21 commited on
Commit
6a128a7
1 Parent(s): ea8b4a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -2,12 +2,14 @@ from fastapi import FastAPI, Request
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
  import torch
4
  from pydantic import BaseModel
 
5
 
6
  app = FastAPI()
7
 
8
 
9
  class InputText(BaseModel):
10
  text : str
 
11
 
12
 
13
  model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
@@ -15,11 +17,45 @@ sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_name)
15
  sentiment_tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  sentiment_model.config.id2label[3] = "mixed"
17
 
 
 
 
 
 
 
 
 
 
 
18
  @app.get("/")
19
  def greet_json():
20
  return {"Hello": "World!"}
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @app.post("/sentiment_score")
24
  async def sentiment_score(inp: InputText):
25
  text = inp.text
 
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
  import torch
4
  from pydantic import BaseModel
5
+ from typing import Optional
6
 
7
  app = FastAPI()
8
 
9
 
10
  class InputText(BaseModel):
11
  text : str
12
+ threshold: Optional[float] = None
13
 
14
 
15
  model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
 
17
  sentiment_tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  sentiment_model.config.id2label[3] = "mixed"
19
 
20
+ model_name = 'qanastek/51-languages-classifier'
21
+ language_model = AutoModelForSequenceClassification.from_pretrained(model_name)
22
+ language_tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
  @app.get("/")
31
  def greet_json():
32
  return {"Hello": "World!"}
33
 
34
 
35
+
36
+ @app.post("/language_detection")
37
+ async def sentiment_score(inp: InputText):
38
+ inputs = tokenizer(inp.text, return_tensors='pt')
39
+ with torch.no_grad():
40
+ logits = language_model(**inputs).logits
41
+
42
+ softmax = torch.nn.functional.sigmoid(logits)
43
+
44
+ # Apply the threshold by creating a mask
45
+ mask = softmax >= inp.threshold
46
+
47
+ # Filter the tensor based on the threshold
48
+ filtered_x = softmax[mask]
49
+
50
+ # Get the sorted indices of the filtered tensor
51
+ sorted_indices = torch.argsort(filtered_x, descending=True)
52
+
53
+ # Map the sorted indices back to the original tensor indices
54
+ original_indices = torch.nonzero(mask, as_tuple=True)[1][sorted_indices]
55
+
56
+ return [{"label":model.config.id2label[predicted_class_id.tolist()], "score":softmax[0, predicted_class_id].tolist()} for predicted_class_id in original_indices]
57
+
58
+
59
  @app.post("/sentiment_score")
60
  async def sentiment_score(inp: InputText):
61
  text = inp.text