import os import subprocess # Retrieve the token from the environment variables token = os.environ.get("token") # Clone the repository using the token repo_url = f"https://robocan:{token}" destination_dir = os.path.expanduser("~/SVD") # Use a directory in the home directory # Run the git clone command using subprocess["git", "clone", repo_url, destination_dir], check=True) import torch from import Dataset, DataLoader import pandas as pd import numpy as np import io import joblib import requests from tqdm import tqdm from PIL import Image from torchvision import transforms from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from torchvision import models import gradio as gr device = 'cpu' le = LabelEncoder() le = joblib.load("~/SVD/le.gz") class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768, out_features=512), torch.nn.ReLU(), torch.nn.Linear(in_features=512, out_features=len(le.classes_) + 1), ) def forward(self, data): return self.embedding(data) model = torch.load("~/SVD/GeoG.pth", map_location=torch.device(device)) modelm = ModelPre() modelm.load_state_dict(model['model']) import warnings warnings.filterwarnings("ignore", category=RuntimeWarning, module="multiprocessing.popen_fork") cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = modelm( prediction = le.inverse_transform(torch.argmax(res.cpu()).unsqueeze(0).numpy())[0] return prediction gradio_app = gr.Interface( fn=predict, inputs=gr.Image(label="Upload an Image", type="pil"), outputs=gr.Label(label="Location"), title="Predict the Location of this Image" ) if __name__ == "__main__": gradio_app.launch()