File size: 5,701 Bytes
f4b1311 194b093 f4b1311 2a62a79 91d8e2e 2a62a79 91d8e2e 1a4386f f4b1311 cb8ddf6 2e89562 2de9666 c894abc 2e89562 2de9666 f4b1311 e1f60ba f4b1311 e1f60ba f4b1311 2de9666 cb8ddf6 f4b1311 cb8ddf6 2de9666 267519a edc9052 cb8ddf6 267519a edc9052 cb8ddf6 2de9666 bb70cd9 2de9666 77c92b5 2a62a79 2de9666 bb70cd9 2de9666 bb70cd9 2de9666 cb8ddf6 2a62a79 2de9666 2a62a79 2de9666 2a62a79 cb8ddf6 2a62a79 29c69b2 2de9666 cb8ddf6 29c69b2 2de9666 966e6ea 75dd19d 2de9666 509301a cb8ddf6 1a4386f cb8ddf6 29c69b2 2a62a79 29c69b2 2a62a79 f4b1311 91d8e2e f4b1311 267519a f4b1311 edd1aeb 509301a 966e6ea edd1aeb 29c69b2 edd1aeb e1f60ba 7a2bd30 9c9b891 7a2bd30 e1f60ba f4b1311 267519a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
# set logging level
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "Roblox/basebody_feedback"
)
clip_model, preprocess = clip.load(
CLIP_MODEL_NAME, device="cpu"
)
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
) as f:
hair_text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
),
"r",
) as f:
text_prompts_old = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
),
"rb",
) as f:
lr_old_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
),
"rb",
) as f:
hair_rf_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
def get_text_features(text_prompts):
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu()
return all_text_features
all_text_features = get_text_features(text_prompts)
hair_text_features = get_text_features(hair_text_prompts)
old_text_features = get_text_features(text_prompts_old)
def get_cosine_similarities(image_features, text_features, text_prompts):
cosine_simlarities = softmax(
(text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
return cosine_simlarities
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
base_body_cosine_simlarities = get_cosine_similarities(
image_features, all_text_features, text_prompts
)
hair_cosine_simlarities = get_cosine_similarities(
image_features, hair_text_features, hair_text_prompts
)
old_cosine_simlarities = get_cosine_similarities(
image_features, old_text_features, text_prompts_old
)
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
probabilities = lr_model.predict_proba(
base_body_cosine_simlarities.reshape(1, -1)
)
hair_probabilities = hair_rf_model.predict_proba(
hair_cosine_simlarities.reshape(1, -1)
)
old_lr_probabilities = lr_old_model.predict_proba(
old_cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
result_probabilty = float(probabilities[0][1].round(3))
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
# get decision string
if result_probabilty > 0.77:
if hair_result_probabilty < 0.24:
logger.info("hair_result_probabilty < 0.5")
result_probabilty = hair_result_probabilty
decision = "AUTO REJECT"
else:
decision = "AUTO ACCEPT"
elif result_probabilty < 0.2:
logger.info("result_probabilty < 0.2")
decision = "AUTO REJECT"
elif old_result_probabilty < 0.06:
logger.info("old_result_probabilty < 0.06")
result_probabilty = old_result_probabilty
decision = "AUTO REJECT"
else:
decision = "MODERATION"
logger.info(f"decision: {decision}")
decision_json = json.dumps(
{"is_base_body": result_probabilty, "decision": decision}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
description="""
The model returns the probability of the image being a base body. If
probability > 0.77, the image can be automatically tagged as a base body. If
probability < 0.24, the image can be automatically REJECTED as NOT as base
body. All other cases will be submitted for moderation.
Please flag if you think the decision is wrong.
""",
allow_flagging="manual",
flagging_options=[
": decision should be accept",
": decision should be reject",
": decision should be moderation"
],
flagging_callback=hf_writer
)
iface.launch()
|