|
import asyncio |
|
import json |
|
import os |
|
import shutil |
|
import subprocess |
|
import tempfile |
|
import time |
|
from pathlib import Path |
|
from unittest.mock import patch |
|
|
|
import gradio_client as grc |
|
import pytest |
|
from gradio_client import media_data |
|
from gradio_client import utils as client_utils |
|
from pydub import AudioSegment |
|
from starlette.testclient import TestClient |
|
from tqdm import tqdm |
|
|
|
import gradio as gr |
|
from gradio import utils |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestExamples: |
|
def test_handle_single_input(self, patched_cache_folder): |
|
examples = gr.Examples(["hello", "hi"], gr.Textbox()) |
|
assert examples.processed_examples == [["hello"], ["hi"]] |
|
|
|
examples = gr.Examples([["hello"]], gr.Textbox()) |
|
assert examples.processed_examples == [["hello"]] |
|
|
|
examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) |
|
assert ( |
|
client_utils.encode_file_to_base64( |
|
examples.processed_examples[0][0]["path"] |
|
) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_multiple_inputs(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] |
|
) |
|
assert examples.processed_examples[0][0] == "hello" |
|
assert ( |
|
client_utils.encode_file_to_base64( |
|
examples.processed_examples[0][1]["path"] |
|
) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_directory(self, patched_cache_folder): |
|
examples = gr.Examples("test/test_files/images", gr.Image()) |
|
assert len(examples.processed_examples) == 2 |
|
for row in examples.processed_examples: |
|
for output in row: |
|
assert ( |
|
client_utils.encode_file_to_base64(output["path"]) |
|
== media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_handle_directory_with_log_file(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
|
) |
|
ex = client_utils.traverse( |
|
examples.processed_examples, |
|
lambda s: client_utils.encode_file_to_base64(s["path"]), |
|
lambda x: isinstance(x, dict) and Path(x["path"]).exists(), |
|
) |
|
assert ex == [ |
|
[media_data.BASE64_IMAGE, "hello"], |
|
[media_data.BASE64_IMAGE, "hi"], |
|
] |
|
for sample in examples.dataset.samples: |
|
assert os.path.isabs(sample[0]["path"]) |
|
|
|
def test_examples_per_page(self, patched_cache_folder): |
|
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2) |
|
assert examples.dataset.get_config()["samples_per_page"] == 2 |
|
|
|
def test_no_preprocessing(self, patched_cache_folder): |
|
with gr.Blocks(): |
|
image = gr.Image() |
|
textbox = gr.Textbox() |
|
|
|
examples = gr.Examples( |
|
examples=["test/test_files/bus.png"], |
|
inputs=image, |
|
outputs=textbox, |
|
fn=lambda x: x["path"], |
|
cache_examples=True, |
|
preprocess=False, |
|
) |
|
|
|
prediction = examples.load_from_cache(0) |
|
assert ( |
|
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE |
|
) |
|
|
|
def test_no_postprocessing(self, patched_cache_folder): |
|
def im(x): |
|
return [ |
|
{ |
|
"image": { |
|
"path": "test/test_files/bus.png", |
|
}, |
|
"caption": "hi", |
|
} |
|
] |
|
|
|
with gr.Blocks(): |
|
text = gr.Textbox() |
|
gall = gr.Gallery() |
|
|
|
examples = gr.Examples( |
|
examples=["hi"], |
|
inputs=text, |
|
outputs=gall, |
|
fn=im, |
|
cache_examples=True, |
|
postprocess=False, |
|
) |
|
|
|
prediction = examples.load_from_cache(0) |
|
file = prediction[0].root[0].image.path |
|
assert client_utils.encode_url_or_file_to_base64( |
|
file |
|
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png") |
|
|
|
|
|
def test_setting_cache_dir_env_variable(monkeypatch): |
|
temp_dir = tempfile.mkdtemp() |
|
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir) |
|
with gr.Blocks(): |
|
image = gr.Image() |
|
image2 = gr.Image() |
|
|
|
examples = gr.Examples( |
|
examples=["test/test_files/bus.png"], |
|
inputs=image, |
|
outputs=image2, |
|
fn=lambda x: x, |
|
cache_examples=True, |
|
) |
|
prediction = examples.load_from_cache(0) |
|
path_to_cached_file = Path(prediction[0].path) |
|
assert utils.is_in_or_equal(path_to_cached_file, temp_dir) |
|
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False) |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestExamplesDataset: |
|
def test_no_headers(self, patched_cache_folder): |
|
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()]) |
|
assert examples.dataset.headers == [] |
|
|
|
def test_all_headers(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", |
|
[gr.Image(label="im"), gr.Text(label="your text")], |
|
) |
|
assert examples.dataset.headers == ["im", "your text"] |
|
|
|
def test_some_headers(self, patched_cache_folder): |
|
examples = gr.Examples( |
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
|
) |
|
assert examples.dataset.headers == ["im", ""] |
|
|
|
|
|
def test_example_caching_relaunch(connect): |
|
def combine(a, b): |
|
return a + " " + b |
|
|
|
with gr.Blocks() as demo: |
|
txt = gr.Textbox(label="Input") |
|
txt_2 = gr.Textbox(label="Input 2") |
|
txt_3 = gr.Textbox(value="", label="Output") |
|
btn = gr.Button(value="Submit") |
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
|
gr.Examples( |
|
[["hi", "Adam"], ["hello", "Eve"]], |
|
[txt, txt_2], |
|
txt_3, |
|
combine, |
|
cache_examples=True, |
|
api_name="examples", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
|
|
time.sleep(1) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
|
class TestProcessExamples: |
|
def test_caching(self, patched_cache_folder): |
|
io = gr.Interface( |
|
lambda x: f"Hello {x}", |
|
"text", |
|
"text", |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == "Hello Dunya" |
|
|
|
def test_example_caching_relaunch(self, patched_cache_folder, connect): |
|
def combine(a, b): |
|
return a + " " + b |
|
|
|
with gr.Blocks() as demo: |
|
txt = gr.Textbox(label="Input") |
|
txt_2 = gr.Textbox(label="Input 2") |
|
txt_3 = gr.Textbox(value="", label="Output") |
|
btn = gr.Button(value="Submit") |
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
|
gr.Examples( |
|
[["hi", "Adam"], ["hello", "Eve"]], |
|
[txt, txt_2], |
|
txt_3, |
|
combine, |
|
cache_examples=True, |
|
api_name="examples", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
with connect(demo) as client: |
|
assert client.predict(1, api_name="/examples") == ( |
|
"hello", |
|
"Eve", |
|
"hello Eve", |
|
) |
|
|
|
def test_caching_image(self, patched_cache_folder): |
|
io = gr.Interface( |
|
lambda x: x, |
|
"image", |
|
"image", |
|
examples=[["test/test_files/bus.png"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith( |
|
"" |
|
) |
|
|
|
def test_caching_audio(self, patched_cache_folder): |
|
io = gr.Interface( |
|
lambda x: x, |
|
"audio", |
|
"audio", |
|
examples=[["test/test_files/audio_sample.wav"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
file = prediction[0].path |
|
assert client_utils.encode_url_or_file_to_base64(file).startswith( |
|
"data:audio/wav;base64,UklGRgA/" |
|
) |
|
|
|
def test_caching_with_update(self, patched_cache_folder): |
|
io = gr.Interface( |
|
lambda x: gr.update(visible=False), |
|
"text", |
|
"image", |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == { |
|
"visible": False, |
|
"__type__": "update", |
|
} |
|
|
|
def test_caching_with_mix_update(self, patched_cache_folder): |
|
io = gr.Interface( |
|
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"], |
|
"text", |
|
["text", "image"], |
|
examples=[["World"], ["Dunya"], ["Monde"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert prediction[0] == { |
|
"lines": 4, |
|
"value": "hello", |
|
"__type__": "update", |
|
} |
|
|
|
def test_caching_with_dict(self, patched_cache_folder): |
|
text = gr.Textbox() |
|
out = gr.Label() |
|
|
|
io = gr.Interface( |
|
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"}, |
|
"textbox", |
|
[text, out], |
|
examples=["abc"], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == [ |
|
{"lines": 4, "__type__": "update", "interactive": False}, |
|
gr.Label.data_model(**{"label": "lion", "confidences": None}), |
|
] |
|
|
|
def test_caching_with_generators(self, patched_cache_folder): |
|
def test_generator(x): |
|
for y in range(len(x)): |
|
yield "Your output: " + x[: y + 1] |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"textbox", |
|
"textbox", |
|
examples=["abcdef"], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction[0] == "Your output: abcdef" |
|
|
|
def test_caching_with_generators_and_streamed_output(self, patched_cache_folder): |
|
file_dir = Path(Path(__file__).parent, "test_files") |
|
audio = str(file_dir / "audio_sample.wav") |
|
|
|
def test_generator(x): |
|
for y in range(int(x)): |
|
yield audio, y * 5 |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"number", |
|
[gr.Audio(streaming=True), "number"], |
|
examples=[3], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
len_input_audio = len(AudioSegment.from_wav(audio)) |
|
len_output_audio = len(AudioSegment.from_wav(prediction[0].path)) |
|
length_ratio = len_output_audio / len_input_audio |
|
assert round(length_ratio, 1) == 3.0 |
|
assert float(prediction[1]) == 10.0 |
|
|
|
def test_caching_with_async_generators(self, patched_cache_folder): |
|
async def test_generator(x): |
|
for y in range(len(x)): |
|
yield "Your output: " + x[: y + 1] |
|
|
|
io = gr.Interface( |
|
test_generator, |
|
"textbox", |
|
"textbox", |
|
examples=["abcdef"], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction[0] == "Your output: abcdef" |
|
|
|
def test_raise_helpful_error_message_if_providing_partial_examples( |
|
self, patched_cache_folder, tmp_path |
|
): |
|
def foo(a, b): |
|
return a + b |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples are being cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
gr.Interface( |
|
foo, |
|
inputs=["text", "text"], |
|
outputs=["text"], |
|
examples=[["foo"], ["bar"]], |
|
cache_examples=True, |
|
) |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples are being cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
gr.Interface( |
|
foo, |
|
inputs=["text", "text"], |
|
outputs=["text"], |
|
examples=[["foo", "bar"], ["bar", None]], |
|
cache_examples=True, |
|
) |
|
|
|
def foo_no_exception(a, b=2): |
|
return a * b |
|
|
|
gr.Interface( |
|
foo_no_exception, |
|
inputs=["text", "number"], |
|
outputs=["text"], |
|
examples=[["foo"], ["bar"]], |
|
cache_examples=True, |
|
) |
|
|
|
def many_missing(a, b, c): |
|
return a * b |
|
|
|
with pytest.warns( |
|
UserWarning, |
|
match="^Examples are being cached but not all input components have", |
|
): |
|
with pytest.raises(Exception): |
|
gr.Interface( |
|
many_missing, |
|
inputs=["text", "number", "number"], |
|
outputs=["text"], |
|
examples=[["foo", None, None], ["bar", 2, 3]], |
|
cache_examples=True, |
|
) |
|
|
|
def test_caching_with_batch(self, patched_cache_folder): |
|
def trim_words(words, lens): |
|
trimmed_words = [word[:length] for word, length in zip(words, lens)] |
|
return [trimmed_words] |
|
|
|
io = gr.Interface( |
|
trim_words, |
|
["textbox", gr.Number(precision=0)], |
|
["textbox"], |
|
batch=True, |
|
max_batch_size=16, |
|
examples=[["hello", 3], ["hi", 4]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == ["hel"] |
|
|
|
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder): |
|
def trim_words(words, lens): |
|
trimmed_words = [word[:length] for word, length in zip(words, lens)] |
|
return trimmed_words, lens |
|
|
|
io = gr.Interface( |
|
trim_words, |
|
["textbox", gr.Number(precision=0)], |
|
["textbox", gr.Number(precision=0)], |
|
batch=True, |
|
max_batch_size=16, |
|
examples=[["hello", 3], ["hi", 4]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert prediction == ["hel", "3"] |
|
|
|
def test_caching_with_non_io_component(self, patched_cache_folder): |
|
def predict(name): |
|
return name, gr.update(visible=True) |
|
|
|
with gr.Blocks(): |
|
t1 = gr.Textbox() |
|
with gr.Column(visible=False) as c: |
|
t2 = gr.Textbox() |
|
|
|
examples = gr.Examples( |
|
[["John"], ["Mary"]], |
|
fn=predict, |
|
inputs=[t1], |
|
outputs=[t2, c], |
|
cache_examples=True, |
|
) |
|
|
|
prediction = examples.load_from_cache(0) |
|
assert prediction == ["John", {"visible": True, "__type__": "update"}] |
|
|
|
def test_end_to_end(self, patched_cache_folder): |
|
def concatenate(str1, str2): |
|
return str1 + str2 |
|
|
|
with gr.Blocks() as demo: |
|
t1 = gr.Textbox() |
|
t2 = gr.Textbox() |
|
t1.submit(concatenate, [t1, t2], t2) |
|
|
|
gr.Examples( |
|
[["Hello,", None], ["Michael", None]], |
|
inputs=[t1, t2], |
|
api_name="load_example", |
|
) |
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True) |
|
client = TestClient(app) |
|
|
|
response = client.post("/api/load_example/", json={"data": [0]}) |
|
assert response.json()["data"] == [ |
|
{ |
|
"lines": 1, |
|
"max_lines": 20, |
|
"show_label": True, |
|
"container": True, |
|
"min_width": 160, |
|
"autofocus": False, |
|
"autoscroll": True, |
|
"elem_classes": [], |
|
"rtl": False, |
|
"show_copy_button": False, |
|
"__type__": "update", |
|
"visible": True, |
|
"value": "Hello,", |
|
"type": "text", |
|
} |
|
] |
|
|
|
response = client.post("/api/load_example/", json={"data": [1]}) |
|
assert response.json()["data"] == [ |
|
{ |
|
"lines": 1, |
|
"max_lines": 20, |
|
"show_label": True, |
|
"container": True, |
|
"min_width": 160, |
|
"autofocus": False, |
|
"autoscroll": True, |
|
"elem_classes": [], |
|
"rtl": False, |
|
"show_copy_button": False, |
|
"__type__": "update", |
|
"visible": True, |
|
"value": "Michael", |
|
"type": "text", |
|
} |
|
] |
|
|
|
def test_end_to_end_cache_examples(self, patched_cache_folder): |
|
def concatenate(str1, str2): |
|
return f"{str1} {str2}" |
|
|
|
with gr.Blocks() as demo: |
|
t1 = gr.Textbox() |
|
t2 = gr.Textbox() |
|
t1.submit(concatenate, [t1, t2], t2) |
|
|
|
gr.Examples( |
|
examples=[["Hello,", "World"], ["Michael", "Jordan"]], |
|
inputs=[t1, t2], |
|
outputs=[t2], |
|
fn=concatenate, |
|
cache_examples=True, |
|
api_name="load_example", |
|
) |
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True) |
|
client = TestClient(app) |
|
|
|
response = client.post("/api/load_example/", json={"data": [0]}) |
|
assert response.json()["data"] == ["Hello,", "World", "Hello, World"] |
|
|
|
response = client.post("/api/load_example/", json={"data": [1]}) |
|
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"] |
|
|
|
|
|
def test_multiple_file_flagging(tmp_path): |
|
with patch("gradio.utils.get_cache_folder", return_value=tmp_path): |
|
io = gr.Interface( |
|
fn=lambda *x: list(x), |
|
inputs=[ |
|
gr.Image(type="filepath", label="frame 1"), |
|
gr.Image(type="filepath", label="frame 2"), |
|
], |
|
outputs=[gr.Files()], |
|
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
|
|
assert len(prediction[0].root) == 2 |
|
assert all(isinstance(d, gr.FileData) for d in prediction[0].root) |
|
|
|
|
|
def test_examples_keep_all_suffixes(tmp_path): |
|
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())): |
|
file_1 = tmp_path / "foo.bar.txt" |
|
file_1.write_text("file 1") |
|
file_2 = tmp_path / "file_2" |
|
file_2.mkdir(parents=True) |
|
file_2 = file_2 / "foo.bar.txt" |
|
file_2.write_text("file 2") |
|
io = gr.Interface( |
|
fn=lambda x: x.name, |
|
inputs=gr.File(), |
|
outputs=[gr.File()], |
|
examples=[[str(file_1)], [str(file_2)]], |
|
cache_examples=True, |
|
) |
|
prediction = io.examples_handler.load_from_cache(0) |
|
assert Path(prediction[0].path).read_text() == "file 1" |
|
assert prediction[0].orig_name == "foo.bar.txt" |
|
assert prediction[0].path.endswith("foo.bar.txt") |
|
prediction = io.examples_handler.load_from_cache(1) |
|
assert Path(prediction[0].path).read_text() == "file 2" |
|
assert prediction[0].orig_name == "foo.bar.txt" |
|
assert prediction[0].path.endswith("foo.bar.txt") |
|
|
|
|
|
def test_make_waveform_with_spaces_in_filename(): |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
audio = os.path.join(tmpdirname, "test audio.wav") |
|
shutil.copy("test/test_files/audio_sample.wav", audio) |
|
waveform = gr.make_waveform(audio) |
|
assert waveform.endswith(".mp4") |
|
|
|
try: |
|
command = [ |
|
"ffprobe", |
|
"-v", |
|
"error", |
|
"-select_streams", |
|
"v:0", |
|
"-show_entries", |
|
"stream=width,height", |
|
"-of", |
|
"json", |
|
waveform, |
|
] |
|
|
|
result = subprocess.run(command, capture_output=True, text=True, check=True) |
|
output = result.stdout |
|
data = json.loads(output) |
|
|
|
width = data["streams"][0]["width"] |
|
height = data["streams"][0]["height"] |
|
assert width == 1000 |
|
assert height == 400 |
|
|
|
except subprocess.CalledProcessError as e: |
|
print("Error retrieving resolution of output waveform video:", e) |
|
|
|
|
|
def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch): |
|
""" |
|
Test that make_waveform raises an exception if ffmpeg fails, |
|
instead of returning a path to a non-existent or empty file. |
|
""" |
|
audio = tmp_path / "test audio.wav" |
|
shutil.copy("test/test_files/audio_sample.wav", audio) |
|
|
|
def _failing_ffmpeg(*args, **kwargs): |
|
raise subprocess.CalledProcessError(1, "ffmpeg") |
|
|
|
monkeypatch.setattr(subprocess, "call", _failing_ffmpeg) |
|
with pytest.raises(Exception): |
|
gr.make_waveform(str(audio)) |
|
|
|
|
|
class TestProgressBar: |
|
@pytest.mark.asyncio |
|
async def test_progress_bar(self): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
greeting = gr.Textbox() |
|
button = gr.Button(value="Greet") |
|
|
|
def greet(s, prog=gr.Progress()): |
|
prog(0, desc="start") |
|
time.sleep(0.15) |
|
for _ in prog.tqdm(range(4), unit="iter"): |
|
time.sleep(0.15) |
|
time.sleep(0.15) |
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
button.click(greet, name, greeting) |
|
demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].desc if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
(None, "start"), |
|
(0, None), |
|
(1, None), |
|
(2, None), |
|
(3, None), |
|
(4, None), |
|
] |
|
|
|
@pytest.mark.asyncio |
|
async def test_progress_bar_track_tqdm(self): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox() |
|
greeting = gr.Textbox() |
|
button = gr.Button(value="Greet") |
|
|
|
def greet(s, prog=gr.Progress(track_tqdm=True)): |
|
prog(0, desc="start") |
|
time.sleep(0.15) |
|
for _ in prog.tqdm(range(4), unit="iter"): |
|
time.sleep(0.15) |
|
time.sleep(0.15) |
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
button.click(greet, name, greeting) |
|
demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].desc if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
(None, "start"), |
|
(0, None), |
|
(1, None), |
|
(2, None), |
|
(3, None), |
|
(4, None), |
|
(0, "alphabet"), |
|
(1, "alphabet"), |
|
(2, "alphabet"), |
|
] |
|
|
|
@pytest.mark.asyncio |
|
async def test_progress_bar_track_tqdm_without_iterable(self): |
|
def greet(s, _=gr.Progress(track_tqdm=True)): |
|
with tqdm(total=len(s)) as progress_bar: |
|
for _c in s: |
|
progress_bar.update() |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
demo = gr.Interface(greet, "text", "text") |
|
demo.queue().launch(prevent_thread_lock=True) |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Gradio") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = ( |
|
status.progress_data[0].index if status.progress_data else None, |
|
status.progress_data[0].unit if status.progress_data else None, |
|
) |
|
if update != (None, None) and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
(1, "steps"), |
|
(2, "steps"), |
|
(3, "steps"), |
|
(4, "steps"), |
|
(5, "steps"), |
|
(6, "steps"), |
|
] |
|
|
|
@pytest.mark.asyncio |
|
async def test_info_and_warning_alerts(self): |
|
def greet(s): |
|
for _c in s: |
|
gr.Info(f"Letter {_c}") |
|
time.sleep(0.15) |
|
if len(s) < 5: |
|
gr.Warning("Too short!") |
|
time.sleep(0.15) |
|
return f"Hello, {s}!" |
|
|
|
demo = gr.Interface(greet, "text", "text") |
|
demo.queue().launch(prevent_thread_lock=True) |
|
|
|
client = grc.Client(demo.local_url) |
|
job = client.submit("Jon") |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = status.log |
|
if update is not None and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
|
|
assert status_updates == [ |
|
("Letter J", "info"), |
|
("Letter o", "info"), |
|
("Letter n", "info"), |
|
("Too short!", "warning"), |
|
] |
|
|
|
|
|
@pytest.mark.asyncio |
|
@pytest.mark.parametrize("async_handler", [True, False]) |
|
async def test_info_isolation(async_handler: bool): |
|
async def greet_async(name): |
|
await asyncio.sleep(2) |
|
gr.Info(f"Hello {name}") |
|
await asyncio.sleep(1) |
|
return name |
|
|
|
def greet_sync(name): |
|
time.sleep(2) |
|
gr.Info(f"Hello {name}") |
|
time.sleep(1) |
|
return name |
|
|
|
demo = gr.Interface( |
|
greet_async if async_handler else greet_sync, |
|
"text", |
|
"text", |
|
concurrency_limit=2, |
|
) |
|
demo.launch(prevent_thread_lock=True) |
|
|
|
async def session_interaction(name, delay=0): |
|
client = grc.Client(demo.local_url) |
|
job = client.submit(name) |
|
|
|
status_updates = [] |
|
while not job.done(): |
|
status = job.status() |
|
update = status.log |
|
if update is not None and ( |
|
len(status_updates) == 0 or status_updates[-1] != update |
|
): |
|
status_updates.append(update) |
|
time.sleep(0.05) |
|
return status_updates[-1][0] if status_updates else None |
|
|
|
alice_logs, bob_logs = await asyncio.gather( |
|
session_interaction("Alice"), |
|
session_interaction("Bob", delay=1), |
|
) |
|
|
|
assert alice_logs == "Hello Alice" |
|
assert bob_logs == "Hello Bob" |
|
|