File size: 5,565 Bytes
4859d06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
###
# take a file containing image filepaths and return a file also containing detected objects
#
# the input csv file must contain an 'image_file' column containing all the image filepaths
# #
import os
import clip
import torch
import pandas as pd
from PIL import Image
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# this dataset gives us the object classes
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
def save_checkpoint(checkpoint_path,df, object_list):
output_df = df.copy()
output_df['clip_recognized_objects'] = object_list
output_df.to_csv(checkpoint_path,
index= False, # don't write a new 'Index' column
)
print("Saved checkpoint!")
def load_checkpoint(checkpoint_path):
try:
print("reading checkpoint at ", checkpoint_path)
df = pd.read_csv(checkpoint_path)
cached_objects = {
row['image_file']: row['clip_recognized_objects']
for _, row in df.iterrows()
}
print(f"Checkpoint loaded succesfully to cache: {len(cached_objects)} processed files")
return cached_objects
except:
print("Checkpoint was not loaded")
return cached_objects_dict
def get_checkpoint_path(output_path):
#checkpoint_path = "checkpoint" + os.path.basename(output_path)
#checkpoint_path = os.path.join( os.path.dirname(output_path), checkpoint_path)
#return checkpoint_path
return output_path
cached_objects_dict = {} # to avoid recomputing
def get_objects(filepath, model, preprocess, device, cached_objects_dict):
objects = cached_objects_dict.get(filepath)
if objects is None:
objects = get_objects_in_image(filepath, model, preprocess, device)
cached_objects_dict[filepath] = objects
return objects
def get_objects_in_image(image_filepath, model, preprocess, device):
# Prepare the inputs
image = Image.open(image_filepath).resize((600,600))
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Append the the result
#print("\nTop predictions:\n")
objects = []
for value, index in zip(values, indices):
objects.append((cifar100.classes[index], value.item()))
# print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
return objects
def clip_object_detection(input_csv, output_csv):
checkpoint_path = get_checkpoint_path(output_csv)
cached_objects_dict = load_checkpoint(checkpoint_path)
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
recognized_objects_per_image = []
processed_files = set(cached_objects_dict.keys())
df = pd.read_csv(input_csv)
iterable_list = list(enumerate( df['image_file']))
for elem in tqdm(iterable_list):
idx = elem[0]
filepath = elem[1]
#save checkpoint every 50 files
if (not (len(processed_files) % 49)
):
print(f"Images processed: {len(processed_files)}")
save_checkpoint(checkpoint_path, df.iloc[:idx], recognized_objects_per_image)
objects = get_objects(
filepath, model, preprocess, device,
cached_objects_dict
)
recognized_objects_per_image.append(objects)
processed_files.add(filepath)
recognized_objects_per_image = pd.Series(recognized_objects_per_image)
return recognized_objects_per_image
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="CLIP object recognition",
description='Recognizes the top 5 main objects per image in an image list')
parser.add_argument("--input_csv", "-in", metavar='in', type=str, nargs=1,
help='input file containing images-paths for object recognition.',
#default=[default_painting_folder]
)
parser.add_argument("--output_csv", "-out", metavar='out', type=str, nargs=1,
help='output file containing images-paths + recognized objects'
#default=[default_interpretation_folder]
)
args = parser.parse_args()
input_csv_file = args.input_csv[0]
output_csv_file = args.output_csv[0]
print(">>> input file: " , input_csv_file)
print(">>> output file: ", output_csv_file)
# perform object recognition
recognized_objects_per_image = clip_object_detection(input_csv_file, output_csv_file)
# add a column with the recognized objects
output_df = pd.read_csv(input_csv_file)
output_df['clip_recognized_objects'] = recognized_objects_per_image
output_df.to_csv(output_csv_file,
index= False, # don't write a new 'Index' column
)
|