strongpear commited on
Commit
6ac36ef
1 Parent(s): 7bbedb1

update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -74,10 +74,10 @@ class inferSSCL():
74
  self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
75
 
76
  self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
77
- self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model'))
78
 
79
  self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
80
- self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model'), strict=False)
81
 
82
 
83
  def build_pipe(self):
 
74
  self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
75
 
76
  self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
77
+ self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model', map_location=torch.device('cpu')))
78
 
79
  self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
80
+ self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model', map_location=torch.device('cpu')), strict=False)
81
 
82
 
83
  def build_pipe(self):