Spaces:
Runtime error
Runtime error
Commit
·
e178154
1
Parent(s):
b08025b
add chronological attribution output
Browse files
app.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import subprocess
|
2 |
import jinja2
|
3 |
import gradio
|
|
|
|
|
|
|
|
|
4 |
|
5 |
subprocess.run(
|
6 |
["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
|
@@ -26,6 +30,44 @@ from ithaca.models.model import Model
|
|
26 |
from ithaca.util.alphabet import GreekAlphabet
|
27 |
import jax
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def get_subregion_name(id, region_map):
|
30 |
return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
|
31 |
|
@@ -64,10 +106,7 @@ def load_checkpoint(path):
|
|
64 |
|
65 |
return checkpoint['model_config'], region_map, alphabet, params, forward
|
66 |
|
67 |
-
|
68 |
def main(text):
|
69 |
-
|
70 |
-
|
71 |
restore_template = jinja2.Template("""<!DOCTYPE html>
|
72 |
<html>
|
73 |
<head>
|
@@ -207,18 +246,21 @@ def main(text):
|
|
207 |
vocab_word_size=vocab_word_size)
|
208 |
|
209 |
prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
|
210 |
-
|
211 |
attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
|
|
|
|
|
|
|
|
|
212 |
return restore_template.render(
|
213 |
restoration_results=restoration,
|
214 |
-
prediction_idx=prediction_idx), attrib_dict
|
215 |
|
216 |
with open('example_input.txt', encoding='utf8') as f:
|
217 |
examples = [line for line in f]
|
218 |
gradio.Interface(
|
219 |
main,
|
220 |
inputs=gradio.inputs.Textbox(lines=3),
|
221 |
-
outputs=[
|
222 |
examples=examples,
|
223 |
title='Spaces Demo for Ithaca',
|
224 |
description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)
|
|
|
1 |
import subprocess
|
2 |
import jinja2
|
3 |
import gradio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
|
9 |
subprocess.run(
|
10 |
["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
|
|
|
30 |
from ithaca.util.alphabet import GreekAlphabet
|
31 |
import jax
|
32 |
|
33 |
+
|
34 |
+
def create_time_plot(attribution, dataset_config):
|
35 |
+
#compute scores
|
36 |
+
date_pred_y = np.array(attribution.year_scores)
|
37 |
+
date_pred_x = np.arange(
|
38 |
+
dataset_config.date_min + dataset_config.date_interval / 2,
|
39 |
+
dataset_config.date_max + dataset_config.date_interval / 2,
|
40 |
+
dataset_config.date_interval)
|
41 |
+
date_pred_argmax = date_pred_y.argmax(
|
42 |
+
) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2
|
43 |
+
date_pred_avg = np.dot(date_pred_y, date_pred_x)
|
44 |
+
|
45 |
+
# Plot figure
|
46 |
+
fig = plt.figure(figsize=(10, 5), dpi=100)
|
47 |
+
|
48 |
+
plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution')
|
49 |
+
plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average')
|
50 |
+
|
51 |
+
|
52 |
+
plt.ylabel('Probability', fontsize=14)
|
53 |
+
yticks = np.arange(0, 1.1, 0.1)
|
54 |
+
yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks))
|
55 |
+
plt.yticks(yticks, yticks_str, fontsize=12, rotation=0)
|
56 |
+
plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10)
|
57 |
+
|
58 |
+
plt.xlabel('Date', fontsize=14)
|
59 |
+
xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25))
|
60 |
+
xticks_str = list(map(bce_ad, xticks))
|
61 |
+
plt.xticks(xticks, xticks_str, fontsize=12, rotation=0)
|
62 |
+
plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100))
|
63 |
+
plt.legend(loc='upper right', fontsize=12)
|
64 |
+
|
65 |
+
#encode to base64 for html parsing
|
66 |
+
tmpfile = BytesIO()
|
67 |
+
fig.savefig(tmpfile.getvalue()).decode('utf-8')
|
68 |
+
html = '<div>' + '<img src="data:image/png;base64,{}">'.format(encoded.decode('utf-8')) + '</div>'
|
69 |
+
|
70 |
+
return html
|
71 |
def get_subregion_name(id, region_map):
|
72 |
return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
|
73 |
|
|
|
106 |
|
107 |
return checkpoint['model_config'], region_map, alphabet, params, forward
|
108 |
|
|
|
109 |
def main(text):
|
|
|
|
|
110 |
restore_template = jinja2.Template("""<!DOCTYPE html>
|
111 |
<html>
|
112 |
<head>
|
|
|
246 |
vocab_word_size=vocab_word_size)
|
247 |
|
248 |
prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
|
|
|
249 |
attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
|
250 |
+
class dataset_config:
|
251 |
+
date_interval = 10
|
252 |
+
date_max = 800
|
253 |
+
date_min = -800
|
254 |
return restore_template.render(
|
255 |
restoration_results=restoration,
|
256 |
+
prediction_idx=prediction_idx), attrib_dict, create_time_plot(attribution, dataset_config)
|
257 |
|
258 |
with open('example_input.txt', encoding='utf8') as f:
|
259 |
examples = [line for line in f]
|
260 |
gradio.Interface(
|
261 |
main,
|
262 |
inputs=gradio.inputs.Textbox(lines=3),
|
263 |
+
outputs=['html', gradio.outputs.Label(label='Geographical Attribution'), 'html'],
|
264 |
examples=examples,
|
265 |
title='Spaces Demo for Ithaca',
|
266 |
description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)
|