File size: 5,711 Bytes
df1aa82 66e8b15 6401503 df1aa82 59166be 66e8b15 4be4da9 6348ca6 e05b343 7a1d097 6401503 6348ca6 ff04f99 1bbdc1e ff04f99 6348ca6 28e9187 6348ca6 44ac6bc f5dcb41 6348ca6 66e8b15 bfb27a7 59166be c0e96e4 2ecd2bd 59166be aeabc6f 59166be 29d69db 59166be 6c63d42 6401503 629a811 66e8b15 6348ca6 66e8b15 6401503 59166be 629a811 6401503 6348ca6 e05b343 c0e96e4 6348ca6 66e8b15 6401503 59166be 66e8b15 59166be df1aa82 4be4da9 3c6249e 4be4da9 69b1728 6348ca6 69b1728 6348ca6 ff04f99 7a1d097 1bbdc1e ff04f99 7a1d097 e05b343 1bbdc1e 4ad28b6 1bbdc1e e05b343 97e37df 1bbdc1e 4ad28b6 1bbdc1e 97e37df 6348ca6 175b3fd 4be4da9 175b3fd 6348ca6 69b1728 6348ca6 1bbdc1e 6348ca6 bfb27a7 6348ca6 ff04f99 6348ca6 df1aa82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import logging
import subprocess
import threading
import psutil
import sys
import os
from giskard.settings import settings
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("giskard").setLevel(logging.INFO)
GSK_HUB_URL = 'GSK_HUB_URL'
GSK_API_KEY = 'GSK_API_KEY'
HF_SPACE_HOST = 'SPACE_HOST'
HF_SPACE_TOKEN = 'GSK_HUB_HFS'
READONLY = os.environ.get("READONLY") if os.environ.get("READONLY") else False
LOG_FILE = "output.log"
def read_logs():
sys.stdout.flush()
try:
with open(LOG_FILE, "r") as f:
return f.read()
except Exception:
return "ML worker not running"
def detect_gpu():
try:
import torch
logger.info(f"PyTorch GPU: {torch.cuda.is_available()}")
except ImportError:
logger.warn("No PyTorch installed")
try:
import tensorflow as tf
logger.info(f"Tensorflow GPU: {len(tf.config.list_physical_devices('GPU')) > 0}")
except ImportError:
logger.warn("No Tensorflow installed")
threading.Thread(target=detect_gpu).start()
previous_url = ""
ml_worker = None
def read_status():
if ml_worker:
return f"ML worker serving {previous_url}"
elif len(previous_url):
return f"ML worker exited for {previous_url}"
else:
return "ML worker not started"
def run_ml_worker(url, api_key, hf_token):
global ml_worker, previous_url
previous_url = url
subprocess.run(["giskard", "worker", "stop"])
ml_worker = subprocess.Popen(
[
"giskard", "worker", "start",
"-u", f"{url}", "-k", f"{api_key}", "--hf-token", f"{hf_token}"
],
stdout=open(LOG_FILE, "w"), stderr=subprocess.STDOUT
)
args = ml_worker.args[:3]
logging.info(f"Process {args} exited with {ml_worker.wait()}")
ml_worker = None
def stop_ml_worker():
global ml_worker, previous_url
if ml_worker is not None:
logging.info(f"Stopping ML worker for {previous_url}")
ml_worker.terminate()
ml_worker = None
logging.info("ML worker stopped")
return "ML worker stopped"
return "ML worker not started"
def start_ml_worker(url, api_key, hf_token):
if not url or len(url) < 1:
return "Please provide URL of Giskard"
if ml_worker is not None:
return f"ML worker is still running for {previous_url}"
# Always run an external ML worker
stop_ml_worker()
logging.info(f"Starting ML worker for {url}")
thread = threading.Thread(target=run_ml_worker, args=(url, api_key, hf_token))
thread.start()
return f"ML worker running for {url}"
def get_gpu_usage():
# Referring: https://stackoverflow.com/questions/67707828/how-to-get-every-seconds-gpu-usage-in-python
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
COMMAND = "nvidia-smi --query-gpu=utilization.gpu --format=csv"
try:
gpu_use_info = output_to_list(subprocess.check_output(COMMAND.split(),stderr=subprocess.STDOUT))[1:]
except subprocess.CalledProcessError:
return "Unavailable"
gpu_use_values = [int(x.split()[0]) for i, x in enumerate(gpu_use_info)]
return f"{gpu_use_values[0]} %" if len(gpu_use_values) > 0 else "Unavailable"
def get_usage():
from shutil import which
gpu_usage_info = ""
if which("nvidia-smi") is not None:
gpu_usage_info = f" GPU: {get_gpu_usage()}"
return f"CPU: {psutil.cpu_percent()} %{gpu_usage_info}"
theme = gr.themes.Soft(
primary_hue="green",
)
with gr.Blocks(theme=theme) as iface:
with gr.Row():
with gr.Column():
url = os.environ.get(GSK_HUB_URL) if os.environ.get(GSK_HUB_URL) else f"http://{settings.host}:{settings.ws_port}"
url_input = gr.Textbox(
label="Giskard Hub URL",
interactive=not READONLY,
value=url,
)
api_key_input = gr.Textbox(
label="Giskard Hub API Key",
interactive=not READONLY,
type="password",
value=os.environ.get(GSK_API_KEY),
placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx",
)
hf_token_input = gr.Textbox(
label="Hugging Face Spaces Token",
interactive=not READONLY,
type="password",
value=os.environ.get(HF_SPACE_TOKEN),
info="if using a private Giskard Hub on Hugging Face Spaces",
)
with gr.Column():
output = gr.Textbox(label="Status")
gr.Textbox(value=get_usage, label="Usage", every=1.0)
if READONLY:
gr.Textbox("You are browsering a read-only 🐢 Giskard ML worker instance. ", container=False)
gr.Textbox("Please duplicate this space to configure your own Giskard ML worker.", container=False)
gr.DuplicateButton(value="Duplicate Space for 🐢 Giskard ML worker", size='lg', variant="primary")
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output)
stop_btn = gr.Button("Stop", variant="stop", interactive=not READONLY)
stop_btn.click(stop_ml_worker, None, output)
logs = gr.Textbox(label="Giskard ML worker log:")
iface.load(read_logs, None, logs, every=0.5)
iface.load(read_status, None, output, every=5)
if os.environ.get(GSK_HUB_URL) and os.environ.get(GSK_API_KEY):
start_ml_worker(os.environ.get(GSK_HUB_URL), os.environ.get(GSK_API_KEY), os.environ.get(HF_SPACE_TOKEN))
iface.queue()
iface.launch()
|