File size: 5,701 Bytes
f4b1311
194b093
 
f4b1311
 
 
2a62a79
91d8e2e
2a62a79
91d8e2e
1a4386f
 
 
 
 
f4b1311
 
 
cb8ddf6
2e89562
2de9666
c894abc
2e89562
2de9666
f4b1311
e1f60ba
 
 
 
f4b1311
e1f60ba
 
 
f4b1311
 
 
 
 
2de9666
 
 
 
 
 
cb8ddf6
 
 
 
 
 
 
 
 
 
f4b1311
 
 
 
 
 
 
 
cb8ddf6
 
 
 
 
 
 
 
 
 
2de9666
 
 
 
 
 
 
 
 
267519a
 
edc9052
cb8ddf6
 
 
 
 
 
 
 
 
 
 
267519a
edc9052
cb8ddf6
 
 
2de9666
 
bb70cd9
2de9666
 
 
 
 
 
 
 
 
77c92b5
2a62a79
 
 
 
 
 
2de9666
bb70cd9
2de9666
 
bb70cd9
2de9666
cb8ddf6
 
 
2a62a79
2de9666
2a62a79
2de9666
 
 
 
2a62a79
cb8ddf6
 
 
2a62a79
29c69b2
2de9666
cb8ddf6
29c69b2
2de9666
966e6ea
75dd19d
2de9666
 
509301a
 
cb8ddf6
1a4386f
cb8ddf6
 
 
 
29c69b2
 
 
 
2a62a79
29c69b2
2a62a79
 
 
f4b1311
91d8e2e
f4b1311
267519a
f4b1311
 
edd1aeb
 
509301a
966e6ea
edd1aeb
 
29c69b2
edd1aeb
 
e1f60ba
7a2bd30
9c9b891
 
 
7a2bd30
e1f60ba
f4b1311
267519a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax

# set logging level
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"

TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"

HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
    HF_TOKEN, "Roblox/basebody_feedback"
)

clip_model, preprocess = clip.load(
    CLIP_MODEL_NAME, device="cpu"
)

with open(
    os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
    text_prompts = json.load(f)

with open(
    os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
) as f:
    hair_text_prompts = json.load(f)


with open(
    os.path.join(
        os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
    ),
    "r",
) as f:
    text_prompts_old = json.load(f)


with open(
    os.path.join(
        os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    lr_model = pickle.load(f)


with open(
    os.path.join(
        os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    lr_old_model = pickle.load(f)


with open(
    os.path.join(
        os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    hair_rf_model = pickle.load(f)


logger.info("Logistic regression model loaded, coefficients: ")


def get_text_features(text_prompts):
    all_text_features = []
    with torch.no_grad():
        for k, prompts in text_prompts.items():
            assert len(prompts) == 2
            inputs = clip.tokenize(prompts)
            outputs = clip_model.encode_text(inputs)
            all_text_features.append(outputs)
        all_text_features = torch.cat(all_text_features, dim=0)
        all_text_features = all_text_features.cpu()
    return all_text_features


all_text_features = get_text_features(text_prompts)
hair_text_features = get_text_features(hair_text_prompts)
old_text_features = get_text_features(text_prompts_old)


def get_cosine_similarities(image_features, text_features, text_prompts):
    cosine_simlarities = softmax(
        (text_features @ image_features.cpu().T)
        .squeeze()
        .reshape(len(text_prompts), 2, -1),
        axis=1,
    )[:, 0, :]
    return cosine_simlarities


def predict_fn(input_img):
    input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
    image = preprocess(
        input_img
    ).unsqueeze(0)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        base_body_cosine_simlarities = get_cosine_similarities(
            image_features, all_text_features, text_prompts
        )
        hair_cosine_simlarities = get_cosine_similarities(
            image_features, hair_text_features, hair_text_prompts
        )
        old_cosine_simlarities = get_cosine_similarities(
            image_features, old_text_features, text_prompts_old
        )
        # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
        logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
    probabilities = lr_model.predict_proba(
        base_body_cosine_simlarities.reshape(1, -1)
    )
    hair_probabilities = hair_rf_model.predict_proba(
        hair_cosine_simlarities.reshape(1, -1)
    )
    old_lr_probabilities = lr_old_model.predict_proba(
        old_cosine_simlarities.reshape(1, -1)
    )
    logger.info(f"probabilities: {probabilities}")
    result_probabilty = float(probabilities[0][1].round(3))
    hair_result_probabilty = float(hair_probabilities[0][1].round(3))
    old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
    # get decision string
    if result_probabilty > 0.77:
        if hair_result_probabilty < 0.24:
            logger.info("hair_result_probabilty < 0.5")
            result_probabilty = hair_result_probabilty
            decision = "AUTO REJECT"
        else:
            decision = "AUTO ACCEPT"
    elif result_probabilty < 0.2:
        logger.info("result_probabilty < 0.2")
        decision = "AUTO REJECT"
    elif old_result_probabilty < 0.06:
        logger.info("old_result_probabilty < 0.06")
        result_probabilty = old_result_probabilty
        decision = "AUTO REJECT"
    else:
        decision = "MODERATION"
    logger.info(f"decision: {decision}")
    decision_json = json.dumps(
        {"is_base_body": result_probabilty, "decision": decision}
    ).encode("utf-8")
    logger.info(f"decision_json: {decision_json}")
    return decision_json


iface = gr.Interface(
    fn=predict_fn,
    inputs="image",
    outputs="text",
    description="""
    The model returns the probability of the image being a base body. If
    probability > 0.77, the image can be automatically tagged as a base body. If
    probability < 0.24, the image can be automatically REJECTED as NOT as base
    body. All other cases will be submitted for moderation.

    Please flag if you think the decision is wrong.
    """,

    allow_flagging="manual",
    flagging_options=[
        ": decision should be accept",
        ": decision should be reject",
        ": decision should be moderation"
    ],
    flagging_callback=hf_writer
)
iface.launch()