Ceyda Cinarel commited on
Commit
9fbe234
·
1 Parent(s): b0b9e1f

add nearest neighbor

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +34 -8
  3. beit_index.faiss +3 -0
  4. demo.py +23 -4
.gitattributes CHANGED
@@ -26,3 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ *.faiss filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import streamlit as st # HF spaces at v1.2.0
2
- from demo import load_model,generate
3
 
4
  # TODOs
5
  # Add markdown short readme project intro
@@ -21,21 +22,46 @@ def load_model_intocache(model_name):
21
 
22
  return gan
23
 
 
 
 
 
 
24
  model_name='ceyda/butterfly_cropped_uniq1K_512'
25
  model=load_model_intocache(model_name)
 
26
 
27
  st.write(f"Model {model_name} is loaded")
28
  st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
29
 
30
- run=st.button("Generate")
31
- if run:
 
 
 
 
32
  with st.spinner("Generating..."):
33
-
34
- batch_size=4 #generate 4 butterflies
35
  ims=generate(model,batch_size)
 
36
 
37
- cols=st.columns(batch_size)
38
- for i,im in enumerate(ims):
39
- cols[i].image(im)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
 
1
+ import re
2
  import streamlit as st # HF spaces at v1.2.0
3
+ from demo import load_model,generate,get_dataset,embed
4
 
5
  # TODOs
6
  # Add markdown short readme project intro
 
22
 
23
  return gan
24
 
25
+ @st.experimental_singleton
26
+ def load_dataset():
27
+ dataset=get_dataset()
28
+ return dataset
29
+
30
  model_name='ceyda/butterfly_cropped_uniq1K_512'
31
  model=load_model_intocache(model_name)
32
+ dataset=load_dataset()
33
 
34
  st.write(f"Model {model_name} is loaded")
35
  st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
36
 
37
+ if 'ims' not in st.session_state:
38
+ st.session_state['ims'] = None
39
+
40
+ ims=st.session_state["ims"]
41
+ batch_size=4 #generate 4 butterflies
42
+ def run():
43
  with st.spinner("Generating..."):
 
 
44
  ims=generate(model,batch_size)
45
+ st.session_state['ims'] = ims
46
 
47
+ runb=st.button("Generate", on_click=run)
48
+ if ims is not None:
49
+ cols=st.columns(batch_size)
50
+ picks=[False]*batch_size
51
+ for i,im in enumerate(ims):
52
+ cols[i].image(im)
53
+ picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
54
+ # if picks[i]:
55
+ # scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
56
+ # for r in retrieved_examples["image"]:
57
+ # st.image(r)
58
+
59
+ if any(picks):
60
+ # st.write("Nearest butterflies:")
61
+ for i,pick in enumerate(picks):
62
+ if pick:
63
+ scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
64
+ for r in retrieved_examples["image"]:
65
+ cols[i].image(r)
66
 
67
 
beit_index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d56496f69d06d78867ab39298a5354c0419056000824d82b06db343449c4518d
3
+ size 3072045
demo.py CHANGED
@@ -7,15 +7,34 @@ def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropp
7
  dataset=dataset.sort("sim_score")
8
  score_thresh = dataset["train"][data_limit]['sim_score']
9
  dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
10
-
11
- dataset = dataset.map(lambda x: x.convert("RGB"))
12
  return dataset["train"]
 
 
 
 
 
 
 
 
 
 
 
13
 
14
-
 
 
 
 
 
 
 
 
 
15
 
16
  def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
17
  gan = LightweightGAN.from_pretrained(model_name)
18
- gan.eval();
19
  return gan
20
 
21
  def generate(gan,batch_size=1):
 
7
  dataset=dataset.sort("sim_score")
8
  score_thresh = dataset["train"][data_limit]['sim_score']
9
  dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
10
+ dataset = dataset.map(lambda x: {'image' : x['image'].convert("RGB")})
 
11
  return dataset["train"]
12
+
13
+ from transformers import BeitFeatureExtractor, BeitForImageClassification
14
+ feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
15
+ model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
16
+ def embed(images):
17
+ inputs = feature_extractor(images=images, return_tensors="pt")
18
+ outputs = model(**inputs,output_hidden_states= True)
19
+ last_hidden=outputs.hidden_states[-1]
20
+ pooler=model.base_model.pooler
21
+ final_emb=pooler(last_hidden).detach().numpy()
22
+ return final_emb
23
 
24
+ def build_index():
25
+ dataset=get_train_data()
26
+ ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
27
+ ds_with_embeddings.add_faiss_index(column='beit_embeddings')
28
+ ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
29
+
30
+ def get_dataset():
31
+ dataset=get_train_data()
32
+ dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
33
+ return dataset
34
 
35
  def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
36
  gan = LightweightGAN.from_pretrained(model_name)
37
+ gan.eval()
38
  return gan
39
 
40
  def generate(gan,batch_size=1):