Spaces:
Sleeping
Sleeping
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
# Load the trained model and tokenizer | |
# model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/spaces/restricted_item_detector/trained_model") | |
model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_detector") | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Function to predict the class of a single input text | |
def predict(text): | |
# Preprocess the input text | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
# Make predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted class | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=1).item() | |
return predicted_class | |
label_map = {0: 'Allowed Item', 1: 'Restricted Item'} | |
def main(): | |
while True: | |
# Prompting the user for input | |
user_input = input("Enter something: ") | |
predicted_class = predict(user_input) | |
# Map the predicted class to a human-readable label | |
predicted_label = label_map[predicted_class] | |
# Displaying the user input | |
print(f'The item "{user_input}" is classified as: "{predicted_label}"') | |
if __name__ == "__main__": | |
main() | |