Kevin Fink
commited on
Commit
·
53d2cb3
1
Parent(s):
64a72dd
dev
Browse files
app.py
CHANGED
@@ -54,7 +54,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
54 |
num_train_epochs=int(num_epochs),
|
55 |
weight_decay=0.01,
|
56 |
#gradient_accumulation_steps=int(grad),
|
57 |
-
max_grad_norm = 1.0,
|
58 |
load_best_model_at_end=True,
|
59 |
metric_for_best_model="accuracy",
|
60 |
greater_is_better=True,
|
@@ -156,7 +156,7 @@ def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
|
156 |
|
157 |
config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
|
158 |
model = AutoModelForSeq2SeqLM.from_config(config)
|
159 |
-
print(model.named_parameters())
|
160 |
initialize_weights(model)
|
161 |
lora_config = LoraConfig(
|
162 |
r=16, # Rank of the low-rank adaptation
|
@@ -165,7 +165,6 @@ def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
|
165 |
bias="none" # Bias handling
|
166 |
)
|
167 |
model = get_peft_model(model, lora_config)
|
168 |
-
model.gradient_checkpointing_enable()
|
169 |
result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
|
170 |
return result
|
171 |
# Create Gradio interface
|
|
|
54 |
num_train_epochs=int(num_epochs),
|
55 |
weight_decay=0.01,
|
56 |
#gradient_accumulation_steps=int(grad),
|
57 |
+
#max_grad_norm = 1.0,
|
58 |
load_best_model_at_end=True,
|
59 |
metric_for_best_model="accuracy",
|
60 |
greater_is_better=True,
|
|
|
156 |
|
157 |
config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
|
158 |
model = AutoModelForSeq2SeqLM.from_config(config)
|
159 |
+
print(list(model.named_parameters()))
|
160 |
initialize_weights(model)
|
161 |
lora_config = LoraConfig(
|
162 |
r=16, # Rank of the low-rank adaptation
|
|
|
165 |
bias="none" # Bias handling
|
166 |
)
|
167 |
model = get_peft_model(model, lora_config)
|
|
|
168 |
result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
|
169 |
return result
|
170 |
# Create Gradio interface
|