File size: 3,277 Bytes
0319a9a
 
 
 
 
 
 
 
 
 
 
 
 
0c3718a
 
0319a9a
0c3718a
 
 
0319a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import cv2
from PIL import Image
import numpy as np



class CLIPExtractor:
    def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):

        # 设置代理环境变量
        # os.environ['HTTP_PROXY'] = 'http://localhost:8234'
        # os.environ['HTTPS_PROXY'] = 'http://localhost:8234'

        # # 设置环境变量
        # os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
        # os.environ["HF_HOME"] = os.path.expanduser("models/")

        if not cache_dir:
            # 指定缓存目录
            cache_dir = "models"
            if not os.path.exists(cache_dir) and os.path.exists("../models"):
                cache_dir = "../models"

        # Initialize the model and processor with specified values
        self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
        self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def extract_image(self, frame):
        # Convert frame (from OpenCV) to PIL Image
        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        images = [image]
        
        # Process the image and extract features
        inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.get_image_features(**inputs)
        
        ans = outputs.cpu().numpy()
        return ans[0]

    def extract_image_from_file(self, file_name):
        if not os.path.exists(file_name):
            raise FileNotFoundError(f"File {file_name} not found.")
        
        images = [Image.open(file_name).convert("RGB")]
        
        # Process the image and extract features
        inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.get_image_features(**inputs)
        
        ans = outputs.cpu().numpy()
        return ans[0]
    
    def extract_text(self, text):
        if not isinstance(text, str) or not text:
            raise ValueError("Input text should be a non-empty string.")
        
         # Tokenize the text
        inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
        

        # Process the text and extract features
        # inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.get_text_features(**inputs)
        
        ans = outputs.cpu().numpy()
        return ans[0]


if __name__ == "__main__":

    clip_extractor = CLIPExtractor()

    sample_image = "images/狐狸.jpg"
    # 提取图像特征
    image_feature = clip_extractor.extract_image_from_file(sample_image)


    # 提取文本特征
    sample_text = "A photo of fox"
    text_feature = clip_extractor.extract_text(sample_text)

    # consine similarity
    cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
    print(cosine_similarity)