File size: 3,880 Bytes
c81898a
 
ba03fb2
c81898a
b59b1d0
c81898a
 
 
 
 
 
 
 
 
 
 
 
000d238
 
c81898a
 
 
 
 
 
ba03fb2
c81898a
ba03fb2
c81898a
 
 
 
 
 
000d238
c81898a
 
 
 
 
 
 
 
 
 
 
 
ff968d5
 
 
9ea8c8c
ba03fb2
9ea8c8c
c81898a
 
 
 
 
 
 
 
ff968d5
 
 
 
 
 
 
 
c81898a
7600dc3
555584f
 
 
 
7600dc3
586f7e5
55fea56
586f7e5
c81898a
 
 
 
 
 
 
 
 
d2df557
3f8dd94
ff968d5
c81898a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
from transformers import CLIPProcessor, CLIPModel

@st.cache(show_spinner=False,
          hash_funcs={CLIPModel: lambda _: None,
                      CLIPProcessor: lambda _: None,
                      dict: lambda _: None})
def load():
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
  embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
  for k in [0, 1]:
    embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
  return model, processor, df, embeddings
model, processor, df, embeddings = load()

source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}

def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
        if len(link) > 0:
            html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs)

st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    k = 0 if corpus == 'Unsplash' else 1
    results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [(df[k].iloc[i]['path'],
             df[k].iloc[i]['tooltip'] + source[k],
             df[k].iloc[i]['link']) for i in results]

description = '''
# Semantic image search

**Enter your query and hit enter**

*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*

*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
'''

def main():
  st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              div.row-widget.stRadio > div{
                flex-direction:row;
                display: flex;
                justify-content: center;
              }
              div.row-widget.stRadio > div > label{
                margin-left: 5px;
                margin-right: 5px;
              }
              section.main>div:first-child {
                padding-top: 0px;
              }
              section:not(.main)>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
              unsafe_allow_html=True)
  st.sidebar.markdown(description)
  _, c, _ = st.columns((1, 3, 1))
  query = c.text_input('', value='clouds at sunset')
  corpus = st.radio('', ["Unsplash","Movies"])
  if len(query) > 0:
    results = image_search(query, corpus)
    st.markdown(get_html(results), unsafe_allow_html=True)

if __name__ == '__main__':
  main()