from fastapi import FastAPI app = FastAPI() @app.get("/") def greet_json(): return {"Hello": "World!"} from transformers import BertTokenizer, BertForSequenceClassification import torch from bertopic import BERTopic model = BERTopic.load("sleiyer/restricted_item_detector") # Load the trained model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Function to predict the class of a single input text def predict(text): # Preprocess the input text inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) # Make predictions with torch.no_grad(): outputs = model(**inputs) # Get the predicted class logits = outputs.logits predicted_class = torch.argmax(logits, dim=1).item() label_map = {0: 'Allowed Item', 1: 'Restricted Item'} # Map the predicted class to a human-readable label predicted_label = label_map[predicted_class] # Displaying the user input return f'The item "{text}" is classified as: "{predicted_label}"' return predicted_class @app.post("/predict") def predict(input): return predict(input)