Alexander Becker commited on
Commit
1df2a85
·
1 Parent(s): 43a3559

Print JAX device

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import gradio as gr
6
  from PIL import Image
7
  import numpy as np
 
8
 
9
  from gradio_dualvision import DualVisionApp
10
  from gradio_dualvision.gradio_patches.radio import Radio
@@ -14,6 +15,9 @@ from super_resolve import process
14
 
15
  REPO_ID = "prs-eth/thera-edsr-plus"
16
 
 
 
 
17
  # load model
18
  model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
19
  with open(model_path, 'rb') as fh:
 
5
  import gradio as gr
6
  from PIL import Image
7
  import numpy as np
8
+ import jax
9
 
10
  from gradio_dualvision import DualVisionApp
11
  from gradio_dualvision.gradio_patches.radio import Radio
 
15
 
16
  REPO_ID = "prs-eth/thera-edsr-plus"
17
 
18
+ print(f"JAX devices: {jax.devices()}")
19
+ print(f"JAX device type: {jax.devices()[0].device_kind}")
20
+
21
  # load model
22
  model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
23
  with open(model_path, 'rb') as fh: