Ffftdtd5dtft commited on
Commit
2087ed7
·
verified ·
1 Parent(s): e0fce48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -72,31 +72,38 @@ def get_model_or_download(model_id, redis_key, loader_func):
72
 
73
  def generate_image(prompt):
74
  redis_key = f"generated_image:{prompt}"
75
- image = load_object_from_redis(redis_key)
76
- if not image:
77
  try:
78
  with tqdm(total=1, desc="Generating image") as pbar:
79
  image = text_to_image_pipeline(prompt).images[0]
80
  pbar.update(1)
81
- save_object_to_redis(redis_key, image)
 
 
 
82
  except Exception as e:
83
  print(f"Failed to generate image: {e}")
84
  return None
85
- return image
86
 
87
- def edit_image_with_prompt(image, prompt, strength=0.75):
88
  redis_key = f"edited_image:{prompt}:{strength}"
89
- edited_image = load_object_from_redis(redis_key)
90
- if not edited_image:
91
  try:
 
92
  with tqdm(total=1, desc="Editing image") as pbar:
93
  edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
94
  pbar.update(1)
95
- save_object_to_redis(redis_key, edited_image)
 
 
 
96
  except Exception as e:
97
  print(f"Failed to edit image: {e}")
98
  return None
99
- return edited_image
100
 
101
  def generate_song(prompt, duration=10):
102
  redis_key = f"generated_song:{prompt}:{duration}"
@@ -128,8 +135,8 @@ def generate_text(prompt):
128
 
129
  def generate_flux_image(prompt):
130
  redis_key = f"generated_flux_image:{prompt}"
131
- flux_image = load_object_from_redis(redis_key)
132
- if not flux_image:
133
  try:
134
  with tqdm(total=1, desc="Generating FLUX image") as pbar:
135
  flux_image = flux_pipeline(
@@ -140,11 +147,14 @@ def generate_flux_image(prompt):
140
  generator=torch.Generator("cpu").manual_seed(0)
141
  ).images[0]
142
  pbar.update(1)
143
- save_object_to_redis(redis_key, flux_image)
 
 
 
144
  except Exception as e:
145
  print(f"Failed to generate flux image: {e}")
146
  return None
147
- return flux_image
148
 
149
  def generate_code(prompt):
150
  redis_key = f"generated_code:{prompt}"
@@ -240,19 +250,19 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
241
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
242
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
243
- text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b")
244
  music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody')
245
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
246
  starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
247
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
248
 
249
- gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Image")
250
- edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Image")
251
- generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
252
- generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
253
- generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
254
- generate_code_tab = gr.Interface(generate_code, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Code:"), title="Generate Code")
255
- model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
256
 
257
  app = gr.TabbedInterface(
258
  [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, generate_code_tab, model_meta_llama_test_tab],
 
72
 
73
  def generate_image(prompt):
74
  redis_key = f"generated_image:{prompt}"
75
+ image_bytes = load_object_from_redis(redis_key)
76
+ if not image_bytes:
77
  try:
78
  with tqdm(total=1, desc="Generating image") as pbar:
79
  image = text_to_image_pipeline(prompt).images[0]
80
  pbar.update(1)
81
+ buffered = io.BytesIO()
82
+ image.save(buffered, format="JPEG")
83
+ image_bytes = buffered.getvalue()
84
+ save_object_to_redis(redis_key, image_bytes)
85
  except Exception as e:
86
  print(f"Failed to generate image: {e}")
87
  return None
88
+ return image_bytes
89
 
90
+ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
91
  redis_key = f"edited_image:{prompt}:{strength}"
92
+ edited_image_bytes = load_object_from_redis(redis_key)
93
+ if not edited_image_bytes:
94
  try:
95
+ image = Image.open(io.BytesIO(image_bytes))
96
  with tqdm(total=1, desc="Editing image") as pbar:
97
  edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
98
  pbar.update(1)
99
+ buffered = io.BytesIO()
100
+ edited_image.save(buffered, format="JPEG")
101
+ edited_image_bytes = buffered.getvalue()
102
+ save_object_to_redis(redis_key, edited_image_bytes)
103
  except Exception as e:
104
  print(f"Failed to edit image: {e}")
105
  return None
106
+ return edited_image_bytes
107
 
108
  def generate_song(prompt, duration=10):
109
  redis_key = f"generated_song:{prompt}:{duration}"
 
135
 
136
  def generate_flux_image(prompt):
137
  redis_key = f"generated_flux_image:{prompt}"
138
+ flux_image_bytes = load_object_from_redis(redis_key)
139
+ if not flux_image_bytes:
140
  try:
141
  with tqdm(total=1, desc="Generating FLUX image") as pbar:
142
  flux_image = flux_pipeline(
 
147
  generator=torch.Generator("cpu").manual_seed(0)
148
  ).images[0]
149
  pbar.update(1)
150
+ buffered = io.BytesIO()
151
+ flux_image.save(buffered, format="JPEG")
152
+ flux_image_bytes = buffered.getvalue()
153
+ save_object_to_redis(redis_key, flux_image_bytes)
154
  except Exception as e:
155
  print(f"Failed to generate flux image: {e}")
156
  return None
157
+ return flux_image_bytes
158
 
159
  def generate_code(prompt):
160
  redis_key = f"generated_code:{prompt}"
 
250
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
251
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
252
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
253
+ text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b", device=0)
254
  music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody')
255
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
256
  starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
257
  starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
258
 
259
+ gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image")
260
+ edit_image_tab = gr.Interface(fn=edit_image_with_prompt, inputs=[gr.Image(type="pil", label="Image:"), gr.Textbox(label="Prompt:"), gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], outputs=gr.Image(type="pil"), title="Edit Image")
261
+ generate_song_tab = gr.Interface(fn=generate_song, inputs=[gr.Textbox(label="Prompt:"), gr.Slider(5, 60, 10, step=1, label="Duration (s):")], outputs=gr.Audio(type="numpy"), title="Generate Songs")
262
+ generate_text_tab = gr.Interface(fn=generate_text, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Text:"), title="Generate Text")
263
+ generate_flux_image_tab = gr.Interface(fn=generate_flux_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate FLUX Images")
264
+ generate_code_tab = gr.Interface(fn=generate_code, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Code:"), title="Generate Code")
265
+ model_meta_llama_test_tab = gr.Interface(fn=test_model_meta_llama, inputs=gr.Textbox(label="Test Input:"), outputs=gr.Textbox(label="Model Output:"), title="Test Meta-Llama")
266
 
267
  app = gr.TabbedInterface(
268
  [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, generate_code_tab, model_meta_llama_test_tab],