Sheng Lei
Add application file
28de1fd
raw
history blame
5.3 kB
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
import torch
# Load dataset
# Assume dataset is a list of tuples (text, label)
# dataset = [("item description 1", 0), ("item description 2", 1), ...]
restricted_dataset = [
("Promotional Email GiftCard $10", 1),
("$100 Vanilla® Visa", 1),
("Promotional Email GiftCard $10", 1),
("GPS7000 GPS Tracker for", 1),
("$50 Vanilla Visa eGift", 1),
("GPS7000 GPS Tracker for", 1),
("$200 Vanilla Visa Shiny", 1),
("$200 Vanilla Visa Shiny", 1),
("$200 Vanilla Visa Shiny", 1),
("$200 Vanilla® Visa", 1),
("Promotional Email GiftCard $20", 1),
("Xbox $100 Gift Card", 1),
("$25 Vanilla Visa Shiny", 1),
("$25 Vanilla Visa Shiny", 1),
("$25 Vanilla Visa Shiny", 1),
("Xbox $100 Gift Card", 1),
("Home Depot Birthday Cupcake", 1),
("$100 Vanilla® Visa", 1),
("Birthday Celebration Target Giftcard", 1),
("Birthday Celebration Target Giftcard", 1),
("Promotional Email GiftCard $10", 1),
("Birthday Celebration Target Giftcard", 1),
("Nintendo Switch Family Online", 1),
("Promotional Email GiftCard $10", 1),
("$50 Vanilla® Mastercard", 1),
("$50 Vanilla® Mastercard", 1),
("$50 Vanilla® Mastercard", 1),
("$50 Vanilla® Visa", 1),
("$50 Vanilla Visa Shiny", 1),
("Delta Airlines Wedding $250", 1)
]
normal_dataset =[
("Kerrygold Grass-Fed Pure Irish Garlic & Herb Butter Stick, 3", 0),
("bettergoods Garlic, Parmesan, & Basil Butter, 3 oz", 0),
("Birds Eye Savory Herb Riced Cauliflower, 10 oz (Frozen)", 0),
("Great Value Root Blend, Beets, Carrots, Parsnips and Sweet Potatoes", 0),
("Fresh Blueberries, 18 oz Container", 0),
("Mixpresso 3 Piece Black Canisters Sets For The Kitchen, Kitchen Jars With", 0),
("Freshness Guaranteed Chicken Breast Tenderloins, 2.25 - 3.2", 0),
("Kiolbassa Smoked Meats Beef Hickory Smoked Sausage, 4 links - 13oz", 0),
("Hot Pockets Frozen Snacks, Pepperoni Pizza Buttery Crust, 5 Sandwiches", 0),
("Kool Aid Jammers Tropical Punch Kids Drink 0% Juice Box Pouches, 10", 0),
("Frito-Lay Flavor Mix Variety Pack Snack Chips, 1oz Bags, 18 Count", 0),
("State Fair Classic Corn Dogs, 42.7 oz, 16 Count", 0),
("ASURION 2 Year Sporting Goods Protection Plan ($175 - $199.99)", 0),
("6% Incline Walking Pad Treadmill 320+ lb Capacity, Under The Desk", 0),
("Renpure Biotin & Collagen Thickening Conditioner for All Hair Types, 32 fl", 0),
("Renpure Biotin & Collagen Thickening Hair Shampoo for All Hair Types, 32", 0),
("eos Shea Better Body Lotion for Dry Skin, Vanilla Cashmere, 16 fl", 0),
("Degree Ultra Clear Long Lasting Men's Antiperspirant Deodorant Dry Spray,", 0),
("Tide PODS Liquid Laundry Detergent, Original Scent, HE Compatible, 42 Count", 0),
("DEER PARK Brand 100% Natural Spring Water, 16.9-ounce", 0),
("Great Value Milk Whole Vitamin D Gallon Plastic Jug", 0),
("Jumbo Russet Potatoes Whole Fresh, 8 lb Bag", 0),
("Great Value Butter Pecan Flavored Ice Cream, 16 fl oz", 0),
("Beef Lean Stew Meat, 1.0 - 1.5 lb Tray", 0),
("Great Value Spaghetti 16oz", 0),
("Great Value Flavored with Meat Pasta Sauce, 24 oz", 0),
("Kentucky Kernel Original Seasoned Flour, Coating Mix for Frying, Value Size", 0)
]
dataset = restricted_dataset + normal_dataset
# Split dataset
train_texts, val_texts, train_labels, val_labels = train_test_split([item[0] for item in dataset], [item[1] for item in dataset], test_size=0.2)
import pdb; pdb.set_trace()
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Tokenize data
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
# Convert to torch Dataset
class ShoppingCartDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = ShoppingCartDataset(train_encodings, train_labels)
val_dataset = ShoppingCartDataset(val_encodings, val_labels)
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# Training arguments
training_args = TrainingArguments(
output_dir='../results',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
# Train model
trainer.train()
# Evaluate model
trainer.evaluate()
model.save_pretrained('trained_model')