RealTimeAnswer / utils.py
GabrielSalem's picture
Create utils.py
1b1d234 verified
raw
history blame
1.04 kB
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
def preprocess_data(df, tokenizer):
df["text"] = df.apply(lambda row: f"Question: {row['Question']} Answer: {row['Answer']}", axis=1)
dataset = Dataset.from_pandas(df)
dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512), batched=True)
return dataset
def train_model(model, tokenizer, dataset, output_dir):
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
num_train_epochs=1,
logging_dir="./logs",
save_steps=10,
logging_steps=10
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator
)
trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)