File size: 1,911 Bytes
34d1458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import tokenizers
import os
from io import BytesIO
import pickle
import base64

import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer,
    CLIPModel,
    AutoProcessor
)
import streamlit.components.v1 as components
from st_clickable_images import clickable_images #pip install st-clickable-images


@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None
    }
)
def load_path_clip():
    model = CLIPModel.from_pretrained("vinid/plip")
    processor = AutoProcessor.from_pretrained("vinid/plip")
    return model, processor

@st.cache
def init():
    with open('data/twitter.asset', 'rb') as f:
        data = pickle.load(f)
    meta = data['meta'].reset_index(drop=True)
    image_embedding = data['image_embedding']
    text_embedding = data['text_embedding']
    print(meta.shape, image_embedding.shape)
    validation_subset_index = meta['source'].values == 'Val_Tweets'
    return meta, image_embedding, text_embedding, validation_subset_index

def embed_images(model, images, processor):
    inputs = processor(images=images)
    pixel_values = torch.tensor(np.array(inputs["pixel_values"]))

    with torch.no_grad():
        embeddings = model.get_image_features(pixel_values=pixel_values)
    return embeddings

def embed_texts(model, texts, processor):
    inputs = processor(text=texts, padding="longest")
    input_ids = torch.tensor(inputs["input_ids"])
    attention_mask = torch.tensor(inputs["attention_mask"])

    with torch.no_grad():
        embeddings = model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
    return embeddings