huangzhii commited on
Commit
e571e8f
·
1 Parent(s): 79c7253

Add text embedding, allowing input to compare with both text and image

Browse files
Files changed (6) hide show
  1. app.py +1 -1
  2. data/twitter.asset +2 -2
  3. helper.py +65 -0
  4. image2image.py +88 -82
  5. plip_support.py +0 -9
  6. text2image.py +86 -80
app.py CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
5
 
6
 
7
 
8
- #st.set_page_config(layout="wide")
9
 
10
  st.sidebar.title("Multi-task Vision–Language AI for Pathology")
11
 
 
5
 
6
 
7
 
8
+ st.set_page_config(layout="wide")
9
 
10
  st.sidebar.title("Multi-task Vision–Language AI for Pathology")
11
 
data/twitter.asset CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:73aa98497b501cef03980d0ff0be5b3a02ff88d377bf5513e4eca8dab0870153
3
- size 145886932
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8804057c2b910dd56a2cde6f02d317fed9dacc51e6e0ace5fa57effdf06f8c34
3
+ size 266592030
helper.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from plip_support import embed_text
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ import tokenizers
8
+ import os
9
+ from io import BytesIO
10
+ import pickle
11
+ import base64
12
+
13
+ import torch
14
+ from transformers import (
15
+ VisionTextDualEncoderModel,
16
+ AutoFeatureExtractor,
17
+ AutoTokenizer,
18
+ CLIPModel,
19
+ AutoProcessor
20
+ )
21
+ import streamlit.components.v1 as components
22
+ from st_clickable_images import clickable_images #pip install st-clickable-images
23
+
24
+
25
+ @st.cache(
26
+ hash_funcs={
27
+ torch.nn.parameter.Parameter: lambda _: None,
28
+ tokenizers.Tokenizer: lambda _: None,
29
+ tokenizers.AddedToken: lambda _: None
30
+ }
31
+ )
32
+ def load_path_clip():
33
+ model = CLIPModel.from_pretrained("vinid/plip")
34
+ processor = AutoProcessor.from_pretrained("vinid/plip")
35
+ return model, processor
36
+
37
+ @st.cache
38
+ def init():
39
+ with open('data/twitter.asset', 'rb') as f:
40
+ data = pickle.load(f)
41
+ meta = data['meta'].reset_index(drop=True)
42
+ image_embedding = data['image_embedding']
43
+ text_embedding = data['text_embedding']
44
+ print(meta.shape, image_embedding.shape)
45
+ validation_subset_index = meta['source'].values == 'Val_Tweets'
46
+ return meta, image_embedding, text_embedding, validation_subset_index
47
+
48
+ def embed_images(model, images, processor):
49
+ inputs = processor(images=images)
50
+ pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
51
+
52
+ with torch.no_grad():
53
+ embeddings = model.get_image_features(pixel_values=pixel_values)
54
+ return embeddings
55
+
56
+ def embed_texts(model, texts, processor):
57
+ inputs = processor(text=texts, padding="longest")
58
+ input_ids = torch.tensor(inputs["input_ids"])
59
+ attention_mask = torch.tensor(inputs["attention_mask"])
60
+
61
+ with torch.no_grad():
62
+ embeddings = model.get_text_features(
63
+ input_ids=input_ids, attention_mask=attention_mask
64
+ )
65
+ return embeddings
image2image.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from plip_support import embed_text
4
  import numpy as np
5
  from PIL import Image
6
  import requests
@@ -21,50 +20,38 @@ from transformers import (
21
  import streamlit.components.v1 as components
22
  from st_clickable_images import clickable_images #pip install st-clickable-images
23
 
24
-
25
- def embed_images(model, images, processor):
26
- inputs = processor(images=images)
27
- pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
28
-
29
- with torch.no_grad():
30
- embeddings = model.get_image_features(pixel_values=pixel_values)
31
- return embeddings
32
-
33
- @st.cache
34
- def load_embeddings(embeddings_path):
35
- print("loading embeddings")
36
- return np.load(embeddings_path)
37
-
38
- @st.cache(
39
- hash_funcs={
40
- torch.nn.parameter.Parameter: lambda _: None,
41
- tokenizers.Tokenizer: lambda _: None,
42
- tokenizers.AddedToken: lambda _: None
43
- }
44
- )
45
- def load_path_clip():
46
- model = CLIPModel.from_pretrained("vinid/plip")
47
- processor = AutoProcessor.from_pretrained("vinid/plip")
48
- return model, processor
49
-
50
- def init():
51
- with open('data/twitter.asset', 'rb') as f:
52
- data = pickle.load(f)
53
- meta = data['meta'].reset_index(drop=True)
54
- image_embedding = data['embedding']
55
- print(meta.shape, image_embedding.shape)
56
- validation_subset_index = meta['source'].values == 'Val_Tweets'
57
- return meta, image_embedding, validation_subset_index
58
 
59
 
60
  def app():
61
  st.title('Image to Image Retrieval')
62
  st.markdown('#### A pathology image search engine that correlate images with images.')
63
 
64
- meta, image_embedding, validation_subset_index = init()
65
  model, processor = load_path_clip()
66
 
67
- st.markdown('Click following examples:')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  example_path = 'data/example_images'
69
  list_of_examples = [os.path.join(example_path, v) for v in os.listdir(example_path)]
70
  example_imgs = []
@@ -86,18 +73,9 @@ def app():
86
 
87
 
88
 
89
- data_options = ["All twitter data (2006-03-21 — 2023-01-15)",
90
- "Twitter validation data (2022-11-16 — 2023-01-15)"]
91
- st.radio(
92
- "Or choose dataset for image retrieval 👉",
93
- key="datapool",
94
- options=data_options,
95
- )
96
-
97
 
98
 
99
-
100
- col1, col2 = st.columns(2)
101
  with col1:
102
  query = st.file_uploader("Choose a file to upload")
103
 
@@ -113,49 +91,77 @@ def app():
113
  with col2:
114
  st.image(image, caption='Your upload')
115
 
116
- single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy()
117
 
118
- single_image = single_image/np.linalg.norm(single_image)
119
 
120
  # Sort IDs by cosine-similarity from high to low
121
- similarity_scores = single_image.dot(image_embedding.T)
122
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
124
  topn = 5
125
- if st.session_state.datapool == data_options[0]:
126
- #Use all twitter data
127
- id_sorted = np.argsort(similarity_scores)[::-1]
128
- best_ids = id_sorted[:topn]
129
- best_scores = similarity_scores[best_ids]
130
- target_weblinks = meta["weblink"].values[best_ids]
131
- else:
132
- #Use validation twitter data
133
- similarity_scores = similarity_scores[validation_subset_index]
134
- # Sort IDs by cosine-similarity from high to low
135
- id_sorted = np.argsort(similarity_scores)[::-1]
136
- best_ids = id_sorted[:topn]
137
- best_scores = similarity_scores[best_ids]
138
- target_weblinks = meta["weblink"].values[validation_subset_index][best_ids]
139
- #TODO: Avoid duplicated ID
140
 
 
 
 
 
 
141
  topk_options = ['1st', '2nd', '3rd', '4th', '5th']
142
- st.radio(
143
- "Choose the most similar 👉",
144
- key="top_k",
145
- options=topk_options,
146
- horizontal=True
147
- )
148
- topn_txt = st.session_state.top_k
149
- topn_value = int(st.session_state.top_k[0])-1
150
- st.caption(f'The {topn_txt} relevant image (similarity = {best_scores[topn_value]:.4f})')
151
- components.html('''
152
- <blockquote class="twitter-tweet">
153
- <a href="%s"></a>
154
- </blockquote>
155
- <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
156
- </script>
157
- ''' % target_weblinks[topn_value],
158
- height=800)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import numpy as np
4
  from PIL import Image
5
  import requests
 
20
  import streamlit.components.v1 as components
21
  from st_clickable_images import clickable_images #pip install st-clickable-images
22
 
23
+ from helper import load_path_clip, init, embed_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def app():
27
  st.title('Image to Image Retrieval')
28
  st.markdown('#### A pathology image search engine that correlate images with images.')
29
 
30
+ meta, image_embedding, text_embedding, validation_subset_index = init()
31
  model, processor = load_path_clip()
32
 
33
+
34
+ col1, col2 = st.columns(2)
35
+ with col1:
36
+ data_options = ["All twitter data (2006-03-21 — 2023-01-15)",
37
+ "Twitter validation data (2022-11-16 — 2023-01-15)"]
38
+ st.radio(
39
+ "Choose dataset for image retrieval 👉",
40
+ key="datapool",
41
+ options=data_options,
42
+ )
43
+ with col2:
44
+ retrieval_options = ["Image only",
45
+ "Text and image (beta)",
46
+ ]
47
+ st.radio(
48
+ "Similarity calcuation 👉",
49
+ key="calculation_option",
50
+ options=retrieval_options,
51
+ )
52
+
53
+
54
+ st.markdown('Try out following examples:')
55
  example_path = 'data/example_images'
56
  list_of_examples = [os.path.join(example_path, v) for v in os.listdir(example_path)]
57
  example_imgs = []
 
73
 
74
 
75
 
 
 
 
 
 
 
 
 
76
 
77
 
78
+ col1, col2, _ = st.columns(3)
 
79
  with col1:
80
  query = st.file_uploader("Choose a file to upload")
81
 
 
91
  with col2:
92
  st.image(image, caption='Your upload')
93
 
94
+ input_image = embed_images(model, [image], processor)[0].detach().cpu().numpy()
95
 
96
+ input_image = input_image/np.linalg.norm(input_image)
97
 
98
  # Sort IDs by cosine-similarity from high to low
 
99
 
100
+ if st.session_state.calculation_option == retrieval_options[0]: # Image only
101
+ similarity_scores = input_image.dot(image_embedding.T)
102
+ else: # Text and Image
103
+ similarity_scores_i = input_image.dot(image_embedding.T)
104
+ similarity_scores_t = input_image.dot(text_embedding.T)
105
+ similarity_scores_i = similarity_scores_i/np.max(similarity_scores_i)
106
+ similarity_scores_t = similarity_scores_t/np.max(similarity_scores_t)
107
+ similarity_scores = (similarity_scores_i + similarity_scores_t)/2
108
 
109
+
110
+ ############################################################
111
+ # Get top results
112
+ ############################################################
113
  topn = 5
114
+ df = pd.DataFrame(np.c_[np.arange(len(meta)), similarity_scores, meta['weblink'].values], columns = ['idx', 'score', 'twitterlink'])
115
+ if st.session_state.datapool == data_options[1]: #Use val twitter data
116
+ df = df.loc[validation_subset_index,:]
117
+ df = df.sort_values('score', ascending=False)
118
+ df = df.drop_duplicates(subset=['twitterlink'])
119
+ best_id_topk = df['idx'].values[:topn]
120
+ target_scores = df['score'].values[:topn]
121
+ target_weblinks = df['twitterlink'].values[:topn]
122
+
123
+
 
 
 
 
 
124
 
125
+ ############################################################
126
+ # Display results
127
+ ############################################################
128
+
129
+ st.markdown('#### Top 5 results:')
130
  topk_options = ['1st', '2nd', '3rd', '4th', '5th']
131
+ tab = {}
132
+ tab[0], tab[1], tab[2] = st.columns(3)
133
+ for i in [0,1,2]:
134
+ with tab[i]:
135
+ topn_value = i
136
+ topn_txt = topk_options[i]
137
+ st.caption(f'The {topn_txt} relevant image (similarity = {target_scores[topn_value]:.4f})')
138
+ components.html('''
139
+ <blockquote class="twitter-tweet">
140
+ <a href="%s"></a>
141
+ </blockquote>
142
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
143
+ </script>
144
+ ''' % target_weblinks[topn_value],
145
+ height=800)
146
+
147
+ tab[3], tab[4], tab[5] = st.columns(3)
148
+ for i in [3,4]:
149
+ with tab[i]:
150
+ topn_value = i
151
+ topn_txt = topk_options[i]
152
+ st.caption(f'The {topn_txt} relevant image (similarity = {target_scores[topn_value]:.4f})')
153
+ components.html('''
154
+ <blockquote class="twitter-tweet">
155
+ <a href="%s"></a>
156
+ </blockquote>
157
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
158
+ </script>
159
+ ''' % target_weblinks[topn_value],
160
+ height=800)
161
+
162
+
163
+
164
+
165
 
166
 
167
 
plip_support.py DELETED
@@ -1,9 +0,0 @@
1
- import clip
2
- import torch
3
-
4
-
5
-
6
-
7
- def embed_text(plip, text, device="cpu"):
8
- idx = clip.tokenize([text], truncate=True).to(device)
9
- return plip.encode_text(idx).detach().cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
text2image.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from plip_support import embed_text
4
  import numpy as np
5
  from PIL import Image
6
  import requests
@@ -17,43 +16,9 @@ from transformers import (
17
  )
18
  import streamlit.components.v1 as components
19
 
 
20
 
21
- def embed_texts(model, texts, processor):
22
- inputs = processor(text=texts, padding="longest")
23
- input_ids = torch.tensor(inputs["input_ids"])
24
- attention_mask = torch.tensor(inputs["attention_mask"])
25
 
26
- with torch.no_grad():
27
- embeddings = model.get_text_features(
28
- input_ids=input_ids, attention_mask=attention_mask
29
- )
30
- return embeddings
31
-
32
- @st.cache
33
- def load_embeddings(embeddings_path):
34
- print("loading embeddings")
35
- return np.load(embeddings_path)
36
-
37
- @st.cache(
38
- hash_funcs={
39
- torch.nn.parameter.Parameter: lambda _: None,
40
- tokenizers.Tokenizer: lambda _: None,
41
- tokenizers.AddedToken: lambda _: None
42
- }
43
- )
44
- def load_path_clip():
45
- model = CLIPModel.from_pretrained("vinid/plip")
46
- processor = AutoProcessor.from_pretrained("vinid/plip")
47
- return model, processor
48
-
49
- def init():
50
- with open('data/twitter.asset', 'rb') as f:
51
- data = pickle.load(f)
52
- meta = data['meta'].reset_index(drop=True)
53
- image_embedding = data['embedding']
54
- print(meta.shape, image_embedding.shape)
55
- validation_subset_index = meta['source'].values == 'Val_Tweets'
56
- return meta, image_embedding, validation_subset_index
57
 
58
  def app():
59
 
@@ -61,16 +26,29 @@ def app():
61
  st.markdown('#### A pathology image search engine that correlate texts directly with images.')
62
  st.caption('Note: The searching query matches images only. The twitter text does not used for searching.')
63
 
64
- meta, image_embedding, validation_subset_index = init()
65
  model, processor = load_path_clip()
66
 
67
- data_options = ["All twitter data (2006-03-21 — 2023-01-15)",
68
- "Twitter validation data (2022-11-16 — 2023-01-15)"]
69
- st.radio(
70
- "Choose dataset for image retrieval 👉",
71
- key="datapool",
72
- options=data_options,
73
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  col1, col2 = st.columns(2)
@@ -106,46 +84,74 @@ def app():
106
  else:
107
  query = query_2
108
 
109
- text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
110
- text_embedding = text_embedding/np.linalg.norm(text_embedding)
 
 
111
 
112
- similarity_scores = text_embedding.dot(image_embedding.T)
 
 
 
 
 
 
 
113
 
114
- topn = 5
115
- if st.session_state.datapool == data_options[0]:
116
- #Use all twitter data
117
- id_sorted = np.argsort(similarity_scores)[::-1]
118
- best_ids = id_sorted[:topn]
119
- best_scores = similarity_scores[best_ids]
120
- target_weblinks = meta["weblink"].values[best_ids]
121
- else:
122
- #Use validation twitter data
123
- similarity_scores = similarity_scores[validation_subset_index]
124
- # Sort IDs by cosine-similarity from high to low
125
- id_sorted = np.argsort(similarity_scores)[::-1]
126
- best_ids = id_sorted[:topn]
127
- best_scores = similarity_scores[best_ids]
128
- target_weblinks = meta["weblink"].values[validation_subset_index][best_ids]
129
- #TODO: Avoid duplicated ID
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  topk_options = ['1st', '2nd', '3rd', '4th', '5th']
132
- st.radio(
133
- "Choose the most similar 👉",
134
- key="top_k",
135
- options=topk_options,
136
- horizontal=True
137
- )
138
- topn_txt = st.session_state.top_k
139
- topn_value = int(st.session_state.top_k[0])-1
140
- st.caption(f'The {topn_txt} relevant image (similarity = {best_scores[topn_value]:.4f})')
141
- components.html('''
142
- <blockquote class="twitter-tweet">
143
- <a href="%s"></a>
144
- </blockquote>
145
- <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
146
- </script>
147
- ''' % target_weblinks[topn_value],
148
- height=800)
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import numpy as np
4
  from PIL import Image
5
  import requests
 
16
  )
17
  import streamlit.components.v1 as components
18
 
19
+ from helper import load_path_clip, init, embed_texts
20
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def app():
24
 
 
26
  st.markdown('#### A pathology image search engine that correlate texts directly with images.')
27
  st.caption('Note: The searching query matches images only. The twitter text does not used for searching.')
28
 
29
+ meta, image_embedding, text_embedding, validation_subset_index = init()
30
  model, processor = load_path_clip()
31
 
32
+
33
+ col1, col2 = st.columns(2)
34
+ with col1:
35
+ data_options = ["All twitter data (2006-03-21 2023-01-15)",
36
+ "Twitter validation data (2022-11-16 — 2023-01-15)"]
37
+ st.radio(
38
+ "Choose dataset for image retrieval 👉",
39
+ key="datapool",
40
+ options=data_options,
41
+ )
42
+ with col2:
43
+ retrieval_options = ["Image only",
44
+ "text and image (beta)",
45
+ ]
46
+ st.radio(
47
+ "Similarity calcuation Mapping input with 👉",
48
+ key="calculation_option",
49
+ options=retrieval_options,
50
+ )
51
+
52
 
53
 
54
  col1, col2 = st.columns(2)
 
84
  else:
85
  query = query_2
86
 
87
+
88
+ input_text = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
89
+ input_text = input_text/np.linalg.norm(input_text)
90
+
91
 
92
+ if st.session_state.calculation_option == retrieval_options[0]: # Image only
93
+ similarity_scores = input_text.dot(image_embedding.T)
94
+ else: # Text and Image
95
+ similarity_scores_i = input_text.dot(image_embedding.T)
96
+ similarity_scores_t = input_text.dot(text_embedding.T)
97
+ similarity_scores_i = similarity_scores_i/np.max(similarity_scores_i)
98
+ similarity_scores_t = similarity_scores_t/np.max(similarity_scores_t)
99
+ similarity_scores = (similarity_scores_i + similarity_scores_t)/2
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+
103
+
104
+ ############################################################
105
+ # Get top results
106
+ ############################################################
107
+ topn = 5
108
+ df = pd.DataFrame(np.c_[np.arange(len(meta)), similarity_scores, meta['weblink'].values], columns = ['idx', 'score', 'twitterlink'])
109
+ if st.session_state.datapool == data_options[1]: #Use val twitter data
110
+ df = df.loc[validation_subset_index,:]
111
+ df = df.sort_values('score', ascending=False)
112
+ df = df.drop_duplicates(subset=['twitterlink'])
113
+ best_id_topk = df['idx'].values[:topn]
114
+ target_scores = df['score'].values[:topn]
115
+ target_weblinks = df['twitterlink'].values[:topn]
116
+
117
+
118
+ ############################################################
119
+ # Display results
120
+ ############################################################
121
+
122
+ st.markdown('Your input query: %s' % query)
123
+ st.markdown('#### Top 5 results:')
124
  topk_options = ['1st', '2nd', '3rd', '4th', '5th']
125
+ tab = {}
126
+ tab[0], tab[1], tab[2] = st.columns(3)
127
+ for i in [0,1,2]:
128
+ with tab[i]:
129
+ topn_value = i
130
+ topn_txt = topk_options[i]
131
+ st.caption(f'The {topn_txt} relevant image (similarity = {target_scores[topn_value]:.4f})')
132
+ components.html('''
133
+ <blockquote class="twitter-tweet">
134
+ <a href="%s"></a>
135
+ </blockquote>
136
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
137
+ </script>
138
+ ''' % target_weblinks[topn_value],
139
+ height=800)
140
+
141
+ tab[3], tab[4], tab[5] = st.columns(3)
142
+ for i in [3,4]:
143
+ with tab[i]:
144
+ topn_value = i
145
+ topn_txt = topk_options[i]
146
+ st.caption(f'The {topn_txt} relevant image (similarity = {target_scores[topn_value]:.4f})')
147
+ components.html('''
148
+ <blockquote class="twitter-tweet">
149
+ <a href="%s"></a>
150
+ </blockquote>
151
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
152
+ </script>
153
+ ''' % target_weblinks[topn_value],
154
+ height=800)
155
 
156
 
157