{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jarvis/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import ViTImageProcessor, ViTForImageClassification,FlaxViTForImageClassification\n", "from PIL import Image\n", "import requests\n", "from matplotlib import pyplot as plt " ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['tiger cat', 'tabby, tabby cat', 'Egyptian cat'] [282 281 285]\n" ] } ], "source": [ "url = 'http://images.cocodataset.org/val2017/000000039769.jpg'\n", "image = Image.open(requests.get(url, stream=True).raw)\n", "\n", "processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')\n", "model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')\n", "\n", "inputs = processor(images=image, return_tensors=\"pt\")\n", "outputs = model(**inputs)\n", "logits = outputs.logits\n", "\n", "logits_np = logits.detach().cpu().numpy()\n", "logits_args = logits_np.argsort()[0][-3:]\n", "\n", "prediction_classes = [model.config.id2label[predicted_class_idx] for predicted_class_idx in logits_args ]\n", "print(prediction_classes,logits_args)\n" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'tiger cat': -0.27440035,\n", " 'tabby, tabby cat': 0.8215165,\n", " 'Egyptian cat': -0.08364794}" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result = {}\n", "for i,item in enumerate(prediction_classes):\n", " result[item] = logits_np[0][i]\n", "\n", "result" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['tiger cat', 'tabby, tabby cat', 'Egyptian cat']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# model predicts one of the 1000 ImageNet classes\n", "\n", "prediction_classes = [model.config.id2label[predicted_class_idx] for predicted_class_idx in logits_args ]\n", "\n", "prediction_classes\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py_llm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }