badongtakla commited on
Commit
e178154
·
1 Parent(s): b08025b

add chronological attribution output

Browse files
Files changed (1) hide show
  1. app.py +48 -6
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import subprocess
2
  import jinja2
3
  import gradio
 
 
 
 
4
 
5
  subprocess.run(
6
  ["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
@@ -26,6 +30,44 @@ from ithaca.models.model import Model
26
  from ithaca.util.alphabet import GreekAlphabet
27
  import jax
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_subregion_name(id, region_map):
30
  return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
31
 
@@ -64,10 +106,7 @@ def load_checkpoint(path):
64
 
65
  return checkpoint['model_config'], region_map, alphabet, params, forward
66
 
67
-
68
  def main(text):
69
-
70
-
71
  restore_template = jinja2.Template("""<!DOCTYPE html>
72
  <html>
73
  <head>
@@ -207,18 +246,21 @@ def main(text):
207
  vocab_word_size=vocab_word_size)
208
 
209
  prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
210
-
211
  attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
 
 
 
 
212
  return restore_template.render(
213
  restoration_results=restoration,
214
- prediction_idx=prediction_idx), attrib_dict
215
 
216
  with open('example_input.txt', encoding='utf8') as f:
217
  examples = [line for line in f]
218
  gradio.Interface(
219
  main,
220
  inputs=gradio.inputs.Textbox(lines=3),
221
- outputs=["html", gradio.outputs.Label(label="Geographical Attribution")],
222
  examples=examples,
223
  title='Spaces Demo for Ithaca',
224
  description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)
 
1
  import subprocess
2
  import jinja2
3
  import gradio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import base64
7
+ from io import BytesIO
8
 
9
  subprocess.run(
10
  ["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
 
30
  from ithaca.util.alphabet import GreekAlphabet
31
  import jax
32
 
33
+
34
+ def create_time_plot(attribution, dataset_config):
35
+ #compute scores
36
+ date_pred_y = np.array(attribution.year_scores)
37
+ date_pred_x = np.arange(
38
+ dataset_config.date_min + dataset_config.date_interval / 2,
39
+ dataset_config.date_max + dataset_config.date_interval / 2,
40
+ dataset_config.date_interval)
41
+ date_pred_argmax = date_pred_y.argmax(
42
+ ) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2
43
+ date_pred_avg = np.dot(date_pred_y, date_pred_x)
44
+
45
+ # Plot figure
46
+ fig = plt.figure(figsize=(10, 5), dpi=100)
47
+
48
+ plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution')
49
+ plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average')
50
+
51
+
52
+ plt.ylabel('Probability', fontsize=14)
53
+ yticks = np.arange(0, 1.1, 0.1)
54
+ yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks))
55
+ plt.yticks(yticks, yticks_str, fontsize=12, rotation=0)
56
+ plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10)
57
+
58
+ plt.xlabel('Date', fontsize=14)
59
+ xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25))
60
+ xticks_str = list(map(bce_ad, xticks))
61
+ plt.xticks(xticks, xticks_str, fontsize=12, rotation=0)
62
+ plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100))
63
+ plt.legend(loc='upper right', fontsize=12)
64
+
65
+ #encode to base64 for html parsing
66
+ tmpfile = BytesIO()
67
+ fig.savefig(tmpfile.getvalue()).decode('utf-8')
68
+ html = '<div>' + '<img src="data:image/png;base64,{}">'.format(encoded.decode('utf-8')) + '</div>'
69
+
70
+ return html
71
  def get_subregion_name(id, region_map):
72
  return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
73
 
 
106
 
107
  return checkpoint['model_config'], region_map, alphabet, params, forward
108
 
 
109
  def main(text):
 
 
110
  restore_template = jinja2.Template("""<!DOCTYPE html>
111
  <html>
112
  <head>
 
246
  vocab_word_size=vocab_word_size)
247
 
248
  prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
 
249
  attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
250
+ class dataset_config:
251
+ date_interval = 10
252
+ date_max = 800
253
+ date_min = -800
254
  return restore_template.render(
255
  restoration_results=restoration,
256
+ prediction_idx=prediction_idx), attrib_dict, create_time_plot(attribution, dataset_config)
257
 
258
  with open('example_input.txt', encoding='utf8') as f:
259
  examples = [line for line in f]
260
  gradio.Interface(
261
  main,
262
  inputs=gradio.inputs.Textbox(lines=3),
263
+ outputs=['html', gradio.outputs.Label(label='Geographical Attribution'), 'html'],
264
  examples=examples,
265
  title='Spaces Demo for Ithaca',
266
  description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)