VinitT commited on
Commit
2c2a1bc
1 Parent(s): 4f5f162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -19
app.py CHANGED
@@ -6,8 +6,6 @@ from diffusers import StableDiffusionPipeline
6
  from transformers import CLIPTokenizer
7
  import os
8
  import zipfile
9
- import tempfile
10
- import shutil
11
  import gradio as gr
12
 
13
  # Define the device
@@ -104,26 +102,18 @@ def zip_model(model_path):
104
  zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
105
  return zip_path
106
 
 
 
 
 
 
 
 
107
  # Gradio interface functions
108
  def start_fine_tuning(uploaded_files, prompts, num_epochs):
109
- # Create a temporary directory for storing files
110
- temp_dir = tempfile.mkdtemp()
111
- print("Temporary directory:", temp_dir)
112
-
113
- images = []
114
- for file in uploaded_files:
115
- # Store the uploaded file in the temp directory
116
- image_path = os.path.join(temp_dir, file.name)
117
- with open(image_path, 'wb') as f:
118
- f.write(file.read()) # Save file content
119
- images.append(Image.open(image_path).convert("RGB"))
120
-
121
  model_save_path = "fine_tuned_model"
122
  fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
123
-
124
- # Clean up the temporary directory after fine-tuning
125
- shutil.rmtree(temp_dir)
126
-
127
  return "Fine-tuning completed! Model is ready for download."
128
 
129
  def download_model():
@@ -173,4 +163,4 @@ with gr.Blocks() as demo:
173
 
174
  generate_button.click(generate_new_image, [prompt_input], generated_image)
175
 
176
- demo.launch()
 
6
  from transformers import CLIPTokenizer
7
  import os
8
  import zipfile
 
 
9
  import gradio as gr
10
 
11
  # Define the device
 
102
  zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
103
  return zip_path
104
 
105
+ # Function to save uploaded files
106
+ def save_uploaded_file(uploaded_file, save_path):
107
+ # Open the file in binary write mode
108
+ with open(save_path, 'wb') as f:
109
+ f.write(uploaded_file.data) # Use .data for the file content
110
+ return f"File saved at {save_path}"
111
+
112
  # Gradio interface functions
113
  def start_fine_tuning(uploaded_files, prompts, num_epochs):
114
+ images = [Image.open(file).convert("RGB") for file in uploaded_files]
 
 
 
 
 
 
 
 
 
 
 
115
  model_save_path = "fine_tuned_model"
116
  fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
 
 
 
 
117
  return "Fine-tuning completed! Model is ready for download."
118
 
119
  def download_model():
 
163
 
164
  generate_button.click(generate_new_image, [prompt_input], generated_image)
165
 
166
+ demo.launch()