cocktailpeanut commited on
Commit
24edeed
·
1 Parent(s): 0fe01cf
Files changed (3) hide show
  1. OmniGen/pipeline.py +3 -3
  2. app.py +1 -3
  3. requirements.txt +1 -3
OmniGen/pipeline.py CHANGED
@@ -18,7 +18,7 @@ from diffusers.utils import (
18
  )
19
 
20
  from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
21
-
22
 
23
  logger = logging.get_logger(__name__)
24
 
@@ -52,7 +52,7 @@ class OmniGenPipeline:
52
  self.model = model
53
  self.processor = processor
54
 
55
- self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
56
  self.model.to(self.device)
57
  self.model.eval()
58
  self.vae.to(self.device)
@@ -226,4 +226,4 @@ class OmniGenPipeline:
226
  for i, sample in enumerate(output_samples):
227
  output_images.append(Image.fromarray(sample))
228
 
229
- return output_images
 
18
  )
19
 
20
  from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
21
+ import devicetorch
22
 
23
  logger = logging.get_logger(__name__)
24
 
 
52
  self.model = model
53
  self.processor = processor
54
 
55
+ self.device = devicetorch.get(torch)
56
  self.model.to(self.device)
57
  self.model.eval()
58
  self.vae.to(self.device)
 
226
  for i, sample in enumerate(output_samples):
227
  output_images.append(Image.fromarray(sample))
228
 
229
+ return output_images
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
- import spaces
5
 
6
  from OmniGen import OmniGenPipeline
7
 
@@ -9,7 +8,6 @@ pipe = OmniGenPipeline.from_pretrained(
9
  "Shitao/OmniGen-v1"
10
  )
11
 
12
- @spaces.GPU(duration=180)
13
  # 示例处理函数:生成图像
14
  def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
15
  input_images = [img1, img2, img3]
@@ -250,4 +248,4 @@ with gr.Blocks() as demo:
250
  )
251
 
252
  # 启动应用
253
- demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
 
4
 
5
  from OmniGen import OmniGenPipeline
6
 
 
8
  "Shitao/OmniGen-v1"
9
  )
10
 
 
11
  # 示例处理函数:生成图像
12
  def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
13
  input_images = [img1, img2, img3]
 
248
  )
249
 
250
  # 启动应用
251
+ demo.launch()
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
  accelerate
2
  diffusers
3
  invisible_watermark
4
- torch
5
  transformers
6
- xformers
7
  timm
8
  peft
9
- safetensors
 
1
  accelerate
2
  diffusers
3
  invisible_watermark
 
4
  transformers
 
5
  timm
6
  peft
7
+ safetensors