BanUrsus commited on
Commit
8e41ed9
·
verified ·
1 Parent(s): 8485e4f

Update app.py

Browse files

Add Counting, PostingList, and InvertedIndex

Files changed (1) hide show
  1. app.py +134 -0
app.py CHANGED
@@ -7,6 +7,140 @@ import gradio as gr
7
  from typing import TypedDict
8
  from dataclasses import dataclass
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  sciq = load_sciq()
11
  counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
12
 
 
7
  from typing import TypedDict
8
  from dataclasses import dataclass
9
 
10
+ from dataclasses import dataclass
11
+ import pickle
12
+ import os
13
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
14
+ from nlp4web_codebase.ir.data_loaders.dm import Document
15
+ from collections import Counter
16
+ import tqdm
17
+ import re
18
+ import nltk
19
+ nltk.download("stopwords", quiet=True)
20
+ from nltk.corpus import stopwords as nltk_stopwords
21
+
22
+ LANGUAGE = "english"
23
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
24
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
25
+
26
+
27
+ def word_splitting(text: str) -> List[str]:
28
+ return word_splitter(text.lower())
29
+
30
+ def lemmatization(words: List[str]) -> List[str]:
31
+ return words # We ignore lemmatization here for simplicity
32
+
33
+ def simple_tokenize(text: str) -> List[str]:
34
+ words = word_splitting(text)
35
+ tokenized = list(filter(lambda w: w not in stopwords, words))
36
+ tokenized = lemmatization(tokenized)
37
+ return tokenized
38
+
39
+ T = TypeVar("T", bound="InvertedIndex")
40
+
41
+ @dataclass
42
+ class PostingList:
43
+ term: str # The term
44
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
45
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
46
+
47
+
48
+ @dataclass
49
+ class InvertedIndex:
50
+ posting_lists: List[PostingList] # docid -> posting_list
51
+ vocab: Dict[str, int]
52
+ cid2docid: Dict[str, int] # collection_id -> docid
53
+ collection_ids: List[str] # docid -> collection_id
54
+ doc_texts: Optional[List[str]] = None # docid -> document text
55
+
56
+ def save(self, output_dir: str) -> None:
57
+ os.makedirs(output_dir, exist_ok=True)
58
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
59
+ pickle.dump(self, f)
60
+
61
+ @classmethod
62
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
63
+ index = cls(
64
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
65
+ )
66
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
67
+ index = pickle.load(f)
68
+ return index
69
+
70
+
71
+ # The output of the counting function:
72
+ @dataclass
73
+ class Counting:
74
+ posting_lists: List[PostingList]
75
+ vocab: Dict[str, int]
76
+ cid2docid: Dict[str, int]
77
+ collection_ids: List[str]
78
+ dfs: List[int] # tid -> df
79
+ dls: List[int] # docid -> doc length
80
+ avgdl: float
81
+ nterms: int
82
+ doc_texts: Optional[List[str]] = None
83
+
84
+ def run_counting(
85
+ documents: Iterable[Document],
86
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
87
+ store_raw: bool = True, # store the document text in doc_texts
88
+ ndocs: Optional[int] = None,
89
+ show_progress_bar: bool = True,
90
+ ) -> Counting:
91
+ """Counting TFs, DFs, doc_lengths, etc."""
92
+ posting_lists: List[PostingList] = []
93
+ vocab: Dict[str, int] = {}
94
+ cid2docid: Dict[str, int] = {}
95
+ collection_ids: List[str] = []
96
+ dfs: List[int] = [] # tid -> df
97
+ dls: List[int] = [] # docid -> doc length
98
+ nterms: int = 0
99
+ doc_texts: Optional[List[str]] = []
100
+ for doc in tqdm.tqdm(
101
+ documents,
102
+ desc="Counting",
103
+ total=ndocs,
104
+ disable=not show_progress_bar,
105
+ ):
106
+ if doc.collection_id in cid2docid:
107
+ continue
108
+ collection_ids.append(doc.collection_id)
109
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
110
+ toks = tokenize_fn(doc.text)
111
+ tok2tf = Counter(toks)
112
+ dls.append(sum(tok2tf.values()))
113
+ for tok, tf in tok2tf.items():
114
+ nterms += tf
115
+ tid = vocab.get(tok, None)
116
+ if tid is None:
117
+ posting_lists.append(
118
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
119
+ )
120
+ tid = vocab.setdefault(tok, len(vocab))
121
+ posting_lists[tid].docid_postings.append(docid)
122
+ posting_lists[tid].tweight_postings.append(tf)
123
+ if tid < len(dfs):
124
+ dfs[tid] += 1
125
+ else:
126
+ dfs.append(0)
127
+ if store_raw:
128
+ doc_texts.append(doc.text)
129
+ else:
130
+ doc_texts = None
131
+ return Counting(
132
+ posting_lists=posting_lists,
133
+ vocab=vocab,
134
+ cid2docid=cid2docid,
135
+ collection_ids=collection_ids,
136
+ dfs=dfs,
137
+ dls=dls,
138
+ avgdl=sum(dls) / len(dls),
139
+ nterms=nterms,
140
+ doc_texts=doc_texts,
141
+ )
142
+
143
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
144
  sciq = load_sciq()
145
  counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
146