File size: 4,752 Bytes
df1aa82
66e8b15
6401503
df1aa82
59166be
66e8b15
 
6348ca6
e05b343
 
7a1d097
 
6401503
 
 
 
6348ca6
ff04f99
 
 
 
1bbdc1e
ff04f99
6348ca6
 
 
 
28e9187
 
 
 
 
6348ca6
44ac6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
6348ca6
66e8b15
 
 
bfb27a7
 
 
 
 
 
 
 
 
59166be
 
 
c0e96e4
2ecd2bd
59166be
 
 
 
29d69db
59166be
6c63d42
6401503
629a811
66e8b15
 
6348ca6
66e8b15
 
6401503
59166be
629a811
6401503
6348ca6
 
 
 
 
e05b343
 
 
c0e96e4
 
 
6348ca6
 
66e8b15
6401503
59166be
66e8b15
59166be
df1aa82
69b1728
 
 
6348ca6
69b1728
6348ca6
 
ff04f99
7a1d097
 
1bbdc1e
ff04f99
7a1d097
e05b343
 
1bbdc1e
4ad28b6
1bbdc1e
e05b343
 
97e37df
 
1bbdc1e
4ad28b6
1bbdc1e
97e37df
 
6348ca6
175b3fd
 
 
 
 
 
 
6348ca6
69b1728
6348ca6
 
1bbdc1e
6348ca6
 
 
 
bfb27a7
6348ca6
ff04f99
 
 
6348ca6
df1aa82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr

import logging

import subprocess
import threading

import sys
import os

from giskard.settings import settings

logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("giskard").setLevel(logging.INFO)


GSK_HUB_URL = 'GSK_HUB_URL'
GSK_API_KEY = 'GSK_API_KEY'
HF_SPACE_HOST = 'SPACE_HOST'
HF_SPACE_TOKEN = 'GSK_HUB_HFS'
READONLY = os.environ.get("READONLY") if os.environ.get("READONLY") else False

LOG_FILE = "output.log"

def read_logs():
    sys.stdout.flush()
    try:
        with open(LOG_FILE, "r") as f:
            return f.read()
    except Exception:
        return "ML worker not running"

def detect_gpu():
    try:
        import torch
        logger.info(f"PyTorch GPU: {torch.cuda.is_available()}")
    except ImportError:
        logger.warn("No PyTorch installed")
    
    try:
        import tensorflow as tf
        logger.info(f"Tensorflow GPU: {len(tf.config.list_physical_devices('GPU')) > 0}")
    except ImportError:
        logger.warn("No Tensorflow installed")

detect_gpu()

previous_url = ""
ml_worker = None

def read_status():
    if ml_worker:
        return f"ML worker serving {previous_url}"
    elif len(previous_url):
        return f"ML worker exited for {previous_url}"
    else:
        return "ML worker not started"


def run_ml_worker(url, api_key, hf_token):
    global ml_worker, previous_url
    previous_url = url
    subprocess.run(["giskard", "worker", "stop"])
    ml_worker = subprocess.Popen(
        [
            "giskard", "worker", "start",
            "-u", f"{url}", "-k", f"{api_key}", "-t", f"{hf_token}"
        ],
        stdout=open(LOG_FILE, "w"), stderr=subprocess.STDOUT
    )
    args = ml_worker.args[:3]
    logging.info(f"Process {args} exited with {ml_worker.wait()}")
    ml_worker = None


def stop_ml_worker():
    global ml_worker, previous_url
    if ml_worker is not None:
        logging.info(f"Stopping ML worker for {previous_url}")
        ml_worker.terminate()
        ml_worker = None
        logging.info("ML worker stopped")
        return "ML worker stopped"
    return "ML worker not started"


def start_ml_worker(url, api_key, hf_token):
    if not url or len(url) < 1:
        return "Please provide URL of Giskard"

    if ml_worker is not None:
        return f"ML worker is still running for {previous_url}"

    # Always run an external ML worker
    stop_ml_worker()

    logging.info(f"Starting ML worker for {url}")
    thread = threading.Thread(target=run_ml_worker, args=(url, api_key, hf_token))
    thread.start()
    return f"ML worker running for {url}"

theme = gr.themes.Soft(
    primary_hue="green",
)

with gr.Blocks(theme=theme) as iface:
    with gr.Row():
        with gr.Column():
            url = os.environ.get(GSK_HUB_URL) if os.environ.get(GSK_HUB_URL) else f"http://{settings.host}:{settings.ws_port}"
            url_input = gr.Textbox(
                label="Giskard Hub URL",
                interactive=not READONLY,
                value=url,
            )
            api_key_input = gr.Textbox(
                label="Giskard Hub API Key",
                interactive=not READONLY,
                type="password",
                value=os.environ.get(GSK_API_KEY),
                placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx",
            )
            hf_token_input = gr.Textbox(
                label="Hugging Face Spaces Token",
                interactive=not READONLY,
                type="password",
                value=os.environ.get(HF_SPACE_TOKEN),
                info="if using a private Giskard Hub on Hugging Face Spaces",
            )

        with gr.Column():
            output = gr.Textbox(label="Status")
            if READONLY:
                gr.Textbox("You are browsering a read-only 🐢 Giskard ML worker instance. ", container=False)
                gr.Textbox("Please duplicate this space to configure your own Giskard ML worker.", container=False)
                gr.DuplicateButton(value="Duplicate Space for 🐢 Giskard ML worker", size='lg', variant="primary")

    with gr.Row():
        run_btn = gr.Button("Run", variant="primary")
        run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output)

        stop_btn = gr.Button("Stop", variant="stop", interactive=not READONLY)
        stop_btn.click(stop_ml_worker, None, output)

    logs = gr.Textbox(label="Giskard ML worker log:")
    iface.load(read_logs, None, logs, every=0.5)
    iface.load(read_status, None, output, every=5)

if os.environ.get(GSK_HUB_URL) and os.environ.get(GSK_API_KEY):
    start_ml_worker(os.environ.get(GSK_HUB_URL), os.environ.get(GSK_API_KEY), os.environ.get(HF_SPACE_TOKEN))

iface.queue()
iface.launch()