{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mahnaz/mlprojects/bloom_classifier/ven_bloom_gradio/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import gradio as gr\n", "import json\n", "from transformers import pipeline\n", "from transformers import AutoImageProcessor\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n", "import numpy as np\n", "\n", "def preprocess_input(input_data, image_processor):\n", " \"\"\"\n", " Preprocesses the input image for inference.\n", "\n", " Parameters:\n", " input_data (str or np.ndarray): Path to the image file in .jpg format or a NumPy array.\n", " image_processor (AutoImageProcessor): An instance of AutoImageProcessor from the model's checkpoint.\n", "\n", " Returns:\n", " processed_img (torch.Tensor): Preprocessed image ready for inference.\n", " \"\"\"\n", " # Load the image based on the input type\n", " if isinstance(input_data, str):\n", " img = Image.open(input_data).convert('RGB')\n", " elif isinstance(input_data, np.ndarray):\n", " img = Image.fromarray(input_data.astype('uint8'), 'RGB')\n", " else:\n", " raise ValueError(\"Unsupported input type. Only str and np.ndarray are supported.\")\n", " \n", " # Obtain the mean and std from image_processor\n", " mean = image_processor.image_mean\n", " std = image_processor.image_std\n", " \n", " # Obtain the image size from image_processor\n", " size = (\n", " image_processor.size[\"shortest_edge\"]\n", " if \"shortest_edge\" in image_processor.size\n", " else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n", " )\n", " \n", " # Define the transformations\n", " preprocess = Compose([\n", " Resize(size), # Resizing to the same size used during training\n", " CenterCrop(size), # Center cropping to the same size used during training\n", " ToTensor(),\n", " Normalize(mean=mean, std=std)\n", " ])\n", " \n", " # Apply the transformations\n", " processed_img = preprocess(img)\n", " \n", " # Add a batch dimension\n", " processed_img = processed_img.unsqueeze(0) # This is necessary because the model expects a batch\n", " to_pil = ToPILImage()\n", " processed_img = to_pil(processed_img)\n", "\n", " return processed_img\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n", "\n", "def preprocess_input(image_path, image_processor):\n", " \"\"\"\n", " Preprocesses the input image for inference.\n", "\n", " Parameters:\n", " image_path (str): Path to the image file in .jpg format.\n", " image_processor (AutoImageProcessor): An instance of AutoImageProcessor from the model's checkpoint.\n", "\n", " Returns:\n", " processed_img (torch.Tensor): Preprocessed image ready for inference.\n", " \"\"\"\n", " # Load the image\n", " img = Image.open(image_path).convert('RGB')\n", " \n", " # Obtain the mean and std from image_processor\n", " mean = image_processor.image_mean\n", " std = image_processor.image_std\n", " \n", " # Obtain the image size from image_processor\n", " size = (\n", " image_processor.size[\"shortest_edge\"]\n", " if \"shortest_edge\" in image_processor.size\n", " else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n", " )\n", " \n", " # Define the transformations\n", " preprocess = Compose([\n", " Resize(size), # Resizing to the same size used during training\n", " CenterCrop(size), # Center cropping to the same size used during training\n", " ToTensor(),\n", " Normalize(mean=mean, std=std)\n", " ])\n", " \n", " # Apply the transformations\n", " processed_img = preprocess(img)\n", " \n", " # Add a batch dimension\n", " processed_img = processed_img.unsqueeze(0) # This is necessary because the model expects a batch\n", "\n", " return processed_img\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mahnaz/mlprojects/bloom_classifier/ven_bloom_gradio/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import gradio as gr\n", "import json\n", "from transformers import pipeline\n", "\n", "\n", "def load_label_to_name_mapping(json_file_path):\n", " \"\"\"Load the label-to-name mapping from a JSON file.\"\"\"\n", " with open(json_file_path, 'r') as f:\n", " mapping = json.load(f)\n", " return {int(k): v for k, v in mapping.items()}\n", "\n", "def infer_flower_name(classifier, image):\n", " \"\"\"Perform inference on an image and return the flower name.\"\"\"\n", " # Perform inference\n", " # Load the model checkpoint for inference\n", " \n", " result = classifier(image)\n", " # Get the label from the inference result\n", " label = result[0]['label'].split('_')[-1] # The label is usually in the format 'LABEL_#'\n", " label = int(label)\n", " \n", " # Map the integer label to the flower name\n", " json_file_path = 'label_to_name.json'\n", " label_to_name = load_label_to_name_mapping(json_file_path)\n", " flower_name = label_to_name.get(label, \"Unknown\")\n", " \n", " return flower_name\n", "\n", "\n", "\n", "def predict(prompt_img):# would call a model to make a prediction on an input and return the output.\n", "\n", " # Instantiate the AutoImageProcessor\n", " #image_processor = AutoImageProcessor.from_pretrained(\"google/vit-base-patch16-224-in21k\")\n", "\n", " # Preprocess the input image\n", " #image_path = 'path/to/your/image.jpg'\n", " #processed_img = preprocess_input(prompt_img, image_processor)\n", " processed_img= prompt_img \n", " classifier = pipeline(\"image-classification\", model=\"checkpoint-160\")\n", " flower_name = infer_flower_name(classifier, processed_img)\n", " return flower_name\n", "demo = gr.Interface(fn=predict, \n", " inputs=gr.Image(type=\"pil\"), \n", " outputs=gr.Label(num_top_classes=3),\n", " examples=[\"example.jpg\"])\n", "\n", "demo.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "venv_bloom-classifier", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }