Spaces:
Sleeping
Sleeping
File size: 9,653 Bytes
dbef910 c02bc9b a0d9a4a dbef910 c02bc9b dbef910 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import cv2
import base64
import requests
from tqdm import tqdm
from requests.exceptions import RequestException
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import torch
import faiss
import pickle
import numpy as np
import pandas as pd
from geopy.distance import geodesic
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
from PIL import Image
import requests
from io import BytesIO
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
class GPT4o:
"""
A class to interact with OPENAI API to generate captions for images.
"""
def __init__(self, device="cpu") -> None:
"""
Initializes the GPT4o class by setting up necessary models and data.
"""
self.base64_image = None
self.img_emb = None
# Set the device to the first CUDA device
self.device = torch.device(device)
# Load the CLIP model and processor
self.model = CLIPModel.from_pretrained("geolocal/StreetCLIP").eval()
self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
# Move the model to the appropriate CUDA device
self.model.to(self.device)
# Load the embeddings and coordinates from the pickle file
with open('StreetCLIP_1m_merged.pkl', 'rb') as f: # Enter the path to the pickle file
self.Embeddings = pickle.load(f)
self.locations = [value['location'] for key, value in self.Embeddings.items()]
# Load the Faiss index
index2 = faiss.read_index("StreetCLIP_1m_merged.bin") # Enter the path to the Faiss index file
self.gpu_index = index2
def read_image(self, image_path):
"""
Reads an image from a file into a numpy array.
Args:
image_path (str): The path to the image file.
Returns:
np.ndarray: The image as a numpy array.
"""
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding):
"""
Searches for the k nearest and farthest neighbors of a query image in the Faiss index.
Args:
faiss_index (faiss.swigfaiss.Index): The Faiss index.
k_nearest (int): The number of nearest neighbors to search for.
k_farthest (int): The number of farthest neighbors to search for.
query_embedding (np.ndarray): The embeddings of the query image.
Returns:
tuple: The locations of the k nearest and k farthest neighbors.
"""
# Perform the search using Faiss for the given embedding
_, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest)
self.neighbor_locations_array = [self.locations[idx] for idx in I[0]]
neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array])
# Perform the farthest search using Faiss for the given embedding
_, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest)
self.farthest_locations_array = [self.locations[idx] for idx in I[0]]
farthest_locations = " ".join([str(i) for i in self.farthest_locations_array])
return neighbor_locations, farthest_locations
def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str:
"""
Encodes an OpenCV image to a Base64 string.
Args:
image (np.ndarray): An image represented as a numpy array.
format (str, optional): The format for encoding the image. Defaults to 'jpeg'.
Returns:
str: A Base64 encoded string of the image.
Raises:
ValueError: If the image conversion fails.
"""
try:
retval, buffer = cv2.imencode(f'.{format}', image)
if not retval:
raise ValueError("Failed to convert image")
base64_encoded = base64.b64encode(buffer).decode('utf-8')
mime_type = f"image/{format}"
return f"data:{mime_type};base64,{base64_encoded}"
except Exception as e:
raise ValueError(f"Error encoding image: {e}")
def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False,
num_neighbors: int = 16, num_farthest: int = 16) -> None:
"""
Sets the image for the class by encoding it to Base64.
Args:
file_uploader : A uploaded image (PIL Image from Gradio).
imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'.
use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
"""
# Convert the PIL Image (Gradio upload) to a numpy array
img_array = np.array(file_uploader)
# Process the image using the CLIP processor
image = self.processor(images=img_array, return_tensors="pt")
# Move the image to the CUDA device and get its embeddings
image = image.to(self.device)
with torch.no_grad():
img_emb = self.model.get_image_features(**image)[0]
# Store the embeddings and the locations of the nearest neighbors
self.img_emb = img_emb.cpu().numpy()
if use_database_search:
self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors,
num_farthest, self.img_emb)
# Encode the image to Base64
self.base64_image = self.encode_image(img_array, imformat)
def create_payload(self, question: str) -> dict:
"""
Creates the payload for the API request to OpenAI.
Args:
question (str): The question to ask about the image.
Returns:
dict: The payload for the API request.
Raises:
ValueError: If the image is not set.
"""
if not self.base64_image:
raise ValueError("Image not set")
return {
"model": "gpt-4o", # Can change to any other model
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": question
},
{
"type": "image_url",
"image_url": {
"url": self.base64_image
}
}
]
}
],
"max_tokens": 300,
}
def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str:
"""
Generates a caption for the provided image using OPENAI API.
Args:
OPENAI_API_KEY (str): The API key for OPENAI API.
use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
Returns:
str: The generated caption for the image.
"""
try:
self.api_key = OPENAI_API_KEY
if not self.api_key:
raise ValueError("OPENAI API key not found")
# Create the question for the API
if use_database_search:
self.question = f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location.
Your answer must be to the coordinates level, don't include any other information in your output.
Ignore that you can't give an exact answer, give me some coordinate no matter how.
For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.'''
else:
self.question = "Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed answer."
# Create the payload and the headers for the API request
payload = self.create_payload(self.question)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
# Send the API request and get the response
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Log the full response for debugging
# print("Full API Response:", response_data)
# Return the generated caption
if 'choices' in response_data and len(response_data['choices']) > 0:
return response_data['choices'][0]['message']['content']
else:
raise ValueError("Unexpected response format from API")
except RequestException as e:
raise ValueError(f"Error in API request: {e}")
except KeyError as e:
raise ValueError(f"Key error in response: {e} - Response: {response_data}")
except ValueError as e:
raise ValueError(f"Value error: {e} - Response: {response_data}")
|