from fastapi import FastAPI app = FastAPI() @app.get("/") def greet_json(): return {"Hello": "World!"} from transformers import BertTokenizer, BertForSequenceClassification import torch model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_detector") # Load the trained model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') from pydantic import BaseModel class Predict(BaseModel): input: str # Function to predict the class of a single input text def predict(request: Predict): # Preprocess the input text inputs = tokenizer(request.input, return_tensors='pt', truncation=True, padding=True) # Make predictions with torch.no_grad(): outputs = model(**inputs) # Get the predicted class logits = outputs.logits predicted_class = torch.argmax(logits, dim=1).item() label_map = {0: 'Allowed Item', 1: 'Restricted Item'} # Map the predicted class to a human-readable label predicted_label = label_map[predicted_class] # Displaying the user input return f'The item "{request.input}" is classified as: "{predicted_label}"' return predicted_class @app.post("/predict") def predictApi(request: Predict): return predict(request)