JackAILab commited on
Commit
65cbb98
·
verified ·
1 Parent(s): 526e738

Update pipline_StableDiffusion_ConsistentID.py

Browse files
pipline_StableDiffusion_ConsistentID.py CHANGED
@@ -21,6 +21,7 @@ from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, Facial
21
  ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
22
  ### Thanks for the open source of face-parsing model.
23
  from models.BiSeNet.model import BiSeNet
 
24
 
25
  PipelineImageInput = Union[
26
  PIL.Image.Image,
@@ -31,7 +32,6 @@ PipelineImageInput = Union[
31
 
32
  ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
33
  class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
34
-
35
  @validate_hf_hub_args
36
  def load_ConsistentID_model(
37
  self,
@@ -65,7 +65,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
65
  ### BiSeNet
66
  self.bise_net = BiSeNet(n_classes = 19)
67
  self.bise_net.cuda()
68
- self.bise_net_cp='JackAILab/ConsistentID/face_parsing.pth'
69
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
70
  self.bise_net.eval()
71
  # Colors for all 20 parts
@@ -124,11 +124,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
124
  if weight_name.endswith(".safetensors"):
125
  state_dict = {"id_encoder": {}, "lora_weights": {}}
126
  with safe_open(model_file, framework="pt", device="cpu") as f:
 
127
  for key in f.keys():
128
- if key.startswith("id_encoder."):
129
- state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
130
- elif key.startswith("lora_weights."):
131
- state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
132
  else:
133
  state_dict = torch.load(model_file, map_location="cpu")
134
  else:
 
21
  ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
22
  ### Thanks for the open source of face-parsing model.
23
  from models.BiSeNet.model import BiSeNet
24
+ bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", repo_type="model")
25
 
26
  PipelineImageInput = Union[
27
  PIL.Image.Image,
 
32
 
33
  ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
34
  class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
 
35
  @validate_hf_hub_args
36
  def load_ConsistentID_model(
37
  self,
 
65
  ### BiSeNet
66
  self.bise_net = BiSeNet(n_classes = 19)
67
  self.bise_net.cuda()
68
+ self.bise_net_cp=bise_net_cp_path
69
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
70
  self.bise_net.eval()
71
  # Colors for all 20 parts
 
124
  if weight_name.endswith(".safetensors"):
125
  state_dict = {"id_encoder": {}, "lora_weights": {}}
126
  with safe_open(model_file, framework="pt", device="cpu") as f:
127
+ ### TODO safetensors add
128
  for key in f.keys():
129
+ if key.startswith("FacialEncoder."):
130
+ state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
131
+ elif key.startswith("image_proj."):
132
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
133
  else:
134
  state_dict = torch.load(model_file, map_location="cpu")
135
  else: