File size: 9,148 Bytes
7bb8c92
dca8bd0
b938769
e178154
 
 
 
7bb8c92
 
9c91ab8
b938769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e178154
6a05bf6
 
 
 
 
 
 
 
 
 
 
 
e178154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9831e2
 
5e74ed8
e178154
 
0138889
 
 
b938769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a5c5fd
a3811cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e74ed8
a3811cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b725af2
a3811cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0138889
b938769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a5c5fd
 
5dab160
3a5c5fd
a3811cb
6a05bf6
ab1439f
8554db3
ab1439f
b938769
 
7fe2538
5e74ed8
f3beb50
c5fa3d6
b08025b
7c7e65a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import subprocess
import jinja2
import gradio
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO

subprocess.run(
        ["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
# Copyright 2021 the Ithaca Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for running inference. See also colab."""
import functools
import pickle

from ithaca.eval import inference
from ithaca.models.model import Model
from ithaca.util.alphabet import GreekAlphabet
import jax


def create_time_plot(attribution):
    class dataset_config:
      date_interval = 10
      date_max = 800
      date_min = -800

    def bce_ad(d):
      if d < 0:
        return f'{abs(d)} BCE'
      elif d > 0:
        return f'{abs(d)} AD'
      return 0
    #compute scores
    date_pred_y = np.array(attribution.year_scores)
    date_pred_x = np.arange(
      dataset_config.date_min + dataset_config.date_interval / 2,
      dataset_config.date_max + dataset_config.date_interval / 2,
      dataset_config.date_interval)
    date_pred_argmax = date_pred_y.argmax(
    ) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2
    date_pred_avg = np.dot(date_pred_y, date_pred_x)

    # Plot figure
    fig = plt.figure(figsize=(10, 5), dpi=100)

    plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution')
    plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average')


    plt.ylabel('Probability', fontsize=14)
    yticks = np.arange(0, 1.1, 0.1)
    yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks))
    plt.yticks(yticks, yticks_str, fontsize=12, rotation=0)
    plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10)

    plt.xlabel('Date', fontsize=14)
    xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25))
    xticks_str = list(map(bce_ad, xticks))
    plt.xticks(xticks, xticks_str, fontsize=12, rotation=0)
    plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100))
    plt.legend(loc='upper right', fontsize=12)

    #encode to base64 for html parsing
    tmpfile = BytesIO()
    fig.savefig(tmpfile, format='png')
    encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
    html = '<h3> Chronological Attribution </h3>' + '<div>' + '<img src="data:image/png;charset=utf-8;base64,{}">'.format(encoded) + '</div>'

    return html
def get_subregion_name(id, region_map):
  return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]

def load_checkpoint(path):
  """Loads a checkpoint pickle.

  Args:
    path: path to checkpoint pickle

  Returns:
    a model config dictionary (arguments to the model's constructor), a dict of
    dicts containing region mapping information, a GreekAlphabet instance with
    indices and words populated from the checkpoint, a dict of Jax arrays
    `params`, and a `forward` function.
  """

  # Pickled checkpoint dict containing params and various config:
  with open(path, 'rb') as f:
    checkpoint = pickle.load(f)

  # We reconstruct the model using the same arguments as during training, which
  # are saved as a dict in the "model_config" key, and construct a `forward`
  # function of the form required by attribute() and restore().
  params = jax.device_put(checkpoint['params'])
  model = Model(**checkpoint['model_config'])
  forward = functools.partial(model.apply, params)

  # Contains the mapping between region IDs and names:
  region_map = checkpoint['region_map']

  # Use vocabulary mapping from the checkpoint, the rest of the values in the
  # class are fixed and constant e.g. the padding symbol
  alphabet = GreekAlphabet()
  alphabet.idx2word = checkpoint['alphabet']['idx2word']
  alphabet.word2idx = checkpoint['alphabet']['word2idx']

  return checkpoint['model_config'], region_map, alphabet, params, forward

def main(text):
  restore_template = jinja2.Template("""<!DOCTYPE html>
    <html>
    <head>
    <link rel="preconnect" href="https://fonts.googleapis.com">
    <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
    <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
    <style>
    body {
      font-family: 'Roboto Mono', monospace;
      font-weight: 400;
    }
    .container {
      overflow-x: scroll;
      scroll-behavior: smooth;
    }
    table {
      table-layout: fixed;
      font-size: 16px;
      padding: 0;
      white-space: nowrap;
    }
    table tr:first-child {
      font-weight: bold;
    }
    table td {
      border-bottom: 1px solid #ccc;
      padding: 3px 0;
    }
    table td.header {
      font-family: Roboto, Helvetica, sans-serif;
      text-align: right;
      position: -webkit-sticky;
      position: sticky;
      background-color: white;
    }
    .header-1 {
      background-color: white;
      width: 120px;
      min-width: 120px;
      max-width: 120px;
      left: 0;
    }
    .header-2 {
      left: 120px;
      width: 50px;
      max-width: 50px;
      min-width: 50px;
      padding-right: 5px;
    }
    table td:not(.header) {
      border-left: 1px solid black;
      padding-left: 5px;
    }
    .header-2col {
      width: 170px;
      min-width: 170px;
      max-width: 170px;
      left: 0;
      padding-right: 5px;
    }
    .pred {
      background: #ddd;
    }
    </style>
    </head>
    <body>
    <h3> Restoration </h3>
    <div class="container">
    <table cellspacing="0">
      <tr>
        <td colspan="2" class="header header-2col">Input text:</td>
        <td>
        {% for char in restoration_results.input_text -%}
          {%- if loop.index0 in prediction_idx -%}
            <span class="pred">{{char}}</span>
          {%- else -%}
            {{char}}
          {%- endif -%}
        {%- endfor %}
        </td>
      </tr>
      <!-- Predictions: -->
      {% for pred in restoration_results.predictions[:3] %}
      <tr>
        <td class="header header-1">Hypothesis {{ loop.index }}:</td>
        <td class="header header-2">{{ "%.1f%%"|format(100 * pred.score) }}</td>
        <td>
          {% for char in pred.text -%}
            {%- if loop.index0 in prediction_idx -%}
              <span class="pred">{{char}}</span>
            {%- else -%}
              {{char}}
            {%- endif -%}
          {%- endfor %}
        </td>
      </tr>
      {% endfor %}
    </table>
    </div>
    <script>
    document.querySelector('#btn').addEventListener('click', () => {
      const pred = document.querySelector(".pred");
      pred.scrollIntoViewIfNeeded();
    });
    </script>
    </body>
    </html>
    """)

  if not 50 <= len(text) <= 750:
    raise app.UsageError(
        f'Text should be between 50 and 750 chars long, but the input was '
        f'{len(input_text)} characters')

  # Load the checkpoint pickle and extract from it the pieces needed for calling
  # the attribute() and restore() functions:
  (model_config, region_map, alphabet, params,
   forward) = load_checkpoint('checkpoint.pkl')
  vocab_char_size = model_config['vocab_char_size']
  vocab_word_size = model_config['vocab_word_size']

  attribution = inference.attribute(
      text,
      forward=forward,
      params=params,
      alphabet=alphabet,
      region_map=region_map,
      vocab_char_size=vocab_char_size,
      vocab_word_size=vocab_word_size)

  restoration = inference.restore(
      text,
      forward=forward,
      params=params,
      alphabet=alphabet,
      vocab_char_size=vocab_char_size,
      vocab_word_size=vocab_word_size)

  prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
  attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
  return restore_template.render(
          restoration_results=restoration,
          prediction_idx=prediction_idx), attrib_dict, create_time_plot(attribution)

with open('example_input.txt', encoding='utf8') as f:
    examples = [line for line in f]
gradio.Interface(
        main,
        inputs=gradio.inputs.Textbox(lines=3),
        outputs=['html', gradio.outputs.Label(label='Geographical Attribution'), 'html'],
        examples=examples,
        title='Spaces Demo for Ithaca',
        description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)