hhhwmws's picture
Upload 19 files
0319a9a verified
raw
history blame
3.27 kB
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import cv2
from PIL import Image
import numpy as np
class CLIPExtractor:
def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):
# 设置代理环境变量
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
# 设置环境变量
os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
os.environ["HF_HOME"] = os.path.expanduser("models/")
if not cache_dir:
# 指定缓存目录
cache_dir = "models"
if not os.path.exists(cache_dir) and os.path.exists("../models"):
cache_dir = "../models"
# Initialize the model and processor with specified values
self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def extract_image(self, frame):
# Convert frame (from OpenCV) to PIL Image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
images = [image]
# Process the image and extract features
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.get_image_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
def extract_image_from_file(self, file_name):
if not os.path.exists(file_name):
raise FileNotFoundError(f"File {file_name} not found.")
images = [Image.open(file_name).convert("RGB")]
# Process the image and extract features
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.get_image_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
def extract_text(self, text):
if not isinstance(text, str) or not text:
raise ValueError("Input text should be a non-empty string.")
# Tokenize the text
inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
# Process the text and extract features
# inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
outputs = self.model.get_text_features(**inputs)
ans = outputs.cpu().numpy()
return ans[0]
if __name__ == "__main__":
clip_extractor = CLIPExtractor()
sample_image = "images/狐狸.jpg"
# 提取图像特征
image_feature = clip_extractor.extract_image_from_file(sample_image)
# 提取文本特征
sample_text = "A photo of fox"
text_feature = clip_extractor.extract_text(sample_text)
# consine similarity
cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
print(cosine_similarity)