|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import huggingface_hub as hf |
|
|
|
|
|
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224') |
|
repo = "uploader" |
|
|
|
|
|
file = gr.FileInput(type="file", label="Upload Image File", preview=True) |
|
|
|
|
|
image = gr.Image(label="Uploaded Image") |
|
|
|
|
|
def upload_to_hf(filename): |
|
with open(filename, 'rb') as f: |
|
data = f.read() |
|
hf.upload_file(data, f"/{repo}/{filename}") |
|
return f"/{repo}/{filename}" |
|
|
|
|
|
def run(file): |
|
if file.startswith("http"): |
|
filename = file.split("/")[-1] |
|
filepath = upload_to_hf(filename) |
|
else: |
|
filepath = file |
|
|
|
image = Image.open(filepath).convert('RGB') |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) |
|
]) |
|
|
|
tensor = transform(image) |
|
tensor = tensor.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(tensor) |
|
|
|
image.update(filepath) |
|
return image |
|
|
|
|
|
app = gr.Interface(fn=run, inputs=[file], outputs=[image]) |
|
app.launch() |