Spaces:
Sleeping
Sleeping
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
from sklearn.model_selection import train_test_split | |
import torch | |
# Load dataset | |
# Assume dataset is a list of tuples (text, label) | |
# dataset = [("item description 1", 0), ("item description 2", 1), ...] | |
restricted_dataset = [ | |
("Promotional Email GiftCard $10", 1), | |
("$100 Vanilla® Visa", 1), | |
("Promotional Email GiftCard $10", 1), | |
("GPS7000 GPS Tracker for", 1), | |
("$50 Vanilla Visa eGift", 1), | |
("GPS7000 GPS Tracker for", 1), | |
("$200 Vanilla Visa Shiny", 1), | |
("$200 Vanilla Visa Shiny", 1), | |
("$200 Vanilla Visa Shiny", 1), | |
("$200 Vanilla® Visa", 1), | |
("Promotional Email GiftCard $20", 1), | |
("Xbox $100 Gift Card", 1), | |
("$25 Vanilla Visa Shiny", 1), | |
("$25 Vanilla Visa Shiny", 1), | |
("$25 Vanilla Visa Shiny", 1), | |
("Xbox $100 Gift Card", 1), | |
("Home Depot Birthday Cupcake", 1), | |
("$100 Vanilla® Visa", 1), | |
("Birthday Celebration Target Giftcard", 1), | |
("Birthday Celebration Target Giftcard", 1), | |
("Promotional Email GiftCard $10", 1), | |
("Birthday Celebration Target Giftcard", 1), | |
("Nintendo Switch Family Online", 1), | |
("Promotional Email GiftCard $10", 1), | |
("$50 Vanilla® Mastercard", 1), | |
("$50 Vanilla® Mastercard", 1), | |
("$50 Vanilla® Mastercard", 1), | |
("$50 Vanilla® Visa", 1), | |
("$50 Vanilla Visa Shiny", 1), | |
("Delta Airlines Wedding $250", 1) | |
] | |
normal_dataset =[ | |
("Kerrygold Grass-Fed Pure Irish Garlic & Herb Butter Stick, 3", 0), | |
("bettergoods Garlic, Parmesan, & Basil Butter, 3 oz", 0), | |
("Birds Eye Savory Herb Riced Cauliflower, 10 oz (Frozen)", 0), | |
("Great Value Root Blend, Beets, Carrots, Parsnips and Sweet Potatoes", 0), | |
("Fresh Blueberries, 18 oz Container", 0), | |
("Mixpresso 3 Piece Black Canisters Sets For The Kitchen, Kitchen Jars With", 0), | |
("Freshness Guaranteed Chicken Breast Tenderloins, 2.25 - 3.2", 0), | |
("Kiolbassa Smoked Meats Beef Hickory Smoked Sausage, 4 links - 13oz", 0), | |
("Hot Pockets Frozen Snacks, Pepperoni Pizza Buttery Crust, 5 Sandwiches", 0), | |
("Kool Aid Jammers Tropical Punch Kids Drink 0% Juice Box Pouches, 10", 0), | |
("Frito-Lay Flavor Mix Variety Pack Snack Chips, 1oz Bags, 18 Count", 0), | |
("State Fair Classic Corn Dogs, 42.7 oz, 16 Count", 0), | |
("ASURION 2 Year Sporting Goods Protection Plan ($175 - $199.99)", 0), | |
("6% Incline Walking Pad Treadmill 320+ lb Capacity, Under The Desk", 0), | |
("Renpure Biotin & Collagen Thickening Conditioner for All Hair Types, 32 fl", 0), | |
("Renpure Biotin & Collagen Thickening Hair Shampoo for All Hair Types, 32", 0), | |
("eos Shea Better Body Lotion for Dry Skin, Vanilla Cashmere, 16 fl", 0), | |
("Degree Ultra Clear Long Lasting Men's Antiperspirant Deodorant Dry Spray,", 0), | |
("Tide PODS Liquid Laundry Detergent, Original Scent, HE Compatible, 42 Count", 0), | |
("DEER PARK Brand 100% Natural Spring Water, 16.9-ounce", 0), | |
("Great Value Milk Whole Vitamin D Gallon Plastic Jug", 0), | |
("Jumbo Russet Potatoes Whole Fresh, 8 lb Bag", 0), | |
("Great Value Butter Pecan Flavored Ice Cream, 16 fl oz", 0), | |
("Beef Lean Stew Meat, 1.0 - 1.5 lb Tray", 0), | |
("Great Value Spaghetti 16oz", 0), | |
("Great Value Flavored with Meat Pasta Sauce, 24 oz", 0), | |
("Kentucky Kernel Original Seasoned Flour, Coating Mix for Frying, Value Size", 0) | |
] | |
dataset = restricted_dataset + normal_dataset | |
# Split dataset | |
train_texts, val_texts, train_labels, val_labels = train_test_split([item[0] for item in dataset], [item[1] for item in dataset], test_size=0.2) | |
import pdb; pdb.set_trace() | |
# Load pre-trained BERT tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Tokenize data | |
train_encodings = tokenizer(train_texts, truncation=True, padding=True) | |
val_encodings = tokenizer(val_texts, truncation=True, padding=True) | |
# Convert to torch Dataset | |
class ShoppingCartDataset(torch.utils.data.Dataset): | |
def __init__(self, encodings, labels): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
item['labels'] = torch.tensor(self.labels[idx]) | |
return item | |
def __len__(self): | |
return len(self.labels) | |
train_dataset = ShoppingCartDataset(train_encodings, train_labels) | |
val_dataset = ShoppingCartDataset(val_encodings, val_labels) | |
# Load pre-trained BERT model | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir='../results', | |
num_train_epochs=3, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
) | |
# Train model | |
trainer.train() | |
# Evaluate model | |
trainer.evaluate() | |
model.save_pretrained('trained_model') | |