Spaces:
Runtime error
Runtime error
letrunglinh
commited on
Commit
•
dd58cce
1
Parent(s):
31f75a0
Update pairwise_model.py
Browse files- pairwise_model.py +8 -3
pairwise_model.py
CHANGED
@@ -8,7 +8,7 @@ from optimum.intel import OVModelForQuestionAnswering
|
|
8 |
import openvino.inference_engine as ie
|
9 |
import os
|
10 |
import gradio as gr
|
11 |
-
|
12 |
AUTH_TOKEN = "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"
|
13 |
|
14 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
@@ -35,7 +35,12 @@ class PairwiseModel_modify(nn.Module):
|
|
35 |
|
36 |
def forward(self, ids, masks):
|
37 |
# Export the model to ONNX format
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
# Specify the input shapes (batch_size, max_sequence_length)
|
40 |
input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
|
41 |
|
@@ -57,7 +62,7 @@ class PairwiseModel_modify(nn.Module):
|
|
57 |
tmp["question"] = question
|
58 |
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
59 |
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
60 |
-
num_workers=
|
61 |
preds = []
|
62 |
with torch.no_grad():
|
63 |
bar = enumerate(valid_loader)
|
|
|
8 |
import openvino.inference_engine as ie
|
9 |
import os
|
10 |
import gradio as gr
|
11 |
+
from multiprocessing import cpu_count
|
12 |
AUTH_TOKEN = "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"
|
13 |
|
14 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
|
|
35 |
|
36 |
def forward(self, ids, masks):
|
37 |
# Export the model to ONNX format
|
38 |
+
ids_np = ids.cpu().numpy().astype(np.int64)
|
39 |
+
masks_np = masks.cpu().numpy().astype(np.int64)
|
40 |
+
ids_device = torch.from_numpy(ids_np).to(self.device)
|
41 |
+
masks_device = torch.from_numpy(masks_np).to(self.device)
|
42 |
+
|
43 |
+
input_feed = {"input_ids": ids_device, "attention_mask": masks_device}
|
44 |
# Specify the input shapes (batch_size, max_sequence_length)
|
45 |
input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
|
46 |
|
|
|
62 |
tmp["question"] = question
|
63 |
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
64 |
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
65 |
+
num_workers=cpu_count(), shuffle=False, pin_memory=True)
|
66 |
preds = []
|
67 |
with torch.no_grad():
|
68 |
bar = enumerate(valid_loader)
|