Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,6 @@ from diffusers import StableDiffusionPipeline
|
|
6 |
from transformers import CLIPTokenizer
|
7 |
import os
|
8 |
import zipfile
|
9 |
-
import tempfile
|
10 |
-
import shutil
|
11 |
import gradio as gr
|
12 |
|
13 |
# Define the device
|
@@ -104,26 +102,18 @@ def zip_model(model_path):
|
|
104 |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
|
105 |
return zip_path
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
# Gradio interface functions
|
108 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
109 |
-
|
110 |
-
temp_dir = tempfile.mkdtemp()
|
111 |
-
print("Temporary directory:", temp_dir)
|
112 |
-
|
113 |
-
images = []
|
114 |
-
for file in uploaded_files:
|
115 |
-
# Store the uploaded file in the temp directory
|
116 |
-
image_path = os.path.join(temp_dir, file.name)
|
117 |
-
with open(image_path, 'wb') as f:
|
118 |
-
f.write(file.read()) # Save file content
|
119 |
-
images.append(Image.open(image_path).convert("RGB"))
|
120 |
-
|
121 |
model_save_path = "fine_tuned_model"
|
122 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
123 |
-
|
124 |
-
# Clean up the temporary directory after fine-tuning
|
125 |
-
shutil.rmtree(temp_dir)
|
126 |
-
|
127 |
return "Fine-tuning completed! Model is ready for download."
|
128 |
|
129 |
def download_model():
|
@@ -173,4 +163,4 @@ with gr.Blocks() as demo:
|
|
173 |
|
174 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
175 |
|
176 |
-
demo.launch()
|
|
|
6 |
from transformers import CLIPTokenizer
|
7 |
import os
|
8 |
import zipfile
|
|
|
|
|
9 |
import gradio as gr
|
10 |
|
11 |
# Define the device
|
|
|
102 |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
|
103 |
return zip_path
|
104 |
|
105 |
+
# Function to save uploaded files
|
106 |
+
def save_uploaded_file(uploaded_file, save_path):
|
107 |
+
# Open the file in binary write mode
|
108 |
+
with open(save_path, 'wb') as f:
|
109 |
+
f.write(uploaded_file.data) # Use .data for the file content
|
110 |
+
return f"File saved at {save_path}"
|
111 |
+
|
112 |
# Gradio interface functions
|
113 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
114 |
+
images = [Image.open(file).convert("RGB") for file in uploaded_files]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
model_save_path = "fine_tuned_model"
|
116 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
|
|
|
|
|
|
|
|
117 |
return "Fine-tuning completed! Model is ready for download."
|
118 |
|
119 |
def download_model():
|
|
|
163 |
|
164 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
165 |
|
166 |
+
demo.launch()
|