vinid commited on
Commit
e66a8fa
·
1 Parent(s): 34d1458
Files changed (3) hide show
  1. helper.py +0 -65
  2. image2image.py +41 -2
  3. text2image.py +41 -1
helper.py DELETED
@@ -1,65 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- from plip_support import embed_text
4
- import numpy as np
5
- from PIL import Image
6
- import requests
7
- import tokenizers
8
- import os
9
- from io import BytesIO
10
- import pickle
11
- import base64
12
-
13
- import torch
14
- from transformers import (
15
- VisionTextDualEncoderModel,
16
- AutoFeatureExtractor,
17
- AutoTokenizer,
18
- CLIPModel,
19
- AutoProcessor
20
- )
21
- import streamlit.components.v1 as components
22
- from st_clickable_images import clickable_images #pip install st-clickable-images
23
-
24
-
25
- @st.cache(
26
- hash_funcs={
27
- torch.nn.parameter.Parameter: lambda _: None,
28
- tokenizers.Tokenizer: lambda _: None,
29
- tokenizers.AddedToken: lambda _: None
30
- }
31
- )
32
- def load_path_clip():
33
- model = CLIPModel.from_pretrained("vinid/plip")
34
- processor = AutoProcessor.from_pretrained("vinid/plip")
35
- return model, processor
36
-
37
- @st.cache
38
- def init():
39
- with open('data/twitter.asset', 'rb') as f:
40
- data = pickle.load(f)
41
- meta = data['meta'].reset_index(drop=True)
42
- image_embedding = data['image_embedding']
43
- text_embedding = data['text_embedding']
44
- print(meta.shape, image_embedding.shape)
45
- validation_subset_index = meta['source'].values == 'Val_Tweets'
46
- return meta, image_embedding, text_embedding, validation_subset_index
47
-
48
- def embed_images(model, images, processor):
49
- inputs = processor(images=images)
50
- pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
51
-
52
- with torch.no_grad():
53
- embeddings = model.get_image_features(pixel_values=pixel_values)
54
- return embeddings
55
-
56
- def embed_texts(model, texts, processor):
57
- inputs = processor(text=texts, padding="longest")
58
- input_ids = torch.tensor(inputs["input_ids"])
59
- attention_mask = torch.tensor(inputs["attention_mask"])
60
-
61
- with torch.no_grad():
62
- embeddings = model.get_text_features(
63
- input_ids=input_ids, attention_mask=attention_mask
64
- )
65
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image2image.py CHANGED
@@ -20,9 +20,48 @@ from transformers import (
20
  import streamlit.components.v1 as components
21
  from st_clickable_images import clickable_images #pip install st-clickable-images
22
 
23
- from helper import load_path_clip, init, embed_images
24
-
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def app():
27
  st.title('Image to Image Retrieval')
28
  st.markdown('#### A pathology image search engine that correlate images with images.')
 
20
  import streamlit.components.v1 as components
21
  from st_clickable_images import clickable_images #pip install st-clickable-images
22
 
 
 
23
 
24
+ @st.cache(
25
+ hash_funcs={
26
+ torch.nn.parameter.Parameter: lambda _: None,
27
+ tokenizers.Tokenizer: lambda _: None,
28
+ tokenizers.AddedToken: lambda _: None
29
+ }
30
+ )
31
+ def load_path_clip():
32
+ model = CLIPModel.from_pretrained("vinid/plip")
33
+ processor = AutoProcessor.from_pretrained("vinid/plip")
34
+ return model, processor
35
+
36
+ @st.cache
37
+ def init():
38
+ with open('data/twitter.asset', 'rb') as f:
39
+ data = pickle.load(f)
40
+ meta = data['meta'].reset_index(drop=True)
41
+ image_embedding = data['image_embedding']
42
+ text_embedding = data['text_embedding']
43
+ print(meta.shape, image_embedding.shape)
44
+ validation_subset_index = meta['source'].values == 'Val_Tweets'
45
+ return meta, image_embedding, text_embedding, validation_subset_index
46
+
47
+ def embed_images(model, images, processor):
48
+ inputs = processor(images=images)
49
+ pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
50
+
51
+ with torch.no_grad():
52
+ embeddings = model.get_image_features(pixel_values=pixel_values)
53
+ return embeddings
54
+
55
+ def embed_texts(model, texts, processor):
56
+ inputs = processor(text=texts, padding="longest")
57
+ input_ids = torch.tensor(inputs["input_ids"])
58
+ attention_mask = torch.tensor(inputs["attention_mask"])
59
+
60
+ with torch.no_grad():
61
+ embeddings = model.get_text_features(
62
+ input_ids=input_ids, attention_mask=attention_mask
63
+ )
64
+ return embeddings
65
  def app():
66
  st.title('Image to Image Retrieval')
67
  st.markdown('#### A pathology image search engine that correlate images with images.')
text2image.py CHANGED
@@ -16,8 +16,48 @@ from transformers import (
16
  )
17
  import streamlit.components.v1 as components
18
 
19
- from helper import load_path_clip, init, embed_texts
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def app():
 
16
  )
17
  import streamlit.components.v1 as components
18
 
 
19
 
20
+ @st.cache(
21
+ hash_funcs={
22
+ torch.nn.parameter.Parameter: lambda _: None,
23
+ tokenizers.Tokenizer: lambda _: None,
24
+ tokenizers.AddedToken: lambda _: None
25
+ }
26
+ )
27
+ def load_path_clip():
28
+ model = CLIPModel.from_pretrained("vinid/plip")
29
+ processor = AutoProcessor.from_pretrained("vinid/plip")
30
+ return model, processor
31
+
32
+ @st.cache
33
+ def init():
34
+ with open('data/twitter.asset', 'rb') as f:
35
+ data = pickle.load(f)
36
+ meta = data['meta'].reset_index(drop=True)
37
+ image_embedding = data['image_embedding']
38
+ text_embedding = data['text_embedding']
39
+ print(meta.shape, image_embedding.shape)
40
+ validation_subset_index = meta['source'].values == 'Val_Tweets'
41
+ return meta, image_embedding, text_embedding, validation_subset_index
42
+
43
+ def embed_images(model, images, processor):
44
+ inputs = processor(images=images)
45
+ pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
46
+
47
+ with torch.no_grad():
48
+ embeddings = model.get_image_features(pixel_values=pixel_values)
49
+ return embeddings
50
+
51
+ def embed_texts(model, texts, processor):
52
+ inputs = processor(text=texts, padding="longest")
53
+ input_ids = torch.tensor(inputs["input_ids"])
54
+ attention_mask = torch.tensor(inputs["attention_mask"])
55
+
56
+ with torch.no_grad():
57
+ embeddings = model.get_text_features(
58
+ input_ids=input_ids, attention_mask=attention_mask
59
+ )
60
+ return embeddings
61
 
62
 
63
  def app():