Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import os | |
def load_images_from_current_directory(): | |
images = [] | |
current_directory = os.getcwd() | |
for filename in os.listdir(current_directory): | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
img_path = os.path.join(current_directory, filename) | |
img = Image.open(img_path) | |
if img is not None: | |
images.append(img) | |
return images | |
# Example: Load images from the current directory | |
example_images = load_images_from_current_directory() | |
# Define the image classification function | |
def classify_image(image): | |
try: | |
# Convert the Gradio image input (which is a NumPy array) to a PIL image | |
image = Image.fromarray(image) | |
# Create the image classification pipeline | |
img_class = pipeline( | |
"image-classification", model="AMfeta99/vit-base-oxford-brain-tumor" | |
) | |
# Perform image classification | |
results = img_class(image) | |
# Find the result with the highest score | |
max_score_result = max(results, key=lambda x: x['score']) | |
# Extract the predicted label | |
predictions = max_score_result['label'] | |
if predictions==1: | |
text_pred='Tumor' | |
else: | |
text_pred='Normal' | |
return text_pred | |
except Exception as e: | |
# Handle any errors that occur during classification | |
return f"Error: {str(e)}" | |
# Define the Gradio interface | |
image = gr.Image() | |
label = gr.Label(num_top_classes=1) | |
title = "Brain Tumor X-ray Classification" | |
description = "Worried about whether your brain scan is normal or not? Upload your x-ray and the algorithm will give you an expert opinion. Check out [the original algorithm](https://huggingface.co/AMfeta99/vit-base-oxford-brain-tumor) that this demo is based off of." | |
article = "<p style='text-align: center'>Image Classification | Demo Model</p>" | |
# Prepare examples with loaded images | |
examples = [] | |
for img in example_images: | |
examples.append([np.array(img), os.path.basename(os.path.splitext(img.filename)[0])]) | |
demo = gr.Interface(fn=classify_image, inputs=image, outputs=label, description=description, article=article, title=title, examples=examples) | |
# Launch the Gradio interface | |
demo.launch() |