Kevin Fink
commited on
Commit
·
bc59d39
1
Parent(s):
69cfd5f
dev
Browse files
app.py
CHANGED
@@ -83,6 +83,8 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
83 |
print("Loading model from checkpoint...")
|
84 |
model = AutoModelForSeq2SeqLM.from_pretrained(training_args.output_dir)
|
85 |
|
|
|
|
|
86 |
def tokenize_function(examples):
|
87 |
|
88 |
# Assuming 'text' is the input and 'target' is the expected output
|
@@ -115,7 +117,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
115 |
third_size = train_size // 3
|
116 |
max_length = model.get_input_embeddings().weight.shape[0]
|
117 |
try:
|
118 |
-
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
119 |
if 'test' in saved_dataset.keys():
|
120 |
print("FOUND TEST")
|
121 |
# Create Trainer
|
@@ -144,7 +146,6 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
144 |
return
|
145 |
|
146 |
except:
|
147 |
-
tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
|
148 |
# Tokenize the dataset
|
149 |
first_third = dataset['train'].select(range(third_size))
|
150 |
dataset['train'] = first_third
|
|
|
83 |
print("Loading model from checkpoint...")
|
84 |
model = AutoModelForSeq2SeqLM.from_pretrained(training_args.output_dir)
|
85 |
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
|
87 |
+
|
88 |
def tokenize_function(examples):
|
89 |
|
90 |
# Assuming 'text' is the input and 'target' is the expected output
|
|
|
117 |
third_size = train_size // 3
|
118 |
max_length = model.get_input_embeddings().weight.shape[0]
|
119 |
try:
|
120 |
+
saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
|
121 |
if 'test' in saved_dataset.keys():
|
122 |
print("FOUND TEST")
|
123 |
# Create Trainer
|
|
|
146 |
return
|
147 |
|
148 |
except:
|
|
|
149 |
# Tokenize the dataset
|
150 |
first_third = dataset['train'].select(range(third_size))
|
151 |
dataset['train'] = first_third
|