{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a5831e0b-d99b-4f34-a65e-97f5d09f00ec", "metadata": {}, "outputs": [], "source": [ "# import required libraries\n", "from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 2, "id": "6881b277-9511-4460-a0aa-19b8d9e61fdf", "metadata": {}, "outputs": [], "source": [ "# pipeline function with default values\n", "def query(image, user_question):\n", " \"\"\"\n", " image: single image or batch of images;\n", " question: user prompt question;\n", " \"\"\"\n", " # select model from hugging face\n", " model_name = \"google/deplot\"\n", " # set preprocessor for current model\n", " processor = Pix2StructProcessor.from_pretrained(model_name)\n", " # load pre-trained model\n", " model = Pix2StructForConditionalGeneration.from_pretrained(model_name)\n", " # process the inputs for prediction\n", " inputs = processor(images=image, text=user_question, return_tensors=\"pt\")\n", " # save the results\n", " predictions = model.generate(**inputs, max_new_tokens=512)\n", " # save output\n", " result = processor.decode(predictions[0], skip_special_tokens=True)\n", " # process the results for output table\n", " outs = [x.strip() for x in result.split(\"<0x0A>\")]\n", " # create an empty list\n", " nested = list()\n", " # loop for splitting the data\n", " for data in outs:\n", " if \"|\" in data:\n", " nested.append([x.strip() for x in data.split(\"|\")])\n", " else:\n", " nested.append(data)\n", " # return the converted output\n", " return nested" ] }, { "cell_type": "code", "execution_count": null, "id": "04526adc-1ce4-48c6-b635-13bf506ed862", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cache from 'C:\\Users\\faiza\\huggingface\\Group Project\\gradio_cached_examples\\14' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n", "\n", "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] } ], "source": [ "# Interface framework to customize the io page \n", "ui = gr.Interface(title=\"Chart Q/A\",\n", " fn=query,\n", " inputs=[gr.Image(label=\"Upload Here\", type=\"pil\"), gr.Textbox(label=\"Question?\")],\n", " outputs=\"list\",\n", " examples=[[\"./samples/sample1.png\", \"Generate underlying data table of the figure\"], \n", " [\"./samples/sample2.png\", \"Is the sum of all 4 places greater than Laos?\"]],\n", " # [\"./samples/sample3.webp\", \"What are the 2020 net sales?\"]],\n", " cache_examples=True,\n", " allow_flagging='never')\n", "\n", "ui.queue(api_open=False)\n", "ui.launch(inline=False, share=False, debug=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }