Spaces:
Sleeping
Sleeping
Upload predict.py
Browse files
img2art_search/models/predict.py
CHANGED
@@ -11,13 +11,10 @@ from img2art_search.models.compute_embeddings import search_image
|
|
11 |
|
12 |
|
13 |
def predict(img: Image.Image) -> list:
|
14 |
-
tmp_img_path = "tmp_img.png"
|
15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
if img:
|
17 |
-
img.
|
18 |
-
|
19 |
-
pred_dataset = ImageRetrievalDataset(pred_img, transform=transform)
|
20 |
-
pred_image_data = pred_dataset[0][0].unsqueeze(0).to(DEVICE)
|
21 |
indices, distances = search_image(pred_image_data)
|
22 |
results = []
|
23 |
for index, distance in zip(indices, distances):
|
@@ -31,7 +28,6 @@ def predict(img: Image.Image) -> list:
|
|
31 |
str(distance),
|
32 |
)
|
33 |
)
|
34 |
-
os.remove(tmp_img_path)
|
35 |
return results
|
36 |
else:
|
37 |
return []
|
|
|
11 |
|
12 |
|
13 |
def predict(img: Image.Image) -> list:
|
|
|
14 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
if img:
|
16 |
+
img = img.convert("RGB")
|
17 |
+
pred_image_data = transform(img).unsqueeze(0).to(DEVICE)
|
|
|
|
|
18 |
indices, distances = search_image(pred_image_data)
|
19 |
results = []
|
20 |
for index, distance in zip(indices, distances):
|
|
|
28 |
str(distance),
|
29 |
)
|
30 |
)
|
|
|
31 |
return results
|
32 |
else:
|
33 |
return []
|