Spaces:
Build error
Build error
Upload weakly_supervised_parser/utils/populate_chart.py
Browse files
weakly_supervised_parser/utils/populate_chart.py
CHANGED
@@ -26,9 +26,9 @@ ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'm
|
|
26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
27 |
ptb_most_common_first_token = "the"
|
28 |
|
29 |
-
|
30 |
|
31 |
-
|
32 |
|
33 |
|
34 |
class PopulateCKYChart:
|
@@ -54,20 +54,20 @@ class PopulateCKYChart:
|
|
54 |
|
55 |
if predict_type == "inside":
|
56 |
|
57 |
-
if data.shape[0] > chunks:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
else:
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
data["inside_scores"] = inside_scores
|
73 |
data.loc[
|
|
|
26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
27 |
ptb_most_common_first_token = "the"
|
28 |
|
29 |
+
from pytorch_lightning import Trainer
|
30 |
|
31 |
+
trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1)
|
32 |
|
33 |
|
34 |
class PopulateCKYChart:
|
|
|
54 |
|
55 |
if predict_type == "inside":
|
56 |
|
57 |
+
# if data.shape[0] > chunks:
|
58 |
+
# data_chunks = np.array_split(data, data.shape[0] // chunks)
|
59 |
+
# for data_chunk in data_chunks:
|
60 |
+
# inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
61 |
+
# scale_axis=scale_axis,
|
62 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
63 |
+
# else:
|
64 |
+
# inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
65 |
+
# scale_axis=scale_axis,
|
66 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
67 |
|
68 |
+
test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None,
|
69 |
+
test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]])
|
70 |
+
inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0])
|
71 |
|
72 |
data["inside_scores"] = inside_scores
|
73 |
data.loc[
|