|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2-large") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Kaludi/chatgpt-gpt4-prompts-bart-large-cnn-samsum", from_tf=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./gpt4-text-gen", |
|
overwrite_output_dir=True, |
|
per_device_train_batch_size=4, |
|
save_steps=10_000, |
|
save_total_limit=2, |
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=your_training_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
model.save_pretrained("./gpt4-text-gen") |
|
tokenizer.save_pretrained("./gpt4-text-gen") |
|
|
|
|
|
input_text = "Once upon a time" |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
print("Generated Text: ", generated_text) |