AideepImage commited on
Commit
84b71f7
β€’
1 Parent(s): 553ebf1

Update txt2panoimg/text_to_360panorama_image_pipeline.py

Browse files
txt2panoimg/text_to_360panorama_image_pipeline.py CHANGED
@@ -1,19 +1,32 @@
1
- # Copyright Β© Alibaba, Inc. and its affiliates.
2
  import random
3
  from typing import Any, Dict
4
 
5
  import numpy as np
6
  import torch
7
- from basicsr.archs.rrdbnet_arch import RRDBNet
8
  from diffusers import (ControlNetModel, DiffusionPipeline,
9
  EulerAncestralDiscreteScheduler,
10
  UniPCMultistepScheduler)
11
  from PIL import Image
12
- from realesrgan import RealESRGANer
13
 
14
  from .pipeline_base import StableDiffusionBlendExtendPipeline
15
  from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class Text2360PanoramaImagePipeline(DiffusionPipeline):
19
  """ Stable Diffusion for 360 Panorama Image Generation Pipeline.
@@ -40,7 +53,7 @@ class Text2360PanoramaImagePipeline(DiffusionPipeline):
40
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
41
  ) if device is None else device
42
  if device == 'gpu':
43
- device = 'cuda'
44
 
45
  torch_dtype = kwargs.get('torch_dtype', torch.float16)
46
  enable_xformers_memory_efficient_attention = kwargs.get(
@@ -60,7 +73,6 @@ class Text2360PanoramaImagePipeline(DiffusionPipeline):
60
  self.pipe.enable_xformers_memory_efficient_attention()
61
  except Exception as e:
62
  print(e)
63
- self.pipe.enable_model_cpu_offload()
64
 
65
  # init controlnet-sr model
66
  base_model_path = model + '/sr-base'
@@ -79,35 +91,15 @@ class Text2360PanoramaImagePipeline(DiffusionPipeline):
79
  self.pipe_sr.enable_xformers_memory_efficient_attention()
80
  except Exception as e:
81
  print(e)
82
- self.pipe_sr.enable_model_cpu_offload()
83
-
84
- # init realesrgan model
85
- sr_model = RRDBNet(
86
- num_in_ch=3,
87
- num_out_ch=3,
88
- num_feat=64,
89
- num_block=23,
90
- num_grow_ch=32,
91
- scale=2)
92
- netscale = 2
93
-
94
  model_path = model + '/RealESRGAN_x2plus.pth'
95
-
96
- dni_weight = None
97
- self.upsampler = RealESRGANer(
98
- scale=netscale,
99
- model_path=model_path,
100
- dni_weight=dni_weight,
101
- model=sr_model,
102
- tile=384,
103
- tile_pad=20,
104
- pre_pad=20,
105
- half=False,
106
- device=device,
107
- )
108
 
109
  @staticmethod
110
  def blend_h(a, b, blend_extent):
 
 
111
  blend_extent = min(a.shape[1], b.shape[1], blend_extent)
112
  for x in range(blend_extent):
113
  b[:, x, :] = a[:, -blend_extent
@@ -188,8 +180,8 @@ class Text2360PanoramaImagePipeline(DiffusionPipeline):
188
  output_img = np.array(output_img)
189
  output_img = np.concatenate(
190
  [output_img, output_img[:, :blend_extend, :]], axis=1)
191
- output_img, _ = self.upsampler.enhance(
192
- output_img, outscale=outscale)
193
  output_img = self.blend_h(output_img, output_img,
194
  blend_extend * outscale)
195
  output_img = Image.fromarray(output_img[:, :w * outscale, :])
 
 
1
  import random
2
  from typing import Any, Dict
3
 
4
  import numpy as np
5
  import torch
 
6
  from diffusers import (ControlNetModel, DiffusionPipeline,
7
  EulerAncestralDiscreteScheduler,
8
  UniPCMultistepScheduler)
9
  from PIL import Image
10
+ from RealESRGAN import RealESRGAN
11
 
12
  from .pipeline_base import StableDiffusionBlendExtendPipeline
13
  from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
14
 
15
+ class LazyRealESRGAN:
16
+ def __init__(self, device, scale):
17
+ self.device = device
18
+ self.scale = scale
19
+ self.model = None
20
+ self.model_path = None
21
+
22
+ def load_model(self):
23
+ if self.model is None:
24
+ self.model = RealESRGAN(self.device, scale=self.scale)
25
+ self.model.load_weights(self.model_path, download=False)
26
+
27
+ def predict(self, img):
28
+ self.load_model()
29
+ return self.model.predict(img)
30
 
31
  class Text2360PanoramaImagePipeline(DiffusionPipeline):
32
  """ Stable Diffusion for 360 Panorama Image Generation Pipeline.
 
53
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
54
  ) if device is None else device
55
  if device == 'gpu':
56
+ device = torch.device('cuda')
57
 
58
  torch_dtype = kwargs.get('torch_dtype', torch.float16)
59
  enable_xformers_memory_efficient_attention = kwargs.get(
 
73
  self.pipe.enable_xformers_memory_efficient_attention()
74
  except Exception as e:
75
  print(e)
 
76
 
77
  # init controlnet-sr model
78
  base_model_path = model + '/sr-base'
 
91
  self.pipe_sr.enable_xformers_memory_efficient_attention()
92
  except Exception as e:
93
  print(e)
94
+ device = torch.device("cuda")
 
 
 
 
 
 
 
 
 
 
 
95
  model_path = model + '/RealESRGAN_x2plus.pth'
96
+ self.upsampler = LazyRealESRGAN(device=device, scale=2)
97
+ self.upsampler.model_path = model_path
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  @staticmethod
100
  def blend_h(a, b, blend_extent):
101
+ a = np.array(a)
102
+ b = np.array(b)
103
  blend_extent = min(a.shape[1], b.shape[1], blend_extent)
104
  for x in range(blend_extent):
105
  b[:, x, :] = a[:, -blend_extent
 
180
  output_img = np.array(output_img)
181
  output_img = np.concatenate(
182
  [output_img, output_img[:, :blend_extend, :]], axis=1)
183
+ output_img = self.upsampler.predict(
184
+ output_img)
185
  output_img = self.blend_h(output_img, output_img,
186
  blend_extend * outscale)
187
  output_img = Image.fromarray(output_img[:, :w * outscale, :])