Spaces:
Sleeping
Sleeping
File size: 9,068 Bytes
20ea451 1a36398 20ea451 1a36398 20ea451 1a36398 20ea451 6b1bbaf 20ea451 6b1bbaf 20ea451 6b1bbaf 81bf6cd 20ea451 6b1bbaf 20ea451 9bd955c 20ea451 1a36398 20ea451 1a36398 20ea451 1a36398 20ea451 1a36398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import numpy as np
from sentence_transformers import SentenceTransformer, util
from open_clip import create_model_from_pretrained, get_tokenizer
import torch
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
import boto3
import streamlit as st
from PIL import Image
from PIL import ImageDraw
from io import BytesIO
import pandas as pd
from typing import List, Union
import concurrent.futures
# Initialize the model globally to avoid reloading each time
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
#what model do we use?
def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
"""
Encode the query using the OpenCLIP model.
Parameters
----------
query : Union[str, Image.Image]
The query, which can be a text string or an Image object.
Returns
-------
torch.Tensor
The encoded query vector.
"""
if isinstance(query, Image.Image):
query = preprocess(query).unsqueeze(0) # Preprocess the image and add batch dimension
with torch.no_grad():
query_embedding = model.encode_image(query) # Get image embedding
elif isinstance(query, str):
text = tokenizer(query, context_length=model.context_length)
with torch.no_grad():
query_embedding = model.encode_text(text) # Get text embedding
else:
raise ValueError("Query must be either a string or an Image.")
return query_embedding
def load_dataset_with_limit(dataset_name, dataset_subset, search_in_small_objects,limit=1000):
"""
Load a dataset from Hugging Face and limit the number of rows.
"""
if search_in_small_objects:
split = f'Splits_{dataset_subset}'
else:
split = f'Main_{dataset_subset}'
dataset_name = f"quasara-io/{dataset_name}"
dataset = load_dataset(dataset_name, split=split)
total_rows = dataset.num_rows
# Convert to DataFrame and sample if limit is provided
if limit is not None:
df = dataset.to_pandas().sample(n=limit, random_state=42)
else:
df = dataset.to_pandas()
return df,total_rows
def get_image_vectors(df):
# Get the image vectors from the dataframe
image_vectors = np.vstack(df['Vector'].to_numpy())
return torch.tensor(image_vectors, dtype=torch.float32)
def search(query, df, limit, search_in_images = True):
if search_in_images:
# Encode the image query
query_vector = encode_query(query)
# Get the image vectors from the dataframe
image_vectors = get_image_vectors(df)
# Calculate the cosine similarity between the query vector and each image vector
query_vector = query_vector[0, :].detach().numpy() # Detach and convert to a NumPy array
image_vectors = image_vectors.detach().numpy() # Convert the image vectors to a NumPy array
cosine_similarities = cosine_similarity([query_vector], image_vectors)
# Get the top K indices of the most similar image vectors
top_k_indices = np.argsort(-cosine_similarities[0])[:limit]
# Return the top K indices
return top_k_indices
#Try Batch Search
def batch_search(query, df, batch_size=100000, limit=10):
top_k_indices = []
# Get the image vectors from the dataframe and ensure they are NumPy arrays
vectors = get_image_vectors(df).numpy() # Convert to NumPy array if it's a tensor
# Encode the query and ensure it's a NumPy array
query_vector = encode_query(query)[0].detach().numpy() # Assuming the first element is the query embedding
# Iterate over the batches and compute cosine similarities
for i in range(0, len(vectors), batch_size):
batch_vectors = vectors[i:i + batch_size] # Extract a batch of vectors
# Compute cosine similarity between the query vector and the batch
batch_similarities = cosine_similarity([query_vector], batch_vectors)
# Get the top-k similar vectors within this batch
top_k_indices.extend(np.argsort(-batch_similarities[0])[:limit])
return top_k_indices
def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
"""
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
Parameters:
- df: pandas DataFrame containing the data
- top_k_indices: numpy array of the top K indices
- column_name: str, the name of the column to fetch (e.g., 'ImagePath')
Returns:
- top_k_paths: list of file paths or values from the specified column
"""
# Fetch the specific column corresponding to the top K indices
top_k_paths = df.iloc[top_k_indices][column_name].tolist()
return top_k_paths
def get_cordinates(df, top_k_indices, column_name = 'Coordinate'):
"""
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
Parameters:
- df: pandas DataFrame containing the data
- top_k_indices: numpy array of the top K indices
- column_name: str, the name of the column to fetch (e.g., 'ImagePath')
Returns:
- top_k_paths: list of file paths or values from the specified column
"""
# Fetch the specific column corresponding to the top K indices
top_k_paths = df.iloc[top_k_indices][column_name].tolist()
return top_k_paths
def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name):
"""
Retrieve and display images from AWS S3 in a Streamlit app.
Parameters:
- bucket_name: str, the name of the S3 bucket
- file_paths: list, a list of file paths to retrieve from S3
Returns:
- None (directly displays images in the Streamlit app)
"""
# Initialize S3 client
s3 = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
# Iterate over file paths and display each image
for file_path in file_paths:
# Retrieve the image from S3
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
img_data = s3_object['Body'].read()
# Open the image using PIL and display it using Streamlit
img = Image.open(BytesIO(img_data))
st.image(img, caption=file_path, use_column_width=True)
def get_images_with_bounding_boxes_from_s3(bucket_name, file_paths, bounding_boxes, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_name):
"""
Retrieve and display images from AWS S3 with corresponding bounding boxes in a Streamlit app.
Parameters:
- bucket_name: str, the name of the S3 bucket
- file_paths: list, a list of file paths to retrieve from S3
- bounding_boxes: list of numpy arrays or lists, each containing coordinates of bounding boxes (in the form [x_min, y_min, x_max, y_max])
- AWS_ACCESS_KEY_ID: str, AWS access key ID for authentication
- AWS_SECRET_ACCESS_KEY: str, AWS secret access key for authentication
- folder_name: str, the folder prefix in S3 bucket where the images are stored
Returns:
- None (directly displays images in the Streamlit app with bounding boxes)
"""
# Initialize S3 client
s3 = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
# Iterate over file paths and corresponding bounding boxes
for file_path, box_coords in zip(file_paths, bounding_boxes):
# Retrieve the image from S3
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
img_data = s3_object['Body'].read()
# Open the image using PIL
img = Image.open(BytesIO(img_data))
# Draw bounding boxes on the image
draw = ImageDraw.Draw(img)
# Ensure box_coords is iterable, in case it's a single numpy array or float value
if isinstance(box_coords, (np.ndarray, list)):
# Check if we have multiple bounding boxes or a single one
if len(box_coords) > 0 and isinstance(box_coords[0], (np.ndarray, list)):
# Multiple bounding boxes
for box in box_coords:
x_min, y_min, x_max, y_max = map(int, box) # Convert to integers
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
else:
# Single bounding box
x_min, y_min, x_max, y_max = map(int, box_coords)
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
else:
raise ValueError(f"Bounding box data for {file_path} is not in an iterable format.")
# Display the image with bounding boxes using Streamlit
st.image(img, caption=file_path, use_column_width=True)
|