iamrobotbear commited on
Commit
8e34f80
·
1 Parent(s): b0eb421

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import pandas as pd
5
+ from lavis.models import load_model_and_preprocess
6
+ from lavis.processors import load_processor
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
8
+ import tensorflow as tf
9
+ import tensorflow_hub as hub
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ # Import logging module
13
+ import logging
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
+
18
+ # Load model and preprocessors for Image-Text Matching (LAVIS)
19
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
20
+ model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
21
+
22
+ # Load tokenizer and model for Image Captioning (TextCaps)
23
+ git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
24
+ git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
25
+
26
+ # Load Universal Sentence Encoder model for textual similarity calculation
27
+ embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
28
+
29
+ # Define a function to compute textual similarity between caption and statement
30
+ def compute_textual_similarity(caption, statement):
31
+ # Convert caption and statement into sentence embeddings
32
+ caption_embedding = embed([caption])[0].numpy()
33
+ statement_embedding = embed([statement])[0].numpy()
34
+
35
+ # Calculate cosine similarity between sentence embeddings
36
+ similarity_score = cosine_similarity([caption_embedding], [statement_embedding])[0][0]
37
+ return similarity_score
38
+
39
+ # List of statements for Image-Text Matching
40
+ statements = [
41
+ "cartoon, figurine, or toy",
42
+ "appears to be for children",
43
+ "includes children",
44
+ "is sexual",
45
+ "depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
46
+ "uses the name of or depicts Santa Claus",
47
+ 'promotes alcohol use as a "rite of passage" to adulthood',
48
+ "uses brand identification—including logos, trademarks, or names—on clothing, toys, games, game equipment, or other items intended for use primarily by persons below the legal purchase age",
49
+ "portrays persons in a state of intoxication or in any way suggests that intoxication is socially acceptable conduct",
50
+ "makes curative or therapeutic claims, except as permitted by law",
51
+ "makes claims or representations that individuals can attain social, professional, educational, or athletic success or status due to beverage alcohol consumption",
52
+ "degrades the image, form, or status of women, men, or of any ethnic group, minority, sexual orientation, religious affiliation, or other such group?",
53
+ "uses lewd or indecent images or language",
54
+ "employs religion or religious themes?",
55
+ "relies upon sexual prowess or sexual success as a selling point for the brand",
56
+ "uses graphic or gratuitous nudity, overt sexual activity, promiscuity, or sexually lewd or indecent images or language",
57
+ "associates with anti-social or dangerous behavior",
58
+ "depicts illegal activity of any kind?",
59
+ 'uses the term "spring break" or sponsors events or activities that use the term "spring break," unless those events or activities are located at a licensed retail establishment',
60
+ "baseball",
61
+ ]
62
+
63
+ # Function to compute ITM scores for the image-statement pair
64
+ def compute_itm_score(image, statement):
65
+ logging.info('Starting compute_itm_score')
66
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
67
+ img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
68
+ # Pass the statement text directly to model_itm
69
+ itm_output = model_itm({"image": img, "text_input": statement}, match_head="itm")
70
+ itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
71
+ score = itm_scores[:, 1].item()
72
+ logging.info('Finished compute_itm_score')
73
+ return score
74
+
75
+ def generate_caption(processor, model, image):
76
+ logging.info('Starting generate_caption')
77
+ inputs = processor(images=image, return_tensors="pt").to(device)
78
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
79
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
+ logging.info('Finished generate_caption')
81
+ return generated_caption
82
+
83
+ # Main function to perform image captioning and image-text matching
84
+ def process_images_and_statements(image):
85
+ logging.info('Starting process_images_and_statements')
86
+
87
+ # Generate image caption for the uploaded image using git-large-r-textcaps
88
+ caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
89
+
90
+ # Initialize an empty list to store the results
91
+ results = []
92
+
93
+ # Define weights for combining textual similarity score and image-statement ITM score (adjust as needed)
94
+ weight_textual_similarity = 0.5
95
+ weight_statement = 0.5
96
+
97
+ # Loop through each predefined statement
98
+ for statement in statements:
99
+ # Compute textual similarity between caption and statement
100
+ textual_similarity_score = compute_textual_similarity(caption, statement)
101
+
102
+ # Compute ITM score for the image-statement pair
103
+ itm_score_statement = compute_itm_score(image, statement)
104
+
105
+ # Combine the two scores using a weighted average
106
+ final_score = (weight_textual_similarity * textual_similarity_score) + (weight_statement * itm_score_statement)
107
+
108
+ # Store the result
109
+ result_text = (f'Textual similarity between caption ("{caption}") and statement ("{statement}") is {textual_similarity_score:.3f}\n'
110
+ f'The image-statement pair ("{statement}") is matched with a probability of {itm_score_statement:.3%}\n'
111
+ f'The final combined score is {final_score:.3%}')
112
+ results.append(result_text)
113
+
114
+ logging.info('Finished process_images_and_statements')
115
+
116
+ # Combine the results and return them
117
+ output = "\n\n".join(results)
118
+ return output
119
+
120
+ # Gradio interface
121
+ image_input = gr.inputs.Image()
122
+ output = gr.outputs.Textbox(label="Results")
123
+
124
+ iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
125
+ iface.launch()