khaerens commited on
Commit
ba0e651
·
1 Parent(s): 376bc0c
.vscode/settings.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "workbench.colorCustomizations": {
3
- "activityBar.background": "#630018",
4
- "titleBar.activeBackground": "#8A0121",
5
- "titleBar.activeForeground": "#FFFBFC"
6
  }
7
  }
 
1
  {
2
  "workbench.colorCustomizations": {
3
+ "activityBar.background": "#09323E",
4
+ "titleBar.activeBackground": "#0C4656",
5
+ "titleBar.activeForeground": "#F6FCFE"
6
  }
7
  }
__pycache__/app.cpython-38.pyc CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
 
__pycache__/rebel.cpython-38.pyc CHANGED
Binary files a/__pycache__/rebel.cpython-38.pyc and b/__pycache__/rebel.cpython-38.pyc differ
 
app.py CHANGED
@@ -14,7 +14,7 @@ network_filename = "test.html"
14
 
15
  state_variables = {
16
  'has_run':False,
17
- 'wiki_suggestions': "",
18
  'wiki_text' : [],
19
  'nodes':[]
20
  }
@@ -23,11 +23,10 @@ for k, v in state_variables.items():
23
  if k not in st.session_state:
24
  st.session_state[k] = v
25
 
26
- def clip_text(t, lenght = 5):
27
  return ".".join(t.split(".")[:lenght]) + "."
28
 
29
 
30
-
31
  def generate_graph():
32
  if 'wiki_text' not in st.session_state:
33
  return
@@ -42,12 +41,14 @@ def generate_graph():
42
  st.success('Done!')
43
 
44
  def show_suggestion():
45
- reset_session()
46
  with st.spinner(text="fetching wiki topics..."):
47
  if st.session_state['input_method'] == "wikipedia":
48
  text = st.session_state.text
49
  if text is not None:
50
- st.session_state['wiki_suggestions'] = wikipedia.search(text, results = 3)
 
 
51
 
52
  def show_wiki_text(page_title):
53
  with st.spinner(text="fetching wiki page..."):
@@ -64,7 +65,8 @@ def add_text(term):
64
  try:
65
  extra_text = clip_text(wikipedia.page(title=term, auto_suggest=True).summary)
66
  st.session_state['wiki_text'].append(extra_text)
67
- except wikipedia.DisambiguationError as e:
 
68
  st.session_state["nodes"].remove(term)
69
 
70
  def reset_session():
@@ -74,6 +76,17 @@ def reset_session():
74
  st.title('REBELious knowledge graph generation')
75
  st.session_state['input_method'] = "wikipedia"
76
 
 
 
 
 
 
 
 
 
 
 
 
77
  # st.selectbox(
78
  # 'input method',
79
  # ('wikipedia', 'free text'), key="input_method")
@@ -82,13 +95,25 @@ if st.session_state['input_method'] != "wikipedia":
82
  # st.text_area("Your text", key="text")
83
  pass
84
  else:
85
- st.text_input("wikipedia search term",on_change=show_suggestion, key="text")
 
 
 
 
 
 
86
 
87
  if len(st.session_state['wiki_suggestions']) != 0:
88
- columns = st.columns([1] * len(st.session_state['wiki_suggestions']))
89
- for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'])):
90
- with c:
91
- st.button(s, on_click=show_wiki_text, args=(s,), key=str(i)+s)
 
 
 
 
 
 
92
 
93
  if len(st.session_state['wiki_text']) != 0:
94
  for i, t in enumerate(st.session_state['wiki_text']):
@@ -102,17 +127,26 @@ if st.session_state['input_method'] != "wikipedia":
102
  # st.button("generate", on_click=generate_graph, key="gen_graph")
103
  pass
104
  else:
105
- st.button("generate", on_click=generate_graph, key="gen_graph")
 
106
 
107
 
108
  if st.session_state['has_run']:
109
- cols = st.columns([4, 1])
 
 
 
 
 
 
 
 
 
110
  with cols[0]:
111
  HtmlFile = open(network_filename, 'r', encoding='utf-8')
112
  source_code = HtmlFile.read()
113
  components.html(source_code, height=2000,width=2000)
114
  with cols[1]:
115
- st.text("expand")
116
  for i,s in enumerate(st.session_state["nodes"]):
117
  st.button(s, on_click=add_text, args=(s,), key=s+str(i))
118
 
 
14
 
15
  state_variables = {
16
  'has_run':False,
17
+ 'wiki_suggestions': [],
18
  'wiki_text' : [],
19
  'nodes':[]
20
  }
 
23
  if k not in st.session_state:
24
  st.session_state[k] = v
25
 
26
+ def clip_text(t, lenght = 10):
27
  return ".".join(t.split(".")[:lenght]) + "."
28
 
29
 
 
30
  def generate_graph():
31
  if 'wiki_text' not in st.session_state:
32
  return
 
41
  st.success('Done!')
42
 
43
  def show_suggestion():
44
+ st.session_state['wiki_suggestions'] = []
45
  with st.spinner(text="fetching wiki topics..."):
46
  if st.session_state['input_method'] == "wikipedia":
47
  text = st.session_state.text
48
  if text is not None:
49
+ subjects = text.split(",")
50
+ for subj in subjects:
51
+ st.session_state['wiki_suggestions'] += wikipedia.search(subj, results = 3)
52
 
53
  def show_wiki_text(page_title):
54
  with st.spinner(text="fetching wiki page..."):
 
65
  try:
66
  extra_text = clip_text(wikipedia.page(title=term, auto_suggest=True).summary)
67
  st.session_state['wiki_text'].append(extra_text)
68
+ except wikipedia.WikipediaException:
69
+ st.error("Woops, no wikipedia page for this node")
70
  st.session_state["nodes"].remove(term)
71
 
72
  def reset_session():
 
76
  st.title('REBELious knowledge graph generation')
77
  st.session_state['input_method'] = "wikipedia"
78
 
79
+ st.sidebar.markdown(
80
+ """
81
+ # how to
82
+ - Enter wikipedia search terms, separated by comma's
83
+ - Choose one or more of the suggested pages
84
+ - Click generate!
85
+ """
86
+ )
87
+
88
+ st.sidebar.button("Reset", on_click=reset_session, key="reset_key")
89
+
90
  # st.selectbox(
91
  # 'input method',
92
  # ('wikipedia', 'free text'), key="input_method")
 
95
  # st.text_area("Your text", key="text")
96
  pass
97
  else:
98
+ cols = st.columns([8, 1])
99
+ with cols[0]:
100
+ st.text_input("wikipedia search term", on_change=show_suggestion, key="text")
101
+ with cols[1]:
102
+ st.text('')
103
+ st.text('')
104
+ st.button("Search", on_click=show_suggestion, key="show_suggestion_key")
105
 
106
  if len(st.session_state['wiki_suggestions']) != 0:
107
+
108
+ num_cols = 10
109
+ num_buttons = len(st.session_state['wiki_suggestions'])
110
+ columns = st.columns([1] * num_cols + [1])
111
+ print(st.session_state['wiki_suggestions'])
112
+
113
+ for q in range(1 + num_buttons//num_cols):
114
+ for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'][q*num_cols: (q+1)*num_cols])):
115
+ with c:
116
+ st.button(s, on_click=show_wiki_text, args=(s,), key=str(i)+s)
117
 
118
  if len(st.session_state['wiki_text']) != 0:
119
  for i, t in enumerate(st.session_state['wiki_text']):
 
127
  # st.button("generate", on_click=generate_graph, key="gen_graph")
128
  pass
129
  else:
130
+ if len(st.session_state['wiki_text']) > 0:
131
+ st.button("Generate", on_click=generate_graph, key="gen_graph")
132
 
133
 
134
  if st.session_state['has_run']:
135
+ st.sidebar.markdown(
136
+ """
137
+ # How to expand the graph
138
+ - Click a button on the right to expand that node
139
+ - Only nodes that have wiki pages will be expanded
140
+ - Hit the Generate button again to expand your graph!
141
+ """
142
+ )
143
+
144
+ cols = st.columns([5, 1])
145
  with cols[0]:
146
  HtmlFile = open(network_filename, 'r', encoding='utf-8')
147
  source_code = HtmlFile.read()
148
  components.html(source_code, height=2000,width=2000)
149
  with cols[1]:
 
150
  for i,s in enumerate(st.session_state["nodes"]):
151
  st.button(s, on_click=add_text, args=(s,), key=s+str(i))
152
 
rebel.py CHANGED
@@ -30,7 +30,7 @@ DEFAULT_LABEL_COLORS = {
30
 
31
  def generate_knowledge_graph(texts: List[str], filename: str):
32
  nlp = spacy.load("en_core_web_sm")
33
- doc = nlp("\n".join(texts))
34
  NERs = [ent.text for ent in doc.ents]
35
  NER_types = [ent.label_ for ent in doc.ents]
36
  for nr, nrt in zip(NERs, NER_types):
@@ -40,8 +40,8 @@ def generate_knowledge_graph(texts: List[str], filename: str):
40
  for triplet in texts:
41
  triplets.extend(generate_partial_graph(triplet))
42
  print(generate_partial_graph.cache_info())
43
- heads = [ t["head"] for t in triplets]
44
- tails = [ t["tail"] for t in triplets]
45
 
46
  nodes = set(heads + tails)
47
  net = Network(directed=True)
@@ -55,10 +55,10 @@ def generate_knowledge_graph(texts: List[str], filename: str):
55
  net.add_node(n, shape="circle")
56
 
57
  unique_triplets = set()
58
- stringify_trip = lambda x : x["tail"] + x["head"] + x["type"]
59
  for triplet in triplets:
60
  if stringify_trip(triplet) not in unique_triplets:
61
- net.add_edge(triplet["tail"], triplet["head"], title=triplet["type"], label=triplet["type"])
62
  unique_triplets.add(stringify_trip(triplet))
63
 
64
  net.repulsion(
@@ -74,7 +74,8 @@ def generate_knowledge_graph(texts: List[str], filename: str):
74
 
75
 
76
  @lru_cache
77
- def generate_partial_graph(text):
 
78
  triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
79
  a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
80
  extracted_text = triplet_extractor.tokenizer.batch_decode(a)
 
30
 
31
  def generate_knowledge_graph(texts: List[str], filename: str):
32
  nlp = spacy.load("en_core_web_sm")
33
+ doc = nlp("\n".join(texts).lower())
34
  NERs = [ent.text for ent in doc.ents]
35
  NER_types = [ent.label_ for ent in doc.ents]
36
  for nr, nrt in zip(NERs, NER_types):
 
40
  for triplet in texts:
41
  triplets.extend(generate_partial_graph(triplet))
42
  print(generate_partial_graph.cache_info())
43
+ heads = [ t["head"].lower() for t in triplets]
44
+ tails = [ t["tail"].lower() for t in triplets]
45
 
46
  nodes = set(heads + tails)
47
  net = Network(directed=True)
 
55
  net.add_node(n, shape="circle")
56
 
57
  unique_triplets = set()
58
+ stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
59
  for triplet in triplets:
60
  if stringify_trip(triplet) not in unique_triplets:
61
+ net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), title=triplet["type"], label=triplet["type"])
62
  unique_triplets.add(stringify_trip(triplet))
63
 
64
  net.repulsion(
 
74
 
75
 
76
  @lru_cache
77
+ def generate_partial_graph(text: str):
78
+ print(text[0:20], hash(text))
79
  triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
80
  a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
81
  extracted_text = triplet_extractor.tokenizer.batch_decode(a)