import gradio as gr import gc import cv2 import torch import torch.nn.functional as F from tqdm import tqdm from transformers import DistilBertTokenizer import matplotlib.pyplot as plt from implement import * import config as CFG from main import build_loaders from CLIP import CLIPModel import os with gr.Blocks(css="style.css") as demo: def get_image_embeddings(valid_df, model_path): tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) valid_loader = build_loaders(valid_df, tokenizer, mode="valid") model = CLIPModel().to(CFG.device) model.load_state_dict(torch.load(model_path, map_location=CFG.device)) model.eval() valid_image_embeddings = [] with torch.no_grad(): for batch in tqdm(valid_loader): image_features = model.image_encoder(batch["image"].to(CFG.device)) image_embeddings = model.image_projection(image_features) valid_image_embeddings.append(image_embeddings) return model, torch.cat(valid_image_embeddings) _, valid_df = make_train_valid_dfs() model, image_embeddings = get_image_embeddings(valid_df, "best.pt") def find_matches(query, n=9): tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) encoded_query = tokenizer([query]) batch = { key: torch.tensor(values).to(CFG.device) for key, values in encoded_query.items() } with torch.no_grad(): text_features = model.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) text_embeddings = model.text_projection(text_features) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = text_embeddings_n @ image_embeddings_n.T _, indices = torch.topk(dot_similarity.squeeze(0), n * 5) matches = [valid_df['image'].values[idx] for idx in indices[::5]] images = [] for match in matches: image = cv2.imread(f"{CFG.image_path}/{match}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # images.append(image) return image with gr.Row(): textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.") image = gr.Image(type="numpy") button = gr.Button("Press") button.click( fn = find_matches, inputs=textbox, outputs=image ) # Create Gradio interface demo.launch(share=True)