File size: 2,814 Bytes
92913ed
 
 
 
 
0704015
 
2edb029
92913ed
 
 
 
941c44c
2edb029
941c44c
 
2edb029
941c44c
 
2edb029
 
 
 
 
 
 
 
 
 
 
aab6fe4
941c44c
 
 
 
 
 
 
 
92913ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0704015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c997f1d
aab6fe4
 
 
c997f1d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import requests.exceptions
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load

app = gr.Blocks()

def load_agent(model_id_1, model_id_2):
    """
    This function load the agent's video and results
    :return: video_path
    """
    # Load the metrics
    metadata_1 = get_metadata(model_id_1)
    
    # Get the accuracy
    results_1 = parse_metrics_accuracy(metadata_1)
    
    # Load the video
    video_path_1 = hf_hub_download(model_id_1, filename="replay.mp4")
    
    # Load the metrics
    metadata_2 = get_metadata(model_id_2)
    
    # Get the accuracy
    results_2 = parse_metrics_accuracy(metadata_2)
    
    # Load the video
    video_path_2 = hf_hub_download(model_id_2, filename="replay.mp4")
    
    return model_id_1, video_path_1, results_1, model_id_2, video_path_2, results_2

def parse_metrics_accuracy(meta):
    if "model-index" not in meta:
        return None
    result = meta["model-index"][0]["results"]
    metrics = result[0]["metrics"]
    accuracy = metrics[0]["value"]
    return accuracy

def get_metadata(model_id):
    """
    Get the metadata of the model repo
    :param model_id:
    :return: metadata
    """
    try:
        readme_path = hf_hub_download(model_id, filename="README.md")
        metadata = metadata_load(readme_path)
        print(metadata)
        return metadata
    except requests.exceptions.HTTPError:
        return None




with app:
    gr.Markdown(
    """
    # Compare Deep Reinforcement Learning Agents 🤖
    
    Type two models id you want to compare or check examples below.
    """)
    with gr.Row():
      model1_input = gr.Textbox(label="Model 1")
      model2_input = gr.Textbox(label="Model 2")
    with gr.Row():
      app_button = gr.Button("Compare models")
    with gr.Row():
      with gr.Column():
        model1_name = gr.Markdown()
        model1_video_output = gr.Video()
        model1_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
      with gr.Column():
        model2_name = gr.Markdown()
        model2_video_output = gr.Video()
        model2_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")

    app_button.click(load_agent, inputs=[model1_input, model2_input], outputs=[model1_name,  model1_video_output, model1_score_output, model2_name, model2_video_output, model2_score_output])
    
    examples = gr.Examples(examples=[["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
        ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
        ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"], 
        ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"],
                           inputs=[model1_input, model2_input])

    
app.launch()