File size: 2,123 Bytes
9b889da
 
 
 
 
 
 
 
955fc23
 
 
 
 
 
 
 
 
 
 
 
 
6d49cf1
 
955fc23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d49cf1
 
955fc23
 
 
 
 
6d49cf1
 
955fc23
 
 
 
6d49cf1
 
 
955fc23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os

# Retrieve the token from the environment variables
token = os.environ.get("token")

# Clone the repository using the token
!git clone https://robocan:{token}@huggingface.co/robocan/GeoG_City /SVD

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import io
import joblib
import requests
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torchvision import models
import gradio as gr

device = 'cpu'
le = LabelEncoder()
le = joblib.load("/kaggle/working/SVD/le.gz")

class ModelPre(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Sequential(
            *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=768, out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512, out_features=len(le.classes_) + 1),
        )

    def forward(self, data):
        return self.embedding(data)

model = torch.load("/SVD/GeoG.pth", map_location=torch.device(device))

modelm = ModelPre()
modelm.load_state_dict(model['model'])

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="multiprocessing.popen_fork")

cmp = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224, 224), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict(input_img):
    with torch.inference_mode():
        img = cmp(input_img).unsqueeze(0)
        res = modelm(img.to(device))
        prediction = le.inverse_transform(torch.argmax(res.cpu()).unsqueeze(0).numpy())[0]
        return prediction

gradio_app = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Upload an Image", type="pil"),
    outputs=gr.Label(label="Location"),
    title="Predict the Location of this Image"
)

if __name__ == "__main__":
    gradio_app.launch()