hwajjala commited on
Commit
2a62a79
1 Parent(s): 77c92b5

Add execution logic

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -4,7 +4,9 @@ import torch
4
  import logging
5
  import json
6
  import pickle
 
7
  import gradio as gr
 
8
 
9
 
10
  logger = logging.getLogger("basebody")
@@ -14,7 +16,6 @@ TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
14
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
15
 
16
 
17
-
18
  clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
19
 
20
  with open(
@@ -44,9 +45,29 @@ with torch.no_grad():
44
 
45
 
46
  def predict_fn(input_img):
47
- print(type(input_img))
48
- print(input_img)
49
- return "Hello " + "!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  iface = gr.Interface(
 
4
  import logging
5
  import json
6
  import pickle
7
+ from PIL import Image
8
  import gradio as gr
9
+ from scipy.special import softmax
10
 
11
 
12
  logger = logging.getLogger("basebody")
 
16
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
17
 
18
 
 
19
  clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
20
 
21
  with open(
 
45
 
46
 
47
  def predict_fn(input_img):
48
+ input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
49
+ image = preprocess(
50
+ input_img
51
+ ).unsqueeze(0)
52
+ with torch.no_grad():
53
+ image_features = clip_model.encode_image(image)
54
+ cosine_simlarities = softmax(
55
+ (all_text_features @ image_features.cpu().T)
56
+ .squeeze()
57
+ .reshape(len(text_prompts), 2, -1),
58
+ axis=1,
59
+ )[:, 0, :]
60
+ # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
61
+ logger.info(f"cosine_simlarities: {cosine_simlarities}")
62
+ probabilities = lr_model.predict_proba(
63
+ cosine_simlarities.reshape(1, -1)
64
+ )
65
+ logger.info(f"probabilities: {probabilities}")
66
+ decision_json = json.dumps(
67
+ {"is_base_body": float(probabilities[0][1])}
68
+ ).encode("utf-8")
69
+ logger.info(f"decision_json: {decision_json}")
70
+ return decision_json
71
 
72
 
73
  iface = gr.Interface(