import gradio as gr from transformers import pipeline # Load the model pipeline pipe = pipeline("image-classification", "dima806/medicinal_plants_image_detection") # Define the image classification function def image_classifier(image): # Perform image classification outputs = pipe(image) results = {} for result in outputs: results[result['label']] = result['score'] return results # Define app title and description with HTML formatting title = "

Image Classification

" description = "

This application serves to classify skin lesion images based on their skin cancer type. Trained using Vision Transformer (ViT), it has achieved a validation accuracy of 86%.

" # Define custom CSS styles for the Gradio app custom_css = """ .gradio-interface { max-width: 600px; margin: auto; border-radius: 10px; box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1); } .title-container { padding: 20px; background-color: #f0f0f0; border-top-left-radius: 10px; border-top-right-radius: 10px; } .description-container { padding: 20px; } """ # Launch the Gradio interface with custom HTML and CSS demo = gr.Interface(fn=image_classifier, inputs=gr.Image(type="pil"), outputs="label", title=title, description=description, theme="gstaff/sketch", css=custom_css, ) demo.launch() # import torch # from torchvision import transforms # from PIL import Image # from torchvision import models # import gradio.inputs as gi # import gradio.outputs as go # import gradio as gr # # Define the ResNet50 model # class ResNet50(torch.nn.Module): # def __init__(self): # super(ResNet50, self).__init__() # self.resnet = models.resnet50(pretrained=True) # for param in self.resnet.parameters(): # param.requires_grad = False # self.resnet.fc = torch.nn.Sequential( # torch.nn.Linear(2048, 2) # ) # def forward(self, x): # x = self.resnet(x) # return x # # Load the pre-trained model # model = ResNet50() # model.load_state_dict(torch.load('best_modelv2.pth', map_location=torch.device('cpu'))) # model.eval() # # Define transform for input images # data_transforms = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) # # Function to predict image label # def predict_image_label(image): # # Preprocess the image # image = data_transforms(image).unsqueeze(0) # # Make prediction # with torch.no_grad(): # output = model(image) # _, predicted = torch.max(output, 1) # label = 'Leaf' if predicted.item() == 0 else 'Plant' # return label # # Create Gradio interface # # image = gi.Image(shape=(224, 224)) # label = go.Label(num_top_classes=2) # gr.Interface(fn=predict_image_label,inputs="image", outputs=label, title="Leaf or Plant Classifier").launch()