Vivien
commited on
Commit
•
000d238
1
Parent(s):
3f8dd94
Simplify code
Browse files
app.py
CHANGED
@@ -10,14 +10,13 @@ from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
|
|
10 |
dict: lambda _: None})
|
11 |
def load():
|
12 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
13 |
-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
14 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
15 |
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
|
16 |
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
|
17 |
for k in [0, 1]:
|
18 |
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
|
19 |
-
return model,
|
20 |
-
model,
|
21 |
|
22 |
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
|
23 |
|
@@ -33,7 +32,7 @@ def get_html(url_list, height=200):
|
|
33 |
|
34 |
def compute_text_embeddings(list_of_strings):
|
35 |
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
|
36 |
-
return model.
|
37 |
|
38 |
st.cache(show_spinner=False)
|
39 |
def image_search(query, corpus, n_results=24):
|
|
|
10 |
dict: lambda _: None})
|
11 |
def load():
|
12 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
13 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
14 |
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
|
15 |
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
|
16 |
for k in [0, 1]:
|
17 |
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
|
18 |
+
return model, processor, df, embeddings
|
19 |
+
model, processor, df, embeddings = load()
|
20 |
|
21 |
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
|
22 |
|
|
|
32 |
|
33 |
def compute_text_embeddings(list_of_strings):
|
34 |
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
|
35 |
+
return model.get_text_features(**inputs)
|
36 |
|
37 |
st.cache(show_spinner=False)
|
38 |
def image_search(query, corpus, n_results=24):
|