Kevin Fink commited on
Commit
94aee2e
·
1 Parent(s): 50f7a65
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -86,23 +86,25 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
86
 
87
  tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
88
 
 
 
89
  def tokenize_function(examples):
90
 
91
  # Assuming 'text' is the input and 'target' is the expected output
92
  model_inputs = tokenizer(
93
  examples['text'],
94
- #max_length=max_length, # Set to None for dynamic padding
95
- #truncation=True,
96
- #padding='max_length',
97
  return_tensors='pt',
98
  )
99
 
100
  # Setup the decoder input IDs (shifted right)
101
  labels = tokenizer(
102
  examples['target'],
103
- #max_length=max_length, # Set to None for dynamic padding
104
- #truncation=True,
105
- #padding='max_length',
106
  #text_target=examples['target'],
107
  return_tensors='pt',
108
  )
@@ -113,7 +115,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
113
 
114
  #max_length = 512
115
  # Load the dataset
116
- max_length = model.get_input_embeddings().weight.shape[0]
117
 
118
  try:
119
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
 
86
 
87
  tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
88
 
89
+ max_length = model.get_input_embeddings().weight.shape[0]
90
+
91
  def tokenize_function(examples):
92
 
93
  # Assuming 'text' is the input and 'target' is the expected output
94
  model_inputs = tokenizer(
95
  examples['text'],
96
+ max_length=max_length, # Set to None for dynamic padding
97
+ truncation=True,
98
+ padding=True,
99
  return_tensors='pt',
100
  )
101
 
102
  # Setup the decoder input IDs (shifted right)
103
  labels = tokenizer(
104
  examples['target'],
105
+ max_length=max_length, # Set to None for dynamic padding
106
+ truncation=True,
107
+ padding=True,
108
  #text_target=examples['target'],
109
  return_tensors='pt',
110
  )
 
115
 
116
  #max_length = 512
117
  # Load the dataset
118
+
119
 
120
  try:
121
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')