entropy commited on
Commit
20d0936
1 Parent(s): 963134f

Update train_script.py

Browse files

Added loading pretrained model

Files changed (1) hide show
  1. train_script.py +7 -2
train_script.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import torch
5
  import torch.nn as nn
6
 
7
- from transformers import GPT2TokenizerFast, GPT2LMHeadModel
8
  from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
9
  from transformers import Trainer, TrainingArguments, RobertaTokenizerFast
10
 
@@ -58,7 +58,7 @@ dataset = dataset.with_format("torch")
58
  tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
59
  collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
60
 
61
-
62
  config = GPT2Config(
63
  vocab_size=len(tokenizer),
64
  n_positions=TOKENIZER_MAX_LEN,
@@ -71,6 +71,11 @@ config = GPT2Config(
71
 
72
  model = ConditionalGPT2LMHeadModel(config)
73
 
 
 
 
 
 
74
  # change trainer args as needed
75
  args = TrainingArguments(
76
  output_dir=TRAINER_SAVE_DIR,
 
4
  import torch
5
  import torch.nn as nn
6
 
7
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM
8
  from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
9
  from transformers import Trainer, TrainingArguments, RobertaTokenizerFast
10
 
 
58
  tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
59
  collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
60
 
61
+ # train from scratch
62
  config = GPT2Config(
63
  vocab_size=len(tokenizer),
64
  n_positions=TOKENIZER_MAX_LEN,
 
71
 
72
  model = ConditionalGPT2LMHeadModel(config)
73
 
74
+ # alternatively, load a pre-trained model
75
+ # commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
76
+ # model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
77
+ # trust_remote_code=True, revision=commit_hash)
78
+
79
  # change trainer args as needed
80
  args = TrainingArguments(
81
  output_dir=TRAINER_SAVE_DIR,