multimodalart HF staff commited on
Commit
1f2f15c
·
verified ·
1 Parent(s): 77e039c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -5
app.py CHANGED
@@ -1,8 +1,50 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  with gr.Blocks() as demo:
4
- gr.Markdown('''# CosXL unofficial demo
5
- CosXL is a SDXL model tuned to produce full color range images. CosXL Edit allows you to perform edits on images. Both models have a [non-commercial community license](https://huggingface.co/stabilityai/cosxl/blob/main/LICENSE)
6
  ''')
7
  with gr.Tab("CosXL"):
8
  with gr.Group():
@@ -14,13 +56,22 @@ with gr.Blocks() as demo:
14
  pass
15
  with gr.Tab("CosXL Edit"):
16
  with gr.Group():
17
- image_edit = gr.Image(label="Image you would like to edit")
18
  with gr.Row():
19
  prompt_edit = gr.Textbox(show_label=False, scale=4, placeholder="Edit instructions, e.g.: Make the day cloudy")
20
  button_edit = gr.Button("Generate", min_width=120)
21
  output_edit = gr.Image(label="Your result image", interactive=False)
22
  with gr.Accordion("Advanced Settings", open=False):
23
  pass
24
-
 
 
 
 
 
 
 
 
 
25
  if __name__ == "__main__":
26
- demo.launch()
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, EDMEulerScheduler
3
+ from custom_pipeline import CosStableDiffusionXLInstructPix2PixPipeline
4
+ import spaces
5
+
6
+ edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
7
+ normal_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl.safetensors")
8
+
9
+ def set_timesteps_patched(self, num_inference_steps: int, device = None):
10
+ self.num_inference_steps = num_inference_steps
11
+
12
+ ramp = np.linspace(0, 1, self.num_inference_steps)
13
+ sigmas = torch.linspace(math.log(self.config.sigma_min), math.log(self.config.sigma_max), len(ramp)).exp().flip(0)
14
+
15
+ sigmas = (sigmas).to(dtype=torch.float32, device=device)
16
+ self.timesteps = self.precondition_noise(sigmas)
17
+
18
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
19
+ self._step_index = None
20
+ self._begin_index = None
21
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
22
+
23
+ EDMEulerScheduler.set_timesteps = set_timesteps_patched
24
+
25
+ pipe_edit = CosStableDiffusionXLInstructPix2PixPipeline.from_single_file(
26
+ edit_file, num_in_channels=8
27
+ )
28
+ pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
29
+ pipe_edit.to("cuda")
30
+
31
+ pipe_normal = StableDiffusionXLPipeline.from_single_file(normal_file, torch_dtype=torch.float16)
32
+ pipe_normal.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
33
+ pipe_normal.to("cuda")
34
+
35
+ @spaces.GPU
36
+ def run_normal(prompt):
37
+ return pipe_normal(prompt, num_inference_steps=20).images[0]
38
+
39
+ @spaces.GPU
40
+ def run_edit(image, prompt):
41
+ resolution = 1024
42
+ image.resize((resolution, resolution))
43
+ return pipe_edit(prompt=prompt,image=image,height=resolution,width=resolution,num_inference_steps=20).images[0]
44
 
45
  with gr.Blocks() as demo:
46
+ gr.Markdown('''# CosXL demo
47
+ Unofficial demo for CosXL, a SDXL model tuned to produce full color range images. CosXL Edit allows you to perform edits on images. Both have a [non-commercial community license](https://huggingface.co/stabilityai/cosxl/blob/main/LICENSE)
48
  ''')
49
  with gr.Tab("CosXL"):
50
  with gr.Group():
 
56
  pass
57
  with gr.Tab("CosXL Edit"):
58
  with gr.Group():
59
+ image_edit = gr.Image(label="Image you would like to edit", type="pil")
60
  with gr.Row():
61
  prompt_edit = gr.Textbox(show_label=False, scale=4, placeholder="Edit instructions, e.g.: Make the day cloudy")
62
  button_edit = gr.Button("Generate", min_width=120)
63
  output_edit = gr.Image(label="Your result image", interactive=False)
64
  with gr.Accordion("Advanced Settings", open=False):
65
  pass
66
+ button_normal.click(
67
+ fn=run_normal,
68
+ inputs=[prompt_normal],
69
+ outputs=[output_normal]
70
+ )
71
+ button_edit.click(
72
+ fn=run_edit,
73
+ inputs=[image_edit, prompt_edit],
74
+ outputs=[output_edit]
75
+ )
76
  if __name__ == "__main__":
77
+ demo.launch(share=True)