File size: 3,649 Bytes
1acf699
fa0e40e
1acf699
 
 
 
 
 
 
 
 
1ff7cbe
 
 
1acf699
1ff7cbe
 
 
 
 
1acf699
1ff7cbe
 
1acf699
 
 
 
f26c118
 
1ff7cbe
 
f26c118
 
 
 
 
a9c34c4
1acf699
f26c118
1acf699
 
adc0e7a
 
 
 
 
8f69aac
adc0e7a
 
 
0a95580
 
 
 
 
 
8f69aac
0a95580
1ff7cbe
adc0e7a
 
 
47a9ca5
 
 
 
 
1acf699
 
 
 
 
0a95580
 
 
1acf699
 
 
1ff7cbe
 
0a95580
8b45707
1acf699
adc0e7a
 
0a95580
adc0e7a
865ea06
adc0e7a
47a9ca5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
import os, io, ast #wget
import matplotlib.pyplot as plt
from PIL import Image
from genQC.pipeline.diffusion_pipeline import DiffusionPipeline
from genQC.inference.infer_srv import generate_srv_tensors, convert_tensors_to_srvs
from genQC.util import infer_torch_device

#--------------------------------
# download model into storage

#save_destination = "saves/"
#url_config  = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
#url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"

#def download(url, dst_dir):
#    if not os.path.exists(dst_dir): os.mkdir(dst_dir)
#    filename = os.path.join(dst_dir, os.path.basename(url))
#    if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
#    return filename

#config_file  = download(url_config, save_destination)
#weigths_file = download(url_weights, save_destination)

#--------------------------------
# setup

@st.cache_resource
def load_pipeline():
    #pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())  
    pipeline = DiffusionPipeline.from_pretrained("Floki00/qc_srv_3to8qubit", "cpu")
    pipeline.scheduler.set_timesteps(20)   
    return pipeline

pipeline = load_pipeline()


is_gpu_busy = False
def get_qcs(srv, num_of_qubits, max_gates, g):
    global is_gpu_busy

    with st.status("Generation started", expanded=True) as status:
        st.write("Generating tensors...")
        out_tensor = generate_srv_tensors(pipeline, f"Generate SRV: {srv}", samples=6, system_size=num_of_qubits, num_of_qubits=num_of_qubits, max_gates=max_gates, g=g)
        
        st.write("Converting to circuits...")
        qc_list, _, srv_list = convert_tensors_to_srvs(out_tensor, pipeline.gate_pool)
        
        st.write("Plotting...")
        fig, axs = plt.subplots(3, 2, figsize=(7,10), constrained_layout=True, dpi=120)

        for ax in axs.flatten():
            ax.axis('off')
            ax.text(0.5, 0.5,"Circuit generated with errors")
            
        
        for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()): 
            ax.clear()
            qc.draw("mpl", plot_barriers=False, ax=ax)
            ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
        status.update(label="Generation complete!", state="complete", expanded=False)
    
    # buf = io.BytesIO()
    # fig.savefig(buf)
    # buf.seek(0)
    # return Image.open(buf)
    return fig

#--------------------------------
# run

st.title("genQC · Generative Quantum Circuits")
st.write("""
Generating quantum circuits with diffusion models. Official demo of [[paper-arxiv]](https://arxiv.org/abs/2311.02041) [[code-repo]](https://github.com/FlorianFuerrutter/genQC).
""")

col1, col2 = st.columns(2)

srv           = col1.text_input('SRV', "[1,1,1,2,2,2]")
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=3)
max_gates     = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
g             = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=10.0)

srv_list = ast.literal_eval(srv)
if len(srv_list)!=num_of_qubits:
    st.warning(f'Number of qubits does not match with given SRV {srv_list}. This could result in error-circuits!', icon="⚠️")

if col1.button('Generate circuits'):    
    fig = get_qcs(srv_list, num_of_qubits, max_gates, g)
    # col2.image(image, use_column_width=True)
    col2.pyplot(fig)