Spaces:
Runtime error
Runtime error
add gradio dir
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- gradio-modified/gradio/.dockerignore +2 -0
- gradio-modified/gradio/__init__.py +86 -0
- gradio-modified/gradio/blocks.py +1673 -0
- gradio-modified/gradio/components.py +0 -0
- gradio-modified/gradio/context.py +14 -0
- gradio-modified/gradio/data_classes.py +55 -0
- gradio-modified/gradio/deprecation.py +45 -0
- gradio-modified/gradio/documentation.py +193 -0
- gradio-modified/gradio/encryptor.py +31 -0
- gradio-modified/gradio/events.py +723 -0
- gradio-modified/gradio/examples.py +327 -0
- gradio-modified/gradio/exceptions.py +23 -0
- gradio-modified/gradio/external.py +462 -0
- gradio-modified/gradio/external_utils.py +186 -0
- gradio-modified/gradio/flagging.py +560 -0
- gradio-modified/gradio/helpers.py +792 -0
- gradio-modified/gradio/inputs.py +473 -0
- gradio-modified/gradio/interface.py +844 -0
- gradio-modified/gradio/interpretation.py +255 -0
- gradio-modified/gradio/ipython_ext.py +17 -0
- gradio-modified/gradio/launches.json +1 -0
- gradio-modified/gradio/layouts.py +377 -0
- gradio-modified/gradio/media_data.py +0 -0
- gradio-modified/gradio/mix.py +128 -0
- gradio-modified/gradio/networking.py +185 -0
- gradio-modified/gradio/outputs.py +334 -0
- gradio-modified/gradio/pipelines.py +191 -0
- gradio-modified/gradio/processing_utils.py +755 -0
- gradio-modified/gradio/queueing.py +446 -0
- gradio-modified/gradio/reload.py +59 -0
- gradio-modified/gradio/routes.py +622 -0
- gradio-modified/gradio/serializing.py +208 -0
- gradio-modified/gradio/strings.py +41 -0
- gradio-modified/gradio/templates.py +563 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/BlockLabel.37da86a3.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.cc0aed40.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.e110d966.css +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Column.06c172ac.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/File.60a988f4.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Image.4a41f1aa.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Image.95fa511c.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Model3D.b44fd6f2.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/ModifyUpload.2cfe71e4.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Tabs.6b500f1a.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Upload.5d0148e8.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/Webcam.8816836e.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/_commonjsHelpers.88e99c8f.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/color.509e5f03.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/csv.27f5436c.js +0 -0
- gradio-modified/{templates → gradio/templates}/frontend/assets/dsv.7fe76a93.js +0 -0
gradio-modified/gradio/.dockerignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
templates/frontend
|
2 |
+
templates/frontend/**/*
|
gradio-modified/gradio/__init__.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pkgutil
|
2 |
+
|
3 |
+
import gradio.components as components
|
4 |
+
import gradio.inputs as inputs
|
5 |
+
import gradio.outputs as outputs
|
6 |
+
import gradio.processing_utils
|
7 |
+
import gradio.templates
|
8 |
+
from gradio.blocks import Blocks
|
9 |
+
from gradio.components import (
|
10 |
+
HTML,
|
11 |
+
JSON,
|
12 |
+
Audio,
|
13 |
+
Button,
|
14 |
+
Carousel,
|
15 |
+
Chatbot,
|
16 |
+
Checkbox,
|
17 |
+
Checkboxgroup,
|
18 |
+
CheckboxGroup,
|
19 |
+
ColorPicker,
|
20 |
+
DataFrame,
|
21 |
+
Dataframe,
|
22 |
+
Dataset,
|
23 |
+
Dropdown,
|
24 |
+
File,
|
25 |
+
Gallery,
|
26 |
+
Highlight,
|
27 |
+
Highlightedtext,
|
28 |
+
HighlightedText,
|
29 |
+
Image,
|
30 |
+
Interpretation,
|
31 |
+
Json,
|
32 |
+
Label,
|
33 |
+
LinePlot,
|
34 |
+
Markdown,
|
35 |
+
Model3D,
|
36 |
+
Number,
|
37 |
+
Plot,
|
38 |
+
Radio,
|
39 |
+
ScatterPlot,
|
40 |
+
Slider,
|
41 |
+
State,
|
42 |
+
StatusTracker,
|
43 |
+
Text,
|
44 |
+
Textbox,
|
45 |
+
TimeSeries,
|
46 |
+
Timeseries,
|
47 |
+
UploadButton,
|
48 |
+
Variable,
|
49 |
+
Video,
|
50 |
+
component,
|
51 |
+
)
|
52 |
+
from gradio.exceptions import Error
|
53 |
+
from gradio.flagging import (
|
54 |
+
CSVLogger,
|
55 |
+
FlaggingCallback,
|
56 |
+
HuggingFaceDatasetJSONSaver,
|
57 |
+
HuggingFaceDatasetSaver,
|
58 |
+
SimpleCSVLogger,
|
59 |
+
)
|
60 |
+
from gradio.helpers import Progress
|
61 |
+
from gradio.helpers import create_examples as Examples
|
62 |
+
from gradio.helpers import make_waveform, skip, update
|
63 |
+
from gradio.interface import Interface, TabbedInterface, close_all
|
64 |
+
from gradio.ipython_ext import load_ipython_extension
|
65 |
+
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
|
66 |
+
from gradio.mix import Parallel, Series
|
67 |
+
from gradio.routes import Request, mount_gradio_app
|
68 |
+
from gradio.templates import (
|
69 |
+
Files,
|
70 |
+
ImageMask,
|
71 |
+
ImagePaint,
|
72 |
+
List,
|
73 |
+
Matrix,
|
74 |
+
Mic,
|
75 |
+
Microphone,
|
76 |
+
Numpy,
|
77 |
+
Paint,
|
78 |
+
Pil,
|
79 |
+
PlayableVideo,
|
80 |
+
Sketchpad,
|
81 |
+
TextArea,
|
82 |
+
Webcam,
|
83 |
+
)
|
84 |
+
|
85 |
+
current_pkg_version = pkgutil.get_data(__name__, "version.txt").decode("ascii").strip()
|
86 |
+
__version__ = current_pkg_version
|
gradio-modified/gradio/blocks.py
ADDED
@@ -0,0 +1,1673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import getpass
|
5 |
+
import inspect
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import pkgutil
|
9 |
+
import random
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
import warnings
|
13 |
+
import webbrowser
|
14 |
+
from abc import abstractmethod
|
15 |
+
from pathlib import Path
|
16 |
+
from types import ModuleType
|
17 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
|
18 |
+
|
19 |
+
import anyio
|
20 |
+
import requests
|
21 |
+
from anyio import CapacityLimiter
|
22 |
+
from typing_extensions import Literal
|
23 |
+
|
24 |
+
from gradio import (
|
25 |
+
components,
|
26 |
+
encryptor,
|
27 |
+
external,
|
28 |
+
networking,
|
29 |
+
queueing,
|
30 |
+
routes,
|
31 |
+
strings,
|
32 |
+
utils,
|
33 |
+
)
|
34 |
+
from gradio.context import Context
|
35 |
+
from gradio.deprecation import check_deprecated_parameters
|
36 |
+
from gradio.documentation import document, set_documentation_group
|
37 |
+
from gradio.exceptions import DuplicateBlockError, InvalidApiName
|
38 |
+
from gradio.helpers import create_tracker, skip, special_args
|
39 |
+
from gradio.tunneling import CURRENT_TUNNELS
|
40 |
+
from gradio.utils import (
|
41 |
+
TupleNoPrint,
|
42 |
+
check_function_inputs_match,
|
43 |
+
component_or_layout_class,
|
44 |
+
delete_none,
|
45 |
+
get_cancel_function,
|
46 |
+
get_continuous_fn,
|
47 |
+
)
|
48 |
+
|
49 |
+
set_documentation_group("blocks")
|
50 |
+
|
51 |
+
|
52 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
53 |
+
import comet_ml
|
54 |
+
from fastapi.applications import FastAPI
|
55 |
+
|
56 |
+
from gradio.components import Component
|
57 |
+
|
58 |
+
|
59 |
+
class Block:
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
*,
|
63 |
+
render: bool = True,
|
64 |
+
elem_id: str | None = None,
|
65 |
+
visible: bool = True,
|
66 |
+
root_url: str | None = None, # URL that is prepended to all file paths
|
67 |
+
_skip_init_processing: bool = False, # Used for loading from Spaces
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
self._id = Context.id
|
71 |
+
Context.id += 1
|
72 |
+
self.visible = visible
|
73 |
+
self.elem_id = elem_id
|
74 |
+
self.root_url = root_url
|
75 |
+
self._skip_init_processing = _skip_init_processing
|
76 |
+
self._style = {}
|
77 |
+
self.parent: BlockContext | None = None
|
78 |
+
|
79 |
+
if render:
|
80 |
+
self.render()
|
81 |
+
check_deprecated_parameters(self.__class__.__name__, **kwargs)
|
82 |
+
|
83 |
+
def render(self):
|
84 |
+
"""
|
85 |
+
Adds self into appropriate BlockContext
|
86 |
+
"""
|
87 |
+
if Context.root_block is not None and self._id in Context.root_block.blocks:
|
88 |
+
raise DuplicateBlockError(
|
89 |
+
f"A block with id: {self._id} has already been rendered in the current Blocks."
|
90 |
+
)
|
91 |
+
if Context.block is not None:
|
92 |
+
Context.block.add(self)
|
93 |
+
if Context.root_block is not None:
|
94 |
+
Context.root_block.blocks[self._id] = self
|
95 |
+
if isinstance(self, components.TempFileManager):
|
96 |
+
Context.root_block.temp_file_sets.append(self.temp_files)
|
97 |
+
return self
|
98 |
+
|
99 |
+
def unrender(self):
|
100 |
+
"""
|
101 |
+
Removes self from BlockContext if it has been rendered (otherwise does nothing).
|
102 |
+
Removes self from the layout and collection of blocks, but does not delete any event triggers.
|
103 |
+
"""
|
104 |
+
if Context.block is not None:
|
105 |
+
try:
|
106 |
+
Context.block.children.remove(self)
|
107 |
+
except ValueError:
|
108 |
+
pass
|
109 |
+
if Context.root_block is not None:
|
110 |
+
try:
|
111 |
+
del Context.root_block.blocks[self._id]
|
112 |
+
except KeyError:
|
113 |
+
pass
|
114 |
+
return self
|
115 |
+
|
116 |
+
def get_block_name(self) -> str:
|
117 |
+
"""
|
118 |
+
Gets block's class name.
|
119 |
+
|
120 |
+
If it is template component it gets the parent's class name.
|
121 |
+
|
122 |
+
@return: class name
|
123 |
+
"""
|
124 |
+
return (
|
125 |
+
self.__class__.__base__.__name__.lower()
|
126 |
+
if hasattr(self, "is_template")
|
127 |
+
else self.__class__.__name__.lower()
|
128 |
+
)
|
129 |
+
|
130 |
+
def get_expected_parent(self) -> Type[BlockContext] | None:
|
131 |
+
return None
|
132 |
+
|
133 |
+
def set_event_trigger(
|
134 |
+
self,
|
135 |
+
event_name: str,
|
136 |
+
fn: Callable | None,
|
137 |
+
inputs: Component | List[Component] | Set[Component] | None,
|
138 |
+
outputs: Component | List[Component] | None,
|
139 |
+
preprocess: bool = True,
|
140 |
+
postprocess: bool = True,
|
141 |
+
scroll_to_output: bool = False,
|
142 |
+
show_progress: bool = True,
|
143 |
+
api_name: str | None = None,
|
144 |
+
js: str | None = None,
|
145 |
+
no_target: bool = False,
|
146 |
+
queue: bool | None = None,
|
147 |
+
batch: bool = False,
|
148 |
+
max_batch_size: int = 4,
|
149 |
+
cancels: List[int] | None = None,
|
150 |
+
every: float | None = None,
|
151 |
+
) -> Dict[str, Any]:
|
152 |
+
"""
|
153 |
+
Adds an event to the component's dependencies.
|
154 |
+
Parameters:
|
155 |
+
event_name: event name
|
156 |
+
fn: Callable function
|
157 |
+
inputs: input list
|
158 |
+
outputs: output list
|
159 |
+
preprocess: whether to run the preprocess methods of components
|
160 |
+
postprocess: whether to run the postprocess methods of components
|
161 |
+
scroll_to_output: whether to scroll to output of dependency on trigger
|
162 |
+
show_progress: whether to show progress animation while running.
|
163 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
164 |
+
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
|
165 |
+
no_target: if True, sets "targets" to [], used for Blocks "load" event
|
166 |
+
batch: whether this function takes in a batch of inputs
|
167 |
+
max_batch_size: the maximum batch size to send to the function
|
168 |
+
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
169 |
+
Returns: None
|
170 |
+
"""
|
171 |
+
# Support for singular parameter
|
172 |
+
if isinstance(inputs, set):
|
173 |
+
inputs_as_dict = True
|
174 |
+
inputs = sorted(inputs, key=lambda x: x._id)
|
175 |
+
else:
|
176 |
+
inputs_as_dict = False
|
177 |
+
if inputs is None:
|
178 |
+
inputs = []
|
179 |
+
elif not isinstance(inputs, list):
|
180 |
+
inputs = [inputs]
|
181 |
+
|
182 |
+
if isinstance(outputs, set):
|
183 |
+
outputs = sorted(outputs, key=lambda x: x._id)
|
184 |
+
else:
|
185 |
+
if outputs is None:
|
186 |
+
outputs = []
|
187 |
+
elif not isinstance(outputs, list):
|
188 |
+
outputs = [outputs]
|
189 |
+
|
190 |
+
if fn is not None and not cancels:
|
191 |
+
check_function_inputs_match(fn, inputs, inputs_as_dict)
|
192 |
+
|
193 |
+
if Context.root_block is None:
|
194 |
+
raise AttributeError(
|
195 |
+
f"{event_name}() and other events can only be called within a Blocks context."
|
196 |
+
)
|
197 |
+
if every is not None and every <= 0:
|
198 |
+
raise ValueError("Parameter every must be positive or None")
|
199 |
+
if every and batch:
|
200 |
+
raise ValueError(
|
201 |
+
f"Cannot run {event_name} event in a batch and every {every} seconds. "
|
202 |
+
"Either batch is True or every is non-zero but not both."
|
203 |
+
)
|
204 |
+
|
205 |
+
if every and fn:
|
206 |
+
fn = get_continuous_fn(fn, every)
|
207 |
+
elif every:
|
208 |
+
raise ValueError("Cannot set a value for `every` without a `fn`.")
|
209 |
+
|
210 |
+
Context.root_block.fns.append(
|
211 |
+
BlockFunction(fn, inputs, outputs, preprocess, postprocess, inputs_as_dict)
|
212 |
+
)
|
213 |
+
if api_name is not None:
|
214 |
+
api_name_ = utils.append_unique_suffix(
|
215 |
+
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
|
216 |
+
)
|
217 |
+
if not (api_name == api_name_):
|
218 |
+
warnings.warn(
|
219 |
+
"api_name {} already exists, using {}".format(api_name, api_name_)
|
220 |
+
)
|
221 |
+
api_name = api_name_
|
222 |
+
|
223 |
+
dependency = {
|
224 |
+
"targets": [self._id] if not no_target else [],
|
225 |
+
"trigger": event_name,
|
226 |
+
"inputs": [block._id for block in inputs],
|
227 |
+
"outputs": [block._id for block in outputs],
|
228 |
+
"backend_fn": fn is not None,
|
229 |
+
"js": js,
|
230 |
+
"queue": False if fn is None else queue,
|
231 |
+
"api_name": api_name,
|
232 |
+
"scroll_to_output": scroll_to_output,
|
233 |
+
"show_progress": show_progress,
|
234 |
+
"every": every,
|
235 |
+
"batch": batch,
|
236 |
+
"max_batch_size": max_batch_size,
|
237 |
+
"cancels": cancels or [],
|
238 |
+
}
|
239 |
+
Context.root_block.dependencies.append(dependency)
|
240 |
+
return dependency
|
241 |
+
|
242 |
+
def get_config(self):
|
243 |
+
return {
|
244 |
+
"visible": self.visible,
|
245 |
+
"elem_id": self.elem_id,
|
246 |
+
"style": self._style,
|
247 |
+
"root_url": self.root_url,
|
248 |
+
}
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
@abstractmethod
|
252 |
+
def update(**kwargs) -> Dict:
|
253 |
+
return {}
|
254 |
+
|
255 |
+
@classmethod
|
256 |
+
def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
|
257 |
+
del generic_update["__type__"]
|
258 |
+
specific_update = cls.update(**generic_update)
|
259 |
+
return specific_update
|
260 |
+
|
261 |
+
|
262 |
+
class BlockContext(Block):
|
263 |
+
def __init__(
|
264 |
+
self,
|
265 |
+
visible: bool = True,
|
266 |
+
render: bool = True,
|
267 |
+
**kwargs,
|
268 |
+
):
|
269 |
+
"""
|
270 |
+
Parameters:
|
271 |
+
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
|
272 |
+
render: If False, this will not be included in the Blocks config file at all.
|
273 |
+
"""
|
274 |
+
self.children: List[Block] = []
|
275 |
+
super().__init__(visible=visible, render=render, **kwargs)
|
276 |
+
|
277 |
+
def __enter__(self):
|
278 |
+
self.parent = Context.block
|
279 |
+
Context.block = self
|
280 |
+
return self
|
281 |
+
|
282 |
+
def add(self, child: Block):
|
283 |
+
child.parent = self
|
284 |
+
self.children.append(child)
|
285 |
+
|
286 |
+
def fill_expected_parents(self):
|
287 |
+
children = []
|
288 |
+
pseudo_parent = None
|
289 |
+
for child in self.children:
|
290 |
+
expected_parent = child.get_expected_parent()
|
291 |
+
if not expected_parent or isinstance(self, expected_parent):
|
292 |
+
pseudo_parent = None
|
293 |
+
children.append(child)
|
294 |
+
else:
|
295 |
+
if pseudo_parent is not None and isinstance(
|
296 |
+
pseudo_parent, expected_parent
|
297 |
+
):
|
298 |
+
pseudo_parent.children.append(child)
|
299 |
+
else:
|
300 |
+
pseudo_parent = expected_parent(render=False)
|
301 |
+
children.append(pseudo_parent)
|
302 |
+
pseudo_parent.children = [child]
|
303 |
+
if Context.root_block:
|
304 |
+
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
|
305 |
+
child.parent = pseudo_parent
|
306 |
+
self.children = children
|
307 |
+
|
308 |
+
def __exit__(self, *args):
|
309 |
+
if getattr(self, "allow_expected_parents", True):
|
310 |
+
self.fill_expected_parents()
|
311 |
+
Context.block = self.parent
|
312 |
+
|
313 |
+
def postprocess(self, y):
|
314 |
+
"""
|
315 |
+
Any postprocessing needed to be performed on a block context.
|
316 |
+
"""
|
317 |
+
return y
|
318 |
+
|
319 |
+
|
320 |
+
class BlockFunction:
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
fn: Callable | None,
|
324 |
+
inputs: List[Component],
|
325 |
+
outputs: List[Component],
|
326 |
+
preprocess: bool,
|
327 |
+
postprocess: bool,
|
328 |
+
inputs_as_dict: bool,
|
329 |
+
):
|
330 |
+
self.fn = fn
|
331 |
+
self.inputs = inputs
|
332 |
+
self.outputs = outputs
|
333 |
+
self.preprocess = preprocess
|
334 |
+
self.postprocess = postprocess
|
335 |
+
self.total_runtime = 0
|
336 |
+
self.total_runs = 0
|
337 |
+
self.inputs_as_dict = inputs_as_dict
|
338 |
+
|
339 |
+
def __str__(self):
|
340 |
+
return str(
|
341 |
+
{
|
342 |
+
"fn": getattr(self.fn, "__name__", "fn")
|
343 |
+
if self.fn is not None
|
344 |
+
else None,
|
345 |
+
"preprocess": self.preprocess,
|
346 |
+
"postprocess": self.postprocess,
|
347 |
+
}
|
348 |
+
)
|
349 |
+
|
350 |
+
def __repr__(self):
|
351 |
+
return str(self)
|
352 |
+
|
353 |
+
|
354 |
+
class class_or_instancemethod(classmethod):
|
355 |
+
def __get__(self, instance, type_):
|
356 |
+
descr_get = super().__get__ if instance is None else self.__func__.__get__
|
357 |
+
return descr_get(instance, type_)
|
358 |
+
|
359 |
+
|
360 |
+
def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
|
361 |
+
"""
|
362 |
+
Converts a dictionary of updates into a format that can be sent to the frontend.
|
363 |
+
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
|
364 |
+
Into -> {"__type__": "update", "value": 2.0, "mode": "static"}
|
365 |
+
|
366 |
+
Parameters:
|
367 |
+
block: The Block that is being updated with this update dictionary.
|
368 |
+
update_dict: The original update dictionary
|
369 |
+
postprocess: Whether to postprocess the "value" key of the update dictionary.
|
370 |
+
"""
|
371 |
+
if update_dict.get("__type__", "") == "generic_update":
|
372 |
+
update_dict = block.get_specific_update(update_dict)
|
373 |
+
if update_dict.get("value") is components._Keywords.NO_VALUE:
|
374 |
+
update_dict.pop("value")
|
375 |
+
prediction_value = delete_none(update_dict, skip_value=True)
|
376 |
+
if "value" in prediction_value and postprocess:
|
377 |
+
assert isinstance(
|
378 |
+
block, components.IOComponent
|
379 |
+
), f"Component {block.__class__} does not support value"
|
380 |
+
prediction_value["value"] = block.postprocess(prediction_value["value"])
|
381 |
+
return prediction_value
|
382 |
+
|
383 |
+
|
384 |
+
def convert_component_dict_to_list(
|
385 |
+
outputs_ids: List[int], predictions: Dict
|
386 |
+
) -> List | Dict:
|
387 |
+
"""
|
388 |
+
Converts a dictionary of component updates into a list of updates in the order of
|
389 |
+
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
|
390 |
+
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
|
391 |
+
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
|
392 |
+
"""
|
393 |
+
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
|
394 |
+
if all(keys_are_blocks):
|
395 |
+
reordered_predictions = [skip() for _ in outputs_ids]
|
396 |
+
for component, value in predictions.items():
|
397 |
+
if component._id not in outputs_ids:
|
398 |
+
raise ValueError(
|
399 |
+
f"Returned component {component} not specified as output of function."
|
400 |
+
)
|
401 |
+
output_index = outputs_ids.index(component._id)
|
402 |
+
reordered_predictions[output_index] = value
|
403 |
+
predictions = utils.resolve_singleton(reordered_predictions)
|
404 |
+
elif any(keys_are_blocks):
|
405 |
+
raise ValueError(
|
406 |
+
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
|
407 |
+
)
|
408 |
+
return predictions
|
409 |
+
|
410 |
+
|
411 |
+
@document("load")
|
412 |
+
class Blocks(BlockContext):
|
413 |
+
"""
|
414 |
+
Blocks is Gradio's low-level API that allows you to create more custom web
|
415 |
+
applications and demos than Interfaces (yet still entirely in Python).
|
416 |
+
|
417 |
+
|
418 |
+
Compared to the Interface class, Blocks offers more flexibility and control over:
|
419 |
+
(1) the layout of components (2) the events that
|
420 |
+
trigger the execution of functions (3) data flows (e.g. inputs can trigger outputs,
|
421 |
+
which can trigger the next level of outputs). Blocks also offers ways to group
|
422 |
+
together related demos such as with tabs.
|
423 |
+
|
424 |
+
|
425 |
+
The basic usage of Blocks is as follows: create a Blocks object, then use it as a
|
426 |
+
context (with the "with" statement), and then define layouts, components, or events
|
427 |
+
within the Blocks context. Finally, call the launch() method to launch the demo.
|
428 |
+
|
429 |
+
Example:
|
430 |
+
import gradio as gr
|
431 |
+
def update(name):
|
432 |
+
return f"Welcome to Gradio, {name}!"
|
433 |
+
|
434 |
+
with gr.Blocks() as demo:
|
435 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
436 |
+
with gr.Row():
|
437 |
+
inp = gr.Textbox(placeholder="What is your name?")
|
438 |
+
out = gr.Textbox()
|
439 |
+
btn = gr.Button("Run")
|
440 |
+
btn.click(fn=update, inputs=inp, outputs=out)
|
441 |
+
|
442 |
+
demo.launch()
|
443 |
+
Demos: blocks_hello, blocks_flipper, blocks_speech_text_sentiment, generate_english_german, sound_alert
|
444 |
+
Guides: blocks_and_event_listeners, controlling_layout, state_in_blocks, custom_CSS_and_JS, custom_interpretations_with_blocks, using_blocks_like_functions
|
445 |
+
"""
|
446 |
+
|
447 |
+
def __init__(
|
448 |
+
self,
|
449 |
+
theme: str = "default",
|
450 |
+
analytics_enabled: bool | None = None,
|
451 |
+
mode: str = "blocks",
|
452 |
+
title: str = "Gradio",
|
453 |
+
css: str | None = None,
|
454 |
+
**kwargs,
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Parameters:
|
458 |
+
theme: which theme to use - right now, only "default" is supported.
|
459 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
460 |
+
mode: a human-friendly name for the kind of Blocks or Interface being created.
|
461 |
+
title: The tab title to display when this is opened in a browser window.
|
462 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
463 |
+
"""
|
464 |
+
# Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
|
465 |
+
self.limiter = None
|
466 |
+
self.save_to = None
|
467 |
+
self.theme = theme
|
468 |
+
self.encrypt = False
|
469 |
+
self.share = False
|
470 |
+
self.enable_queue = None
|
471 |
+
self.max_threads = 40
|
472 |
+
self.show_error = True
|
473 |
+
if css is not None and Path(css).exists():
|
474 |
+
with open(css) as css_file:
|
475 |
+
self.css = css_file.read()
|
476 |
+
else:
|
477 |
+
self.css = css
|
478 |
+
|
479 |
+
# For analytics_enabled and allow_flagging: (1) first check for
|
480 |
+
# parameter, (2) check for env variable, (3) default to True/"manual"
|
481 |
+
self.analytics_enabled = (
|
482 |
+
analytics_enabled
|
483 |
+
if analytics_enabled is not None
|
484 |
+
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
485 |
+
)
|
486 |
+
|
487 |
+
super().__init__(render=False, **kwargs)
|
488 |
+
self.blocks: Dict[int, Block] = {}
|
489 |
+
self.fns: List[BlockFunction] = []
|
490 |
+
self.dependencies = []
|
491 |
+
self.mode = mode
|
492 |
+
|
493 |
+
self.is_running = False
|
494 |
+
self.local_url = None
|
495 |
+
self.share_url = None
|
496 |
+
self.width = None
|
497 |
+
self.height = None
|
498 |
+
self.api_open = True
|
499 |
+
|
500 |
+
self.ip_address = ""
|
501 |
+
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
|
502 |
+
self.favicon_path = None
|
503 |
+
self.auth = None
|
504 |
+
self.dev_mode = True
|
505 |
+
self.app_id = random.getrandbits(64)
|
506 |
+
self.temp_file_sets = []
|
507 |
+
self.title = title
|
508 |
+
self.show_api = True
|
509 |
+
|
510 |
+
# Only used when an Interface is loaded from a config
|
511 |
+
self.predict = None
|
512 |
+
self.input_components = None
|
513 |
+
self.output_components = None
|
514 |
+
self.__name__ = None
|
515 |
+
self.api_mode = None
|
516 |
+
|
517 |
+
if self.analytics_enabled:
|
518 |
+
self.ip_address = utils.get_local_ip_address()
|
519 |
+
data = {
|
520 |
+
"mode": self.mode,
|
521 |
+
"ip_address": self.ip_address,
|
522 |
+
"custom_css": self.css is not None,
|
523 |
+
"theme": self.theme,
|
524 |
+
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
|
525 |
+
.decode("ascii")
|
526 |
+
.strip(),
|
527 |
+
}
|
528 |
+
utils.initiated_analytics(data)
|
529 |
+
|
530 |
+
@classmethod
|
531 |
+
def from_config(
|
532 |
+
cls, config: dict, fns: List[Callable], root_url: str | None = None
|
533 |
+
) -> Blocks:
|
534 |
+
"""
|
535 |
+
Factory method that creates a Blocks from a config and list of functions.
|
536 |
+
|
537 |
+
Parameters:
|
538 |
+
config: a dictionary containing the configuration of the Blocks.
|
539 |
+
fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config.
|
540 |
+
root_url: an optional root url to use for the components in the Blocks. Allows serving files from an external URL.
|
541 |
+
"""
|
542 |
+
config = copy.deepcopy(config)
|
543 |
+
components_config = config["components"]
|
544 |
+
original_mapping: Dict[int, Block] = {}
|
545 |
+
|
546 |
+
def get_block_instance(id: int) -> Block:
|
547 |
+
for block_config in components_config:
|
548 |
+
if block_config["id"] == id:
|
549 |
+
break
|
550 |
+
else:
|
551 |
+
raise ValueError("Cannot find block with id {}".format(id))
|
552 |
+
cls = component_or_layout_class(block_config["type"])
|
553 |
+
block_config["props"].pop("type", None)
|
554 |
+
block_config["props"].pop("name", None)
|
555 |
+
style = block_config["props"].pop("style", None)
|
556 |
+
if block_config["props"].get("root_url") is None and root_url:
|
557 |
+
block_config["props"]["root_url"] = root_url + "/"
|
558 |
+
# Any component has already processed its initial value, so we skip that step here
|
559 |
+
block = cls(**block_config["props"], _skip_init_processing=True)
|
560 |
+
if style and isinstance(block, components.IOComponent):
|
561 |
+
block.style(**style)
|
562 |
+
return block
|
563 |
+
|
564 |
+
def iterate_over_children(children_list):
|
565 |
+
for child_config in children_list:
|
566 |
+
id = child_config["id"]
|
567 |
+
block = get_block_instance(id)
|
568 |
+
|
569 |
+
original_mapping[id] = block
|
570 |
+
|
571 |
+
children = child_config.get("children")
|
572 |
+
if children is not None:
|
573 |
+
assert isinstance(
|
574 |
+
block, BlockContext
|
575 |
+
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
|
576 |
+
with block:
|
577 |
+
iterate_over_children(children)
|
578 |
+
|
579 |
+
with Blocks(theme=config["theme"], css=config["theme"]) as blocks:
|
580 |
+
# ID 0 should be the root Blocks component
|
581 |
+
original_mapping[0] = Context.root_block or blocks
|
582 |
+
|
583 |
+
iterate_over_children(config["layout"]["children"])
|
584 |
+
|
585 |
+
first_dependency = None
|
586 |
+
|
587 |
+
# add the event triggers
|
588 |
+
for dependency, fn in zip(config["dependencies"], fns):
|
589 |
+
# We used to add a "fake_event" to the config to cache examples
|
590 |
+
# without removing it. This was causing bugs in calling gr.Interface.load
|
591 |
+
# We fixed the issue by removing "fake_event" from the config in examples.py
|
592 |
+
# but we still need to skip these events when loading the config to support
|
593 |
+
# older demos
|
594 |
+
if dependency["trigger"] == "fake_event":
|
595 |
+
continue
|
596 |
+
targets = dependency.pop("targets")
|
597 |
+
trigger = dependency.pop("trigger")
|
598 |
+
dependency.pop("backend_fn")
|
599 |
+
dependency.pop("documentation", None)
|
600 |
+
dependency["inputs"] = [
|
601 |
+
original_mapping[i] for i in dependency["inputs"]
|
602 |
+
]
|
603 |
+
dependency["outputs"] = [
|
604 |
+
original_mapping[o] for o in dependency["outputs"]
|
605 |
+
]
|
606 |
+
dependency.pop("status_tracker", None)
|
607 |
+
dependency["preprocess"] = False
|
608 |
+
dependency["postprocess"] = False
|
609 |
+
|
610 |
+
for target in targets:
|
611 |
+
dependency = original_mapping[target].set_event_trigger(
|
612 |
+
event_name=trigger, fn=fn, **dependency
|
613 |
+
)
|
614 |
+
if first_dependency is None:
|
615 |
+
first_dependency = dependency
|
616 |
+
|
617 |
+
# Allows some use of Interface-specific methods with loaded Spaces
|
618 |
+
if first_dependency and Context.root_block:
|
619 |
+
blocks.predict = [fns[0]]
|
620 |
+
blocks.input_components = [
|
621 |
+
Context.root_block.blocks[i] for i in first_dependency["inputs"]
|
622 |
+
]
|
623 |
+
blocks.output_components = [
|
624 |
+
Context.root_block.blocks[o] for o in first_dependency["outputs"]
|
625 |
+
]
|
626 |
+
blocks.__name__ = "Interface"
|
627 |
+
blocks.api_mode = True
|
628 |
+
|
629 |
+
return blocks
|
630 |
+
|
631 |
+
def __str__(self):
|
632 |
+
return self.__repr__()
|
633 |
+
|
634 |
+
def __repr__(self):
|
635 |
+
num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
|
636 |
+
repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
|
637 |
+
repr += "\n" + "-" * len(repr)
|
638 |
+
for d, dependency in enumerate(self.dependencies):
|
639 |
+
if dependency["backend_fn"]:
|
640 |
+
repr += f"\nfn_index={d}"
|
641 |
+
repr += "\n inputs:"
|
642 |
+
for input_id in dependency["inputs"]:
|
643 |
+
block = self.blocks[input_id]
|
644 |
+
repr += "\n |-{}".format(str(block))
|
645 |
+
repr += "\n outputs:"
|
646 |
+
for output_id in dependency["outputs"]:
|
647 |
+
block = self.blocks[output_id]
|
648 |
+
repr += "\n |-{}".format(str(block))
|
649 |
+
return repr
|
650 |
+
|
651 |
+
def render(self):
|
652 |
+
if Context.root_block is not None:
|
653 |
+
if self._id in Context.root_block.blocks:
|
654 |
+
raise DuplicateBlockError(
|
655 |
+
f"A block with id: {self._id} has already been rendered in the current Blocks."
|
656 |
+
)
|
657 |
+
if not set(Context.root_block.blocks).isdisjoint(self.blocks):
|
658 |
+
raise DuplicateBlockError(
|
659 |
+
"At least one block in this Blocks has already been rendered."
|
660 |
+
)
|
661 |
+
|
662 |
+
Context.root_block.blocks.update(self.blocks)
|
663 |
+
Context.root_block.fns.extend(self.fns)
|
664 |
+
dependency_offset = len(Context.root_block.dependencies)
|
665 |
+
for i, dependency in enumerate(self.dependencies):
|
666 |
+
api_name = dependency["api_name"]
|
667 |
+
if api_name is not None:
|
668 |
+
api_name_ = utils.append_unique_suffix(
|
669 |
+
api_name,
|
670 |
+
[dep["api_name"] for dep in Context.root_block.dependencies],
|
671 |
+
)
|
672 |
+
if not (api_name == api_name_):
|
673 |
+
warnings.warn(
|
674 |
+
"api_name {} already exists, using {}".format(
|
675 |
+
api_name, api_name_
|
676 |
+
)
|
677 |
+
)
|
678 |
+
dependency["api_name"] = api_name_
|
679 |
+
dependency["cancels"] = [
|
680 |
+
c + dependency_offset for c in dependency["cancels"]
|
681 |
+
]
|
682 |
+
# Recreate the cancel function so that it has the latest
|
683 |
+
# dependency fn indices. This is necessary to properly cancel
|
684 |
+
# events in the backend
|
685 |
+
if dependency["cancels"]:
|
686 |
+
updated_cancels = [
|
687 |
+
Context.root_block.dependencies[i]
|
688 |
+
for i in dependency["cancels"]
|
689 |
+
]
|
690 |
+
new_fn = BlockFunction(
|
691 |
+
get_cancel_function(updated_cancels)[0],
|
692 |
+
[],
|
693 |
+
[],
|
694 |
+
False,
|
695 |
+
True,
|
696 |
+
False,
|
697 |
+
)
|
698 |
+
Context.root_block.fns[dependency_offset + i] = new_fn
|
699 |
+
Context.root_block.dependencies.append(dependency)
|
700 |
+
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
|
701 |
+
|
702 |
+
if Context.block is not None:
|
703 |
+
Context.block.children.extend(self.children)
|
704 |
+
return self
|
705 |
+
|
706 |
+
def is_callable(self, fn_index: int = 0) -> bool:
|
707 |
+
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
|
708 |
+
block_fn = self.fns[fn_index]
|
709 |
+
dependency = self.dependencies[fn_index]
|
710 |
+
|
711 |
+
if inspect.isasyncgenfunction(block_fn.fn):
|
712 |
+
return False
|
713 |
+
if inspect.isgeneratorfunction(block_fn.fn):
|
714 |
+
return False
|
715 |
+
for input_id in dependency["inputs"]:
|
716 |
+
block = self.blocks[input_id]
|
717 |
+
if getattr(block, "stateful", False):
|
718 |
+
return False
|
719 |
+
for output_id in dependency["outputs"]:
|
720 |
+
block = self.blocks[output_id]
|
721 |
+
if getattr(block, "stateful", False):
|
722 |
+
return False
|
723 |
+
|
724 |
+
return True
|
725 |
+
|
726 |
+
def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
|
727 |
+
"""
|
728 |
+
Allows Blocks objects to be called as functions. Supply the parameters to the
|
729 |
+
function as positional arguments. To choose which function to call, use the
|
730 |
+
fn_index parameter, which must be a keyword argument.
|
731 |
+
|
732 |
+
Parameters:
|
733 |
+
*inputs: the parameters to pass to the function
|
734 |
+
fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
|
735 |
+
api_name: The api_name of the dependency to call. Will take precedence over fn_index.
|
736 |
+
"""
|
737 |
+
if api_name is not None:
|
738 |
+
inferred_fn_index = next(
|
739 |
+
(
|
740 |
+
i
|
741 |
+
for i, d in enumerate(self.dependencies)
|
742 |
+
if d.get("api_name") == api_name
|
743 |
+
),
|
744 |
+
None,
|
745 |
+
)
|
746 |
+
if inferred_fn_index is None:
|
747 |
+
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
|
748 |
+
fn_index = inferred_fn_index
|
749 |
+
if not (self.is_callable(fn_index)):
|
750 |
+
raise ValueError(
|
751 |
+
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
|
752 |
+
)
|
753 |
+
|
754 |
+
inputs = list(inputs)
|
755 |
+
processed_inputs = self.serialize_data(fn_index, inputs)
|
756 |
+
batch = self.dependencies[fn_index]["batch"]
|
757 |
+
if batch:
|
758 |
+
processed_inputs = [[inp] for inp in processed_inputs]
|
759 |
+
|
760 |
+
outputs = utils.synchronize_async(
|
761 |
+
self.process_api,
|
762 |
+
fn_index=fn_index,
|
763 |
+
inputs=processed_inputs,
|
764 |
+
request=None,
|
765 |
+
state={},
|
766 |
+
)
|
767 |
+
outputs = outputs["data"]
|
768 |
+
|
769 |
+
if batch:
|
770 |
+
outputs = [out[0] for out in outputs]
|
771 |
+
|
772 |
+
processed_outputs = self.deserialize_data(fn_index, outputs)
|
773 |
+
processed_outputs = utils.resolve_singleton(processed_outputs)
|
774 |
+
|
775 |
+
return processed_outputs
|
776 |
+
|
777 |
+
async def call_function(
|
778 |
+
self,
|
779 |
+
fn_index: int,
|
780 |
+
processed_input: List[Any],
|
781 |
+
iterator: Iterator[Any] | None = None,
|
782 |
+
requests: routes.Request | List[routes.Request] | None = None,
|
783 |
+
event_id: str | None = None,
|
784 |
+
):
|
785 |
+
"""
|
786 |
+
Calls function with given index and preprocessed input, and measures process time.
|
787 |
+
Parameters:
|
788 |
+
fn_index: index of function to call
|
789 |
+
processed_input: preprocessed input to pass to function
|
790 |
+
iterator: iterator to use if function is a generator
|
791 |
+
requests: requests to pass to function
|
792 |
+
event_id: id of event in queue
|
793 |
+
"""
|
794 |
+
block_fn = self.fns[fn_index]
|
795 |
+
assert block_fn.fn, f"function with index {fn_index} not defined."
|
796 |
+
is_generating = False
|
797 |
+
|
798 |
+
if block_fn.inputs_as_dict:
|
799 |
+
processed_input = [
|
800 |
+
{
|
801 |
+
input_component: data
|
802 |
+
for input_component, data in zip(block_fn.inputs, processed_input)
|
803 |
+
}
|
804 |
+
]
|
805 |
+
|
806 |
+
if isinstance(requests, list):
|
807 |
+
request = requests[0]
|
808 |
+
else:
|
809 |
+
request = requests
|
810 |
+
processed_input, progress_index = special_args(
|
811 |
+
block_fn.fn,
|
812 |
+
processed_input,
|
813 |
+
request,
|
814 |
+
)
|
815 |
+
progress_tracker = (
|
816 |
+
processed_input[progress_index] if progress_index is not None else None
|
817 |
+
)
|
818 |
+
|
819 |
+
start = time.time()
|
820 |
+
|
821 |
+
if iterator is None: # If not a generator function that has already run
|
822 |
+
if progress_tracker is not None and progress_index is not None:
|
823 |
+
progress_tracker, fn = create_tracker(
|
824 |
+
self, event_id, block_fn.fn, progress_tracker.track_tqdm
|
825 |
+
)
|
826 |
+
processed_input[progress_index] = progress_tracker
|
827 |
+
else:
|
828 |
+
fn = block_fn.fn
|
829 |
+
|
830 |
+
if inspect.iscoroutinefunction(fn):
|
831 |
+
prediction = await fn(*processed_input)
|
832 |
+
else:
|
833 |
+
prediction = await anyio.to_thread.run_sync(
|
834 |
+
fn, *processed_input, limiter=self.limiter
|
835 |
+
)
|
836 |
+
else:
|
837 |
+
prediction = None
|
838 |
+
|
839 |
+
if inspect.isasyncgenfunction(block_fn.fn):
|
840 |
+
raise ValueError("Gradio does not support async generators.")
|
841 |
+
if inspect.isgeneratorfunction(block_fn.fn):
|
842 |
+
if not self.enable_queue:
|
843 |
+
raise ValueError("Need to enable queue to use generators.")
|
844 |
+
try:
|
845 |
+
if iterator is None:
|
846 |
+
iterator = prediction
|
847 |
+
prediction = await anyio.to_thread.run_sync(
|
848 |
+
utils.async_iteration, iterator, limiter=self.limiter
|
849 |
+
)
|
850 |
+
is_generating = True
|
851 |
+
except StopAsyncIteration:
|
852 |
+
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
853 |
+
prediction = (
|
854 |
+
components._Keywords.FINISHED_ITERATING
|
855 |
+
if n_outputs == 1
|
856 |
+
else (components._Keywords.FINISHED_ITERATING,) * n_outputs
|
857 |
+
)
|
858 |
+
iterator = None
|
859 |
+
|
860 |
+
duration = time.time() - start
|
861 |
+
|
862 |
+
return {
|
863 |
+
"prediction": prediction,
|
864 |
+
"duration": duration,
|
865 |
+
"is_generating": is_generating,
|
866 |
+
"iterator": iterator,
|
867 |
+
}
|
868 |
+
|
869 |
+
def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
|
870 |
+
dependency = self.dependencies[fn_index]
|
871 |
+
processed_input = []
|
872 |
+
|
873 |
+
for i, input_id in enumerate(dependency["inputs"]):
|
874 |
+
block = self.blocks[input_id]
|
875 |
+
assert isinstance(
|
876 |
+
block, components.IOComponent
|
877 |
+
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
878 |
+
serialized_input = block.serialize(inputs[i])
|
879 |
+
processed_input.append(serialized_input)
|
880 |
+
|
881 |
+
return processed_input
|
882 |
+
|
883 |
+
def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
|
884 |
+
dependency = self.dependencies[fn_index]
|
885 |
+
predictions = []
|
886 |
+
|
887 |
+
for o, output_id in enumerate(dependency["outputs"]):
|
888 |
+
block = self.blocks[output_id]
|
889 |
+
assert isinstance(
|
890 |
+
block, components.IOComponent
|
891 |
+
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
892 |
+
deserialized = block.deserialize(outputs[o])
|
893 |
+
predictions.append(deserialized)
|
894 |
+
|
895 |
+
return predictions
|
896 |
+
|
897 |
+
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
|
898 |
+
block_fn = self.fns[fn_index]
|
899 |
+
dependency = self.dependencies[fn_index]
|
900 |
+
|
901 |
+
if block_fn.preprocess:
|
902 |
+
processed_input = []
|
903 |
+
for i, input_id in enumerate(dependency["inputs"]):
|
904 |
+
block = self.blocks[input_id]
|
905 |
+
assert isinstance(
|
906 |
+
block, components.Component
|
907 |
+
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
908 |
+
if getattr(block, "stateful", False):
|
909 |
+
processed_input.append(state.get(input_id))
|
910 |
+
else:
|
911 |
+
processed_input.append(block.preprocess(inputs[i]))
|
912 |
+
else:
|
913 |
+
processed_input = inputs
|
914 |
+
return processed_input
|
915 |
+
|
916 |
+
def postprocess_data(
|
917 |
+
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
918 |
+
):
|
919 |
+
block_fn = self.fns[fn_index]
|
920 |
+
dependency = self.dependencies[fn_index]
|
921 |
+
batch = dependency["batch"]
|
922 |
+
|
923 |
+
if type(predictions) is dict and len(predictions) > 0:
|
924 |
+
predictions = convert_component_dict_to_list(
|
925 |
+
dependency["outputs"], predictions
|
926 |
+
)
|
927 |
+
|
928 |
+
if len(dependency["outputs"]) == 1 and not (batch):
|
929 |
+
predictions = [
|
930 |
+
predictions,
|
931 |
+
]
|
932 |
+
|
933 |
+
output = []
|
934 |
+
for i, output_id in enumerate(dependency["outputs"]):
|
935 |
+
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
936 |
+
output.append(None)
|
937 |
+
continue
|
938 |
+
block = self.blocks[output_id]
|
939 |
+
if getattr(block, "stateful", False):
|
940 |
+
if not utils.is_update(predictions[i]):
|
941 |
+
state[output_id] = predictions[i]
|
942 |
+
output.append(None)
|
943 |
+
else:
|
944 |
+
prediction_value = predictions[i]
|
945 |
+
if utils.is_update(prediction_value):
|
946 |
+
assert isinstance(prediction_value, dict)
|
947 |
+
prediction_value = postprocess_update_dict(
|
948 |
+
block=block,
|
949 |
+
update_dict=prediction_value,
|
950 |
+
postprocess=block_fn.postprocess,
|
951 |
+
)
|
952 |
+
elif block_fn.postprocess:
|
953 |
+
assert isinstance(
|
954 |
+
block, components.Component
|
955 |
+
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
956 |
+
prediction_value = block.postprocess(prediction_value)
|
957 |
+
output.append(prediction_value)
|
958 |
+
return output
|
959 |
+
|
960 |
+
async def process_api(
|
961 |
+
self,
|
962 |
+
fn_index: int,
|
963 |
+
inputs: List[Any],
|
964 |
+
state: Dict[int, Any],
|
965 |
+
request: routes.Request | List[routes.Request] | None = None,
|
966 |
+
iterators: Dict[int, Any] | None = None,
|
967 |
+
event_id: str | None = None,
|
968 |
+
) -> Dict[str, Any]:
|
969 |
+
"""
|
970 |
+
Processes API calls from the frontend. First preprocesses the data,
|
971 |
+
then runs the relevant function, then postprocesses the output.
|
972 |
+
Parameters:
|
973 |
+
fn_index: Index of function to run.
|
974 |
+
inputs: input data received from the frontend
|
975 |
+
username: name of user if authentication is set up (not used)
|
976 |
+
state: data stored from stateful components for session (key is input block id)
|
977 |
+
iterators: the in-progress iterators for each generator function (key is function index)
|
978 |
+
Returns: None
|
979 |
+
"""
|
980 |
+
block_fn = self.fns[fn_index]
|
981 |
+
batch = self.dependencies[fn_index]["batch"]
|
982 |
+
|
983 |
+
if batch:
|
984 |
+
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
|
985 |
+
batch_sizes = [len(inp) for inp in inputs]
|
986 |
+
batch_size = batch_sizes[0]
|
987 |
+
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
|
988 |
+
block_fn.fn
|
989 |
+
):
|
990 |
+
raise ValueError("Gradio does not support generators in batch mode.")
|
991 |
+
if not all(x == batch_size for x in batch_sizes):
|
992 |
+
raise ValueError(
|
993 |
+
f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
|
994 |
+
)
|
995 |
+
if batch_size > max_batch_size:
|
996 |
+
raise ValueError(
|
997 |
+
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
|
998 |
+
)
|
999 |
+
|
1000 |
+
inputs = [
|
1001 |
+
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
|
1002 |
+
]
|
1003 |
+
result = await self.call_function(
|
1004 |
+
fn_index, list(zip(*inputs)), None, request
|
1005 |
+
)
|
1006 |
+
preds = result["prediction"]
|
1007 |
+
data = [
|
1008 |
+
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
|
1009 |
+
]
|
1010 |
+
data = list(zip(*data))
|
1011 |
+
is_generating, iterator = None, None
|
1012 |
+
else:
|
1013 |
+
inputs = self.preprocess_data(fn_index, inputs, state)
|
1014 |
+
iterator = iterators.get(fn_index, None) if iterators else None
|
1015 |
+
result = await self.call_function(
|
1016 |
+
fn_index, inputs, iterator, request, event_id
|
1017 |
+
)
|
1018 |
+
data = self.postprocess_data(fn_index, result["prediction"], state)
|
1019 |
+
is_generating, iterator = result["is_generating"], result["iterator"]
|
1020 |
+
|
1021 |
+
block_fn.total_runtime += result["duration"]
|
1022 |
+
block_fn.total_runs += 1
|
1023 |
+
|
1024 |
+
return {
|
1025 |
+
"data": data,
|
1026 |
+
"is_generating": is_generating,
|
1027 |
+
"iterator": iterator,
|
1028 |
+
"duration": result["duration"],
|
1029 |
+
"average_duration": block_fn.total_runtime / block_fn.total_runs,
|
1030 |
+
}
|
1031 |
+
|
1032 |
+
async def create_limiter(self):
|
1033 |
+
self.limiter = (
|
1034 |
+
None
|
1035 |
+
if self.max_threads == 40
|
1036 |
+
else CapacityLimiter(total_tokens=self.max_threads)
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
def get_config(self):
|
1040 |
+
return {"type": "column"}
|
1041 |
+
|
1042 |
+
def get_config_file(self):
|
1043 |
+
config = {
|
1044 |
+
"version": routes.VERSION,
|
1045 |
+
"mode": self.mode,
|
1046 |
+
"dev_mode": self.dev_mode,
|
1047 |
+
"components": [],
|
1048 |
+
"theme": self.theme,
|
1049 |
+
"css": self.css,
|
1050 |
+
"title": self.title or "Gradio",
|
1051 |
+
"is_space": self.is_space,
|
1052 |
+
"enable_queue": getattr(self, "enable_queue", False), # launch attributes
|
1053 |
+
"show_error": getattr(self, "show_error", False),
|
1054 |
+
"show_api": self.show_api,
|
1055 |
+
"is_colab": utils.colab_check(),
|
1056 |
+
}
|
1057 |
+
|
1058 |
+
def getLayout(block):
|
1059 |
+
if not isinstance(block, BlockContext):
|
1060 |
+
return {"id": block._id}
|
1061 |
+
children_layout = []
|
1062 |
+
for child in block.children:
|
1063 |
+
children_layout.append(getLayout(child))
|
1064 |
+
return {"id": block._id, "children": children_layout}
|
1065 |
+
|
1066 |
+
config["layout"] = getLayout(self)
|
1067 |
+
|
1068 |
+
for _id, block in self.blocks.items():
|
1069 |
+
config["components"].append(
|
1070 |
+
{
|
1071 |
+
"id": _id,
|
1072 |
+
"type": (block.get_block_name()),
|
1073 |
+
"props": utils.delete_none(block.get_config())
|
1074 |
+
if hasattr(block, "get_config")
|
1075 |
+
else {},
|
1076 |
+
}
|
1077 |
+
)
|
1078 |
+
config["dependencies"] = self.dependencies
|
1079 |
+
return config
|
1080 |
+
|
1081 |
+
def __enter__(self):
|
1082 |
+
if Context.block is None:
|
1083 |
+
Context.root_block = self
|
1084 |
+
self.parent = Context.block
|
1085 |
+
Context.block = self
|
1086 |
+
return self
|
1087 |
+
|
1088 |
+
def __exit__(self, *args):
|
1089 |
+
super().fill_expected_parents()
|
1090 |
+
Context.block = self.parent
|
1091 |
+
# Configure the load events before root_block is reset
|
1092 |
+
self.attach_load_events()
|
1093 |
+
if self.parent is None:
|
1094 |
+
Context.root_block = None
|
1095 |
+
else:
|
1096 |
+
self.parent.children.extend(self.children)
|
1097 |
+
self.config = self.get_config_file()
|
1098 |
+
self.app = routes.App.create_app(self)
|
1099 |
+
|
1100 |
+
@class_or_instancemethod
|
1101 |
+
def load(
|
1102 |
+
self_or_cls,
|
1103 |
+
fn: Callable | None = None,
|
1104 |
+
inputs: List[Component] | None = None,
|
1105 |
+
outputs: List[Component] | None = None,
|
1106 |
+
api_name: str | None = None,
|
1107 |
+
scroll_to_output: bool = False,
|
1108 |
+
show_progress: bool = True,
|
1109 |
+
queue=None,
|
1110 |
+
batch: bool = False,
|
1111 |
+
max_batch_size: int = 4,
|
1112 |
+
preprocess: bool = True,
|
1113 |
+
postprocess: bool = True,
|
1114 |
+
every: float | None = None,
|
1115 |
+
_js: str | None = None,
|
1116 |
+
*,
|
1117 |
+
name: str | None = None,
|
1118 |
+
src: str | None = None,
|
1119 |
+
api_key: str | None = None,
|
1120 |
+
alias: str | None = None,
|
1121 |
+
**kwargs,
|
1122 |
+
) -> Blocks | Dict[str, Any] | None:
|
1123 |
+
"""
|
1124 |
+
For reverse compatibility reasons, this is both a class method and an instance
|
1125 |
+
method, the two of which, confusingly, do two completely different things.
|
1126 |
+
|
1127 |
+
|
1128 |
+
Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Equivalent to gradio.Interface.load()
|
1129 |
+
|
1130 |
+
|
1131 |
+
Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
|
1132 |
+
Parameters:
|
1133 |
+
name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
|
1134 |
+
src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
|
1135 |
+
api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
|
1136 |
+
alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
|
1137 |
+
fn: Instance Method - the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
1138 |
+
inputs: Instance Method - List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
1139 |
+
outputs: Instance Method - List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
1140 |
+
api_name: Instance Method - Defining this parameter exposes the endpoint in the api docs
|
1141 |
+
scroll_to_output: Instance Method - If True, will scroll to output component on completion
|
1142 |
+
show_progress: Instance Method - If True, will show progress animation while pending
|
1143 |
+
queue: Instance Method - If True, will place the request on the queue, if the queue exists
|
1144 |
+
batch: Instance Method - If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
1145 |
+
max_batch_size: Instance Method - Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
1146 |
+
preprocess: Instance Method - If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
1147 |
+
postprocess: Instance Method - If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
1148 |
+
every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
|
1149 |
+
Example:
|
1150 |
+
import gradio as gr
|
1151 |
+
import datetime
|
1152 |
+
with gr.Blocks() as demo:
|
1153 |
+
def get_time():
|
1154 |
+
return datetime.datetime.now().time()
|
1155 |
+
dt = gr.Textbox(label="Current time")
|
1156 |
+
demo.load(get_time, inputs=None, outputs=dt)
|
1157 |
+
demo.launch()
|
1158 |
+
"""
|
1159 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
1160 |
+
if isinstance(self_or_cls, type):
|
1161 |
+
if name is None:
|
1162 |
+
raise ValueError(
|
1163 |
+
"Blocks.load() requires passing parameters as keyword arguments"
|
1164 |
+
)
|
1165 |
+
return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs)
|
1166 |
+
else:
|
1167 |
+
return self_or_cls.set_event_trigger(
|
1168 |
+
event_name="load",
|
1169 |
+
fn=fn,
|
1170 |
+
inputs=inputs,
|
1171 |
+
outputs=outputs,
|
1172 |
+
api_name=api_name,
|
1173 |
+
preprocess=preprocess,
|
1174 |
+
postprocess=postprocess,
|
1175 |
+
scroll_to_output=scroll_to_output,
|
1176 |
+
show_progress=show_progress,
|
1177 |
+
js=_js,
|
1178 |
+
queue=queue,
|
1179 |
+
batch=batch,
|
1180 |
+
max_batch_size=max_batch_size,
|
1181 |
+
every=every,
|
1182 |
+
no_target=True,
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
def clear(self):
|
1186 |
+
"""Resets the layout of the Blocks object."""
|
1187 |
+
self.blocks = {}
|
1188 |
+
self.fns = []
|
1189 |
+
self.dependencies = []
|
1190 |
+
self.children = []
|
1191 |
+
return self
|
1192 |
+
|
1193 |
+
@document()
|
1194 |
+
def queue(
|
1195 |
+
self,
|
1196 |
+
concurrency_count: int = 1,
|
1197 |
+
status_update_rate: float | Literal["auto"] = "auto",
|
1198 |
+
client_position_to_load_data: int | None = None,
|
1199 |
+
default_enabled: bool | None = None,
|
1200 |
+
api_open: bool = True,
|
1201 |
+
max_size: int | None = None,
|
1202 |
+
):
|
1203 |
+
"""
|
1204 |
+
You can control the rate of processed requests by creating a queue. This will allow you to set the number of requests to be processed at one time, and will let users know their position in the queue.
|
1205 |
+
Parameters:
|
1206 |
+
concurrency_count: Number of worker threads that will be processing requests from the queue concurrently. Increasing this number will increase the rate at which requests are processed, but will also increase the memory usage of the queue.
|
1207 |
+
status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
|
1208 |
+
client_position_to_load_data: DEPRECATED. This parameter is deprecated and has no effect.
|
1209 |
+
default_enabled: Deprecated and has no effect.
|
1210 |
+
api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
|
1211 |
+
max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited.
|
1212 |
+
Example:
|
1213 |
+
demo = gr.Interface(gr.Textbox(), gr.Image(), image_generator)
|
1214 |
+
demo.queue(concurrency_count=3)
|
1215 |
+
demo.launch()
|
1216 |
+
"""
|
1217 |
+
if default_enabled is not None:
|
1218 |
+
warnings.warn(
|
1219 |
+
"The default_enabled parameter of queue has no effect and will be removed "
|
1220 |
+
"in a future version of gradio."
|
1221 |
+
)
|
1222 |
+
self.enable_queue = True
|
1223 |
+
self.api_open = api_open
|
1224 |
+
if client_position_to_load_data is not None:
|
1225 |
+
warnings.warn("The client_position_to_load_data parameter is deprecated.")
|
1226 |
+
self._queue = queueing.Queue(
|
1227 |
+
live_updates=status_update_rate == "auto",
|
1228 |
+
concurrency_count=concurrency_count,
|
1229 |
+
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
|
1230 |
+
max_size=max_size,
|
1231 |
+
blocks_dependencies=self.dependencies,
|
1232 |
+
)
|
1233 |
+
self.config = self.get_config_file()
|
1234 |
+
return self
|
1235 |
+
|
1236 |
+
def launch(
|
1237 |
+
self,
|
1238 |
+
inline: bool | None = None,
|
1239 |
+
inbrowser: bool = False,
|
1240 |
+
share: bool | None = None,
|
1241 |
+
debug: bool = False,
|
1242 |
+
enable_queue: bool | None = None,
|
1243 |
+
max_threads: int = 40,
|
1244 |
+
auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
|
1245 |
+
auth_message: str | None = None,
|
1246 |
+
prevent_thread_lock: bool = False,
|
1247 |
+
show_error: bool = False,
|
1248 |
+
server_name: str | None = None,
|
1249 |
+
server_port: int | None = None,
|
1250 |
+
show_tips: bool = False,
|
1251 |
+
height: int = 500,
|
1252 |
+
width: int | str = "100%",
|
1253 |
+
encrypt: bool = False,
|
1254 |
+
favicon_path: str | None = None,
|
1255 |
+
ssl_keyfile: str | None = None,
|
1256 |
+
ssl_certfile: str | None = None,
|
1257 |
+
ssl_keyfile_password: str | None = None,
|
1258 |
+
quiet: bool = False,
|
1259 |
+
show_api: bool = True,
|
1260 |
+
_frontend: bool = True,
|
1261 |
+
) -> Tuple[FastAPI, str, str]:
|
1262 |
+
"""
|
1263 |
+
Launches a simple web server that serves the demo. Can also be used to create a
|
1264 |
+
public link used by anyone to access the demo from their browser by setting share=True.
|
1265 |
+
|
1266 |
+
Parameters:
|
1267 |
+
inline: whether to display in the interface inline in an iframe. Defaults to True in python notebooks; False otherwise.
|
1268 |
+
inbrowser: whether to automatically launch the interface in a new tab on the default browser.
|
1269 |
+
share: whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. If not provided, it is set to False by default every time, except when running in Google Colab. When localhost is not accessible (e.g. Google Colab), setting share=False is not supported.
|
1270 |
+
debug: if True, blocks the main thread from running. If running in Google Colab, this is needed to print the errors in the cell output.
|
1271 |
+
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
1272 |
+
auth_message: If provided, HTML message provided on login page.
|
1273 |
+
prevent_thread_lock: If True, the interface will block the main thread while the server is running.
|
1274 |
+
show_error: If True, any errors in the interface will be displayed in an alert modal and printed in the browser console log
|
1275 |
+
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. If None, will search for an available port starting at 7860.
|
1276 |
+
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
|
1277 |
+
show_tips: if True, will occasionally show tips about new Gradio features
|
1278 |
+
enable_queue: DEPRECATED (use .queue() method instead.) if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
1279 |
+
max_threads: the maximum number of total threads that the Gradio app can generate in parallel. The default is inherited from the starlette library (currently 40). Applies whether the queue is enabled or not. But if queuing is enabled, this parameter is increaseed to be at least the concurrency_count of the queue.
|
1280 |
+
width: The width in pixels of the iframe element containing the interface (used if inline=True)
|
1281 |
+
height: The height in pixels of the iframe element containing the interface (used if inline=True)
|
1282 |
+
encrypt: If True, flagged data will be encrypted by key provided by creator at launch
|
1283 |
+
favicon_path: If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
|
1284 |
+
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
1285 |
+
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
1286 |
+
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
|
1287 |
+
quiet: If True, suppresses most print statements.
|
1288 |
+
show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
|
1289 |
+
Returns:
|
1290 |
+
app: FastAPI app object that is running the demo
|
1291 |
+
local_url: Locally accessible link to the demo
|
1292 |
+
share_url: Publicly accessible link to the demo (if share=True, otherwise None)
|
1293 |
+
Example:
|
1294 |
+
import gradio as gr
|
1295 |
+
def reverse(text):
|
1296 |
+
return text[::-1]
|
1297 |
+
demo = gr.Interface(reverse, "text", "text")
|
1298 |
+
demo.launch(share=True, auth=("username", "password"))
|
1299 |
+
"""
|
1300 |
+
self.dev_mode = False
|
1301 |
+
if (
|
1302 |
+
auth
|
1303 |
+
and not callable(auth)
|
1304 |
+
and not isinstance(auth[0], tuple)
|
1305 |
+
and not isinstance(auth[0], list)
|
1306 |
+
):
|
1307 |
+
self.auth = [auth]
|
1308 |
+
else:
|
1309 |
+
self.auth = auth
|
1310 |
+
self.auth_message = auth_message
|
1311 |
+
self.show_tips = show_tips
|
1312 |
+
self.show_error = show_error
|
1313 |
+
self.height = height
|
1314 |
+
self.width = width
|
1315 |
+
self.favicon_path = favicon_path
|
1316 |
+
self.progress_tracking = any(
|
1317 |
+
block_fn.fn is not None and special_args(block_fn.fn)[1] is not None
|
1318 |
+
for block_fn in self.fns
|
1319 |
+
)
|
1320 |
+
|
1321 |
+
if enable_queue is not None:
|
1322 |
+
self.enable_queue = enable_queue
|
1323 |
+
warnings.warn(
|
1324 |
+
"The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
|
1325 |
+
DeprecationWarning,
|
1326 |
+
)
|
1327 |
+
|
1328 |
+
if self.is_space:
|
1329 |
+
self.enable_queue = self.enable_queue is not False
|
1330 |
+
else:
|
1331 |
+
self.enable_queue = self.enable_queue is True
|
1332 |
+
if self.enable_queue and not hasattr(self, "_queue"):
|
1333 |
+
self.queue()
|
1334 |
+
self.show_api = self.api_open if self.enable_queue else show_api
|
1335 |
+
|
1336 |
+
if not self.enable_queue and self.progress_tracking:
|
1337 |
+
raise ValueError("Progress tracking requires queuing to be enabled.")
|
1338 |
+
|
1339 |
+
for dep in self.dependencies:
|
1340 |
+
for i in dep["cancels"]:
|
1341 |
+
if not self.queue_enabled_for_fn(i):
|
1342 |
+
raise ValueError(
|
1343 |
+
"In order to cancel an event, the queue for that event must be enabled! "
|
1344 |
+
"You may get this error by either 1) passing a function that uses the yield keyword "
|
1345 |
+
"into an interface without enabling the queue or 2) defining an event that cancels "
|
1346 |
+
"another event without enabling the queue. Both can be solved by calling .queue() "
|
1347 |
+
"before .launch()"
|
1348 |
+
)
|
1349 |
+
if dep["batch"] and (
|
1350 |
+
dep["queue"] is False
|
1351 |
+
or (dep["queue"] is None and not self.enable_queue)
|
1352 |
+
):
|
1353 |
+
raise ValueError("In order to use batching, the queue must be enabled.")
|
1354 |
+
|
1355 |
+
self.config = self.get_config_file()
|
1356 |
+
self.encrypt = encrypt
|
1357 |
+
self.max_threads = max(
|
1358 |
+
self._queue.max_thread_count if self.enable_queue else 0, max_threads
|
1359 |
+
)
|
1360 |
+
if self.encrypt:
|
1361 |
+
self.encryption_key = encryptor.get_key(
|
1362 |
+
getpass.getpass("Enter key for encryption: ")
|
1363 |
+
)
|
1364 |
+
|
1365 |
+
if self.is_running:
|
1366 |
+
assert isinstance(
|
1367 |
+
self.local_url, str
|
1368 |
+
), f"Invalid local_url: {self.local_url}"
|
1369 |
+
if not (quiet):
|
1370 |
+
print(
|
1371 |
+
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
|
1372 |
+
)
|
1373 |
+
else:
|
1374 |
+
server_name, server_port, local_url, app, server = networking.start_server(
|
1375 |
+
self,
|
1376 |
+
server_name,
|
1377 |
+
server_port,
|
1378 |
+
ssl_keyfile,
|
1379 |
+
ssl_certfile,
|
1380 |
+
ssl_keyfile_password,
|
1381 |
+
)
|
1382 |
+
self.server_name = server_name
|
1383 |
+
self.local_url = local_url
|
1384 |
+
self.server_port = server_port
|
1385 |
+
self.server_app = app
|
1386 |
+
self.server = server
|
1387 |
+
self.is_running = True
|
1388 |
+
self.is_colab = utils.colab_check()
|
1389 |
+
self.protocol = (
|
1390 |
+
"https"
|
1391 |
+
if self.local_url.startswith("https") or self.is_colab
|
1392 |
+
else "http"
|
1393 |
+
)
|
1394 |
+
|
1395 |
+
if self.enable_queue:
|
1396 |
+
self._queue.set_url(self.local_url)
|
1397 |
+
|
1398 |
+
# Cannot run async functions in background other than app's scope.
|
1399 |
+
# Workaround by triggering the app endpoint
|
1400 |
+
requests.get(f"{self.local_url}startup-events")
|
1401 |
+
|
1402 |
+
if self.enable_queue:
|
1403 |
+
if self.encrypt:
|
1404 |
+
raise ValueError("Cannot queue with encryption enabled.")
|
1405 |
+
utils.launch_counter()
|
1406 |
+
|
1407 |
+
self.share = (
|
1408 |
+
share
|
1409 |
+
if share is not None
|
1410 |
+
else True
|
1411 |
+
if self.is_colab and self.enable_queue
|
1412 |
+
else False
|
1413 |
+
)
|
1414 |
+
|
1415 |
+
# If running in a colab or not able to access localhost,
|
1416 |
+
# a shareable link must be created.
|
1417 |
+
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
|
1418 |
+
raise ValueError(
|
1419 |
+
"When localhost is not accessible, a shareable link must be created. Please set share=True."
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
if self.is_colab:
|
1423 |
+
if not quiet:
|
1424 |
+
if debug:
|
1425 |
+
print(strings.en["COLAB_DEBUG_TRUE"])
|
1426 |
+
else:
|
1427 |
+
print(strings.en["COLAB_DEBUG_FALSE"])
|
1428 |
+
if not self.share:
|
1429 |
+
print(strings.en["COLAB_WARNING"].format(self.server_port))
|
1430 |
+
if self.enable_queue and not self.share:
|
1431 |
+
raise ValueError(
|
1432 |
+
"When using queueing in Colab, a shareable link must be created. Please set share=True."
|
1433 |
+
)
|
1434 |
+
else:
|
1435 |
+
print(
|
1436 |
+
strings.en["RUNNING_LOCALLY_SEPARATED"].format(
|
1437 |
+
self.protocol, self.server_name, self.server_port
|
1438 |
+
)
|
1439 |
+
)
|
1440 |
+
|
1441 |
+
if self.share:
|
1442 |
+
if self.is_space:
|
1443 |
+
raise RuntimeError("Share is not supported when you are in Spaces")
|
1444 |
+
try:
|
1445 |
+
if self.share_url is None:
|
1446 |
+
self.share_url = networking.setup_tunnel(
|
1447 |
+
self.server_name, self.server_port
|
1448 |
+
)
|
1449 |
+
print(strings.en["SHARE_LINK_DISPLAY"].format(self.share_url))
|
1450 |
+
if not (quiet):
|
1451 |
+
print(strings.en["SHARE_LINK_MESSAGE"])
|
1452 |
+
except RuntimeError:
|
1453 |
+
if self.analytics_enabled:
|
1454 |
+
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
|
1455 |
+
self.share_url = None
|
1456 |
+
self.share = False
|
1457 |
+
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
|
1458 |
+
else:
|
1459 |
+
if not (quiet):
|
1460 |
+
print(strings.en["PUBLIC_SHARE_TRUE"])
|
1461 |
+
self.share_url = None
|
1462 |
+
|
1463 |
+
if inbrowser:
|
1464 |
+
link = self.share_url if self.share and self.share_url else self.local_url
|
1465 |
+
webbrowser.open(link)
|
1466 |
+
|
1467 |
+
# Check if running in a Python notebook in which case, display inline
|
1468 |
+
if inline is None:
|
1469 |
+
inline = utils.ipython_check() and (self.auth is None)
|
1470 |
+
if inline:
|
1471 |
+
if self.auth is not None:
|
1472 |
+
print(
|
1473 |
+
"Warning: authentication is not supported inline. Please"
|
1474 |
+
"click the link to access the interface in a new tab."
|
1475 |
+
)
|
1476 |
+
try:
|
1477 |
+
from IPython.display import HTML, Javascript, display # type: ignore
|
1478 |
+
|
1479 |
+
if self.share and self.share_url:
|
1480 |
+
while not networking.url_ok(self.share_url):
|
1481 |
+
time.sleep(0.25)
|
1482 |
+
display(
|
1483 |
+
HTML(
|
1484 |
+
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
1485 |
+
)
|
1486 |
+
)
|
1487 |
+
elif self.is_colab:
|
1488 |
+
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
|
1489 |
+
code = """(async (port, path, width, height, cache, element) => {
|
1490 |
+
if (!google.colab.kernel.accessAllowed && !cache) {
|
1491 |
+
return;
|
1492 |
+
}
|
1493 |
+
element.appendChild(document.createTextNode(''));
|
1494 |
+
const url = await google.colab.kernel.proxyPort(port, {cache});
|
1495 |
+
|
1496 |
+
const external_link = document.createElement('div');
|
1497 |
+
external_link.innerHTML = `
|
1498 |
+
<div style="font-family: monospace; margin-bottom: 0.5rem">
|
1499 |
+
Running on <a href=${new URL(path, url).toString()} target="_blank">
|
1500 |
+
https://localhost:${port}${path}
|
1501 |
+
</a>
|
1502 |
+
</div>
|
1503 |
+
`;
|
1504 |
+
element.appendChild(external_link);
|
1505 |
+
|
1506 |
+
const iframe = document.createElement('iframe');
|
1507 |
+
iframe.src = new URL(path, url).toString();
|
1508 |
+
iframe.height = height;
|
1509 |
+
iframe.allow = "autoplay; camera; microphone; clipboard-read; clipboard-write;"
|
1510 |
+
iframe.width = width;
|
1511 |
+
iframe.style.border = 0;
|
1512 |
+
element.appendChild(iframe);
|
1513 |
+
})""" + "({port}, {path}, {width}, {height}, {cache}, window.element)".format(
|
1514 |
+
port=json.dumps(self.server_port),
|
1515 |
+
path=json.dumps("/"),
|
1516 |
+
width=json.dumps(self.width),
|
1517 |
+
height=json.dumps(self.height),
|
1518 |
+
cache=json.dumps(False),
|
1519 |
+
)
|
1520 |
+
|
1521 |
+
display(Javascript(code))
|
1522 |
+
else:
|
1523 |
+
display(
|
1524 |
+
HTML(
|
1525 |
+
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
1526 |
+
)
|
1527 |
+
)
|
1528 |
+
except ImportError:
|
1529 |
+
pass
|
1530 |
+
|
1531 |
+
if getattr(self, "analytics_enabled", False):
|
1532 |
+
data = {
|
1533 |
+
"launch_method": "browser" if inbrowser else "inline",
|
1534 |
+
"is_google_colab": self.is_colab,
|
1535 |
+
"is_sharing_on": self.share,
|
1536 |
+
"share_url": self.share_url,
|
1537 |
+
"ip_address": self.ip_address,
|
1538 |
+
"enable_queue": self.enable_queue,
|
1539 |
+
"show_tips": self.show_tips,
|
1540 |
+
"server_name": server_name,
|
1541 |
+
"server_port": server_port,
|
1542 |
+
"is_spaces": self.is_space,
|
1543 |
+
"mode": self.mode,
|
1544 |
+
}
|
1545 |
+
utils.launch_analytics(data)
|
1546 |
+
|
1547 |
+
utils.show_tip(self)
|
1548 |
+
|
1549 |
+
# Block main thread if debug==True
|
1550 |
+
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
|
1551 |
+
self.block_thread()
|
1552 |
+
# Block main thread if running in a script to stop script from exiting
|
1553 |
+
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
1554 |
+
|
1555 |
+
if not prevent_thread_lock and not is_in_interactive_mode:
|
1556 |
+
self.block_thread()
|
1557 |
+
|
1558 |
+
return TupleNoPrint((self.server_app, self.local_url, self.share_url))
|
1559 |
+
|
1560 |
+
def integrate(
|
1561 |
+
self,
|
1562 |
+
comet_ml: comet_ml.Experiment | None = None,
|
1563 |
+
wandb: ModuleType | None = None,
|
1564 |
+
mlflow: ModuleType | None = None,
|
1565 |
+
) -> None:
|
1566 |
+
"""
|
1567 |
+
A catch-all method for integrating with other libraries. This method should be run after launch()
|
1568 |
+
Parameters:
|
1569 |
+
comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
|
1570 |
+
wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
|
1571 |
+
mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
|
1572 |
+
"""
|
1573 |
+
analytics_integration = ""
|
1574 |
+
if comet_ml is not None:
|
1575 |
+
analytics_integration = "CometML"
|
1576 |
+
comet_ml.log_other("Created from", "Gradio")
|
1577 |
+
if self.share_url is not None:
|
1578 |
+
comet_ml.log_text("gradio: " + self.share_url)
|
1579 |
+
comet_ml.end()
|
1580 |
+
elif self.local_url:
|
1581 |
+
comet_ml.log_text("gradio: " + self.local_url)
|
1582 |
+
comet_ml.end()
|
1583 |
+
else:
|
1584 |
+
raise ValueError("Please run `launch()` first.")
|
1585 |
+
if wandb is not None:
|
1586 |
+
analytics_integration = "WandB"
|
1587 |
+
if self.share_url is not None:
|
1588 |
+
wandb.log(
|
1589 |
+
{
|
1590 |
+
"Gradio panel": wandb.Html(
|
1591 |
+
'<iframe src="'
|
1592 |
+
+ self.share_url
|
1593 |
+
+ '" width="'
|
1594 |
+
+ str(self.width)
|
1595 |
+
+ '" height="'
|
1596 |
+
+ str(self.height)
|
1597 |
+
+ '" frameBorder="0"></iframe>'
|
1598 |
+
)
|
1599 |
+
}
|
1600 |
+
)
|
1601 |
+
else:
|
1602 |
+
print(
|
1603 |
+
"The WandB integration requires you to "
|
1604 |
+
"`launch(share=True)` first."
|
1605 |
+
)
|
1606 |
+
if mlflow is not None:
|
1607 |
+
analytics_integration = "MLFlow"
|
1608 |
+
if self.share_url is not None:
|
1609 |
+
mlflow.log_param("Gradio Interface Share Link", self.share_url)
|
1610 |
+
else:
|
1611 |
+
mlflow.log_param("Gradio Interface Local Link", self.local_url)
|
1612 |
+
if self.analytics_enabled and analytics_integration:
|
1613 |
+
data = {"integration": analytics_integration}
|
1614 |
+
utils.integration_analytics(data)
|
1615 |
+
|
1616 |
+
def close(self, verbose: bool = True) -> None:
|
1617 |
+
"""
|
1618 |
+
Closes the Interface that was launched and frees the port.
|
1619 |
+
"""
|
1620 |
+
try:
|
1621 |
+
if self.enable_queue:
|
1622 |
+
self._queue.close()
|
1623 |
+
self.server.close()
|
1624 |
+
self.is_running = False
|
1625 |
+
if verbose:
|
1626 |
+
print("Closing server running on port: {}".format(self.server_port))
|
1627 |
+
except (AttributeError, OSError): # can't close if not running
|
1628 |
+
pass
|
1629 |
+
|
1630 |
+
def block_thread(
|
1631 |
+
self,
|
1632 |
+
) -> None:
|
1633 |
+
"""Block main thread until interrupted by user."""
|
1634 |
+
try:
|
1635 |
+
while True:
|
1636 |
+
time.sleep(0.1)
|
1637 |
+
except (KeyboardInterrupt, OSError):
|
1638 |
+
print("Keyboard interruption in main thread... closing server.")
|
1639 |
+
self.server.close()
|
1640 |
+
for tunnel in CURRENT_TUNNELS:
|
1641 |
+
tunnel.kill()
|
1642 |
+
|
1643 |
+
def attach_load_events(self):
|
1644 |
+
"""Add a load event for every component whose initial value should be randomized."""
|
1645 |
+
if Context.root_block:
|
1646 |
+
for component in Context.root_block.blocks.values():
|
1647 |
+
if (
|
1648 |
+
isinstance(component, components.IOComponent)
|
1649 |
+
and component.load_event_to_attach
|
1650 |
+
):
|
1651 |
+
load_fn, every = component.load_event_to_attach
|
1652 |
+
# Use set_event_trigger to avoid ambiguity between load class/instance method
|
1653 |
+
self.set_event_trigger(
|
1654 |
+
"load",
|
1655 |
+
load_fn,
|
1656 |
+
None,
|
1657 |
+
component,
|
1658 |
+
no_target=True,
|
1659 |
+
queue=False,
|
1660 |
+
every=every,
|
1661 |
+
)
|
1662 |
+
|
1663 |
+
def startup_events(self):
|
1664 |
+
"""Events that should be run when the app containing this block starts up."""
|
1665 |
+
|
1666 |
+
if self.enable_queue:
|
1667 |
+
utils.run_coro_in_background(self._queue.start, (self.progress_tracking,))
|
1668 |
+
utils.run_coro_in_background(self.create_limiter)
|
1669 |
+
|
1670 |
+
def queue_enabled_for_fn(self, fn_index: int):
|
1671 |
+
if self.dependencies[fn_index]["queue"] is None:
|
1672 |
+
return self.enable_queue
|
1673 |
+
return self.dependencies[fn_index]["queue"]
|
gradio-modified/gradio/components.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gradio-modified/gradio/context.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defines the Context class, which is used to store the state of all Blocks that are being rendered.
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import TYPE_CHECKING
|
6 |
+
|
7 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
8 |
+
from gradio.blocks import BlockContext, Blocks
|
9 |
+
|
10 |
+
|
11 |
+
class Context:
|
12 |
+
root_block: Blocks | None = None # The current root block that holds all blocks.
|
13 |
+
block: BlockContext | None = None # The current block that children are added to.
|
14 |
+
id: int = 0 # Running id to uniquely refer to any block that gets defined
|
gradio-modified/gradio/data_classes.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Pydantic data models and other dataclasses. This is the only file that uses Optional[]
|
2 |
+
typing syntax instead of | None syntax to work with pydantic"""
|
3 |
+
|
4 |
+
from enum import Enum, auto
|
5 |
+
from typing import Any, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
class PredictBody(BaseModel):
|
11 |
+
session_hash: Optional[str]
|
12 |
+
event_id: Optional[str]
|
13 |
+
data: List[Any]
|
14 |
+
fn_index: Optional[int]
|
15 |
+
batched: Optional[
|
16 |
+
bool
|
17 |
+
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
|
18 |
+
request: Optional[
|
19 |
+
Union[Dict, List[Dict]]
|
20 |
+
] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
|
21 |
+
|
22 |
+
|
23 |
+
class ResetBody(BaseModel):
|
24 |
+
session_hash: str
|
25 |
+
fn_index: int
|
26 |
+
|
27 |
+
|
28 |
+
class InterfaceTypes(Enum):
|
29 |
+
STANDARD = auto()
|
30 |
+
INPUT_ONLY = auto()
|
31 |
+
OUTPUT_ONLY = auto()
|
32 |
+
UNIFIED = auto()
|
33 |
+
|
34 |
+
|
35 |
+
class Estimation(BaseModel):
|
36 |
+
msg: Optional[str] = "estimation"
|
37 |
+
rank: Optional[int] = None
|
38 |
+
queue_size: int
|
39 |
+
avg_event_process_time: Optional[float]
|
40 |
+
avg_event_concurrent_process_time: Optional[float]
|
41 |
+
rank_eta: Optional[float] = None
|
42 |
+
queue_eta: float
|
43 |
+
|
44 |
+
|
45 |
+
class ProgressUnit(BaseModel):
|
46 |
+
index: Optional[int]
|
47 |
+
length: Optional[int]
|
48 |
+
unit: Optional[str]
|
49 |
+
progress: Optional[float]
|
50 |
+
desc: Optional[str]
|
51 |
+
|
52 |
+
|
53 |
+
class Progress(BaseModel):
|
54 |
+
msg: str = "progress"
|
55 |
+
progress_data: List[ProgressUnit] = []
|
gradio-modified/gradio/deprecation.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
|
4 |
+
def simple_deprecated_notice(term: str) -> str:
|
5 |
+
return f"`{term}` parameter is deprecated, and it has no effect"
|
6 |
+
|
7 |
+
|
8 |
+
def use_in_launch(term: str) -> str:
|
9 |
+
return f"`{term}` is deprecated in `Interface()`, please use it within `launch()` instead."
|
10 |
+
|
11 |
+
|
12 |
+
DEPRECATION_MESSAGE = {
|
13 |
+
"optional": simple_deprecated_notice("optional"),
|
14 |
+
"keep_filename": simple_deprecated_notice("keep_filename"),
|
15 |
+
"numeric": simple_deprecated_notice("numeric"),
|
16 |
+
"verbose": simple_deprecated_notice("verbose"),
|
17 |
+
"allow_screenshot": simple_deprecated_notice("allow_screenshot"),
|
18 |
+
"layout": simple_deprecated_notice("layout"),
|
19 |
+
"show_input": simple_deprecated_notice("show_input"),
|
20 |
+
"show_output": simple_deprecated_notice("show_output"),
|
21 |
+
"capture_session": simple_deprecated_notice("capture_session"),
|
22 |
+
"api_mode": simple_deprecated_notice("api_mode"),
|
23 |
+
"show_tips": use_in_launch("show_tips"),
|
24 |
+
"encrypt": use_in_launch("encrypt"),
|
25 |
+
"enable_queue": use_in_launch("enable_queue"),
|
26 |
+
"server_name": use_in_launch("server_name"),
|
27 |
+
"server_port": use_in_launch("server_port"),
|
28 |
+
"width": use_in_launch("width"),
|
29 |
+
"height": use_in_launch("height"),
|
30 |
+
"plot": "The 'plot' parameter has been deprecated. Use the new Plot component instead",
|
31 |
+
"type": "The 'type' parameter has been deprecated. Use the Number component instead.",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def check_deprecated_parameters(cls: str, **kwargs) -> None:
|
36 |
+
for key, value in DEPRECATION_MESSAGE.items():
|
37 |
+
if key in kwargs:
|
38 |
+
kwargs.pop(key)
|
39 |
+
# Interestingly, using DeprecationWarning causes warning to not appear.
|
40 |
+
warnings.warn(value)
|
41 |
+
|
42 |
+
if len(kwargs) != 0:
|
43 |
+
warnings.warn(
|
44 |
+
f"You have unused kwarg parameters in {cls}, please remove them: {kwargs}"
|
45 |
+
)
|
gradio-modified/gradio/documentation.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Contains methods that generate documentation for Gradio functions and classes."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import inspect
|
6 |
+
from typing import Callable, Dict, List, Tuple
|
7 |
+
|
8 |
+
classes_to_document = {}
|
9 |
+
documentation_group = None
|
10 |
+
|
11 |
+
|
12 |
+
def set_documentation_group(m):
|
13 |
+
global documentation_group
|
14 |
+
documentation_group = m
|
15 |
+
if m not in classes_to_document:
|
16 |
+
classes_to_document[m] = []
|
17 |
+
|
18 |
+
|
19 |
+
def document(*fns):
|
20 |
+
"""
|
21 |
+
Defines the @document decorator which adds classes or functions to the Gradio
|
22 |
+
documentation at www.gradio.app/docs.
|
23 |
+
|
24 |
+
Usage examples:
|
25 |
+
- Put @document() above a class to document the class and its constructor.
|
26 |
+
- Put @document(fn1, fn2) above a class to also document the class methods fn1 and fn2.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def inner_doc(cls):
|
30 |
+
global documentation_group
|
31 |
+
classes_to_document[documentation_group].append((cls, fns))
|
32 |
+
return cls
|
33 |
+
|
34 |
+
return inner_doc
|
35 |
+
|
36 |
+
|
37 |
+
def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, str | None]:
|
38 |
+
"""
|
39 |
+
Generates documentation for any function.
|
40 |
+
Parameters:
|
41 |
+
fn: Function to document
|
42 |
+
Returns:
|
43 |
+
description: General description of fn
|
44 |
+
parameters: A list of dicts for each parameter, storing data for the parameter name, annotation and doc
|
45 |
+
return: A dict storing data for the returned annotation and doc
|
46 |
+
example: Code for an example use of the fn
|
47 |
+
"""
|
48 |
+
doc_str = inspect.getdoc(fn) or ""
|
49 |
+
doc_lines = doc_str.split("\n")
|
50 |
+
signature = inspect.signature(fn)
|
51 |
+
description, parameters, returns, examples = [], {}, [], []
|
52 |
+
mode = "description"
|
53 |
+
for line in doc_lines:
|
54 |
+
line = line.rstrip()
|
55 |
+
if line == "Parameters:":
|
56 |
+
mode = "parameter"
|
57 |
+
elif line == "Example:":
|
58 |
+
mode = "example"
|
59 |
+
elif line == "Returns:":
|
60 |
+
mode = "return"
|
61 |
+
else:
|
62 |
+
if mode == "description":
|
63 |
+
description.append(line if line.strip() else "<br>")
|
64 |
+
continue
|
65 |
+
assert (
|
66 |
+
line.startswith(" ") or line.strip() == ""
|
67 |
+
), f"Documentation format for {fn.__name__} has format error in line: {line}"
|
68 |
+
line = line[4:]
|
69 |
+
if mode == "parameter":
|
70 |
+
colon_index = line.index(": ")
|
71 |
+
assert (
|
72 |
+
colon_index > -1
|
73 |
+
), f"Documentation format for {fn.__name__} has format error in line: {line}"
|
74 |
+
parameter = line[:colon_index]
|
75 |
+
parameter_doc = line[colon_index + 2 :]
|
76 |
+
parameters[parameter] = parameter_doc
|
77 |
+
elif mode == "return":
|
78 |
+
returns.append(line)
|
79 |
+
elif mode == "example":
|
80 |
+
examples.append(line)
|
81 |
+
description_doc = " ".join(description)
|
82 |
+
parameter_docs = []
|
83 |
+
for param_name, param in signature.parameters.items():
|
84 |
+
if param_name.startswith("_"):
|
85 |
+
continue
|
86 |
+
if param_name == "kwargs" and param_name not in parameters:
|
87 |
+
continue
|
88 |
+
parameter_doc = {
|
89 |
+
"name": param_name,
|
90 |
+
"annotation": param.annotation,
|
91 |
+
"doc": parameters.get(param_name),
|
92 |
+
}
|
93 |
+
if param_name in parameters:
|
94 |
+
del parameters[param_name]
|
95 |
+
if param.default != inspect.Parameter.empty:
|
96 |
+
default = param.default
|
97 |
+
if type(default) == str:
|
98 |
+
default = '"' + default + '"'
|
99 |
+
if default.__class__.__module__ != "builtins":
|
100 |
+
default = f"{default.__class__.__name__}()"
|
101 |
+
parameter_doc["default"] = default
|
102 |
+
elif parameter_doc["doc"] is not None and "kwargs" in parameter_doc["doc"]:
|
103 |
+
parameter_doc["kwargs"] = True
|
104 |
+
parameter_docs.append(parameter_doc)
|
105 |
+
assert (
|
106 |
+
len(parameters) == 0
|
107 |
+
), f"Documentation format for {fn.__name__} documents nonexistent parameters: {''.join(parameters.keys())}"
|
108 |
+
if len(returns) == 0:
|
109 |
+
return_docs = {}
|
110 |
+
elif len(returns) == 1:
|
111 |
+
return_docs = {"annotation": signature.return_annotation, "doc": returns[0]}
|
112 |
+
else:
|
113 |
+
return_docs = {}
|
114 |
+
# raise ValueError("Does not support multiple returns yet.")
|
115 |
+
examples_doc = "\n".join(examples) if len(examples) > 0 else None
|
116 |
+
return description_doc, parameter_docs, return_docs, examples_doc
|
117 |
+
|
118 |
+
|
119 |
+
def document_cls(cls):
|
120 |
+
doc_str = inspect.getdoc(cls)
|
121 |
+
if doc_str is None:
|
122 |
+
return "", {}, ""
|
123 |
+
tags = {}
|
124 |
+
description_lines = []
|
125 |
+
mode = "description"
|
126 |
+
for line in doc_str.split("\n"):
|
127 |
+
line = line.rstrip()
|
128 |
+
if line.endswith(":") and " " not in line:
|
129 |
+
mode = line[:-1].lower()
|
130 |
+
tags[mode] = []
|
131 |
+
elif line.split(" ")[0].endswith(":") and not line.startswith(" "):
|
132 |
+
tag = line[: line.index(":")].lower()
|
133 |
+
value = line[line.index(":") + 2 :]
|
134 |
+
tags[tag] = value
|
135 |
+
else:
|
136 |
+
if mode == "description":
|
137 |
+
description_lines.append(line if line.strip() else "<br>")
|
138 |
+
else:
|
139 |
+
assert (
|
140 |
+
line.startswith(" ") or not line.strip()
|
141 |
+
), f"Documentation format for {cls.__name__} has format error in line: {line}"
|
142 |
+
tags[mode].append(line[4:])
|
143 |
+
if "example" in tags:
|
144 |
+
example = "\n".join(tags["example"])
|
145 |
+
del tags["example"]
|
146 |
+
else:
|
147 |
+
example = None
|
148 |
+
for key, val in tags.items():
|
149 |
+
if isinstance(val, list):
|
150 |
+
tags[key] = "<br>".join(val)
|
151 |
+
description = " ".join(description_lines).replace("\n", "<br>")
|
152 |
+
return description, tags, example
|
153 |
+
|
154 |
+
|
155 |
+
def generate_documentation():
|
156 |
+
documentation = {}
|
157 |
+
for mode, class_list in classes_to_document.items():
|
158 |
+
documentation[mode] = []
|
159 |
+
for cls, fns in class_list:
|
160 |
+
fn_to_document = cls if inspect.isfunction(cls) else cls.__init__
|
161 |
+
_, parameter_doc, return_doc, _ = document_fn(fn_to_document)
|
162 |
+
cls_description, cls_tags, cls_example = document_cls(cls)
|
163 |
+
cls_documentation = {
|
164 |
+
"class": cls,
|
165 |
+
"name": cls.__name__,
|
166 |
+
"description": cls_description,
|
167 |
+
"tags": cls_tags,
|
168 |
+
"parameters": parameter_doc,
|
169 |
+
"returns": return_doc,
|
170 |
+
"example": cls_example,
|
171 |
+
"fns": [],
|
172 |
+
}
|
173 |
+
for fn_name in fns:
|
174 |
+
fn = getattr(cls, fn_name)
|
175 |
+
(
|
176 |
+
description_doc,
|
177 |
+
parameter_docs,
|
178 |
+
return_docs,
|
179 |
+
examples_doc,
|
180 |
+
) = document_fn(fn)
|
181 |
+
cls_documentation["fns"].append(
|
182 |
+
{
|
183 |
+
"fn": fn,
|
184 |
+
"name": fn_name,
|
185 |
+
"description": description_doc,
|
186 |
+
"tags": {},
|
187 |
+
"parameters": parameter_docs,
|
188 |
+
"returns": return_docs,
|
189 |
+
"example": examples_doc,
|
190 |
+
}
|
191 |
+
)
|
192 |
+
documentation[mode].append(cls_documentation)
|
193 |
+
return documentation
|
gradio-modified/gradio/encryptor.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Crypto import Random
|
2 |
+
from Crypto.Cipher import AES
|
3 |
+
from Crypto.Hash import SHA256
|
4 |
+
|
5 |
+
|
6 |
+
def get_key(password: str) -> bytes:
|
7 |
+
"""Generates an encryption key based on the password provided."""
|
8 |
+
key = SHA256.new(password.encode()).digest()
|
9 |
+
return key
|
10 |
+
|
11 |
+
|
12 |
+
def encrypt(key: bytes, source: bytes) -> bytes:
|
13 |
+
"""Encrypts source data using the provided encryption key"""
|
14 |
+
IV = Random.new().read(AES.block_size) # generate IV
|
15 |
+
encryptor = AES.new(key, AES.MODE_CBC, IV)
|
16 |
+
padding = AES.block_size - len(source) % AES.block_size # calculate needed padding
|
17 |
+
source += bytes([padding]) * padding # Python 2.x: source += chr(padding) * padding
|
18 |
+
data = IV + encryptor.encrypt(source) # store the IV at the beginning and encrypt
|
19 |
+
return data
|
20 |
+
|
21 |
+
|
22 |
+
def decrypt(key: bytes, source: bytes) -> bytes:
|
23 |
+
IV = source[: AES.block_size] # extract the IV from the beginning
|
24 |
+
decryptor = AES.new(key, AES.MODE_CBC, IV)
|
25 |
+
data = decryptor.decrypt(source[AES.block_size :]) # decrypt
|
26 |
+
padding = data[-1] # pick the padding value from the end; Python 2.x: ord(data[-1])
|
27 |
+
if (
|
28 |
+
data[-padding:] != bytes([padding]) * padding
|
29 |
+
): # Python 2.x: chr(padding) * padding
|
30 |
+
raise ValueError("Invalid padding...")
|
31 |
+
return data[:-padding] # remove the padding
|
gradio-modified/gradio/events.py
ADDED
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Contains all of the events that can be triggered in a gr.Blocks() app, with the exception
|
2 |
+
of the on-page-load event, which is defined in gr.Blocks().load()."""
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set
|
8 |
+
|
9 |
+
from gradio.blocks import Block
|
10 |
+
from gradio.utils import get_cancel_function
|
11 |
+
|
12 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
13 |
+
from gradio.components import Component, StatusTracker
|
14 |
+
|
15 |
+
|
16 |
+
def set_cancel_events(
|
17 |
+
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
|
18 |
+
):
|
19 |
+
if cancels:
|
20 |
+
if not isinstance(cancels, list):
|
21 |
+
cancels = [cancels]
|
22 |
+
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
|
23 |
+
block.set_event_trigger(
|
24 |
+
event_name,
|
25 |
+
cancel_fn,
|
26 |
+
inputs=None,
|
27 |
+
outputs=None,
|
28 |
+
queue=False,
|
29 |
+
preprocess=False,
|
30 |
+
cancels=fn_indices_to_cancel,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class EventListener(Block):
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
class Changeable(EventListener):
|
39 |
+
def change(
|
40 |
+
self,
|
41 |
+
fn: Callable | None,
|
42 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
43 |
+
outputs: Component | List[Component] | None = None,
|
44 |
+
api_name: str | None = None,
|
45 |
+
status_tracker: StatusTracker | None = None,
|
46 |
+
scroll_to_output: bool = False,
|
47 |
+
show_progress: bool = True,
|
48 |
+
queue: bool | None = None,
|
49 |
+
batch: bool = False,
|
50 |
+
max_batch_size: int = 4,
|
51 |
+
preprocess: bool = True,
|
52 |
+
postprocess: bool = True,
|
53 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
54 |
+
every: float | None = None,
|
55 |
+
_js: str | None = None,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
This event is triggered when the component's input value changes (e.g. when the user types in a textbox
|
59 |
+
or uploads an image). This method can be used when this component is in a Gradio Blocks.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
63 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
64 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
65 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
66 |
+
scroll_to_output: If True, will scroll to output component on completion
|
67 |
+
show_progress: If True, will show progress animation while pending
|
68 |
+
queue: If True, will place the request on the queue, if the queue exists
|
69 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
70 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
71 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
72 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
73 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
74 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
75 |
+
"""
|
76 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
77 |
+
if status_tracker:
|
78 |
+
warnings.warn(
|
79 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
80 |
+
)
|
81 |
+
dep = self.set_event_trigger(
|
82 |
+
"change",
|
83 |
+
fn,
|
84 |
+
inputs,
|
85 |
+
outputs,
|
86 |
+
preprocess=preprocess,
|
87 |
+
postprocess=postprocess,
|
88 |
+
scroll_to_output=scroll_to_output,
|
89 |
+
show_progress=show_progress,
|
90 |
+
api_name=api_name,
|
91 |
+
js=_js,
|
92 |
+
queue=queue,
|
93 |
+
batch=batch,
|
94 |
+
max_batch_size=max_batch_size,
|
95 |
+
every=every,
|
96 |
+
)
|
97 |
+
set_cancel_events(self, "change", cancels)
|
98 |
+
return dep
|
99 |
+
|
100 |
+
|
101 |
+
class Clickable(EventListener):
|
102 |
+
def click(
|
103 |
+
self,
|
104 |
+
fn: Callable | None,
|
105 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
106 |
+
outputs: Component | List[Component] | None = None,
|
107 |
+
api_name: str | None = None,
|
108 |
+
status_tracker: StatusTracker | None = None,
|
109 |
+
scroll_to_output: bool = False,
|
110 |
+
show_progress: bool = True,
|
111 |
+
queue=None,
|
112 |
+
batch: bool = False,
|
113 |
+
max_batch_size: int = 4,
|
114 |
+
preprocess: bool = True,
|
115 |
+
postprocess: bool = True,
|
116 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
117 |
+
every: float | None = None,
|
118 |
+
_js: str | None = None,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
This event is triggered when the component (e.g. a button) is clicked.
|
122 |
+
This method can be used when this component is in a Gradio Blocks.
|
123 |
+
|
124 |
+
Parameters:
|
125 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
126 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
127 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
128 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
129 |
+
scroll_to_output: If True, will scroll to output component on completion
|
130 |
+
show_progress: If True, will show progress animation while pending
|
131 |
+
queue: If True, will place the request on the queue, if the queue exists
|
132 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
133 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
134 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
135 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
136 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
137 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
138 |
+
"""
|
139 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
140 |
+
if status_tracker:
|
141 |
+
warnings.warn(
|
142 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
143 |
+
)
|
144 |
+
|
145 |
+
dep = self.set_event_trigger(
|
146 |
+
"click",
|
147 |
+
fn,
|
148 |
+
inputs,
|
149 |
+
outputs,
|
150 |
+
preprocess=preprocess,
|
151 |
+
postprocess=postprocess,
|
152 |
+
scroll_to_output=scroll_to_output,
|
153 |
+
show_progress=show_progress,
|
154 |
+
api_name=api_name,
|
155 |
+
js=_js,
|
156 |
+
queue=queue,
|
157 |
+
batch=batch,
|
158 |
+
max_batch_size=max_batch_size,
|
159 |
+
every=every,
|
160 |
+
)
|
161 |
+
set_cancel_events(self, "click", cancels)
|
162 |
+
return dep
|
163 |
+
|
164 |
+
|
165 |
+
class Submittable(EventListener):
|
166 |
+
def submit(
|
167 |
+
self,
|
168 |
+
fn: Callable | None,
|
169 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
170 |
+
outputs: Component | List[Component] | None = None,
|
171 |
+
api_name: str | None = None,
|
172 |
+
status_tracker: StatusTracker | None = None,
|
173 |
+
scroll_to_output: bool = False,
|
174 |
+
show_progress: bool = True,
|
175 |
+
queue: bool | None = None,
|
176 |
+
batch: bool = False,
|
177 |
+
max_batch_size: int = 4,
|
178 |
+
preprocess: bool = True,
|
179 |
+
postprocess: bool = True,
|
180 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
181 |
+
every: float | None = None,
|
182 |
+
_js: str | None = None,
|
183 |
+
):
|
184 |
+
"""
|
185 |
+
This event is triggered when the user presses the Enter key while the component (e.g. a textbox) is focused.
|
186 |
+
This method can be used when this component is in a Gradio Blocks.
|
187 |
+
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
191 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
192 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
193 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
194 |
+
scroll_to_output: If True, will scroll to output component on completion
|
195 |
+
show_progress: If True, will show progress animation while pending
|
196 |
+
queue: If True, will place the request on the queue, if the queue exists
|
197 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
198 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
199 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
200 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
201 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
202 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
203 |
+
"""
|
204 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
205 |
+
if status_tracker:
|
206 |
+
warnings.warn(
|
207 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
208 |
+
)
|
209 |
+
|
210 |
+
dep = self.set_event_trigger(
|
211 |
+
"submit",
|
212 |
+
fn,
|
213 |
+
inputs,
|
214 |
+
outputs,
|
215 |
+
preprocess=preprocess,
|
216 |
+
postprocess=postprocess,
|
217 |
+
scroll_to_output=scroll_to_output,
|
218 |
+
show_progress=show_progress,
|
219 |
+
api_name=api_name,
|
220 |
+
js=_js,
|
221 |
+
queue=queue,
|
222 |
+
batch=batch,
|
223 |
+
max_batch_size=max_batch_size,
|
224 |
+
every=every,
|
225 |
+
)
|
226 |
+
set_cancel_events(self, "submit", cancels)
|
227 |
+
return dep
|
228 |
+
|
229 |
+
|
230 |
+
class Editable(EventListener):
|
231 |
+
def edit(
|
232 |
+
self,
|
233 |
+
fn: Callable | None,
|
234 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
235 |
+
outputs: Component | List[Component] | None = None,
|
236 |
+
api_name: str | None = None,
|
237 |
+
status_tracker: StatusTracker | None = None,
|
238 |
+
scroll_to_output: bool = False,
|
239 |
+
show_progress: bool = True,
|
240 |
+
queue: bool | None = None,
|
241 |
+
batch: bool = False,
|
242 |
+
max_batch_size: int = 4,
|
243 |
+
preprocess: bool = True,
|
244 |
+
postprocess: bool = True,
|
245 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
246 |
+
every: float | None = None,
|
247 |
+
_js: str | None = None,
|
248 |
+
):
|
249 |
+
"""
|
250 |
+
This event is triggered when the user edits the component (e.g. image) using the
|
251 |
+
built-in editor. This method can be used when this component is in a Gradio Blocks.
|
252 |
+
|
253 |
+
Parameters:
|
254 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
255 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
256 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
257 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
258 |
+
scroll_to_output: If True, will scroll to output component on completion
|
259 |
+
show_progress: If True, will show progress animation while pending
|
260 |
+
queue: If True, will place the request on the queue, if the queue exists
|
261 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
262 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
263 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
264 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
265 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
266 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
267 |
+
"""
|
268 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
269 |
+
if status_tracker:
|
270 |
+
warnings.warn(
|
271 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
272 |
+
)
|
273 |
+
|
274 |
+
dep = self.set_event_trigger(
|
275 |
+
"edit",
|
276 |
+
fn,
|
277 |
+
inputs,
|
278 |
+
outputs,
|
279 |
+
preprocess=preprocess,
|
280 |
+
postprocess=postprocess,
|
281 |
+
scroll_to_output=scroll_to_output,
|
282 |
+
show_progress=show_progress,
|
283 |
+
api_name=api_name,
|
284 |
+
js=_js,
|
285 |
+
queue=queue,
|
286 |
+
batch=batch,
|
287 |
+
max_batch_size=max_batch_size,
|
288 |
+
every=every,
|
289 |
+
)
|
290 |
+
set_cancel_events(self, "edit", cancels)
|
291 |
+
return dep
|
292 |
+
|
293 |
+
|
294 |
+
class Clearable(EventListener):
|
295 |
+
def clear(
|
296 |
+
self,
|
297 |
+
fn: Callable | None,
|
298 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
299 |
+
outputs: Component | List[Component] | None = None,
|
300 |
+
api_name: str | None = None,
|
301 |
+
status_tracker: StatusTracker | None = None,
|
302 |
+
scroll_to_output: bool = False,
|
303 |
+
show_progress: bool = True,
|
304 |
+
queue: bool | None = None,
|
305 |
+
batch: bool = False,
|
306 |
+
max_batch_size: int = 4,
|
307 |
+
preprocess: bool = True,
|
308 |
+
postprocess: bool = True,
|
309 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
310 |
+
every: float | None = None,
|
311 |
+
_js: str | None = None,
|
312 |
+
):
|
313 |
+
"""
|
314 |
+
This event is triggered when the user clears the component (e.g. image or audio)
|
315 |
+
using the X button for the component. This method can be used when this component is in a Gradio Blocks.
|
316 |
+
|
317 |
+
Parameters:
|
318 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
319 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
320 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
321 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
322 |
+
scroll_to_output: If True, will scroll to output component on completion
|
323 |
+
show_progress: If True, will show progress animation while pending
|
324 |
+
queue: If True, will place the request on the queue, if the queue exists
|
325 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
326 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
327 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
328 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
329 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
330 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
331 |
+
"""
|
332 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
333 |
+
if status_tracker:
|
334 |
+
warnings.warn(
|
335 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
336 |
+
)
|
337 |
+
|
338 |
+
dep = self.set_event_trigger(
|
339 |
+
"submit",
|
340 |
+
fn,
|
341 |
+
inputs,
|
342 |
+
outputs,
|
343 |
+
preprocess=preprocess,
|
344 |
+
postprocess=postprocess,
|
345 |
+
scroll_to_output=scroll_to_output,
|
346 |
+
show_progress=show_progress,
|
347 |
+
api_name=api_name,
|
348 |
+
js=_js,
|
349 |
+
queue=queue,
|
350 |
+
batch=batch,
|
351 |
+
max_batch_size=max_batch_size,
|
352 |
+
every=every,
|
353 |
+
)
|
354 |
+
set_cancel_events(self, "submit", cancels)
|
355 |
+
return dep
|
356 |
+
|
357 |
+
|
358 |
+
class Playable(EventListener):
|
359 |
+
def play(
|
360 |
+
self,
|
361 |
+
fn: Callable | None,
|
362 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
363 |
+
outputs: Component | List[Component] | None = None,
|
364 |
+
api_name: str | None = None,
|
365 |
+
status_tracker: StatusTracker | None = None,
|
366 |
+
scroll_to_output: bool = False,
|
367 |
+
show_progress: bool = True,
|
368 |
+
queue: bool | None = None,
|
369 |
+
batch: bool = False,
|
370 |
+
max_batch_size: int = 4,
|
371 |
+
preprocess: bool = True,
|
372 |
+
postprocess: bool = True,
|
373 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
374 |
+
every: float | None = None,
|
375 |
+
_js: str | None = None,
|
376 |
+
):
|
377 |
+
"""
|
378 |
+
This event is triggered when the user plays the component (e.g. audio or video).
|
379 |
+
This method can be used when this component is in a Gradio Blocks.
|
380 |
+
|
381 |
+
Parameters:
|
382 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
383 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
384 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
385 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
386 |
+
scroll_to_output: If True, will scroll to output component on completion
|
387 |
+
show_progress: If True, will show progress animation while pending
|
388 |
+
queue: If True, will place the request on the queue, if the queue exists
|
389 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
390 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
391 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
392 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
393 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
394 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
395 |
+
"""
|
396 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
397 |
+
if status_tracker:
|
398 |
+
warnings.warn(
|
399 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
400 |
+
)
|
401 |
+
|
402 |
+
dep = self.set_event_trigger(
|
403 |
+
"play",
|
404 |
+
fn,
|
405 |
+
inputs,
|
406 |
+
outputs,
|
407 |
+
preprocess=preprocess,
|
408 |
+
postprocess=postprocess,
|
409 |
+
scroll_to_output=scroll_to_output,
|
410 |
+
show_progress=show_progress,
|
411 |
+
api_name=api_name,
|
412 |
+
js=_js,
|
413 |
+
queue=queue,
|
414 |
+
batch=batch,
|
415 |
+
max_batch_size=max_batch_size,
|
416 |
+
every=every,
|
417 |
+
)
|
418 |
+
set_cancel_events(self, "play", cancels)
|
419 |
+
return dep
|
420 |
+
|
421 |
+
def pause(
|
422 |
+
self,
|
423 |
+
fn: Callable | None,
|
424 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
425 |
+
outputs: Component | List[Component] | None = None,
|
426 |
+
api_name: str | None = None,
|
427 |
+
status_tracker: StatusTracker | None = None,
|
428 |
+
scroll_to_output: bool = False,
|
429 |
+
show_progress: bool = True,
|
430 |
+
queue: bool | None = None,
|
431 |
+
batch: bool = False,
|
432 |
+
max_batch_size: int = 4,
|
433 |
+
preprocess: bool = True,
|
434 |
+
postprocess: bool = True,
|
435 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
436 |
+
every: float | None = None,
|
437 |
+
_js: str | None = None,
|
438 |
+
):
|
439 |
+
"""
|
440 |
+
This event is triggered when the user pauses the component (e.g. audio or video).
|
441 |
+
This method can be used when this component is in a Gradio Blocks.
|
442 |
+
|
443 |
+
Parameters:
|
444 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
445 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
446 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
447 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
448 |
+
scroll_to_output: If True, will scroll to output component on completion
|
449 |
+
show_progress: If True, will show progress animation while pending
|
450 |
+
queue: If True, will place the request on the queue, if the queue exists
|
451 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
452 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
453 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
454 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
455 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
456 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
457 |
+
"""
|
458 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
459 |
+
if status_tracker:
|
460 |
+
warnings.warn(
|
461 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
462 |
+
)
|
463 |
+
|
464 |
+
dep = self.set_event_trigger(
|
465 |
+
"pause",
|
466 |
+
fn,
|
467 |
+
inputs,
|
468 |
+
outputs,
|
469 |
+
preprocess=preprocess,
|
470 |
+
postprocess=postprocess,
|
471 |
+
scroll_to_output=scroll_to_output,
|
472 |
+
show_progress=show_progress,
|
473 |
+
api_name=api_name,
|
474 |
+
js=_js,
|
475 |
+
queue=queue,
|
476 |
+
batch=batch,
|
477 |
+
max_batch_size=max_batch_size,
|
478 |
+
every=every,
|
479 |
+
)
|
480 |
+
set_cancel_events(self, "pause", cancels)
|
481 |
+
return dep
|
482 |
+
|
483 |
+
def stop(
|
484 |
+
self,
|
485 |
+
fn: Callable | None,
|
486 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
487 |
+
outputs: Component | List[Component] | None = None,
|
488 |
+
api_name: str | None = None,
|
489 |
+
status_tracker: StatusTracker | None = None,
|
490 |
+
scroll_to_output: bool = False,
|
491 |
+
show_progress: bool = True,
|
492 |
+
queue: bool | None = None,
|
493 |
+
batch: bool = False,
|
494 |
+
max_batch_size: int = 4,
|
495 |
+
preprocess: bool = True,
|
496 |
+
postprocess: bool = True,
|
497 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
498 |
+
every: float | None = None,
|
499 |
+
_js: str | None = None,
|
500 |
+
):
|
501 |
+
"""
|
502 |
+
This event is triggered when the user stops the component (e.g. audio or video).
|
503 |
+
This method can be used when this component is in a Gradio Blocks.
|
504 |
+
|
505 |
+
Parameters:
|
506 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
507 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
508 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
509 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
510 |
+
scroll_to_output: If True, will scroll to output component on completion
|
511 |
+
show_progress: If True, will show progress animation while pending
|
512 |
+
queue: If True, will place the request on the queue, if the queue exists
|
513 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
514 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
515 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
516 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
517 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
518 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
519 |
+
"""
|
520 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
521 |
+
if status_tracker:
|
522 |
+
warnings.warn(
|
523 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
524 |
+
)
|
525 |
+
|
526 |
+
dep = self.set_event_trigger(
|
527 |
+
"stop",
|
528 |
+
fn,
|
529 |
+
inputs,
|
530 |
+
outputs,
|
531 |
+
preprocess=preprocess,
|
532 |
+
postprocess=postprocess,
|
533 |
+
scroll_to_output=scroll_to_output,
|
534 |
+
show_progress=show_progress,
|
535 |
+
api_name=api_name,
|
536 |
+
js=_js,
|
537 |
+
queue=queue,
|
538 |
+
batch=batch,
|
539 |
+
max_batch_size=max_batch_size,
|
540 |
+
every=every,
|
541 |
+
)
|
542 |
+
set_cancel_events(self, "stop", cancels)
|
543 |
+
return dep
|
544 |
+
|
545 |
+
|
546 |
+
class Streamable(EventListener):
|
547 |
+
def stream(
|
548 |
+
self,
|
549 |
+
fn: Callable | None,
|
550 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
551 |
+
outputs: Component | List[Component] | None = None,
|
552 |
+
api_name: str | None = None,
|
553 |
+
status_tracker: StatusTracker | None = None,
|
554 |
+
scroll_to_output: bool = False,
|
555 |
+
show_progress: bool = False,
|
556 |
+
queue: bool | None = None,
|
557 |
+
batch: bool = False,
|
558 |
+
max_batch_size: int = 4,
|
559 |
+
preprocess: bool = True,
|
560 |
+
postprocess: bool = True,
|
561 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
562 |
+
every: float | None = None,
|
563 |
+
_js: str | None = None,
|
564 |
+
):
|
565 |
+
"""
|
566 |
+
This event is triggered when the user streams the component (e.g. a live webcam
|
567 |
+
component). This method can be used when this component is in a Gradio Blocks.
|
568 |
+
|
569 |
+
Parameters:
|
570 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
571 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
572 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
573 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
574 |
+
scroll_to_output: If True, will scroll to output component on completion
|
575 |
+
show_progress: If True, will show progress animation while pending
|
576 |
+
queue: If True, will place the request on the queue, if the queue exists
|
577 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
578 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
579 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
580 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
581 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
582 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
583 |
+
"""
|
584 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
585 |
+
self.streaming = True
|
586 |
+
|
587 |
+
if status_tracker:
|
588 |
+
warnings.warn(
|
589 |
+
"The 'status_tracker' parameter has been deprecated and has no effect."
|
590 |
+
)
|
591 |
+
|
592 |
+
dep = self.set_event_trigger(
|
593 |
+
"stream",
|
594 |
+
fn,
|
595 |
+
inputs,
|
596 |
+
outputs,
|
597 |
+
preprocess=preprocess,
|
598 |
+
postprocess=postprocess,
|
599 |
+
scroll_to_output=scroll_to_output,
|
600 |
+
show_progress=show_progress,
|
601 |
+
api_name=api_name,
|
602 |
+
js=_js,
|
603 |
+
queue=queue,
|
604 |
+
batch=batch,
|
605 |
+
max_batch_size=max_batch_size,
|
606 |
+
every=every,
|
607 |
+
)
|
608 |
+
set_cancel_events(self, "stream", cancels)
|
609 |
+
return dep
|
610 |
+
|
611 |
+
|
612 |
+
class Blurrable(EventListener):
|
613 |
+
def blur(
|
614 |
+
self,
|
615 |
+
fn: Callable | None,
|
616 |
+
inputs: Component | List[Component] | Set[Component] | None = None,
|
617 |
+
outputs: Component | List[Component] | None = None,
|
618 |
+
api_name: str | None = None,
|
619 |
+
scroll_to_output: bool = False,
|
620 |
+
show_progress: bool = True,
|
621 |
+
queue: bool | None = None,
|
622 |
+
batch: bool = False,
|
623 |
+
max_batch_size: int = 4,
|
624 |
+
preprocess: bool = True,
|
625 |
+
postprocess: bool = True,
|
626 |
+
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
627 |
+
every: float | None = None,
|
628 |
+
_js: str | None = None,
|
629 |
+
):
|
630 |
+
"""
|
631 |
+
This event is triggered when the component's is unfocused/blurred (e.g. when the user clicks outside of a textbox). This method can be used when this component is in a Gradio Blocks.
|
632 |
+
|
633 |
+
Parameters:
|
634 |
+
fn: Callable function
|
635 |
+
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
636 |
+
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
|
637 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
638 |
+
scroll_to_output: If True, will scroll to output component on completion
|
639 |
+
show_progress: If True, will show progress animation while pending
|
640 |
+
queue: If True, will place the request on the queue, if the queue exists
|
641 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
642 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
643 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
644 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
645 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
646 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
647 |
+
"""
|
648 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
649 |
+
|
650 |
+
self.set_event_trigger(
|
651 |
+
"blur",
|
652 |
+
fn,
|
653 |
+
inputs,
|
654 |
+
outputs,
|
655 |
+
preprocess=preprocess,
|
656 |
+
postprocess=postprocess,
|
657 |
+
scroll_to_output=scroll_to_output,
|
658 |
+
show_progress=show_progress,
|
659 |
+
api_name=api_name,
|
660 |
+
js=_js,
|
661 |
+
queue=queue,
|
662 |
+
batch=batch,
|
663 |
+
max_batch_size=max_batch_size,
|
664 |
+
every=every,
|
665 |
+
)
|
666 |
+
set_cancel_events(self, "blur", cancels)
|
667 |
+
|
668 |
+
|
669 |
+
class Uploadable(EventListener):
|
670 |
+
def upload(
|
671 |
+
self,
|
672 |
+
fn: Callable | None,
|
673 |
+
inputs: List[Component],
|
674 |
+
outputs: Component | List[Component] | None = None,
|
675 |
+
api_name: str | None = None,
|
676 |
+
scroll_to_output: bool = False,
|
677 |
+
show_progress: bool = True,
|
678 |
+
queue: bool | None = None,
|
679 |
+
batch: bool = False,
|
680 |
+
max_batch_size: int = 4,
|
681 |
+
preprocess: bool = True,
|
682 |
+
postprocess: bool = True,
|
683 |
+
cancels: List[Dict[str, Any]] | None = None,
|
684 |
+
every: float | None = None,
|
685 |
+
_js: str | None = None,
|
686 |
+
):
|
687 |
+
"""
|
688 |
+
This event is triggered when the user uploads a file into the component (e.g. when the user uploads a video into a video component). This method can be used when this component is in a Gradio Blocks.
|
689 |
+
|
690 |
+
Parameters:
|
691 |
+
fn: Callable function
|
692 |
+
inputs: List of inputs
|
693 |
+
outputs: List of outputs
|
694 |
+
api_name: Defining this parameter exposes the endpoint in the api docs
|
695 |
+
scroll_to_output: If True, will scroll to output component on completion
|
696 |
+
show_progress: If True, will show progress animation while pending
|
697 |
+
queue: If True, will place the request on the queue, if the queue exists
|
698 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
699 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
700 |
+
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
701 |
+
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
702 |
+
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
703 |
+
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
704 |
+
"""
|
705 |
+
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
706 |
+
|
707 |
+
self.set_event_trigger(
|
708 |
+
"upload",
|
709 |
+
fn,
|
710 |
+
inputs,
|
711 |
+
outputs,
|
712 |
+
preprocess=preprocess,
|
713 |
+
postprocess=postprocess,
|
714 |
+
scroll_to_output=scroll_to_output,
|
715 |
+
show_progress=show_progress,
|
716 |
+
api_name=api_name,
|
717 |
+
js=_js,
|
718 |
+
queue=queue,
|
719 |
+
batch=batch,
|
720 |
+
max_batch_size=max_batch_size,
|
721 |
+
every=every,
|
722 |
+
)
|
723 |
+
set_cancel_events(self, "upload", cancels)
|
gradio-modified/gradio/examples.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines helper methods useful for loading and caching Interface examples.
|
3 |
+
"""
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import csv
|
8 |
+
import os
|
9 |
+
import warnings
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import TYPE_CHECKING, Any, Callable, List
|
12 |
+
|
13 |
+
from gradio import utils
|
14 |
+
from gradio.components import Dataset
|
15 |
+
from gradio.context import Context
|
16 |
+
from gradio.documentation import document, set_documentation_group
|
17 |
+
from gradio.flagging import CSVLogger
|
18 |
+
|
19 |
+
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
20 |
+
from gradio.components import IOComponent
|
21 |
+
|
22 |
+
CACHED_FOLDER = "gradio_cached_examples"
|
23 |
+
LOG_FILE = "log.csv"
|
24 |
+
|
25 |
+
set_documentation_group("component-helpers")
|
26 |
+
|
27 |
+
|
28 |
+
def create_examples(
|
29 |
+
examples: List[Any] | List[List[Any]] | str,
|
30 |
+
inputs: IOComponent | List[IOComponent],
|
31 |
+
outputs: IOComponent | List[IOComponent] | None = None,
|
32 |
+
fn: Callable | None = None,
|
33 |
+
cache_examples: bool = False,
|
34 |
+
examples_per_page: int = 10,
|
35 |
+
_api_mode: bool = False,
|
36 |
+
label: str | None = None,
|
37 |
+
elem_id: str | None = None,
|
38 |
+
run_on_click: bool = False,
|
39 |
+
preprocess: bool = True,
|
40 |
+
postprocess: bool = True,
|
41 |
+
batch: bool = False,
|
42 |
+
):
|
43 |
+
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
|
44 |
+
examples_obj = Examples(
|
45 |
+
examples=examples,
|
46 |
+
inputs=inputs,
|
47 |
+
outputs=outputs,
|
48 |
+
fn=fn,
|
49 |
+
cache_examples=cache_examples,
|
50 |
+
examples_per_page=examples_per_page,
|
51 |
+
_api_mode=_api_mode,
|
52 |
+
label=label,
|
53 |
+
elem_id=elem_id,
|
54 |
+
run_on_click=run_on_click,
|
55 |
+
preprocess=preprocess,
|
56 |
+
postprocess=postprocess,
|
57 |
+
batch=batch,
|
58 |
+
_initiated_directly=False,
|
59 |
+
)
|
60 |
+
utils.synchronize_async(examples_obj.create)
|
61 |
+
return examples_obj
|
62 |
+
|
63 |
+
|
64 |
+
@document()
|
65 |
+
class Examples:
|
66 |
+
"""
|
67 |
+
This class is a wrapper over the Dataset component and can be used to create Examples
|
68 |
+
for Blocks / Interfaces. Populates the Dataset component with examples and
|
69 |
+
assigns event listener so that clicking on an example populates the input/output
|
70 |
+
components. Optionally handles example caching for fast inference.
|
71 |
+
|
72 |
+
Demos: blocks_inputs, fake_gan
|
73 |
+
Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
examples: List[Any] | List[List[Any]] | str,
|
79 |
+
inputs: IOComponent | List[IOComponent],
|
80 |
+
outputs: IOComponent | List[IOComponent] | None = None,
|
81 |
+
fn: Callable | None = None,
|
82 |
+
cache_examples: bool = False,
|
83 |
+
examples_per_page: int = 10,
|
84 |
+
_api_mode: bool = False,
|
85 |
+
label: str | None = "Examples",
|
86 |
+
elem_id: str | None = None,
|
87 |
+
run_on_click: bool = False,
|
88 |
+
preprocess: bool = True,
|
89 |
+
postprocess: bool = True,
|
90 |
+
batch: bool = False,
|
91 |
+
_initiated_directly: bool = True,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Parameters:
|
95 |
+
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
96 |
+
inputs: the component or list of components corresponding to the examples
|
97 |
+
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
98 |
+
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
99 |
+
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
|
100 |
+
examples_per_page: how many examples to show per page.
|
101 |
+
label: the label to use for the examples component (by default, "Examples")
|
102 |
+
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
|
103 |
+
run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
|
104 |
+
preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
|
105 |
+
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
|
106 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
|
107 |
+
"""
|
108 |
+
if _initiated_directly:
|
109 |
+
warnings.warn(
|
110 |
+
"Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
|
111 |
+
)
|
112 |
+
|
113 |
+
if cache_examples and (fn is None or outputs is None):
|
114 |
+
raise ValueError("If caching examples, `fn` and `outputs` must be provided")
|
115 |
+
|
116 |
+
if not isinstance(inputs, list):
|
117 |
+
inputs = [inputs]
|
118 |
+
if outputs and not isinstance(outputs, list):
|
119 |
+
outputs = [outputs]
|
120 |
+
|
121 |
+
working_directory = Path().absolute()
|
122 |
+
|
123 |
+
if examples is None:
|
124 |
+
raise ValueError("The parameter `examples` cannot be None")
|
125 |
+
elif isinstance(examples, list) and (
|
126 |
+
len(examples) == 0 or isinstance(examples[0], list)
|
127 |
+
):
|
128 |
+
pass
|
129 |
+
elif (
|
130 |
+
isinstance(examples, list) and len(inputs) == 1
|
131 |
+
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
132 |
+
examples = [[e] for e in examples]
|
133 |
+
elif isinstance(examples, str):
|
134 |
+
if not Path(examples).exists():
|
135 |
+
raise FileNotFoundError(
|
136 |
+
"Could not find examples directory: " + examples
|
137 |
+
)
|
138 |
+
working_directory = examples
|
139 |
+
if not (Path(examples) / LOG_FILE).exists():
|
140 |
+
if len(inputs) == 1:
|
141 |
+
examples = [[e] for e in os.listdir(examples)]
|
142 |
+
else:
|
143 |
+
raise FileNotFoundError(
|
144 |
+
"Could not find log file (required for multiple inputs): "
|
145 |
+
+ LOG_FILE
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
with open(Path(examples) / LOG_FILE) as logs:
|
149 |
+
examples = list(csv.reader(logs))
|
150 |
+
examples = [
|
151 |
+
examples[i][: len(inputs)] for i in range(1, len(examples))
|
152 |
+
] # remove header and unnecessary columns
|
153 |
+
|
154 |
+
else:
|
155 |
+
raise ValueError(
|
156 |
+
"The parameter `examples` must either be a string directory or a list"
|
157 |
+
"(if there is only 1 input component) or (more generally), a nested "
|
158 |
+
"list, where each sublist represents a set of inputs."
|
159 |
+
)
|
160 |
+
|
161 |
+
input_has_examples = [False] * len(inputs)
|
162 |
+
for example in examples:
|
163 |
+
for idx, example_for_input in enumerate(example):
|
164 |
+
if not (example_for_input is None):
|
165 |
+
try:
|
166 |
+
input_has_examples[idx] = True
|
167 |
+
except IndexError:
|
168 |
+
pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
|
169 |
+
|
170 |
+
inputs_with_examples = [
|
171 |
+
inp for (inp, keep) in zip(inputs, input_has_examples) if keep
|
172 |
+
]
|
173 |
+
non_none_examples = [
|
174 |
+
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
|
175 |
+
for example in examples
|
176 |
+
]
|
177 |
+
|
178 |
+
self.examples = examples
|
179 |
+
self.non_none_examples = non_none_examples
|
180 |
+
self.inputs = inputs
|
181 |
+
self.inputs_with_examples = inputs_with_examples
|
182 |
+
self.outputs = outputs
|
183 |
+
self.fn = fn
|
184 |
+
self.cache_examples = cache_examples
|
185 |
+
self._api_mode = _api_mode
|
186 |
+
self.preprocess = preprocess
|
187 |
+
self.postprocess = postprocess
|
188 |
+
self.batch = batch
|
189 |
+
|
190 |
+
with utils.set_directory(working_directory):
|
191 |
+
self.processed_examples = [
|
192 |
+
[
|
193 |
+
component.postprocess(sample)
|
194 |
+
for component, sample in zip(inputs, example)
|
195 |
+
]
|
196 |
+
for example in examples
|
197 |
+
]
|
198 |
+
self.non_none_processed_examples = [
|
199 |
+
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
|
200 |
+
for example in self.processed_examples
|
201 |
+
]
|
202 |
+
if cache_examples:
|
203 |
+
for example in self.examples:
|
204 |
+
if len([ex for ex in example if ex is not None]) != len(self.inputs):
|
205 |
+
warnings.warn(
|
206 |
+
"Examples are being cached but not all input components have "
|
207 |
+
"example values. This may result in an exception being thrown by "
|
208 |
+
"your function. If you do get an error while caching examples, make "
|
209 |
+
"sure all of your inputs have example values for all of your examples "
|
210 |
+
"or you provide default values for those particular parameters in your function."
|
211 |
+
)
|
212 |
+
break
|
213 |
+
|
214 |
+
with utils.set_directory(working_directory):
|
215 |
+
self.dataset = Dataset(
|
216 |
+
components=inputs_with_examples,
|
217 |
+
samples=non_none_examples,
|
218 |
+
type="index",
|
219 |
+
label=label,
|
220 |
+
samples_per_page=examples_per_page,
|
221 |
+
elem_id=elem_id,
|
222 |
+
)
|
223 |
+
|
224 |
+
self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
|
225 |
+
self.cached_file = Path(self.cached_folder) / "log.csv"
|
226 |
+
self.cache_examples = cache_examples
|
227 |
+
self.run_on_click = run_on_click
|
228 |
+
|
229 |
+
async def create(self) -> None:
|
230 |
+
"""Caches the examples if self.cache_examples is True and creates the Dataset
|
231 |
+
component to hold the examples"""
|
232 |
+
|
233 |
+
async def load_example(example_id):
|
234 |
+
if self.cache_examples:
|
235 |
+
processed_example = self.non_none_processed_examples[
|
236 |
+
example_id
|
237 |
+
] + await self.load_from_cache(example_id)
|
238 |
+
else:
|
239 |
+
processed_example = self.non_none_processed_examples[example_id]
|
240 |
+
return utils.resolve_singleton(processed_example)
|
241 |
+
|
242 |
+
if Context.root_block:
|
243 |
+
if self.cache_examples and self.outputs:
|
244 |
+
targets = self.inputs_with_examples
|
245 |
+
else:
|
246 |
+
targets = self.inputs
|
247 |
+
self.dataset.click(
|
248 |
+
load_example,
|
249 |
+
inputs=[self.dataset],
|
250 |
+
outputs=targets, # type: ignore
|
251 |
+
postprocess=False,
|
252 |
+
queue=False,
|
253 |
+
)
|
254 |
+
if self.run_on_click and not self.cache_examples:
|
255 |
+
if self.fn is None:
|
256 |
+
raise ValueError("Cannot run_on_click if no function is provided")
|
257 |
+
self.dataset.click(
|
258 |
+
self.fn,
|
259 |
+
inputs=self.inputs, # type: ignore
|
260 |
+
outputs=self.outputs, # type: ignore
|
261 |
+
)
|
262 |
+
|
263 |
+
if self.cache_examples:
|
264 |
+
await self.cache()
|
265 |
+
|
266 |
+
async def cache(self) -> None:
|
267 |
+
"""
|
268 |
+
Caches all of the examples so that their predictions can be shown immediately.
|
269 |
+
"""
|
270 |
+
if Path(self.cached_file).exists():
|
271 |
+
print(
|
272 |
+
f"Using cache from '{Path(self.cached_folder).resolve()}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
273 |
+
)
|
274 |
+
else:
|
275 |
+
if Context.root_block is None:
|
276 |
+
raise ValueError("Cannot cache examples if not in a Blocks context")
|
277 |
+
|
278 |
+
print(f"Caching examples at: '{Path(self.cached_file).resolve()}'")
|
279 |
+
cache_logger = CSVLogger()
|
280 |
+
|
281 |
+
# create a fake dependency to process the examples and get the predictions
|
282 |
+
dependency = Context.root_block.set_event_trigger(
|
283 |
+
event_name="fake_event",
|
284 |
+
fn=self.fn,
|
285 |
+
inputs=self.inputs_with_examples, # type: ignore
|
286 |
+
outputs=self.outputs, # type: ignore
|
287 |
+
preprocess=self.preprocess and not self._api_mode,
|
288 |
+
postprocess=self.postprocess and not self._api_mode,
|
289 |
+
batch=self.batch,
|
290 |
+
)
|
291 |
+
|
292 |
+
fn_index = Context.root_block.dependencies.index(dependency)
|
293 |
+
assert self.outputs is not None
|
294 |
+
cache_logger.setup(self.outputs, self.cached_folder)
|
295 |
+
for example_id, _ in enumerate(self.examples):
|
296 |
+
processed_input = self.processed_examples[example_id]
|
297 |
+
if self.batch:
|
298 |
+
processed_input = [[value] for value in processed_input]
|
299 |
+
prediction = await Context.root_block.process_api(
|
300 |
+
fn_index=fn_index, inputs=processed_input, request=None, state={}
|
301 |
+
)
|
302 |
+
output = prediction["data"]
|
303 |
+
if self.batch:
|
304 |
+
output = [value[0] for value in output]
|
305 |
+
cache_logger.flag(output)
|
306 |
+
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
307 |
+
Context.root_block.dependencies.remove(dependency)
|
308 |
+
Context.root_block.fns.pop(fn_index)
|
309 |
+
|
310 |
+
async def load_from_cache(self, example_id: int) -> List[Any]:
|
311 |
+
"""Loads a particular cached example for the interface.
|
312 |
+
Parameters:
|
313 |
+
example_id: The id of the example to process (zero-indexed).
|
314 |
+
"""
|
315 |
+
with open(self.cached_file) as cache:
|
316 |
+
examples = list(csv.reader(cache))
|
317 |
+
example = examples[example_id + 1] # +1 to adjust for header
|
318 |
+
output = []
|
319 |
+
assert self.outputs is not None
|
320 |
+
for component, value in zip(self.outputs, example):
|
321 |
+
try:
|
322 |
+
value_as_dict = ast.literal_eval(value)
|
323 |
+
assert utils.is_update(value_as_dict)
|
324 |
+
output.append(value_as_dict)
|
325 |
+
except (ValueError, TypeError, SyntaxError, AssertionError):
|
326 |
+
output.append(component.serialize(value, self.cached_folder))
|
327 |
+
return output
|
gradio-modified/gradio/exceptions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class DuplicateBlockError(ValueError):
|
2 |
+
"""Raised when a Blocks contains more than one Block with the same id"""
|
3 |
+
|
4 |
+
pass
|
5 |
+
|
6 |
+
|
7 |
+
class TooManyRequestsError(Exception):
|
8 |
+
"""Raised when the Hugging Face API returns a 429 status code."""
|
9 |
+
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
class InvalidApiName(ValueError):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class Error(Exception):
|
18 |
+
def __init__(self, message: str):
|
19 |
+
self.message = message
|
20 |
+
super().__init__(self.message)
|
21 |
+
|
22 |
+
def __str__(self):
|
23 |
+
return repr(self.message)
|
gradio-modified/gradio/external.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module should not be used directly as its API is subject to change. Instead,
|
2 |
+
use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
import uuid
|
9 |
+
import warnings
|
10 |
+
from copy import deepcopy
|
11 |
+
from typing import TYPE_CHECKING, Callable, Dict
|
12 |
+
|
13 |
+
import requests
|
14 |
+
|
15 |
+
import gradio
|
16 |
+
from gradio import components, utils
|
17 |
+
from gradio.exceptions import TooManyRequestsError
|
18 |
+
from gradio.external_utils import (
|
19 |
+
cols_to_rows,
|
20 |
+
encode_to_base64,
|
21 |
+
get_tabular_examples,
|
22 |
+
get_ws_fn,
|
23 |
+
postprocess_label,
|
24 |
+
rows_to_cols,
|
25 |
+
streamline_spaces_interface,
|
26 |
+
use_websocket,
|
27 |
+
)
|
28 |
+
from gradio.processing_utils import to_binary
|
29 |
+
|
30 |
+
if TYPE_CHECKING:
|
31 |
+
from gradio.blocks import Blocks
|
32 |
+
from gradio.interface import Interface
|
33 |
+
|
34 |
+
|
35 |
+
def load_blocks_from_repo(
|
36 |
+
name: str,
|
37 |
+
src: str | None = None,
|
38 |
+
api_key: str | None = None,
|
39 |
+
alias: str | None = None,
|
40 |
+
**kwargs,
|
41 |
+
) -> Blocks:
|
42 |
+
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
|
43 |
+
if src is None:
|
44 |
+
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
|
45 |
+
tokens = name.split("/")
|
46 |
+
assert (
|
47 |
+
len(tokens) > 1
|
48 |
+
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
|
49 |
+
src = tokens[0]
|
50 |
+
name = "/".join(tokens[1:])
|
51 |
+
|
52 |
+
factory_methods: Dict[str, Callable] = {
|
53 |
+
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
|
54 |
+
"huggingface": from_model,
|
55 |
+
"models": from_model,
|
56 |
+
"spaces": from_spaces,
|
57 |
+
}
|
58 |
+
assert src.lower() in factory_methods, "parameter: src must be one of {}".format(
|
59 |
+
factory_methods.keys()
|
60 |
+
)
|
61 |
+
|
62 |
+
blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
|
63 |
+
return blocks
|
64 |
+
|
65 |
+
|
66 |
+
def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
|
67 |
+
model_url = "https://huggingface.co/{}".format(model_name)
|
68 |
+
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
69 |
+
print("Fetching model from: {}".format(model_url))
|
70 |
+
|
71 |
+
headers = {"Authorization": f"Bearer {api_key}"} if api_key is not None else {}
|
72 |
+
|
73 |
+
# Checking if model exists, and if so, it gets the pipeline
|
74 |
+
response = requests.request("GET", api_url, headers=headers)
|
75 |
+
assert (
|
76 |
+
response.status_code == 200
|
77 |
+
), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
|
78 |
+
p = response.json().get("pipeline_tag")
|
79 |
+
|
80 |
+
pipelines = {
|
81 |
+
"audio-classification": {
|
82 |
+
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
83 |
+
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
84 |
+
"outputs": components.Label(label="Class"),
|
85 |
+
"preprocess": lambda i: to_binary,
|
86 |
+
"postprocess": lambda r: postprocess_label(
|
87 |
+
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
88 |
+
),
|
89 |
+
},
|
90 |
+
"audio-to-audio": {
|
91 |
+
# example model: facebook/xm_transformer_sm_all-en
|
92 |
+
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
93 |
+
"outputs": components.Audio(label="Output"),
|
94 |
+
"preprocess": to_binary,
|
95 |
+
"postprocess": encode_to_base64,
|
96 |
+
},
|
97 |
+
"automatic-speech-recognition": {
|
98 |
+
# example model: facebook/wav2vec2-base-960h
|
99 |
+
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
100 |
+
"outputs": components.Textbox(label="Output"),
|
101 |
+
"preprocess": to_binary,
|
102 |
+
"postprocess": lambda r: r.json()["text"],
|
103 |
+
},
|
104 |
+
"feature-extraction": {
|
105 |
+
# example model: julien-c/distilbert-feature-extraction
|
106 |
+
"inputs": components.Textbox(label="Input"),
|
107 |
+
"outputs": components.Dataframe(label="Output"),
|
108 |
+
"preprocess": lambda x: {"inputs": x},
|
109 |
+
"postprocess": lambda r: r.json()[0],
|
110 |
+
},
|
111 |
+
"fill-mask": {
|
112 |
+
"inputs": components.Textbox(label="Input"),
|
113 |
+
"outputs": components.Label(label="Classification"),
|
114 |
+
"preprocess": lambda x: {"inputs": x},
|
115 |
+
"postprocess": lambda r: postprocess_label(
|
116 |
+
{i["token_str"]: i["score"] for i in r.json()}
|
117 |
+
),
|
118 |
+
},
|
119 |
+
"image-classification": {
|
120 |
+
# Example: google/vit-base-patch16-224
|
121 |
+
"inputs": components.Image(type="filepath", label="Input Image"),
|
122 |
+
"outputs": components.Label(label="Classification"),
|
123 |
+
"preprocess": to_binary,
|
124 |
+
"postprocess": lambda r: postprocess_label(
|
125 |
+
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
126 |
+
),
|
127 |
+
},
|
128 |
+
"question-answering": {
|
129 |
+
# Example: deepset/xlm-roberta-base-squad2
|
130 |
+
"inputs": [
|
131 |
+
components.Textbox(lines=7, label="Context"),
|
132 |
+
components.Textbox(label="Question"),
|
133 |
+
],
|
134 |
+
"outputs": [
|
135 |
+
components.Textbox(label="Answer"),
|
136 |
+
components.Label(label="Score"),
|
137 |
+
],
|
138 |
+
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
|
139 |
+
"postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
|
140 |
+
},
|
141 |
+
"summarization": {
|
142 |
+
# Example: facebook/bart-large-cnn
|
143 |
+
"inputs": components.Textbox(label="Input"),
|
144 |
+
"outputs": components.Textbox(label="Summary"),
|
145 |
+
"preprocess": lambda x: {"inputs": x},
|
146 |
+
"postprocess": lambda r: r.json()[0]["summary_text"],
|
147 |
+
},
|
148 |
+
"text-classification": {
|
149 |
+
# Example: distilbert-base-uncased-finetuned-sst-2-english
|
150 |
+
"inputs": components.Textbox(label="Input"),
|
151 |
+
"outputs": components.Label(label="Classification"),
|
152 |
+
"preprocess": lambda x: {"inputs": x},
|
153 |
+
"postprocess": lambda r: postprocess_label(
|
154 |
+
{i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
|
155 |
+
),
|
156 |
+
},
|
157 |
+
"text-generation": {
|
158 |
+
# Example: gpt2
|
159 |
+
"inputs": components.Textbox(label="Input"),
|
160 |
+
"outputs": components.Textbox(label="Output"),
|
161 |
+
"preprocess": lambda x: {"inputs": x},
|
162 |
+
"postprocess": lambda r: r.json()[0]["generated_text"],
|
163 |
+
},
|
164 |
+
"text2text-generation": {
|
165 |
+
# Example: valhalla/t5-small-qa-qg-hl
|
166 |
+
"inputs": components.Textbox(label="Input"),
|
167 |
+
"outputs": components.Textbox(label="Generated Text"),
|
168 |
+
"preprocess": lambda x: {"inputs": x},
|
169 |
+
"postprocess": lambda r: r.json()[0]["generated_text"],
|
170 |
+
},
|
171 |
+
"translation": {
|
172 |
+
"inputs": components.Textbox(label="Input"),
|
173 |
+
"outputs": components.Textbox(label="Translation"),
|
174 |
+
"preprocess": lambda x: {"inputs": x},
|
175 |
+
"postprocess": lambda r: r.json()[0]["translation_text"],
|
176 |
+
},
|
177 |
+
"zero-shot-classification": {
|
178 |
+
# Example: facebook/bart-large-mnli
|
179 |
+
"inputs": [
|
180 |
+
components.Textbox(label="Input"),
|
181 |
+
components.Textbox(label="Possible class names (" "comma-separated)"),
|
182 |
+
components.Checkbox(label="Allow multiple true classes"),
|
183 |
+
],
|
184 |
+
"outputs": components.Label(label="Classification"),
|
185 |
+
"preprocess": lambda i, c, m: {
|
186 |
+
"inputs": i,
|
187 |
+
"parameters": {"candidate_labels": c, "multi_class": m},
|
188 |
+
},
|
189 |
+
"postprocess": lambda r: postprocess_label(
|
190 |
+
{
|
191 |
+
r.json()["labels"][i]: r.json()["scores"][i]
|
192 |
+
for i in range(len(r.json()["labels"]))
|
193 |
+
}
|
194 |
+
),
|
195 |
+
},
|
196 |
+
"sentence-similarity": {
|
197 |
+
# Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
198 |
+
"inputs": [
|
199 |
+
components.Textbox(
|
200 |
+
value="That is a happy person", label="Source Sentence"
|
201 |
+
),
|
202 |
+
components.Textbox(
|
203 |
+
lines=7,
|
204 |
+
placeholder="Separate each sentence by a newline",
|
205 |
+
label="Sentences to compare to",
|
206 |
+
),
|
207 |
+
],
|
208 |
+
"outputs": components.Label(label="Classification"),
|
209 |
+
"preprocess": lambda src, sentences: {
|
210 |
+
"inputs": {
|
211 |
+
"source_sentence": src,
|
212 |
+
"sentences": [s for s in sentences.splitlines() if s != ""],
|
213 |
+
}
|
214 |
+
},
|
215 |
+
"postprocess": lambda r: postprocess_label(
|
216 |
+
{f"sentence {i}": v for i, v in enumerate(r.json())}
|
217 |
+
),
|
218 |
+
},
|
219 |
+
"text-to-speech": {
|
220 |
+
# Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
|
221 |
+
"inputs": components.Textbox(label="Input"),
|
222 |
+
"outputs": components.Audio(label="Audio"),
|
223 |
+
"preprocess": lambda x: {"inputs": x},
|
224 |
+
"postprocess": encode_to_base64,
|
225 |
+
},
|
226 |
+
"text-to-image": {
|
227 |
+
# example model: osanseviero/BigGAN-deep-128
|
228 |
+
"inputs": components.Textbox(label="Input"),
|
229 |
+
"outputs": components.Image(label="Output"),
|
230 |
+
"preprocess": lambda x: {"inputs": x},
|
231 |
+
"postprocess": encode_to_base64,
|
232 |
+
},
|
233 |
+
"token-classification": {
|
234 |
+
# example model: huggingface-course/bert-finetuned-ner
|
235 |
+
"inputs": components.Textbox(label="Input"),
|
236 |
+
"outputs": components.HighlightedText(label="Output"),
|
237 |
+
"preprocess": lambda x: {"inputs": x},
|
238 |
+
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
|
239 |
+
},
|
240 |
+
}
|
241 |
+
|
242 |
+
if p in ["tabular-classification", "tabular-regression"]:
|
243 |
+
example_data = get_tabular_examples(model_name)
|
244 |
+
col_names, example_data = cols_to_rows(example_data)
|
245 |
+
example_data = [[example_data]] if example_data else None
|
246 |
+
|
247 |
+
pipelines[p] = {
|
248 |
+
"inputs": components.Dataframe(
|
249 |
+
label="Input Rows",
|
250 |
+
type="pandas",
|
251 |
+
headers=col_names,
|
252 |
+
col_count=(len(col_names), "fixed"),
|
253 |
+
),
|
254 |
+
"outputs": components.Dataframe(
|
255 |
+
label="Predictions", type="array", headers=["prediction"]
|
256 |
+
),
|
257 |
+
"preprocess": rows_to_cols,
|
258 |
+
"postprocess": lambda r: {
|
259 |
+
"headers": ["prediction"],
|
260 |
+
"data": [[pred] for pred in json.loads(r.text)],
|
261 |
+
},
|
262 |
+
"examples": example_data,
|
263 |
+
}
|
264 |
+
|
265 |
+
if p is None or not (p in pipelines):
|
266 |
+
raise ValueError("Unsupported pipeline type: {}".format(p))
|
267 |
+
|
268 |
+
pipeline = pipelines[p]
|
269 |
+
|
270 |
+
def query_huggingface_api(*params):
|
271 |
+
# Convert to a list of input components
|
272 |
+
data = pipeline["preprocess"](*params)
|
273 |
+
if isinstance(
|
274 |
+
data, dict
|
275 |
+
): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
|
276 |
+
data.update({"options": {"wait_for_model": True}})
|
277 |
+
data = json.dumps(data)
|
278 |
+
response = requests.request("POST", api_url, headers=headers, data=data)
|
279 |
+
if not (response.status_code == 200):
|
280 |
+
errors_json = response.json()
|
281 |
+
errors, warns = "", ""
|
282 |
+
if errors_json.get("error"):
|
283 |
+
errors = f", Error: {errors_json.get('error')}"
|
284 |
+
if errors_json.get("warnings"):
|
285 |
+
warns = f", Warnings: {errors_json.get('warnings')}"
|
286 |
+
raise ValueError(
|
287 |
+
f"Could not complete request to HuggingFace API, Status Code: {response.status_code}"
|
288 |
+
+ errors
|
289 |
+
+ warns
|
290 |
+
)
|
291 |
+
if (
|
292 |
+
p == "token-classification"
|
293 |
+
): # Handle as a special case since HF API only returns the named entities and we need the input as well
|
294 |
+
ner_groups = response.json()
|
295 |
+
input_string = params[0]
|
296 |
+
response = utils.format_ner_list(input_string, ner_groups)
|
297 |
+
output = pipeline["postprocess"](response)
|
298 |
+
return output
|
299 |
+
|
300 |
+
if alias is None:
|
301 |
+
query_huggingface_api.__name__ = model_name
|
302 |
+
else:
|
303 |
+
query_huggingface_api.__name__ = alias
|
304 |
+
|
305 |
+
interface_info = {
|
306 |
+
"fn": query_huggingface_api,
|
307 |
+
"inputs": pipeline["inputs"],
|
308 |
+
"outputs": pipeline["outputs"],
|
309 |
+
"title": model_name,
|
310 |
+
"examples": pipeline.get("examples"),
|
311 |
+
}
|
312 |
+
|
313 |
+
kwargs = dict(interface_info, **kwargs)
|
314 |
+
kwargs["_api_mode"] = True # So interface doesn't run pre/postprocess.
|
315 |
+
interface = gradio.Interface(**kwargs)
|
316 |
+
return interface
|
317 |
+
|
318 |
+
|
319 |
+
def from_spaces(
|
320 |
+
space_name: str, api_key: str | None, alias: str | None, **kwargs
|
321 |
+
) -> Blocks:
|
322 |
+
space_url = "https://huggingface.co/spaces/{}".format(space_name)
|
323 |
+
|
324 |
+
print("Fetching Space from: {}".format(space_url))
|
325 |
+
|
326 |
+
headers = {}
|
327 |
+
if api_key is not None:
|
328 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
329 |
+
|
330 |
+
iframe_url = (
|
331 |
+
requests.get(
|
332 |
+
f"https://huggingface.co/api/spaces/{space_name}/host", headers=headers
|
333 |
+
)
|
334 |
+
.json()
|
335 |
+
.get("host")
|
336 |
+
)
|
337 |
+
|
338 |
+
if iframe_url is None:
|
339 |
+
raise ValueError(
|
340 |
+
f"Could not find Space: {space_name}. If it is a private or gated Space, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
|
341 |
+
)
|
342 |
+
|
343 |
+
r = requests.get(iframe_url, headers=headers)
|
344 |
+
|
345 |
+
result = re.search(
|
346 |
+
r"window.gradio_config = (.*?);[\s]*</script>", r.text
|
347 |
+
) # some basic regex to extract the config
|
348 |
+
try:
|
349 |
+
config = json.loads(result.group(1)) # type: ignore
|
350 |
+
except AttributeError:
|
351 |
+
raise ValueError("Could not load the Space: {}".format(space_name))
|
352 |
+
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
|
353 |
+
return from_spaces_interface(
|
354 |
+
space_name, config, alias, api_key, iframe_url, **kwargs
|
355 |
+
)
|
356 |
+
else: # Create a Blocks for Gradio 3.x Spaces
|
357 |
+
if kwargs:
|
358 |
+
warnings.warn(
|
359 |
+
"You cannot override parameters for this Space by passing in kwargs. "
|
360 |
+
"Instead, please load the Space as a function and use it to create a "
|
361 |
+
"Blocks or Interface locally. You may find this Guide helpful: "
|
362 |
+
"https://gradio.app/using_blocks_like_functions/"
|
363 |
+
)
|
364 |
+
return from_spaces_blocks(config, api_key, iframe_url)
|
365 |
+
|
366 |
+
|
367 |
+
def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Blocks:
|
368 |
+
api_url = "{}/api/predict/".format(iframe_url)
|
369 |
+
|
370 |
+
headers = {"Content-Type": "application/json"}
|
371 |
+
if api_key is not None:
|
372 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
373 |
+
ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss")
|
374 |
+
|
375 |
+
ws_fn = get_ws_fn(ws_url, headers)
|
376 |
+
|
377 |
+
fns = []
|
378 |
+
for d, dependency in enumerate(config["dependencies"]):
|
379 |
+
if dependency["backend_fn"]:
|
380 |
+
|
381 |
+
def get_fn(outputs, fn_index, use_ws):
|
382 |
+
def fn(*data):
|
383 |
+
data = json.dumps({"data": data, "fn_index": fn_index})
|
384 |
+
hash_data = json.dumps(
|
385 |
+
{"fn_index": fn_index, "session_hash": str(uuid.uuid4())}
|
386 |
+
)
|
387 |
+
if use_ws:
|
388 |
+
result = utils.synchronize_async(ws_fn, data, hash_data)
|
389 |
+
output = result["data"]
|
390 |
+
else:
|
391 |
+
response = requests.post(api_url, headers=headers, data=data)
|
392 |
+
result = json.loads(response.content.decode("utf-8"))
|
393 |
+
try:
|
394 |
+
output = result["data"]
|
395 |
+
except KeyError:
|
396 |
+
if "error" in result and "429" in result["error"]:
|
397 |
+
raise TooManyRequestsError(
|
398 |
+
"Too many requests to the Hugging Face API"
|
399 |
+
)
|
400 |
+
raise KeyError(
|
401 |
+
f"Could not find 'data' key in response from external Space. Response received: {result}"
|
402 |
+
)
|
403 |
+
if len(outputs) == 1:
|
404 |
+
output = output[0]
|
405 |
+
return output
|
406 |
+
|
407 |
+
return fn
|
408 |
+
|
409 |
+
fn = get_fn(
|
410 |
+
deepcopy(dependency["outputs"]), d, use_websocket(config, dependency)
|
411 |
+
)
|
412 |
+
fns.append(fn)
|
413 |
+
else:
|
414 |
+
fns.append(None)
|
415 |
+
return gradio.Blocks.from_config(config, fns, iframe_url)
|
416 |
+
|
417 |
+
|
418 |
+
def from_spaces_interface(
|
419 |
+
model_name: str,
|
420 |
+
config: Dict,
|
421 |
+
alias: str | None,
|
422 |
+
api_key: str | None,
|
423 |
+
iframe_url: str,
|
424 |
+
**kwargs,
|
425 |
+
) -> Interface:
|
426 |
+
|
427 |
+
config = streamline_spaces_interface(config)
|
428 |
+
api_url = "{}/api/predict/".format(iframe_url)
|
429 |
+
headers = {"Content-Type": "application/json"}
|
430 |
+
if api_key is not None:
|
431 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
432 |
+
|
433 |
+
# The function should call the API with preprocessed data
|
434 |
+
def fn(*data):
|
435 |
+
data = json.dumps({"data": data})
|
436 |
+
response = requests.post(api_url, headers=headers, data=data)
|
437 |
+
result = json.loads(response.content.decode("utf-8"))
|
438 |
+
try:
|
439 |
+
output = result["data"]
|
440 |
+
except KeyError:
|
441 |
+
if "error" in result and "429" in result["error"]:
|
442 |
+
raise TooManyRequestsError("Too many requests to the Hugging Face API")
|
443 |
+
raise KeyError(
|
444 |
+
f"Could not find 'data' key in response from external Space. Response received: {result}"
|
445 |
+
)
|
446 |
+
if (
|
447 |
+
len(config["outputs"]) == 1
|
448 |
+
): # if the fn is supposed to return a single value, pop it
|
449 |
+
output = output[0]
|
450 |
+
if len(config["outputs"]) == 1 and isinstance(
|
451 |
+
output, list
|
452 |
+
): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
|
453 |
+
output = output[0]
|
454 |
+
return output
|
455 |
+
|
456 |
+
fn.__name__ = alias if (alias is not None) else model_name
|
457 |
+
config["fn"] = fn
|
458 |
+
|
459 |
+
kwargs = dict(config, **kwargs)
|
460 |
+
kwargs["_api_mode"] = True
|
461 |
+
interface = gradio.Interface(**kwargs)
|
462 |
+
return interface
|
gradio-modified/gradio/external_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility function for gradio/external.py"""
|
2 |
+
|
3 |
+
import base64
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import operator
|
7 |
+
import re
|
8 |
+
import warnings
|
9 |
+
from typing import Any, Dict, List, Tuple
|
10 |
+
|
11 |
+
import requests
|
12 |
+
import websockets
|
13 |
+
import yaml
|
14 |
+
from packaging import version
|
15 |
+
from websockets.legacy.protocol import WebSocketCommonProtocol
|
16 |
+
|
17 |
+
from gradio import components, exceptions
|
18 |
+
|
19 |
+
##################
|
20 |
+
# Helper functions for processing tabular data
|
21 |
+
##################
|
22 |
+
|
23 |
+
|
24 |
+
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
25 |
+
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
26 |
+
if readme.status_code != 200:
|
27 |
+
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
28 |
+
example_data = {}
|
29 |
+
else:
|
30 |
+
yaml_regex = re.search(
|
31 |
+
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
32 |
+
)
|
33 |
+
if yaml_regex is None:
|
34 |
+
example_data = {}
|
35 |
+
else:
|
36 |
+
example_yaml = next(
|
37 |
+
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
|
38 |
+
)
|
39 |
+
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
40 |
+
if not example_data:
|
41 |
+
raise ValueError(
|
42 |
+
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
43 |
+
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
|
44 |
+
"for a reference on how to provide example data to your model."
|
45 |
+
)
|
46 |
+
# replace nan with string NaN for inference API
|
47 |
+
for data in example_data.values():
|
48 |
+
for i, val in enumerate(data):
|
49 |
+
if isinstance(val, float) and math.isnan(val):
|
50 |
+
data[i] = "NaN"
|
51 |
+
return example_data
|
52 |
+
|
53 |
+
|
54 |
+
def cols_to_rows(
|
55 |
+
example_data: Dict[str, List[float]]
|
56 |
+
) -> Tuple[List[str], List[List[float]]]:
|
57 |
+
headers = list(example_data.keys())
|
58 |
+
n_rows = max(len(example_data[header] or []) for header in headers)
|
59 |
+
data = []
|
60 |
+
for row_index in range(n_rows):
|
61 |
+
row_data = []
|
62 |
+
for header in headers:
|
63 |
+
col = example_data[header] or []
|
64 |
+
if row_index >= len(col):
|
65 |
+
row_data.append("NaN")
|
66 |
+
else:
|
67 |
+
row_data.append(col[row_index])
|
68 |
+
data.append(row_data)
|
69 |
+
return headers, data
|
70 |
+
|
71 |
+
|
72 |
+
def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
|
73 |
+
data_column_wise = {}
|
74 |
+
for i, header in enumerate(incoming_data["headers"]):
|
75 |
+
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
|
76 |
+
return {"inputs": {"data": data_column_wise}}
|
77 |
+
|
78 |
+
|
79 |
+
##################
|
80 |
+
# Helper functions for processing other kinds of data
|
81 |
+
##################
|
82 |
+
|
83 |
+
|
84 |
+
def postprocess_label(scores: Dict) -> Dict:
|
85 |
+
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
|
86 |
+
return {
|
87 |
+
"label": sorted_pred[0][0],
|
88 |
+
"confidences": [
|
89 |
+
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
|
90 |
+
],
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def encode_to_base64(r: requests.Response) -> str:
|
95 |
+
# Handles the different ways HF API returns the prediction
|
96 |
+
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
97 |
+
data_prefix = ";base64,"
|
98 |
+
# Case 1: base64 representation already includes data prefix
|
99 |
+
if data_prefix in base64_repr:
|
100 |
+
return base64_repr
|
101 |
+
else:
|
102 |
+
content_type = r.headers.get("content-type")
|
103 |
+
# Case 2: the data prefix is a key in the response
|
104 |
+
if content_type == "application/json":
|
105 |
+
try:
|
106 |
+
content_type = r.json()[0]["content-type"]
|
107 |
+
base64_repr = r.json()[0]["blob"]
|
108 |
+
except KeyError:
|
109 |
+
raise ValueError(
|
110 |
+
"Cannot determine content type returned" "by external API."
|
111 |
+
)
|
112 |
+
# Case 3: the data prefix is included in the response headers
|
113 |
+
else:
|
114 |
+
pass
|
115 |
+
new_base64 = "data:{};base64,".format(content_type) + base64_repr
|
116 |
+
return new_base64
|
117 |
+
|
118 |
+
|
119 |
+
##################
|
120 |
+
# Helper functions for connecting to websockets
|
121 |
+
##################
|
122 |
+
|
123 |
+
|
124 |
+
async def get_pred_from_ws(
|
125 |
+
websocket: WebSocketCommonProtocol, data: str, hash_data: str
|
126 |
+
) -> Dict[str, Any]:
|
127 |
+
completed = False
|
128 |
+
resp = {}
|
129 |
+
while not completed:
|
130 |
+
msg = await websocket.recv()
|
131 |
+
resp = json.loads(msg)
|
132 |
+
if resp["msg"] == "queue_full":
|
133 |
+
raise exceptions.Error("Queue is full! Please try again.")
|
134 |
+
if resp["msg"] == "send_hash":
|
135 |
+
await websocket.send(hash_data)
|
136 |
+
elif resp["msg"] == "send_data":
|
137 |
+
await websocket.send(data)
|
138 |
+
completed = resp["msg"] == "process_completed"
|
139 |
+
return resp["output"]
|
140 |
+
|
141 |
+
|
142 |
+
def get_ws_fn(ws_url, headers):
|
143 |
+
async def ws_fn(data, hash_data):
|
144 |
+
async with websockets.connect( # type: ignore
|
145 |
+
ws_url, open_timeout=10, extra_headers=headers
|
146 |
+
) as websocket:
|
147 |
+
return await get_pred_from_ws(websocket, data, hash_data)
|
148 |
+
|
149 |
+
return ws_fn
|
150 |
+
|
151 |
+
|
152 |
+
def use_websocket(config, dependency):
|
153 |
+
queue_enabled = config.get("enable_queue", False)
|
154 |
+
queue_uses_websocket = version.parse(
|
155 |
+
config.get("version", "2.0")
|
156 |
+
) >= version.Version("3.2")
|
157 |
+
dependency_uses_queue = dependency.get("queue", False) is not False
|
158 |
+
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
159 |
+
|
160 |
+
|
161 |
+
##################
|
162 |
+
# Helper function for cleaning up an Interface loaded from HF Spaces
|
163 |
+
##################
|
164 |
+
|
165 |
+
|
166 |
+
def streamline_spaces_interface(config: Dict) -> Dict:
|
167 |
+
"""Streamlines the interface config dictionary to remove unnecessary keys."""
|
168 |
+
config["inputs"] = [
|
169 |
+
components.get_component_instance(component)
|
170 |
+
for component in config["input_components"]
|
171 |
+
]
|
172 |
+
config["outputs"] = [
|
173 |
+
components.get_component_instance(component)
|
174 |
+
for component in config["output_components"]
|
175 |
+
]
|
176 |
+
parameters = {
|
177 |
+
"article",
|
178 |
+
"description",
|
179 |
+
"flagging_options",
|
180 |
+
"inputs",
|
181 |
+
"outputs",
|
182 |
+
"theme",
|
183 |
+
"title",
|
184 |
+
}
|
185 |
+
config = {k: config[k] for k in parameters}
|
186 |
+
return config
|
gradio-modified/gradio/flagging.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import csv
|
4 |
+
import datetime
|
5 |
+
import io
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import uuid
|
9 |
+
from abc import ABC, abstractmethod
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import TYPE_CHECKING, Any, List
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
from gradio import encryptor, utils
|
15 |
+
from gradio.documentation import document, set_documentation_group
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from gradio.components import IOComponent
|
19 |
+
|
20 |
+
set_documentation_group("flagging")
|
21 |
+
|
22 |
+
|
23 |
+
def _get_dataset_features_info(is_new, components):
|
24 |
+
"""
|
25 |
+
Takes in a list of components and returns a dataset features info
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
is_new: boolean, whether the dataset is new or not
|
29 |
+
components: list of components
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
infos: a dictionary of the dataset features
|
33 |
+
file_preview_types: dictionary mapping of gradio components to appropriate string.
|
34 |
+
header: list of header strings
|
35 |
+
|
36 |
+
"""
|
37 |
+
infos = {"flagged": {"features": {}}}
|
38 |
+
# File previews for certain input and output types
|
39 |
+
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
|
40 |
+
headers = []
|
41 |
+
|
42 |
+
# Generate the headers and dataset_infos
|
43 |
+
if is_new:
|
44 |
+
|
45 |
+
for component in components:
|
46 |
+
headers.append(component.label)
|
47 |
+
infos["flagged"]["features"][component.label] = {
|
48 |
+
"dtype": "string",
|
49 |
+
"_type": "Value",
|
50 |
+
}
|
51 |
+
if isinstance(component, tuple(file_preview_types)):
|
52 |
+
headers.append(component.label + " file")
|
53 |
+
for _component, _type in file_preview_types.items():
|
54 |
+
if isinstance(component, _component):
|
55 |
+
infos["flagged"]["features"][
|
56 |
+
(component.label or "") + " file"
|
57 |
+
] = {"_type": _type}
|
58 |
+
break
|
59 |
+
|
60 |
+
headers.append("flag")
|
61 |
+
infos["flagged"]["features"]["flag"] = {
|
62 |
+
"dtype": "string",
|
63 |
+
"_type": "Value",
|
64 |
+
}
|
65 |
+
|
66 |
+
return infos, file_preview_types, headers
|
67 |
+
|
68 |
+
|
69 |
+
class FlaggingCallback(ABC):
|
70 |
+
"""
|
71 |
+
An abstract class for defining the methods that any FlaggingCallback should have.
|
72 |
+
"""
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def setup(self, components: List[IOComponent], flagging_dir: str):
|
76 |
+
"""
|
77 |
+
This method should be overridden and ensure that everything is set up correctly for flag().
|
78 |
+
This method gets called once at the beginning of the Interface.launch() method.
|
79 |
+
Parameters:
|
80 |
+
components: Set of components that will provide flagged data.
|
81 |
+
flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
|
82 |
+
"""
|
83 |
+
pass
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def flag(
|
87 |
+
self,
|
88 |
+
flag_data: List[Any],
|
89 |
+
flag_option: str | None = None,
|
90 |
+
flag_index: int | None = None,
|
91 |
+
username: str | None = None,
|
92 |
+
) -> int:
|
93 |
+
"""
|
94 |
+
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
|
95 |
+
This gets called every time the <flag> button is pressed.
|
96 |
+
Parameters:
|
97 |
+
interface: The Interface object that is being used to launch the flagging interface.
|
98 |
+
flag_data: The data to be flagged.
|
99 |
+
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
|
100 |
+
flag_index (optional): The index of the sample that is being flagged.
|
101 |
+
username (optional): The username of the user that is flagging the data, if logged in.
|
102 |
+
Returns:
|
103 |
+
(int) The total number of samples that have been flagged.
|
104 |
+
"""
|
105 |
+
pass
|
106 |
+
|
107 |
+
|
108 |
+
@document()
|
109 |
+
class SimpleCSVLogger(FlaggingCallback):
|
110 |
+
"""
|
111 |
+
A simplified implementation of the FlaggingCallback abstract class
|
112 |
+
provided for illustrative purposes. Each flagged sample (both the input and output data)
|
113 |
+
is logged to a CSV file on the machine running the gradio app.
|
114 |
+
Example:
|
115 |
+
import gradio as gr
|
116 |
+
def image_classifier(inp):
|
117 |
+
return {'cat': 0.3, 'dog': 0.7}
|
118 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
119 |
+
flagging_callback=SimpleCSVLogger())
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self):
|
123 |
+
pass
|
124 |
+
|
125 |
+
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
|
126 |
+
self.components = components
|
127 |
+
self.flagging_dir = flagging_dir
|
128 |
+
os.makedirs(flagging_dir, exist_ok=True)
|
129 |
+
|
130 |
+
def flag(
|
131 |
+
self,
|
132 |
+
flag_data: List[Any],
|
133 |
+
flag_option: str | None = None,
|
134 |
+
flag_index: int | None = None,
|
135 |
+
username: str | None = None,
|
136 |
+
) -> int:
|
137 |
+
flagging_dir = self.flagging_dir
|
138 |
+
log_filepath = Path(flagging_dir) / "log.csv"
|
139 |
+
|
140 |
+
csv_data = []
|
141 |
+
for component, sample in zip(self.components, flag_data):
|
142 |
+
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
|
143 |
+
component.label or ""
|
144 |
+
)
|
145 |
+
csv_data.append(
|
146 |
+
component.deserialize(
|
147 |
+
sample,
|
148 |
+
save_dir,
|
149 |
+
None,
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
with open(log_filepath, "a", newline="") as csvfile:
|
154 |
+
writer = csv.writer(csvfile)
|
155 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
156 |
+
|
157 |
+
with open(log_filepath, "r") as csvfile:
|
158 |
+
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
159 |
+
return line_count
|
160 |
+
|
161 |
+
|
162 |
+
@document()
|
163 |
+
class CSVLogger(FlaggingCallback):
|
164 |
+
"""
|
165 |
+
The default implementation of the FlaggingCallback abstract class. Each flagged
|
166 |
+
sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
|
167 |
+
Example:
|
168 |
+
import gradio as gr
|
169 |
+
def image_classifier(inp):
|
170 |
+
return {'cat': 0.3, 'dog': 0.7}
|
171 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
172 |
+
flagging_callback=CSVLogger())
|
173 |
+
Guides: using_flagging
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self):
|
177 |
+
pass
|
178 |
+
|
179 |
+
def setup(
|
180 |
+
self,
|
181 |
+
components: List[IOComponent],
|
182 |
+
flagging_dir: str | Path,
|
183 |
+
encryption_key: bytes | None = None,
|
184 |
+
):
|
185 |
+
self.components = components
|
186 |
+
self.flagging_dir = flagging_dir
|
187 |
+
self.encryption_key = encryption_key
|
188 |
+
os.makedirs(flagging_dir, exist_ok=True)
|
189 |
+
|
190 |
+
def flag(
|
191 |
+
self,
|
192 |
+
flag_data: List[Any],
|
193 |
+
flag_option: str | None = None,
|
194 |
+
flag_index: int | None = None,
|
195 |
+
username: str | None = None,
|
196 |
+
) -> int:
|
197 |
+
flagging_dir = self.flagging_dir
|
198 |
+
log_filepath = Path(flagging_dir) / "log.csv"
|
199 |
+
is_new = not Path(log_filepath).exists()
|
200 |
+
headers = [
|
201 |
+
component.label or f"component {idx}"
|
202 |
+
for idx, component in enumerate(self.components)
|
203 |
+
] + [
|
204 |
+
"flag",
|
205 |
+
"username",
|
206 |
+
"timestamp",
|
207 |
+
]
|
208 |
+
|
209 |
+
csv_data = []
|
210 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
211 |
+
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
|
212 |
+
component.label or f"component {idx}"
|
213 |
+
)
|
214 |
+
if utils.is_update(sample):
|
215 |
+
csv_data.append(str(sample))
|
216 |
+
else:
|
217 |
+
csv_data.append(
|
218 |
+
component.deserialize(
|
219 |
+
sample,
|
220 |
+
save_dir=save_dir,
|
221 |
+
encryption_key=self.encryption_key,
|
222 |
+
)
|
223 |
+
if sample is not None
|
224 |
+
else ""
|
225 |
+
)
|
226 |
+
csv_data.append(flag_option if flag_option is not None else "")
|
227 |
+
csv_data.append(username if username is not None else "")
|
228 |
+
csv_data.append(str(datetime.datetime.now()))
|
229 |
+
|
230 |
+
def replace_flag_at_index(file_content: str, flag_index: int):
|
231 |
+
file_content_ = io.StringIO(file_content)
|
232 |
+
content = list(csv.reader(file_content_))
|
233 |
+
header = content[0]
|
234 |
+
flag_col_index = header.index("flag")
|
235 |
+
content[flag_index][flag_col_index] = flag_option # type: ignore
|
236 |
+
output = io.StringIO()
|
237 |
+
writer = csv.writer(output)
|
238 |
+
writer.writerows(utils.sanitize_list_for_csv(content))
|
239 |
+
return output.getvalue()
|
240 |
+
|
241 |
+
if self.encryption_key:
|
242 |
+
output = io.StringIO()
|
243 |
+
if not is_new:
|
244 |
+
with open(log_filepath, "rb", encoding="utf-8") as csvfile:
|
245 |
+
encrypted_csv = csvfile.read()
|
246 |
+
decrypted_csv = encryptor.decrypt(
|
247 |
+
self.encryption_key, encrypted_csv
|
248 |
+
)
|
249 |
+
file_content = decrypted_csv.decode()
|
250 |
+
if flag_index is not None:
|
251 |
+
file_content = replace_flag_at_index(file_content, flag_index)
|
252 |
+
output.write(file_content)
|
253 |
+
writer = csv.writer(output)
|
254 |
+
if flag_index is None:
|
255 |
+
if is_new:
|
256 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
257 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
258 |
+
with open(log_filepath, "wb", encoding="utf-8") as csvfile:
|
259 |
+
csvfile.write(
|
260 |
+
encryptor.encrypt(self.encryption_key, output.getvalue().encode())
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
if flag_index is None:
|
264 |
+
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
|
265 |
+
writer = csv.writer(csvfile)
|
266 |
+
if is_new:
|
267 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
268 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
269 |
+
else:
|
270 |
+
with open(log_filepath, encoding="utf-8") as csvfile:
|
271 |
+
file_content = csvfile.read()
|
272 |
+
file_content = replace_flag_at_index(file_content, flag_index)
|
273 |
+
with open(
|
274 |
+
log_filepath, "w", newline="", encoding="utf-8"
|
275 |
+
) as csvfile: # newline parameter needed for Windows
|
276 |
+
csvfile.write(file_content)
|
277 |
+
with open(log_filepath, "r", encoding="utf-8") as csvfile:
|
278 |
+
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
279 |
+
return line_count
|
280 |
+
|
281 |
+
|
282 |
+
@document()
|
283 |
+
class HuggingFaceDatasetSaver(FlaggingCallback):
|
284 |
+
"""
|
285 |
+
A callback that saves each flagged sample (both the input and output data)
|
286 |
+
to a HuggingFace dataset.
|
287 |
+
Example:
|
288 |
+
import gradio as gr
|
289 |
+
hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
|
290 |
+
def image_classifier(inp):
|
291 |
+
return {'cat': 0.3, 'dog': 0.7}
|
292 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
293 |
+
allow_flagging="manual", flagging_callback=hf_writer)
|
294 |
+
Guides: using_flagging
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
hf_token: str,
|
300 |
+
dataset_name: str,
|
301 |
+
organization: str | None = None,
|
302 |
+
private: bool = False,
|
303 |
+
):
|
304 |
+
"""
|
305 |
+
Parameters:
|
306 |
+
hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
|
307 |
+
dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
|
308 |
+
organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
|
309 |
+
private: Whether the dataset should be private (defaults to False).
|
310 |
+
"""
|
311 |
+
self.hf_token = hf_token
|
312 |
+
self.dataset_name = dataset_name
|
313 |
+
self.organization_name = organization
|
314 |
+
self.dataset_private = private
|
315 |
+
|
316 |
+
def setup(self, components: List[IOComponent], flagging_dir: str):
|
317 |
+
"""
|
318 |
+
Params:
|
319 |
+
flagging_dir (str): local directory where the dataset is cloned,
|
320 |
+
updated, and pushed from.
|
321 |
+
"""
|
322 |
+
try:
|
323 |
+
import huggingface_hub
|
324 |
+
except (ImportError, ModuleNotFoundError):
|
325 |
+
raise ImportError(
|
326 |
+
"Package `huggingface_hub` not found is needed "
|
327 |
+
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
|
328 |
+
)
|
329 |
+
path_to_dataset_repo = huggingface_hub.create_repo(
|
330 |
+
name=self.dataset_name,
|
331 |
+
token=self.hf_token,
|
332 |
+
private=self.dataset_private,
|
333 |
+
repo_type="dataset",
|
334 |
+
exist_ok=True,
|
335 |
+
)
|
336 |
+
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
337 |
+
self.components = components
|
338 |
+
self.flagging_dir = flagging_dir
|
339 |
+
self.dataset_dir = Path(flagging_dir) / self.dataset_name
|
340 |
+
self.repo = huggingface_hub.Repository(
|
341 |
+
local_dir=str(self.dataset_dir),
|
342 |
+
clone_from=path_to_dataset_repo,
|
343 |
+
use_auth_token=self.hf_token,
|
344 |
+
)
|
345 |
+
self.repo.git_pull(lfs=True)
|
346 |
+
|
347 |
+
# Should filename be user-specified?
|
348 |
+
self.log_file = Path(self.dataset_dir) / "data.csv"
|
349 |
+
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
|
350 |
+
|
351 |
+
def flag(
|
352 |
+
self,
|
353 |
+
flag_data: List[Any],
|
354 |
+
flag_option: str | None = None,
|
355 |
+
flag_index: int | None = None,
|
356 |
+
username: str | None = None,
|
357 |
+
) -> int:
|
358 |
+
self.repo.git_pull(lfs=True)
|
359 |
+
|
360 |
+
is_new = not Path(self.log_file).exists()
|
361 |
+
|
362 |
+
with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
|
363 |
+
writer = csv.writer(csvfile)
|
364 |
+
|
365 |
+
# File previews for certain input and output types
|
366 |
+
infos, file_preview_types, headers = _get_dataset_features_info(
|
367 |
+
is_new, self.components
|
368 |
+
)
|
369 |
+
|
370 |
+
# Generate the headers and dataset_infos
|
371 |
+
if is_new:
|
372 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
373 |
+
|
374 |
+
# Generate the row corresponding to the flagged sample
|
375 |
+
csv_data = []
|
376 |
+
for component, sample in zip(self.components, flag_data):
|
377 |
+
save_dir = Path(
|
378 |
+
self.dataset_dir
|
379 |
+
) / utils.strip_invalid_filename_characters(component.label or "")
|
380 |
+
filepath = component.deserialize(sample, save_dir, None)
|
381 |
+
csv_data.append(filepath)
|
382 |
+
if isinstance(component, tuple(file_preview_types)):
|
383 |
+
csv_data.append(
|
384 |
+
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
385 |
+
)
|
386 |
+
csv_data.append(flag_option if flag_option is not None else "")
|
387 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
388 |
+
|
389 |
+
if is_new:
|
390 |
+
json.dump(infos, open(self.infos_file, "w"))
|
391 |
+
|
392 |
+
with open(self.log_file, "r", encoding="utf-8") as csvfile:
|
393 |
+
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
394 |
+
|
395 |
+
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
|
396 |
+
|
397 |
+
return line_count
|
398 |
+
|
399 |
+
|
400 |
+
class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
401 |
+
"""
|
402 |
+
A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format.
|
403 |
+
|
404 |
+
Each data sample is saved in a different JSONL file,
|
405 |
+
allowing multiple users to use flagging simultaneously.
|
406 |
+
Saving to a single CSV would cause errors as only one user can edit at the same time.
|
407 |
+
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(
|
411 |
+
self,
|
412 |
+
hf_foken: str,
|
413 |
+
dataset_name: str,
|
414 |
+
organization: str | None = None,
|
415 |
+
private: bool = False,
|
416 |
+
verbose: bool = True,
|
417 |
+
):
|
418 |
+
"""
|
419 |
+
Params:
|
420 |
+
hf_token (str): The token to use to access the huggingface API.
|
421 |
+
dataset_name (str): The name of the dataset to save the data to, e.g.
|
422 |
+
"image-classifier-1"
|
423 |
+
organization (str): The name of the organization to which to attach
|
424 |
+
the datasets. If None, the dataset attaches to the user only.
|
425 |
+
private (bool): If the dataset does not already exist, whether it
|
426 |
+
should be created as a private dataset or public. Private datasets
|
427 |
+
may require paid huggingface.co accounts
|
428 |
+
verbose (bool): Whether to print out the status of the dataset
|
429 |
+
creation.
|
430 |
+
"""
|
431 |
+
self.hf_foken = hf_foken
|
432 |
+
self.dataset_name = dataset_name
|
433 |
+
self.organization_name = organization
|
434 |
+
self.dataset_private = private
|
435 |
+
self.verbose = verbose
|
436 |
+
|
437 |
+
def setup(self, components: List[IOComponent], flagging_dir: str):
|
438 |
+
"""
|
439 |
+
Params:
|
440 |
+
components List[Component]: list of components for flagging
|
441 |
+
flagging_dir (str): local directory where the dataset is cloned,
|
442 |
+
updated, and pushed from.
|
443 |
+
"""
|
444 |
+
try:
|
445 |
+
import huggingface_hub
|
446 |
+
except (ImportError, ModuleNotFoundError):
|
447 |
+
raise ImportError(
|
448 |
+
"Package `huggingface_hub` not found is needed "
|
449 |
+
"for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
|
450 |
+
)
|
451 |
+
path_to_dataset_repo = huggingface_hub.create_repo(
|
452 |
+
name=self.dataset_name,
|
453 |
+
token=self.hf_foken,
|
454 |
+
private=self.dataset_private,
|
455 |
+
repo_type="dataset",
|
456 |
+
exist_ok=True,
|
457 |
+
)
|
458 |
+
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
459 |
+
self.components = components
|
460 |
+
self.flagging_dir = flagging_dir
|
461 |
+
self.dataset_dir = Path(flagging_dir) / self.dataset_name
|
462 |
+
self.repo = huggingface_hub.Repository(
|
463 |
+
local_dir=str(self.dataset_dir),
|
464 |
+
clone_from=path_to_dataset_repo,
|
465 |
+
use_auth_token=self.hf_foken,
|
466 |
+
)
|
467 |
+
self.repo.git_pull(lfs=True)
|
468 |
+
|
469 |
+
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
|
470 |
+
|
471 |
+
def flag(
|
472 |
+
self,
|
473 |
+
flag_data: List[Any],
|
474 |
+
flag_option: str | None = None,
|
475 |
+
flag_index: int | None = None,
|
476 |
+
username: str | None = None,
|
477 |
+
) -> str:
|
478 |
+
self.repo.git_pull(lfs=True)
|
479 |
+
|
480 |
+
# Generate unique folder for the flagged sample
|
481 |
+
unique_name = self.get_unique_name() # unique name for folder
|
482 |
+
folder_name = (
|
483 |
+
Path(self.dataset_dir) / unique_name
|
484 |
+
) # unique folder for specific example
|
485 |
+
os.makedirs(folder_name)
|
486 |
+
|
487 |
+
# Now uses the existence of `dataset_infos.json` to determine if new
|
488 |
+
is_new = not Path(self.infos_file).exists()
|
489 |
+
|
490 |
+
# File previews for certain input and output types
|
491 |
+
infos, file_preview_types, _ = _get_dataset_features_info(
|
492 |
+
is_new, self.components
|
493 |
+
)
|
494 |
+
|
495 |
+
# Generate the row and header corresponding to the flagged sample
|
496 |
+
csv_data = []
|
497 |
+
headers = []
|
498 |
+
|
499 |
+
for component, sample in zip(self.components, flag_data):
|
500 |
+
headers.append(component.label)
|
501 |
+
|
502 |
+
try:
|
503 |
+
save_dir = Path(folder_name) / utils.strip_invalid_filename_characters(
|
504 |
+
component.label or ""
|
505 |
+
)
|
506 |
+
filepath = component.deserialize(sample, save_dir, None)
|
507 |
+
except Exception:
|
508 |
+
# Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
|
509 |
+
# does not handle None cases.
|
510 |
+
# for example: Label (line 3109 of components.py raises an error if data is None)
|
511 |
+
filepath = None
|
512 |
+
|
513 |
+
if isinstance(component, tuple(file_preview_types)):
|
514 |
+
headers.append(component.label or "" + " file")
|
515 |
+
|
516 |
+
csv_data.append(
|
517 |
+
"{}/resolve/main/{}/{}".format(
|
518 |
+
self.path_to_dataset_repo, unique_name, filepath
|
519 |
+
)
|
520 |
+
if filepath is not None
|
521 |
+
else None
|
522 |
+
)
|
523 |
+
|
524 |
+
csv_data.append(filepath)
|
525 |
+
headers.append("flag")
|
526 |
+
csv_data.append(flag_option if flag_option is not None else "")
|
527 |
+
|
528 |
+
# Creates metadata dict from row data and dumps it
|
529 |
+
metadata_dict = {
|
530 |
+
header: _csv_data for header, _csv_data in zip(headers, csv_data)
|
531 |
+
}
|
532 |
+
self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
|
533 |
+
|
534 |
+
if is_new:
|
535 |
+
json.dump(infos, open(self.infos_file, "w"))
|
536 |
+
|
537 |
+
self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
|
538 |
+
return unique_name
|
539 |
+
|
540 |
+
def get_unique_name(self):
|
541 |
+
id = uuid.uuid4()
|
542 |
+
return str(id)
|
543 |
+
|
544 |
+
def dump_json(self, thing: dict, file_path: str | Path) -> None:
|
545 |
+
with open(file_path, "w+", encoding="utf8") as f:
|
546 |
+
json.dump(thing, f)
|
547 |
+
|
548 |
+
|
549 |
+
class FlagMethod:
|
550 |
+
"""
|
551 |
+
Helper class that contains the flagging button option and callback
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
|
555 |
+
self.flagging_callback = flagging_callback
|
556 |
+
self.flag_option = flag_option
|
557 |
+
self.__name__ = "Flag"
|
558 |
+
|
559 |
+
def __call__(self, *flag_data):
|
560 |
+
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
|
gradio-modified/gradio/helpers.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines helper methods useful for loading and caching Interface examples.
|
3 |
+
"""
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import csv
|
8 |
+
import inspect
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import threading
|
13 |
+
import warnings
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple
|
16 |
+
|
17 |
+
import matplotlib
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import PIL
|
21 |
+
|
22 |
+
from gradio import processing_utils, routes, utils
|
23 |
+
from gradio.context import Context
|
24 |
+
from gradio.documentation import document, set_documentation_group
|
25 |
+
from gradio.flagging import CSVLogger
|
26 |
+
|
27 |
+
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
28 |
+
from gradio.components import IOComponent
|
29 |
+
|
30 |
+
CACHED_FOLDER = "gradio_cached_examples"
|
31 |
+
LOG_FILE = "log.csv"
|
32 |
+
|
33 |
+
set_documentation_group("helpers")
|
34 |
+
|
35 |
+
|
36 |
+
def create_examples(
|
37 |
+
examples: List[Any] | List[List[Any]] | str,
|
38 |
+
inputs: IOComponent | List[IOComponent],
|
39 |
+
outputs: IOComponent | List[IOComponent] | None = None,
|
40 |
+
fn: Callable | None = None,
|
41 |
+
cache_examples: bool = False,
|
42 |
+
examples_per_page: int = 10,
|
43 |
+
_api_mode: bool = False,
|
44 |
+
label: str | None = None,
|
45 |
+
elem_id: str | None = None,
|
46 |
+
run_on_click: bool = False,
|
47 |
+
preprocess: bool = True,
|
48 |
+
postprocess: bool = True,
|
49 |
+
batch: bool = False,
|
50 |
+
):
|
51 |
+
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
|
52 |
+
examples_obj = Examples(
|
53 |
+
examples=examples,
|
54 |
+
inputs=inputs,
|
55 |
+
outputs=outputs,
|
56 |
+
fn=fn,
|
57 |
+
cache_examples=cache_examples,
|
58 |
+
examples_per_page=examples_per_page,
|
59 |
+
_api_mode=_api_mode,
|
60 |
+
label=label,
|
61 |
+
elem_id=elem_id,
|
62 |
+
run_on_click=run_on_click,
|
63 |
+
preprocess=preprocess,
|
64 |
+
postprocess=postprocess,
|
65 |
+
batch=batch,
|
66 |
+
_initiated_directly=False,
|
67 |
+
)
|
68 |
+
utils.synchronize_async(examples_obj.create)
|
69 |
+
return examples_obj
|
70 |
+
|
71 |
+
|
72 |
+
@document()
|
73 |
+
class Examples:
|
74 |
+
"""
|
75 |
+
This class is a wrapper over the Dataset component and can be used to create Examples
|
76 |
+
for Blocks / Interfaces. Populates the Dataset component with examples and
|
77 |
+
assigns event listener so that clicking on an example populates the input/output
|
78 |
+
components. Optionally handles example caching for fast inference.
|
79 |
+
|
80 |
+
Demos: blocks_inputs, fake_gan
|
81 |
+
Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
examples: List[Any] | List[List[Any]] | str,
|
87 |
+
inputs: IOComponent | List[IOComponent],
|
88 |
+
outputs: Optional[IOComponent | List[IOComponent]] = None,
|
89 |
+
fn: Optional[Callable] = None,
|
90 |
+
cache_examples: bool = False,
|
91 |
+
examples_per_page: int = 10,
|
92 |
+
_api_mode: bool = False,
|
93 |
+
label: str = "Examples",
|
94 |
+
elem_id: Optional[str] = None,
|
95 |
+
run_on_click: bool = False,
|
96 |
+
preprocess: bool = True,
|
97 |
+
postprocess: bool = True,
|
98 |
+
batch: bool = False,
|
99 |
+
_initiated_directly: bool = True,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Parameters:
|
103 |
+
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
104 |
+
inputs: the component or list of components corresponding to the examples
|
105 |
+
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
106 |
+
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
107 |
+
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
|
108 |
+
examples_per_page: how many examples to show per page.
|
109 |
+
label: the label to use for the examples component (by default, "Examples")
|
110 |
+
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
|
111 |
+
run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
|
112 |
+
preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
|
113 |
+
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
|
114 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
|
115 |
+
"""
|
116 |
+
if _initiated_directly:
|
117 |
+
warnings.warn(
|
118 |
+
"Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
|
119 |
+
)
|
120 |
+
|
121 |
+
if cache_examples and (fn is None or outputs is None):
|
122 |
+
raise ValueError("If caching examples, `fn` and `outputs` must be provided")
|
123 |
+
|
124 |
+
if not isinstance(inputs, list):
|
125 |
+
inputs = [inputs]
|
126 |
+
if not isinstance(outputs, list):
|
127 |
+
outputs = [outputs]
|
128 |
+
|
129 |
+
working_directory = Path().absolute()
|
130 |
+
|
131 |
+
if examples is None:
|
132 |
+
raise ValueError("The parameter `examples` cannot be None")
|
133 |
+
elif isinstance(examples, list) and (
|
134 |
+
len(examples) == 0 or isinstance(examples[0], list)
|
135 |
+
):
|
136 |
+
pass
|
137 |
+
elif (
|
138 |
+
isinstance(examples, list) and len(inputs) == 1
|
139 |
+
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
140 |
+
examples = [[e] for e in examples]
|
141 |
+
elif isinstance(examples, str):
|
142 |
+
if not os.path.exists(examples):
|
143 |
+
raise FileNotFoundError(
|
144 |
+
"Could not find examples directory: " + examples
|
145 |
+
)
|
146 |
+
working_directory = examples
|
147 |
+
if not os.path.exists(os.path.join(examples, LOG_FILE)):
|
148 |
+
if len(inputs) == 1:
|
149 |
+
examples = [[e] for e in os.listdir(examples)]
|
150 |
+
else:
|
151 |
+
raise FileNotFoundError(
|
152 |
+
"Could not find log file (required for multiple inputs): "
|
153 |
+
+ LOG_FILE
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
with open(os.path.join(examples, LOG_FILE)) as logs:
|
157 |
+
examples = list(csv.reader(logs))
|
158 |
+
examples = [
|
159 |
+
examples[i][: len(inputs)] for i in range(1, len(examples))
|
160 |
+
] # remove header and unnecessary columns
|
161 |
+
|
162 |
+
else:
|
163 |
+
raise ValueError(
|
164 |
+
"The parameter `examples` must either be a string directory or a list"
|
165 |
+
"(if there is only 1 input component) or (more generally), a nested "
|
166 |
+
"list, where each sublist represents a set of inputs."
|
167 |
+
)
|
168 |
+
|
169 |
+
input_has_examples = [False] * len(inputs)
|
170 |
+
for example in examples:
|
171 |
+
for idx, example_for_input in enumerate(example):
|
172 |
+
if not (example_for_input is None):
|
173 |
+
try:
|
174 |
+
input_has_examples[idx] = True
|
175 |
+
except IndexError:
|
176 |
+
pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
|
177 |
+
|
178 |
+
inputs_with_examples = [
|
179 |
+
inp for (inp, keep) in zip(inputs, input_has_examples) if keep
|
180 |
+
]
|
181 |
+
non_none_examples = [
|
182 |
+
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
|
183 |
+
for example in examples
|
184 |
+
]
|
185 |
+
|
186 |
+
self.examples = examples
|
187 |
+
self.non_none_examples = non_none_examples
|
188 |
+
self.inputs = inputs
|
189 |
+
self.inputs_with_examples = inputs_with_examples
|
190 |
+
self.outputs = outputs
|
191 |
+
self.fn = fn
|
192 |
+
self.cache_examples = cache_examples
|
193 |
+
self._api_mode = _api_mode
|
194 |
+
self.preprocess = preprocess
|
195 |
+
self.postprocess = postprocess
|
196 |
+
self.batch = batch
|
197 |
+
|
198 |
+
with utils.set_directory(working_directory):
|
199 |
+
self.processed_examples = [
|
200 |
+
[
|
201 |
+
component.postprocess(sample)
|
202 |
+
for component, sample in zip(inputs, example)
|
203 |
+
]
|
204 |
+
for example in examples
|
205 |
+
]
|
206 |
+
self.non_none_processed_examples = [
|
207 |
+
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
|
208 |
+
for example in self.processed_examples
|
209 |
+
]
|
210 |
+
if cache_examples:
|
211 |
+
for example in self.examples:
|
212 |
+
if len([ex for ex in example if ex is not None]) != len(self.inputs):
|
213 |
+
warnings.warn(
|
214 |
+
"Examples are being cached but not all input components have "
|
215 |
+
"example values. This may result in an exception being thrown by "
|
216 |
+
"your function. If you do get an error while caching examples, make "
|
217 |
+
"sure all of your inputs have example values for all of your examples "
|
218 |
+
"or you provide default values for those particular parameters in your function."
|
219 |
+
)
|
220 |
+
break
|
221 |
+
|
222 |
+
from gradio.components import Dataset
|
223 |
+
|
224 |
+
with utils.set_directory(working_directory):
|
225 |
+
self.dataset = Dataset(
|
226 |
+
components=inputs_with_examples,
|
227 |
+
samples=non_none_examples,
|
228 |
+
type="index",
|
229 |
+
label=label,
|
230 |
+
samples_per_page=examples_per_page,
|
231 |
+
elem_id=elem_id,
|
232 |
+
)
|
233 |
+
|
234 |
+
self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
|
235 |
+
self.cached_file = os.path.join(self.cached_folder, "log.csv")
|
236 |
+
self.cache_examples = cache_examples
|
237 |
+
self.run_on_click = run_on_click
|
238 |
+
|
239 |
+
async def create(self) -> None:
|
240 |
+
"""Caches the examples if self.cache_examples is True and creates the Dataset
|
241 |
+
component to hold the examples"""
|
242 |
+
|
243 |
+
async def load_example(example_id):
|
244 |
+
if self.cache_examples:
|
245 |
+
processed_example = self.non_none_processed_examples[
|
246 |
+
example_id
|
247 |
+
] + await self.load_from_cache(example_id)
|
248 |
+
else:
|
249 |
+
processed_example = self.non_none_processed_examples[example_id]
|
250 |
+
return utils.resolve_singleton(processed_example)
|
251 |
+
|
252 |
+
if Context.root_block:
|
253 |
+
self.dataset.click(
|
254 |
+
load_example,
|
255 |
+
inputs=[self.dataset],
|
256 |
+
outputs=self.inputs_with_examples
|
257 |
+
+ (self.outputs if self.cache_examples else []),
|
258 |
+
postprocess=False,
|
259 |
+
queue=False,
|
260 |
+
)
|
261 |
+
if self.run_on_click and not self.cache_examples:
|
262 |
+
self.dataset.click(
|
263 |
+
self.fn,
|
264 |
+
inputs=self.inputs,
|
265 |
+
outputs=self.outputs,
|
266 |
+
)
|
267 |
+
|
268 |
+
if self.cache_examples:
|
269 |
+
await self.cache()
|
270 |
+
|
271 |
+
async def cache(self) -> None:
|
272 |
+
"""
|
273 |
+
Caches all of the examples so that their predictions can be shown immediately.
|
274 |
+
"""
|
275 |
+
if os.path.exists(self.cached_file):
|
276 |
+
print(
|
277 |
+
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
if Context.root_block is None:
|
281 |
+
raise ValueError("Cannot cache examples if not in a Blocks context")
|
282 |
+
|
283 |
+
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
|
284 |
+
cache_logger = CSVLogger()
|
285 |
+
|
286 |
+
# create a fake dependency to process the examples and get the predictions
|
287 |
+
dependency = Context.root_block.set_event_trigger(
|
288 |
+
event_name="fake_event",
|
289 |
+
fn=self.fn,
|
290 |
+
inputs=self.inputs_with_examples,
|
291 |
+
outputs=self.outputs,
|
292 |
+
preprocess=self.preprocess and not self._api_mode,
|
293 |
+
postprocess=self.postprocess and not self._api_mode,
|
294 |
+
batch=self.batch,
|
295 |
+
)
|
296 |
+
|
297 |
+
fn_index = Context.root_block.dependencies.index(dependency)
|
298 |
+
cache_logger.setup(self.outputs, self.cached_folder)
|
299 |
+
for example_id, _ in enumerate(self.examples):
|
300 |
+
processed_input = self.processed_examples[example_id]
|
301 |
+
if self.batch:
|
302 |
+
processed_input = [[value] for value in processed_input]
|
303 |
+
prediction = await Context.root_block.process_api(
|
304 |
+
fn_index=fn_index, inputs=processed_input, request=None, state={}
|
305 |
+
)
|
306 |
+
output = prediction["data"]
|
307 |
+
if self.batch:
|
308 |
+
output = [value[0] for value in output]
|
309 |
+
cache_logger.flag(output)
|
310 |
+
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
311 |
+
Context.root_block.dependencies.remove(dependency)
|
312 |
+
Context.root_block.fns.pop(fn_index)
|
313 |
+
|
314 |
+
async def load_from_cache(self, example_id: int) -> List[Any]:
|
315 |
+
"""Loads a particular cached example for the interface.
|
316 |
+
Parameters:
|
317 |
+
example_id: The id of the example to process (zero-indexed).
|
318 |
+
"""
|
319 |
+
with open(self.cached_file) as cache:
|
320 |
+
examples = list(csv.reader(cache))
|
321 |
+
example = examples[example_id + 1] # +1 to adjust for header
|
322 |
+
output = []
|
323 |
+
for component, value in zip(self.outputs, example):
|
324 |
+
try:
|
325 |
+
value_as_dict = ast.literal_eval(value)
|
326 |
+
assert utils.is_update(value_as_dict)
|
327 |
+
output.append(value_as_dict)
|
328 |
+
except (ValueError, TypeError, SyntaxError, AssertionError):
|
329 |
+
output.append(component.serialize(value, self.cached_folder))
|
330 |
+
return output
|
331 |
+
|
332 |
+
|
333 |
+
class TrackedIterable:
|
334 |
+
def __init__(
|
335 |
+
self,
|
336 |
+
iterable: Iterable,
|
337 |
+
index: int | None,
|
338 |
+
length: int | None,
|
339 |
+
desc: str | None,
|
340 |
+
unit: str | None,
|
341 |
+
_tqdm=None,
|
342 |
+
progress: float = None,
|
343 |
+
) -> None:
|
344 |
+
self.iterable = iterable
|
345 |
+
self.index = index
|
346 |
+
self.length = length
|
347 |
+
self.desc = desc
|
348 |
+
self.unit = unit
|
349 |
+
self._tqdm = _tqdm
|
350 |
+
self.progress = progress
|
351 |
+
|
352 |
+
|
353 |
+
@document("__call__", "tqdm")
|
354 |
+
class Progress(Iterable):
|
355 |
+
"""
|
356 |
+
The Progress class provides a custom progress tracker that is used in a function signature.
|
357 |
+
To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance.
|
358 |
+
The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable.
|
359 |
+
The Progress tracker is currently only available with `queue()`.
|
360 |
+
Example:
|
361 |
+
import gradio as gr
|
362 |
+
import time
|
363 |
+
def my_function(x, progress=gr.Progress()):
|
364 |
+
progress(0, desc="Starting...")
|
365 |
+
time.sleep(1)
|
366 |
+
for i in progress.tqdm(range(100)):
|
367 |
+
time.sleep(0.1)
|
368 |
+
return x
|
369 |
+
gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch()
|
370 |
+
Demos: progress
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
track_tqdm: bool = False,
|
376 |
+
_active: bool = False,
|
377 |
+
_callback: Callable = None,
|
378 |
+
_event_id: str = None,
|
379 |
+
):
|
380 |
+
"""
|
381 |
+
Parameters:
|
382 |
+
track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function.
|
383 |
+
"""
|
384 |
+
self.track_tqdm = track_tqdm
|
385 |
+
self._active = _active
|
386 |
+
self._callback = _callback
|
387 |
+
self._event_id = _event_id
|
388 |
+
self.iterables: List[TrackedIterable] = []
|
389 |
+
|
390 |
+
def __len__(self):
|
391 |
+
return self.iterables[-1].length
|
392 |
+
|
393 |
+
def __iter__(self):
|
394 |
+
return self
|
395 |
+
|
396 |
+
def __next__(self):
|
397 |
+
"""
|
398 |
+
Updates progress tracker with next item in iterable.
|
399 |
+
"""
|
400 |
+
if self._active:
|
401 |
+
current_iterable = self.iterables[-1]
|
402 |
+
while (
|
403 |
+
not hasattr(current_iterable.iterable, "__next__")
|
404 |
+
and len(self.iterables) > 0
|
405 |
+
):
|
406 |
+
current_iterable = self.iterables.pop()
|
407 |
+
self._callback(
|
408 |
+
event_id=self._event_id,
|
409 |
+
iterables=self.iterables,
|
410 |
+
)
|
411 |
+
current_iterable.index += 1
|
412 |
+
try:
|
413 |
+
return next(current_iterable.iterable)
|
414 |
+
except StopIteration:
|
415 |
+
self.iterables.pop()
|
416 |
+
raise StopIteration
|
417 |
+
else:
|
418 |
+
return self
|
419 |
+
|
420 |
+
def __call__(
|
421 |
+
self,
|
422 |
+
progress: float | Tuple[int, int | None] | None,
|
423 |
+
desc: str | None = None,
|
424 |
+
total: float | None = None,
|
425 |
+
unit: str = "steps",
|
426 |
+
_tqdm=None,
|
427 |
+
):
|
428 |
+
"""
|
429 |
+
Updates progress tracker with progress and message text.
|
430 |
+
Parameters:
|
431 |
+
progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar.
|
432 |
+
desc: description to display.
|
433 |
+
total: estimated total number of steps.
|
434 |
+
unit: unit of iterations.
|
435 |
+
"""
|
436 |
+
if self._active:
|
437 |
+
if isinstance(progress, tuple):
|
438 |
+
index, total = progress
|
439 |
+
progress = None
|
440 |
+
else:
|
441 |
+
index = None
|
442 |
+
self._callback(
|
443 |
+
event_id=self._event_id,
|
444 |
+
iterables=self.iterables
|
445 |
+
+ [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)],
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
return progress
|
449 |
+
|
450 |
+
def tqdm(
|
451 |
+
self,
|
452 |
+
iterable: Iterable | None,
|
453 |
+
desc: str = None,
|
454 |
+
total: float = None,
|
455 |
+
unit: str = "steps",
|
456 |
+
_tqdm=None,
|
457 |
+
*args,
|
458 |
+
**kwargs,
|
459 |
+
):
|
460 |
+
"""
|
461 |
+
Attaches progress tracker to iterable, like tqdm.
|
462 |
+
Parameters:
|
463 |
+
iterable: iterable to attach progress tracker to.
|
464 |
+
desc: description to display.
|
465 |
+
total: estimated total number of steps.
|
466 |
+
unit: unit of iterations.
|
467 |
+
"""
|
468 |
+
if iterable is None:
|
469 |
+
new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm)
|
470 |
+
self.iterables.append(new_iterable)
|
471 |
+
self._callback(event_id=self._event_id, iterables=self.iterables)
|
472 |
+
return
|
473 |
+
length = len(iterable) if hasattr(iterable, "__len__") else None
|
474 |
+
self.iterables.append(
|
475 |
+
TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm)
|
476 |
+
)
|
477 |
+
return self
|
478 |
+
|
479 |
+
def update(self, n=1):
|
480 |
+
"""
|
481 |
+
Increases latest iterable with specified number of steps.
|
482 |
+
Parameters:
|
483 |
+
n: number of steps completed.
|
484 |
+
"""
|
485 |
+
if self._active and len(self.iterables) > 0:
|
486 |
+
current_iterable = self.iterables[-1]
|
487 |
+
current_iterable.index += n
|
488 |
+
self._callback(
|
489 |
+
event_id=self._event_id,
|
490 |
+
iterables=self.iterables,
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
return
|
494 |
+
|
495 |
+
def close(self, _tqdm):
|
496 |
+
"""
|
497 |
+
Removes iterable with given _tqdm.
|
498 |
+
"""
|
499 |
+
if self._active:
|
500 |
+
for i in range(len(self.iterables)):
|
501 |
+
if id(self.iterables[i]._tqdm) == id(_tqdm):
|
502 |
+
self.iterables.pop(i)
|
503 |
+
break
|
504 |
+
self._callback(
|
505 |
+
event_id=self._event_id,
|
506 |
+
iterables=self.iterables,
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
return
|
510 |
+
|
511 |
+
|
512 |
+
def create_tracker(root_blocks, event_id, fn, track_tqdm):
|
513 |
+
|
514 |
+
progress = Progress(
|
515 |
+
_active=True, _callback=root_blocks._queue.set_progress, _event_id=event_id
|
516 |
+
)
|
517 |
+
if not track_tqdm:
|
518 |
+
return progress, fn
|
519 |
+
|
520 |
+
try:
|
521 |
+
_tqdm = __import__("tqdm")
|
522 |
+
except ModuleNotFoundError:
|
523 |
+
return progress, fn
|
524 |
+
if not hasattr(root_blocks, "_progress_tracker_per_thread"):
|
525 |
+
root_blocks._progress_tracker_per_thread = {}
|
526 |
+
|
527 |
+
def init_tqdm(self, iterable=None, desc=None, *args, **kwargs):
|
528 |
+
self._progress = root_blocks._progress_tracker_per_thread.get(
|
529 |
+
threading.get_ident()
|
530 |
+
)
|
531 |
+
if self._progress is not None:
|
532 |
+
self._progress.event_id = event_id
|
533 |
+
self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
|
534 |
+
kwargs["file"] = open(os.devnull, "w")
|
535 |
+
self.__init__orig__(iterable, desc, *args, **kwargs)
|
536 |
+
|
537 |
+
def iter_tqdm(self):
|
538 |
+
if self._progress is not None:
|
539 |
+
return self._progress
|
540 |
+
else:
|
541 |
+
return self.__iter__orig__()
|
542 |
+
|
543 |
+
def update_tqdm(self, n=1):
|
544 |
+
if self._progress is not None:
|
545 |
+
self._progress.update(n)
|
546 |
+
return self.__update__orig__(n)
|
547 |
+
|
548 |
+
def close_tqdm(self):
|
549 |
+
if self._progress is not None:
|
550 |
+
self._progress.close(self)
|
551 |
+
return self.__close__orig__()
|
552 |
+
|
553 |
+
def exit_tqdm(self, exc_type, exc_value, traceback):
|
554 |
+
if self._progress is not None:
|
555 |
+
self._progress.close(self)
|
556 |
+
return self.__exit__orig__(exc_type, exc_value, traceback)
|
557 |
+
|
558 |
+
if not hasattr(_tqdm.tqdm, "__init__orig__"):
|
559 |
+
_tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__
|
560 |
+
_tqdm.tqdm.__init__ = init_tqdm
|
561 |
+
if not hasattr(_tqdm.tqdm, "__update__orig__"):
|
562 |
+
_tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update
|
563 |
+
_tqdm.tqdm.update = update_tqdm
|
564 |
+
if not hasattr(_tqdm.tqdm, "__close__orig__"):
|
565 |
+
_tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close
|
566 |
+
_tqdm.tqdm.close = close_tqdm
|
567 |
+
if not hasattr(_tqdm.tqdm, "__exit__orig__"):
|
568 |
+
_tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__
|
569 |
+
_tqdm.tqdm.__exit__ = exit_tqdm
|
570 |
+
if not hasattr(_tqdm.tqdm, "__iter__orig__"):
|
571 |
+
_tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__
|
572 |
+
_tqdm.tqdm.__iter__ = iter_tqdm
|
573 |
+
if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"):
|
574 |
+
_tqdm.auto.tqdm = _tqdm.tqdm
|
575 |
+
|
576 |
+
def tracked_fn(*args):
|
577 |
+
thread_id = threading.get_ident()
|
578 |
+
root_blocks._progress_tracker_per_thread[thread_id] = progress
|
579 |
+
response = fn(*args)
|
580 |
+
del root_blocks._progress_tracker_per_thread[thread_id]
|
581 |
+
return response
|
582 |
+
|
583 |
+
return progress, tracked_fn
|
584 |
+
|
585 |
+
|
586 |
+
def special_args(
|
587 |
+
fn: Callable,
|
588 |
+
inputs: List[Any] | None = None,
|
589 |
+
request: routes.Request | None = None,
|
590 |
+
):
|
591 |
+
"""
|
592 |
+
Checks if function has special arguments Request (via annotation) or Progress (via default value).
|
593 |
+
If inputs is provided, these values will be loaded into the inputs array.
|
594 |
+
Parameters:
|
595 |
+
block_fn: function to check.
|
596 |
+
inputs: array to load special arguments into.
|
597 |
+
request: request to load into inputs.
|
598 |
+
Returns:
|
599 |
+
updated inputs, request index, progress index
|
600 |
+
"""
|
601 |
+
signature = inspect.signature(fn)
|
602 |
+
positional_args = []
|
603 |
+
for i, param in enumerate(signature.parameters.values()):
|
604 |
+
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
605 |
+
break
|
606 |
+
positional_args.append(param)
|
607 |
+
progress_index = None
|
608 |
+
for i, param in enumerate(positional_args):
|
609 |
+
if isinstance(param.default, Progress):
|
610 |
+
progress_index = i
|
611 |
+
if inputs is not None:
|
612 |
+
inputs.insert(i, param.default)
|
613 |
+
elif param.annotation == routes.Request:
|
614 |
+
if inputs is not None:
|
615 |
+
inputs.insert(i, request)
|
616 |
+
if inputs is not None:
|
617 |
+
while len(inputs) < len(positional_args):
|
618 |
+
i = len(inputs)
|
619 |
+
param = positional_args[i]
|
620 |
+
if param.default == param.empty:
|
621 |
+
warnings.warn("Unexpected argument. Filling with None.")
|
622 |
+
inputs.append(None)
|
623 |
+
else:
|
624 |
+
inputs.append(param.default)
|
625 |
+
return inputs or [], progress_index
|
626 |
+
|
627 |
+
|
628 |
+
@document()
|
629 |
+
def update(**kwargs) -> dict:
|
630 |
+
"""
|
631 |
+
Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component.
|
632 |
+
This is a shorthand for using the update method on a component.
|
633 |
+
For example, rather than using gr.Number.update(...) you can just use gr.update(...).
|
634 |
+
Note that your editor's autocompletion will suggest proper parameters
|
635 |
+
if you use the update method on the component.
|
636 |
+
Demos: blocks_essay, blocks_update, blocks_essay_update
|
637 |
+
|
638 |
+
Parameters:
|
639 |
+
kwargs: Key-word arguments used to update the component's properties.
|
640 |
+
Example:
|
641 |
+
# Blocks Example
|
642 |
+
import gradio as gr
|
643 |
+
with gr.Blocks() as demo:
|
644 |
+
radio = gr.Radio([1, 2, 4], label="Set the value of the number")
|
645 |
+
number = gr.Number(value=2, interactive=True)
|
646 |
+
radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
|
647 |
+
demo.launch()
|
648 |
+
|
649 |
+
# Interface example
|
650 |
+
import gradio as gr
|
651 |
+
def change_textbox(choice):
|
652 |
+
if choice == "short":
|
653 |
+
return gr.Textbox.update(lines=2, visible=True)
|
654 |
+
elif choice == "long":
|
655 |
+
return gr.Textbox.update(lines=8, visible=True)
|
656 |
+
else:
|
657 |
+
return gr.Textbox.update(visible=False)
|
658 |
+
gr.Interface(
|
659 |
+
change_textbox,
|
660 |
+
gr.Radio(
|
661 |
+
["short", "long", "none"], label="What kind of essay would you like to write?"
|
662 |
+
),
|
663 |
+
gr.Textbox(lines=2),
|
664 |
+
live=True,
|
665 |
+
).launch()
|
666 |
+
"""
|
667 |
+
kwargs["__type__"] = "generic_update"
|
668 |
+
return kwargs
|
669 |
+
|
670 |
+
|
671 |
+
def skip() -> dict:
|
672 |
+
return update()
|
673 |
+
|
674 |
+
|
675 |
+
@document()
|
676 |
+
def make_waveform(
|
677 |
+
audio: str | Tuple[int, np.ndarray],
|
678 |
+
*,
|
679 |
+
bg_color: str = "#f3f4f6",
|
680 |
+
bg_image: str = None,
|
681 |
+
fg_alpha: float = 0.75,
|
682 |
+
bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
|
683 |
+
bar_count: int = 50,
|
684 |
+
bar_width: float = 0.6,
|
685 |
+
):
|
686 |
+
"""
|
687 |
+
Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
|
688 |
+
Parameters:
|
689 |
+
audio: Audio file path or tuple of (sample_rate, audio_data)
|
690 |
+
bg_color: Background color of waveform (ignored if bg_image is provided)
|
691 |
+
bg_image: Background image of waveform
|
692 |
+
fg_alpha: Opacity of foreground waveform
|
693 |
+
bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
|
694 |
+
bar_count: Number of bars in waveform
|
695 |
+
bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
|
696 |
+
Returns:
|
697 |
+
A filepath to the output video.
|
698 |
+
"""
|
699 |
+
if isinstance(audio, str):
|
700 |
+
audio_file = audio
|
701 |
+
audio = processing_utils.audio_from_file(audio)
|
702 |
+
else:
|
703 |
+
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
704 |
+
processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
|
705 |
+
audio_file = tmp_wav.name
|
706 |
+
duration = round(len(audio[1]) / audio[0], 4)
|
707 |
+
|
708 |
+
# Helper methods to create waveform
|
709 |
+
def hex_to_RGB(hex_str):
|
710 |
+
return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
|
711 |
+
|
712 |
+
def get_color_gradient(c1, c2, n):
|
713 |
+
assert n > 1
|
714 |
+
c1_rgb = np.array(hex_to_RGB(c1)) / 255
|
715 |
+
c2_rgb = np.array(hex_to_RGB(c2)) / 255
|
716 |
+
mix_pcts = [x / (n - 1) for x in range(n)]
|
717 |
+
rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
|
718 |
+
return [
|
719 |
+
"#" + "".join([format(int(round(val * 255)), "02x") for val in item])
|
720 |
+
for item in rgb_colors
|
721 |
+
]
|
722 |
+
|
723 |
+
# Reshape audio to have a fixed number of bars
|
724 |
+
samples = audio[1]
|
725 |
+
if len(samples.shape) > 1:
|
726 |
+
samples = np.mean(samples, 1)
|
727 |
+
bins_to_pad = bar_count - (len(samples) % bar_count)
|
728 |
+
samples = np.pad(samples, [(0, bins_to_pad)])
|
729 |
+
samples = np.reshape(samples, (bar_count, -1))
|
730 |
+
samples = np.abs(samples)
|
731 |
+
samples = np.max(samples, 1)
|
732 |
+
|
733 |
+
matplotlib.use("Agg")
|
734 |
+
plt.clf()
|
735 |
+
# Plot waveform
|
736 |
+
color = (
|
737 |
+
bars_color
|
738 |
+
if isinstance(bars_color, str)
|
739 |
+
else get_color_gradient(bars_color[0], bars_color[1], bar_count)
|
740 |
+
)
|
741 |
+
plt.bar(
|
742 |
+
np.arange(0, bar_count),
|
743 |
+
samples * 2,
|
744 |
+
bottom=(-1 * samples),
|
745 |
+
width=bar_width,
|
746 |
+
color=color,
|
747 |
+
)
|
748 |
+
plt.axis("off")
|
749 |
+
plt.margins(x=0)
|
750 |
+
tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
751 |
+
savefig_kwargs = {"bbox_inches": "tight"}
|
752 |
+
if bg_image is not None:
|
753 |
+
savefig_kwargs["transparent"] = True
|
754 |
+
else:
|
755 |
+
savefig_kwargs["facecolor"] = bg_color
|
756 |
+
plt.savefig(tmp_img.name, **savefig_kwargs)
|
757 |
+
waveform_img = PIL.Image.open(tmp_img.name)
|
758 |
+
waveform_img = waveform_img.resize((1000, 200))
|
759 |
+
|
760 |
+
# Composite waveform with background image
|
761 |
+
if bg_image is not None:
|
762 |
+
waveform_array = np.array(waveform_img)
|
763 |
+
waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
|
764 |
+
waveform_img = PIL.Image.fromarray(waveform_array)
|
765 |
+
|
766 |
+
bg_img = PIL.Image.open(bg_image)
|
767 |
+
waveform_width, waveform_height = waveform_img.size
|
768 |
+
bg_width, bg_height = bg_img.size
|
769 |
+
if waveform_width != bg_width:
|
770 |
+
bg_img = bg_img.resize(
|
771 |
+
(waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2))
|
772 |
+
)
|
773 |
+
bg_width, bg_height = bg_img.size
|
774 |
+
composite_height = max(bg_height, waveform_height)
|
775 |
+
composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF")
|
776 |
+
composite.paste(bg_img, (0, composite_height - bg_height))
|
777 |
+
composite.paste(
|
778 |
+
waveform_img, (0, composite_height - waveform_height), waveform_img
|
779 |
+
)
|
780 |
+
composite.save(tmp_img.name)
|
781 |
+
img_width, img_height = composite.size
|
782 |
+
else:
|
783 |
+
img_width, img_height = waveform_img.size
|
784 |
+
waveform_img.save(tmp_img.name)
|
785 |
+
|
786 |
+
# Convert waveform to video with ffmpeg
|
787 |
+
output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
788 |
+
|
789 |
+
ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}"""
|
790 |
+
|
791 |
+
subprocess.call(ffmpeg_cmd, shell=True)
|
792 |
+
return output_mp4.name
|
gradio-modified/gradio/inputs.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# type: ignore
|
2 |
+
"""
|
3 |
+
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
|
4 |
+
`InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are
|
5 |
+
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from __future__ import annotations
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
from typing import Any, List, Optional, Tuple
|
12 |
+
|
13 |
+
from gradio import components
|
14 |
+
|
15 |
+
|
16 |
+
class Textbox(components.Textbox):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
lines: int = 1,
|
20 |
+
placeholder: Optional[str] = None,
|
21 |
+
default: str = "",
|
22 |
+
numeric: Optional[bool] = False,
|
23 |
+
type: Optional[str] = "text",
|
24 |
+
label: Optional[str] = None,
|
25 |
+
optional: bool = False,
|
26 |
+
):
|
27 |
+
warnings.warn(
|
28 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
29 |
+
)
|
30 |
+
super().__init__(
|
31 |
+
value=default,
|
32 |
+
lines=lines,
|
33 |
+
placeholder=placeholder,
|
34 |
+
label=label,
|
35 |
+
numeric=numeric,
|
36 |
+
type=type,
|
37 |
+
optional=optional,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class Number(components.Number):
|
42 |
+
"""
|
43 |
+
Component creates a field for user to enter numeric input. Provides a number as an argument to the wrapped function.
|
44 |
+
Input type: float
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
default: Optional[float] = None,
|
50 |
+
label: Optional[str] = None,
|
51 |
+
optional: bool = False,
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Parameters:
|
55 |
+
default (float): default value.
|
56 |
+
label (str): component name in interface.
|
57 |
+
optional (bool): If True, the interface can be submitted with no value for this component.
|
58 |
+
"""
|
59 |
+
warnings.warn(
|
60 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
61 |
+
)
|
62 |
+
super().__init__(value=default, label=label, optional=optional)
|
63 |
+
|
64 |
+
|
65 |
+
class Slider(components.Slider):
|
66 |
+
"""
|
67 |
+
Component creates a slider that ranges from `minimum` to `maximum`. Provides number as an argument to the wrapped function.
|
68 |
+
Input type: float
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
minimum: float = 0,
|
74 |
+
maximum: float = 100,
|
75 |
+
step: Optional[float] = None,
|
76 |
+
default: Optional[float] = None,
|
77 |
+
label: Optional[str] = None,
|
78 |
+
optional: bool = False,
|
79 |
+
):
|
80 |
+
"""
|
81 |
+
Parameters:
|
82 |
+
minimum (float): minimum value for slider.
|
83 |
+
maximum (float): maximum value for slider.
|
84 |
+
step (float): increment between slider values.
|
85 |
+
default (float): default value.
|
86 |
+
label (str): component name in interface.
|
87 |
+
optional (bool): this parameter is ignored.
|
88 |
+
"""
|
89 |
+
warnings.warn(
|
90 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
91 |
+
)
|
92 |
+
|
93 |
+
super().__init__(
|
94 |
+
value=default,
|
95 |
+
minimum=minimum,
|
96 |
+
maximum=maximum,
|
97 |
+
step=step,
|
98 |
+
label=label,
|
99 |
+
optional=optional,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class Checkbox(components.Checkbox):
|
104 |
+
"""
|
105 |
+
Component creates a checkbox that can be set to `True` or `False`. Provides a boolean as an argument to the wrapped function.
|
106 |
+
Input type: bool
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
default: bool = False,
|
112 |
+
label: Optional[str] = None,
|
113 |
+
optional: bool = False,
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
Parameters:
|
117 |
+
label (str): component name in interface.
|
118 |
+
default (bool): if True, checked by default.
|
119 |
+
optional (bool): this parameter is ignored.
|
120 |
+
"""
|
121 |
+
warnings.warn(
|
122 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
123 |
+
)
|
124 |
+
super().__init__(value=default, label=label, optional=optional)
|
125 |
+
|
126 |
+
|
127 |
+
class CheckboxGroup(components.CheckboxGroup):
|
128 |
+
"""
|
129 |
+
Component creates a set of checkboxes of which a subset can be selected. Provides a list of strings representing the selected choices as an argument to the wrapped function.
|
130 |
+
Input type: Union[List[str], List[int]]
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
choices: List[str],
|
136 |
+
default: List[str] = [],
|
137 |
+
type: str = "value",
|
138 |
+
label: Optional[str] = None,
|
139 |
+
optional: bool = False,
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
Parameters:
|
143 |
+
choices (List[str]): list of options to select from.
|
144 |
+
default (List[str]): default selected list of options.
|
145 |
+
type (str): Type of value to be returned by component. "value" returns the list of strings of the choices selected, "index" returns the list of indicies of the choices selected.
|
146 |
+
label (str): component name in interface.
|
147 |
+
optional (bool): this parameter is ignored.
|
148 |
+
"""
|
149 |
+
warnings.warn(
|
150 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
151 |
+
)
|
152 |
+
super().__init__(
|
153 |
+
value=default,
|
154 |
+
choices=choices,
|
155 |
+
type=type,
|
156 |
+
label=label,
|
157 |
+
optional=optional,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
class Radio(components.Radio):
|
162 |
+
"""
|
163 |
+
Component creates a set of radio buttons of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
|
164 |
+
Input type: Union[str, int]
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
choices: List[str],
|
170 |
+
type: str = "value",
|
171 |
+
default: Optional[str] = None,
|
172 |
+
label: Optional[str] = None,
|
173 |
+
optional: bool = False,
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Parameters:
|
177 |
+
choices (List[str]): list of options to select from.
|
178 |
+
type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
|
179 |
+
default (str): the button selected by default. If None, no button is selected by default.
|
180 |
+
label (str): component name in interface.
|
181 |
+
optional (bool): this parameter is ignored.
|
182 |
+
"""
|
183 |
+
warnings.warn(
|
184 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
185 |
+
)
|
186 |
+
super().__init__(
|
187 |
+
choices=choices,
|
188 |
+
type=type,
|
189 |
+
value=default,
|
190 |
+
label=label,
|
191 |
+
optional=optional,
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
class Dropdown(components.Dropdown):
|
196 |
+
"""
|
197 |
+
Component creates a dropdown of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
|
198 |
+
Input type: Union[str, int]
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
choices: List[str],
|
204 |
+
type: str = "value",
|
205 |
+
default: Optional[str] = None,
|
206 |
+
label: Optional[str] = None,
|
207 |
+
optional: bool = False,
|
208 |
+
):
|
209 |
+
"""
|
210 |
+
Parameters:
|
211 |
+
choices (List[str]): list of options to select from.
|
212 |
+
type (str): Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
|
213 |
+
default (str): default value selected in dropdown. If None, no value is selected by default.
|
214 |
+
label (str): component name in interface.
|
215 |
+
optional (bool): this parameter is ignored.
|
216 |
+
"""
|
217 |
+
warnings.warn(
|
218 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
219 |
+
)
|
220 |
+
super().__init__(
|
221 |
+
choices=choices,
|
222 |
+
type=type,
|
223 |
+
value=default,
|
224 |
+
label=label,
|
225 |
+
optional=optional,
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
class Image(components.Image):
|
230 |
+
"""
|
231 |
+
Component creates an image upload box with editing capabilities.
|
232 |
+
Input type: Union[numpy.array, PIL.Image, file-object]
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
shape: Tuple[int, int] = None,
|
238 |
+
image_mode: str = "RGB",
|
239 |
+
invert_colors: bool = False,
|
240 |
+
source: str = "upload",
|
241 |
+
tool: str = "editor",
|
242 |
+
type: str = "numpy",
|
243 |
+
label: str = None,
|
244 |
+
optional: bool = False,
|
245 |
+
):
|
246 |
+
"""
|
247 |
+
Parameters:
|
248 |
+
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
|
249 |
+
image_mode (str): How to process the uploaded image. Accepts any of the PIL image modes, e.g. "RGB" for color images, "RGBA" to include the transparency mask, "L" for black-and-white images.
|
250 |
+
invert_colors (bool): whether to invert the image as a preprocessing step.
|
251 |
+
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
252 |
+
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
|
253 |
+
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
254 |
+
label (str): component name in interface.
|
255 |
+
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
256 |
+
"""
|
257 |
+
warnings.warn(
|
258 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
259 |
+
)
|
260 |
+
super().__init__(
|
261 |
+
shape=shape,
|
262 |
+
image_mode=image_mode,
|
263 |
+
invert_colors=invert_colors,
|
264 |
+
source=source,
|
265 |
+
tool=tool,
|
266 |
+
type=type,
|
267 |
+
label=label,
|
268 |
+
optional=optional,
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
class Video(components.Video):
|
273 |
+
"""
|
274 |
+
Component creates a video file upload that is converted to a file path.
|
275 |
+
|
276 |
+
Input type: filepath
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(
|
280 |
+
self,
|
281 |
+
type: Optional[str] = None,
|
282 |
+
source: str = "upload",
|
283 |
+
label: Optional[str] = None,
|
284 |
+
optional: bool = False,
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
Parameters:
|
288 |
+
type (str): Type of video format to be returned by component, such as 'avi' or 'mp4'. If set to None, video will keep uploaded format.
|
289 |
+
source (str): Source of video. "upload" creates a box where user can drop an video file, "webcam" allows user to record a video from their webcam.
|
290 |
+
label (str): component name in interface.
|
291 |
+
optional (bool): If True, the interface can be submitted with no uploaded video, in which case the input value is None.
|
292 |
+
"""
|
293 |
+
warnings.warn(
|
294 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
295 |
+
)
|
296 |
+
super().__init__(format=type, source=source, label=label, optional=optional)
|
297 |
+
|
298 |
+
|
299 |
+
class Audio(components.Audio):
|
300 |
+
"""
|
301 |
+
Component accepts audio input files.
|
302 |
+
Input type: Union[Tuple[int, numpy.array], file-object, numpy.array]
|
303 |
+
"""
|
304 |
+
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
source: str = "upload",
|
308 |
+
type: str = "numpy",
|
309 |
+
label: str = None,
|
310 |
+
optional: bool = False,
|
311 |
+
):
|
312 |
+
"""
|
313 |
+
Parameters:
|
314 |
+
source (str): Source of audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input.
|
315 |
+
type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
316 |
+
label (str): component name in interface.
|
317 |
+
optional (bool): If True, the interface can be submitted with no uploaded audio, in which case the input value is None.
|
318 |
+
"""
|
319 |
+
warnings.warn(
|
320 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
321 |
+
)
|
322 |
+
super().__init__(source=source, type=type, label=label, optional=optional)
|
323 |
+
|
324 |
+
|
325 |
+
class File(components.File):
|
326 |
+
"""
|
327 |
+
Component accepts generic file uploads.
|
328 |
+
Input type: Union[file-object, bytes, List[Union[file-object, bytes]]]
|
329 |
+
"""
|
330 |
+
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
file_count: str = "single",
|
334 |
+
type: str = "file",
|
335 |
+
label: Optional[str] = None,
|
336 |
+
keep_filename: bool = True,
|
337 |
+
optional: bool = False,
|
338 |
+
):
|
339 |
+
"""
|
340 |
+
Parameters:
|
341 |
+
file_count (str): if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
|
342 |
+
type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object.
|
343 |
+
label (str): component name in interface.
|
344 |
+
keep_filename (bool): DEPRECATED. Original filename always kept.
|
345 |
+
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
346 |
+
"""
|
347 |
+
warnings.warn(
|
348 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
349 |
+
)
|
350 |
+
super().__init__(
|
351 |
+
file_count=file_count,
|
352 |
+
type=type,
|
353 |
+
label=label,
|
354 |
+
keep_filename=keep_filename,
|
355 |
+
optional=optional,
|
356 |
+
)
|
357 |
+
|
358 |
+
|
359 |
+
class Dataframe(components.Dataframe):
|
360 |
+
"""
|
361 |
+
Component accepts 2D input through a spreadsheet interface.
|
362 |
+
Input type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(
|
366 |
+
self,
|
367 |
+
headers: Optional[List[str]] = None,
|
368 |
+
row_count: int = 3,
|
369 |
+
col_count: Optional[int] = 3,
|
370 |
+
datatype: str | List[str] = "str",
|
371 |
+
col_width: int | List[int] = None,
|
372 |
+
default: Optional[List[List[Any]]] = None,
|
373 |
+
type: str = "pandas",
|
374 |
+
label: Optional[str] = None,
|
375 |
+
optional: bool = False,
|
376 |
+
):
|
377 |
+
"""
|
378 |
+
Parameters:
|
379 |
+
headers (List[str]): Header names to dataframe. If None, no headers are shown.
|
380 |
+
row_count (int): Limit number of rows for input.
|
381 |
+
col_count (int): Limit number of columns for input. If equal to 1, return data will be one-dimensional. Ignored if `headers` is provided.
|
382 |
+
datatype (Union[str, List[str]]): Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", and "date".
|
383 |
+
col_width (Union[int, List[int]]): Width of columns in pixels. Can be provided as single value or list of values per column.
|
384 |
+
default (List[List[Any]]): Default value
|
385 |
+
type (str): Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array.
|
386 |
+
label (str): component name in interface.
|
387 |
+
optional (bool): this parameter is ignored.
|
388 |
+
"""
|
389 |
+
warnings.warn(
|
390 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
391 |
+
)
|
392 |
+
super().__init__(
|
393 |
+
value=default,
|
394 |
+
headers=headers,
|
395 |
+
row_count=row_count,
|
396 |
+
col_count=col_count,
|
397 |
+
datatype=datatype,
|
398 |
+
col_width=col_width,
|
399 |
+
type=type,
|
400 |
+
label=label,
|
401 |
+
optional=optional,
|
402 |
+
)
|
403 |
+
|
404 |
+
|
405 |
+
class Timeseries(components.Timeseries):
|
406 |
+
"""
|
407 |
+
Component accepts pandas.DataFrame uploaded as a timeseries csv file.
|
408 |
+
Input type: pandas.DataFrame
|
409 |
+
"""
|
410 |
+
|
411 |
+
def __init__(
|
412 |
+
self,
|
413 |
+
x: Optional[str] = None,
|
414 |
+
y: str | List[str] = None,
|
415 |
+
label: Optional[str] = None,
|
416 |
+
optional: bool = False,
|
417 |
+
):
|
418 |
+
"""
|
419 |
+
Parameters:
|
420 |
+
x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
|
421 |
+
y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
|
422 |
+
label (str): component name in interface.
|
423 |
+
optional (bool): If True, the interface can be submitted with no uploaded csv file, in which case the input value is None.
|
424 |
+
"""
|
425 |
+
warnings.warn(
|
426 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
427 |
+
)
|
428 |
+
super().__init__(x=x, y=y, label=label, optional=optional)
|
429 |
+
|
430 |
+
|
431 |
+
class State(components.State):
|
432 |
+
"""
|
433 |
+
Special hidden component that stores state across runs of the interface.
|
434 |
+
Input type: Any
|
435 |
+
"""
|
436 |
+
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
label: str = None,
|
440 |
+
default: Any = None,
|
441 |
+
):
|
442 |
+
"""
|
443 |
+
Parameters:
|
444 |
+
label (str): component name in interface (not used).
|
445 |
+
default (Any): the initial value of the state.
|
446 |
+
optional (bool): this parameter is ignored.
|
447 |
+
"""
|
448 |
+
warnings.warn(
|
449 |
+
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
|
450 |
+
)
|
451 |
+
super().__init__(value=default, label=label)
|
452 |
+
|
453 |
+
|
454 |
+
class Image3D(components.Model3D):
|
455 |
+
"""
|
456 |
+
Used for 3D image model output.
|
457 |
+
Input type: File object of type (.obj, glb, or .gltf)
|
458 |
+
"""
|
459 |
+
|
460 |
+
def __init__(
|
461 |
+
self,
|
462 |
+
label: Optional[str] = None,
|
463 |
+
optional: bool = False,
|
464 |
+
):
|
465 |
+
"""
|
466 |
+
Parameters:
|
467 |
+
label (str): component name in interface.
|
468 |
+
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
469 |
+
"""
|
470 |
+
warnings.warn(
|
471 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
472 |
+
)
|
473 |
+
super().__init__(label=label, optional=optional)
|
gradio-modified/gradio/interface.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is the core file in the `gradio` package, and defines the Interface class,
|
3 |
+
including various methods for constructing an interface and then launching it.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
import inspect
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import pkgutil
|
12 |
+
import re
|
13 |
+
import warnings
|
14 |
+
import weakref
|
15 |
+
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
16 |
+
|
17 |
+
from markdown_it import MarkdownIt
|
18 |
+
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
|
19 |
+
from mdit_py_plugins.footnote.index import footnote_plugin
|
20 |
+
|
21 |
+
from gradio import Examples, interpretation, utils
|
22 |
+
from gradio.blocks import Blocks
|
23 |
+
from gradio.components import (
|
24 |
+
Button,
|
25 |
+
Interpretation,
|
26 |
+
IOComponent,
|
27 |
+
Markdown,
|
28 |
+
State,
|
29 |
+
get_component_instance,
|
30 |
+
)
|
31 |
+
from gradio.data_classes import InterfaceTypes
|
32 |
+
from gradio.documentation import document, set_documentation_group
|
33 |
+
from gradio.events import Changeable, Streamable
|
34 |
+
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
35 |
+
from gradio.layouts import Column, Row, Tab, Tabs
|
36 |
+
from gradio.pipelines import load_from_pipeline
|
37 |
+
|
38 |
+
set_documentation_group("interface")
|
39 |
+
|
40 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
41 |
+
from transformers.pipelines.base import Pipeline
|
42 |
+
|
43 |
+
|
44 |
+
@document("launch", "load", "from_pipeline", "integrate", "queue")
|
45 |
+
class Interface(Blocks):
|
46 |
+
"""
|
47 |
+
Interface is Gradio's main high-level class, and allows you to create a web-based GUI / demo
|
48 |
+
around a machine learning model (or any Python function) in a few lines of code.
|
49 |
+
You must specify three parameters: (1) the function to create a GUI for (2) the desired input components and
|
50 |
+
(3) the desired output components. Additional parameters can be used to control the appearance
|
51 |
+
and behavior of the demo.
|
52 |
+
|
53 |
+
Example:
|
54 |
+
import gradio as gr
|
55 |
+
|
56 |
+
def image_classifier(inp):
|
57 |
+
return {'cat': 0.3, 'dog': 0.7}
|
58 |
+
|
59 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
|
60 |
+
demo.launch()
|
61 |
+
Demos: hello_world, hello_world_3, gpt_j
|
62 |
+
Guides: quickstart, key_features, sharing_your_app, interface_state, reactive_interfaces, advanced_interface_features, setting_up_a_gradio_demo_for_maximum_performance
|
63 |
+
"""
|
64 |
+
|
65 |
+
# stores references to all currently existing Interface instances
|
66 |
+
instances: weakref.WeakSet = weakref.WeakSet()
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def get_instances(cls) -> List[Interface]:
|
70 |
+
"""
|
71 |
+
:return: list of all current instances.
|
72 |
+
"""
|
73 |
+
return list(Interface.instances)
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def load(
|
77 |
+
cls,
|
78 |
+
name: str,
|
79 |
+
src: str | None = None,
|
80 |
+
api_key: str | None = None,
|
81 |
+
alias: str | None = None,
|
82 |
+
**kwargs,
|
83 |
+
) -> Interface:
|
84 |
+
"""
|
85 |
+
Class method that constructs an Interface from a Hugging Face repo. Can accept
|
86 |
+
model repos (if src is "models") or Space repos (if src is "spaces"). The input
|
87 |
+
and output components are automatically loaded from the repo.
|
88 |
+
Parameters:
|
89 |
+
name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
|
90 |
+
src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
|
91 |
+
api_key: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
|
92 |
+
alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
|
93 |
+
Returns:
|
94 |
+
a Gradio Interface object for the given model
|
95 |
+
Example:
|
96 |
+
import gradio as gr
|
97 |
+
description = "Story generation with GPT"
|
98 |
+
examples = [["An adventurer is approached by a mysterious stranger in the tavern for a new quest."]]
|
99 |
+
demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples)
|
100 |
+
demo.launch()
|
101 |
+
"""
|
102 |
+
return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
|
106 |
+
"""
|
107 |
+
Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
|
108 |
+
The input and output components are automatically determined from the pipeline.
|
109 |
+
Parameters:
|
110 |
+
pipeline: the pipeline object to use.
|
111 |
+
Returns:
|
112 |
+
a Gradio Interface object from the given Pipeline
|
113 |
+
Example:
|
114 |
+
import gradio as gr
|
115 |
+
from transformers import pipeline
|
116 |
+
pipe = pipeline("image-classification")
|
117 |
+
gr.Interface.from_pipeline(pipe).launch()
|
118 |
+
"""
|
119 |
+
interface_info = load_from_pipeline(pipeline)
|
120 |
+
kwargs = dict(interface_info, **kwargs)
|
121 |
+
interface = cls(**kwargs)
|
122 |
+
return interface
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
fn: Callable,
|
127 |
+
inputs: str | IOComponent | List[str | IOComponent] | None,
|
128 |
+
outputs: str | IOComponent | List[str | IOComponent] | None,
|
129 |
+
examples: List[Any] | List[List[Any]] | str | None = None,
|
130 |
+
cache_examples: bool | None = None,
|
131 |
+
examples_per_page: int = 10,
|
132 |
+
live: bool = False,
|
133 |
+
interpretation: Callable | str | None = None,
|
134 |
+
num_shap: float = 2.0,
|
135 |
+
title: str | None = None,
|
136 |
+
description: str | None = None,
|
137 |
+
article: str | None = None,
|
138 |
+
thumbnail: str | None = None,
|
139 |
+
theme: str = "default",
|
140 |
+
css: str | None = None,
|
141 |
+
allow_flagging: str | None = None,
|
142 |
+
flagging_options: List[str] | None = None,
|
143 |
+
flagging_dir: str = "flagged",
|
144 |
+
flagging_callback: FlaggingCallback = CSVLogger(),
|
145 |
+
analytics_enabled: bool | None = None,
|
146 |
+
batch: bool = False,
|
147 |
+
max_batch_size: int = 4,
|
148 |
+
_api_mode: bool = False,
|
149 |
+
**kwargs,
|
150 |
+
):
|
151 |
+
"""
|
152 |
+
Parameters:
|
153 |
+
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
154 |
+
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
|
155 |
+
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
|
156 |
+
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
157 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
158 |
+
examples_per_page: If examples are provided, how many to display per page.
|
159 |
+
live: whether the interface should automatically rerun if any of the inputs change.
|
160 |
+
interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
|
161 |
+
num_shap: a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
|
162 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
163 |
+
description: a description for the interface; if provided, appears above the input and output components and beneath the title in regular font. Accepts Markdown and HTML content.
|
164 |
+
article: an expanded article explaining the interface; if provided, appears below the input and output components in regular font. Accepts Markdown and HTML content.
|
165 |
+
thumbnail: path or url to image to use as display image when the web demo is shared on social media.
|
166 |
+
theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable.
|
167 |
+
css: custom css or path to custom css file to use with interface.
|
168 |
+
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
|
169 |
+
flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual".
|
170 |
+
flagging_dir: what to name the directory where flagged data is stored.
|
171 |
+
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
|
172 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
173 |
+
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
174 |
+
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
175 |
+
"""
|
176 |
+
super().__init__(
|
177 |
+
analytics_enabled=analytics_enabled,
|
178 |
+
mode="interface",
|
179 |
+
css=css,
|
180 |
+
title=title or "Gradio",
|
181 |
+
theme=theme,
|
182 |
+
**kwargs,
|
183 |
+
)
|
184 |
+
|
185 |
+
if isinstance(fn, list):
|
186 |
+
raise DeprecationWarning(
|
187 |
+
"The `fn` parameter only accepts a single function, support for a list "
|
188 |
+
"of functions has been deprecated. Please use gradio.mix.Parallel "
|
189 |
+
"instead."
|
190 |
+
)
|
191 |
+
|
192 |
+
self.interface_type = InterfaceTypes.STANDARD
|
193 |
+
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
|
194 |
+
raise ValueError("Must provide at least one of `inputs` or `outputs`")
|
195 |
+
elif outputs is None or outputs == []:
|
196 |
+
outputs = []
|
197 |
+
self.interface_type = InterfaceTypes.INPUT_ONLY
|
198 |
+
elif inputs is None or inputs == []:
|
199 |
+
inputs = []
|
200 |
+
self.interface_type = InterfaceTypes.OUTPUT_ONLY
|
201 |
+
|
202 |
+
assert isinstance(inputs, (str, list, IOComponent))
|
203 |
+
assert isinstance(outputs, (str, list, IOComponent))
|
204 |
+
|
205 |
+
if not isinstance(inputs, list):
|
206 |
+
inputs = [inputs]
|
207 |
+
if not isinstance(outputs, list):
|
208 |
+
outputs = [outputs]
|
209 |
+
|
210 |
+
if self.is_space and cache_examples is None:
|
211 |
+
self.cache_examples = True
|
212 |
+
else:
|
213 |
+
self.cache_examples = cache_examples or False
|
214 |
+
|
215 |
+
state_input_indexes = [
|
216 |
+
idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
|
217 |
+
]
|
218 |
+
state_output_indexes = [
|
219 |
+
idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
|
220 |
+
]
|
221 |
+
|
222 |
+
if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
|
223 |
+
pass
|
224 |
+
elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
|
225 |
+
raise ValueError(
|
226 |
+
"If using 'state', there must be exactly one state input and one state output."
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
state_input_index = state_input_indexes[0]
|
230 |
+
state_output_index = state_output_indexes[0]
|
231 |
+
if inputs[state_input_index] == "state":
|
232 |
+
default = utils.get_default_args(fn)[state_input_index]
|
233 |
+
state_variable = State(value=default) # type: ignore
|
234 |
+
else:
|
235 |
+
state_variable = inputs[state_input_index]
|
236 |
+
|
237 |
+
inputs[state_input_index] = state_variable
|
238 |
+
outputs[state_output_index] = state_variable
|
239 |
+
|
240 |
+
if cache_examples:
|
241 |
+
warnings.warn(
|
242 |
+
"Cache examples cannot be used with state inputs and outputs."
|
243 |
+
"Setting cache_examples to False."
|
244 |
+
)
|
245 |
+
self.cache_examples = False
|
246 |
+
|
247 |
+
self.input_components = [
|
248 |
+
get_component_instance(i, render=False) for i in inputs
|
249 |
+
]
|
250 |
+
self.output_components = [
|
251 |
+
get_component_instance(o, render=False) for o in outputs
|
252 |
+
]
|
253 |
+
|
254 |
+
for component in self.input_components + self.output_components:
|
255 |
+
if not (isinstance(component, IOComponent)):
|
256 |
+
raise ValueError(
|
257 |
+
f"{component} is not a valid input/output component for Interface."
|
258 |
+
)
|
259 |
+
|
260 |
+
if len(self.input_components) == len(self.output_components):
|
261 |
+
same_components = [
|
262 |
+
i is o for i, o in zip(self.input_components, self.output_components)
|
263 |
+
]
|
264 |
+
if all(same_components):
|
265 |
+
self.interface_type = InterfaceTypes.UNIFIED
|
266 |
+
|
267 |
+
if self.interface_type in [
|
268 |
+
InterfaceTypes.STANDARD,
|
269 |
+
InterfaceTypes.OUTPUT_ONLY,
|
270 |
+
]:
|
271 |
+
for o in self.output_components:
|
272 |
+
assert isinstance(o, IOComponent)
|
273 |
+
o.interactive = False # Force output components to be non-interactive
|
274 |
+
|
275 |
+
if (
|
276 |
+
interpretation is None
|
277 |
+
or isinstance(interpretation, list)
|
278 |
+
or callable(interpretation)
|
279 |
+
):
|
280 |
+
self.interpretation = interpretation
|
281 |
+
elif isinstance(interpretation, str):
|
282 |
+
self.interpretation = [
|
283 |
+
interpretation.lower() for _ in self.input_components
|
284 |
+
]
|
285 |
+
else:
|
286 |
+
raise ValueError("Invalid value for parameter: interpretation")
|
287 |
+
|
288 |
+
self.api_mode = _api_mode
|
289 |
+
self.fn = fn
|
290 |
+
self.fn_durations = [0, 0]
|
291 |
+
self.__name__ = getattr(fn, "__name__", "fn")
|
292 |
+
self.live = live
|
293 |
+
self.title = title
|
294 |
+
|
295 |
+
CLEANER = re.compile("<.*?>")
|
296 |
+
|
297 |
+
def clean_html(raw_html):
|
298 |
+
cleantext = re.sub(CLEANER, "", raw_html)
|
299 |
+
return cleantext
|
300 |
+
|
301 |
+
md = (
|
302 |
+
MarkdownIt(
|
303 |
+
"js-default",
|
304 |
+
{
|
305 |
+
"linkify": True,
|
306 |
+
"typographer": True,
|
307 |
+
"html": True,
|
308 |
+
},
|
309 |
+
)
|
310 |
+
.use(dollarmath_plugin)
|
311 |
+
.use(footnote_plugin)
|
312 |
+
.enable("table")
|
313 |
+
)
|
314 |
+
|
315 |
+
simple_description = None
|
316 |
+
if description is not None:
|
317 |
+
description = md.render(description)
|
318 |
+
simple_description = clean_html(description)
|
319 |
+
self.simple_description = simple_description
|
320 |
+
self.description = description
|
321 |
+
if article is not None:
|
322 |
+
article = utils.readme_to_html(article)
|
323 |
+
article = md.render(article)
|
324 |
+
self.article = article
|
325 |
+
|
326 |
+
self.thumbnail = thumbnail
|
327 |
+
self.theme = theme or os.getenv("GRADIO_THEME", "default")
|
328 |
+
if not (self.theme == "default"):
|
329 |
+
warnings.warn("Currently, only the 'default' theme is supported.")
|
330 |
+
|
331 |
+
self.examples = examples
|
332 |
+
self.num_shap = num_shap
|
333 |
+
self.examples_per_page = examples_per_page
|
334 |
+
|
335 |
+
self.simple_server = None
|
336 |
+
|
337 |
+
# For analytics_enabled and allow_flagging: (1) first check for
|
338 |
+
# parameter, (2) check for env variable, (3) default to True/"manual"
|
339 |
+
self.analytics_enabled = (
|
340 |
+
analytics_enabled
|
341 |
+
if analytics_enabled is not None
|
342 |
+
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
343 |
+
)
|
344 |
+
if allow_flagging is None:
|
345 |
+
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
|
346 |
+
if allow_flagging is True:
|
347 |
+
warnings.warn(
|
348 |
+
"The `allow_flagging` parameter in `Interface` now"
|
349 |
+
"takes a string value ('auto', 'manual', or 'never')"
|
350 |
+
", not a boolean. Setting parameter to: 'manual'."
|
351 |
+
)
|
352 |
+
self.allow_flagging = "manual"
|
353 |
+
elif allow_flagging == "manual":
|
354 |
+
self.allow_flagging = "manual"
|
355 |
+
elif allow_flagging is False:
|
356 |
+
warnings.warn(
|
357 |
+
"The `allow_flagging` parameter in `Interface` now"
|
358 |
+
"takes a string value ('auto', 'manual', or 'never')"
|
359 |
+
", not a boolean. Setting parameter to: 'never'."
|
360 |
+
)
|
361 |
+
self.allow_flagging = "never"
|
362 |
+
elif allow_flagging == "never":
|
363 |
+
self.allow_flagging = "never"
|
364 |
+
elif allow_flagging == "auto":
|
365 |
+
self.allow_flagging = "auto"
|
366 |
+
else:
|
367 |
+
raise ValueError(
|
368 |
+
"Invalid value for `allow_flagging` parameter."
|
369 |
+
"Must be: 'auto', 'manual', or 'never'."
|
370 |
+
)
|
371 |
+
|
372 |
+
self.flagging_options = flagging_options
|
373 |
+
self.flagging_callback = flagging_callback
|
374 |
+
self.flagging_dir = flagging_dir
|
375 |
+
self.batch = batch
|
376 |
+
self.max_batch_size = max_batch_size
|
377 |
+
|
378 |
+
self.save_to = None # Used for selenium tests
|
379 |
+
self.share = None
|
380 |
+
self.share_url = None
|
381 |
+
self.local_url = None
|
382 |
+
|
383 |
+
self.favicon_path = None
|
384 |
+
|
385 |
+
if self.analytics_enabled:
|
386 |
+
data = {
|
387 |
+
"mode": self.mode,
|
388 |
+
"fn": fn,
|
389 |
+
"inputs": inputs,
|
390 |
+
"outputs": outputs,
|
391 |
+
"live": live,
|
392 |
+
"ip_address": self.ip_address,
|
393 |
+
"interpretation": interpretation,
|
394 |
+
"allow_flagging": allow_flagging,
|
395 |
+
"custom_css": self.css is not None,
|
396 |
+
"theme": self.theme,
|
397 |
+
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
|
398 |
+
.decode("ascii")
|
399 |
+
.strip(),
|
400 |
+
}
|
401 |
+
utils.initiated_analytics(data)
|
402 |
+
|
403 |
+
utils.version_check()
|
404 |
+
Interface.instances.add(self)
|
405 |
+
|
406 |
+
param_names = inspect.getfullargspec(self.fn)[0]
|
407 |
+
for component, param_name in zip(self.input_components, param_names):
|
408 |
+
assert isinstance(component, IOComponent)
|
409 |
+
if component.label is None:
|
410 |
+
component.label = param_name
|
411 |
+
for i, component in enumerate(self.output_components):
|
412 |
+
assert isinstance(component, IOComponent)
|
413 |
+
if component.label is None:
|
414 |
+
if len(self.output_components) == 1:
|
415 |
+
component.label = "output"
|
416 |
+
else:
|
417 |
+
component.label = "output " + str(i)
|
418 |
+
|
419 |
+
if self.allow_flagging != "never":
|
420 |
+
if (
|
421 |
+
self.interface_type == InterfaceTypes.UNIFIED
|
422 |
+
or self.allow_flagging == "auto"
|
423 |
+
):
|
424 |
+
self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
|
425 |
+
elif self.interface_type == InterfaceTypes.INPUT_ONLY:
|
426 |
+
pass
|
427 |
+
else:
|
428 |
+
self.flagging_callback.setup(
|
429 |
+
self.input_components + self.output_components, self.flagging_dir # type: ignore
|
430 |
+
)
|
431 |
+
|
432 |
+
# Render the Gradio UI
|
433 |
+
with self:
|
434 |
+
self.render_title_description()
|
435 |
+
|
436 |
+
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
|
437 |
+
interpretation_btn, interpretation_set = None, None
|
438 |
+
input_component_column, interpret_component_column = None, None
|
439 |
+
|
440 |
+
with Row().style(equal_height=False):
|
441 |
+
if self.interface_type in [
|
442 |
+
InterfaceTypes.STANDARD,
|
443 |
+
InterfaceTypes.INPUT_ONLY,
|
444 |
+
InterfaceTypes.UNIFIED,
|
445 |
+
]:
|
446 |
+
(
|
447 |
+
submit_btn,
|
448 |
+
clear_btn,
|
449 |
+
stop_btn,
|
450 |
+
flag_btns,
|
451 |
+
input_component_column,
|
452 |
+
interpret_component_column,
|
453 |
+
interpretation_set,
|
454 |
+
) = self.render_input_column()
|
455 |
+
if self.interface_type in [
|
456 |
+
InterfaceTypes.STANDARD,
|
457 |
+
InterfaceTypes.OUTPUT_ONLY,
|
458 |
+
]:
|
459 |
+
(
|
460 |
+
submit_btn_out,
|
461 |
+
clear_btn_2_out,
|
462 |
+
stop_btn_2_out,
|
463 |
+
flag_btns_out,
|
464 |
+
interpretation_btn,
|
465 |
+
) = self.render_output_column(submit_btn)
|
466 |
+
submit_btn = submit_btn or submit_btn_out
|
467 |
+
clear_btn = clear_btn or clear_btn_2_out
|
468 |
+
stop_btn = stop_btn or stop_btn_2_out
|
469 |
+
flag_btns = flag_btns or flag_btns_out
|
470 |
+
|
471 |
+
assert clear_btn is not None, "Clear button not rendered"
|
472 |
+
|
473 |
+
self.attach_submit_events(submit_btn, stop_btn)
|
474 |
+
self.attach_clear_events(
|
475 |
+
clear_btn, input_component_column, interpret_component_column
|
476 |
+
)
|
477 |
+
self.attach_interpretation_events(
|
478 |
+
interpretation_btn,
|
479 |
+
interpretation_set,
|
480 |
+
input_component_column,
|
481 |
+
interpret_component_column,
|
482 |
+
)
|
483 |
+
|
484 |
+
self.render_flagging_buttons(flag_btns)
|
485 |
+
self.render_examples()
|
486 |
+
self.render_article()
|
487 |
+
|
488 |
+
self.config = self.get_config_file()
|
489 |
+
|
490 |
+
def render_title_description(self) -> None:
|
491 |
+
if self.title:
|
492 |
+
Markdown(
|
493 |
+
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
494 |
+
+ self.title
|
495 |
+
+ "</h1>"
|
496 |
+
)
|
497 |
+
if self.description:
|
498 |
+
Markdown(self.description)
|
499 |
+
|
500 |
+
def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
|
501 |
+
if self.flagging_options is None:
|
502 |
+
return [(Button("Flag"), None)]
|
503 |
+
else:
|
504 |
+
return [
|
505 |
+
(
|
506 |
+
Button("Flag as " + flag_option),
|
507 |
+
flag_option,
|
508 |
+
)
|
509 |
+
for flag_option in self.flagging_options
|
510 |
+
]
|
511 |
+
|
512 |
+
def render_input_column(
|
513 |
+
self,
|
514 |
+
) -> Tuple[
|
515 |
+
Button | None,
|
516 |
+
Button | None,
|
517 |
+
Button | None,
|
518 |
+
List | None,
|
519 |
+
Column,
|
520 |
+
Column | None,
|
521 |
+
List[Interpretation] | None,
|
522 |
+
]:
|
523 |
+
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
|
524 |
+
interpret_component_column, interpretation_set = None, None
|
525 |
+
|
526 |
+
with Column(variant="panel"):
|
527 |
+
input_component_column = Column()
|
528 |
+
with input_component_column:
|
529 |
+
for component in self.input_components:
|
530 |
+
component.render()
|
531 |
+
if self.interpretation:
|
532 |
+
interpret_component_column = Column(visible=False)
|
533 |
+
interpretation_set = []
|
534 |
+
with interpret_component_column:
|
535 |
+
for component in self.input_components:
|
536 |
+
interpretation_set.append(Interpretation(component))
|
537 |
+
with Row():
|
538 |
+
if self.interface_type in [
|
539 |
+
InterfaceTypes.STANDARD,
|
540 |
+
InterfaceTypes.INPUT_ONLY,
|
541 |
+
]:
|
542 |
+
clear_btn = Button("Clear")
|
543 |
+
if not self.live:
|
544 |
+
submit_btn = Button("Submit", variant="primary")
|
545 |
+
# Stopping jobs only works if the queue is enabled
|
546 |
+
# We don't know if the queue is enabled when the interface
|
547 |
+
# is created. We use whether a generator function is provided
|
548 |
+
# as a proxy of whether the queue will be enabled.
|
549 |
+
# Using a generator function without the queue will raise an error.
|
550 |
+
if inspect.isgeneratorfunction(self.fn):
|
551 |
+
stop_btn = Button("Stop", variant="stop")
|
552 |
+
elif self.interface_type == InterfaceTypes.UNIFIED:
|
553 |
+
clear_btn = Button("Clear")
|
554 |
+
submit_btn = Button("Submit", variant="primary")
|
555 |
+
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
556 |
+
stop_btn = Button("Stop", variant="stop")
|
557 |
+
if self.allow_flagging == "manual":
|
558 |
+
flag_btns = self.render_flag_btns()
|
559 |
+
elif self.allow_flagging == "auto":
|
560 |
+
flag_btns = [(submit_btn, None)]
|
561 |
+
return (
|
562 |
+
submit_btn,
|
563 |
+
clear_btn,
|
564 |
+
stop_btn,
|
565 |
+
flag_btns,
|
566 |
+
input_component_column,
|
567 |
+
interpret_component_column,
|
568 |
+
interpretation_set,
|
569 |
+
)
|
570 |
+
|
571 |
+
def render_output_column(
|
572 |
+
self,
|
573 |
+
submit_btn_in: Button | None,
|
574 |
+
) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
|
575 |
+
submit_btn = submit_btn_in
|
576 |
+
interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
|
577 |
+
|
578 |
+
with Column(variant="panel"):
|
579 |
+
for component in self.output_components:
|
580 |
+
if not (isinstance(component, State)):
|
581 |
+
component.render()
|
582 |
+
with Row():
|
583 |
+
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
584 |
+
clear_btn = Button("Clear")
|
585 |
+
submit_btn = Button("Generate", variant="primary")
|
586 |
+
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
587 |
+
# Stopping jobs only works if the queue is enabled
|
588 |
+
# We don't know if the queue is enabled when the interface
|
589 |
+
# is created. We use whether a generator function is provided
|
590 |
+
# as a proxy of whether the queue will be enabled.
|
591 |
+
# Using a generator function without the queue will raise an error.
|
592 |
+
stop_btn = Button("Stop", variant="stop")
|
593 |
+
if self.allow_flagging == "manual":
|
594 |
+
flag_btns = self.render_flag_btns()
|
595 |
+
elif self.allow_flagging == "auto":
|
596 |
+
assert submit_btn is not None, "Submit button not rendered"
|
597 |
+
flag_btns = [(submit_btn, None)]
|
598 |
+
if self.interpretation:
|
599 |
+
interpretation_btn = Button("Interpret")
|
600 |
+
|
601 |
+
return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
|
602 |
+
|
603 |
+
def render_article(self):
|
604 |
+
if self.article:
|
605 |
+
Markdown(self.article)
|
606 |
+
|
607 |
+
def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
|
608 |
+
if self.live:
|
609 |
+
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
610 |
+
assert submit_btn is not None, "Submit button not rendered"
|
611 |
+
super().load(self.fn, None, self.output_components)
|
612 |
+
# For output-only interfaces, the user probably still want a "generate"
|
613 |
+
# button even if the Interface is live
|
614 |
+
submit_btn.click(
|
615 |
+
self.fn,
|
616 |
+
None,
|
617 |
+
self.output_components,
|
618 |
+
api_name="predict",
|
619 |
+
preprocess=not (self.api_mode),
|
620 |
+
postprocess=not (self.api_mode),
|
621 |
+
batch=self.batch,
|
622 |
+
max_batch_size=self.max_batch_size,
|
623 |
+
)
|
624 |
+
else:
|
625 |
+
for component in self.input_components:
|
626 |
+
if isinstance(component, Streamable) and component.streaming:
|
627 |
+
component.stream(
|
628 |
+
self.fn,
|
629 |
+
self.input_components,
|
630 |
+
self.output_components,
|
631 |
+
api_name="predict",
|
632 |
+
preprocess=not (self.api_mode),
|
633 |
+
postprocess=not (self.api_mode),
|
634 |
+
)
|
635 |
+
continue
|
636 |
+
if isinstance(component, Changeable):
|
637 |
+
component.change(
|
638 |
+
self.fn,
|
639 |
+
self.input_components,
|
640 |
+
self.output_components,
|
641 |
+
api_name="predict",
|
642 |
+
preprocess=not (self.api_mode),
|
643 |
+
postprocess=not (self.api_mode),
|
644 |
+
)
|
645 |
+
else:
|
646 |
+
assert submit_btn is not None, "Submit button not rendered"
|
647 |
+
pred = submit_btn.click(
|
648 |
+
self.fn,
|
649 |
+
self.input_components,
|
650 |
+
self.output_components,
|
651 |
+
api_name="predict",
|
652 |
+
scroll_to_output=True,
|
653 |
+
preprocess=not (self.api_mode),
|
654 |
+
postprocess=not (self.api_mode),
|
655 |
+
batch=self.batch,
|
656 |
+
max_batch_size=self.max_batch_size,
|
657 |
+
)
|
658 |
+
if stop_btn:
|
659 |
+
stop_btn.click(
|
660 |
+
None,
|
661 |
+
inputs=None,
|
662 |
+
outputs=None,
|
663 |
+
cancels=[pred],
|
664 |
+
)
|
665 |
+
|
666 |
+
def attach_clear_events(
|
667 |
+
self,
|
668 |
+
clear_btn: Button,
|
669 |
+
input_component_column: Column | None,
|
670 |
+
interpret_component_column: Column | None,
|
671 |
+
):
|
672 |
+
clear_btn.click(
|
673 |
+
None,
|
674 |
+
[],
|
675 |
+
(
|
676 |
+
self.input_components
|
677 |
+
+ self.output_components
|
678 |
+
+ ([input_component_column] if input_component_column else [])
|
679 |
+
+ ([interpret_component_column] if self.interpretation else [])
|
680 |
+
), # type: ignore
|
681 |
+
_js=f"""() => {json.dumps(
|
682 |
+
[getattr(component, "cleared_value", None)
|
683 |
+
for component in self.input_components + self.output_components] + (
|
684 |
+
[Column.update(visible=True)]
|
685 |
+
if self.interface_type
|
686 |
+
in [
|
687 |
+
InterfaceTypes.STANDARD,
|
688 |
+
InterfaceTypes.INPUT_ONLY,
|
689 |
+
InterfaceTypes.UNIFIED,
|
690 |
+
]
|
691 |
+
else []
|
692 |
+
)
|
693 |
+
+ ([Column.update(visible=False)] if self.interpretation else [])
|
694 |
+
)}
|
695 |
+
""",
|
696 |
+
)
|
697 |
+
|
698 |
+
def attach_interpretation_events(
|
699 |
+
self,
|
700 |
+
interpretation_btn: Button | None,
|
701 |
+
interpretation_set: List[Interpretation] | None,
|
702 |
+
input_component_column: Column | None,
|
703 |
+
interpret_component_column: Column | None,
|
704 |
+
):
|
705 |
+
if interpretation_btn:
|
706 |
+
interpretation_btn.click(
|
707 |
+
self.interpret_func,
|
708 |
+
inputs=self.input_components + self.output_components,
|
709 |
+
outputs=interpretation_set
|
710 |
+
or [] + [input_component_column, interpret_component_column], # type: ignore
|
711 |
+
preprocess=False,
|
712 |
+
)
|
713 |
+
|
714 |
+
def render_flagging_buttons(self, flag_btns: List | None):
|
715 |
+
if flag_btns:
|
716 |
+
if self.interface_type in [
|
717 |
+
InterfaceTypes.STANDARD,
|
718 |
+
InterfaceTypes.OUTPUT_ONLY,
|
719 |
+
InterfaceTypes.UNIFIED,
|
720 |
+
]:
|
721 |
+
if (
|
722 |
+
self.interface_type == InterfaceTypes.UNIFIED
|
723 |
+
or self.allow_flagging == "auto"
|
724 |
+
):
|
725 |
+
flag_components = self.input_components
|
726 |
+
else:
|
727 |
+
flag_components = self.input_components + self.output_components
|
728 |
+
for flag_btn, flag_option in flag_btns:
|
729 |
+
flag_method = FlagMethod(self.flagging_callback, flag_option)
|
730 |
+
flag_btn.click(
|
731 |
+
flag_method,
|
732 |
+
inputs=flag_components,
|
733 |
+
outputs=[],
|
734 |
+
preprocess=False,
|
735 |
+
queue=False,
|
736 |
+
)
|
737 |
+
|
738 |
+
def render_examples(self):
|
739 |
+
if self.examples:
|
740 |
+
non_state_inputs = [
|
741 |
+
c for c in self.input_components if not isinstance(c, State)
|
742 |
+
]
|
743 |
+
non_state_outputs = [
|
744 |
+
c for c in self.output_components if not isinstance(c, State)
|
745 |
+
]
|
746 |
+
self.examples_handler = Examples(
|
747 |
+
examples=self.examples,
|
748 |
+
inputs=non_state_inputs, # type: ignore
|
749 |
+
outputs=non_state_outputs, # type: ignore
|
750 |
+
fn=self.fn,
|
751 |
+
cache_examples=self.cache_examples,
|
752 |
+
examples_per_page=self.examples_per_page,
|
753 |
+
_api_mode=self.api_mode,
|
754 |
+
batch=self.batch,
|
755 |
+
)
|
756 |
+
|
757 |
+
def __str__(self):
|
758 |
+
return self.__repr__()
|
759 |
+
|
760 |
+
def __repr__(self):
|
761 |
+
repr = f"Gradio Interface for: {self.__name__}"
|
762 |
+
repr += "\n" + "-" * len(repr)
|
763 |
+
repr += "\ninputs:"
|
764 |
+
for component in self.input_components:
|
765 |
+
repr += "\n|-{}".format(str(component))
|
766 |
+
repr += "\noutputs:"
|
767 |
+
for component in self.output_components:
|
768 |
+
repr += "\n|-{}".format(str(component))
|
769 |
+
return repr
|
770 |
+
|
771 |
+
async def interpret_func(self, *args):
|
772 |
+
return await self.interpret(list(args)) + [
|
773 |
+
Column.update(visible=False),
|
774 |
+
Column.update(visible=True),
|
775 |
+
]
|
776 |
+
|
777 |
+
async def interpret(self, raw_input: List[Any]) -> List[Any]:
|
778 |
+
return [
|
779 |
+
{"original": raw_value, "interpretation": interpretation}
|
780 |
+
for interpretation, raw_value in zip(
|
781 |
+
(await interpretation.run_interpret(self, raw_input))[0], raw_input
|
782 |
+
)
|
783 |
+
]
|
784 |
+
|
785 |
+
def test_launch(self) -> None:
|
786 |
+
"""
|
787 |
+
Deprecated.
|
788 |
+
"""
|
789 |
+
warnings.warn("The Interface.test_launch() function is deprecated.")
|
790 |
+
|
791 |
+
|
792 |
+
@document()
|
793 |
+
class TabbedInterface(Blocks):
|
794 |
+
"""
|
795 |
+
A TabbedInterface is created by providing a list of Interfaces, each of which gets
|
796 |
+
rendered in a separate tab.
|
797 |
+
Demos: stt_or_tts
|
798 |
+
"""
|
799 |
+
|
800 |
+
def __init__(
|
801 |
+
self,
|
802 |
+
interface_list: List[Interface],
|
803 |
+
tab_names: List[str] | None = None,
|
804 |
+
title: str | None = None,
|
805 |
+
theme: str = "default",
|
806 |
+
analytics_enabled: bool | None = None,
|
807 |
+
css: str | None = None,
|
808 |
+
):
|
809 |
+
"""
|
810 |
+
Parameters:
|
811 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
812 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
813 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
814 |
+
theme: which theme to use - right now, only "default" is supported.
|
815 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
816 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
817 |
+
Returns:
|
818 |
+
a Gradio Tabbed Interface for the given interfaces
|
819 |
+
"""
|
820 |
+
super().__init__(
|
821 |
+
title=title or "Gradio",
|
822 |
+
theme=theme,
|
823 |
+
analytics_enabled=analytics_enabled,
|
824 |
+
mode="tabbed_interface",
|
825 |
+
css=css,
|
826 |
+
)
|
827 |
+
if tab_names is None:
|
828 |
+
tab_names = ["Tab {}".format(i) for i in range(len(interface_list))]
|
829 |
+
with self:
|
830 |
+
if title:
|
831 |
+
Markdown(
|
832 |
+
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
833 |
+
+ title
|
834 |
+
+ "</h1>"
|
835 |
+
)
|
836 |
+
with Tabs():
|
837 |
+
for (interface, tab_name) in zip(interface_list, tab_names):
|
838 |
+
with Tab(label=tab_name):
|
839 |
+
interface.render()
|
840 |
+
|
841 |
+
|
842 |
+
def close_all(verbose: bool = True) -> None:
|
843 |
+
for io in Interface.get_instances():
|
844 |
+
io.close(verbose)
|
gradio-modified/gradio/interpretation.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from gradio import utils
|
7 |
+
from gradio.components import Label, Number
|
8 |
+
|
9 |
+
|
10 |
+
async def run_interpret(interface, raw_input):
|
11 |
+
"""
|
12 |
+
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
13 |
+
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
14 |
+
Parameters:
|
15 |
+
raw_input: a list of raw inputs to apply the interpretation(s) on.
|
16 |
+
"""
|
17 |
+
if isinstance(interface.interpretation, list): # Either "default" or "shap"
|
18 |
+
processed_input = [
|
19 |
+
input_component.preprocess(raw_input[i])
|
20 |
+
for i, input_component in enumerate(interface.input_components)
|
21 |
+
]
|
22 |
+
original_output = await interface.call_function(0, processed_input)
|
23 |
+
original_output = original_output["prediction"]
|
24 |
+
|
25 |
+
if len(interface.output_components) == 1:
|
26 |
+
original_output = [original_output]
|
27 |
+
|
28 |
+
scores, alternative_outputs = [], []
|
29 |
+
|
30 |
+
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
|
31 |
+
if interp == "default":
|
32 |
+
input_component = interface.input_components[i]
|
33 |
+
neighbor_raw_input = list(raw_input)
|
34 |
+
if input_component.interpret_by_tokens:
|
35 |
+
tokens, neighbor_values, masks = input_component.tokenize(x)
|
36 |
+
interface_scores = []
|
37 |
+
alternative_output = []
|
38 |
+
for neighbor_input in neighbor_values:
|
39 |
+
neighbor_raw_input[i] = neighbor_input
|
40 |
+
processed_neighbor_input = [
|
41 |
+
input_component.preprocess(neighbor_raw_input[i])
|
42 |
+
for i, input_component in enumerate(
|
43 |
+
interface.input_components
|
44 |
+
)
|
45 |
+
]
|
46 |
+
|
47 |
+
neighbor_output = await interface.call_function(
|
48 |
+
0, processed_neighbor_input
|
49 |
+
)
|
50 |
+
neighbor_output = neighbor_output["prediction"]
|
51 |
+
if len(interface.output_components) == 1:
|
52 |
+
neighbor_output = [neighbor_output]
|
53 |
+
processed_neighbor_output = [
|
54 |
+
output_component.postprocess(neighbor_output[i])
|
55 |
+
for i, output_component in enumerate(
|
56 |
+
interface.output_components
|
57 |
+
)
|
58 |
+
]
|
59 |
+
|
60 |
+
alternative_output.append(processed_neighbor_output)
|
61 |
+
interface_scores.append(
|
62 |
+
quantify_difference_in_label(
|
63 |
+
interface, original_output, neighbor_output
|
64 |
+
)
|
65 |
+
)
|
66 |
+
alternative_outputs.append(alternative_output)
|
67 |
+
scores.append(
|
68 |
+
input_component.get_interpretation_scores(
|
69 |
+
raw_input[i],
|
70 |
+
neighbor_values,
|
71 |
+
interface_scores,
|
72 |
+
masks=masks,
|
73 |
+
tokens=tokens,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
(
|
78 |
+
neighbor_values,
|
79 |
+
interpret_kwargs,
|
80 |
+
) = input_component.get_interpretation_neighbors(x)
|
81 |
+
interface_scores = []
|
82 |
+
alternative_output = []
|
83 |
+
for neighbor_input in neighbor_values:
|
84 |
+
neighbor_raw_input[i] = neighbor_input
|
85 |
+
processed_neighbor_input = [
|
86 |
+
input_component.preprocess(neighbor_raw_input[i])
|
87 |
+
for i, input_component in enumerate(
|
88 |
+
interface.input_components
|
89 |
+
)
|
90 |
+
]
|
91 |
+
neighbor_output = await interface.call_function(
|
92 |
+
0, processed_neighbor_input
|
93 |
+
)
|
94 |
+
neighbor_output = neighbor_output["prediction"]
|
95 |
+
if len(interface.output_components) == 1:
|
96 |
+
neighbor_output = [neighbor_output]
|
97 |
+
processed_neighbor_output = [
|
98 |
+
output_component.postprocess(neighbor_output[i])
|
99 |
+
for i, output_component in enumerate(
|
100 |
+
interface.output_components
|
101 |
+
)
|
102 |
+
]
|
103 |
+
|
104 |
+
alternative_output.append(processed_neighbor_output)
|
105 |
+
interface_scores.append(
|
106 |
+
quantify_difference_in_label(
|
107 |
+
interface, original_output, neighbor_output
|
108 |
+
)
|
109 |
+
)
|
110 |
+
alternative_outputs.append(alternative_output)
|
111 |
+
interface_scores = [-score for score in interface_scores]
|
112 |
+
scores.append(
|
113 |
+
input_component.get_interpretation_scores(
|
114 |
+
raw_input[i],
|
115 |
+
neighbor_values,
|
116 |
+
interface_scores,
|
117 |
+
**interpret_kwargs
|
118 |
+
)
|
119 |
+
)
|
120 |
+
elif interp == "shap" or interp == "shapley":
|
121 |
+
try:
|
122 |
+
import shap # type: ignore
|
123 |
+
except (ImportError, ModuleNotFoundError):
|
124 |
+
raise ValueError(
|
125 |
+
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
|
126 |
+
)
|
127 |
+
input_component = interface.input_components[i]
|
128 |
+
if not (input_component.interpret_by_tokens):
|
129 |
+
raise ValueError(
|
130 |
+
"Input component {} does not support `shap` interpretation".format(
|
131 |
+
input_component
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
tokens, _, masks = input_component.tokenize(x)
|
136 |
+
|
137 |
+
# construct a masked version of the input
|
138 |
+
def get_masked_prediction(binary_mask):
|
139 |
+
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
140 |
+
preds = []
|
141 |
+
for masked_x in masked_xs:
|
142 |
+
processed_masked_input = copy.deepcopy(processed_input)
|
143 |
+
processed_masked_input[i] = input_component.preprocess(masked_x)
|
144 |
+
new_output = utils.synchronize_async(
|
145 |
+
interface.call_function, 0, processed_masked_input
|
146 |
+
)
|
147 |
+
new_output = new_output["prediction"]
|
148 |
+
if len(interface.output_components) == 1:
|
149 |
+
new_output = [new_output]
|
150 |
+
pred = get_regression_or_classification_value(
|
151 |
+
interface, original_output, new_output
|
152 |
+
)
|
153 |
+
preds.append(pred)
|
154 |
+
return np.array(preds)
|
155 |
+
|
156 |
+
num_total_segments = len(tokens)
|
157 |
+
explainer = shap.KernelExplainer(
|
158 |
+
get_masked_prediction, np.zeros((1, num_total_segments))
|
159 |
+
)
|
160 |
+
shap_values = explainer.shap_values(
|
161 |
+
np.ones((1, num_total_segments)),
|
162 |
+
nsamples=int(interface.num_shap * num_total_segments),
|
163 |
+
silent=True,
|
164 |
+
)
|
165 |
+
scores.append(
|
166 |
+
input_component.get_interpretation_scores(
|
167 |
+
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens
|
168 |
+
)
|
169 |
+
)
|
170 |
+
alternative_outputs.append([])
|
171 |
+
elif interp is None:
|
172 |
+
scores.append(None)
|
173 |
+
alternative_outputs.append([])
|
174 |
+
else:
|
175 |
+
raise ValueError("Unknown intepretation method: {}".format(interp))
|
176 |
+
return scores, alternative_outputs
|
177 |
+
else: # custom interpretation function
|
178 |
+
processed_input = [
|
179 |
+
input_component.preprocess(raw_input[i])
|
180 |
+
for i, input_component in enumerate(interface.input_components)
|
181 |
+
]
|
182 |
+
interpreter = interface.interpretation
|
183 |
+
interpretation = interpreter(*processed_input)
|
184 |
+
if len(raw_input) == 1:
|
185 |
+
interpretation = [interpretation]
|
186 |
+
return interpretation, []
|
187 |
+
|
188 |
+
|
189 |
+
def diff(original, perturbed):
|
190 |
+
try: # try computing numerical difference
|
191 |
+
score = float(original) - float(perturbed)
|
192 |
+
except ValueError: # otherwise, look at strict difference in label
|
193 |
+
score = int(not (original == perturbed))
|
194 |
+
return score
|
195 |
+
|
196 |
+
|
197 |
+
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
198 |
+
output_component = interface.output_components[0]
|
199 |
+
post_original_output = output_component.postprocess(original_output[0])
|
200 |
+
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
201 |
+
|
202 |
+
if isinstance(output_component, Label):
|
203 |
+
original_label = post_original_output["label"]
|
204 |
+
perturbed_label = post_perturbed_output["label"]
|
205 |
+
|
206 |
+
# Handle different return types of Label interface
|
207 |
+
if "confidences" in post_original_output:
|
208 |
+
original_confidence = original_output[0][original_label]
|
209 |
+
perturbed_confidence = perturbed_output[0][original_label]
|
210 |
+
score = original_confidence - perturbed_confidence
|
211 |
+
else:
|
212 |
+
score = diff(original_label, perturbed_label)
|
213 |
+
return score
|
214 |
+
|
215 |
+
elif isinstance(output_component, Number):
|
216 |
+
score = diff(post_original_output, post_perturbed_output)
|
217 |
+
return score
|
218 |
+
|
219 |
+
else:
|
220 |
+
raise ValueError(
|
221 |
+
"This interpretation method doesn't support the Output component: {}".format(
|
222 |
+
output_component
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
|
227 |
+
def get_regression_or_classification_value(
|
228 |
+
interface, original_output, perturbed_output
|
229 |
+
):
|
230 |
+
"""Used to combine regression/classification for Shap interpretation method."""
|
231 |
+
output_component = interface.output_components[0]
|
232 |
+
post_original_output = output_component.postprocess(original_output[0])
|
233 |
+
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
234 |
+
|
235 |
+
if type(output_component) == Label:
|
236 |
+
original_label = post_original_output["label"]
|
237 |
+
perturbed_label = post_perturbed_output["label"]
|
238 |
+
|
239 |
+
# Handle different return types of Label interface
|
240 |
+
if "confidences" in post_original_output:
|
241 |
+
if math.isnan(perturbed_output[0][original_label]):
|
242 |
+
return 0
|
243 |
+
return perturbed_output[0][original_label]
|
244 |
+
else:
|
245 |
+
score = diff(
|
246 |
+
perturbed_label, original_label
|
247 |
+
) # Intentionally inverted order of arguments.
|
248 |
+
return score
|
249 |
+
|
250 |
+
else:
|
251 |
+
raise ValueError(
|
252 |
+
"This interpretation method doesn't support the Output component: {}".format(
|
253 |
+
output_component
|
254 |
+
)
|
255 |
+
)
|
gradio-modified/gradio/ipython_ext.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from IPython.core.magic import needs_local_scope, register_cell_magic
|
3 |
+
except ImportError:
|
4 |
+
pass
|
5 |
+
|
6 |
+
import gradio
|
7 |
+
|
8 |
+
|
9 |
+
def load_ipython_extension(ipython):
|
10 |
+
__demo = gradio.Blocks()
|
11 |
+
|
12 |
+
@register_cell_magic
|
13 |
+
@needs_local_scope
|
14 |
+
def blocks(line, cell, local_ns=None):
|
15 |
+
with __demo.clear():
|
16 |
+
exec(cell, None, local_ns)
|
17 |
+
__demo.launch(quiet=True)
|
gradio-modified/gradio/launches.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"launches": 145}
|
gradio-modified/gradio/layouts.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
from typing import TYPE_CHECKING, Callable, List, Type
|
5 |
+
|
6 |
+
from gradio.blocks import BlockContext
|
7 |
+
from gradio.documentation import document, set_documentation_group
|
8 |
+
|
9 |
+
set_documentation_group("layout")
|
10 |
+
|
11 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
12 |
+
from gradio.components import Component
|
13 |
+
|
14 |
+
|
15 |
+
@document()
|
16 |
+
class Row(BlockContext):
|
17 |
+
"""
|
18 |
+
Row is a layout element within Blocks that renders all children horizontally.
|
19 |
+
Example:
|
20 |
+
with gradio.Blocks() as demo:
|
21 |
+
with gradio.Row():
|
22 |
+
gr.Image("lion.jpg")
|
23 |
+
gr.Image("tiger.jpg")
|
24 |
+
demo.launch()
|
25 |
+
Guides: controlling_layout
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
*,
|
31 |
+
variant: str = "default",
|
32 |
+
visible: bool = True,
|
33 |
+
elem_id: str | None = None,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Parameters:
|
38 |
+
variant: row type, 'default' (no background), 'panel' (gray background color and rounded corners), or 'compact' (rounded corners and no internal gap).
|
39 |
+
visible: If False, row will be hidden.
|
40 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
41 |
+
"""
|
42 |
+
self.variant = variant
|
43 |
+
if variant == "compact":
|
44 |
+
self.allow_expected_parents = False
|
45 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
46 |
+
|
47 |
+
def get_config(self):
|
48 |
+
return {"type": "row", "variant": self.variant, **super().get_config()}
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def update(
|
52 |
+
visible: bool | None = None,
|
53 |
+
):
|
54 |
+
return {
|
55 |
+
"visible": visible,
|
56 |
+
"__type__": "update",
|
57 |
+
}
|
58 |
+
|
59 |
+
def style(
|
60 |
+
self,
|
61 |
+
*,
|
62 |
+
equal_height: bool | None = None,
|
63 |
+
mobile_collapse: bool | None = None,
|
64 |
+
**kwargs,
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
Styles the Row.
|
68 |
+
Parameters:
|
69 |
+
equal_height: If True, makes every child element have equal height
|
70 |
+
mobile_collapse: DEPRECATED.
|
71 |
+
"""
|
72 |
+
if equal_height is not None:
|
73 |
+
self._style["equal_height"] = equal_height
|
74 |
+
if mobile_collapse is not None:
|
75 |
+
warnings.warn("mobile_collapse is no longer supported.")
|
76 |
+
return self
|
77 |
+
|
78 |
+
|
79 |
+
@document()
|
80 |
+
class Column(BlockContext):
|
81 |
+
"""
|
82 |
+
Column is a layout element within Blocks that renders all children vertically. The widths of columns can be set through the `scale` and `min_width` parameters.
|
83 |
+
If a certain scale results in a column narrower than min_width, the min_width parameter will win.
|
84 |
+
Example:
|
85 |
+
with gradio.Blocks() as demo:
|
86 |
+
with gradio.Row():
|
87 |
+
with gradio.Column(scale=1):
|
88 |
+
text1 = gr.Textbox()
|
89 |
+
text2 = gr.Textbox()
|
90 |
+
with gradio.Column(scale=4):
|
91 |
+
btn1 = gr.Button("Button 1")
|
92 |
+
btn2 = gr.Button("Button 2")
|
93 |
+
Guides: controlling_layout
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
*,
|
99 |
+
scale: int = 1,
|
100 |
+
min_width: int = 320,
|
101 |
+
variant: str = "default",
|
102 |
+
visible: bool = True,
|
103 |
+
elem_id: str | None = None,
|
104 |
+
**kwargs,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Parameters:
|
108 |
+
scale: relative width compared to adjacent Columns. For example, if Column A has scale=2, and Column B has scale=1, A will be twice as wide as B.
|
109 |
+
min_width: minimum pixel width of Column, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in a column narrower than min_width, the min_width parameter will be respected first.
|
110 |
+
variant: column type, 'default' (no background), 'panel' (gray background color and rounded corners), or 'compact' (rounded corners and no internal gap).
|
111 |
+
visible: If False, column will be hidden.
|
112 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
113 |
+
"""
|
114 |
+
self.scale = scale
|
115 |
+
self.min_width = min_width
|
116 |
+
self.variant = variant
|
117 |
+
if variant == "compact":
|
118 |
+
self.allow_expected_parents = False
|
119 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
120 |
+
|
121 |
+
def get_config(self):
|
122 |
+
return {
|
123 |
+
"type": "column",
|
124 |
+
"variant": self.variant,
|
125 |
+
"scale": self.scale,
|
126 |
+
"min_width": self.min_width,
|
127 |
+
**super().get_config(),
|
128 |
+
}
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def update(
|
132 |
+
variant: str | None = None,
|
133 |
+
visible: bool | None = None,
|
134 |
+
):
|
135 |
+
return {
|
136 |
+
"variant": variant,
|
137 |
+
"visible": visible,
|
138 |
+
"__type__": "update",
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
class Tabs(BlockContext):
|
143 |
+
"""
|
144 |
+
Tabs is a layout element within Blocks that can contain multiple "Tab" Components.
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
*,
|
150 |
+
selected: int | str | None = None,
|
151 |
+
visible: bool = True,
|
152 |
+
elem_id: str | None = None,
|
153 |
+
**kwargs,
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Parameters:
|
157 |
+
selected: The currently selected tab. Must correspond to an id passed to the one of the child TabItems. Defaults to the first TabItem.
|
158 |
+
visible: If False, Tabs will be hidden.
|
159 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
160 |
+
"""
|
161 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
162 |
+
self.selected = selected
|
163 |
+
|
164 |
+
def get_config(self):
|
165 |
+
return {"selected": self.selected, **super().get_config()}
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def update(
|
169 |
+
selected: int | str | None = None,
|
170 |
+
):
|
171 |
+
return {
|
172 |
+
"selected": selected,
|
173 |
+
"__type__": "update",
|
174 |
+
}
|
175 |
+
|
176 |
+
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
177 |
+
"""
|
178 |
+
Parameters:
|
179 |
+
fn: Callable function
|
180 |
+
inputs: List of inputs
|
181 |
+
outputs: List of outputs
|
182 |
+
Returns: None
|
183 |
+
"""
|
184 |
+
self.set_event_trigger("change", fn, inputs, outputs)
|
185 |
+
|
186 |
+
|
187 |
+
@document()
|
188 |
+
class Tab(BlockContext):
|
189 |
+
"""
|
190 |
+
Tab (or its alias TabItem) is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
|
191 |
+
Example:
|
192 |
+
with gradio.Blocks() as demo:
|
193 |
+
with gradio.Tab("Lion"):
|
194 |
+
gr.Image("lion.jpg")
|
195 |
+
gr.Button("New Lion")
|
196 |
+
with gradio.Tab("Tiger"):
|
197 |
+
gr.Image("tiger.jpg")
|
198 |
+
gr.Button("New Tiger")
|
199 |
+
Guides: controlling_layout
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
label: str,
|
205 |
+
*,
|
206 |
+
id: int | str | None = None,
|
207 |
+
elem_id: str | None = None,
|
208 |
+
**kwargs,
|
209 |
+
):
|
210 |
+
"""
|
211 |
+
Parameters:
|
212 |
+
label: The visual label for the tab
|
213 |
+
id: An optional identifier for the tab, required if you wish to control the selected tab from a predict function.
|
214 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
215 |
+
"""
|
216 |
+
super().__init__(elem_id=elem_id, **kwargs)
|
217 |
+
self.label = label
|
218 |
+
self.id = id
|
219 |
+
|
220 |
+
def get_config(self):
|
221 |
+
return {
|
222 |
+
"label": self.label,
|
223 |
+
"id": self.id,
|
224 |
+
**super().get_config(),
|
225 |
+
}
|
226 |
+
|
227 |
+
def select(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
228 |
+
"""
|
229 |
+
Parameters:
|
230 |
+
fn: Callable function
|
231 |
+
inputs: List of inputs
|
232 |
+
outputs: List of outputs
|
233 |
+
Returns: None
|
234 |
+
"""
|
235 |
+
self.set_event_trigger("select", fn, inputs, outputs)
|
236 |
+
|
237 |
+
def get_expected_parent(self) -> Type[Tabs]:
|
238 |
+
return Tabs
|
239 |
+
|
240 |
+
def get_block_name(self):
|
241 |
+
return "tabitem"
|
242 |
+
|
243 |
+
|
244 |
+
TabItem = Tab
|
245 |
+
|
246 |
+
|
247 |
+
class Group(BlockContext):
|
248 |
+
"""
|
249 |
+
Group is a layout element within Blocks which groups together children so that
|
250 |
+
they do not have any padding or margin between them.
|
251 |
+
Example:
|
252 |
+
with gradio.Group():
|
253 |
+
gr.Textbox(label="First")
|
254 |
+
gr.Textbox(label="Last")
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
*,
|
260 |
+
visible: bool = True,
|
261 |
+
elem_id: str | None = None,
|
262 |
+
**kwargs,
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
Parameters:
|
266 |
+
visible: If False, group will be hidden.
|
267 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
268 |
+
"""
|
269 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
270 |
+
|
271 |
+
def get_config(self):
|
272 |
+
return {"type": "group", **super().get_config()}
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def update(
|
276 |
+
visible: bool | None = None,
|
277 |
+
):
|
278 |
+
return {
|
279 |
+
"visible": visible,
|
280 |
+
"__type__": "update",
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
@document()
|
285 |
+
class Box(BlockContext):
|
286 |
+
"""
|
287 |
+
Box is a a layout element which places children in a box with rounded corners and
|
288 |
+
some padding around them.
|
289 |
+
Example:
|
290 |
+
with gradio.Box():
|
291 |
+
gr.Textbox(label="First")
|
292 |
+
gr.Textbox(label="Last")
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
*,
|
298 |
+
visible: bool = True,
|
299 |
+
elem_id: str | None = None,
|
300 |
+
**kwargs,
|
301 |
+
):
|
302 |
+
"""
|
303 |
+
Parameters:
|
304 |
+
visible: If False, box will be hidden.
|
305 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
306 |
+
"""
|
307 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
308 |
+
|
309 |
+
def get_config(self):
|
310 |
+
return {"type": "box", **super().get_config()}
|
311 |
+
|
312 |
+
@staticmethod
|
313 |
+
def update(
|
314 |
+
visible: bool | None = None,
|
315 |
+
):
|
316 |
+
return {
|
317 |
+
"visible": visible,
|
318 |
+
"__type__": "update",
|
319 |
+
}
|
320 |
+
|
321 |
+
def style(self, **kwargs):
|
322 |
+
return self
|
323 |
+
|
324 |
+
|
325 |
+
class Form(BlockContext):
|
326 |
+
def get_config(self):
|
327 |
+
return {"type": "form", **super().get_config()}
|
328 |
+
|
329 |
+
|
330 |
+
@document()
|
331 |
+
class Accordion(BlockContext):
|
332 |
+
"""
|
333 |
+
Accordion is a layout element which can be toggled to show/hide the contained content.
|
334 |
+
Example:
|
335 |
+
with gradio.Accordion("See Details"):
|
336 |
+
gr.Markdown("lorem ipsum")
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
label,
|
342 |
+
*,
|
343 |
+
open: bool = True,
|
344 |
+
visible: bool = True,
|
345 |
+
elem_id: str | None = None,
|
346 |
+
**kwargs,
|
347 |
+
):
|
348 |
+
"""
|
349 |
+
Parameters:
|
350 |
+
label: name of accordion section.
|
351 |
+
open: if True, accordion is open by default.
|
352 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
353 |
+
"""
|
354 |
+
self.label = label
|
355 |
+
self.open = open
|
356 |
+
super().__init__(visible=visible, elem_id=elem_id, **kwargs)
|
357 |
+
|
358 |
+
def get_config(self):
|
359 |
+
return {
|
360 |
+
"type": "accordion",
|
361 |
+
"open": self.open,
|
362 |
+
"label": self.label,
|
363 |
+
**super().get_config(),
|
364 |
+
}
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
def update(
|
368 |
+
open: bool | None = None,
|
369 |
+
label: str | None = None,
|
370 |
+
visible: bool | None = None,
|
371 |
+
):
|
372 |
+
return {
|
373 |
+
"visible": visible,
|
374 |
+
"label": label,
|
375 |
+
"open": open,
|
376 |
+
"__type__": "update",
|
377 |
+
}
|
gradio-modified/gradio/media_data.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gradio-modified/gradio/mix.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Ways to transform interfaces to produce new interfaces
|
3 |
+
"""
|
4 |
+
import asyncio
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import gradio
|
8 |
+
from gradio.documentation import document, set_documentation_group
|
9 |
+
|
10 |
+
set_documentation_group("mix_interface")
|
11 |
+
|
12 |
+
|
13 |
+
@document()
|
14 |
+
class Parallel(gradio.Interface):
|
15 |
+
"""
|
16 |
+
Creates a new Interface consisting of multiple Interfaces in parallel (comparing their outputs).
|
17 |
+
The Interfaces to put in Parallel must share the same input components (but can have different output components).
|
18 |
+
|
19 |
+
Demos: interface_parallel, interface_parallel_load
|
20 |
+
Guides: advanced_interface_features
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, *interfaces: gradio.Interface, **options):
|
24 |
+
"""
|
25 |
+
Parameters:
|
26 |
+
interfaces: any number of Interface objects that are to be compared in parallel
|
27 |
+
options: additional kwargs that are passed into the new Interface object to customize it
|
28 |
+
Returns:
|
29 |
+
an Interface object comparing the given models
|
30 |
+
"""
|
31 |
+
outputs = []
|
32 |
+
|
33 |
+
for interface in interfaces:
|
34 |
+
if not (isinstance(interface, gradio.Interface)):
|
35 |
+
warnings.warn(
|
36 |
+
"Parallel requires all inputs to be of type Interface. "
|
37 |
+
"May not work as expected."
|
38 |
+
)
|
39 |
+
outputs.extend(interface.output_components)
|
40 |
+
|
41 |
+
async def parallel_fn(*args):
|
42 |
+
return_values_with_durations = await asyncio.gather(
|
43 |
+
*[interface.call_function(0, list(args)) for interface in interfaces]
|
44 |
+
)
|
45 |
+
return_values = [rv["prediction"] for rv in return_values_with_durations]
|
46 |
+
combined_list = []
|
47 |
+
for interface, return_value in zip(interfaces, return_values):
|
48 |
+
if len(interface.output_components) == 1:
|
49 |
+
combined_list.append(return_value)
|
50 |
+
else:
|
51 |
+
combined_list.extend(return_value)
|
52 |
+
if len(outputs) == 1:
|
53 |
+
return combined_list[0]
|
54 |
+
return combined_list
|
55 |
+
|
56 |
+
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])
|
57 |
+
|
58 |
+
kwargs = {
|
59 |
+
"fn": parallel_fn,
|
60 |
+
"inputs": interfaces[0].input_components,
|
61 |
+
"outputs": outputs,
|
62 |
+
}
|
63 |
+
kwargs.update(options)
|
64 |
+
super().__init__(**kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
@document()
|
68 |
+
class Series(gradio.Interface):
|
69 |
+
"""
|
70 |
+
Creates a new Interface from multiple Interfaces in series (the output of one is fed as the input to the next,
|
71 |
+
and so the input and output components must agree between the interfaces).
|
72 |
+
|
73 |
+
Demos: interface_series, interface_series_load
|
74 |
+
Guides: advanced_interface_features
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, *interfaces: gradio.Interface, **options):
|
78 |
+
"""
|
79 |
+
Parameters:
|
80 |
+
interfaces: any number of Interface objects that are to be connected in series
|
81 |
+
options: additional kwargs that are passed into the new Interface object to customize it
|
82 |
+
Returns:
|
83 |
+
an Interface object connecting the given models
|
84 |
+
"""
|
85 |
+
|
86 |
+
async def connected_fn(*data):
|
87 |
+
for idx, interface in enumerate(interfaces):
|
88 |
+
# skip preprocessing for first interface since the Series interface will include it
|
89 |
+
if idx > 0 and not (interface.api_mode):
|
90 |
+
data = [
|
91 |
+
input_component.preprocess(data[i])
|
92 |
+
for i, input_component in enumerate(interface.input_components)
|
93 |
+
]
|
94 |
+
|
95 |
+
# run all of predictions sequentially
|
96 |
+
data = (await interface.call_function(0, list(data)))["prediction"]
|
97 |
+
if len(interface.output_components) == 1:
|
98 |
+
data = [data]
|
99 |
+
|
100 |
+
# skip postprocessing for final interface since the Series interface will include it
|
101 |
+
if idx < len(interfaces) - 1 and not (interface.api_mode):
|
102 |
+
data = [
|
103 |
+
output_component.postprocess(data[i])
|
104 |
+
for i, output_component in enumerate(
|
105 |
+
interface.output_components
|
106 |
+
)
|
107 |
+
]
|
108 |
+
|
109 |
+
if len(interface.output_components) == 1: # type: ignore
|
110 |
+
return data[0]
|
111 |
+
return data
|
112 |
+
|
113 |
+
for interface in interfaces:
|
114 |
+
if not (isinstance(interface, gradio.Interface)):
|
115 |
+
warnings.warn(
|
116 |
+
"Series requires all inputs to be of type Interface. May "
|
117 |
+
"not work as expected."
|
118 |
+
)
|
119 |
+
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])
|
120 |
+
|
121 |
+
kwargs = {
|
122 |
+
"fn": connected_fn,
|
123 |
+
"inputs": interfaces[0].input_components,
|
124 |
+
"outputs": interfaces[-1].output_components,
|
125 |
+
"_api_mode": interfaces[0].api_mode, # TODO: set api_mode per-interface
|
126 |
+
}
|
127 |
+
kwargs.update(options)
|
128 |
+
super().__init__(**kwargs)
|
gradio-modified/gradio/networking.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines helper methods useful for setting up ports, launching servers, and
|
3 |
+
creating tunnels.
|
4 |
+
"""
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import os
|
8 |
+
import socket
|
9 |
+
import threading
|
10 |
+
import time
|
11 |
+
import warnings
|
12 |
+
from typing import TYPE_CHECKING, Tuple
|
13 |
+
|
14 |
+
import requests
|
15 |
+
import uvicorn
|
16 |
+
|
17 |
+
from gradio.routes import App
|
18 |
+
from gradio.tunneling import Tunnel
|
19 |
+
|
20 |
+
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
21 |
+
from gradio.blocks import Blocks
|
22 |
+
|
23 |
+
# By default, the local server will try to open on localhost, port 7860.
|
24 |
+
# If that is not available, then it will try 7861, 7862, ... 7959.
|
25 |
+
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
26 |
+
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
|
27 |
+
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
|
28 |
+
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
|
29 |
+
|
30 |
+
|
31 |
+
class Server(uvicorn.Server):
|
32 |
+
def install_signal_handlers(self):
|
33 |
+
pass
|
34 |
+
|
35 |
+
def run_in_thread(self):
|
36 |
+
self.thread = threading.Thread(target=self.run, daemon=True)
|
37 |
+
self.thread.start()
|
38 |
+
while not self.started:
|
39 |
+
time.sleep(1e-3)
|
40 |
+
|
41 |
+
def close(self):
|
42 |
+
self.should_exit = True
|
43 |
+
self.thread.join()
|
44 |
+
|
45 |
+
|
46 |
+
def get_first_available_port(initial: int, final: int) -> int:
|
47 |
+
"""
|
48 |
+
Gets the first open port in a specified range of port numbers
|
49 |
+
Parameters:
|
50 |
+
initial: the initial value in the range of port numbers
|
51 |
+
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
|
52 |
+
Returns:
|
53 |
+
port: the first open port in the range
|
54 |
+
"""
|
55 |
+
for port in range(initial, final):
|
56 |
+
try:
|
57 |
+
s = socket.socket() # create a socket object
|
58 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
59 |
+
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
60 |
+
s.close()
|
61 |
+
return port
|
62 |
+
except OSError:
|
63 |
+
pass
|
64 |
+
raise OSError(
|
65 |
+
"All ports from {} to {} are in use. Please close a port.".format(
|
66 |
+
initial, final - 1
|
67 |
+
)
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def configure_app(app: App, blocks: Blocks) -> App:
|
72 |
+
auth = blocks.auth
|
73 |
+
if auth is not None:
|
74 |
+
if not callable(auth):
|
75 |
+
app.auth = {account[0]: account[1] for account in auth}
|
76 |
+
else:
|
77 |
+
app.auth = auth
|
78 |
+
else:
|
79 |
+
app.auth = None
|
80 |
+
app.blocks = blocks
|
81 |
+
app.cwd = os.getcwd()
|
82 |
+
app.favicon_path = blocks.favicon_path
|
83 |
+
app.tokens = {}
|
84 |
+
return app
|
85 |
+
|
86 |
+
|
87 |
+
def start_server(
|
88 |
+
blocks: Blocks,
|
89 |
+
server_name: str | None = None,
|
90 |
+
server_port: int | None = None,
|
91 |
+
ssl_keyfile: str | None = None,
|
92 |
+
ssl_certfile: str | None = None,
|
93 |
+
ssl_keyfile_password: str | None = None,
|
94 |
+
) -> Tuple[str, int, str, App, Server]:
|
95 |
+
"""Launches a local server running the provided Interface
|
96 |
+
Parameters:
|
97 |
+
blocks: The Blocks object to run on the server
|
98 |
+
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
99 |
+
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
100 |
+
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
|
101 |
+
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
102 |
+
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
103 |
+
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
|
104 |
+
Returns:
|
105 |
+
port: the port number the server is running on
|
106 |
+
path_to_local_server: the complete address that the local server can be accessed at
|
107 |
+
app: the FastAPI app object
|
108 |
+
server: the server object that is a subclass of uvicorn.Server (used to close the server)
|
109 |
+
"""
|
110 |
+
server_name = server_name or LOCALHOST_NAME
|
111 |
+
# if port is not specified, search for first available port
|
112 |
+
if server_port is None:
|
113 |
+
port = get_first_available_port(
|
114 |
+
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
try:
|
118 |
+
s = socket.socket()
|
119 |
+
s.bind((LOCALHOST_NAME, server_port))
|
120 |
+
s.close()
|
121 |
+
except OSError:
|
122 |
+
raise OSError(
|
123 |
+
"Port {} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all().".format(
|
124 |
+
server_port
|
125 |
+
)
|
126 |
+
)
|
127 |
+
port = server_port
|
128 |
+
|
129 |
+
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
|
130 |
+
|
131 |
+
if ssl_keyfile is not None:
|
132 |
+
if ssl_certfile is None:
|
133 |
+
raise ValueError(
|
134 |
+
"ssl_certfile must be provided if ssl_keyfile is provided."
|
135 |
+
)
|
136 |
+
path_to_local_server = "https://{}:{}/".format(url_host_name, port)
|
137 |
+
else:
|
138 |
+
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
|
139 |
+
|
140 |
+
app = App.create_app(blocks)
|
141 |
+
|
142 |
+
if blocks.save_to is not None: # Used for selenium tests
|
143 |
+
blocks.save_to["port"] = port
|
144 |
+
config = uvicorn.Config(
|
145 |
+
app=app,
|
146 |
+
port=port,
|
147 |
+
host=server_name,
|
148 |
+
log_level="warning",
|
149 |
+
ssl_keyfile=ssl_keyfile,
|
150 |
+
ssl_certfile=ssl_certfile,
|
151 |
+
ssl_keyfile_password=ssl_keyfile_password,
|
152 |
+
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
|
153 |
+
)
|
154 |
+
server = Server(config=config)
|
155 |
+
server.run_in_thread()
|
156 |
+
return server_name, port, path_to_local_server, app, server
|
157 |
+
|
158 |
+
|
159 |
+
def setup_tunnel(local_host: str, local_port: int) -> str:
|
160 |
+
response = requests.get(GRADIO_API_SERVER)
|
161 |
+
if response and response.status_code == 200:
|
162 |
+
try:
|
163 |
+
payload = response.json()[0]
|
164 |
+
remote_host, remote_port = payload["host"], int(payload["port"])
|
165 |
+
tunnel = Tunnel(remote_host, remote_port, local_host, local_port)
|
166 |
+
address = tunnel.start_tunnel()
|
167 |
+
return address
|
168 |
+
except Exception as e:
|
169 |
+
raise RuntimeError(str(e))
|
170 |
+
else:
|
171 |
+
raise RuntimeError("Could not get share link from Gradio API Server.")
|
172 |
+
|
173 |
+
|
174 |
+
def url_ok(url: str) -> bool:
|
175 |
+
try:
|
176 |
+
for _ in range(5):
|
177 |
+
with warnings.catch_warnings():
|
178 |
+
warnings.filterwarnings("ignore")
|
179 |
+
r = requests.head(url, timeout=3, verify=False)
|
180 |
+
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
|
181 |
+
return True
|
182 |
+
time.sleep(0.500)
|
183 |
+
except (ConnectionError, requests.exceptions.ConnectionError):
|
184 |
+
return False
|
185 |
+
return False
|
gradio-modified/gradio/outputs.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# type: ignore
|
2 |
+
"""
|
3 |
+
This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
|
4 |
+
`OutputComponent`, and each class must define a path to its template. All of the subclasses of `OutputComponent` are
|
5 |
+
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from __future__ import annotations
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
from typing import Dict, List, Optional
|
12 |
+
|
13 |
+
from gradio import components
|
14 |
+
|
15 |
+
|
16 |
+
class Textbox(components.Textbox):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
type: str = "text",
|
20 |
+
label: Optional[str] = None,
|
21 |
+
):
|
22 |
+
warnings.warn(
|
23 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
24 |
+
)
|
25 |
+
super().__init__(label=label, type=type)
|
26 |
+
|
27 |
+
|
28 |
+
class Image(components.Image):
|
29 |
+
"""
|
30 |
+
Component displays an output image.
|
31 |
+
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self, type: str = "auto", plot: bool = False, label: Optional[str] = None
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Parameters:
|
39 |
+
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
40 |
+
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
41 |
+
label (str): component name in interface.
|
42 |
+
"""
|
43 |
+
warnings.warn(
|
44 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
45 |
+
)
|
46 |
+
if plot:
|
47 |
+
type = "plot"
|
48 |
+
super().__init__(type=type, label=label)
|
49 |
+
|
50 |
+
|
51 |
+
class Video(components.Video):
|
52 |
+
"""
|
53 |
+
Used for video output.
|
54 |
+
Output type: filepath
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, type: Optional[str] = None, label: Optional[str] = None):
|
58 |
+
"""
|
59 |
+
Parameters:
|
60 |
+
type (str): Type of video format to be passed to component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep returned format.
|
61 |
+
label (str): component name in interface.
|
62 |
+
"""
|
63 |
+
warnings.warn(
|
64 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
65 |
+
)
|
66 |
+
super().__init__(format=type, label=label)
|
67 |
+
|
68 |
+
|
69 |
+
class Audio(components.Audio):
|
70 |
+
"""
|
71 |
+
Creates an audio player that plays the output audio.
|
72 |
+
Output type: Union[Tuple[int, numpy.array], str]
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, type: str = "auto", label: Optional[str] = None):
|
76 |
+
"""
|
77 |
+
Parameters:
|
78 |
+
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data as 16-bit int numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
|
79 |
+
label (str): component name in interface.
|
80 |
+
"""
|
81 |
+
warnings.warn(
|
82 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
83 |
+
)
|
84 |
+
super().__init__(type=type, label=label)
|
85 |
+
|
86 |
+
|
87 |
+
class File(components.File):
|
88 |
+
"""
|
89 |
+
Used for file output.
|
90 |
+
Output type: Union[file-like, str]
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self, label: Optional[str] = None):
|
94 |
+
"""
|
95 |
+
Parameters:
|
96 |
+
label (str): component name in interface.
|
97 |
+
"""
|
98 |
+
warnings.warn(
|
99 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
100 |
+
)
|
101 |
+
super().__init__(label=label)
|
102 |
+
|
103 |
+
|
104 |
+
class Dataframe(components.Dataframe):
|
105 |
+
"""
|
106 |
+
Component displays 2D output through a spreadsheet interface.
|
107 |
+
Output type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
headers: Optional[List[str]] = None,
|
113 |
+
max_rows: Optional[int] = 20,
|
114 |
+
max_cols: Optional[int] = None,
|
115 |
+
overflow_row_behaviour: str = "paginate",
|
116 |
+
type: str = "auto",
|
117 |
+
label: Optional[str] = None,
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
Parameters:
|
121 |
+
headers (List[str]): Header names to dataframe. Only applicable if type is "numpy" or "array".
|
122 |
+
max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
|
123 |
+
max_cols (int): Maximum number of columns to display at once. Set to None for infinite.
|
124 |
+
overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
|
125 |
+
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
|
126 |
+
label (str): component name in interface.
|
127 |
+
"""
|
128 |
+
warnings.warn(
|
129 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
130 |
+
)
|
131 |
+
super().__init__(
|
132 |
+
headers=headers,
|
133 |
+
type=type,
|
134 |
+
label=label,
|
135 |
+
max_rows=max_rows,
|
136 |
+
max_cols=max_cols,
|
137 |
+
overflow_row_behaviour=overflow_row_behaviour,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
class Timeseries(components.Timeseries):
|
142 |
+
"""
|
143 |
+
Component accepts pandas.DataFrame.
|
144 |
+
Output type: pandas.DataFrame
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Parameters:
|
152 |
+
x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
|
153 |
+
y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
|
154 |
+
label (str): component name in interface.
|
155 |
+
"""
|
156 |
+
warnings.warn(
|
157 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
158 |
+
)
|
159 |
+
super().__init__(x=x, y=y, label=label)
|
160 |
+
|
161 |
+
|
162 |
+
class State(components.State):
|
163 |
+
"""
|
164 |
+
Special hidden component that stores state across runs of the interface.
|
165 |
+
Output type: Any
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, label: Optional[str] = None):
|
169 |
+
"""
|
170 |
+
Parameters:
|
171 |
+
label (str): component name in interface (not used).
|
172 |
+
"""
|
173 |
+
warnings.warn(
|
174 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
|
175 |
+
)
|
176 |
+
super().__init__(label=label)
|
177 |
+
|
178 |
+
|
179 |
+
class Label(components.Label):
|
180 |
+
"""
|
181 |
+
Component outputs a classification label, along with confidence scores of top categories if provided. Confidence scores are represented as a dictionary mapping labels to scores between 0 and 1.
|
182 |
+
Output type: Union[Dict[str, float], str, int, float]
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
num_top_classes: Optional[int] = None,
|
188 |
+
type: str = "auto",
|
189 |
+
label: Optional[str] = None,
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
Parameters:
|
193 |
+
num_top_classes (int): number of most confident classes to show.
|
194 |
+
type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
|
195 |
+
label (str): component name in interface.
|
196 |
+
"""
|
197 |
+
warnings.warn(
|
198 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
199 |
+
)
|
200 |
+
super().__init__(num_top_classes=num_top_classes, type=type, label=label)
|
201 |
+
|
202 |
+
|
203 |
+
class KeyValues:
|
204 |
+
"""
|
205 |
+
Component displays a table representing values for multiple fields.
|
206 |
+
Output type: Union[Dict, List[Tuple[str, Union[str, int, float]]]]
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self, value: str = " ", *, label: Optional[str] = None, **kwargs):
|
210 |
+
"""
|
211 |
+
Parameters:
|
212 |
+
value (str): IGNORED
|
213 |
+
label (str): component name in interface.
|
214 |
+
"""
|
215 |
+
raise DeprecationWarning(
|
216 |
+
"The KeyValues component is deprecated. Please use the DataFrame or JSON "
|
217 |
+
"components instead."
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
class HighlightedText(components.HighlightedText):
|
222 |
+
"""
|
223 |
+
Component creates text that contains spans that are highlighted by category or numerical value.
|
224 |
+
Output is represent as a list of Tuple pairs, where the first element represents the span of text represented by the tuple, and the second element represents the category or value of the text.
|
225 |
+
Output type: List[Tuple[str, Union[float, str]]]
|
226 |
+
"""
|
227 |
+
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
color_map: Dict[str, str] = None,
|
231 |
+
label: Optional[str] = None,
|
232 |
+
show_legend: bool = False,
|
233 |
+
):
|
234 |
+
"""
|
235 |
+
Parameters:
|
236 |
+
color_map (Dict[str, str]): Map between category and respective colors
|
237 |
+
label (str): component name in interface.
|
238 |
+
show_legend (bool): whether to show span categories in a separate legend or inline.
|
239 |
+
"""
|
240 |
+
warnings.warn(
|
241 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
242 |
+
)
|
243 |
+
super().__init__(color_map=color_map, label=label, show_legend=show_legend)
|
244 |
+
|
245 |
+
|
246 |
+
class JSON(components.JSON):
|
247 |
+
"""
|
248 |
+
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
|
249 |
+
Output type: Union[str, Any]
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, label: Optional[str] = None):
|
253 |
+
"""
|
254 |
+
Parameters:
|
255 |
+
label (str): component name in interface.
|
256 |
+
"""
|
257 |
+
warnings.warn(
|
258 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
259 |
+
)
|
260 |
+
super().__init__(label=label)
|
261 |
+
|
262 |
+
|
263 |
+
class HTML(components.HTML):
|
264 |
+
"""
|
265 |
+
Used for HTML output. Expects an HTML valid string.
|
266 |
+
Output type: str
|
267 |
+
"""
|
268 |
+
|
269 |
+
def __init__(self, label: Optional[str] = None):
|
270 |
+
"""
|
271 |
+
Parameters:
|
272 |
+
label (str): component name in interface.
|
273 |
+
"""
|
274 |
+
super().__init__(label=label)
|
275 |
+
|
276 |
+
|
277 |
+
class Carousel(components.Carousel):
|
278 |
+
"""
|
279 |
+
Component displays a set of output components that can be scrolled through.
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
components: components.Component | List[components.Component],
|
285 |
+
label: Optional[str] = None,
|
286 |
+
):
|
287 |
+
"""
|
288 |
+
Parameters:
|
289 |
+
components (Union[List[Component], Component]): Classes of component(s) that will be scrolled through.
|
290 |
+
label (str): component name in interface.
|
291 |
+
"""
|
292 |
+
warnings.warn(
|
293 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
294 |
+
)
|
295 |
+
super().__init__(components=components, label=label)
|
296 |
+
|
297 |
+
|
298 |
+
class Chatbot(components.Chatbot):
|
299 |
+
"""
|
300 |
+
Component displays a chatbot output showing both user submitted messages and responses
|
301 |
+
Output type: List[Tuple[str, str]]
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self, label: Optional[str] = None):
|
305 |
+
"""
|
306 |
+
Parameters:
|
307 |
+
label (str): component name in interface (not used).
|
308 |
+
"""
|
309 |
+
warnings.warn(
|
310 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
311 |
+
)
|
312 |
+
super().__init__(label=label)
|
313 |
+
|
314 |
+
|
315 |
+
class Image3D(components.Model3D):
|
316 |
+
"""
|
317 |
+
Used for 3D image model output.
|
318 |
+
Input type: File object of type (.obj, glb, or .gltf)
|
319 |
+
"""
|
320 |
+
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
clear_color=None,
|
324 |
+
label: Optional[str] = None,
|
325 |
+
):
|
326 |
+
"""
|
327 |
+
Parameters:
|
328 |
+
label (str): component name in interface.
|
329 |
+
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
330 |
+
"""
|
331 |
+
warnings.warn(
|
332 |
+
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
333 |
+
)
|
334 |
+
super().__init__(clear_color=clear_color, label=label)
|
gradio-modified/gradio/pipelines.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module should not be used directly as its API is subject to change. Instead,
|
2 |
+
please use the `gr.Interface.from_pipeline()` function."""
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from typing import TYPE_CHECKING, Dict
|
7 |
+
|
8 |
+
from gradio import components
|
9 |
+
|
10 |
+
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
11 |
+
from transformers import pipelines
|
12 |
+
|
13 |
+
|
14 |
+
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
|
15 |
+
"""
|
16 |
+
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
17 |
+
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
18 |
+
Returns:
|
19 |
+
(dict): a dictionary of kwargs that can be used to construct an Interface object
|
20 |
+
"""
|
21 |
+
try:
|
22 |
+
import transformers
|
23 |
+
from transformers import pipelines
|
24 |
+
except ImportError:
|
25 |
+
raise ImportError(
|
26 |
+
"transformers not installed. Please try `pip install transformers`"
|
27 |
+
)
|
28 |
+
if not isinstance(pipeline, pipelines.base.Pipeline):
|
29 |
+
raise ValueError("pipeline must be a transformers.Pipeline")
|
30 |
+
|
31 |
+
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
|
32 |
+
# version of the transformers library that the user has installed.
|
33 |
+
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
|
34 |
+
pipeline, pipelines.audio_classification.AudioClassificationPipeline
|
35 |
+
):
|
36 |
+
pipeline_info = {
|
37 |
+
"inputs": components.Audio(
|
38 |
+
source="microphone", type="filepath", label="Input"
|
39 |
+
),
|
40 |
+
"outputs": components.Label(label="Class"),
|
41 |
+
"preprocess": lambda i: {"inputs": i},
|
42 |
+
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
43 |
+
}
|
44 |
+
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
|
45 |
+
pipeline,
|
46 |
+
pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
|
47 |
+
):
|
48 |
+
pipeline_info = {
|
49 |
+
"inputs": components.Audio(
|
50 |
+
source="microphone", type="filepath", label="Input"
|
51 |
+
),
|
52 |
+
"outputs": components.Textbox(label="Output"),
|
53 |
+
"preprocess": lambda i: {"inputs": i},
|
54 |
+
"postprocess": lambda r: r["text"],
|
55 |
+
}
|
56 |
+
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
|
57 |
+
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
|
58 |
+
):
|
59 |
+
pipeline_info = {
|
60 |
+
"inputs": components.Textbox(label="Input"),
|
61 |
+
"outputs": components.Dataframe(label="Output"),
|
62 |
+
"preprocess": lambda x: {"inputs": x},
|
63 |
+
"postprocess": lambda r: r[0],
|
64 |
+
}
|
65 |
+
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
|
66 |
+
pipeline, pipelines.fill_mask.FillMaskPipeline
|
67 |
+
):
|
68 |
+
pipeline_info = {
|
69 |
+
"inputs": components.Textbox(label="Input"),
|
70 |
+
"outputs": components.Label(label="Classification"),
|
71 |
+
"preprocess": lambda x: {"inputs": x},
|
72 |
+
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
73 |
+
}
|
74 |
+
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
|
75 |
+
pipeline, pipelines.image_classification.ImageClassificationPipeline
|
76 |
+
):
|
77 |
+
pipeline_info = {
|
78 |
+
"inputs": components.Image(type="filepath", label="Input Image"),
|
79 |
+
"outputs": components.Label(type="confidences", label="Classification"),
|
80 |
+
"preprocess": lambda i: {"images": i},
|
81 |
+
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
82 |
+
}
|
83 |
+
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
|
84 |
+
pipeline, pipelines.question_answering.QuestionAnsweringPipeline
|
85 |
+
):
|
86 |
+
pipeline_info = {
|
87 |
+
"inputs": [
|
88 |
+
components.Textbox(lines=7, label="Context"),
|
89 |
+
components.Textbox(label="Question"),
|
90 |
+
],
|
91 |
+
"outputs": [
|
92 |
+
components.Textbox(label="Answer"),
|
93 |
+
components.Label(label="Score"),
|
94 |
+
],
|
95 |
+
"preprocess": lambda c, q: {"context": c, "question": q},
|
96 |
+
"postprocess": lambda r: (r["answer"], r["score"]),
|
97 |
+
}
|
98 |
+
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
|
99 |
+
pipeline, pipelines.text2text_generation.SummarizationPipeline
|
100 |
+
):
|
101 |
+
pipeline_info = {
|
102 |
+
"inputs": components.Textbox(lines=7, label="Input"),
|
103 |
+
"outputs": components.Textbox(label="Summary"),
|
104 |
+
"preprocess": lambda x: {"inputs": x},
|
105 |
+
"postprocess": lambda r: r[0]["summary_text"],
|
106 |
+
}
|
107 |
+
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
|
108 |
+
pipeline, pipelines.text_classification.TextClassificationPipeline
|
109 |
+
):
|
110 |
+
pipeline_info = {
|
111 |
+
"inputs": components.Textbox(label="Input"),
|
112 |
+
"outputs": components.Label(label="Classification"),
|
113 |
+
"preprocess": lambda x: [x],
|
114 |
+
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
115 |
+
}
|
116 |
+
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
|
117 |
+
pipeline, pipelines.text_generation.TextGenerationPipeline
|
118 |
+
):
|
119 |
+
pipeline_info = {
|
120 |
+
"inputs": components.Textbox(label="Input"),
|
121 |
+
"outputs": components.Textbox(label="Output"),
|
122 |
+
"preprocess": lambda x: {"text_inputs": x},
|
123 |
+
"postprocess": lambda r: r[0]["generated_text"],
|
124 |
+
}
|
125 |
+
elif hasattr(transformers, "TranslationPipeline") and isinstance(
|
126 |
+
pipeline, pipelines.text2text_generation.TranslationPipeline
|
127 |
+
):
|
128 |
+
pipeline_info = {
|
129 |
+
"inputs": components.Textbox(label="Input"),
|
130 |
+
"outputs": components.Textbox(label="Translation"),
|
131 |
+
"preprocess": lambda x: [x],
|
132 |
+
"postprocess": lambda r: r[0]["translation_text"],
|
133 |
+
}
|
134 |
+
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
|
135 |
+
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
|
136 |
+
):
|
137 |
+
pipeline_info = {
|
138 |
+
"inputs": components.Textbox(label="Input"),
|
139 |
+
"outputs": components.Textbox(label="Generated Text"),
|
140 |
+
"preprocess": lambda x: [x],
|
141 |
+
"postprocess": lambda r: r[0]["generated_text"],
|
142 |
+
}
|
143 |
+
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
|
144 |
+
pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
|
145 |
+
):
|
146 |
+
pipeline_info = {
|
147 |
+
"inputs": [
|
148 |
+
components.Textbox(label="Input"),
|
149 |
+
components.Textbox(label="Possible class names (" "comma-separated)"),
|
150 |
+
components.Checkbox(label="Allow multiple true classes"),
|
151 |
+
],
|
152 |
+
"outputs": components.Label(label="Classification"),
|
153 |
+
"preprocess": lambda i, c, m: {
|
154 |
+
"sequences": i,
|
155 |
+
"candidate_labels": c,
|
156 |
+
"multi_label": m,
|
157 |
+
},
|
158 |
+
"postprocess": lambda r: {
|
159 |
+
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
|
160 |
+
},
|
161 |
+
}
|
162 |
+
else:
|
163 |
+
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
|
164 |
+
|
165 |
+
# define the function that will be called by the Interface
|
166 |
+
def fn(*params):
|
167 |
+
data = pipeline_info["preprocess"](*params)
|
168 |
+
# special cases that needs to be handled differently
|
169 |
+
if isinstance(
|
170 |
+
pipeline,
|
171 |
+
(
|
172 |
+
pipelines.text_classification.TextClassificationPipeline,
|
173 |
+
pipelines.text2text_generation.Text2TextGenerationPipeline,
|
174 |
+
pipelines.text2text_generation.TranslationPipeline,
|
175 |
+
),
|
176 |
+
):
|
177 |
+
data = pipeline(*data)
|
178 |
+
else:
|
179 |
+
data = pipeline(**data)
|
180 |
+
output = pipeline_info["postprocess"](data)
|
181 |
+
return output
|
182 |
+
|
183 |
+
interface_info = pipeline_info.copy()
|
184 |
+
interface_info["fn"] = fn
|
185 |
+
del interface_info["preprocess"]
|
186 |
+
del interface_info["postprocess"]
|
187 |
+
|
188 |
+
# define the title/description of the Interface
|
189 |
+
interface_info["title"] = pipeline.model.__class__.__name__
|
190 |
+
|
191 |
+
return interface_info
|
gradio-modified/gradio/processing_utils.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import base64
|
4 |
+
import hashlib
|
5 |
+
import json
|
6 |
+
import mimetypes
|
7 |
+
import os
|
8 |
+
import pathlib
|
9 |
+
import shutil
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import urllib.request
|
13 |
+
import warnings
|
14 |
+
from io import BytesIO
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Dict, Tuple
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import requests
|
20 |
+
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
|
21 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
22 |
+
|
23 |
+
from gradio import encryptor, utils
|
24 |
+
|
25 |
+
with warnings.catch_warnings():
|
26 |
+
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
|
27 |
+
from pydub import AudioSegment
|
28 |
+
|
29 |
+
|
30 |
+
#########################
|
31 |
+
# GENERAL
|
32 |
+
#########################
|
33 |
+
|
34 |
+
|
35 |
+
def to_binary(x: str | Dict) -> bytes:
|
36 |
+
"""Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
|
37 |
+
if isinstance(x, dict):
|
38 |
+
if x.get("data"):
|
39 |
+
base64str = x["data"]
|
40 |
+
else:
|
41 |
+
base64str = encode_url_or_file_to_base64(x["name"])
|
42 |
+
else:
|
43 |
+
base64str = x
|
44 |
+
return base64.b64decode(base64str.split(",")[1])
|
45 |
+
|
46 |
+
|
47 |
+
#########################
|
48 |
+
# IMAGE PRE-PROCESSING
|
49 |
+
#########################
|
50 |
+
|
51 |
+
|
52 |
+
def decode_base64_to_image(encoding: str) -> Image.Image:
|
53 |
+
content = encoding.split(";")[1]
|
54 |
+
image_encoded = content.split(",")[1]
|
55 |
+
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
56 |
+
|
57 |
+
|
58 |
+
def encode_url_or_file_to_base64(path: str | Path, encryption_key: bytes | None = None):
|
59 |
+
if utils.validate_url(str(path)):
|
60 |
+
return encode_url_to_base64(str(path), encryption_key=encryption_key)
|
61 |
+
else:
|
62 |
+
return encode_file_to_base64(str(path), encryption_key=encryption_key)
|
63 |
+
|
64 |
+
|
65 |
+
def get_mimetype(filename: str) -> str | None:
|
66 |
+
mimetype = mimetypes.guess_type(filename)[0]
|
67 |
+
if mimetype is not None:
|
68 |
+
mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
|
69 |
+
return mimetype
|
70 |
+
|
71 |
+
|
72 |
+
def get_extension(encoding: str) -> str | None:
|
73 |
+
encoding = encoding.replace("audio/wav", "audio/x-wav")
|
74 |
+
type = mimetypes.guess_type(encoding)[0]
|
75 |
+
if type == "audio/flac": # flac is not supported by mimetypes
|
76 |
+
return "flac"
|
77 |
+
elif type is None:
|
78 |
+
return None
|
79 |
+
extension = mimetypes.guess_extension(type)
|
80 |
+
if extension is not None and extension.startswith("."):
|
81 |
+
extension = extension[1:]
|
82 |
+
return extension
|
83 |
+
|
84 |
+
|
85 |
+
def encode_file_to_base64(f, encryption_key=None):
|
86 |
+
with open(f, "rb") as file:
|
87 |
+
encoded_string = base64.b64encode(file.read())
|
88 |
+
if encryption_key:
|
89 |
+
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
|
90 |
+
base64_str = str(encoded_string, "utf-8")
|
91 |
+
mimetype = get_mimetype(f)
|
92 |
+
return (
|
93 |
+
"data:"
|
94 |
+
+ (mimetype if mimetype is not None else "")
|
95 |
+
+ ";base64,"
|
96 |
+
+ base64_str
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def encode_url_to_base64(url, encryption_key=None):
|
101 |
+
encoded_string = base64.b64encode(requests.get(url).content)
|
102 |
+
if encryption_key:
|
103 |
+
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
|
104 |
+
base64_str = str(encoded_string, "utf-8")
|
105 |
+
mimetype = get_mimetype(url)
|
106 |
+
return (
|
107 |
+
"data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def encode_plot_to_base64(plt):
|
112 |
+
with BytesIO() as output_bytes:
|
113 |
+
plt.savefig(output_bytes, format="png")
|
114 |
+
bytes_data = output_bytes.getvalue()
|
115 |
+
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
116 |
+
return "data:image/png;base64," + base64_str
|
117 |
+
|
118 |
+
|
119 |
+
def save_array_to_file(image_array, dir=None):
|
120 |
+
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
121 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
122 |
+
pil_image.save(file_obj)
|
123 |
+
return file_obj
|
124 |
+
|
125 |
+
|
126 |
+
def save_pil_to_file(pil_image, dir=None):
|
127 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
128 |
+
pil_image.save(file_obj)
|
129 |
+
return file_obj
|
130 |
+
|
131 |
+
|
132 |
+
def encode_pil_to_base64(pil_image):
|
133 |
+
with BytesIO() as output_bytes:
|
134 |
+
|
135 |
+
# Copy any text-only metadata
|
136 |
+
use_metadata = False
|
137 |
+
metadata = PngImagePlugin.PngInfo()
|
138 |
+
for key, value in pil_image.info.items():
|
139 |
+
if isinstance(key, str) and isinstance(value, str):
|
140 |
+
metadata.add_text(key, value)
|
141 |
+
use_metadata = True
|
142 |
+
|
143 |
+
pil_image.save(
|
144 |
+
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
145 |
+
)
|
146 |
+
bytes_data = output_bytes.getvalue()
|
147 |
+
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
148 |
+
return "data:image/png;base64," + base64_str
|
149 |
+
|
150 |
+
|
151 |
+
def encode_array_to_base64(image_array):
|
152 |
+
with BytesIO() as output_bytes:
|
153 |
+
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
154 |
+
pil_image.save(output_bytes, "PNG")
|
155 |
+
bytes_data = output_bytes.getvalue()
|
156 |
+
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
157 |
+
return "data:image/png;base64," + base64_str
|
158 |
+
|
159 |
+
|
160 |
+
def resize_and_crop(img, size, crop_type="center"):
|
161 |
+
"""
|
162 |
+
Resize and crop an image to fit the specified size.
|
163 |
+
args:
|
164 |
+
size: `(width, height)` tuple. Pass `None` for either width or height
|
165 |
+
to only crop and resize the other.
|
166 |
+
crop_type: can be 'top', 'middle' or 'bottom', depending on this
|
167 |
+
value, the image will cropped getting the 'top/left', 'middle' or
|
168 |
+
'bottom/right' of the image to fit the size.
|
169 |
+
raises:
|
170 |
+
ValueError: if an invalid `crop_type` is provided.
|
171 |
+
"""
|
172 |
+
if crop_type == "top":
|
173 |
+
center = (0, 0)
|
174 |
+
elif crop_type == "center":
|
175 |
+
center = (0.5, 0.5)
|
176 |
+
else:
|
177 |
+
raise ValueError
|
178 |
+
|
179 |
+
resize = list(size)
|
180 |
+
if size[0] is None:
|
181 |
+
resize[0] = img.size[0]
|
182 |
+
if size[1] is None:
|
183 |
+
resize[1] = img.size[1]
|
184 |
+
return ImageOps.fit(img, resize, centering=center) # type: ignore
|
185 |
+
|
186 |
+
|
187 |
+
##################
|
188 |
+
# Audio
|
189 |
+
##################
|
190 |
+
|
191 |
+
|
192 |
+
def audio_from_file(filename, crop_min=0, crop_max=100):
|
193 |
+
try:
|
194 |
+
audio = AudioSegment.from_file(filename)
|
195 |
+
except FileNotFoundError as e:
|
196 |
+
isfile = Path(filename).is_file()
|
197 |
+
msg = (
|
198 |
+
f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
|
199 |
+
+ " Please install `ffmpeg` in your system to use non-WAV audio file formats"
|
200 |
+
" and make sure `ffprobe` is in your PATH."
|
201 |
+
if isfile
|
202 |
+
else ""
|
203 |
+
)
|
204 |
+
raise RuntimeError(msg) from e
|
205 |
+
if crop_min != 0 or crop_max != 100:
|
206 |
+
audio_start = len(audio) * crop_min / 100
|
207 |
+
audio_end = len(audio) * crop_max / 100
|
208 |
+
audio = audio[audio_start:audio_end]
|
209 |
+
data = np.array(audio.get_array_of_samples())
|
210 |
+
if audio.channels > 1:
|
211 |
+
data = data.reshape(-1, audio.channels)
|
212 |
+
return audio.frame_rate, data
|
213 |
+
|
214 |
+
|
215 |
+
def audio_to_file(sample_rate, data, filename):
|
216 |
+
data = convert_to_16_bit_wav(data)
|
217 |
+
audio = AudioSegment(
|
218 |
+
data.tobytes(),
|
219 |
+
frame_rate=sample_rate,
|
220 |
+
sample_width=data.dtype.itemsize,
|
221 |
+
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
222 |
+
)
|
223 |
+
file = audio.export(filename, format="wav")
|
224 |
+
file.close() # type: ignore
|
225 |
+
|
226 |
+
|
227 |
+
def convert_to_16_bit_wav(data):
|
228 |
+
# Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
|
229 |
+
warning = "Trying to convert audio automatically from {} to 16-bit int format."
|
230 |
+
if data.dtype in [np.float64, np.float32, np.float16]:
|
231 |
+
warnings.warn(warning.format(data.dtype))
|
232 |
+
data = data / np.abs(data).max()
|
233 |
+
data = data * 32767
|
234 |
+
data = data.astype(np.int16)
|
235 |
+
elif data.dtype == np.int32:
|
236 |
+
warnings.warn(warning.format(data.dtype))
|
237 |
+
data = data / 65538
|
238 |
+
data = data.astype(np.int16)
|
239 |
+
elif data.dtype == np.int16:
|
240 |
+
pass
|
241 |
+
elif data.dtype == np.uint16:
|
242 |
+
warnings.warn(warning.format(data.dtype))
|
243 |
+
data = data - 32768
|
244 |
+
data = data.astype(np.int16)
|
245 |
+
elif data.dtype == np.uint8:
|
246 |
+
warnings.warn(warning.format(data.dtype))
|
247 |
+
data = data * 257 - 32768
|
248 |
+
data = data.astype(np.int16)
|
249 |
+
else:
|
250 |
+
raise ValueError(
|
251 |
+
"Audio data cannot be converted automatically from "
|
252 |
+
f"{data.dtype} to 16-bit int format."
|
253 |
+
)
|
254 |
+
return data
|
255 |
+
|
256 |
+
|
257 |
+
##################
|
258 |
+
# OUTPUT
|
259 |
+
##################
|
260 |
+
|
261 |
+
|
262 |
+
def decode_base64_to_binary(encoding) -> Tuple[bytes, str | None]:
|
263 |
+
extension = get_extension(encoding)
|
264 |
+
data = encoding.split(",")[1]
|
265 |
+
return base64.b64decode(data), extension
|
266 |
+
|
267 |
+
|
268 |
+
def decode_base64_to_file(
|
269 |
+
encoding, encryption_key=None, file_path=None, dir=None, prefix=None
|
270 |
+
):
|
271 |
+
if dir is not None:
|
272 |
+
os.makedirs(dir, exist_ok=True)
|
273 |
+
data, extension = decode_base64_to_binary(encoding)
|
274 |
+
if file_path is not None and prefix is None:
|
275 |
+
filename = Path(file_path).name
|
276 |
+
prefix = filename
|
277 |
+
if "." in filename:
|
278 |
+
prefix = filename[0 : filename.index(".")]
|
279 |
+
extension = filename[filename.index(".") + 1 :]
|
280 |
+
|
281 |
+
if prefix is not None:
|
282 |
+
prefix = utils.strip_invalid_filename_characters(prefix)
|
283 |
+
|
284 |
+
if extension is None:
|
285 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
|
286 |
+
else:
|
287 |
+
file_obj = tempfile.NamedTemporaryFile(
|
288 |
+
delete=False,
|
289 |
+
prefix=prefix,
|
290 |
+
suffix="." + extension,
|
291 |
+
dir=dir,
|
292 |
+
)
|
293 |
+
if encryption_key is not None:
|
294 |
+
data = encryptor.encrypt(encryption_key, data)
|
295 |
+
file_obj.write(data)
|
296 |
+
file_obj.flush()
|
297 |
+
return file_obj
|
298 |
+
|
299 |
+
|
300 |
+
def dict_or_str_to_json_file(jsn, dir=None):
|
301 |
+
if dir is not None:
|
302 |
+
os.makedirs(dir, exist_ok=True)
|
303 |
+
|
304 |
+
file_obj = tempfile.NamedTemporaryFile(
|
305 |
+
delete=False, suffix=".json", dir=dir, mode="w+"
|
306 |
+
)
|
307 |
+
if isinstance(jsn, str):
|
308 |
+
jsn = json.loads(jsn)
|
309 |
+
json.dump(jsn, file_obj)
|
310 |
+
file_obj.flush()
|
311 |
+
return file_obj
|
312 |
+
|
313 |
+
|
314 |
+
def file_to_json(file_path: str | Path) -> Dict:
|
315 |
+
with open(file_path) as f:
|
316 |
+
return json.load(f)
|
317 |
+
|
318 |
+
|
319 |
+
class TempFileManager:
|
320 |
+
"""
|
321 |
+
A class that should be inherited by any Component that needs to manage temporary files.
|
322 |
+
It should be instantiated in the __init__ method of the component.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self) -> None:
|
326 |
+
# Set stores all the temporary files created by this component.
|
327 |
+
self.temp_files = set()
|
328 |
+
|
329 |
+
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
330 |
+
sha1 = hashlib.sha1()
|
331 |
+
with open(file_path, "rb") as f:
|
332 |
+
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
333 |
+
sha1.update(chunk)
|
334 |
+
return sha1.hexdigest()
|
335 |
+
|
336 |
+
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
337 |
+
sha1 = hashlib.sha1()
|
338 |
+
remote = urllib.request.urlopen(url)
|
339 |
+
max_file_size = 100 * 1024 * 1024 # 100MB
|
340 |
+
total_read = 0
|
341 |
+
while True:
|
342 |
+
data = remote.read(chunk_num_blocks * sha1.block_size)
|
343 |
+
total_read += chunk_num_blocks * sha1.block_size
|
344 |
+
if not data or total_read > max_file_size:
|
345 |
+
break
|
346 |
+
sha1.update(data)
|
347 |
+
return sha1.hexdigest()
|
348 |
+
|
349 |
+
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
|
350 |
+
file_name = Path(file_path_or_url).name
|
351 |
+
prefix, extension = file_name, None
|
352 |
+
if "." in file_name:
|
353 |
+
prefix = file_name[0 : file_name.index(".")]
|
354 |
+
extension = "." + file_name[file_name.index(".") + 1 :]
|
355 |
+
else:
|
356 |
+
extension = ""
|
357 |
+
prefix = utils.strip_invalid_filename_characters(prefix)
|
358 |
+
return prefix, extension
|
359 |
+
|
360 |
+
def get_temp_file_path(self, file_path: str) -> str:
|
361 |
+
prefix, extension = self.get_prefix_and_extension(file_path)
|
362 |
+
file_hash = self.hash_file(file_path)
|
363 |
+
return prefix + file_hash + extension
|
364 |
+
|
365 |
+
def get_temp_url_path(self, url: str) -> str:
|
366 |
+
prefix, extension = self.get_prefix_and_extension(url)
|
367 |
+
file_hash = self.hash_url(url)
|
368 |
+
return prefix + file_hash + extension
|
369 |
+
|
370 |
+
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
371 |
+
"""Returns a temporary file path for a copy of the given file path if it does
|
372 |
+
not already exist. Otherwise returns the path to the existing temp file."""
|
373 |
+
f = tempfile.NamedTemporaryFile()
|
374 |
+
temp_dir = Path(f.name).parent
|
375 |
+
|
376 |
+
temp_file_path = self.get_temp_file_path(file_path)
|
377 |
+
f.name = str(temp_dir / temp_file_path)
|
378 |
+
full_temp_file_path = str(Path(f.name).resolve())
|
379 |
+
|
380 |
+
if not Path(full_temp_file_path).exists():
|
381 |
+
shutil.copy2(file_path, full_temp_file_path)
|
382 |
+
|
383 |
+
self.temp_files.add(full_temp_file_path)
|
384 |
+
return full_temp_file_path
|
385 |
+
|
386 |
+
def download_temp_copy_if_needed(self, url: str) -> str:
|
387 |
+
"""Downloads a file and makes a temporary file path for a copy if does not already
|
388 |
+
exist. Otherwise returns the path to the existing temp file."""
|
389 |
+
f = tempfile.NamedTemporaryFile()
|
390 |
+
temp_dir = Path(f.name).parent
|
391 |
+
|
392 |
+
temp_file_path = self.get_temp_url_path(url)
|
393 |
+
f.name = str(temp_dir / temp_file_path)
|
394 |
+
full_temp_file_path = str(Path(f.name).resolve())
|
395 |
+
|
396 |
+
if not Path(full_temp_file_path).exists():
|
397 |
+
with requests.get(url, stream=True) as r:
|
398 |
+
with open(full_temp_file_path, "wb") as f:
|
399 |
+
shutil.copyfileobj(r.raw, f)
|
400 |
+
|
401 |
+
self.temp_files.add(full_temp_file_path)
|
402 |
+
return full_temp_file_path
|
403 |
+
|
404 |
+
|
405 |
+
def create_tmp_copy_of_file(file_path, dir=None):
|
406 |
+
if dir is not None:
|
407 |
+
os.makedirs(dir, exist_ok=True)
|
408 |
+
file_name = Path(file_path).name
|
409 |
+
prefix, extension = file_name, None
|
410 |
+
if "." in file_name:
|
411 |
+
prefix = file_name[0 : file_name.index(".")]
|
412 |
+
extension = file_name[file_name.index(".") + 1 :]
|
413 |
+
prefix = utils.strip_invalid_filename_characters(prefix)
|
414 |
+
if extension is None:
|
415 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
|
416 |
+
else:
|
417 |
+
file_obj = tempfile.NamedTemporaryFile(
|
418 |
+
delete=False,
|
419 |
+
prefix=prefix,
|
420 |
+
suffix="." + extension,
|
421 |
+
dir=dir,
|
422 |
+
)
|
423 |
+
shutil.copy2(file_path, file_obj.name)
|
424 |
+
return file_obj
|
425 |
+
|
426 |
+
|
427 |
+
def _convert(image, dtype, force_copy=False, uniform=False):
|
428 |
+
"""
|
429 |
+
Adapted from: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/dtype.py#L510-L531
|
430 |
+
|
431 |
+
Convert an image to the requested data-type.
|
432 |
+
Warnings are issued in case of precision loss, or when negative values
|
433 |
+
are clipped during conversion to unsigned integer types (sign loss).
|
434 |
+
Floating point values are expected to be normalized and will be clipped
|
435 |
+
to the range [0.0, 1.0] or [-1.0, 1.0] when converting to unsigned or
|
436 |
+
signed integers respectively.
|
437 |
+
Numbers are not shifted to the negative side when converting from
|
438 |
+
unsigned to signed integer types. Negative values will be clipped when
|
439 |
+
converting to unsigned integers.
|
440 |
+
Parameters
|
441 |
+
----------
|
442 |
+
image : ndarray
|
443 |
+
Input image.
|
444 |
+
dtype : dtype
|
445 |
+
Target data-type.
|
446 |
+
force_copy : bool, optional
|
447 |
+
Force a copy of the data, irrespective of its current dtype.
|
448 |
+
uniform : bool, optional
|
449 |
+
Uniformly quantize the floating point range to the integer range.
|
450 |
+
By default (uniform=False) floating point values are scaled and
|
451 |
+
rounded to the nearest integers, which minimizes back and forth
|
452 |
+
conversion errors.
|
453 |
+
.. versionchanged :: 0.15
|
454 |
+
``_convert`` no longer warns about possible precision or sign
|
455 |
+
information loss. See discussions on these warnings at:
|
456 |
+
https://github.com/scikit-image/scikit-image/issues/2602
|
457 |
+
https://github.com/scikit-image/scikit-image/issues/543#issuecomment-208202228
|
458 |
+
https://github.com/scikit-image/scikit-image/pull/3575
|
459 |
+
References
|
460 |
+
----------
|
461 |
+
.. [1] DirectX data conversion rules.
|
462 |
+
https://msdn.microsoft.com/en-us/library/windows/desktop/dd607323%28v=vs.85%29.aspx
|
463 |
+
.. [2] Data Conversions. In "OpenGL ES 2.0 Specification v2.0.25",
|
464 |
+
pp 7-8. Khronos Group, 2010.
|
465 |
+
.. [3] Proper treatment of pixels as integers. A.W. Paeth.
|
466 |
+
In "Graphics Gems I", pp 249-256. Morgan Kaufmann, 1990.
|
467 |
+
.. [4] Dirty Pixels. J. Blinn. In "Jim Blinn's corner: Dirty Pixels",
|
468 |
+
pp 47-57. Morgan Kaufmann, 1998.
|
469 |
+
"""
|
470 |
+
dtype_range = {
|
471 |
+
bool: (False, True),
|
472 |
+
np.bool_: (False, True),
|
473 |
+
np.bool8: (False, True),
|
474 |
+
float: (-1, 1),
|
475 |
+
np.float_: (-1, 1),
|
476 |
+
np.float16: (-1, 1),
|
477 |
+
np.float32: (-1, 1),
|
478 |
+
np.float64: (-1, 1),
|
479 |
+
}
|
480 |
+
|
481 |
+
def _dtype_itemsize(itemsize, *dtypes):
|
482 |
+
"""Return first of `dtypes` with itemsize greater than `itemsize`
|
483 |
+
Parameters
|
484 |
+
----------
|
485 |
+
itemsize: int
|
486 |
+
The data type object element size.
|
487 |
+
Other Parameters
|
488 |
+
----------------
|
489 |
+
*dtypes:
|
490 |
+
Any Object accepted by `np.dtype` to be converted to a data
|
491 |
+
type object
|
492 |
+
Returns
|
493 |
+
-------
|
494 |
+
dtype: data type object
|
495 |
+
First of `dtypes` with itemsize greater than `itemsize`.
|
496 |
+
"""
|
497 |
+
return next(dt for dt in dtypes if np.dtype(dt).itemsize >= itemsize)
|
498 |
+
|
499 |
+
def _dtype_bits(kind, bits, itemsize=1):
|
500 |
+
"""Return dtype of `kind` that can store a `bits` wide unsigned int
|
501 |
+
Parameters:
|
502 |
+
kind: str
|
503 |
+
Data type kind.
|
504 |
+
bits: int
|
505 |
+
Desired number of bits.
|
506 |
+
itemsize: int
|
507 |
+
The data type object element size.
|
508 |
+
Returns
|
509 |
+
-------
|
510 |
+
dtype: data type object
|
511 |
+
Data type of `kind` that can store a `bits` wide unsigned int
|
512 |
+
"""
|
513 |
+
|
514 |
+
s = next(
|
515 |
+
i
|
516 |
+
for i in (itemsize,) + (2, 4, 8)
|
517 |
+
if bits < (i * 8) or (bits == (i * 8) and kind == "u")
|
518 |
+
)
|
519 |
+
|
520 |
+
return np.dtype(kind + str(s))
|
521 |
+
|
522 |
+
def _scale(a, n, m, copy=True):
|
523 |
+
"""Scale an array of unsigned/positive integers from `n` to `m` bits.
|
524 |
+
Numbers can be represented exactly only if `m` is a multiple of `n`.
|
525 |
+
Parameters
|
526 |
+
----------
|
527 |
+
a : ndarray
|
528 |
+
Input image array.
|
529 |
+
n : int
|
530 |
+
Number of bits currently used to encode the values in `a`.
|
531 |
+
m : int
|
532 |
+
Desired number of bits to encode the values in `out`.
|
533 |
+
copy : bool, optional
|
534 |
+
If True, allocates and returns new array. Otherwise, modifies
|
535 |
+
`a` in place.
|
536 |
+
Returns
|
537 |
+
-------
|
538 |
+
out : array
|
539 |
+
Output image array. Has the same kind as `a`.
|
540 |
+
"""
|
541 |
+
kind = a.dtype.kind
|
542 |
+
if n > m and a.max() < 2**m:
|
543 |
+
return a.astype(_dtype_bits(kind, m))
|
544 |
+
elif n == m:
|
545 |
+
return a.copy() if copy else a
|
546 |
+
elif n > m:
|
547 |
+
# downscale with precision loss
|
548 |
+
if copy:
|
549 |
+
b = np.empty(a.shape, _dtype_bits(kind, m))
|
550 |
+
np.floor_divide(a, 2 ** (n - m), out=b, dtype=a.dtype, casting="unsafe")
|
551 |
+
return b
|
552 |
+
else:
|
553 |
+
a //= 2 ** (n - m)
|
554 |
+
return a
|
555 |
+
elif m % n == 0:
|
556 |
+
# exact upscale to a multiple of `n` bits
|
557 |
+
if copy:
|
558 |
+
b = np.empty(a.shape, _dtype_bits(kind, m))
|
559 |
+
np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
560 |
+
return b
|
561 |
+
else:
|
562 |
+
a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
|
563 |
+
a *= (2**m - 1) // (2**n - 1)
|
564 |
+
return a
|
565 |
+
else:
|
566 |
+
# upscale to a multiple of `n` bits,
|
567 |
+
# then downscale with precision loss
|
568 |
+
o = (m // n + 1) * n
|
569 |
+
if copy:
|
570 |
+
b = np.empty(a.shape, _dtype_bits(kind, o))
|
571 |
+
np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
572 |
+
b //= 2 ** (o - m)
|
573 |
+
return b
|
574 |
+
else:
|
575 |
+
a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
|
576 |
+
a *= (2**o - 1) // (2**n - 1)
|
577 |
+
a //= 2 ** (o - m)
|
578 |
+
return a
|
579 |
+
|
580 |
+
image = np.asarray(image)
|
581 |
+
dtypeobj_in = image.dtype
|
582 |
+
if dtype is np.floating:
|
583 |
+
dtypeobj_out = np.dtype("float64")
|
584 |
+
else:
|
585 |
+
dtypeobj_out = np.dtype(dtype)
|
586 |
+
dtype_in = dtypeobj_in.type
|
587 |
+
dtype_out = dtypeobj_out.type
|
588 |
+
kind_in = dtypeobj_in.kind
|
589 |
+
kind_out = dtypeobj_out.kind
|
590 |
+
itemsize_in = dtypeobj_in.itemsize
|
591 |
+
itemsize_out = dtypeobj_out.itemsize
|
592 |
+
|
593 |
+
# Below, we do an `issubdtype` check. Its purpose is to find out
|
594 |
+
# whether we can get away without doing any image conversion. This happens
|
595 |
+
# when:
|
596 |
+
#
|
597 |
+
# - the output and input dtypes are the same or
|
598 |
+
# - when the output is specified as a type, and the input dtype
|
599 |
+
# is a subclass of that type (e.g. `np.floating` will allow
|
600 |
+
# `float32` and `float64` arrays through)
|
601 |
+
|
602 |
+
if np.issubdtype(dtype_in, np.obj2sctype(dtype)):
|
603 |
+
if force_copy:
|
604 |
+
image = image.copy()
|
605 |
+
return image
|
606 |
+
|
607 |
+
if kind_in in "ui":
|
608 |
+
imin_in = np.iinfo(dtype_in).min
|
609 |
+
imax_in = np.iinfo(dtype_in).max
|
610 |
+
if kind_out in "ui":
|
611 |
+
imin_out = np.iinfo(dtype_out).min # type: ignore
|
612 |
+
imax_out = np.iinfo(dtype_out).max # type: ignore
|
613 |
+
|
614 |
+
# any -> binary
|
615 |
+
if kind_out == "b":
|
616 |
+
return image > dtype_in(dtype_range[dtype_in][1] / 2)
|
617 |
+
|
618 |
+
# binary -> any
|
619 |
+
if kind_in == "b":
|
620 |
+
result = image.astype(dtype_out)
|
621 |
+
if kind_out != "f":
|
622 |
+
result *= dtype_out(dtype_range[dtype_out][1])
|
623 |
+
return result
|
624 |
+
|
625 |
+
# float -> any
|
626 |
+
if kind_in == "f":
|
627 |
+
if kind_out == "f":
|
628 |
+
# float -> float
|
629 |
+
return image.astype(dtype_out)
|
630 |
+
|
631 |
+
if np.min(image) < -1.0 or np.max(image) > 1.0:
|
632 |
+
raise ValueError("Images of type float must be between -1 and 1.")
|
633 |
+
# floating point -> integer
|
634 |
+
# use float type that can represent output integer type
|
635 |
+
computation_type = _dtype_itemsize(
|
636 |
+
itemsize_out, dtype_in, np.float32, np.float64
|
637 |
+
)
|
638 |
+
|
639 |
+
if not uniform:
|
640 |
+
if kind_out == "u":
|
641 |
+
image_out = np.multiply(image, imax_out, dtype=computation_type) # type: ignore
|
642 |
+
else:
|
643 |
+
image_out = np.multiply(
|
644 |
+
image, (imax_out - imin_out) / 2, dtype=computation_type # type: ignore
|
645 |
+
)
|
646 |
+
image_out -= 1.0 / 2.0
|
647 |
+
np.rint(image_out, out=image_out)
|
648 |
+
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
|
649 |
+
elif kind_out == "u":
|
650 |
+
image_out = np.multiply(image, imax_out + 1, dtype=computation_type) # type: ignore
|
651 |
+
np.clip(image_out, 0, imax_out, out=image_out) # type: ignore
|
652 |
+
else:
|
653 |
+
image_out = np.multiply(
|
654 |
+
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type # type: ignore
|
655 |
+
)
|
656 |
+
np.floor(image_out, out=image_out)
|
657 |
+
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
|
658 |
+
return image_out.astype(dtype_out)
|
659 |
+
|
660 |
+
# signed/unsigned int -> float
|
661 |
+
if kind_out == "f":
|
662 |
+
# use float type that can exactly represent input integers
|
663 |
+
computation_type = _dtype_itemsize(
|
664 |
+
itemsize_in, dtype_out, np.float32, np.float64
|
665 |
+
)
|
666 |
+
|
667 |
+
if kind_in == "u":
|
668 |
+
# using np.divide or np.multiply doesn't copy the data
|
669 |
+
# until the computation time
|
670 |
+
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) # type: ignore
|
671 |
+
# DirectX uses this conversion also for signed ints
|
672 |
+
# if imin_in:
|
673 |
+
# np.maximum(image, -1.0, out=image)
|
674 |
+
else:
|
675 |
+
image = np.add(image, 0.5, dtype=computation_type)
|
676 |
+
image *= 2 / (imax_in - imin_in) # type: ignore
|
677 |
+
|
678 |
+
return np.asarray(image, dtype_out)
|
679 |
+
|
680 |
+
# unsigned int -> signed/unsigned int
|
681 |
+
if kind_in == "u":
|
682 |
+
if kind_out == "i":
|
683 |
+
# unsigned int -> signed int
|
684 |
+
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out - 1)
|
685 |
+
return image.view(dtype_out)
|
686 |
+
else:
|
687 |
+
# unsigned int -> unsigned int
|
688 |
+
return _scale(image, 8 * itemsize_in, 8 * itemsize_out)
|
689 |
+
|
690 |
+
# signed int -> unsigned int
|
691 |
+
if kind_out == "u":
|
692 |
+
image = _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out)
|
693 |
+
result = np.empty(image.shape, dtype_out)
|
694 |
+
np.maximum(image, 0, out=result, dtype=image.dtype, casting="unsafe")
|
695 |
+
return result
|
696 |
+
|
697 |
+
# signed int -> signed int
|
698 |
+
if itemsize_in > itemsize_out:
|
699 |
+
return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
|
700 |
+
|
701 |
+
image = image.astype(_dtype_bits("i", itemsize_out * 8))
|
702 |
+
image -= imin_in # type: ignore
|
703 |
+
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
|
704 |
+
image += imin_out # type: ignore
|
705 |
+
return image.astype(dtype_out)
|
706 |
+
|
707 |
+
|
708 |
+
def ffmpeg_installed() -> bool:
|
709 |
+
return shutil.which("ffmpeg") is not None
|
710 |
+
|
711 |
+
|
712 |
+
def video_is_playable(video_filepath: str) -> bool:
|
713 |
+
"""Determines if a video is playable in the browser.
|
714 |
+
|
715 |
+
A video is playable if it has a playable container and codec.
|
716 |
+
.mp4 -> h264
|
717 |
+
.webm -> vp9
|
718 |
+
.ogg -> theora
|
719 |
+
"""
|
720 |
+
try:
|
721 |
+
container = pathlib.Path(video_filepath).suffix.lower()
|
722 |
+
probe = FFprobe(
|
723 |
+
global_options="-show_format -show_streams -select_streams v -print_format json",
|
724 |
+
inputs={video_filepath: None},
|
725 |
+
)
|
726 |
+
output = probe.run(stderr=subprocess.PIPE, stdout=subprocess.PIPE)
|
727 |
+
output = json.loads(output[0])
|
728 |
+
video_codec = output["streams"][0]["codec_name"]
|
729 |
+
return (container, video_codec) in [
|
730 |
+
(".mp4", "h264"),
|
731 |
+
(".ogg", "theora"),
|
732 |
+
(".webm", "vp9"),
|
733 |
+
]
|
734 |
+
# If anything goes wrong, assume the video can be played to not convert downstream
|
735 |
+
except (FFRuntimeError, IndexError, KeyError):
|
736 |
+
return True
|
737 |
+
|
738 |
+
|
739 |
+
def convert_video_to_playable_mp4(video_path: str) -> str:
|
740 |
+
"""Convert the video to mp4. If something goes wrong return the original video."""
|
741 |
+
try:
|
742 |
+
output_path = pathlib.Path(video_path).with_suffix(".mp4")
|
743 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
744 |
+
shutil.copy2(video_path, tmp_file.name)
|
745 |
+
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
|
746 |
+
ff = FFmpeg(
|
747 |
+
inputs={str(tmp_file.name): None},
|
748 |
+
outputs={str(output_path): None},
|
749 |
+
global_options="-y -loglevel quiet",
|
750 |
+
)
|
751 |
+
ff.run()
|
752 |
+
except FFRuntimeError as e:
|
753 |
+
print(f"Error converting video to browser-playable format {str(e)}")
|
754 |
+
output_path = video_path
|
755 |
+
return str(output_path)
|
gradio-modified/gradio/queueing.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import copy
|
5 |
+
import sys
|
6 |
+
import time
|
7 |
+
from collections import deque
|
8 |
+
from typing import Any, Deque, Dict, List, Tuple
|
9 |
+
|
10 |
+
import fastapi
|
11 |
+
|
12 |
+
from gradio.data_classes import Estimation, PredictBody, Progress, ProgressUnit
|
13 |
+
from gradio.helpers import TrackedIterable
|
14 |
+
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
|
15 |
+
|
16 |
+
|
17 |
+
class Event:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
websocket: fastapi.WebSocket,
|
21 |
+
session_hash: str,
|
22 |
+
fn_index: int,
|
23 |
+
):
|
24 |
+
self.websocket = websocket
|
25 |
+
self.session_hash: str = session_hash
|
26 |
+
self.fn_index: int = fn_index
|
27 |
+
self._id = f"{self.session_hash}_{self.fn_index}"
|
28 |
+
self.data: PredictBody | None = None
|
29 |
+
self.lost_connection_time: float | None = None
|
30 |
+
self.token: str | None = None
|
31 |
+
self.progress: Progress | None = None
|
32 |
+
self.progress_pending: bool = False
|
33 |
+
|
34 |
+
async def disconnect(self, code: int = 1000):
|
35 |
+
await self.websocket.close(code=code)
|
36 |
+
|
37 |
+
|
38 |
+
class Queue:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
live_updates: bool,
|
42 |
+
concurrency_count: int,
|
43 |
+
update_intervals: float,
|
44 |
+
max_size: int | None,
|
45 |
+
blocks_dependencies: List,
|
46 |
+
):
|
47 |
+
self.event_queue: Deque[Event] = deque()
|
48 |
+
self.events_pending_reconnection = []
|
49 |
+
self.stopped = False
|
50 |
+
self.max_thread_count = concurrency_count
|
51 |
+
self.update_intervals = update_intervals
|
52 |
+
self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
|
53 |
+
self.delete_lock = asyncio.Lock()
|
54 |
+
self.server_path = None
|
55 |
+
self.duration_history_total = 0
|
56 |
+
self.duration_history_count = 0
|
57 |
+
self.avg_process_time = 0
|
58 |
+
self.avg_concurrent_process_time = None
|
59 |
+
self.queue_duration = 1
|
60 |
+
self.live_updates = live_updates
|
61 |
+
self.sleep_when_free = 0.05
|
62 |
+
self.progress_update_sleep_when_free = 0.1
|
63 |
+
self.max_size = max_size
|
64 |
+
self.blocks_dependencies = blocks_dependencies
|
65 |
+
self.access_token = ""
|
66 |
+
|
67 |
+
async def start(self, progress_tracking=False):
|
68 |
+
run_coro_in_background(self.start_processing)
|
69 |
+
if progress_tracking:
|
70 |
+
run_coro_in_background(self.start_progress_tracking)
|
71 |
+
if not self.live_updates:
|
72 |
+
run_coro_in_background(self.notify_clients)
|
73 |
+
|
74 |
+
def close(self):
|
75 |
+
self.stopped = True
|
76 |
+
|
77 |
+
def resume(self):
|
78 |
+
self.stopped = False
|
79 |
+
|
80 |
+
def set_url(self, url: str):
|
81 |
+
self.server_path = url
|
82 |
+
|
83 |
+
def set_access_token(self, token: str):
|
84 |
+
self.access_token = token
|
85 |
+
|
86 |
+
def get_active_worker_count(self) -> int:
|
87 |
+
count = 0
|
88 |
+
for worker in self.active_jobs:
|
89 |
+
if worker is not None:
|
90 |
+
count += 1
|
91 |
+
return count
|
92 |
+
|
93 |
+
def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
|
94 |
+
if not (self.event_queue):
|
95 |
+
return None, False
|
96 |
+
|
97 |
+
first_event = self.event_queue.popleft()
|
98 |
+
events = [first_event]
|
99 |
+
|
100 |
+
event_fn_index = first_event.fn_index
|
101 |
+
batch = self.blocks_dependencies[event_fn_index]["batch"]
|
102 |
+
|
103 |
+
if batch:
|
104 |
+
batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"]
|
105 |
+
rest_of_batch = [
|
106 |
+
event for event in self.event_queue if event.fn_index == event_fn_index
|
107 |
+
][: batch_size - 1]
|
108 |
+
events.extend(rest_of_batch)
|
109 |
+
[self.event_queue.remove(event) for event in rest_of_batch]
|
110 |
+
|
111 |
+
return events, batch
|
112 |
+
|
113 |
+
async def start_processing(self) -> None:
|
114 |
+
while not self.stopped:
|
115 |
+
if not self.event_queue:
|
116 |
+
await asyncio.sleep(self.sleep_when_free)
|
117 |
+
continue
|
118 |
+
|
119 |
+
if not (None in self.active_jobs):
|
120 |
+
await asyncio.sleep(self.sleep_when_free)
|
121 |
+
continue
|
122 |
+
# Using mutex to avoid editing a list in use
|
123 |
+
async with self.delete_lock:
|
124 |
+
events, batch = self.get_events_in_batch()
|
125 |
+
|
126 |
+
if events:
|
127 |
+
self.active_jobs[self.active_jobs.index(None)] = events
|
128 |
+
task = run_coro_in_background(self.process_events, events, batch)
|
129 |
+
run_coro_in_background(self.broadcast_live_estimations)
|
130 |
+
set_task_name(task, events[0].session_hash, events[0].fn_index, batch)
|
131 |
+
|
132 |
+
async def start_progress_tracking(self) -> None:
|
133 |
+
while not self.stopped:
|
134 |
+
if not any(self.active_jobs):
|
135 |
+
await asyncio.sleep(self.progress_update_sleep_when_free)
|
136 |
+
continue
|
137 |
+
|
138 |
+
for job in self.active_jobs:
|
139 |
+
if job is None:
|
140 |
+
continue
|
141 |
+
for event in job:
|
142 |
+
if event.progress_pending and event.progress:
|
143 |
+
event.progress_pending = False
|
144 |
+
client_awake = await self.send_message(
|
145 |
+
event, event.progress.dict()
|
146 |
+
)
|
147 |
+
if not client_awake:
|
148 |
+
await self.clean_event(event)
|
149 |
+
|
150 |
+
await asyncio.sleep(self.progress_update_sleep_when_free)
|
151 |
+
|
152 |
+
def set_progress(
|
153 |
+
self,
|
154 |
+
event_id: str,
|
155 |
+
iterables: List[TrackedIterable] | None,
|
156 |
+
):
|
157 |
+
if iterables is None:
|
158 |
+
return
|
159 |
+
for job in self.active_jobs:
|
160 |
+
if job is None:
|
161 |
+
continue
|
162 |
+
for evt in job:
|
163 |
+
if evt._id == event_id:
|
164 |
+
progress_data: List[ProgressUnit] = []
|
165 |
+
for iterable in iterables:
|
166 |
+
progress_unit = ProgressUnit(
|
167 |
+
index=iterable.index,
|
168 |
+
length=iterable.length,
|
169 |
+
unit=iterable.unit,
|
170 |
+
progress=iterable.progress,
|
171 |
+
desc=iterable.desc,
|
172 |
+
)
|
173 |
+
progress_data.append(progress_unit)
|
174 |
+
evt.progress = Progress(progress_data=progress_data)
|
175 |
+
evt.progress_pending = True
|
176 |
+
|
177 |
+
def push(self, event: Event) -> int | None:
|
178 |
+
"""
|
179 |
+
Add event to queue, or return None if Queue is full
|
180 |
+
Parameters:
|
181 |
+
event: Event to add to Queue
|
182 |
+
Returns:
|
183 |
+
rank of submitted Event
|
184 |
+
"""
|
185 |
+
queue_len = len(self.event_queue)
|
186 |
+
if self.max_size is not None and queue_len >= self.max_size:
|
187 |
+
return None
|
188 |
+
self.event_queue.append(event)
|
189 |
+
return queue_len
|
190 |
+
|
191 |
+
async def clean_event(self, event: Event) -> None:
|
192 |
+
if event in self.event_queue:
|
193 |
+
async with self.delete_lock:
|
194 |
+
self.event_queue.remove(event)
|
195 |
+
|
196 |
+
async def broadcast_live_estimations(self) -> None:
|
197 |
+
"""
|
198 |
+
Runs 2 functions sequentially instead of concurrently. Otherwise dced clients are tried to get deleted twice.
|
199 |
+
"""
|
200 |
+
if self.live_updates:
|
201 |
+
await self.broadcast_estimations()
|
202 |
+
|
203 |
+
async def gather_event_data(self, event: Event) -> bool:
|
204 |
+
"""
|
205 |
+
Gather data for the event
|
206 |
+
|
207 |
+
Parameters:
|
208 |
+
event:
|
209 |
+
"""
|
210 |
+
if not event.data:
|
211 |
+
client_awake = await self.send_message(event, {"msg": "send_data"})
|
212 |
+
if not client_awake:
|
213 |
+
return False
|
214 |
+
event.data = await self.get_message(event)
|
215 |
+
return True
|
216 |
+
|
217 |
+
async def notify_clients(self) -> None:
|
218 |
+
"""
|
219 |
+
Notify clients about events statuses in the queue periodically.
|
220 |
+
"""
|
221 |
+
while not self.stopped:
|
222 |
+
await asyncio.sleep(self.update_intervals)
|
223 |
+
if self.event_queue:
|
224 |
+
await self.broadcast_estimations()
|
225 |
+
|
226 |
+
async def broadcast_estimations(self) -> None:
|
227 |
+
estimation = self.get_estimation()
|
228 |
+
# Send all messages concurrently
|
229 |
+
await asyncio.gather(
|
230 |
+
*[
|
231 |
+
self.send_estimation(event, estimation, rank)
|
232 |
+
for rank, event in enumerate(self.event_queue)
|
233 |
+
]
|
234 |
+
)
|
235 |
+
|
236 |
+
async def send_estimation(
|
237 |
+
self, event: Event, estimation: Estimation, rank: int
|
238 |
+
) -> Estimation:
|
239 |
+
"""
|
240 |
+
Send estimation about ETA to the client.
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
event:
|
244 |
+
estimation:
|
245 |
+
rank:
|
246 |
+
"""
|
247 |
+
estimation.rank = rank
|
248 |
+
|
249 |
+
if self.avg_concurrent_process_time is not None:
|
250 |
+
estimation.rank_eta = (
|
251 |
+
estimation.rank * self.avg_concurrent_process_time
|
252 |
+
+ self.avg_process_time
|
253 |
+
)
|
254 |
+
if None not in self.active_jobs:
|
255 |
+
# Add estimated amount of time for a thread to get empty
|
256 |
+
estimation.rank_eta += self.avg_concurrent_process_time
|
257 |
+
client_awake = await self.send_message(event, estimation.dict())
|
258 |
+
if not client_awake:
|
259 |
+
await self.clean_event(event)
|
260 |
+
return estimation
|
261 |
+
|
262 |
+
def update_estimation(self, duration: float) -> None:
|
263 |
+
"""
|
264 |
+
Update estimation by last x element's average duration.
|
265 |
+
|
266 |
+
Parameters:
|
267 |
+
duration:
|
268 |
+
"""
|
269 |
+
self.duration_history_total += duration
|
270 |
+
self.duration_history_count += 1
|
271 |
+
self.avg_process_time = (
|
272 |
+
self.duration_history_total / self.duration_history_count
|
273 |
+
)
|
274 |
+
self.avg_concurrent_process_time = self.avg_process_time / min(
|
275 |
+
self.max_thread_count, self.duration_history_count
|
276 |
+
)
|
277 |
+
self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)
|
278 |
+
|
279 |
+
def get_estimation(self) -> Estimation:
|
280 |
+
return Estimation(
|
281 |
+
queue_size=len(self.event_queue),
|
282 |
+
avg_event_process_time=self.avg_process_time,
|
283 |
+
avg_event_concurrent_process_time=self.avg_concurrent_process_time,
|
284 |
+
queue_eta=self.queue_duration,
|
285 |
+
)
|
286 |
+
|
287 |
+
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
|
288 |
+
return {
|
289 |
+
"url": str(websocket.url),
|
290 |
+
"headers": dict(websocket.headers),
|
291 |
+
"query_params": dict(websocket.query_params),
|
292 |
+
"path_params": dict(websocket.path_params),
|
293 |
+
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
|
294 |
+
}
|
295 |
+
|
296 |
+
async def call_prediction(self, events: List[Event], batch: bool):
|
297 |
+
data = events[0].data
|
298 |
+
assert data is not None, "No event data"
|
299 |
+
token = events[0].token
|
300 |
+
data.event_id = events[0]._id if not batch else None
|
301 |
+
try:
|
302 |
+
data.request = self.get_request_params(events[0].websocket)
|
303 |
+
except ValueError:
|
304 |
+
pass
|
305 |
+
|
306 |
+
if batch:
|
307 |
+
data.data = list(zip(*[event.data.data for event in events if event.data]))
|
308 |
+
data.request = [
|
309 |
+
self.get_request_params(event.websocket)
|
310 |
+
for event in events
|
311 |
+
if event.data
|
312 |
+
]
|
313 |
+
data.batched = True
|
314 |
+
|
315 |
+
response = await AsyncRequest(
|
316 |
+
method=AsyncRequest.Method.POST,
|
317 |
+
url=f"{self.server_path}api/predict",
|
318 |
+
json=dict(data),
|
319 |
+
headers={"Authorization": f"Bearer {self.access_token}"},
|
320 |
+
cookies={"access-token": token} if token is not None else None,
|
321 |
+
)
|
322 |
+
return response
|
323 |
+
|
324 |
+
async def process_events(self, events: List[Event], batch: bool) -> None:
|
325 |
+
awake_events: List[Event] = []
|
326 |
+
try:
|
327 |
+
for event in events:
|
328 |
+
client_awake = await self.gather_event_data(event)
|
329 |
+
if client_awake:
|
330 |
+
client_awake = await self.send_message(
|
331 |
+
event, {"msg": "process_starts"}
|
332 |
+
)
|
333 |
+
if client_awake:
|
334 |
+
awake_events.append(event)
|
335 |
+
if not awake_events:
|
336 |
+
return
|
337 |
+
begin_time = time.time()
|
338 |
+
response = await self.call_prediction(awake_events, batch)
|
339 |
+
if response.has_exception:
|
340 |
+
for event in awake_events:
|
341 |
+
await self.send_message(
|
342 |
+
event,
|
343 |
+
{
|
344 |
+
"msg": "process_completed",
|
345 |
+
"output": {"error": str(response.exception)},
|
346 |
+
"success": False,
|
347 |
+
},
|
348 |
+
)
|
349 |
+
elif response.json.get("is_generating", False):
|
350 |
+
old_response = response
|
351 |
+
while response.json.get("is_generating", False):
|
352 |
+
# Python 3.7 doesn't have named tasks.
|
353 |
+
# In order to determine if a task was cancelled, we
|
354 |
+
# ping the websocket to see if it was closed mid-iteration.
|
355 |
+
if sys.version_info < (3, 8):
|
356 |
+
is_alive = await self.send_message(event, {"msg": "alive?"})
|
357 |
+
if not is_alive:
|
358 |
+
return
|
359 |
+
old_response = response
|
360 |
+
open_ws = []
|
361 |
+
for event in awake_events:
|
362 |
+
open = await self.send_message(
|
363 |
+
event,
|
364 |
+
{
|
365 |
+
"msg": "process_generating",
|
366 |
+
"output": old_response.json,
|
367 |
+
"success": old_response.status == 200,
|
368 |
+
},
|
369 |
+
)
|
370 |
+
open_ws.append(open)
|
371 |
+
awake_events = [
|
372 |
+
e for e, is_open in zip(awake_events, open_ws) if is_open
|
373 |
+
]
|
374 |
+
if not awake_events:
|
375 |
+
return
|
376 |
+
response = await self.call_prediction(awake_events, batch)
|
377 |
+
for event in awake_events:
|
378 |
+
if response.status != 200:
|
379 |
+
relevant_response = response
|
380 |
+
else:
|
381 |
+
relevant_response = old_response
|
382 |
+
|
383 |
+
await self.send_message(
|
384 |
+
event,
|
385 |
+
{
|
386 |
+
"msg": "process_completed",
|
387 |
+
"output": relevant_response.json,
|
388 |
+
"success": relevant_response.status == 200,
|
389 |
+
},
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
output = copy.deepcopy(response.json)
|
393 |
+
for e, event in enumerate(awake_events):
|
394 |
+
if batch and "data" in output:
|
395 |
+
output["data"] = list(zip(*response.json.get("data")))[e]
|
396 |
+
await self.send_message(
|
397 |
+
event,
|
398 |
+
{
|
399 |
+
"msg": "process_completed",
|
400 |
+
"output": output,
|
401 |
+
"success": response.status == 200,
|
402 |
+
},
|
403 |
+
)
|
404 |
+
end_time = time.time()
|
405 |
+
if response.status == 200:
|
406 |
+
self.update_estimation(end_time - begin_time)
|
407 |
+
finally:
|
408 |
+
for event in awake_events:
|
409 |
+
try:
|
410 |
+
await event.disconnect()
|
411 |
+
except Exception:
|
412 |
+
pass
|
413 |
+
self.active_jobs[self.active_jobs.index(events)] = None
|
414 |
+
for event in awake_events:
|
415 |
+
await self.clean_event(event)
|
416 |
+
# Always reset the state of the iterator
|
417 |
+
# If the job finished successfully, this has no effect
|
418 |
+
# If the job is cancelled, this will enable future runs
|
419 |
+
# to start "from scratch"
|
420 |
+
await self.reset_iterators(event.session_hash, event.fn_index)
|
421 |
+
|
422 |
+
async def send_message(self, event, data: Dict) -> bool:
|
423 |
+
try:
|
424 |
+
await event.websocket.send_json(data=data)
|
425 |
+
return True
|
426 |
+
except:
|
427 |
+
await self.clean_event(event)
|
428 |
+
return False
|
429 |
+
|
430 |
+
async def get_message(self, event) -> PredictBody | None:
|
431 |
+
try:
|
432 |
+
data = await event.websocket.receive_json()
|
433 |
+
return PredictBody(**data)
|
434 |
+
except:
|
435 |
+
await self.clean_event(event)
|
436 |
+
return None
|
437 |
+
|
438 |
+
async def reset_iterators(self, session_hash: str, fn_index: int):
|
439 |
+
await AsyncRequest(
|
440 |
+
method=AsyncRequest.Method.POST,
|
441 |
+
url=f"{self.server_path}reset",
|
442 |
+
json={
|
443 |
+
"session_hash": session_hash,
|
444 |
+
"fn_index": fn_index,
|
445 |
+
},
|
446 |
+
)
|
gradio-modified/gradio/reload.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
Contains the functions that run when `gradio` is called from the command line. Specifically, allows
|
4 |
+
|
5 |
+
$ gradio app.py, to run app.py in reload mode where any changes in the app.py file or Gradio library reloads the demo.
|
6 |
+
$ gradio app.py my_demo, to use variable names other than "demo"
|
7 |
+
"""
|
8 |
+
import inspect
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import gradio
|
14 |
+
from gradio import networking
|
15 |
+
|
16 |
+
|
17 |
+
def run_in_reload_mode():
|
18 |
+
args = sys.argv[1:]
|
19 |
+
if len(args) == 0:
|
20 |
+
raise ValueError("No file specified.")
|
21 |
+
if len(args) == 1:
|
22 |
+
demo_name = "demo"
|
23 |
+
else:
|
24 |
+
demo_name = args[1]
|
25 |
+
|
26 |
+
original_path = args[0]
|
27 |
+
abs_original_path = Path(original_path).name
|
28 |
+
path = str(Path(original_path).resolve())
|
29 |
+
path = path.replace("/", ".")
|
30 |
+
path = path.replace("\\", ".")
|
31 |
+
filename = Path(path).stem
|
32 |
+
|
33 |
+
gradio_folder = Path(inspect.getfile(gradio)).parent
|
34 |
+
|
35 |
+
port = networking.get_first_available_port(
|
36 |
+
networking.INITIAL_PORT_VALUE,
|
37 |
+
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
|
38 |
+
)
|
39 |
+
print(
|
40 |
+
f"\nLaunching in *reload mode* on: http://{networking.LOCALHOST_NAME}:{port} (Press CTRL+C to quit)\n"
|
41 |
+
)
|
42 |
+
command = f"uvicorn {filename}:{demo_name}.app --reload --port {port} --log-level warning "
|
43 |
+
message = "Watching:"
|
44 |
+
|
45 |
+
message_change_count = 0
|
46 |
+
if str(gradio_folder).strip():
|
47 |
+
command += f'--reload-dir "{gradio_folder}" '
|
48 |
+
message += f" '{gradio_folder}'"
|
49 |
+
message_change_count += 1
|
50 |
+
|
51 |
+
abs_parent = Path(abs_original_path).parent
|
52 |
+
if str(abs_parent).strip():
|
53 |
+
command += f'--reload-dir "{abs_parent}"'
|
54 |
+
if message_change_count == 1:
|
55 |
+
message += ","
|
56 |
+
message += f" '{abs_parent}'"
|
57 |
+
|
58 |
+
print(message + "\n")
|
59 |
+
os.system(command)
|
gradio-modified/gradio/routes.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements a FastAPI server to run the gradio interface. Note that some types in this
|
2 |
+
module use the Optional/Union notation so that they work correctly with pydantic."""
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import inspect
|
8 |
+
import json
|
9 |
+
import mimetypes
|
10 |
+
import os
|
11 |
+
import posixpath
|
12 |
+
import secrets
|
13 |
+
import traceback
|
14 |
+
from collections import defaultdict
|
15 |
+
from copy import deepcopy
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Any, Dict, List, Optional, Type
|
18 |
+
from urllib.parse import urlparse
|
19 |
+
|
20 |
+
import fastapi
|
21 |
+
import markupsafe
|
22 |
+
import orjson
|
23 |
+
import pkg_resources
|
24 |
+
from fastapi import Depends, FastAPI, HTTPException, WebSocket, status
|
25 |
+
from fastapi.middleware.cors import CORSMiddleware
|
26 |
+
from fastapi.responses import (
|
27 |
+
FileResponse,
|
28 |
+
HTMLResponse,
|
29 |
+
JSONResponse,
|
30 |
+
PlainTextResponse,
|
31 |
+
)
|
32 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
33 |
+
from fastapi.templating import Jinja2Templates
|
34 |
+
from jinja2.exceptions import TemplateNotFound
|
35 |
+
from starlette.responses import RedirectResponse
|
36 |
+
from starlette.websockets import WebSocketState
|
37 |
+
|
38 |
+
import gradio
|
39 |
+
from gradio import utils
|
40 |
+
from gradio.data_classes import PredictBody, ResetBody
|
41 |
+
from gradio.documentation import document, set_documentation_group
|
42 |
+
from gradio.exceptions import Error
|
43 |
+
from gradio.queueing import Estimation, Event
|
44 |
+
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
|
45 |
+
|
46 |
+
mimetypes.init()
|
47 |
+
|
48 |
+
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
49 |
+
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
|
50 |
+
BUILD_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/assets")
|
51 |
+
VERSION_FILE = pkg_resources.resource_filename("gradio", "version.txt")
|
52 |
+
with open(VERSION_FILE) as version_file:
|
53 |
+
VERSION = version_file.read()
|
54 |
+
|
55 |
+
|
56 |
+
class ORJSONResponse(JSONResponse):
|
57 |
+
media_type = "application/json"
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def _render(content: Any) -> bytes:
|
61 |
+
return orjson.dumps(
|
62 |
+
content,
|
63 |
+
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
|
64 |
+
default=str,
|
65 |
+
)
|
66 |
+
|
67 |
+
def render(self, content: Any) -> bytes:
|
68 |
+
return ORJSONResponse._render(content)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def _render_str(content: Any) -> str:
|
72 |
+
return ORJSONResponse._render(content).decode("utf-8")
|
73 |
+
|
74 |
+
|
75 |
+
def toorjson(value):
|
76 |
+
return markupsafe.Markup(
|
77 |
+
ORJSONResponse._render_str(value)
|
78 |
+
.replace("<", "\\u003c")
|
79 |
+
.replace(">", "\\u003e")
|
80 |
+
.replace("&", "\\u0026")
|
81 |
+
.replace("'", "\\u0027")
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
86 |
+
templates.env.filters["toorjson"] = toorjson
|
87 |
+
|
88 |
+
|
89 |
+
###########
|
90 |
+
# Auth
|
91 |
+
###########
|
92 |
+
|
93 |
+
|
94 |
+
class App(FastAPI):
|
95 |
+
"""
|
96 |
+
FastAPI App Wrapper
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, **kwargs):
|
100 |
+
self.tokens = {}
|
101 |
+
self.auth = None
|
102 |
+
self.blocks: gradio.Blocks | None = None
|
103 |
+
self.state_holder = {}
|
104 |
+
self.iterators = defaultdict(dict)
|
105 |
+
self.lock = asyncio.Lock()
|
106 |
+
self.queue_token = secrets.token_urlsafe(32)
|
107 |
+
self.startup_events_triggered = False
|
108 |
+
super().__init__(**kwargs)
|
109 |
+
|
110 |
+
def configure_app(self, blocks: gradio.Blocks) -> None:
|
111 |
+
auth = blocks.auth
|
112 |
+
if auth is not None:
|
113 |
+
if not callable(auth):
|
114 |
+
self.auth = {account[0]: account[1] for account in auth}
|
115 |
+
else:
|
116 |
+
self.auth = auth
|
117 |
+
else:
|
118 |
+
self.auth = None
|
119 |
+
|
120 |
+
self.blocks = blocks
|
121 |
+
if hasattr(self.blocks, "_queue"):
|
122 |
+
self.blocks._queue.set_access_token(self.queue_token)
|
123 |
+
self.cwd = os.getcwd()
|
124 |
+
self.favicon_path = blocks.favicon_path
|
125 |
+
self.tokens = {}
|
126 |
+
|
127 |
+
def get_blocks(self) -> gradio.Blocks:
|
128 |
+
if self.blocks is None:
|
129 |
+
raise ValueError("No Blocks has been configured for this app.")
|
130 |
+
return self.blocks
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def create_app(blocks: gradio.Blocks) -> App:
|
134 |
+
app = App(default_response_class=ORJSONResponse)
|
135 |
+
app.configure_app(blocks)
|
136 |
+
|
137 |
+
app.add_middleware(
|
138 |
+
CORSMiddleware,
|
139 |
+
allow_origins=["*"],
|
140 |
+
allow_methods=["*"],
|
141 |
+
allow_headers=["*"],
|
142 |
+
)
|
143 |
+
|
144 |
+
@app.get("/user")
|
145 |
+
@app.get("/user/")
|
146 |
+
def get_current_user(request: fastapi.Request) -> Optional[str]:
|
147 |
+
token = request.cookies.get("access-token")
|
148 |
+
return app.tokens.get(token)
|
149 |
+
|
150 |
+
@app.get("/login_check")
|
151 |
+
@app.get("/login_check/")
|
152 |
+
def login_check(user: str = Depends(get_current_user)):
|
153 |
+
if app.auth is None or not (user is None):
|
154 |
+
return
|
155 |
+
raise HTTPException(
|
156 |
+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
157 |
+
)
|
158 |
+
|
159 |
+
async def ws_login_check(websocket: WebSocket) -> Optional[str]:
|
160 |
+
token = websocket.cookies.get("access-token")
|
161 |
+
return token # token is returned to allow request in queue
|
162 |
+
|
163 |
+
@app.get("/token")
|
164 |
+
@app.get("/token/")
|
165 |
+
def get_token(request: fastapi.Request) -> dict:
|
166 |
+
token = request.cookies.get("access-token")
|
167 |
+
return {"token": token, "user": app.tokens.get(token)}
|
168 |
+
|
169 |
+
@app.get("/app_id")
|
170 |
+
@app.get("/app_id/")
|
171 |
+
def app_id(request: fastapi.Request) -> dict:
|
172 |
+
return {"app_id": app.get_blocks().app_id}
|
173 |
+
|
174 |
+
@app.post("/login")
|
175 |
+
@app.post("/login/")
|
176 |
+
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
177 |
+
username, password = form_data.username, form_data.password
|
178 |
+
if app.auth is None:
|
179 |
+
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
180 |
+
if (
|
181 |
+
not callable(app.auth)
|
182 |
+
and username in app.auth
|
183 |
+
and app.auth[username] == password
|
184 |
+
) or (callable(app.auth) and app.auth.__call__(username, password)):
|
185 |
+
token = secrets.token_urlsafe(16)
|
186 |
+
app.tokens[token] = username
|
187 |
+
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
188 |
+
response.set_cookie(key="access-token", value=token, httponly=True)
|
189 |
+
return response
|
190 |
+
else:
|
191 |
+
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
192 |
+
|
193 |
+
###############
|
194 |
+
# Main Routes
|
195 |
+
###############
|
196 |
+
|
197 |
+
@app.head("/", response_class=HTMLResponse)
|
198 |
+
@app.get("/", response_class=HTMLResponse)
|
199 |
+
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
|
200 |
+
mimetypes.add_type("application/javascript", ".js")
|
201 |
+
blocks = app.get_blocks()
|
202 |
+
|
203 |
+
if app.auth is None or not (user is None):
|
204 |
+
config = app.get_blocks().config
|
205 |
+
else:
|
206 |
+
config = {
|
207 |
+
"auth_required": True,
|
208 |
+
"auth_message": blocks.auth_message,
|
209 |
+
}
|
210 |
+
|
211 |
+
try:
|
212 |
+
template = (
|
213 |
+
"frontend/share.html" if blocks.share else "frontend/index.html"
|
214 |
+
)
|
215 |
+
return templates.TemplateResponse(
|
216 |
+
template, {"request": request, "config": config}
|
217 |
+
)
|
218 |
+
except TemplateNotFound:
|
219 |
+
if blocks.share:
|
220 |
+
raise ValueError(
|
221 |
+
"Did you install Gradio from source files? Share mode only "
|
222 |
+
"works when Gradio is installed through the pip package."
|
223 |
+
)
|
224 |
+
else:
|
225 |
+
raise ValueError(
|
226 |
+
"Did you install Gradio from source files? You need to build "
|
227 |
+
"the frontend by running /scripts/build_frontend.sh"
|
228 |
+
)
|
229 |
+
|
230 |
+
@app.get("/config/", dependencies=[Depends(login_check)])
|
231 |
+
@app.get("/config", dependencies=[Depends(login_check)])
|
232 |
+
def get_config():
|
233 |
+
return app.get_blocks().config
|
234 |
+
|
235 |
+
@app.get("/static/{path:path}")
|
236 |
+
def static_resource(path: str):
|
237 |
+
static_file = safe_join(STATIC_PATH_LIB, path)
|
238 |
+
if static_file is not None:
|
239 |
+
return FileResponse(static_file)
|
240 |
+
raise HTTPException(status_code=404, detail="Static file not found")
|
241 |
+
|
242 |
+
@app.get("/assets/{path:path}")
|
243 |
+
def build_resource(path: str):
|
244 |
+
build_file = safe_join(BUILD_PATH_LIB, path)
|
245 |
+
if build_file is not None:
|
246 |
+
return FileResponse(build_file)
|
247 |
+
raise HTTPException(status_code=404, detail="Build file not found")
|
248 |
+
|
249 |
+
@app.get("/favicon.ico")
|
250 |
+
async def favicon():
|
251 |
+
blocks = app.get_blocks()
|
252 |
+
if blocks.favicon_path is None:
|
253 |
+
return static_resource("img/logo.svg")
|
254 |
+
else:
|
255 |
+
return FileResponse(blocks.favicon_path)
|
256 |
+
|
257 |
+
@app.get("/file={path:path}", dependencies=[Depends(login_check)])
|
258 |
+
def file(path: str):
|
259 |
+
blocks = app.get_blocks()
|
260 |
+
if utils.validate_url(path):
|
261 |
+
return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND)
|
262 |
+
if Path(app.cwd).resolve() in Path(path).resolve().parents or Path(
|
263 |
+
path
|
264 |
+
).resolve() in set().union(*blocks.temp_file_sets):
|
265 |
+
return FileResponse(
|
266 |
+
Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
raise ValueError(
|
270 |
+
f"File cannot be fetched: {path}. All files must contained within the Gradio python app working directory, or be a temp file created by the Gradio python app."
|
271 |
+
)
|
272 |
+
|
273 |
+
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
274 |
+
def file_deprecated(path: str):
|
275 |
+
return file(path)
|
276 |
+
|
277 |
+
@app.post("/reset/")
|
278 |
+
@app.post("/reset")
|
279 |
+
async def reset_iterator(body: ResetBody):
|
280 |
+
if body.session_hash not in app.iterators:
|
281 |
+
return {"success": False}
|
282 |
+
async with app.lock:
|
283 |
+
app.iterators[body.session_hash][body.fn_index] = None
|
284 |
+
app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
|
285 |
+
return {"success": True}
|
286 |
+
|
287 |
+
async def run_predict(
|
288 |
+
body: PredictBody,
|
289 |
+
request: Request | List[Request],
|
290 |
+
fn_index_inferred: int,
|
291 |
+
username: str = Depends(get_current_user),
|
292 |
+
):
|
293 |
+
if hasattr(body, "session_hash"):
|
294 |
+
if body.session_hash not in app.state_holder:
|
295 |
+
app.state_holder[body.session_hash] = {
|
296 |
+
_id: deepcopy(getattr(block, "value", None))
|
297 |
+
for _id, block in app.get_blocks().blocks.items()
|
298 |
+
if getattr(block, "stateful", False)
|
299 |
+
}
|
300 |
+
session_state = app.state_holder[body.session_hash]
|
301 |
+
iterators = app.iterators[body.session_hash]
|
302 |
+
# The should_reset set keeps track of the fn_indices
|
303 |
+
# that have been cancelled. When a job is cancelled,
|
304 |
+
# the /reset route will mark the jobs as having been reset.
|
305 |
+
# That way if the cancel job finishes BEFORE the job being cancelled
|
306 |
+
# the job being cancelled will not overwrite the state of the iterator.
|
307 |
+
# In all cases, should_reset will be the empty set the next time
|
308 |
+
# the fn_index is run.
|
309 |
+
app.iterators[body.session_hash]["should_reset"] = set([])
|
310 |
+
else:
|
311 |
+
session_state = {}
|
312 |
+
iterators = {}
|
313 |
+
event_id = getattr(body, "event_id", None)
|
314 |
+
raw_input = body.data
|
315 |
+
fn_index = body.fn_index
|
316 |
+
batch = app.get_blocks().dependencies[fn_index_inferred]["batch"]
|
317 |
+
if not (body.batched) and batch:
|
318 |
+
raw_input = [raw_input]
|
319 |
+
try:
|
320 |
+
output = await app.get_blocks().process_api(
|
321 |
+
fn_index=fn_index_inferred,
|
322 |
+
inputs=raw_input,
|
323 |
+
request=request,
|
324 |
+
state=session_state,
|
325 |
+
iterators=iterators,
|
326 |
+
event_id=event_id,
|
327 |
+
)
|
328 |
+
iterator = output.pop("iterator", None)
|
329 |
+
if hasattr(body, "session_hash"):
|
330 |
+
if fn_index in app.iterators[body.session_hash]["should_reset"]:
|
331 |
+
app.iterators[body.session_hash][fn_index] = None
|
332 |
+
else:
|
333 |
+
app.iterators[body.session_hash][fn_index] = iterator
|
334 |
+
if isinstance(output, Error):
|
335 |
+
raise output
|
336 |
+
except BaseException as error:
|
337 |
+
show_error = app.get_blocks().show_error or isinstance(error, Error)
|
338 |
+
traceback.print_exc()
|
339 |
+
return JSONResponse(
|
340 |
+
content={"error": str(error) if show_error else None},
|
341 |
+
status_code=500,
|
342 |
+
)
|
343 |
+
|
344 |
+
if not (body.batched) and batch:
|
345 |
+
output["data"] = output["data"][0]
|
346 |
+
return output
|
347 |
+
|
348 |
+
# had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
|
349 |
+
@app.post("/run/{api_name}", dependencies=[Depends(login_check)])
|
350 |
+
@app.post("/run/{api_name}/", dependencies=[Depends(login_check)])
|
351 |
+
@app.post("/api/{api_name}", dependencies=[Depends(login_check)])
|
352 |
+
@app.post("/api/{api_name}/", dependencies=[Depends(login_check)])
|
353 |
+
async def predict(
|
354 |
+
api_name: str,
|
355 |
+
body: PredictBody,
|
356 |
+
request: fastapi.Request,
|
357 |
+
username: str = Depends(get_current_user),
|
358 |
+
):
|
359 |
+
fn_index_inferred = None
|
360 |
+
if body.fn_index is None:
|
361 |
+
for i, fn in enumerate(app.get_blocks().dependencies):
|
362 |
+
if fn["api_name"] == api_name:
|
363 |
+
fn_index_inferred = i
|
364 |
+
break
|
365 |
+
if fn_index_inferred is None:
|
366 |
+
return JSONResponse(
|
367 |
+
content={
|
368 |
+
"error": f"This app has no endpoint /api/{api_name}/."
|
369 |
+
},
|
370 |
+
status_code=500,
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
fn_index_inferred = body.fn_index
|
374 |
+
if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
|
375 |
+
fn_index_inferred
|
376 |
+
):
|
377 |
+
if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
|
378 |
+
raise HTTPException(
|
379 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
380 |
+
detail="Not authorized to skip the queue",
|
381 |
+
)
|
382 |
+
|
383 |
+
# If this fn_index cancels jobs, then the only input we need is the
|
384 |
+
# current session hash
|
385 |
+
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
|
386 |
+
body.data = [body.session_hash]
|
387 |
+
if body.request:
|
388 |
+
if body.batched:
|
389 |
+
gr_request = [Request(**req) for req in body.request]
|
390 |
+
else:
|
391 |
+
assert isinstance(body.request, dict)
|
392 |
+
gr_request = Request(**body.request)
|
393 |
+
else:
|
394 |
+
gr_request = Request(request)
|
395 |
+
result = await run_predict(
|
396 |
+
body=body,
|
397 |
+
fn_index_inferred=fn_index_inferred,
|
398 |
+
username=username,
|
399 |
+
request=gr_request,
|
400 |
+
)
|
401 |
+
return result
|
402 |
+
|
403 |
+
@app.websocket("/queue/join")
|
404 |
+
async def join_queue(
|
405 |
+
websocket: WebSocket,
|
406 |
+
token: Optional[str] = Depends(ws_login_check),
|
407 |
+
):
|
408 |
+
blocks = app.get_blocks()
|
409 |
+
if app.auth is not None and token is None:
|
410 |
+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
411 |
+
return
|
412 |
+
if blocks._queue.server_path is None:
|
413 |
+
app_url = get_server_url_from_ws_url(str(websocket.url))
|
414 |
+
blocks._queue.set_url(app_url)
|
415 |
+
await websocket.accept()
|
416 |
+
# In order to cancel jobs, we need the session_hash and fn_index
|
417 |
+
# to create a unique id for each job
|
418 |
+
await websocket.send_json({"msg": "send_hash"})
|
419 |
+
session_info = await websocket.receive_json()
|
420 |
+
event = Event(
|
421 |
+
websocket, session_info["session_hash"], session_info["fn_index"]
|
422 |
+
)
|
423 |
+
# set the token into Event to allow using the same token for call_prediction
|
424 |
+
event.token = token
|
425 |
+
event.session_hash = session_info["session_hash"]
|
426 |
+
|
427 |
+
# Continuous events are not put in the queue so that they do not
|
428 |
+
# occupy the queue's resource as they are expected to run forever
|
429 |
+
if blocks.dependencies[event.fn_index].get("every", 0):
|
430 |
+
await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
|
431 |
+
await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
|
432 |
+
task = run_coro_in_background(
|
433 |
+
blocks._queue.process_events, [event], False
|
434 |
+
)
|
435 |
+
set_task_name(task, event.session_hash, event.fn_index, batch=False)
|
436 |
+
else:
|
437 |
+
rank = blocks._queue.push(event)
|
438 |
+
|
439 |
+
if rank is None:
|
440 |
+
await blocks._queue.send_message(event, {"msg": "queue_full"})
|
441 |
+
await event.disconnect()
|
442 |
+
return
|
443 |
+
estimation = blocks._queue.get_estimation()
|
444 |
+
await blocks._queue.send_estimation(event, estimation, rank)
|
445 |
+
while True:
|
446 |
+
await asyncio.sleep(60)
|
447 |
+
if websocket.application_state == WebSocketState.DISCONNECTED:
|
448 |
+
return
|
449 |
+
|
450 |
+
@app.get(
|
451 |
+
"/queue/status",
|
452 |
+
dependencies=[Depends(login_check)],
|
453 |
+
response_model=Estimation,
|
454 |
+
)
|
455 |
+
async def get_queue_status():
|
456 |
+
return app.get_blocks()._queue.get_estimation()
|
457 |
+
|
458 |
+
@app.get("/startup-events")
|
459 |
+
async def startup_events():
|
460 |
+
if not app.startup_events_triggered:
|
461 |
+
app.get_blocks().startup_events()
|
462 |
+
app.startup_events_triggered = True
|
463 |
+
return True
|
464 |
+
return False
|
465 |
+
|
466 |
+
@app.get("/robots.txt", response_class=PlainTextResponse)
|
467 |
+
def robots_txt():
|
468 |
+
if app.get_blocks().share:
|
469 |
+
return "User-agent: *\nDisallow: /"
|
470 |
+
else:
|
471 |
+
return "User-agent: *\nDisallow: "
|
472 |
+
|
473 |
+
return app
|
474 |
+
|
475 |
+
|
476 |
+
########
|
477 |
+
# Helper functions
|
478 |
+
########
|
479 |
+
|
480 |
+
|
481 |
+
def safe_join(directory: str, path: str) -> str | None:
|
482 |
+
"""Safely path to a base directory to avoid escaping the base directory.
|
483 |
+
Borrowed from: werkzeug.security.safe_join"""
|
484 |
+
_os_alt_seps: List[str] = list(
|
485 |
+
sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/"
|
486 |
+
)
|
487 |
+
|
488 |
+
if path != "":
|
489 |
+
filename = posixpath.normpath(path)
|
490 |
+
else:
|
491 |
+
return directory
|
492 |
+
|
493 |
+
if (
|
494 |
+
any(sep in filename for sep in _os_alt_seps)
|
495 |
+
or os.path.isabs(filename)
|
496 |
+
or filename == ".."
|
497 |
+
or filename.startswith("../")
|
498 |
+
):
|
499 |
+
return None
|
500 |
+
return posixpath.join(directory, filename)
|
501 |
+
|
502 |
+
|
503 |
+
def get_types(cls_set: List[Type]):
|
504 |
+
docset = []
|
505 |
+
types = []
|
506 |
+
for cls in cls_set:
|
507 |
+
doc = inspect.getdoc(cls) or ""
|
508 |
+
doc_lines = doc.split("\n")
|
509 |
+
for line in doc_lines:
|
510 |
+
if "value (" in line:
|
511 |
+
types.append(line.split("value (")[1].split(")")[0])
|
512 |
+
docset.append(doc_lines[1].split(":")[-1])
|
513 |
+
return docset, types
|
514 |
+
|
515 |
+
|
516 |
+
def get_server_url_from_ws_url(ws_url: str):
|
517 |
+
ws_url_parsed = urlparse(ws_url)
|
518 |
+
scheme = "http" if ws_url_parsed.scheme == "ws" else "https"
|
519 |
+
port = f":{ws_url_parsed.port}" if ws_url_parsed.port else ""
|
520 |
+
return f"{scheme}://{ws_url_parsed.hostname}{port}{ws_url_parsed.path.replace('queue/join', '')}"
|
521 |
+
|
522 |
+
|
523 |
+
set_documentation_group("routes")
|
524 |
+
|
525 |
+
|
526 |
+
class Obj:
|
527 |
+
"""
|
528 |
+
Using a class to convert dictionaries into objects. Used by the `Request` class.
|
529 |
+
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self, dict1):
|
533 |
+
self.__dict__.update(dict1)
|
534 |
+
|
535 |
+
def __str__(self) -> str:
|
536 |
+
return str(self.__dict__)
|
537 |
+
|
538 |
+
def __repr__(self) -> str:
|
539 |
+
return str(self.__dict__)
|
540 |
+
|
541 |
+
|
542 |
+
@document()
|
543 |
+
class Request:
|
544 |
+
"""
|
545 |
+
A Gradio request object that can be used to access the request headers, cookies,
|
546 |
+
query parameters and other information about the request from within the prediction
|
547 |
+
function. The class is a thin wrapper around the fastapi.Request class. Attributes
|
548 |
+
of this class include: `headers`, `client`, `query_params`, and `path_params`,
|
549 |
+
Example:
|
550 |
+
import gradio as gr
|
551 |
+
def echo(name, request: gr.Request):
|
552 |
+
print("Request headers dictionary:", request.headers)
|
553 |
+
print("IP address:", request.client.host)
|
554 |
+
return name
|
555 |
+
io = gr.Interface(echo, "textbox", "textbox").launch()
|
556 |
+
"""
|
557 |
+
|
558 |
+
def __init__(self, request: fastapi.Request | None = None, **kwargs):
|
559 |
+
"""
|
560 |
+
Can be instantiated with either a fastapi.Request or by manually passing in
|
561 |
+
attributes (needed for websocket-based queueing).
|
562 |
+
Parameters:
|
563 |
+
request: A fastapi.Request
|
564 |
+
"""
|
565 |
+
self.request = request
|
566 |
+
self.kwargs: Dict = kwargs
|
567 |
+
|
568 |
+
def dict_to_obj(self, d):
|
569 |
+
if isinstance(d, dict):
|
570 |
+
return json.loads(json.dumps(d), object_hook=Obj)
|
571 |
+
else:
|
572 |
+
return d
|
573 |
+
|
574 |
+
def __getattr__(self, name):
|
575 |
+
if self.request:
|
576 |
+
return self.dict_to_obj(getattr(self.request, name))
|
577 |
+
else:
|
578 |
+
try:
|
579 |
+
obj = self.kwargs[name]
|
580 |
+
except KeyError:
|
581 |
+
raise AttributeError(f"'Request' object has no attribute '{name}'")
|
582 |
+
return self.dict_to_obj(obj)
|
583 |
+
|
584 |
+
|
585 |
+
@document()
|
586 |
+
def mount_gradio_app(
|
587 |
+
app: fastapi.FastAPI,
|
588 |
+
blocks: gradio.Blocks,
|
589 |
+
path: str,
|
590 |
+
gradio_api_url: str | None = None,
|
591 |
+
) -> fastapi.FastAPI:
|
592 |
+
"""Mount a gradio.Blocks to an existing FastAPI application.
|
593 |
+
|
594 |
+
Parameters:
|
595 |
+
app: The parent FastAPI application.
|
596 |
+
blocks: The blocks object we want to mount to the parent app.
|
597 |
+
path: The path at which the gradio application will be mounted.
|
598 |
+
gradio_api_url: The full url at which the gradio app will run. This is only needed if deploying to Huggingface spaces of if the websocket endpoints of your deployed app are on a different network location than the gradio app. If deploying to spaces, set gradio_api_url to 'http://localhost:7860/'
|
599 |
+
Example:
|
600 |
+
from fastapi import FastAPI
|
601 |
+
import gradio as gr
|
602 |
+
app = FastAPI()
|
603 |
+
@app.get("/")
|
604 |
+
def read_main():
|
605 |
+
return {"message": "This is your main app"}
|
606 |
+
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
|
607 |
+
app = gr.mount_gradio_app(app, io, path="/gradio")
|
608 |
+
# Then run `uvicorn run:app` from the terminal and navigate to http://localhost:8000/gradio.
|
609 |
+
"""
|
610 |
+
blocks.dev_mode = False
|
611 |
+
blocks.config = blocks.get_config_file()
|
612 |
+
gradio_app = App.create_app(blocks)
|
613 |
+
|
614 |
+
@app.on_event("startup")
|
615 |
+
async def start_queue():
|
616 |
+
if gradio_app.get_blocks().enable_queue:
|
617 |
+
if gradio_api_url:
|
618 |
+
gradio_app.get_blocks()._queue.set_url(gradio_api_url)
|
619 |
+
gradio_app.get_blocks().startup_events()
|
620 |
+
|
621 |
+
app.mount(path, gradio_app)
|
622 |
+
return app
|
gradio-modified/gradio/serializing.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict
|
6 |
+
|
7 |
+
from gradio import processing_utils, utils
|
8 |
+
|
9 |
+
|
10 |
+
class Serializable(ABC):
|
11 |
+
@abstractmethod
|
12 |
+
def serialize(
|
13 |
+
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Convert data from human-readable format to serialized format for a browser.
|
17 |
+
"""
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def deserialize(
|
22 |
+
self,
|
23 |
+
x: Any,
|
24 |
+
save_dir: str | Path | None = None,
|
25 |
+
encryption_key: bytes | None = None,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Convert data from serialized format for a browser to human-readable format.
|
29 |
+
"""
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
class SimpleSerializable(Serializable):
|
34 |
+
def serialize(
|
35 |
+
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
|
36 |
+
) -> Any:
|
37 |
+
"""
|
38 |
+
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
|
39 |
+
Parameters:
|
40 |
+
x: Input data to serialize
|
41 |
+
load_dir: Ignored
|
42 |
+
encryption_key: Ignored
|
43 |
+
"""
|
44 |
+
return x
|
45 |
+
|
46 |
+
def deserialize(
|
47 |
+
self,
|
48 |
+
x: Any,
|
49 |
+
save_dir: str | Path | None = None,
|
50 |
+
encryption_key: bytes | None = None,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
|
54 |
+
Parameters:
|
55 |
+
x: Input data to deserialize
|
56 |
+
save_dir: Ignored
|
57 |
+
encryption_key: Ignored
|
58 |
+
"""
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class ImgSerializable(Serializable):
|
63 |
+
def serialize(
|
64 |
+
self,
|
65 |
+
x: str | None,
|
66 |
+
load_dir: str | Path = "",
|
67 |
+
encryption_key: bytes | None = None,
|
68 |
+
) -> str | None:
|
69 |
+
"""
|
70 |
+
Convert from human-friendly version of a file (string filepath) to a seralized
|
71 |
+
representation (base64).
|
72 |
+
Parameters:
|
73 |
+
x: String path to file to serialize
|
74 |
+
load_dir: Path to directory containing x
|
75 |
+
encryption_key: Used to encrypt the file
|
76 |
+
"""
|
77 |
+
if x is None or x == "":
|
78 |
+
return None
|
79 |
+
return processing_utils.encode_url_or_file_to_base64(
|
80 |
+
Path(load_dir) / x, encryption_key=encryption_key
|
81 |
+
)
|
82 |
+
|
83 |
+
def deserialize(
|
84 |
+
self,
|
85 |
+
x: str | None,
|
86 |
+
save_dir: str | Path | None = None,
|
87 |
+
encryption_key: bytes | None = None,
|
88 |
+
) -> str | None:
|
89 |
+
"""
|
90 |
+
Convert from serialized representation of a file (base64) to a human-friendly
|
91 |
+
version (string filepath). Optionally, save the file to the directory specified by save_dir
|
92 |
+
Parameters:
|
93 |
+
x: Base64 representation of image to deserialize into a string filepath
|
94 |
+
save_dir: Path to directory to save the deserialized image to
|
95 |
+
encryption_key: Used to decrypt the file
|
96 |
+
"""
|
97 |
+
if x is None or x == "":
|
98 |
+
return None
|
99 |
+
file = processing_utils.decode_base64_to_file(
|
100 |
+
x, dir=save_dir, encryption_key=encryption_key
|
101 |
+
)
|
102 |
+
return file.name
|
103 |
+
|
104 |
+
|
105 |
+
class FileSerializable(Serializable):
|
106 |
+
def serialize(
|
107 |
+
self,
|
108 |
+
x: str | None,
|
109 |
+
load_dir: str | Path = "",
|
110 |
+
encryption_key: bytes | None = None,
|
111 |
+
) -> Dict | None:
|
112 |
+
"""
|
113 |
+
Convert from human-friendly version of a file (string filepath) to a
|
114 |
+
seralized representation (base64)
|
115 |
+
Parameters:
|
116 |
+
x: String path to file to serialize
|
117 |
+
load_dir: Path to directory containing x
|
118 |
+
encryption_key: Used to encrypt the file
|
119 |
+
"""
|
120 |
+
if x is None or x == "":
|
121 |
+
return None
|
122 |
+
filename = Path(load_dir) / x
|
123 |
+
return {
|
124 |
+
"name": filename,
|
125 |
+
"data": processing_utils.encode_url_or_file_to_base64(
|
126 |
+
filename, encryption_key=encryption_key
|
127 |
+
),
|
128 |
+
"orig_name": Path(filename).name,
|
129 |
+
"is_file": False,
|
130 |
+
}
|
131 |
+
|
132 |
+
def deserialize(
|
133 |
+
self,
|
134 |
+
x: str | Dict | None,
|
135 |
+
save_dir: Path | str | None = None,
|
136 |
+
encryption_key: bytes | None = None,
|
137 |
+
) -> str | None:
|
138 |
+
"""
|
139 |
+
Convert from serialized representation of a file (base64) to a human-friendly
|
140 |
+
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
|
141 |
+
Parameters:
|
142 |
+
x: Base64 representation of file to deserialize into a string filepath
|
143 |
+
save_dir: Path to directory to save the deserialized file to
|
144 |
+
encryption_key: Used to decrypt the file
|
145 |
+
"""
|
146 |
+
if x is None:
|
147 |
+
return None
|
148 |
+
if isinstance(save_dir, Path):
|
149 |
+
save_dir = str(save_dir)
|
150 |
+
if isinstance(x, str):
|
151 |
+
file_name = processing_utils.decode_base64_to_file(
|
152 |
+
x, dir=save_dir, encryption_key=encryption_key
|
153 |
+
).name
|
154 |
+
elif isinstance(x, dict):
|
155 |
+
if x.get("is_file", False):
|
156 |
+
if utils.validate_url(x["name"]):
|
157 |
+
file_name = x["name"]
|
158 |
+
else:
|
159 |
+
file_name = processing_utils.create_tmp_copy_of_file(
|
160 |
+
x["name"], dir=save_dir
|
161 |
+
).name
|
162 |
+
else:
|
163 |
+
file_name = processing_utils.decode_base64_to_file(
|
164 |
+
x["data"], dir=save_dir, encryption_key=encryption_key
|
165 |
+
).name
|
166 |
+
else:
|
167 |
+
raise ValueError(
|
168 |
+
f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}"
|
169 |
+
)
|
170 |
+
return file_name
|
171 |
+
|
172 |
+
|
173 |
+
class JSONSerializable(Serializable):
|
174 |
+
def serialize(
|
175 |
+
self,
|
176 |
+
x: str | None,
|
177 |
+
load_dir: str | Path = "",
|
178 |
+
encryption_key: bytes | None = None,
|
179 |
+
) -> Dict | None:
|
180 |
+
"""
|
181 |
+
Convert from a a human-friendly version (string path to json file) to a
|
182 |
+
serialized representation (json string)
|
183 |
+
Parameters:
|
184 |
+
x: String path to json file to read to get json string
|
185 |
+
load_dir: Path to directory containing x
|
186 |
+
encryption_key: Ignored
|
187 |
+
"""
|
188 |
+
if x is None or x == "":
|
189 |
+
return None
|
190 |
+
return processing_utils.file_to_json(Path(load_dir) / x)
|
191 |
+
|
192 |
+
def deserialize(
|
193 |
+
self,
|
194 |
+
x: str | Dict,
|
195 |
+
save_dir: str | Path | None = None,
|
196 |
+
encryption_key: bytes | None = None,
|
197 |
+
) -> str | None:
|
198 |
+
"""
|
199 |
+
Convert from serialized representation (json string) to a human-friendly
|
200 |
+
version (string path to json file). Optionally, save the file to the directory specified by `save_dir`
|
201 |
+
Parameters:
|
202 |
+
x: Json string
|
203 |
+
save_dir: Path to save the deserialized json file to
|
204 |
+
encryption_key: Ignored
|
205 |
+
"""
|
206 |
+
if x is None:
|
207 |
+
return None
|
208 |
+
return processing_utils.dict_or_str_to_json_file(x, dir=save_dir).name
|
gradio-modified/gradio/strings.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
|
6 |
+
|
7 |
+
en = {
|
8 |
+
"RUNNING_LOCALLY": "Running on local URL: {}",
|
9 |
+
"RUNNING_LOCALLY_SEPARATED": "Running on local URL: {}://{}:{}",
|
10 |
+
"SHARE_LINK_DISPLAY": "Running on public URL: {}",
|
11 |
+
"COULD_NOT_GET_SHARE_LINK": "\nCould not create share link, please check your internet connection.",
|
12 |
+
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
13 |
+
"PUBLIC_SHARE_TRUE": "\nTo create a public link, set `share=True` in `launch()`.",
|
14 |
+
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
|
15 |
+
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
|
16 |
+
"BETA_INVITE": "\nThanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB",
|
17 |
+
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
|
18 |
+
"To turn off, set debug=False in launch().",
|
19 |
+
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
|
20 |
+
"COLAB_WARNING": "Note: opening Chrome Inspector may crash demo inside Colab notebooks.",
|
21 |
+
"SHARE_LINK_MESSAGE": "\nThis share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces",
|
22 |
+
"INLINE_DISPLAY_BELOW": "Interface loading below...",
|
23 |
+
"TIPS": [
|
24 |
+
"You can add authentication to your app with the `auth=` kwarg in the `launch()` command; for example: `gr.Interface(...).launch(auth=('username', 'password'))`",
|
25 |
+
"Let users specify why they flagged input with the `flagging_options=` kwarg; for example: `gr.Interface(..., flagging_options=['too slow', 'incorrect output', 'other'])`",
|
26 |
+
"You can show or hide the button for flagging with the `allow_flagging=` kwarg; for example: gr.Interface(..., allow_flagging=False)",
|
27 |
+
"The inputs and outputs flagged by the users are stored in the flagging directory, specified by the flagging_dir= kwarg. You can view this data through the interface by setting the examples= kwarg to the flagging directory; for example gr.Interface(..., examples='flagged')",
|
28 |
+
"You can add a title and description to your interface using the `title=` and `description=` kwargs. The `article=` kwarg can be used to add a description under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum'). Try using Markdown!",
|
29 |
+
"For a classification or regression model, set `interpretation='default'` to see why the model made a prediction.",
|
30 |
+
],
|
31 |
+
}
|
32 |
+
|
33 |
+
try:
|
34 |
+
updated_messaging = requests.get(MESSAGING_API_ENDPOINT, timeout=3).json()
|
35 |
+
en.update(updated_messaging)
|
36 |
+
except (
|
37 |
+
requests.ConnectionError,
|
38 |
+
requests.exceptions.ReadTimeout,
|
39 |
+
json.decoder.JSONDecodeError,
|
40 |
+
): # Use default messaging
|
41 |
+
pass
|
gradio-modified/gradio/templates.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import typing
|
4 |
+
from typing import Any, Callable, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL.Image import Image
|
8 |
+
|
9 |
+
from gradio import components
|
10 |
+
|
11 |
+
|
12 |
+
class TextArea(components.Textbox):
|
13 |
+
"""
|
14 |
+
Sets: lines=7
|
15 |
+
"""
|
16 |
+
|
17 |
+
is_template = True
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
value: str | Callable | None = "",
|
22 |
+
*,
|
23 |
+
lines: int = 7,
|
24 |
+
max_lines: int = 20,
|
25 |
+
placeholder: str | None = None,
|
26 |
+
label: str | None = None,
|
27 |
+
show_label: bool = True,
|
28 |
+
interactive: bool | None = None,
|
29 |
+
visible: bool = True,
|
30 |
+
elem_id: str | None = None,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
super().__init__(
|
34 |
+
value=value,
|
35 |
+
lines=lines,
|
36 |
+
max_lines=max_lines,
|
37 |
+
placeholder=placeholder,
|
38 |
+
label=label,
|
39 |
+
show_label=show_label,
|
40 |
+
interactive=interactive,
|
41 |
+
visible=visible,
|
42 |
+
elem_id=elem_id,
|
43 |
+
**kwargs,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class Webcam(components.Image):
|
48 |
+
"""
|
49 |
+
Sets: source="webcam", interactive=True
|
50 |
+
"""
|
51 |
+
|
52 |
+
is_template = True
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
value: str | Image | np.ndarray | None = None,
|
57 |
+
*,
|
58 |
+
shape: Tuple[int, int] | None = None,
|
59 |
+
image_mode: str = "RGB",
|
60 |
+
invert_colors: bool = False,
|
61 |
+
source: str = "webcam",
|
62 |
+
tool: str | None = None,
|
63 |
+
type: str = "numpy",
|
64 |
+
label: str | None = None,
|
65 |
+
show_label: bool = True,
|
66 |
+
interactive: bool | None = True,
|
67 |
+
visible: bool = True,
|
68 |
+
streaming: bool = False,
|
69 |
+
elem_id: str | None = None,
|
70 |
+
mirror_webcam: bool = True,
|
71 |
+
**kwargs,
|
72 |
+
):
|
73 |
+
super().__init__(
|
74 |
+
value=value,
|
75 |
+
shape=shape,
|
76 |
+
image_mode=image_mode,
|
77 |
+
invert_colors=invert_colors,
|
78 |
+
source=source,
|
79 |
+
tool=tool,
|
80 |
+
type=type,
|
81 |
+
label=label,
|
82 |
+
show_label=show_label,
|
83 |
+
interactive=interactive,
|
84 |
+
visible=visible,
|
85 |
+
streaming=streaming,
|
86 |
+
elem_id=elem_id,
|
87 |
+
mirror_webcam=mirror_webcam,
|
88 |
+
**kwargs,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
class Sketchpad(components.Image):
|
93 |
+
"""
|
94 |
+
Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True, interactive=True
|
95 |
+
"""
|
96 |
+
|
97 |
+
is_template = True
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
value: str | Image | np.ndarray | None = None,
|
102 |
+
*,
|
103 |
+
shape: Tuple[int, int] = (28, 28),
|
104 |
+
image_mode: str = "L",
|
105 |
+
invert_colors: bool = True,
|
106 |
+
source: str = "canvas",
|
107 |
+
tool: str | None = None,
|
108 |
+
type: str = "numpy",
|
109 |
+
label: str | None = None,
|
110 |
+
show_label: bool = True,
|
111 |
+
interactive: bool | None = True,
|
112 |
+
visible: bool = True,
|
113 |
+
streaming: bool = False,
|
114 |
+
elem_id: str | None = None,
|
115 |
+
mirror_webcam: bool = True,
|
116 |
+
**kwargs,
|
117 |
+
):
|
118 |
+
super().__init__(
|
119 |
+
value=value,
|
120 |
+
shape=shape,
|
121 |
+
image_mode=image_mode,
|
122 |
+
invert_colors=invert_colors,
|
123 |
+
source=source,
|
124 |
+
tool=tool,
|
125 |
+
type=type,
|
126 |
+
label=label,
|
127 |
+
show_label=show_label,
|
128 |
+
interactive=interactive,
|
129 |
+
visible=visible,
|
130 |
+
streaming=streaming,
|
131 |
+
elem_id=elem_id,
|
132 |
+
mirror_webcam=mirror_webcam,
|
133 |
+
**kwargs,
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
class Paint(components.Image):
|
138 |
+
"""
|
139 |
+
Sets: source="canvas", tool="color-sketch", interactive=True
|
140 |
+
"""
|
141 |
+
|
142 |
+
is_template = True
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
value: str | Image | np.ndarray | None = None,
|
147 |
+
*,
|
148 |
+
shape: Tuple[int, int] | None = None,
|
149 |
+
image_mode: str = "RGB",
|
150 |
+
invert_colors: bool = False,
|
151 |
+
source: str = "canvas",
|
152 |
+
tool: str = "color-sketch",
|
153 |
+
type: str = "numpy",
|
154 |
+
label: str | None = None,
|
155 |
+
show_label: bool = True,
|
156 |
+
interactive: bool | None = True,
|
157 |
+
visible: bool = True,
|
158 |
+
streaming: bool = False,
|
159 |
+
elem_id: str | None = None,
|
160 |
+
mirror_webcam: bool = True,
|
161 |
+
**kwargs,
|
162 |
+
):
|
163 |
+
super().__init__(
|
164 |
+
value=value,
|
165 |
+
shape=shape,
|
166 |
+
image_mode=image_mode,
|
167 |
+
invert_colors=invert_colors,
|
168 |
+
source=source,
|
169 |
+
tool=tool,
|
170 |
+
type=type,
|
171 |
+
label=label,
|
172 |
+
show_label=show_label,
|
173 |
+
interactive=interactive,
|
174 |
+
visible=visible,
|
175 |
+
streaming=streaming,
|
176 |
+
elem_id=elem_id,
|
177 |
+
mirror_webcam=mirror_webcam,
|
178 |
+
**kwargs,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
class ImageMask(components.Image):
|
183 |
+
"""
|
184 |
+
Sets: source="upload", tool="sketch", interactive=True
|
185 |
+
"""
|
186 |
+
|
187 |
+
is_template = True
|
188 |
+
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
value: str | Image | np.ndarray | None = None,
|
192 |
+
*,
|
193 |
+
shape: Tuple[int, int] | None = None,
|
194 |
+
image_mode: str = "RGB",
|
195 |
+
invert_colors: bool = False,
|
196 |
+
source: str = "upload",
|
197 |
+
tool: str = "sketch",
|
198 |
+
type: str = "numpy",
|
199 |
+
label: str | None = None,
|
200 |
+
show_label: bool = True,
|
201 |
+
interactive: bool | None = True,
|
202 |
+
visible: bool = True,
|
203 |
+
streaming: bool = False,
|
204 |
+
elem_id: str | None = None,
|
205 |
+
mirror_webcam: bool = True,
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
super().__init__(
|
209 |
+
value=value,
|
210 |
+
shape=shape,
|
211 |
+
image_mode=image_mode,
|
212 |
+
invert_colors=invert_colors,
|
213 |
+
source=source,
|
214 |
+
tool=tool,
|
215 |
+
type=type,
|
216 |
+
label=label,
|
217 |
+
show_label=show_label,
|
218 |
+
interactive=interactive,
|
219 |
+
visible=visible,
|
220 |
+
streaming=streaming,
|
221 |
+
elem_id=elem_id,
|
222 |
+
mirror_webcam=mirror_webcam,
|
223 |
+
**kwargs,
|
224 |
+
)
|
225 |
+
|
226 |
+
|
227 |
+
class ImagePaint(components.Image):
|
228 |
+
"""
|
229 |
+
Sets: source="upload", tool="color-sketch", interactive=True
|
230 |
+
"""
|
231 |
+
|
232 |
+
is_template = True
|
233 |
+
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
value: str | Image | np.ndarray | None = None,
|
237 |
+
*,
|
238 |
+
shape: Tuple[int, int] | None = None,
|
239 |
+
image_mode: str = "RGB",
|
240 |
+
invert_colors: bool = False,
|
241 |
+
source: str = "upload",
|
242 |
+
tool: str = "color-sketch",
|
243 |
+
type: str = "numpy",
|
244 |
+
label: str | None = None,
|
245 |
+
show_label: bool = True,
|
246 |
+
interactive: bool | None = True,
|
247 |
+
visible: bool = True,
|
248 |
+
streaming: bool = False,
|
249 |
+
elem_id: str | None = None,
|
250 |
+
mirror_webcam: bool = True,
|
251 |
+
**kwargs,
|
252 |
+
):
|
253 |
+
super().__init__(
|
254 |
+
value=value,
|
255 |
+
shape=shape,
|
256 |
+
image_mode=image_mode,
|
257 |
+
invert_colors=invert_colors,
|
258 |
+
source=source,
|
259 |
+
tool=tool,
|
260 |
+
type=type,
|
261 |
+
label=label,
|
262 |
+
show_label=show_label,
|
263 |
+
interactive=interactive,
|
264 |
+
visible=visible,
|
265 |
+
streaming=streaming,
|
266 |
+
elem_id=elem_id,
|
267 |
+
mirror_webcam=mirror_webcam,
|
268 |
+
**kwargs,
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
class Pil(components.Image):
|
273 |
+
"""
|
274 |
+
Sets: type="pil"
|
275 |
+
"""
|
276 |
+
|
277 |
+
is_template = True
|
278 |
+
|
279 |
+
def __init__(
|
280 |
+
self,
|
281 |
+
value: str | Image | np.ndarray | None = None,
|
282 |
+
*,
|
283 |
+
shape: Tuple[int, int] | None = None,
|
284 |
+
image_mode: str = "RGB",
|
285 |
+
invert_colors: bool = False,
|
286 |
+
source: str = "upload",
|
287 |
+
tool: str | None = None,
|
288 |
+
type: str = "pil",
|
289 |
+
label: str | None = None,
|
290 |
+
show_label: bool = True,
|
291 |
+
interactive: bool | None = None,
|
292 |
+
visible: bool = True,
|
293 |
+
streaming: bool = False,
|
294 |
+
elem_id: str | None = None,
|
295 |
+
mirror_webcam: bool = True,
|
296 |
+
**kwargs,
|
297 |
+
):
|
298 |
+
super().__init__(
|
299 |
+
value=value,
|
300 |
+
shape=shape,
|
301 |
+
image_mode=image_mode,
|
302 |
+
invert_colors=invert_colors,
|
303 |
+
source=source,
|
304 |
+
tool=tool,
|
305 |
+
type=type,
|
306 |
+
label=label,
|
307 |
+
show_label=show_label,
|
308 |
+
interactive=interactive,
|
309 |
+
visible=visible,
|
310 |
+
streaming=streaming,
|
311 |
+
elem_id=elem_id,
|
312 |
+
mirror_webcam=mirror_webcam,
|
313 |
+
**kwargs,
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
class PlayableVideo(components.Video):
|
318 |
+
"""
|
319 |
+
Sets: format="mp4"
|
320 |
+
"""
|
321 |
+
|
322 |
+
is_template = True
|
323 |
+
|
324 |
+
def __init__(
|
325 |
+
self,
|
326 |
+
value: str | Callable | None = None,
|
327 |
+
*,
|
328 |
+
format: str | None = "mp4",
|
329 |
+
source: str = "upload",
|
330 |
+
label: str | None = None,
|
331 |
+
show_label: bool = True,
|
332 |
+
interactive: bool | None = None,
|
333 |
+
visible: bool = True,
|
334 |
+
elem_id: str | None = None,
|
335 |
+
mirror_webcam: bool = True,
|
336 |
+
include_audio: bool | None = None,
|
337 |
+
**kwargs,
|
338 |
+
):
|
339 |
+
super().__init__(
|
340 |
+
value=value,
|
341 |
+
format=format,
|
342 |
+
source=source,
|
343 |
+
label=label,
|
344 |
+
show_label=show_label,
|
345 |
+
interactive=interactive,
|
346 |
+
visible=visible,
|
347 |
+
elem_id=elem_id,
|
348 |
+
mirror_webcam=mirror_webcam,
|
349 |
+
include_audio=include_audio,
|
350 |
+
**kwargs,
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
class Microphone(components.Audio):
|
355 |
+
"""
|
356 |
+
Sets: source="microphone"
|
357 |
+
"""
|
358 |
+
|
359 |
+
is_template = True
|
360 |
+
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
value: str | Tuple[int, np.ndarray] | Callable | None = None,
|
364 |
+
*,
|
365 |
+
source: str = "microphone",
|
366 |
+
type: str = "numpy",
|
367 |
+
label: str | None = None,
|
368 |
+
show_label: bool = True,
|
369 |
+
interactive: bool | None = None,
|
370 |
+
visible: bool = True,
|
371 |
+
streaming: bool = False,
|
372 |
+
elem_id: str | None = None,
|
373 |
+
**kwargs,
|
374 |
+
):
|
375 |
+
super().__init__(
|
376 |
+
value=value,
|
377 |
+
source=source,
|
378 |
+
type=type,
|
379 |
+
label=label,
|
380 |
+
show_label=show_label,
|
381 |
+
interactive=interactive,
|
382 |
+
visible=visible,
|
383 |
+
streaming=streaming,
|
384 |
+
elem_id=elem_id,
|
385 |
+
**kwargs,
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
class Files(components.File):
|
390 |
+
"""
|
391 |
+
Sets: file_count="multiple"
|
392 |
+
"""
|
393 |
+
|
394 |
+
is_template = True
|
395 |
+
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
value: str | typing.List[str] | Callable | None = None,
|
399 |
+
*,
|
400 |
+
file_count: str = "multiple",
|
401 |
+
type: str = "file",
|
402 |
+
label: str | None = None,
|
403 |
+
show_label: bool = True,
|
404 |
+
interactive: bool | None = None,
|
405 |
+
visible: bool = True,
|
406 |
+
elem_id: str | None = None,
|
407 |
+
**kwargs,
|
408 |
+
):
|
409 |
+
super().__init__(
|
410 |
+
value=value,
|
411 |
+
file_count=file_count,
|
412 |
+
type=type,
|
413 |
+
label=label,
|
414 |
+
show_label=show_label,
|
415 |
+
interactive=interactive,
|
416 |
+
visible=visible,
|
417 |
+
elem_id=elem_id,
|
418 |
+
**kwargs,
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
class Numpy(components.Dataframe):
|
423 |
+
"""
|
424 |
+
Sets: type="numpy"
|
425 |
+
"""
|
426 |
+
|
427 |
+
is_template = True
|
428 |
+
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
value: typing.List[typing.List[Any]] | Callable | None = None,
|
432 |
+
*,
|
433 |
+
headers: typing.List[str] | None = None,
|
434 |
+
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
435 |
+
col_count: int | Tuple[int, str] | None = None,
|
436 |
+
datatype: str | typing.List[str] = "str",
|
437 |
+
type: str = "numpy",
|
438 |
+
max_rows: int | None = 20,
|
439 |
+
max_cols: int | None = None,
|
440 |
+
overflow_row_behaviour: str = "paginate",
|
441 |
+
label: str | None = None,
|
442 |
+
show_label: bool = True,
|
443 |
+
interactive: bool | None = None,
|
444 |
+
visible: bool = True,
|
445 |
+
elem_id: str | None = None,
|
446 |
+
wrap: bool = False,
|
447 |
+
**kwargs,
|
448 |
+
):
|
449 |
+
super().__init__(
|
450 |
+
value=value,
|
451 |
+
headers=headers,
|
452 |
+
row_count=row_count,
|
453 |
+
col_count=col_count,
|
454 |
+
datatype=datatype,
|
455 |
+
type=type,
|
456 |
+
max_rows=max_rows,
|
457 |
+
max_cols=max_cols,
|
458 |
+
overflow_row_behaviour=overflow_row_behaviour,
|
459 |
+
label=label,
|
460 |
+
show_label=show_label,
|
461 |
+
interactive=interactive,
|
462 |
+
visible=visible,
|
463 |
+
elem_id=elem_id,
|
464 |
+
wrap=wrap,
|
465 |
+
**kwargs,
|
466 |
+
)
|
467 |
+
|
468 |
+
|
469 |
+
class Matrix(components.Dataframe):
|
470 |
+
"""
|
471 |
+
Sets: type="array"
|
472 |
+
"""
|
473 |
+
|
474 |
+
is_template = True
|
475 |
+
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
value: typing.List[typing.List[Any]] | Callable | None = None,
|
479 |
+
*,
|
480 |
+
headers: typing.List[str] | None = None,
|
481 |
+
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
482 |
+
col_count: int | Tuple[int, str] | None = None,
|
483 |
+
datatype: str | typing.List[str] = "str",
|
484 |
+
type: str = "array",
|
485 |
+
max_rows: int | None = 20,
|
486 |
+
max_cols: int | None = None,
|
487 |
+
overflow_row_behaviour: str = "paginate",
|
488 |
+
label: str | None = None,
|
489 |
+
show_label: bool = True,
|
490 |
+
interactive: bool | None = None,
|
491 |
+
visible: bool = True,
|
492 |
+
elem_id: str | None = None,
|
493 |
+
wrap: bool = False,
|
494 |
+
**kwargs,
|
495 |
+
):
|
496 |
+
super().__init__(
|
497 |
+
value=value,
|
498 |
+
headers=headers,
|
499 |
+
row_count=row_count,
|
500 |
+
col_count=col_count,
|
501 |
+
datatype=datatype,
|
502 |
+
type=type,
|
503 |
+
max_rows=max_rows,
|
504 |
+
max_cols=max_cols,
|
505 |
+
overflow_row_behaviour=overflow_row_behaviour,
|
506 |
+
label=label,
|
507 |
+
show_label=show_label,
|
508 |
+
interactive=interactive,
|
509 |
+
visible=visible,
|
510 |
+
elem_id=elem_id,
|
511 |
+
wrap=wrap,
|
512 |
+
**kwargs,
|
513 |
+
)
|
514 |
+
|
515 |
+
|
516 |
+
class List(components.Dataframe):
|
517 |
+
"""
|
518 |
+
Sets: type="array", col_count=1
|
519 |
+
"""
|
520 |
+
|
521 |
+
is_template = True
|
522 |
+
|
523 |
+
def __init__(
|
524 |
+
self,
|
525 |
+
value: typing.List[typing.List[Any]] | Callable | None = None,
|
526 |
+
*,
|
527 |
+
headers: typing.List[str] | None = None,
|
528 |
+
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
529 |
+
col_count: int | Tuple[int, str] = 1,
|
530 |
+
datatype: str | typing.List[str] = "str",
|
531 |
+
type: str = "array",
|
532 |
+
max_rows: int | None = 20,
|
533 |
+
max_cols: int | None = None,
|
534 |
+
overflow_row_behaviour: str = "paginate",
|
535 |
+
label: str | None = None,
|
536 |
+
show_label: bool = True,
|
537 |
+
interactive: bool | None = None,
|
538 |
+
visible: bool = True,
|
539 |
+
elem_id: str | None = None,
|
540 |
+
wrap: bool = False,
|
541 |
+
**kwargs,
|
542 |
+
):
|
543 |
+
super().__init__(
|
544 |
+
value=value,
|
545 |
+
headers=headers,
|
546 |
+
row_count=row_count,
|
547 |
+
col_count=col_count,
|
548 |
+
datatype=datatype,
|
549 |
+
type=type,
|
550 |
+
max_rows=max_rows,
|
551 |
+
max_cols=max_cols,
|
552 |
+
overflow_row_behaviour=overflow_row_behaviour,
|
553 |
+
label=label,
|
554 |
+
show_label=show_label,
|
555 |
+
interactive=interactive,
|
556 |
+
visible=visible,
|
557 |
+
elem_id=elem_id,
|
558 |
+
wrap=wrap,
|
559 |
+
**kwargs,
|
560 |
+
)
|
561 |
+
|
562 |
+
|
563 |
+
Mic = Microphone
|
gradio-modified/{templates → gradio/templates}/frontend/assets/BlockLabel.37da86a3.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.cc0aed40.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/CarouselItem.svelte_svelte_type_style_lang.e110d966.css
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Column.06c172ac.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/File.60a988f4.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Image.4a41f1aa.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Image.95fa511c.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Model3D.b44fd6f2.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/ModifyUpload.2cfe71e4.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Tabs.6b500f1a.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Upload.5d0148e8.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/Webcam.8816836e.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/_commonjsHelpers.88e99c8f.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/color.509e5f03.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/csv.27f5436c.js
RENAMED
File without changes
|
gradio-modified/{templates → gradio/templates}/frontend/assets/dsv.7fe76a93.js
RENAMED
File without changes
|