Spaces:
Build error
Build error
Update inference.py
Browse files- inference.py +10 -29
inference.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
from diffusers import StableDiffusionPipeline
|
11 |
|
12 |
sys.path.insert(0, 'lora')
|
13 |
-
from
|
14 |
|
15 |
|
16 |
class InferencePipeline:
|
@@ -28,24 +28,16 @@ class InferencePipeline:
|
|
28 |
gc.collect()
|
29 |
|
30 |
@staticmethod
|
31 |
-
def
|
32 |
curr_dir = pathlib.Path(__file__).parent
|
33 |
return curr_dir / name
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
parent_dir = path.parent
|
38 |
-
stem = path.stem
|
39 |
-
text_encoder_filename = f'{stem}.text_encoder.pt'
|
40 |
-
path = parent_dir / text_encoder_filename
|
41 |
-
return path.as_posix() if path.exists() else ''
|
42 |
-
|
43 |
-
def load_pipe(self, model_id: str, lora_filename: str) -> None:
|
44 |
-
weight_path = self.get_lora_weight_path(lora_filename)
|
45 |
if weight_path == self.weight_path:
|
46 |
return
|
47 |
self.weight_path = weight_path
|
48 |
-
|
49 |
|
50 |
if self.device.type == 'cpu':
|
51 |
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
@@ -54,40 +46,29 @@ class InferencePipeline:
|
|
54 |
model_id, torch_dtype=torch.float16)
|
55 |
pipe = pipe.to(self.device)
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
|
60 |
-
weight_path)
|
61 |
-
if lora_text_encoder_weight_path:
|
62 |
-
lora_text_encoder_weight = torch.load(
|
63 |
-
lora_text_encoder_weight_path, map_location=self.device)
|
64 |
-
monkeypatch_lora(pipe.text_encoder,
|
65 |
-
lora_text_encoder_weight,
|
66 |
-
target_replace_module=['CLIPAttention'])
|
67 |
|
68 |
self.pipe = pipe
|
69 |
|
70 |
def run(
|
71 |
self,
|
72 |
base_model: str,
|
73 |
-
|
74 |
prompt: str,
|
75 |
-
alpha: float,
|
76 |
-
alpha_for_text: float,
|
77 |
seed: int,
|
78 |
n_steps: int,
|
79 |
guidance_scale: float,
|
|
|
80 |
) -> PIL.Image.Image:
|
81 |
if not torch.cuda.is_available():
|
82 |
raise gr.Error('CUDA is not available.')
|
83 |
|
84 |
-
self.load_pipe(base_model,
|
85 |
|
86 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
87 |
-
tune_lora_scale(self.pipe.unet, alpha) # type: ignore
|
88 |
-
tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
|
89 |
out = self.pipe(prompt,
|
90 |
num_inference_steps=n_steps,
|
91 |
guidance_scale=guidance_scale,
|
|
|
92 |
generator=generator) # type: ignore
|
93 |
return out.images[0]
|
|
|
10 |
from diffusers import StableDiffusionPipeline
|
11 |
|
12 |
sys.path.insert(0, 'lora')
|
13 |
+
from src import sample_diffuser, diffuser_training
|
14 |
|
15 |
|
16 |
class InferencePipeline:
|
|
|
28 |
gc.collect()
|
29 |
|
30 |
@staticmethod
|
31 |
+
def get_weight_path(name: str) -> pathlib.Path:
|
32 |
curr_dir = pathlib.Path(__file__).parent
|
33 |
return curr_dir / name
|
34 |
|
35 |
+
def load_pipe(self, model_id: str, filename: str) -> None:
|
36 |
+
weight_path = self.get_weight_path(filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if weight_path == self.weight_path:
|
38 |
return
|
39 |
self.weight_path = weight_path
|
40 |
+
weight = torch.load(self.weight_path, map_location=self.device)
|
41 |
|
42 |
if self.device.type == 'cpu':
|
43 |
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
|
|
46 |
model_id, torch_dtype=torch.float16)
|
47 |
pipe = pipe.to(self.device)
|
48 |
|
49 |
+
diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
self.pipe = pipe
|
52 |
|
53 |
def run(
|
54 |
self,
|
55 |
base_model: str,
|
56 |
+
weight_name: str,
|
57 |
prompt: str,
|
|
|
|
|
58 |
seed: int,
|
59 |
n_steps: int,
|
60 |
guidance_scale: float,
|
61 |
+
eta: float,
|
62 |
) -> PIL.Image.Image:
|
63 |
if not torch.cuda.is_available():
|
64 |
raise gr.Error('CUDA is not available.')
|
65 |
|
66 |
+
self.load_pipe(base_model, weight_name)
|
67 |
|
68 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
|
|
|
69 |
out = self.pipe(prompt,
|
70 |
num_inference_steps=n_steps,
|
71 |
guidance_scale=guidance_scale,
|
72 |
+
eta = eta,
|
73 |
generator=generator) # type: ignore
|
74 |
return out.images[0]
|