File size: 7,316 Bytes
8e34f80
7ecd689
 
 
 
 
 
8e34f80
 
7ecd689
 
125e1ba
7ecd689
0f9d382
35e0000
7ecd689
53dde32
35e0000
7ecd689
 
 
8e34f80
7ecd689
 
 
 
 
 
35e0000
7ecd689
35e0000
7ecd689
 
 
 
 
 
 
8e34f80
7ecd689
 
 
 
 
35e0000
7ecd689
 
 
 
 
 
 
 
 
2a7a772
7ecd689
 
 
 
 
 
 
8e34f80
35e0000
7ecd689
 
 
 
 
 
 
 
 
 
 
35e0000
2a7a772
 
ab23201
 
2a7a772
 
ab23201
53dde32
125e1ba
bba9de5
53dde32
 
 
 
 
cf785f9
2a7a772
 
8b69d1b
bba9de5
 
 
cf785f9
53dde32
 
 
 
 
 
 
7ecd689
125e1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ecd689
125e1ba
 
 
7ecd689
26c12e7
125e1ba
26c12e7
125e1ba
26c12e7
125e1ba
 
 
 
15b96ac
125e1ba
 
35e0000
125e1ba
 
 
 
 
 
 
 
 
2a7a772
125e1ba
 
8e34f80
ee4f4a6
8b69d1b
bba9de5
26c12e7
125e1ba
 
 
ee4f4a6
4c0fb4c
26c12e7
0041559
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
import tensorflow as tf
import tensorflow_hub as hub
import io
from sklearn.metrics.pairwise import cosine_similarity
import tempfile # Add this import
import logging
import os

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)

# Load tokenizer and model for Image Captioning (TextCaps)
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")

# Load Universal Sentence Encoder model for textual similarity calculation
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

# Define a function to compute textual similarity between caption and statement
def compute_textual_similarity(caption, statement):
    # Convert caption and statement into sentence embeddings
    caption_embedding = embed([caption])[0].numpy()
    statement_embedding = embed([statement])[0].numpy()

    # Calculate cosine similarity between sentence embeddings
    similarity_score = cosine_similarity([caption_embedding], [statement_embedding])[0][0]
    return similarity_score

# Read statements from the external file 'statements.txt'
with open('statements.txt', 'r') as file:
    statements = file.read().splitlines()

# Function to compute ITM scores for the image-statement pair
def compute_itm_score(image, statement):
    logging.info('Starting compute_itm_score')
    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
    # Pass the statement text directly to model_itm
    itm_output = model_itm({"image": img, "text_input": statement}, match_head="itm")
    itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
    score = itm_scores[:, 1].item()
    logging.info('Finished compute_itm_score')
    return score

def generate_caption(processor, model, image):
    logging.info('Starting generate_caption')
    inputs = processor(images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    logging.info('Finished generate_caption')
    return generated_caption

def save_dataframe_to_csv(df):
    csv_buffer = io.StringIO()
    df.to_csv(csv_buffer, index=False)
    csv_string = csv_buffer.getvalue()

    # Save the CSV string to a temporary file
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as temp_file:
        temp_file.write(csv_string)
        temp_file_path = temp_file.name # Get the file path

    # Return the file path (no need to reopen the file with "rb" mode)
    return temp_file_path

# Define a function to check if the uploaded file is an image
def is_image_file(file):
    allowed_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff"]
    file_extension = os.path.splitext(file.name)[1]
    return file_extension.lower() in allowed_extensions


# Main function to perform image captioning and image-text matching
# Main function to perform image captioning and image-text matching
def process_images_and_statements(file):
    logging.debug("Entered process_images_and_statements function")
    logging.debug(f"File object: {file}")
    logging.debug(f"File name: {file.name}")
    logging.debug(f"File size: {file.tell()}")

    # Check if the uploaded file is an image
    if not is_image_file(file):
        return "Invalid file type. Please upload an image file (e.g., .jpg, .png, .jpeg)."
    
    # Extract the filename from the file object
    filename = file.name

    # Load the image data from the file (convert file object to bytes using file.read())
    try:
        logging.debug("Attempting to open image")
        image = Image.open(io.BytesIO(file.read()))
        logging.debug("Image opened successfully")
    except Exception as e:
        logging.exception("Error occurred while opening image")
        return str(e)  # Return error message to the user

    # Generate image caption for the uploaded image using git-large-r-textcaps
    caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)

    # Define weights for combining textual similarity score and image-statement ITM score (adjust as needed)
    weight_textual_similarity = 0.5
    weight_statement = 0.5

    # Initialize an empty list to store the results
    results_list = []

    # Loop through each predefined statement
    for statement in statements:
        # Compute textual similarity between caption and statement
        textual_similarity_score = (compute_textual_similarity(caption, statement) * 100) # Multiply by 100

        # Compute ITM score for the image-statement pair
        itm_score_statement = (compute_itm_score(image, statement) * 100) # Multiply by 100

        # Combine the two scores using a weighted average
        final_score = ((weight_textual_similarity * textual_similarity_score) +
                       (weight_statement * itm_score_statement))

        # Append the result to the results_list, including the image filename
        results_list.append({
            'Image Filename': filename,  # Add the image filename to the output
            'Statement': statement,
            'Generated Caption': caption,
            'Textual Similarity Score': f"{textual_similarity_score:.2f}%", # Format as percentage with two decimal places
            'ITM Score': f"{itm_score_statement:.2f}%", # Format as percentage with two decimal places
            'Final Combined Score': f"{final_score:.2f}%" # Format as percentage with two decimal places
        })

    # Convert the results_list to a DataFrame using pandas.concat
    results_df = pd.concat([pd.DataFrame([result]) for result in results_list], ignore_index=True)

    logging.info('Finished process_images_and_statements')

    # Save results_df to a CSV file
    csv_results = save_dataframe_to_csv(results_df)

    # Return both the DataFrame and the CSV data for the Gradio interface
    return results_df, csv_results # <--- Return results_df and csv_results

# Gradio interface
file_input = gr.inputs.File(label="Upload Image")  # Use File input for image upload
output_df = gr.outputs.Dataframe(type="pandas", label="Results")
output_csv = gr.outputs.File(label="Download CSV")

iface = gr.Interface(
    fn=process_images_and_statements,
    inputs=file_input,
    outputs=[output_df, output_csv],
    title="Image Captioning and Image-Text Matching",
    theme='sudeepshouche/minimalist',
    css=".output { flex-direction: column; } .output .outputs { width: 100%; }" # Custom CSS
)

# Launch the Gradio interface
iface.launch(debug=True)