hhhwmws's picture
Update src/CLIPExtractor.py
0c3718a verified
raw
history blame
No virus
3.28 kB
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import cv2
from PIL import Image
import numpy as np
class CLIPExtractor:
def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):
# 设置代理环境变量
# os.environ['HTTP_PROXY'] = 'http://localhost:8234'
# os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
# # 设置环境变量
# os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
# os.environ["HF_HOME"] = os.path.expanduser("models/")
if not cache_dir:
# 指定缓存目录
cache_dir = "models"
if not os.path.exists(cache_dir) and os.path.exists("../models"):
cache_dir = "../models"
# Initialize the model and processor with specified values
self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def extract_image(self, frame):
# Convert frame (from OpenCV) to PIL Image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
images = [image]
# Process the image and extract features
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.get_image_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
def extract_image_from_file(self, file_name):
if not os.path.exists(file_name):
raise FileNotFoundError(f"File {file_name} not found.")
images = [Image.open(file_name).convert("RGB")]
# Process the image and extract features
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.get_image_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
def extract_text(self, text):
if not isinstance(text, str) or not text:
raise ValueError("Input text should be a non-empty string.")
# Tokenize the text
inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
# Process the text and extract features
# inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
outputs = self.model.get_text_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
if __name__ == "__main__":
clip_extractor = CLIPExtractor()
sample_image = "images/狐狸.jpg"
# 提取图像特征
image_feature = clip_extractor.extract_image_from_file(sample_image)
# 提取文本特征
sample_text = "A photo of fox"
text_feature = clip_extractor.extract_text(sample_text)
# consine similarity
cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
print(cosine_similarity)