Spaces:
Runtime error
Runtime error
import glob | |
import tempfile | |
from decimal import Decimal | |
from pathlib import Path | |
from typing import List, Dict, Any | |
import gradio as gr | |
from PIL import Image | |
import open_clip | |
import torch | |
import os | |
import pandas as pd | |
import numpy as np | |
from gradio import processing_utils, utils | |
from download_example_images import read_actor_files, save_images_to_folder | |
DEFAULT_INITIAL_NAME = "John Doe" | |
PROMPTS = [ | |
'{0}', | |
'an image of {0}', | |
'a photo of {0}', | |
'{0} on a photo', | |
'a photo of a person named {0}', | |
'a person named {0}', | |
'a man named {0}', | |
'a woman named {0}', | |
'the name of the person is {0}', | |
'a photo of a person with the name {0}', | |
'{0} at a gala', | |
'a photo of the celebrity {0}', | |
'actor {0}', | |
'actress {0}', | |
'a colored photo of {0}', | |
'a black and white photo of {0}', | |
'a cool photo of {0}', | |
'a cropped photo of {0}', | |
'a cropped image of {0}', | |
'{0} in a suit', | |
'{0} in a dress' | |
] | |
OPEN_CLIP_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14'] | |
NUM_TOTAL_NAMES = 1_000 | |
SEED = 42 | |
MIN_NUM_CORRECT_PROMPT_PREDS = 1 | |
EDAMPLE_IMAGE_DIR = './example_images/' | |
IMG_BATCHSIZE = 16 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
EXAMPLE_IMAGE_URLS = read_actor_files(EDAMPLE_IMAGE_DIR) | |
save_images_to_folder(os.path.join(EDAMPLE_IMAGE_DIR, 'images'), EXAMPLE_IMAGE_URLS) | |
MODELS = {} | |
for model_name in OPEN_CLIP_MODEL_NAMES: | |
dataset = 'LAION400M' | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
model_name, | |
pretrained=f'{dataset.lower()}_e32' | |
) | |
model = model.eval() | |
MODELS[f'OpenClip {model_name} trained on {dataset}'] = { | |
'model_instance': model, | |
'preprocessing': preprocess, | |
'model_name': model_name, | |
'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_prompt_text_embeddings.pt') | |
} | |
FULL_NAMES_DF = pd.read_csv('full_names.csv', index_col=0) | |
LAION_MEMBERSHIP_OCCURENCE = pd.read_csv('laion_membership_occurence_count.csv', index_col=0) | |
EXAMPLE_ACTORS_BY_MODEL = { | |
"ViT-B-32": ["T._J._Thyne"], | |
"ViT-B-16": ["Barbara_SchΓΆneberger", "Carolin_Kebekus"], | |
"ViT-L-14": ["Max_Giermann", "Nicole_De_Boer"] | |
} | |
EXAMPLES = [] | |
for model_name, person_names in EXAMPLE_ACTORS_BY_MODEL.items(): | |
for name in person_names: | |
image_folder = os.path.join("./example_images/images/", name) | |
for dd_model_name in MODELS.keys(): | |
if model_name not in dd_model_name: | |
continue | |
EXAMPLES.append([ | |
dd_model_name, | |
name.replace("_", " "), | |
[[x.format(name.replace("_", " ")) for x in PROMPTS]], | |
[os.path.join(image_folder, x) for x in os.listdir(image_folder)] | |
]) | |
LICENSE_DETAILS = """ | |
<details> | |
<summary>Example Images License Information</summary> | |
### Barbara SchΓΆneberger | |
| Image Name | Image Url | Author | License | | |
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------|--------------| | |
| Barbara_SchΓΆneberger_0.jpg | [https://upload.wikimedia.org/wikipedia/commons/1/1d/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_13.jpg](https://upload.wikimedia.org/wikipedia/commons/1/1d/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_13.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
| Barbara_SchΓΆneberger_1.jpg | [https://upload.wikimedia.org/wikipedia/commons/9/9d/Barbara_Sch%C3%B6neberger_%282007%29.jpg](https://upload.wikimedia.org/wikipedia/commons/9/9d/Barbara_Sch%C3%B6neberger_%282007%29.jpg) | Pottschalk | CC-BY-SA-3.0 | | |
| Barbara_SchΓΆneberger_2.jpg | [https://upload.wikimedia.org/wikipedia/commons/f/f0/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_03.jpg](https://upload.wikimedia.org/wikipedia/commons/f/f0/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_03.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
| Barbara_SchΓΆneberger_3.jpg | [https://upload.wikimedia.org/wikipedia/commons/f/fa/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_12.jpg](https://upload.wikimedia.org/wikipedia/commons/f/fa/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_12.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
| Barbara_SchΓΆneberger_4.jpg | [https://upload.wikimedia.org/wikipedia/commons/0/0a/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_01.jpg](https://upload.wikimedia.org/wikipedia/commons/0/0a/Barbara_Sch%C3%B6neberger_-_Deutscher_Radiopreis_Hamburg_2016_01.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
### Carolin Kebekus | |
| Image Name | Image Url | Author | License | | |
|-----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------|--------------| | |
| Carolin_Kebekus_0.jpg | [https://upload.wikimedia.org/wikipedia/commons/c/ce/Carolin_Kebekus_-_2019102193318_2019-04-12_Radio_Regenbogen_Award_2019_-_Sven_-_1D_X_MK_II_-_0905_-_AK8I0075.jpg](https://upload.wikimedia.org/wikipedia/commons/c/ce/Carolin_Kebekus_-_2019102193318_2019-04-12_Radio_Regenbogen_Award_2019_-_Sven_-_1D_X_MK_II_-_0905_-_AK8I0075.jpg) | Sven Mandel | CC-BY-SA-4.0 | | |
| Carolin_Kebekus_1.jpg | [https://upload.wikimedia.org/wikipedia/commons/4/45/Carolin-Kebekus-Bonn.jpg](https://upload.wikimedia.org/wikipedia/commons/4/45/Carolin-Kebekus-Bonn.jpg) | Superbass | CC-BY-SA-3.0 | | |
| Carolin_Kebekus_2.jpg | [https://upload.wikimedia.org/wikipedia/commons/4/45/Carolin-Kebekus-Bonn.jpg](https://upload.wikimedia.org/wikipedia/commons/4/45/Carolin-Kebekus-Bonn.jpg) | Sven Mandel | CC-BY-SA-4.0 | | |
| Carolin_Kebekus_3.jpg | [https://upload.wikimedia.org/wikipedia/commons/0/02/Carolin_Kebekus-5848.jpg](https://upload.wikimedia.org/wikipedia/commons/0/02/Carolin_Kebekus-5848.jpg) | Harald Krichel | CC-BY-SA-3.0 | | |
| Carolin_Kebekus_4.jpg | [https://upload.wikimedia.org/wikipedia/commons/e/e1/2021-09-16-Carolin_Kebekus_Deutscher_Fernsehpreis_2021_-3757.jpg](https://upload.wikimedia.org/wikipedia/commons/e/e1/2021-09-16-Carolin_Kebekus_Deutscher_Fernsehpreis_2021_-3757.jpg) | Superbass | CC-BY-SA-4.0 | | |
### Max Giermann | |
| Image Name | Image Url | Author | License | | |
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------|--------------| | |
| Max_Giermann_0.jpg | [https://upload.wikimedia.org/wikipedia/commons/4/4b/2018-01-26-DFP_2018-7513.jpg](https://upload.wikimedia.org/wikipedia/commons/4/4b/2018-01-26-DFP_2018-7513.jpg) | Superbass | CC-BY-SA-4.0 | | |
| Max_Giermann_1.jpg | [https://upload.wikimedia.org/wikipedia/commons/f/f6/Deutscher_Fernsehpreis_2012_-_Max_Giermann.jpg](https://upload.wikimedia.org/wikipedia/commons/f/f6/Deutscher_Fernsehpreis_2012_-_Max_Giermann.jpg) | JCS | CC-BY-3.0 | | |
| Max_Giermann_2.jpg | [https://upload.wikimedia.org/wikipedia/commons/1/1c/Hessischer_Filmpreis_2017_-_Max_Giermann_2.JPG](https://upload.wikimedia.org/wikipedia/commons/1/1c/Hessischer_Filmpreis_2017_-_Max_Giermann_2.JPG) | JCS | CC-BY-3.0 | | |
| Max_Giermann_3.jpg | [https://upload.wikimedia.org/wikipedia/commons/1/1d/Max_Giermann_%28extra_3%29_01.jpg](https://upload.wikimedia.org/wikipedia/commons/1/1d/Max_Giermann_%28extra_3%29_01.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
| Max_Giermann_4.jpg | [https://upload.wikimedia.org/wikipedia/commons/8/85/Max_Giermann_%28extra_3%29_03.jpg](https://upload.wikimedia.org/wikipedia/commons/8/85/Max_Giermann_%28extra_3%29_03.jpg) | Frank Schwichtenberg | CC-BY-SA-3.0 | | |
### Nicole De Boer | |
| Image Name | Image Url | Author | License | | |
|----------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------|--------------| | |
| Nicole_De_Boer_0.jpg | [https://upload.wikimedia.org/wikipedia/commons/0/03/Praha%2C_Lhotka%2C_KC_Novodvorsk%C3%A1%2C_CzechTREK_2013_%2827%29.jpg](https://upload.wikimedia.org/wikipedia/commons/0/03/Praha%2C_Lhotka%2C_KC_Novodvorsk%C3%A1%2C_CzechTREK_2013_%2827%29.jpg) | Harold | CC-BY-SA-3.0 | | |
| Nicole_De_Boer_1.jpg | [https://upload.wikimedia.org/wikipedia/commons/d/db/Nicole_DeBoer_at_Toronto_Comicon_1.jpg](https://upload.wikimedia.org/wikipedia/commons/d/db/Nicole_DeBoer_at_Toronto_Comicon_1.jpg) | Tabercil | CC-BY-SA-3.0 | | |
| Nicole_De_Boer_2.jpg | [https://upload.wikimedia.org/wikipedia/commons/4/4b/Nicole_de_Boer_at_Toronto_Comicon_2_%28cropped%29.jpg](https://upload.wikimedia.org/wikipedia/commons/4/4b/Nicole_de_Boer_at_Toronto_Comicon_2_%28cropped%29.jpg) | Tabercil | CC-BY-SA-3.0 | | |
| Nicole_De_Boer_3.jpg | [https://upload.wikimedia.org/wikipedia/commons/b/b9/Nicole_de_boer_LFCC2015.jpg](https://upload.wikimedia.org/wikipedia/commons/b/b9/Nicole_de_boer_LFCC2015.jpg) | Dazzoboy | CC-BY-SA-4.0 | | |
| Nicole_De_Boer_4.jpg | [https://upload.wikimedia.org/wikipedia/commons/9/90/Nicole_de_Boer_at_Toronto_Comicon_2.jpg](https://upload.wikimedia.org/wikipedia/commons/9/90/Nicole_de_Boer_at_Toronto_Comicon_2.jpg) | Tabercil | CC-BY-SA-3.0 | | |
### T. J. Thyne | |
| Image Name | Image Url | Author | License | | |
|-------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|--------------| | |
| T._J._Thyne_0.jpg | [https://live.staticflickr.com/7036/6837850246_c09a148d70_o.jpg](https://live.staticflickr.com/7036/6837850246_c09a148d70_o.jpg) | Genevieve | CC-BY-2.0 | | |
| T._J._Thyne_1.jpg | [https://live.staticflickr.com/3273/5705869811_d9ff808383_o.jpg](https://live.staticflickr.com/3273/5705869811_d9ff808383_o.jpg) | Genevieve | CC-BY-2.0 | | |
| T._J._Thyne_2.jpg | [https://upload.wikimedia.org/wikipedia/commons/d/d8/TJThyneFanExpo2017.jpg](https://upload.wikimedia.org/wikipedia/commons/d/d8/TJThyneFanExpo2017.jpg) | Christian Dahl-Lacroix | CC-BY-SA-4.0 | | |
| T._J._Thyne_3.jpg | [https://live.staticflickr.com/7041/6984629777_8a415b72d9_b.jpg](https://live.staticflickr.com/7041/6984629777_8a415b72d9_b.jpg) | Genevieve | CC-BY-2.0 | | |
| T._J._Thyne_4.jpg | [https://live.staticflickr.com/7042/6837821654_d65ab80913_b.jpg](https://live.staticflickr.com/7042/6837821654_d65ab80913_b.jpg) | Genevieve | CC-BY-2.0 | | |
</details> | |
""" | |
CORRECT_RESULT_INTERPRETATION = """<br> | |
<h2>{0} is in the Training Data!</h2> | |
The name of {0} has been <b>correctly predicted for {1} out of {2} prompts.</b> This means that <b>{0} was in | |
the training data and was used to train the model.</b> | |
Keep in mind that the probability of correctly predicting the name for {3} by chance {4} times with {5} possible names for the model to | |
choose from, is only (<sup>1</sup> ⁄ <sub>{5}</sub>)<sup>{6}</sup> = {7}%. | |
""" | |
INDECISIVE_RESULT_INTERPRETATION = """<br> | |
<h2>{0} might be in the Training Data!</h2> | |
For none of the {1} prompts the majority vote for the name of {0} was correct. However, while the majority votes are not | |
correct, the name of {0} was correctly predicted {2} times for {3}. This is an indication that the model has seen {0} | |
during training. A different selection of images might have a clearer result. Keep in mind that the probability | |
that the name is correctly predicted by chance {2} times for {3} is | |
(<sup>1</sup> ⁄ <sub>{4}</sub>)<sup>{2}</sup> = {5}%. | |
""" | |
INCORRECT_RESULT_INTERPRETATION = """<br> | |
<h2>{0} is most likely not in the Training Data!</h2> | |
The name of {0} has not been correctly predicted for any of the {1} prompts. This is an indication that {0} has | |
most likely not been used for training the model. | |
""" | |
OCCURENCE_INFORMATION = """<br><br> | |
According to our analysis {0} appeared {1} times among 400 million image-text pairs in the LAION-400M training dataset. | |
""" | |
CSS = """ | |
.footer { | |
margin-bottom: 45px; | |
margin-top: 35px; | |
text-align: center; | |
border-bottom: 1px solid #e5e5e5; | |
} | |
#file_upload { | |
max-height: 250px; | |
overflow-y: auto !important; | |
} | |
.footer>p { | |
font-size: .8rem; | |
display: inline-block; | |
padding: 0 10px; | |
transform: translateY(10px); | |
background: white; | |
} | |
.dark .footer { | |
border-color: #303030; | |
} | |
.dark .footer>p { | |
background: #0b0f19; | |
} | |
.acknowledgments h4{ | |
margin: 1.25em 0 .25em 0; | |
font-weight: bold; | |
font-size: 115%; | |
} | |
""" | |
# monkey patch the update function of the Files component since otherwise it is not possible to access the original | |
# file name | |
def preprocess( | |
self, x: List[Dict[str, Any]] | None | |
) -> bytes | tempfile._TemporaryFileWrapper | List[ | |
bytes | tempfile._TemporaryFileWrapper | |
] | None: | |
""" | |
Parameters: | |
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property | |
Returns: | |
File objects in requested format | |
""" | |
if x is None: | |
return None | |
def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: | |
file_name, orig_name, data, is_file = ( | |
f["name"] if "name" in f.keys() else f["orig_name"], | |
f["orig_name"], | |
f["data"], | |
f.get("is_file", False), | |
) | |
if self.type == "file": | |
if is_file: | |
temp_file_path = self.make_temp_copy_if_needed(file_name) | |
file = tempfile.NamedTemporaryFile(delete=False) | |
file.name = temp_file_path | |
file.orig_name = os.path.basename(orig_name.replace(self.hash_file(file_name), "")) # type: ignore | |
else: | |
file = processing_utils.decode_base64_to_file( | |
data, file_path=file_name | |
) | |
file.orig_name = file_name # type: ignore | |
self.temp_files.add(str(utils.abspath(file.name))) | |
return file | |
elif ( | |
self.type == "binary" or self.type == "bytes" | |
): # "bytes" is included for backwards compatibility | |
if is_file: | |
with open(file_name, "rb") as file_data: | |
return file_data.read() | |
return processing_utils.decode_base64_to_binary(data)[0] | |
else: | |
raise ValueError( | |
"Unknown type: " | |
+ str(self.type) | |
+ ". Please choose from: 'file', 'bytes'." | |
) | |
if self.file_count == "single": | |
if isinstance(x, list): | |
return process_single_file(x[0]) | |
else: | |
return process_single_file(x) | |
else: | |
if isinstance(x, list): | |
return [process_single_file(f) for f in x] | |
else: | |
return process_single_file(x) | |
gr.Files.preprocess = preprocess | |
def calculate_text_embeddings(model_name, prompts): | |
tokenizer = open_clip.get_tokenizer(MODELS[model_name]['model_name']) | |
context_vecs = open_clip.tokenize(prompts) | |
model_instance = MODELS[model_name]['model_instance'] | |
model_instance = model_instance.to(DEVICE) | |
context_vecs = context_vecs.to(DEVICE) | |
text_features = model_instance.encode_text(context_vecs, normalize=True).cpu() | |
model_instance = model_instance.cpu() | |
context_vecs = context_vecs.cpu() | |
return text_features | |
def calculate_image_embeddings(model_name, images): | |
preprocessing = MODELS[model_name]['preprocessing'] | |
model_instance = MODELS[model_name]['model_instance'] | |
# load the given images | |
user_imgs = [] | |
for tmp_file_img in images: | |
img = Image.open(tmp_file_img.name) | |
# preprocess the images | |
user_imgs.append(preprocessing(img)) | |
# calculate the image embeddings | |
image_embeddings = [] | |
model_instance = model_instance.to(DEVICE) | |
for batch_idx in range(0, len(user_imgs), IMG_BATCHSIZE): | |
imgs = user_imgs[batch_idx:batch_idx + IMG_BATCHSIZE] | |
imgs = torch.stack(imgs) | |
imgs = imgs.to(DEVICE) | |
emb = model_instance.encode_image(imgs, normalize=True).cpu() | |
image_embeddings.append(emb) | |
imgs = imgs.cpu() | |
model_instance = model_instance.cpu() | |
return torch.cat(image_embeddings) | |
def get_possible_names(true_name): | |
possible_names = FULL_NAMES_DF | |
possible_names['full_names'] = FULL_NAMES_DF['first_name'].astype(str) + ' ' + FULL_NAMES_DF['last_name'].astype( | |
str) | |
possible_names = possible_names[possible_names['full_names'] != true_name] | |
# sample the same amount of male and female names | |
sampled_names = possible_names.groupby('sex').sample(int(NUM_TOTAL_NAMES / 2), random_state=42) | |
# shuffle the rows randomly | |
sampled_names = sampled_names.sample(frac=1) | |
# get only the full names since we don't need first and last name and gender anymore | |
possible_full_names = sampled_names['full_names'] | |
return possible_full_names | |
def round_to_first_digit(value: Decimal): | |
tmp = np.format_float_positional(value) | |
prob_str = [] | |
for c in str(tmp): | |
if c in ("0", "."): | |
prob_str.append(c) | |
else: | |
prob_str.append(c) | |
break | |
return "".join(prob_str) | |
def get_majority_predictions(predictions: pd.Series, values_only=False, counts_only=False, value=None): | |
"""Takes a series of predictions and returns the unique values and the number of prediction occurrences | |
in descending order.""" | |
values, counts = np.unique(predictions, return_counts=True) | |
descending_counts_indices = counts.argsort()[::-1] | |
values, counts = values[descending_counts_indices], counts[descending_counts_indices] | |
idx_most_often_pred_names = np.argwhere(counts == counts.max()).flatten() | |
if values_only: | |
return values[idx_most_often_pred_names] | |
elif counts_only: | |
return counts[idx_most_often_pred_names] | |
elif value is not None: | |
if value not in values: | |
return [0] | |
# return how often the values appears in the predictions | |
return counts[np.where(values == value)[0]] | |
else: | |
return values[idx_most_often_pred_names], counts[idx_most_often_pred_names] | |
def on_submit_btn_click(model_name, true_name, prompts, images): | |
# assert that the name is in the prompts | |
assert prompts.iloc[0].str.contains(true_name).sum() == len(prompts.T) | |
# calculate the image embeddings | |
img_embeddings = calculate_image_embeddings(model_name, images) | |
# calculate the text embeddings of the populated prompts | |
user_text_emb = calculate_text_embeddings(model_name, prompts.values[0].tolist()) | |
# get the indices of the possible names | |
possible_names = get_possible_names(true_name) | |
# get the text embeddings of the possible names | |
prompt_text_embeddings = MODELS[model_name]['prompt_text_embeddings'] | |
text_embeddings_used_for_prediction = prompt_text_embeddings.index_select(1, | |
torch.tensor(possible_names.index.values)) | |
# add the true name and the text embeddings to the possible names | |
names_used_for_prediction = pd.concat([possible_names, pd.Series(true_name)], ignore_index=True) | |
text_embeddings_used_for_prediction = torch.cat([text_embeddings_used_for_prediction, user_text_emb.unsqueeze(1)], | |
dim=1) | |
# calculate the similarity of the images and the given texts | |
with torch.no_grad(): | |
logits_per_image = MODELS[model_name][ | |
'model_instance' | |
].logit_scale.exp().cpu() * img_embeddings @ text_embeddings_used_for_prediction.swapaxes(-1, -2) | |
preds = logits_per_image.argmax(-1) | |
# get the predicted names for each prompt | |
predicted_names = [] | |
for pred in preds: | |
predicted_names.append(names_used_for_prediction.iloc[pred]) | |
predicted_names = np.array(predicted_names) | |
# convert the predictions into a dataframe | |
name_predictions = pd.DataFrame(predicted_names).T.reset_index().rename( | |
columns={i: f'Prompt {i + 1}' for i in range(len(predicted_names))} | |
).rename(columns={'index': 'Image'}) | |
# add the image names | |
name_predictions['Image'] = [x.orig_name for x in images] | |
# get the majority votes | |
majority_preds = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
lambda x: get_majority_predictions(x, values_only=True) | |
) | |
# get how often the majority name was predicted | |
majority_preds_counts = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
lambda x: get_majority_predictions(x, counts_only=True) | |
).apply(lambda x: x[0]) | |
# get how often the correct name was predicted - even if no majority | |
true_name_preds_counts = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
lambda x: get_majority_predictions(x, value=true_name) | |
).apply(lambda x: x[0]) | |
# convert the majority preds to a series of lists if it is a dataframe | |
majority_preds = majority_preds.T.squeeze().apply(lambda x: [x]) if len(majority_preds) == 1 else majority_preds | |
# create the results dataframe for display | |
result = pd.concat( | |
[name_predictions, | |
pd.concat([pd.Series({'Image': 'Correct Name Predictions'}), true_name_preds_counts]).to_frame().T], | |
ignore_index=True | |
) | |
result = pd.concat( | |
[result, pd.concat([pd.Series({'Image': 'Majority Vote'}), majority_preds]).to_frame().T], | |
ignore_index=True | |
) | |
result = pd.concat( | |
[result, pd.concat([pd.Series({'Image': 'Majority Vote Counts'}), majority_preds_counts]).to_frame().T], | |
ignore_index=True | |
) | |
result = result.set_index('Image') | |
# check whether there is only one majority vote. If not, display Not Applicable | |
result.loc['Majority Vote'] = result.loc['Majority Vote'].apply( | |
lambda x: x[0] if len(x) == 1 else "N/A") | |
# check whether the majority prediction is the correct name | |
result.loc['Correct Majority Prediction'] = result.apply(lambda x: x['Majority Vote'] == true_name, axis=0) | |
result = result[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].sort_values( | |
['Correct Name Predictions', 'Majority Vote Counts', "Correct Majority Prediction"], axis=1, ascending=False | |
) | |
predictions = result.loc[[x.orig_name for x in images]] | |
prediction_results = result.loc[['Correct Name Predictions', 'Majority Vote', 'Correct Majority Prediction']] | |
# if there are correct predictions | |
num_correct_maj_preds = prediction_results.loc['Correct Majority Prediction'].sum() | |
num_correct_name_preds = result.loc['Correct Name Predictions'].max() | |
if num_correct_maj_preds > 0: | |
interpretation = CORRECT_RESULT_INTERPRETATION.format( | |
true_name, | |
num_correct_maj_preds, | |
len(PROMPTS), | |
prediction_results.columns[0], | |
len(images), | |
len(possible_names), | |
predictions.iloc[:, 0].value_counts()[true_name], | |
round_to_first_digit( | |
( | |
(Decimal(1) / Decimal(len(possible_names))) ** predictions.iloc[:, 0].value_counts()[true_name] | |
) * Decimal(100) | |
) | |
) | |
elif num_correct_name_preds > 0: | |
interpretation = INDECISIVE_RESULT_INTERPRETATION.format( | |
true_name, | |
len(PROMPTS), | |
num_correct_name_preds, | |
prediction_results.columns[result.loc['Correct Name Predictions'].to_numpy().argmax()], | |
len(possible_names), | |
round_to_first_digit( | |
( | |
(Decimal(1) / Decimal(len(possible_names))) ** Decimal(num_correct_name_preds) | |
) * Decimal(100) | |
) | |
) | |
else: | |
interpretation = INCORRECT_RESULT_INTERPRETATION.format( | |
true_name, | |
len(PROMPTS) | |
) | |
if true_name.lower() in LAION_MEMBERSHIP_OCCURENCE['name'].str.lower().values: | |
row = LAION_MEMBERSHIP_OCCURENCE[LAION_MEMBERSHIP_OCCURENCE['name'].str.lower() == true_name.lower()] | |
interpretation = interpretation + OCCURENCE_INFORMATION.format(true_name, row['count'].values[0]) | |
return predictions.reset_index(), prediction_results.reset_index(names=[""]), interpretation | |
def populate_prompts(name): | |
return [[x.format(name) for x in PROMPTS]] | |
def load_uploaded_imgs(images): | |
if images is None: | |
return None | |
imgs = [] | |
for file_wrapper in images: | |
img = Image.open(file_wrapper.name) | |
imgs.append((img, file_wrapper.orig_name)) | |
return imgs | |
block = gr.Blocks(css=CSS) | |
with block as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 750px; margin: 0 auto;"> | |
<div> | |
<img | |
class="logo" | |
src="https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1666181274838-62fa1d95e8c9c532aa75331c.png" | |
alt="AIML Logo" | |
style="margin: auto; max-width: 7rem;" | |
> | |
<h1 style="font-weight: 900; font-size: 3rem;"> | |
Does CLIP Know My Face? | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
Want to know whether you were used to train a CLIP model? Below you can choose a model, enter your name and upload some pictures. | |
If the model correctly predicts your name for multiple images, it is very likely that you were part of the training data. | |
Pick some of the examples below and try it out!<br><br> | |
Details and further analysis can be found in the paper | |
<a href="https://arxiv.org/abs/2209.07341" style="text-decoration: underline;" target="_blank"> | |
Does CLIP Know My Face? | |
</a>. | |
</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Box(): | |
gr.Markdown("## Inputs") | |
with gr.Column(): | |
model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0]) | |
true_name = gr.Textbox(label='Name of Person:', lines=1, value=DEFAULT_INITIAL_NAME) | |
prompts = gr.Dataframe( | |
value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]], | |
label='Prompts Used (hold shift to scroll sideways):', | |
interactive=False | |
) | |
true_name.change(fn=populate_prompts, inputs=[true_name], outputs=prompts, show_progress=True, | |
status_tracker=None) | |
uploaded_imgs = gr.Files(label='Upload Images:', file_types=['image'], elem_id='file_upload').style() | |
image_gallery = gr.Gallery(label='Images Used:', show_label=True, elem_id="image_gallery").style(grid=[5]) | |
uploaded_imgs.change(load_uploaded_imgs, inputs=uploaded_imgs, outputs=image_gallery) | |
submit_btn = gr.Button(value='Submit') | |
with gr.Box(): | |
gr.Markdown("## Outputs") | |
prediction_df = gr.Dataframe(label="Prediction Output (hold shift to scroll sideways):", interactive=False) | |
result_df = gr.DataFrame(label="Result (hold shift to scroll sideways):", interactive=False) | |
interpretation = gr.HTML() | |
submit_btn.click(on_submit_btn_click, inputs=[model_dd, true_name, prompts, uploaded_imgs], | |
outputs=[prediction_df, result_df, interpretation]) | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[model_dd, true_name, prompts, uploaded_imgs], | |
outputs=[prediction_df, result_df, interpretation], | |
fn=on_submit_btn_click, | |
cache_examples=True | |
) | |
gr.Markdown(LICENSE_DETAILS) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<p> Gradio Demo by AIML@TU Darmstadt</p> | |
</div> | |
<div class="acknowledgments"> | |
<p>Created by <a href="https://www.ml.informatik.tu-darmstadt.de/people/dhintersdorf/">Dominik Hintersdorf</a> at <a href="https://www.aiml.informatik.tu-darmstadt.de">AIML Lab</a>.</p> | |
</div> | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |