File size: 3,587 Bytes
b1beb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d48158
b1beb2e
 
 
 
 
1d48158
 
 
 
 
 
 
 
 
 
b1beb2e
 
 
1d48158
b1beb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d64f08f
b1beb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5442ad
98c0a6b
 
 
b1beb2e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import lcpfn
import torch

model = lcpfn.LCPFN()

def line_plot_fn(data, cutoff, ci_form):
    cutoff = int(cutoff)
    ci = int(ci_form)

    empty_values = list(data[data.y == ""].index)

    if len(empty_values) > 0:
        if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))):
            raise gr.Error("Please enter a valid learning curve.")
        else:
            data = data[data.y != ""]
    
    if len(data) < cutoff:
        raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).")

    try:
        data["y"] = data["y"].astype(float)
    except:
        raise gr.Error("Please enter a valid learning curve.")

    x = torch.arange(1, 51).unsqueeze(1)
    y = torch.from_numpy(data.y.values).float().unsqueeze(1)

    rest_prob = (1 - (ci / 100)) / 2
    predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob])
    
    fig, ax = plt.subplots()

    ax.plot(x, data.y, "black", label="target")

    predictions = predictions.numpy()
    new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3)
    predictions = np.concatenate(
        [
            new,
            predictions
        ],
        axis=0
    )

    # plot extrapolation
    ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
    ax.fill_between(
            x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%"
    )

    # plot cutoff
    ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed")
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 50)
    ax.legend(loc="lower right")
    ax.set_xlabel("t")
    ax.set_ylabel("y")

    return fig

prior = lcpfn.sample_from_prior(np.random)
curve, _ = prior()

examples = []
for _ in range(14):
    prior = lcpfn.sample_from_prior(np.random)
    curve, _ = prior()
    if np.random.rand() < 0.5:
        curve = _
    df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"])
    df["t"] = [i for i in range(1, 50 + 1)]
    examples.append([df[["t", "y"]], 10])

with gr.Column() as components:
    gr.Number(value=10)
    gr.Number(value=10)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            dataform = gr.Dataframe(
                    value=examples[0][0],
                    headers=["t", "y"],
                    datatype=["number", "number"],
                    row_count=(50, "fixed"),
                    col_count=(2, "fixed"),
                    type="pandas",
                ) 
            with gr.Row():
                cutoffform = gr.Number(label="cutoff", value=10)
                ci_form = gr.Dropdown(label="Confidence Interval", choices=[
                    ("90%", 90),
                    ("95%", 95),
                    ("99%", 99)
                ], value=90)
        outputform = gr.Plot()
    gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves", examples_per_page=14)
    dataform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
    cutoffform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
    ci_form.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)




if __name__ == "__main__":
    demo.launch()