lfashionnlp / train.py
DINGOLANI's picture
Create train.py
d4cc803 verified
raw
history blame
2.25 kB
from datasets import load_dataset
from transformers import AutoModelForTokenClassification, AutoTokenizer, TrainingArguments, Trainer
import torch
# Load Dataset
dataset_path = "train-lf-final.jsonl" # Ensure this file is uploaded
dataset = load_dataset("json", data_files=dataset_path)
# Split dataset into training and validation sets
dataset = dataset["train"].train_test_split(test_size=0.1)
# Define label mapping
label_list = ["O", "B-BRAND", "I-BRAND", "B-CATEGORY", "I-CATEGORY", "B-GENDER", "B-PRICE", "I-PRICE"]
label_map = {label: i for i, label in enumerate(label_list)}
# Load Tokenizer
model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Tokenization function
def tokenize_and_align_labels(example):
tokenized_inputs = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, padding="max_length", max_length=128)
labels = []
word_ids = tokenized_inputs.word_ids()
prev_word_idx = None
for word_idx in word_ids:
if word_idx is None:
labels.append(-100)
elif word_idx != prev_word_idx:
labels.append(label_map[example["tags"][word_idx]])
else:
labels.append(label_map[example["tags"][word_idx]])
prev_word_idx = word_idx
tokenized_inputs["labels"] = labels
return tokenized_inputs
# Apply tokenization
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
# Load Model
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(label_list))
# Training Arguments
training_args = TrainingArguments(
output_dir="./ner_model",
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
push_to_hub=True,
logging_dir="./logs"
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
tokenizer=tokenizer
)
# Train the model
trainer.train()
# Push to Hugging Face Hub
model.push_to_hub("your-hf-username/distilbert-ner")
tokenizer.push_to_hub("your-hf-username/distilbert-ner")