File size: 3,346 Bytes
a7f6bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
from persist import persist, load_widget_state
from modelcards import CardData, ModelCard
from huggingface_hub import create_repo


def is_float(value):
    try:
        float(value)
        return True
    except:
        return False

def get_card():
    languages=st.session_state.languages or None
    license=st.session_state.license or None
    library_name = st.session_state.library_name or None
    tags= [x.strip() for x in st.session_state.tags.split(',') if x.strip()]
    tags.append("autogenerated-modelcard")
    datasets= [x.strip() for x in st.session_state.datasets.split(',') if x.strip()] or None
    metrics=st.session_state.metrics or None
    model_name = st.session_state.model_name or None
    model_description = st.session_state.model_description or None
    #Model_details_text = st.session_state.Model_details_text or None
    #Model_how_to = st.session_state.Model_how_to or None
    authors = st.session_state.authors or None
    paper_url = st.session_state.paper_url or None
    github_url = st.session_state.github_url or None
    bibtex_citations = st.session_state.bibtex_citations or None
    emissions = float(st.session_state.emissions) if is_float(st.session_state.emissions) else None  # BUG

    # Handle any warnings...
    do_warn = False
    warning_msg = "Warning: The following fields are required but have not been filled in: "
    if not languages:
        warning_msg += "\n- Languages"
        do_warn = True
    if not license:
        warning_msg += "\n- License"
        do_warn = True
    if do_warn:
        st.error(warning_msg)
        st.stop()

    # Generate and display card
    card_data = CardData(
        language=languages,
        license=license,
        library_name=library_name,
        tags=tags,
        datasets=datasets,
        metrics=metrics,
    )
    if emissions:
        card_data.co2_eq_emissions = {'emissions': emissions}

    card = ModelCard.from_template(
        card_data,
        template_path='template.md',
        model_id=model_name,
        # Template kwargs:
        model_description=model_description,
        license=license,
        authors=authors,
        paper_url=paper_url,
        github_url=github_url,
        bibtex_citations=bibtex_citations,
        emissions=emissions
    )
    return card


def main():

    card = get_card()
    card.save('current_card.md')
    view_raw = st.sidebar.checkbox("View Raw")
    if view_raw:
        st.text(card)
    else:
        st.markdown(card.text, unsafe_allow_html=True)

    with st.sidebar:
        with st.form("Upload to 🤗 Hub"):
            st.markdown("Use a token with write access from [here](https://hf.co/settings/tokens)")
            token = st.text_input("Token", type='password')
            repo_id = st.text_input("Repo ID")
            submit = st.form_submit_button('Upload to 🤗 Hub')

        if submit:
            if len(repo_id.split('/')) == 2:
                repo_url = create_repo(repo_id, exist_ok=True, token=token)
                card.push_to_hub(repo_id, token=token)
                st.success(f"Pushed the card to the repo [here]({repo_url}!")
            else:
                st.error("Repo ID invalid. It should be username/repo-name. For example: nateraw/food")


if __name__ == "__main__":
    load_widget_state()
    main()