ArneBinder commited on
Commit
148e0d6
1 Parent(s): 86277c0

Upload 9 files

Browse files
Files changed (2) hide show
  1. app.py +20 -5
  2. document_store.py +219 -85
app.py CHANGED
@@ -134,7 +134,7 @@ def upload_processed_documents(
134
  file_name: str,
135
  document_store: DocumentStore,
136
  ) -> pd.DataFrame:
137
- document_store.add_from_json(file_name)
138
  return document_store.overview()
139
 
140
 
@@ -175,7 +175,9 @@ def main():
175
  }
176
 
177
  with gr.Blocks() as demo:
178
- document_store_state = gr.State(DocumentStore())
 
 
179
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
180
  models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
181
  with gr.Row():
@@ -342,7 +344,10 @@ def main():
342
  )
343
 
344
  retrieve_relevant_adus_event_kwargs = dict(
345
- fn=partial(DocumentStore.get_relevant_adus_df, columns=relevant_adus.headers),
 
 
 
346
  inputs=[
347
  document_store_state,
348
  selected_adu_id,
@@ -355,13 +360,23 @@ def main():
355
  )
356
 
357
  selected_adu_id.change(
358
- fn=partial(get_annotation_from_document, annotation_layer="labeled_spans"),
 
 
 
 
359
  inputs=[document_state, selected_adu_id],
360
  outputs=[selected_adu_text],
361
  ).success(**retrieve_relevant_adus_event_kwargs)
362
 
363
  retrieve_similar_adus_btn.click(
364
- fn=DocumentStore.get_similar_adus_df,
 
 
 
 
 
 
365
  inputs=[
366
  document_store_state,
367
  selected_adu_id,
 
134
  file_name: str,
135
  document_store: DocumentStore,
136
  ) -> pd.DataFrame:
137
+ document_store.add_documents_from_json(file_name)
138
  return document_store.overview()
139
 
140
 
 
175
  }
176
 
177
  with gr.Blocks() as demo:
178
+ document_store_state = gr.State(
179
+ DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
180
+ )
181
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
182
  models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
183
  with gr.Row():
 
344
  )
345
 
346
  retrieve_relevant_adus_event_kwargs = dict(
347
+ fn=partial(
348
+ DocumentStore.get_related_annotations_from_other_documents_df,
349
+ columns=relevant_adus.headers,
350
+ ),
351
  inputs=[
352
  document_store_state,
353
  selected_adu_id,
 
360
  )
361
 
362
  selected_adu_id.change(
363
+ fn=partial(
364
+ get_annotation_from_document,
365
+ annotation_layer="labeled_spans",
366
+ use_predictions=True,
367
+ ),
368
  inputs=[document_state, selected_adu_id],
369
  outputs=[selected_adu_text],
370
  ).success(**retrieve_relevant_adus_event_kwargs)
371
 
372
  retrieve_similar_adus_btn.click(
373
+ fn=lambda document_store, ann_id, document, min_sim, k: document_store.get_similar_annotations_df(
374
+ ref_annotation_id=ann_id,
375
+ ref_document=document,
376
+ min_similarity=min_sim,
377
+ top_k=k,
378
+ annotation_layer="labeled_spans",
379
+ ),
380
  inputs=[
381
  document_store_state,
382
  selected_adu_id,
document_store.py CHANGED
@@ -6,21 +6,45 @@ from typing import Dict, List, Optional, Tuple
6
  import gradio as gr
7
  import pandas as pd
8
  from annotation_utils import labeled_span_to_id
9
- from pytorch_ie.annotations import LabeledSpan
10
- from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
 
 
11
  from vector_store import SimpleVectorStore, VectorStore
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
  def get_annotation_from_document(
17
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
  annotation_id: str,
19
  annotation_layer: str,
20
- ) -> LabeledSpan:
21
- # use predictions
22
- annotations = document[annotation_layer].predictions
23
- id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  annotation = id2annotation.get(annotation_id)
25
  if annotation is None:
26
  raise gr.Error(
@@ -30,53 +54,165 @@ def get_annotation_from_document(
30
  return annotation
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class DocumentStore:
 
 
 
 
34
 
35
- DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None):
 
 
 
 
 
 
 
 
 
 
 
38
  # The annotated documents. As key, we use the document id. All documents keep the embeddings
39
- # of the ADUs in the metadata.
40
- self.documents: Dict[
41
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
42
- ] = {}
43
- # The vector store to efficiently retrieve similar ADUs. Can be constructed from the
44
  # documents.
45
  self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
46
  vector_store or SimpleVectorStore()
47
  )
 
 
 
 
 
 
 
 
 
48
 
49
  def get_annotation(
50
  self,
51
  doc_id: str,
52
  annotation_id: str,
53
  annotation_layer: str,
54
- ) -> LabeledSpan:
 
55
  document = self.documents.get(doc_id)
56
  if document is None:
57
  raise gr.Error(
58
  f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
59
  )
60
- return get_annotation_from_document(document, annotation_id, annotation_layer)
 
 
61
 
62
- def get_similar_adus_df(
63
  self,
64
  ref_annotation_id: str,
65
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
66
- min_similarity: float,
67
- top_k: int,
68
  ) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  similar_entries = self.vector_store.retrieve_similar(
70
  ref_id=(ref_document.id, ref_annotation_id),
71
- min_similarity=min_similarity,
72
- top_k=top_k,
73
  )
74
 
75
  similar_annotations = [
76
  self.get_annotation(
77
  doc_id=doc_id,
78
  annotation_id=annotation_id,
79
- annotation_layer="labeled_spans",
 
80
  )
81
  for (doc_id, annotation_id), _ in similar_entries
82
  ]
@@ -89,20 +225,38 @@ class DocumentStore:
89
  similar_entries, similar_annotations
90
  )
91
  ],
92
- columns=["doc_id", "adu_id", "sim_score", "text"],
93
  )
94
 
95
  return df
96
 
97
- def get_relevant_adus_df(
98
  self,
99
  ref_annotation_id: str,
100
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
101
  min_similarity: float,
102
  top_k: int,
103
  relation_types: List[str],
104
  columns: List[str],
105
  ) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  similar_entries = self.vector_store.retrieve_similar(
107
  ref_id=(ref_document.id, ref_annotation_id),
108
  min_similarity=min_similarity,
@@ -114,41 +268,29 @@ class DocumentStore:
114
  if doc_id == ref_document.id:
115
  continue
116
  document = self.documents[doc_id]
117
- tail2rels = defaultdict(list)
118
- head2rels = defaultdict(list)
119
- for rel in document.binary_relations.predictions:
120
- # skip non-argumentative relations
121
- if rel.label not in relation_types:
122
- continue
123
- head2rels[rel.head].append(rel)
124
- tail2rels[rel.tail].append(rel)
125
-
126
- id2annotation = {
127
- labeled_span_to_id(annotation): annotation
128
- for annotation in document.labeled_spans.predictions
129
- }
130
- annotation = id2annotation.get(annotation_id)
131
- # note: we do not need to check if the annotation is different from the reference annotation,
132
- # because they come from different documents and we already skip entries from the same document
133
- for rel in head2rels.get(annotation, []):
134
- result.append(
135
- {
136
- "doc_id": doc_id,
137
- "reference_adu": str(annotation),
138
- "sim_score": score,
139
- "rel_score": rel.score,
140
- "relation": rel.label,
141
- "adu": str(rel.tail),
142
- }
143
- )
144
 
145
  # define column order
146
  df = pd.DataFrame(result, columns=columns)
147
  return df
148
 
149
- def add_document(
150
- self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
151
- ) -> None:
152
  try:
153
  if document.id in self.documents:
154
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
@@ -157,61 +299,53 @@ class DocumentStore:
157
  self.documents[document.id] = document
158
 
159
  # save the embeddings to the vector store
160
- for adu_id, embedding in document.metadata["embeddings"].items():
161
- self.vector_store.save((document.id, adu_id), embedding)
162
 
163
  except Exception as e:
164
  raise gr.Error(f"Failed to add document {document.id} to index: {e}")
165
 
166
  def add_document_from_dict(self, document_dict: dict) -> None:
167
- document = self.DOCUMENT_TYPE.fromdict(document_dict)
168
  # metadata is not automatically deserialized, so we need to set it manually
169
  document.metadata = document_dict["metadata"]
170
  self.add_document(document)
171
 
172
- def add_documents(
173
- self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
174
- ) -> None:
175
- size_before = len(self.documents)
176
  for document in documents:
177
  self.add_document(document)
178
- size_after = len(self.documents)
179
  gr.Info(
180
- f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
181
  )
182
 
183
- def add_from_json(self, file_path: str) -> None:
184
- size_before = len(self.documents)
185
  with open(file_path, "r", encoding="utf-8") as f:
186
- processed_documents_json = json.load(f)
187
- for _, document_json in processed_documents_json.items():
188
  self.add_document_from_dict(document_dict=document_json)
189
- size_after = len(self.documents)
190
  gr.Info(
191
- f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
192
  )
193
 
194
  def save_to_json(self, file_path: str, **kwargs) -> None:
195
  with open(file_path, "w", encoding="utf-8") as f:
196
  json.dump(self.as_dict(), f, **kwargs)
197
 
198
- def get_document(
199
- self, doc_id: str
200
- ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
201
  return self.documents[doc_id]
202
 
203
  def overview(self) -> pd.DataFrame:
204
- df = pd.DataFrame(
205
- [
206
- (
207
- doc_id,
208
- len(document.labeled_spans.predictions),
209
- len(document.binary_relations.predictions),
210
- )
211
- for doc_id, document in self.documents.items()
212
- ],
213
- columns=["doc_id", "num_adus", "num_relations"],
214
- )
215
  return df
216
 
217
  def as_dict(self) -> dict:
 
6
  import gradio as gr
7
  import pandas as pd
8
  from annotation_utils import labeled_span_to_id
9
+ from pytorch_ie import Annotation
10
+ from pytorch_ie.documents import (
11
+ TextBasedDocument,
12
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
13
+ )
14
  from vector_store import SimpleVectorStore, VectorStore
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
  def get_annotation_from_document(
20
+ document: TextBasedDocument,
21
  annotation_id: str,
22
  annotation_layer: str,
23
+ use_predictions: bool,
24
+ ) -> Annotation:
25
+ """Get an annotation from a document by its id. Note that the annotation id is constructed from
26
+ the annotation itself, so it is unique within the document.
27
+
28
+ Args:
29
+ document: The document to get the annotation from.
30
+ annotation_id: The id of the annotation.
31
+ annotation_layer: The name of the annotation layer.
32
+ use_predictions: Whether to use the predictions of the annotation layer.
33
+
34
+ Returns:
35
+ The annotation with the given id.
36
+ """
37
+
38
+ annotations = document[annotation_layer]
39
+ if use_predictions:
40
+ annotations = annotations.predictions
41
+
42
+ if annotation_layer == "labeled_spans":
43
+ annotation_to_id_func = labeled_span_to_id
44
+ else:
45
+ raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
46
+
47
+ id2annotation = {annotation_to_id_func(annotation): annotation for annotation in annotations}
48
  annotation = id2annotation.get(annotation_id)
49
  if annotation is None:
50
  raise gr.Error(
 
54
  return annotation
55
 
56
 
57
+ def get_related_annotation_records_from_document(
58
+ document: TextBasedDocument,
59
+ reference_annotation: Annotation,
60
+ relation_layer_name: str,
61
+ use_predictions: bool,
62
+ annotation_caption: str,
63
+ relation_types: Optional[List[str]] = None,
64
+ additional_static_columns: Optional[Dict[str, str]] = None,
65
+ ) -> List[Dict[str, str]]:
66
+ """Get related annotations from a document for a given reference annotation. The related
67
+ annotations are all annotations that are targets (tails) of relations with the reference
68
+ annotation as source (head).
69
+
70
+ Args:
71
+ document: The document to get the related annotations from.
72
+ reference_annotation: The reference annotation. Should be an annotation from the document.
73
+ relation_layer_name: The name of the relation layer.
74
+ use_predictions: Whether to use the predictions of the relation layer.
75
+ annotation_caption: The caption for the related annotations in the result.
76
+ relation_types: The types of relations to consider. If None, all relation types are considered.
77
+ additional_static_columns: Additional static columns to add to the result.
78
+
79
+ Returns:
80
+ A list of dictionaries with the related annotations and additional columns.
81
+ """
82
+
83
+ result = []
84
+
85
+ # get the relation layer
86
+ relation_layer = document[relation_layer_name]
87
+ if use_predictions:
88
+ relation_layer = relation_layer.predictions
89
+
90
+ # create helper dictionaries to quickly find related annotations
91
+ tail2rels = defaultdict(list)
92
+ head2rels = defaultdict(list)
93
+ for rel in relation_layer:
94
+ # skip non-argumentative relations
95
+ if relation_types is not None and rel.label not in relation_types:
96
+ continue
97
+ head2rels[rel.head].append(rel)
98
+ tail2rels[rel.tail].append(rel)
99
+
100
+ # get the related annotations: all annotations that are targets (tails) of relations with the reference
101
+ # annotation as source (head)
102
+ for rel in head2rels.get(reference_annotation, []):
103
+ result.append(
104
+ {
105
+ "doc_id": document.id,
106
+ f"reference_{annotation_caption}": str(reference_annotation),
107
+ "rel_score": rel.score,
108
+ "relation": rel.label,
109
+ annotation_caption: str(rel.tail),
110
+ **(additional_static_columns or {}),
111
+ }
112
+ )
113
+ return result
114
+
115
+
116
  class DocumentStore:
117
+ """A document store that allows to add, retrieve, and search for documents and annotations.
118
+
119
+ The store keeps the documents in memory and stores the embeddings of the labeled spans in a vector
120
+ store to efficiently retrieve similar or related spans.
121
 
122
+ Args:
123
+ vector_store: The vector store to use. If None, a new SimpleVectorStore is created.
124
+ document_type: The type of the documents to store. Should be a subclass of TextBasedDocument with
125
+ a span and a relation layer (see below).
126
+ span_layer_name: The name of the span annotation layer. This should be a valid annotation layer
127
+ of type LabelSpan in the document type.
128
+ relation_layer_name: The name of the argumentative relation annotation layer. This should be a
129
+ valid annotation layer of type BinaryRelation in the document type.
130
+ span_annotation_caption: The caption for the span annotations (e.g. in the statistical overview)
131
+ relation_annotation_caption: The caption for the relation annotations (e.g. in the statistical
132
+ overview)
133
+ use_predictions: Whether to use the predictions of the annotation layers. If True, the predictions
134
+ are used, otherwise the gold annotations are used.
135
+ """
136
 
137
+ def __init__(
138
+ self,
139
+ vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None,
140
+ document_type: type[
141
+ TextBasedDocument
142
+ ] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
143
+ span_layer_name: str = "labeled_spans",
144
+ relation_layer_name: str = "binary_relations",
145
+ span_annotation_caption: str = "span",
146
+ relation_annotation_caption: str = "relation",
147
+ use_predictions: bool = True,
148
+ ):
149
  # The annotated documents. As key, we use the document id. All documents keep the embeddings
150
+ # of the spans in the metadata.
151
+ self.documents: Dict[str, TextBasedDocument] = {}
152
+ # The vector store to efficiently retrieve similar spans. Can be constructed from the
 
 
153
  # documents.
154
  self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
155
  vector_store or SimpleVectorStore()
156
  )
157
+ # the document type (to create new documents from dicts)
158
+ self.document_type = document_type
159
+ self.span_layer_name = span_layer_name
160
+ self.relation_layer_name = relation_layer_name
161
+ self.use_predictions = use_predictions
162
+ self.layer_captions = {
163
+ self.span_layer_name: span_annotation_caption,
164
+ self.relation_layer_name: relation_annotation_caption,
165
+ }
166
 
167
  def get_annotation(
168
  self,
169
  doc_id: str,
170
  annotation_id: str,
171
  annotation_layer: str,
172
+ use_predictions: bool,
173
+ ) -> Annotation:
174
  document = self.documents.get(doc_id)
175
  if document is None:
176
  raise gr.Error(
177
  f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
178
  )
179
+ return get_annotation_from_document(
180
+ document, annotation_id, annotation_layer, use_predictions=use_predictions
181
+ )
182
 
183
+ def get_similar_annotations_df(
184
  self,
185
  ref_annotation_id: str,
186
+ ref_document: TextBasedDocument,
187
+ annotation_layer: str,
188
+ **similarity_kwargs,
189
  ) -> pd.DataFrame:
190
+ """Get similar annotations from documents in the store sorted by similarity. Usually, the
191
+ reference annotation is returned as the most similar annotation.
192
+
193
+ Args:
194
+ ref_annotation_id: The id of the reference annotation.
195
+ ref_document: The document of the reference annotation.
196
+ annotation_layer: The name of the annotation layer to consider.
197
+ **similarity_kwargs: Additional keyword arguments that will be passed to the vector
198
+ store to retrieve similar entries (see VectorStore.retrieve_similar()).
199
+
200
+ Returns:
201
+ A DataFrame with the similar annotations with columns: doc_id, annotation_id, sim_score,
202
+ and text.
203
+ """
204
+
205
  similar_entries = self.vector_store.retrieve_similar(
206
  ref_id=(ref_document.id, ref_annotation_id),
207
+ **similarity_kwargs,
 
208
  )
209
 
210
  similar_annotations = [
211
  self.get_annotation(
212
  doc_id=doc_id,
213
  annotation_id=annotation_id,
214
+ annotation_layer=annotation_layer,
215
+ use_predictions=self.use_predictions,
216
  )
217
  for (doc_id, annotation_id), _ in similar_entries
218
  ]
 
225
  similar_entries, similar_annotations
226
  )
227
  ],
228
+ columns=["doc_id", "annotation_id", "sim_score", "text"],
229
  )
230
 
231
  return df
232
 
233
+ def get_related_annotations_from_other_documents_df(
234
  self,
235
  ref_annotation_id: str,
236
+ ref_document: TextBasedDocument,
237
  min_similarity: float,
238
  top_k: int,
239
  relation_types: List[str],
240
  columns: List[str],
241
  ) -> pd.DataFrame:
242
+ """Get related annotations from documents in the store for a given reference annotation.
243
+ First, similar annotations are retrieved from the vector store. Then, annotations that are
244
+ linked to them via relations are returned. Only annotations from other documents are
245
+ considered.
246
+
247
+ Args:
248
+ ref_annotation_id: The id of the reference annotation.
249
+ ref_document: The document of the reference annotation.
250
+ min_similarity: The minimum similarity score to consider.
251
+ top_k: The number of related annotations to return.
252
+ relation_types: The types of relations to consider.
253
+ columns: The columns to include in the result DataFrame.
254
+
255
+ Returns:
256
+ A DataFrame with the columns that contain: the related annotation, the relation type,
257
+ the similar annotation, the similarity score, and the relation score.
258
+ """
259
+
260
  similar_entries = self.vector_store.retrieve_similar(
261
  ref_id=(ref_document.id, ref_annotation_id),
262
  min_similarity=min_similarity,
 
268
  if doc_id == ref_document.id:
269
  continue
270
  document = self.documents[doc_id]
271
+ reference_annotation = get_annotation_from_document(
272
+ document=document,
273
+ annotation_id=annotation_id,
274
+ annotation_layer=self.span_layer_name,
275
+ use_predictions=self.use_predictions,
276
+ )
277
+
278
+ new_entries = get_related_annotation_records_from_document(
279
+ document=document,
280
+ reference_annotation=reference_annotation,
281
+ relation_types=relation_types,
282
+ relation_layer_name=self.relation_layer_name,
283
+ use_predictions=self.use_predictions,
284
+ annotation_caption=self.layer_captions[self.span_layer_name],
285
+ additional_static_columns={"sim_score": str(score)},
286
+ )
287
+ result.extend(new_entries)
 
 
 
 
 
 
 
 
 
 
288
 
289
  # define column order
290
  df = pd.DataFrame(result, columns=columns)
291
  return df
292
 
293
+ def add_document(self, document: TextBasedDocument) -> None:
 
 
294
  try:
295
  if document.id in self.documents:
296
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
 
299
  self.documents[document.id] = document
300
 
301
  # save the embeddings to the vector store
302
+ for annotation_id, embedding in document.metadata["embeddings"].items():
303
+ self.vector_store.save((document.id, annotation_id), embedding)
304
 
305
  except Exception as e:
306
  raise gr.Error(f"Failed to add document {document.id} to index: {e}")
307
 
308
  def add_document_from_dict(self, document_dict: dict) -> None:
309
+ document = self.document_type.fromdict(document_dict)
310
  # metadata is not automatically deserialized, so we need to set it manually
311
  document.metadata = document_dict["metadata"]
312
  self.add_document(document)
313
 
314
+ def add_documents(self, documents: List[TextBasedDocument]) -> None:
 
 
 
315
  for document in documents:
316
  self.add_document(document)
 
317
  gr.Info(
318
+ f"Added {len(documents)} documents to the index ({len(self.documents)} documents in total)."
319
  )
320
 
321
+ def add_documents_from_json(self, file_path: str) -> None:
 
322
  with open(file_path, "r", encoding="utf-8") as f:
323
+ documents_json = json.load(f)
324
+ for _, document_json in documents_json.items():
325
  self.add_document_from_dict(document_dict=document_json)
 
326
  gr.Info(
327
+ f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
328
  )
329
 
330
  def save_to_json(self, file_path: str, **kwargs) -> None:
331
  with open(file_path, "w", encoding="utf-8") as f:
332
  json.dump(self.as_dict(), f, **kwargs)
333
 
334
+ def get_document(self, doc_id: str) -> TextBasedDocument:
 
 
335
  return self.documents[doc_id]
336
 
337
  def overview(self) -> pd.DataFrame:
338
+ rows = []
339
+ for doc_id, document in self.documents.items():
340
+ layers = {
341
+ caption: document[layer_name]
342
+ for layer_name, caption in self.layer_captions.items()
343
+ }
344
+ if self.use_predictions:
345
+ layers = {caption: layer.predictions for caption, layer in layers.items()}
346
+ layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
347
+ rows.append({"doc_id": doc_id, **layer_sizes})
348
+ df = pd.DataFrame(rows)
349
  return df
350
 
351
  def as_dict(self) -> dict: