Update app.py
Browse files
app.py
CHANGED
@@ -9,26 +9,26 @@ from genQC.util import infer_torch_device
|
|
9 |
#--------------------------------
|
10 |
# download model into storage
|
11 |
|
12 |
-
save_destination = "saves/"
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
filename = os.path.join(dst_dir, os.path.basename(url))
|
20 |
-
if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
|
21 |
-
return filename
|
22 |
-
|
23 |
-
config_file = download(url_config, save_destination)
|
24 |
-
weigths_file = download(url_weights, save_destination)
|
25 |
|
26 |
#--------------------------------
|
27 |
# setup
|
28 |
|
29 |
@st.cache_resource
|
30 |
def load_pipeline():
|
31 |
-
pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
|
|
|
32 |
pipeline.scheduler.set_timesteps(20)
|
33 |
return pipeline
|
34 |
|
@@ -56,7 +56,7 @@ def get_qcs(srv, num_of_qubits, max_gates, g):
|
|
56 |
|
57 |
for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
|
58 |
ax.clear()
|
59 |
-
qc.draw("mpl", plot_barriers=False, ax=ax
|
60 |
ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
|
61 |
status.update(label="Generation complete!", state="complete", expanded=False)
|
62 |
|
@@ -76,10 +76,10 @@ Generating quantum circuits with diffusion models. Official demo of [[paper-arxi
|
|
76 |
|
77 |
col1, col2 = st.columns(2)
|
78 |
|
79 |
-
srv = col1.text_input('SRV', "[1,1,1,2,2]")
|
80 |
-
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=
|
81 |
max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
|
82 |
-
g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=
|
83 |
|
84 |
srv_list = ast.literal_eval(srv)
|
85 |
if len(srv_list)!=num_of_qubits:
|
|
|
9 |
#--------------------------------
|
10 |
# download model into storage
|
11 |
|
12 |
+
#save_destination = "saves/"
|
13 |
+
#url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
|
14 |
+
#url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
|
15 |
|
16 |
+
#def download(url, dst_dir):
|
17 |
+
# if not os.path.exists(dst_dir): os.mkdir(dst_dir)
|
18 |
+
# filename = os.path.join(dst_dir, os.path.basename(url))
|
19 |
+
# if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
|
20 |
+
# return filename
|
21 |
|
22 |
+
#config_file = download(url_config, save_destination)
|
23 |
+
#weigths_file = download(url_weights, save_destination)
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
#--------------------------------
|
26 |
# setup
|
27 |
|
28 |
@st.cache_resource
|
29 |
def load_pipeline():
|
30 |
+
#pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
|
31 |
+
pipeline = DiffusionPipeline.from_pretrained("Floki00/qc_srv_3to8qubit", "cpu")
|
32 |
pipeline.scheduler.set_timesteps(20)
|
33 |
return pipeline
|
34 |
|
|
|
56 |
|
57 |
for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
|
58 |
ax.clear()
|
59 |
+
qc.draw("mpl", plot_barriers=False, ax=ax)
|
60 |
ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
|
61 |
status.update(label="Generation complete!", state="complete", expanded=False)
|
62 |
|
|
|
76 |
|
77 |
col1, col2 = st.columns(2)
|
78 |
|
79 |
+
srv = col1.text_input('SRV', "[1,1,1,2,2,2]")
|
80 |
+
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=3)
|
81 |
max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
|
82 |
+
g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=10)
|
83 |
|
84 |
srv_list = ast.literal_eval(srv)
|
85 |
if len(srv_list)!=num_of_qubits:
|