MonsterMMORPG commited on
Commit
36ae03a
·
1 Parent(s): 7d405ed

Upload web-ui.py

Browse files
Files changed (1) hide show
  1. web-ui.py +26 -11
web-ui.py CHANGED
@@ -14,6 +14,7 @@ from pyngrok import ngrok
14
  import threading
15
  import time
16
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
 
17
 
18
  # Argument parser for command line options
19
  parser = argparse.ArgumentParser()
@@ -51,6 +52,12 @@ static_model_names = [
51
  model_cache = {}
52
  max_cache_size = args.cache_limit
53
 
 
 
 
 
 
 
54
  def convert_model(checkpoint_path, output_path, isSDXL):
55
  try:
56
  if isSDXL:
@@ -126,19 +133,27 @@ def generate_image(input_image, positive_prompt, negative_prompt, width, height,
126
  # Load and prepare the model
127
  ip_model = load_model(model_name, isSDXL)
128
 
129
- # Convert input image to the format expected by the model
130
  input_image = input_image.convert("RGB")
131
- input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
132
- app = FaceAnalysis(
133
- name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
134
- )
135
- app.prepare(ctx_id=0, det_size=(640, 640))
136
- faces = app.get(input_image)
137
- if not faces:
138
- raise ValueError("No faces found in the image.")
139
 
140
- faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
141
- face_image = face_align.norm_crop(input_image, landmark=faces[0].kps, image_size=224)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  for image_index in range(num_images):
144
  if randomize_seed or image_index > 0:
 
14
  import threading
15
  import time
16
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
17
+ import hashlib
18
 
19
  # Argument parser for command line options
20
  parser = argparse.ArgumentParser()
 
52
  model_cache = {}
53
  max_cache_size = args.cache_limit
54
 
55
+ embeddings_cache = {}
56
+
57
+ def get_image_hash(image):
58
+ image_bytes = image.tobytes()
59
+ return hashlib.sha256(image_bytes).hexdigest()
60
+
61
  def convert_model(checkpoint_path, output_path, isSDXL):
62
  try:
63
  if isSDXL:
 
133
  # Load and prepare the model
134
  ip_model = load_model(model_name, isSDXL)
135
 
136
+ # Convert input image to the format expected by the model and calculate its hash
137
  input_image = input_image.convert("RGB")
138
+ input_image_cv2 = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
139
+ image_hash = get_image_hash(input_image)
 
 
 
 
 
 
140
 
141
+ # Check if embeddings are cached
142
+ if image_hash in embeddings_cache:
143
+ faceid_embeds, face_image = embeddings_cache[image_hash]
144
+ else:
145
+ app = FaceAnalysis(
146
+ name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
147
+ )
148
+ app.prepare(ctx_id=0, det_size=(640, 640))
149
+ faces = app.get(input_image_cv2)
150
+ if not faces:
151
+ raise ValueError("No faces found in the image.")
152
+
153
+ faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
154
+ face_image = face_align.norm_crop(input_image_cv2, landmark=faces[0].kps, image_size=224)
155
+ # Cache the embeddings
156
+ embeddings_cache[image_hash] = (faceid_embeds, face_image)
157
 
158
  for image_index in range(num_images):
159
  if randomize_seed or image_index > 0: