Spaces:
Sleeping
Sleeping
Sheng Lei
commited on
Commit
•
2a6f3b3
1
Parent(s):
51c5d35
more fixes
Browse files- app.py +10 -5
- restrictedItems/predict.py +2 -1
- restrictedItems/train.py +2 -1
app.py
CHANGED
@@ -15,10 +15,15 @@ model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_d
|
|
15 |
# Load the trained model and tokenizer
|
16 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
# Function to predict the class of a single input text
|
19 |
-
def predict(
|
20 |
# Preprocess the input text
|
21 |
-
inputs = tokenizer(
|
22 |
|
23 |
# Make predictions
|
24 |
with torch.no_grad():
|
@@ -34,10 +39,10 @@ def predict(text: str):
|
|
34 |
predicted_label = label_map[predicted_class]
|
35 |
|
36 |
# Displaying the user input
|
37 |
-
return f'The item "{
|
38 |
|
39 |
return predicted_class
|
40 |
|
41 |
@app.post("/predict")
|
42 |
-
def
|
43 |
-
return predict(
|
|
|
15 |
# Load the trained model and tokenizer
|
16 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
17 |
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
class Predict(BaseModel):
|
21 |
+
input: str
|
22 |
+
|
23 |
# Function to predict the class of a single input text
|
24 |
+
def predict(request: Predict):
|
25 |
# Preprocess the input text
|
26 |
+
inputs = tokenizer(request.input, return_tensors='pt', truncation=True, padding=True)
|
27 |
|
28 |
# Make predictions
|
29 |
with torch.no_grad():
|
|
|
39 |
predicted_label = label_map[predicted_class]
|
40 |
|
41 |
# Displaying the user input
|
42 |
+
return f'The item "{request.input}" is classified as: "{predicted_label}"'
|
43 |
|
44 |
return predicted_class
|
45 |
|
46 |
@app.post("/predict")
|
47 |
+
def predictApi(request: Predict):
|
48 |
+
return predict(request)
|
restrictedItems/predict.py
CHANGED
@@ -2,7 +2,8 @@ from transformers import BertTokenizer, BertForSequenceClassification
|
|
2 |
import torch
|
3 |
|
4 |
# Load the trained model and tokenizer
|
5 |
-
model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/
|
|
|
6 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
7 |
|
8 |
# Function to predict the class of a single input text
|
|
|
2 |
import torch
|
3 |
|
4 |
# Load the trained model and tokenizer
|
5 |
+
# model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/spaces/restricted_item_detector/trained_model")
|
6 |
+
model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_detector")
|
7 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
8 |
|
9 |
# Function to predict the class of a single input text
|
restrictedItems/train.py
CHANGED
@@ -100,7 +100,7 @@ train_dataset = ShoppingCartDataset(train_encodings, train_labels)
|
|
100 |
val_dataset = ShoppingCartDataset(val_encodings, val_labels)
|
101 |
|
102 |
# Load pre-trained BERT model
|
103 |
-
model = BertForSequenceClassification.from_pretrained('
|
104 |
|
105 |
# Training arguments
|
106 |
training_args = TrainingArguments(
|
@@ -128,4 +128,5 @@ trainer.train()
|
|
128 |
# Evaluate model
|
129 |
trainer.evaluate()
|
130 |
|
|
|
131 |
model.push_to_hub("sleiyer/restricted_item_detector")
|
|
|
100 |
val_dataset = ShoppingCartDataset(val_encodings, val_labels)
|
101 |
|
102 |
# Load pre-trained BERT model
|
103 |
+
model = BertForSequenceClassification.from_pretrained('sleiyer/restricted_item_detector')
|
104 |
|
105 |
# Training arguments
|
106 |
training_args = TrainingArguments(
|
|
|
128 |
# Evaluate model
|
129 |
trainer.evaluate()
|
130 |
|
131 |
+
model.save_pretrained('trained_model')
|
132 |
model.push_to_hub("sleiyer/restricted_item_detector")
|