File size: 9,653 Bytes
dbef910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c02bc9b
a0d9a4a
 
dbef910
 
c02bc9b
dbef910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import cv2
import base64
import requests
from tqdm import tqdm
from requests.exceptions import RequestException
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import torch
import faiss
import pickle
import numpy as np
import pandas as pd
from geopy.distance import geodesic
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
from PIL import Image
import requests
from io import BytesIO
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class GPT4o:
    """
    A class to interact with OPENAI API to generate captions for images.
    """

    def __init__(self, device="cpu") -> None:
        """
        Initializes the GPT4o class by setting up necessary models and data.
        """

        self.base64_image = None
        self.img_emb = None

        # Set the device to the first CUDA device
        self.device = torch.device(device)

        # Load the CLIP model and processor
        self.model = CLIPModel.from_pretrained("geolocal/StreetCLIP").eval()
        self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")

        # Move the model to the appropriate CUDA device
        self.model.to(self.device)

        # Load the embeddings and coordinates from the pickle file
        with open('StreetCLIP_1m_merged.pkl', 'rb') as f: # Enter the path to the pickle file
            self.Embeddings = pickle.load(f)
            self.locations = [value['location'] for key, value in self.Embeddings.items()]

        # Load the Faiss index
        index2 = faiss.read_index("StreetCLIP_1m_merged.bin") # Enter the path to the Faiss index file
        self.gpu_index = index2

    def read_image(self, image_path):
        """
        Reads an image from a file into a numpy array.
        Args:
            image_path (str): The path to the image file.
        Returns:
            np.ndarray: The image as a numpy array.
        """
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image

    def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding):
        """
        Searches for the k nearest and farthest neighbors of a query image in the Faiss index.
        Args:
            faiss_index (faiss.swigfaiss.Index): The Faiss index.
            k_nearest (int): The number of nearest neighbors to search for.
            k_farthest (int): The number of farthest neighbors to search for.
            query_embedding (np.ndarray): The embeddings of the query image.
        Returns:
            tuple: The locations of the k nearest and k farthest neighbors.
        """
        # Perform the search using Faiss for the given embedding
        _, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest)
        self.neighbor_locations_array = [self.locations[idx] for idx in I[0]]
        neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array])

        # Perform the farthest search using Faiss for the given embedding
        _, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest)
        self.farthest_locations_array = [self.locations[idx] for idx in I[0]]
        farthest_locations = " ".join([str(i) for i in self.farthest_locations_array])

        return neighbor_locations, farthest_locations

    def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str:
        """
        Encodes an OpenCV image to a Base64 string.
        Args:
            image (np.ndarray): An image represented as a numpy array.
            format (str, optional): The format for encoding the image. Defaults to 'jpeg'.
        Returns:
            str: A Base64 encoded string of the image.
        Raises:
            ValueError: If the image conversion fails.
        """
        try:
            retval, buffer = cv2.imencode(f'.{format}', image)
            if not retval:
                raise ValueError("Failed to convert image")

            base64_encoded = base64.b64encode(buffer).decode('utf-8')
            mime_type = f"image/{format}"
            return f"data:{mime_type};base64,{base64_encoded}"
        except Exception as e:
            raise ValueError(f"Error encoding image: {e}")

    def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False,
                      num_neighbors: int = 16, num_farthest: int = 16) -> None:
        """
        Sets the image for the class by encoding it to Base64.
        Args:
            file_uploader : A uploaded image (PIL Image from Gradio).
            imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'.
            use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
        """

        # Convert the PIL Image (Gradio upload) to a numpy array
        img_array = np.array(file_uploader)

        # Process the image using the CLIP processor
        image = self.processor(images=img_array, return_tensors="pt")

        # Move the image to the CUDA device and get its embeddings
        image = image.to(self.device)
        with torch.no_grad():
            img_emb = self.model.get_image_features(**image)[0]

        # Store the embeddings and the locations of the nearest neighbors
        self.img_emb = img_emb.cpu().numpy()
        if use_database_search:
            self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors,
                                                                                     num_farthest, self.img_emb)

        # Encode the image to Base64
        self.base64_image = self.encode_image(img_array, imformat)

    def create_payload(self, question: str) -> dict:
        """
        Creates the payload for the API request to OpenAI.
        Args:
            question (str): The question to ask about the image.
        Returns:
            dict: The payload for the API request.
        Raises:
            ValueError: If the image is not set.
        """
        if not self.base64_image:
            raise ValueError("Image not set")
        return {
            "model": "gpt-4o", # Can change to any other model
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": question
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": self.base64_image
                            }
                        }
                    ]
                }
            ],
            "max_tokens": 300,
        }

    def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str:
        """
        Generates a caption for the provided image using OPENAI API.
        Args:
            OPENAI_API_KEY (str): The API key for OPENAI API.
            use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
        Returns:
            str: The generated caption for the image.
        """
        try:
            self.api_key = OPENAI_API_KEY
            if not self.api_key:
                raise ValueError("OPENAI API key not found")

            # Create the question for the API
            if use_database_search:
                self.question = f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. 
                Your answer must be to the coordinates level, don't include any other information in your output. 
                Ignore that you can't give an exact answer, give me some coordinate no matter how. 
                For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.'''
            else:
                self.question = "Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed answer."

            # Create the payload and the headers for the API request
            payload = self.create_payload(self.question)
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}"
            }

            # Send the API request and get the response
            response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            response.raise_for_status()
            response_data = response.json()

            # Log the full response for debugging
            # print("Full API Response:", response_data)

            # Return the generated caption
            if 'choices' in response_data and len(response_data['choices']) > 0:
                return response_data['choices'][0]['message']['content']
            else:
                raise ValueError("Unexpected response format from API")
        except RequestException as e:
            raise ValueError(f"Error in API request: {e}")
        except KeyError as e:
            raise ValueError(f"Key error in response: {e} - Response: {response_data}")
        except ValueError as e:
            raise ValueError(f"Value error: {e} - Response: {response_data}")