hwajjala commited on
Commit
bb70cd9
1 Parent(s): 2de9666

Bugfix length of text prompts

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -78,7 +78,7 @@ with torch.no_grad():
78
  hair_text_features = hair_text_features.cpu()
79
 
80
 
81
- def get_cosine_similarities(image_features, text_features):
82
  cosine_simlarities = softmax(
83
  (text_features @ image_features.cpu().T)
84
  .squeeze()
@@ -96,10 +96,10 @@ def predict_fn(input_img):
96
  with torch.no_grad():
97
  image_features = clip_model.encode_image(image)
98
  base_body_cosine_simlarities = get_cosine_similarities(
99
- image_features, all_text_features
100
  )
101
  hair_cosine_simlarities = get_cosine_similarities(
102
- image_features, hair_text_features
103
  )
104
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
105
  logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
 
78
  hair_text_features = hair_text_features.cpu()
79
 
80
 
81
+ def get_cosine_similarities(image_features, text_features, text_prompts):
82
  cosine_simlarities = softmax(
83
  (text_features @ image_features.cpu().T)
84
  .squeeze()
 
96
  with torch.no_grad():
97
  image_features = clip_model.encode_image(image)
98
  base_body_cosine_simlarities = get_cosine_similarities(
99
+ image_features, all_text_features, text_prompts
100
  )
101
  hair_cosine_simlarities = get_cosine_similarities(
102
+ image_features, hair_text_features, hair_text_prompts
103
  )
104
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
105
  logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")