Spaces:
Running
Running
NotShrirang
commited on
Commit
·
bdc8737
1
Parent(s):
a865eda
fix: bugs
Browse files
model_finetuning/components/model_training_component.py
CHANGED
@@ -18,7 +18,7 @@ import os
|
|
18 |
|
19 |
def model_training():
|
20 |
dataset_path = st.session_state.get("selected_dataset", None)
|
21 |
-
if not dataset_path:
|
22 |
st.error("Please select a dataset to proceed.")
|
23 |
return
|
24 |
|
@@ -36,6 +36,8 @@ def model_training():
|
|
36 |
|
37 |
test_size = st.selectbox("Select Test Size", options=[0.1, 0.2, 0.3, 0.4, 0.5], index=1)
|
38 |
train_df, val_df = train_test_split(annotations_df, test_size=test_size, random_state=42)
|
|
|
|
|
39 |
st.write(f"Train Size: {len(train_df)} | Validation Size: {len(val_df)}")
|
40 |
col1, col2 = st.columns(2)
|
41 |
with col1:
|
@@ -50,6 +52,8 @@ def model_training():
|
|
50 |
for batch_size in batch_size_options:
|
51 |
if batch_size > ideal_batch_size:
|
52 |
ideal_batch_size_index = batch_size_options.index(batch_size) - 1
|
|
|
|
|
53 |
break
|
54 |
batch_size = st.selectbox("Select Batch Size", options=[2, 4, 8, 16, 32, 64, 128], index=ideal_batch_size_index)
|
55 |
|
|
|
18 |
|
19 |
def model_training():
|
20 |
dataset_path = st.session_state.get("selected_dataset", None)
|
21 |
+
if not dataset_path or dataset_path == "":
|
22 |
st.error("Please select a dataset to proceed.")
|
23 |
return
|
24 |
|
|
|
36 |
|
37 |
test_size = st.selectbox("Select Test Size", options=[0.1, 0.2, 0.3, 0.4, 0.5], index=1)
|
38 |
train_df, val_df = train_test_split(annotations_df, test_size=test_size, random_state=42)
|
39 |
+
if len(train_df) < 2:
|
40 |
+
st.error("Not enough data to train the model.")
|
41 |
st.write(f"Train Size: {len(train_df)} | Validation Size: {len(val_df)}")
|
42 |
col1, col2 = st.columns(2)
|
43 |
with col1:
|
|
|
52 |
for batch_size in batch_size_options:
|
53 |
if batch_size > ideal_batch_size:
|
54 |
ideal_batch_size_index = batch_size_options.index(batch_size) - 1
|
55 |
+
if ideal_batch_size_index < 0:
|
56 |
+
ideal_batch_size_index = 0
|
57 |
break
|
58 |
batch_size = st.selectbox("Select Batch Size", options=[2, 4, 8, 16, 32, 64, 128], index=ideal_batch_size_index)
|
59 |
|