chicelli commited on
Commit
ed54141
·
verified ·
1 Parent(s): 8752739

Upload predict.py

Browse files
Files changed (1) hide show
  1. img2art_search/models/predict.py +2 -6
img2art_search/models/predict.py CHANGED
@@ -11,13 +11,10 @@ from img2art_search.models.compute_embeddings import search_image
11
 
12
 
13
  def predict(img: Image.Image) -> list:
14
- tmp_img_path = "tmp_img.png"
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  if img:
17
- img.save(tmp_img_path)
18
- pred_img = np.array([[tmp_img_path], [tmp_img_path]])
19
- pred_dataset = ImageRetrievalDataset(pred_img, transform=transform)
20
- pred_image_data = pred_dataset[0][0].unsqueeze(0).to(DEVICE)
21
  indices, distances = search_image(pred_image_data)
22
  results = []
23
  for index, distance in zip(indices, distances):
@@ -31,7 +28,6 @@ def predict(img: Image.Image) -> list:
31
  str(distance),
32
  )
33
  )
34
- os.remove(tmp_img_path)
35
  return results
36
  else:
37
  return []
 
11
 
12
 
13
  def predict(img: Image.Image) -> list:
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  if img:
16
+ img = img.convert("RGB")
17
+ pred_image_data = transform(img).unsqueeze(0).to(DEVICE)
 
 
18
  indices, distances = search_image(pred_image_data)
19
  results = []
20
  for index, distance in zip(indices, distances):
 
28
  str(distance),
29
  )
30
  )
 
31
  return results
32
  else:
33
  return []