vinid commited on
Commit
61448a4
·
1 Parent(s): d1b8523

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +9 -51
  2. home.py +11 -0
  3. introduction.md +2 -0
  4. text2image.py +74 -0
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from plip_support import embed_text
4
  import numpy as np
5
  from PIL import Image
6
  import requests
7
  import transformers
 
8
  import tokenizers
9
  from io import BytesIO
10
  import streamlit as st
@@ -18,55 +19,12 @@ from transformers import (
18
  )
19
  from transformers import AutoProcessor
20
 
 
21
 
22
- def embed_texts(model, texts, processor):
23
- inputs = processor(text=texts, padding="longest")
24
- input_ids = torch.tensor(inputs["input_ids"])
25
- attention_mask = torch.tensor(inputs["attention_mask"])
26
 
27
- with torch.no_grad():
28
- embeddings = model.get_text_features(
29
- input_ids=input_ids, attention_mask=attention_mask
30
- )
31
- return embeddings
32
-
33
- @st.cache
34
- def load_embeddings(embeddings_path):
35
- print("loading embeddings")
36
- return np.load(embeddings_path)
37
-
38
- @st.cache(
39
- hash_funcs={
40
- torch.nn.parameter.Parameter: lambda _: None,
41
- tokenizers.Tokenizer: lambda _: None,
42
- tokenizers.AddedToken: lambda _: None
43
- }
44
- )
45
- def load_path_clip():
46
- model = CLIPModel.from_pretrained("vinid/plip")
47
- processor = AutoProcessor.from_pretrained("vinid/plip")
48
- return model, processor
49
-
50
- st.title('PLIP Image Search')
51
-
52
- plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
53
-
54
- model, processor = load_path_clip()
55
-
56
- image_embedding = load_embeddings("tweet_eval_embeddings.npy")
57
-
58
- query = st.text_input('Search Query', '')
59
-
60
-
61
- if query:
62
-
63
- text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
64
-
65
- text_embedding = text_embedding/np.linalg.norm(text_embedding)
66
-
67
- best_id = np.argmax(text_embedding.dot(image_embedding.T))
68
- url = (plip_dataset.iloc[best_id]["imageURL"])
69
-
70
- response = requests.get(url)
71
- img = Image.open(BytesIO(response.content))
72
- st.image(img)
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import home
4
  import numpy as np
5
  from PIL import Image
6
  import requests
7
  import transformers
8
+ import text2image
9
  import tokenizers
10
  from io import BytesIO
11
  import streamlit as st
 
19
  )
20
  from transformers import AutoProcessor
21
 
22
+ st.sidebar.title("Explore our PLIP Demo")
23
 
24
+ PAGES = {
25
+ "Introduction": home,
26
+ "Text to Image": text2image,
27
+ }
28
 
29
+ page = st.sidebar.radio("", list(PAGES.keys()))
30
+ PAGES[page].app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
home.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import streamlit as st
3
+
4
+
5
+ def read_markdown_file(markdown_file):
6
+ return Path(markdown_file).read_text()
7
+
8
+
9
+ def app():
10
+ intro_markdown = read_markdown_file("introduction.md")
11
+ st.markdown(intro_markdown, unsafe_allow_html=True)
introduction.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ # Welcome to our PLIP Demo
text2image.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from plip_support import embed_text
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ import transformers
8
+ import tokenizers
9
+ from io import BytesIO
10
+ import streamlit as st
11
+ from transformers import CLIPModel
12
+ import clip
13
+ import torch
14
+ from transformers import (
15
+ VisionTextDualEncoderModel,
16
+ AutoFeatureExtractor,
17
+ AutoTokenizer
18
+ )
19
+ from transformers import AutoProcessor
20
+
21
+
22
+ def embed_texts(model, texts, processor):
23
+ inputs = processor(text=texts, padding="longest")
24
+ input_ids = torch.tensor(inputs["input_ids"])
25
+ attention_mask = torch.tensor(inputs["attention_mask"])
26
+
27
+ with torch.no_grad():
28
+ embeddings = model.get_text_features(
29
+ input_ids=input_ids, attention_mask=attention_mask
30
+ )
31
+ return embeddings
32
+
33
+ @st.cache
34
+ def load_embeddings(embeddings_path):
35
+ print("loading embeddings")
36
+ return np.load(embeddings_path)
37
+
38
+ @st.cache(
39
+ hash_funcs={
40
+ torch.nn.parameter.Parameter: lambda _: None,
41
+ tokenizers.Tokenizer: lambda _: None,
42
+ tokenizers.AddedToken: lambda _: None
43
+ }
44
+ )
45
+ def load_path_clip():
46
+ model = CLIPModel.from_pretrained("vinid/plip")
47
+ processor = AutoProcessor.from_pretrained("vinid/plip")
48
+ return model, processor
49
+
50
+
51
+ def app():
52
+ st.title('PLIP Image Search')
53
+
54
+ plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
55
+
56
+ model, processor = load_path_clip()
57
+
58
+ image_embedding = load_embeddings("tweet_eval_embeddings.npy")
59
+
60
+ query = st.text_input('Search Query', '')
61
+
62
+
63
+ if query:
64
+
65
+ text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
66
+
67
+ text_embedding = text_embedding/np.linalg.norm(text_embedding)
68
+
69
+ best_id = np.argmax(text_embedding.dot(image_embedding.T))
70
+ url = (plip_dataset.iloc[best_id]["imageURL"])
71
+
72
+ response = requests.get(url)
73
+ img = Image.open(BytesIO(response.content))
74
+ st.image(img)