pedramyazdipoor
commited on
Commit
•
ed4a795
1
Parent(s):
2f59c7f
Update README.md
Browse files
README.md
CHANGED
@@ -58,7 +58,7 @@ There are some considerations for inference:
|
|
58 |
3) The selected span must be the most probable choice among N pairs of candidates.
|
59 |
|
60 |
```python
|
61 |
-
def generate_indexes(start_logits, end_logits, N,
|
62 |
|
63 |
output_start = start_logits
|
64 |
output_end = end_logits
|
@@ -79,7 +79,7 @@ def generate_indexes(start_logits, end_logits, N, min_index_list):
|
|
79 |
for a in range(0,N):
|
80 |
for b in range(0,N):
|
81 |
if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
|
82 |
-
if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (
|
83 |
prob = sorted_start_list[a][1] + sorted_end_list[b][1]
|
84 |
start_idx = sorted_start_list[a][0]
|
85 |
end_idx = sorted_end_list[b][0]
|
@@ -104,7 +104,7 @@ encoding = tokenizer(text,question,add_special_tokens = True,
|
|
104 |
out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
|
105 |
#we had to change some pieces of code to make it compatible with one answer generation at a time
|
106 |
#If you have unanswerable questions, use out['start_logits'][0][0:] and out['end_logits'][0][0:] because <s> (the 1st token) is for this situation and must be compared with other tokens.
|
107 |
-
#you can initialize
|
108 |
answer_start_index, answer_end_index = generate_indexes(out['start_logits'][0][1:], out['end_logits'][0][1:], 5, 0)
|
109 |
print(tokenizer.tokenize(text + question))
|
110 |
print(tokenizer.tokenize(text + question)[answer_start_index : (answer_end_index + 1)])
|
|
|
58 |
3) The selected span must be the most probable choice among N pairs of candidates.
|
59 |
|
60 |
```python
|
61 |
+
def generate_indexes(start_logits, end_logits, N, max_index):
|
62 |
|
63 |
output_start = start_logits
|
64 |
output_end = end_logits
|
|
|
79 |
for a in range(0,N):
|
80 |
for b in range(0,N):
|
81 |
if (sorted_start_list[a][1] + sorted_end_list[b][1]) > prob :
|
82 |
+
if (sorted_start_list[a][0] <= sorted_end_list[b][0]) and (sorted_end_list[a][0] < max_index) :
|
83 |
prob = sorted_start_list[a][1] + sorted_end_list[b][1]
|
84 |
start_idx = sorted_start_list[a][0]
|
85 |
end_idx = sorted_end_list[b][0]
|
|
|
104 |
out = model(encoding['input_ids'].to(device),encoding['attention_mask'].to(device), encoding['token_type_ids'].to(device))
|
105 |
#we had to change some pieces of code to make it compatible with one answer generation at a time
|
106 |
#If you have unanswerable questions, use out['start_logits'][0][0:] and out['end_logits'][0][0:] because <s> (the 1st token) is for this situation and must be compared with other tokens.
|
107 |
+
#you can initialize max_index in generate_indexes() to put force on tokens being chosen to be within the context(end index must be less than seperator token).
|
108 |
answer_start_index, answer_end_index = generate_indexes(out['start_logits'][0][1:], out['end_logits'][0][1:], 5, 0)
|
109 |
print(tokenizer.tokenize(text + question))
|
110 |
print(tokenizer.tokenize(text + question)[answer_start_index : (answer_end_index + 1)])
|