samusander commited on
Commit
c866f44
·
1 Parent(s): 0a0209c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from tqdm.auto import tqdm
4
+
5
+ from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
6
+ from point_e.diffusion.sampler import PointCloudSampler
7
+ from point_e.models.download import load_checkpoint
8
+ from point_e.models.configs import MODEL_CONFIGS, model_from_config
9
+ from point_e.util.plotting import plot_point_cloud
10
+ import streamlit as st
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ st.write('creating base model...')
14
+ base_name = 'base40M' # use base300M or base1B for better results
15
+ base_model = model_from_config(MODEL_CONFIGS[base_name], device)
16
+ base_model.eval()
17
+ base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
18
+ st.write('creating upsample model...')
19
+ upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
20
+ upsampler_model.eval()
21
+ upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
22
+ st.write('downloading base checkpoint...')
23
+ base_model.load_state_dict(load_checkpoint(base_name, device))
24
+ st.write('downloading upsampler checkpoint...')
25
+ upsampler_model.load_state_dict(load_checkpoint('upsample', device))
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+