Sheng Lei commited on
Commit
305150f
1 Parent(s): 0a07308

Add predict

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py CHANGED
@@ -5,3 +5,40 @@ app = FastAPI()
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
8
+
9
+
10
+ from transformers import BertTokenizer, BertForSequenceClassification
11
+ import torch
12
+
13
+ from bertopic import BERTopic
14
+
15
+ model = BERTopic.load("sleiyer/restricted_item_detector")
16
+ # Load the trained model and tokenizer
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+
19
+ # Function to predict the class of a single input text
20
+ def predict(text):
21
+ # Preprocess the input text
22
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
23
+
24
+ # Make predictions
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+
28
+ # Get the predicted class
29
+ logits = outputs.logits
30
+ predicted_class = torch.argmax(logits, dim=1).item()
31
+
32
+ label_map = {0: 'Allowed Item', 1: 'Restricted Item'}
33
+
34
+ # Map the predicted class to a human-readable label
35
+ predicted_label = label_map[predicted_class]
36
+
37
+ # Displaying the user input
38
+ return f'The item "{text}" is classified as: "{predicted_label}"'
39
+
40
+ return predicted_class
41
+
42
+ @app.post("/predict")
43
+ def predict(input):
44
+ return predict(input)