R-ai-den fabiochiu commited on
Commit
5a7a278
0 Parent(s):

Duplicate from fabiochiu/text-to-kb

Browse files

Co-authored-by: Fabio Chiusano <fabiochiu@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ myvenv
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text To Kb
3
+ emoji: 📉
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.9.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: fabiochiu/text-to-kb
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
__pycache__/kb.cpython-37.pyc ADDED
Binary file (4.07 kB). View file
 
__pycache__/utils.cpython-37.pyc ADDED
Binary file (5.13 kB). View file
 
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ import utils
5
+ from kb import KB
6
+
7
+ texts = {
8
+ "Napoleon": "Napoleon Bonaparte (born Napoleone di Buonaparte; 15 August 1769 – 5 May 1821), and later known by his regnal name Napoleon I, was a French military and political leader who rose to prominence during the French Revolution and led several successful campaigns during the Revolutionary Wars. He was the de facto leader of the French Republic as First Consul from 1799 to 1804. As Napoleon I, he was Emperor of the French from 1804 until 1814 and again in 1815. Napoleon's political and cultural legacy has endured, and he has been one of the most celebrated and controversial leaders in world history.",
9
+ "Kobe Bryant": "Kobe Bean Bryant (August 23, 1978 – January 26, 2020) was an American professional basketball player. A shooting guard, he spent his entire 20-year career with the Los Angeles Lakers in the National Basketball Association (NBA). Widely regarded as one of the greatest basketball players of all time, Bryant won five NBA championships, was an 18-time All-Star, a 15-time member of the All-NBA Team, a 12-time member of the All-Defensive Team, the 2008 NBA Most Valuable Player (MVP), and a two-time NBA Finals MVP. Bryant also led the NBA in scoring twice, and ranks fourth in league all-time regular season and postseason scoring. He was posthumously voted into the Naismith Memorial Basketball Hall of Fame in 2020 and named to the NBA 75th Anniversary Team in 2021.",
10
+ "Google": "Originally known as BackRub. Google is a search engine that started development in 1996 by Sergey Brin and Larry Page as a research project at Stanford University to find files on the Internet. Larry and Sergey later decided the name of their search engine needed to change and chose Google, which is inspired from the term googol. The company is headquartered in Mountain View, California."
11
+ }
12
+
13
+ urls = {
14
+ "Crypto": "https://www.investopedia.com/terms/c/cryptocurrency.asp",
15
+ "Jhonny Depp": "https://www.britannica.com/biography/Johnny-Depp",
16
+ "Rome": "https://www.timeout.com/rome/things-to-do/best-things-to-do-in-rome"
17
+ }
18
+
19
+ st.header("Extracting a Knowledge Base from text")
20
+
21
+ # sidebar
22
+ with st.sidebar:
23
+ st.markdown("_Read the accompanying article [Building a Knowledge Base from Texts: a Full Practical Example](https://medium.com/nlplanet/building-a-knowledge-base-from-texts-a-full-practical-example-8dbbffb912fa)_")
24
+ st.header("What is a Knowledge Base")
25
+ st.markdown("A [**Knowledge Base (KB)**](https://en.wikipedia.org/wiki/Knowledge_base) is information stored in structured data, ready to be used for analysis or inference. Usually a KB is stored as a graph (i.e. a [**Knowledge Graph**](https://www.ibm.com/cloud/learn/knowledge-graph)), where nodes are **entities** and edges are **relations** between entities.")
26
+ st.markdown("_For example, from the text \"Fabio lives in Italy\" we can extract the relation triplet <Fabio, lives in, Italy>, where \"Fabio\" and \"Italy\" are entities._")
27
+ st.header("How to build a Knowledge Graph")
28
+ st.markdown("To build a Knowledge Graph from text, we typically need to perform two steps:\n- Extract entities, a.k.a. **Named Entity Recognition (NER)**, i.e. the nodes.\n- Extract relations between the entities, a.k.a. **Relation Classification (RC)**, i.e. the edges.\nRecently, end-to-end approaches have been proposed to tackle both tasks simultaneously. This task is usually referred to as **Relation Extraction (RE)**. In this demo, an end-to-end model called [**REBEL**](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf) is used, trained by [Babelscape](https://babelscape.com/).")
29
+ st.header("How REBEL works")
30
+ st.markdown("REBEL is a **text2text** model obtained by fine-tuning [**BART**](https://huggingface.co/docs/transformers/model_doc/bart) for translating a raw input sentence containing entities and implicit relations into a set of triplets that explicitly refer to those relations. You can find [REBEL in the Hugging Face Hub](https://huggingface.co/Babelscape/rebel-large).")
31
+ st.header("Further steps")
32
+ st.markdown("Even though they are not visualized, the knowledge graph saves information about the provenience of each relation (e.g. from which articles it has been extracted and other metadata), along with Wikipedia data about each entity.")
33
+ st.markdown("Other libraries used:\n- [wikipedia](https://pypi.org/project/wikipedia/): For validating extracted entities checking if they have a corresponding Wikipedia page.\n- [newspaper](https://github.com/codelucas/newspaper): For parsing articles from URLs.\n- [pyvis](https://pyvis.readthedocs.io/en/latest/index.html): For graphs visualizations.\n- [GoogleNews](https://github.com/Iceloof/GoogleNews): For reading Google News latest articles about a topic.")
34
+ st.header("Considerations")
35
+ st.markdown("If you look closely at the extracted knowledge graphs, some extracted relations are false. Indeed, relation extraction models are still far from perfect and require further steps in the pipeline to build reliable knowledge graphs. Consider this demo as a starting step!")
36
+
37
+ # Loading the model
38
+ st_model_load = st.text('Loading NER model... It may take a while.')
39
+
40
+ @st.cache(allow_output_mutation=True)
41
+ def load_model():
42
+ print("Loading model...")
43
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
44
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
45
+ print("Model loaded!")
46
+ return tokenizer, model
47
+
48
+ tokenizer, model = load_model()
49
+ st.success('Model loaded!')
50
+ st_model_load.text("")
51
+
52
+ # Choose from where to generate the KB
53
+ options = [
54
+ "Text",
55
+ "Article at URL",
56
+ "Multiple news articles"
57
+ ]
58
+ if 'option' not in st.session_state:
59
+ st.session_state.option = options[0]
60
+ option = st.selectbox('Build a Knowledge Base from:', options, index=options.index(st.session_state.option))
61
+
62
+ text_option, text = None, None
63
+ url_option, url = None, None
64
+ news_option = None
65
+
66
+ if option == "Text":
67
+ text_options = [
68
+ "Napoleon",
69
+ "Kobe Bryant",
70
+ "Google",
71
+ "Free text"
72
+ ]
73
+ if 'text_option' not in st.session_state or st.session_state.text_option is None:
74
+ st.session_state.text_option = text_options[0]
75
+ text_option = st.selectbox('Choose text option:', text_options, index=text_options.index(st.session_state.text_option))
76
+
77
+ disabled = False
78
+ if text_option != "Free text":
79
+ disabled = True
80
+ text = texts[text_option]
81
+ else:
82
+ if 'text' not in st.session_state:
83
+ st.session_state.text = ""
84
+ text = st.session_state.text
85
+ text = st.text_area('Text:', value=text, height=300, disabled=disabled, max_chars=10000)
86
+ elif option == "Article at URL":
87
+ url_options = [
88
+ "Crypto",
89
+ "Jhonny Depp",
90
+ "Rome",
91
+ "Free URL"
92
+ ]
93
+ if 'url_option' not in st.session_state or st.session_state.url_option is None:
94
+ st.session_state.url_option = url_options[0]
95
+ url_option = st.selectbox('Choose URL option:', url_options, index=url_options.index(st.session_state.url_option))
96
+
97
+ disabled = False
98
+ if url_option != "Free URL":
99
+ disabled = True
100
+ url = urls[url_option]
101
+ else:
102
+ if 'url' not in st.session_state:
103
+ st.session_state.url = ""
104
+ url = st.session_state.url
105
+ url = st.text_input('URL:', value=url, disabled=disabled)
106
+ else:
107
+ news_options = [
108
+ "Google",
109
+ "Apple",
110
+ "Elon Musk",
111
+ "Kobe Bryant"
112
+ ]
113
+ if 'news_option' not in st.session_state or st.session_state.news_option is None:
114
+ st.session_state.news_option = news_options[0]
115
+ news_option = st.selectbox('Use articles about:', news_options, index=news_options.index(st.session_state.news_option))
116
+
117
+ def generate_kb():
118
+ st.session_state.option = option
119
+ st.session_state.text_option = text_option
120
+ st.session_state.text = text
121
+ st.session_state.url_option = url_option
122
+ st.session_state.url = url
123
+ st.session_state.news_option = news_option
124
+
125
+ # load correct kb
126
+ if option == "Text":
127
+ if text_option == "Napoleon":
128
+ kb = utils.load_kb("networks/network_1_napoleon.p")
129
+ elif text_option == "Kobe Bryant":
130
+ kb = utils.load_kb("networks/network_1_bryant.p")
131
+ elif text_option == "Google":
132
+ kb = utils.load_kb("networks/network_1_google.p")
133
+ else:
134
+ kb = utils.from_text_to_kb(text, model, tokenizer, "", verbose=True)
135
+ elif option == "Article at URL":
136
+ if url_option == "Crypto":
137
+ kb = utils.load_kb("networks/network_2_crypto.p")
138
+ elif url_option == "Jhonny Depp":
139
+ kb = utils.load_kb("networks/network_2_depp.p")
140
+ elif url_option == "Rome":
141
+ kb = utils.load_kb("networks/network_2_rome.p")
142
+ else:
143
+ try:
144
+ kb = utils.from_url_to_kb(url, model, tokenizer)
145
+ except Exception as e:
146
+ print("Couldn't extract article from URL")
147
+ st.session_state.error_url = "Couldn't extract article from URL"
148
+ return
149
+ else:
150
+ if news_option == "Google":
151
+ kb = utils.load_kb("networks/network_3_google.p")
152
+ elif news_option == "Apple":
153
+ kb = utils.load_kb("networks/network_3_apple.p")
154
+ elif news_option == "Elon Musk":
155
+ kb = utils.load_kb("networks/network_3_musk.p")
156
+ elif news_option == "Kobe Bryant":
157
+ kb = utils.load_kb("networks/network_3_bryant.p")
158
+
159
+ # save chart
160
+ utils.save_network_html(kb, filename="networks/network.html")
161
+ st.session_state.kb_chart = "networks/network.html"
162
+ st.session_state.kb_text = kb.get_textual_representation()
163
+ st.session_state.error_url = None
164
+
165
+
166
+ st.session_state.option = option
167
+ st.session_state.text_option = text_option
168
+ st.session_state.text = text
169
+ st.session_state.url_option = url_option
170
+ st.session_state.url = url
171
+ st.session_state.news_option = news_option
172
+
173
+ button_text = "Show KB"
174
+ if (option == "Text" and text_option == "Free text") or (option == "Article at URL" and url_option == "Free URL"):
175
+ button_text = "Generate KB"
176
+
177
+ # generate KB button
178
+ st.button(button_text, on_click=generate_kb)
179
+
180
+ # kb chart session state
181
+ if 'kb_chart' not in st.session_state:
182
+ st.session_state.kb_chart = None
183
+ if 'kb_text' not in st.session_state:
184
+ st.session_state.kb_text = None
185
+ if 'error_url' not in st.session_state:
186
+ st.session_state.error_url = None
187
+
188
+ # show graph
189
+ if st.session_state.error_url:
190
+ st.markdown(st.session_state.error_url)
191
+ elif st.session_state.kb_chart:
192
+ with st.container():
193
+ st.subheader("Generated KB")
194
+ st.markdown("*You can interact with the graph and zoom.*")
195
+ html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read()
196
+ components.html(html_source_code, width=700, height=700)
197
+ st.markdown(st.session_state.kb_text)
kb.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipedia
2
+
3
+ class KB():
4
+ def __init__(self):
5
+ self.entities = {} # { entity_title: {...} }
6
+ self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
7
+ # meta: { article_url: { spans: [...] } } ]
8
+ self.sources = {} # { article_url: {...} }
9
+
10
+ def merge_with_kb(self, kb2):
11
+ for r in kb2.relations:
12
+ article_url = list(r["meta"].keys())[0]
13
+ source_data = kb2.sources[article_url]
14
+ self.add_relation(r, source_data["article_title"],
15
+ source_data["article_publish_date"])
16
+
17
+ def are_relations_equal(self, r1, r2):
18
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
19
+
20
+ def exists_relation(self, r1):
21
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
22
+
23
+ def merge_relations(self, r2):
24
+ r1 = [r for r in self.relations
25
+ if self.are_relations_equal(r2, r)][0]
26
+
27
+ # if different article
28
+ article_url = list(r2["meta"].keys())[0]
29
+ if article_url not in r1["meta"]:
30
+ r1["meta"][article_url] = r2["meta"][article_url]
31
+
32
+ # if existing article
33
+ else:
34
+ spans_to_add = [span for span in r2["meta"][article_url]["spans"]
35
+ if span not in r1["meta"][article_url]["spans"]]
36
+ r1["meta"][article_url]["spans"] += spans_to_add
37
+
38
+ def get_wikipedia_data(self, candidate_entity):
39
+ try:
40
+ page = wikipedia.page(candidate_entity, auto_suggest=False)
41
+ entity_data = {
42
+ "title": page.title,
43
+ "url": page.url,
44
+ "summary": page.summary
45
+ }
46
+ return entity_data
47
+ except:
48
+ return None
49
+
50
+ def add_entity(self, e):
51
+ self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
52
+
53
+ def add_relation(self, r, article_title, article_publish_date):
54
+ # check on wikipedia
55
+ candidate_entities = [r["head"], r["tail"]]
56
+ entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
57
+
58
+ # if one entity does not exist, stop
59
+ if any(ent is None for ent in entities):
60
+ return
61
+
62
+ # manage new entities
63
+ for e in entities:
64
+ self.add_entity(e)
65
+
66
+ # rename relation entities with their wikipedia titles
67
+ r["head"] = entities[0]["title"]
68
+ r["tail"] = entities[1]["title"]
69
+
70
+ # add source if not in kb
71
+ article_url = list(r["meta"].keys())[0]
72
+ if article_url not in self.sources:
73
+ self.sources[article_url] = {
74
+ "article_title": article_title,
75
+ "article_publish_date": article_publish_date
76
+ }
77
+
78
+ # manage new relation
79
+ if not self.exists_relation(r):
80
+ self.relations.append(r)
81
+ else:
82
+ self.merge_relations(r)
83
+
84
+ def get_textual_representation(self):
85
+ res = ""
86
+ res += "### Entities\n"
87
+ for e in self.entities.items():
88
+ # shorten summary
89
+ e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
90
+ res += f"- {e_temp}\n"
91
+ res += "\n"
92
+ res += "### Relations\n"
93
+ for r in self.relations:
94
+ res += f"- {r}\n"
95
+ res += "\n"
96
+ res += "### Sources\n"
97
+ for s in self.sources.items():
98
+ res += f"- {s}\n"
99
+ return res
networks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
networks/network.html ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/vis-network@latest/styles/vis-network.css" type="text/css" />
4
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/vis-network@latest/dist/vis-network.min.js"> </script>
5
+ <center>
6
+ <h1></h1>
7
+ </center>
8
+
9
+ <!-- <link rel="stylesheet" href="../node_modules/vis/dist/vis.min.css" type="text/css" />
10
+ <script type="text/javascript" src="../node_modules/vis/dist/vis.js"> </script>-->
11
+
12
+ <style type="text/css">
13
+
14
+ #mynetwork {
15
+ width: 700px;
16
+ height: 700px;
17
+ background-color: #ffffff;
18
+ border: 1px solid lightgray;
19
+ position: relative;
20
+ float: left;
21
+ }
22
+
23
+
24
+
25
+
26
+
27
+
28
+ </style>
29
+
30
+ </head>
31
+
32
+ <body>
33
+ <div id = "mynetwork"></div>
34
+
35
+
36
+ <script type="text/javascript">
37
+
38
+ // initialize global variables.
39
+ var edges;
40
+ var nodes;
41
+ var network;
42
+ var container;
43
+ var options, data;
44
+
45
+
46
+ // This method is responsible for drawing the graph, returns the drawn network
47
+ function drawGraph() {
48
+ var container = document.getElementById('mynetwork');
49
+
50
+
51
+
52
+ // parsing and collecting nodes and edges from the python
53
+ nodes = new vis.DataSet([{"color": "#00FF00", "id": "Napoleon", "label": "Napoleon", "shape": "circle"}, {"color": "#00FF00", "id": "French Revolution", "label": "French Revolution", "shape": "circle"}, {"color": "#00FF00", "id": "France", "label": "France", "shape": "circle"}]);
54
+ edges = new vis.DataSet([{"arrows": "to", "from": "Napoleon", "label": "participant in", "title": "participant in", "to": "French Revolution"}, {"arrows": "to", "from": "French Revolution", "label": "participant", "title": "participant", "to": "Napoleon"}, {"arrows": "to", "from": "French Revolution", "label": "country", "title": "country", "to": "France"}]);
55
+
56
+ // adding nodes and edges to the graph
57
+ data = {nodes: nodes, edges: edges};
58
+
59
+ var options = {
60
+ "configure": {
61
+ "enabled": false
62
+ },
63
+ "edges": {
64
+ "color": {
65
+ "inherit": true
66
+ },
67
+ "smooth": {
68
+ "enabled": true,
69
+ "type": "dynamic"
70
+ }
71
+ },
72
+ "interaction": {
73
+ "dragNodes": true,
74
+ "hideEdgesOnDrag": false,
75
+ "hideNodesOnDrag": false
76
+ },
77
+ "physics": {
78
+ "enabled": true,
79
+ "repulsion": {
80
+ "centralGravity": 0.2,
81
+ "damping": 0.09,
82
+ "nodeDistance": 200,
83
+ "springConstant": 0.05,
84
+ "springLength": 200
85
+ },
86
+ "solver": "repulsion",
87
+ "stabilization": {
88
+ "enabled": true,
89
+ "fit": true,
90
+ "iterations": 1000,
91
+ "onlyDynamicEdges": false,
92
+ "updateInterval": 50
93
+ }
94
+ }
95
+ };
96
+
97
+
98
+
99
+
100
+
101
+ network = new vis.Network(container, data, options);
102
+
103
+
104
+
105
+
106
+
107
+
108
+ return network;
109
+
110
+ }
111
+
112
+ drawGraph();
113
+
114
+ </script>
115
+ </body>
116
+ </html>
networks/network_1_bryant.p ADDED
Binary file (20.9 kB). View file
 
networks/network_1_google.p ADDED
Binary file (11.2 kB). View file
 
networks/network_1_napoleon.p ADDED
Binary file (11.9 kB). View file
 
networks/network_2_crypto.p ADDED
Binary file (37.7 kB). View file
 
networks/network_2_depp.p ADDED
Binary file (7.83 kB). View file
 
networks/network_2_rome.p ADDED
Binary file (4.92 kB). View file
 
networks/network_3_amazon.p ADDED
Binary file (153 kB). View file
 
networks/network_3_apple.p ADDED
Binary file (227 kB). View file
 
networks/network_3_bryant.p ADDED
Binary file (185 kB). View file
 
networks/network_3_google.p ADDED
Binary file (190 kB). View file
 
networks/network_3_musk.p ADDED
Binary file (113 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pyvis
4
+ GoogleNews
5
+ newspaper3k
6
+ wikipedia
utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyvis.network import Network
2
+ from GoogleNews import GoogleNews
3
+ from newspaper import Article, ArticleException
4
+ import math
5
+ import torch
6
+ from kb import KB
7
+ import pickle
8
+
9
+ def extract_relations_from_model_output(text):
10
+ relations = []
11
+ relation, subject, relation, object_ = '', '', '', ''
12
+ text = text.strip()
13
+ current = 'x'
14
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
15
+ for token in text_replaced.split():
16
+ if token == "<triplet>":
17
+ current = 't'
18
+ if relation != '':
19
+ relations.append({
20
+ 'head': subject.strip(),
21
+ 'type': relation.strip(),
22
+ 'tail': object_.strip()
23
+ })
24
+ relation = ''
25
+ subject = ''
26
+ elif token == "<subj>":
27
+ current = 's'
28
+ if relation != '':
29
+ relations.append({
30
+ 'head': subject.strip(),
31
+ 'type': relation.strip(),
32
+ 'tail': object_.strip()
33
+ })
34
+ object_ = ''
35
+ elif token == "<obj>":
36
+ current = 'o'
37
+ relation = ''
38
+ else:
39
+ if current == 't':
40
+ subject += ' ' + token
41
+ elif current == 's':
42
+ object_ += ' ' + token
43
+ elif current == 'o':
44
+ relation += ' ' + token
45
+ if subject != '' and relation != '' and object_ != '':
46
+ relations.append({
47
+ 'head': subject.strip(),
48
+ 'type': relation.strip(),
49
+ 'tail': object_.strip()
50
+ })
51
+ return relations
52
+
53
+ def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
54
+ article_publish_date=None, verbose=False):
55
+ # tokenize whole text
56
+ inputs = tokenizer([text], return_tensors="pt")
57
+
58
+ # compute span boundaries
59
+ num_tokens = len(inputs["input_ids"][0])
60
+ if verbose:
61
+ print(f"Input has {num_tokens} tokens")
62
+ num_spans = math.ceil(num_tokens / span_length)
63
+ if verbose:
64
+ print(f"Input has {num_spans} spans")
65
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
66
+ max(num_spans - 1, 1))
67
+ spans_boundaries = []
68
+ start = 0
69
+ for i in range(num_spans):
70
+ spans_boundaries.append([start + span_length * i,
71
+ start + span_length * (i + 1)])
72
+ start -= overlap
73
+ if verbose:
74
+ print(f"Span boundaries are {spans_boundaries}")
75
+
76
+ # transform input with spans
77
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
78
+ for boundary in spans_boundaries]
79
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
80
+ for boundary in spans_boundaries]
81
+ inputs = {
82
+ "input_ids": torch.stack(tensor_ids),
83
+ "attention_mask": torch.stack(tensor_masks)
84
+ }
85
+
86
+ # generate relations
87
+ num_return_sequences = 3
88
+ gen_kwargs = {
89
+ "max_length": 256,
90
+ "length_penalty": 0,
91
+ "num_beams": 3,
92
+ "num_return_sequences": num_return_sequences
93
+ }
94
+ generated_tokens = model.generate(
95
+ **inputs,
96
+ **gen_kwargs,
97
+ )
98
+
99
+ # decode relations
100
+ decoded_preds = tokenizer.batch_decode(generated_tokens,
101
+ skip_special_tokens=False)
102
+
103
+ # create kb
104
+ kb = KB()
105
+ i = 0
106
+ for sentence_pred in decoded_preds:
107
+ current_span_index = i // num_return_sequences
108
+ relations = extract_relations_from_model_output(sentence_pred)
109
+ for relation in relations:
110
+ relation["meta"] = {
111
+ article_url: {
112
+ "spans": [spans_boundaries[current_span_index]]
113
+ }
114
+ }
115
+ kb.add_relation(relation, article_title, article_publish_date)
116
+ i += 1
117
+
118
+ return kb
119
+
120
+ def get_article(url):
121
+ article = Article(url)
122
+ article.download()
123
+ article.parse()
124
+ return article
125
+
126
+ def from_url_to_kb(url, model, tokenizer):
127
+ article = get_article(url)
128
+ config = {
129
+ "article_title": article.title,
130
+ "article_publish_date": article.publish_date
131
+ }
132
+ kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
133
+ return kb
134
+
135
+ def get_news_links(query, lang="en", region="US", pages=1):
136
+ googlenews = GoogleNews(lang=lang, region=region)
137
+ googlenews.search(query)
138
+ all_urls = []
139
+ for page in range(pages):
140
+ googlenews.get_page(page)
141
+ all_urls += googlenews.get_links()
142
+ return list(set(all_urls))
143
+
144
+ def from_urls_to_kb(urls, model, tokenizer, verbose=False):
145
+ kb = KB()
146
+ if verbose:
147
+ print(f"{len(urls)} links to visit")
148
+ for url in urls:
149
+ if verbose:
150
+ print(f"Visiting {url}...")
151
+ try:
152
+ kb_url = from_url_to_kb(url, model, tokenizer)
153
+ kb.merge_with_kb(kb_url)
154
+ except ArticleException:
155
+ if verbose:
156
+ print(f" Couldn't download article at url {url}")
157
+ return kb
158
+
159
+ def save_network_html(kb, filename="network.html"):
160
+ # create network
161
+ net = Network(directed=True, width="700px", height="700px")
162
+
163
+ # nodes
164
+ color_entity = "#00FF00"
165
+ for e in kb.entities:
166
+ net.add_node(e, shape="circle", color=color_entity)
167
+
168
+ # edges
169
+ for r in kb.relations:
170
+ net.add_edge(r["head"], r["tail"],
171
+ title=r["type"], label=r["type"])
172
+
173
+ # save network
174
+ net.repulsion(
175
+ node_distance=200,
176
+ central_gravity=0.2,
177
+ spring_length=200,
178
+ spring_strength=0.05,
179
+ damping=0.09
180
+ )
181
+ net.set_edge_smooth('dynamic')
182
+ net.show(filename)
183
+
184
+ def save_kb(kb, filename):
185
+ with open(filename, "wb") as f:
186
+ pickle.dump(kb, f)
187
+
188
+ class CustomUnpickler(pickle.Unpickler):
189
+ def find_class(self, module, name):
190
+ if name == 'KB':
191
+ return KB
192
+ return super().find_class(module, name)
193
+
194
+ def load_kb(filename):
195
+ res = None
196
+ with open(filename, "rb") as f:
197
+ res = CustomUnpickler(f).load()
198
+ return res