Spaces:
Runtime error
Runtime error
File size: 9,148 Bytes
7bb8c92 dca8bd0 b938769 e178154 7bb8c92 9c91ab8 b938769 e178154 6a05bf6 e178154 f9831e2 5e74ed8 e178154 0138889 b938769 3a5c5fd a3811cb 5e74ed8 a3811cb b725af2 a3811cb 0138889 b938769 3a5c5fd 5dab160 3a5c5fd a3811cb 6a05bf6 ab1439f 8554db3 ab1439f b938769 7fe2538 5e74ed8 f3beb50 c5fa3d6 b08025b 7c7e65a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
import subprocess
import jinja2
import gradio
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO
subprocess.run(
["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
# Copyright 2021 the Ithaca Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for running inference. See also colab."""
import functools
import pickle
from ithaca.eval import inference
from ithaca.models.model import Model
from ithaca.util.alphabet import GreekAlphabet
import jax
def create_time_plot(attribution):
class dataset_config:
date_interval = 10
date_max = 800
date_min = -800
def bce_ad(d):
if d < 0:
return f'{abs(d)} BCE'
elif d > 0:
return f'{abs(d)} AD'
return 0
#compute scores
date_pred_y = np.array(attribution.year_scores)
date_pred_x = np.arange(
dataset_config.date_min + dataset_config.date_interval / 2,
dataset_config.date_max + dataset_config.date_interval / 2,
dataset_config.date_interval)
date_pred_argmax = date_pred_y.argmax(
) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2
date_pred_avg = np.dot(date_pred_y, date_pred_x)
# Plot figure
fig = plt.figure(figsize=(10, 5), dpi=100)
plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution')
plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average')
plt.ylabel('Probability', fontsize=14)
yticks = np.arange(0, 1.1, 0.1)
yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks))
plt.yticks(yticks, yticks_str, fontsize=12, rotation=0)
plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10)
plt.xlabel('Date', fontsize=14)
xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25))
xticks_str = list(map(bce_ad, xticks))
plt.xticks(xticks, xticks_str, fontsize=12, rotation=0)
plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100))
plt.legend(loc='upper right', fontsize=12)
#encode to base64 for html parsing
tmpfile = BytesIO()
fig.savefig(tmpfile, format='png')
encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
html = '<h3> Chronological Attribution </h3>' + '<div>' + '<img src="data:image/png;charset=utf-8;base64,{}">'.format(encoded) + '</div>'
return html
def get_subregion_name(id, region_map):
return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
def load_checkpoint(path):
"""Loads a checkpoint pickle.
Args:
path: path to checkpoint pickle
Returns:
a model config dictionary (arguments to the model's constructor), a dict of
dicts containing region mapping information, a GreekAlphabet instance with
indices and words populated from the checkpoint, a dict of Jax arrays
`params`, and a `forward` function.
"""
# Pickled checkpoint dict containing params and various config:
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
# We reconstruct the model using the same arguments as during training, which
# are saved as a dict in the "model_config" key, and construct a `forward`
# function of the form required by attribute() and restore().
params = jax.device_put(checkpoint['params'])
model = Model(**checkpoint['model_config'])
forward = functools.partial(model.apply, params)
# Contains the mapping between region IDs and names:
region_map = checkpoint['region_map']
# Use vocabulary mapping from the checkpoint, the rest of the values in the
# class are fixed and constant e.g. the padding symbol
alphabet = GreekAlphabet()
alphabet.idx2word = checkpoint['alphabet']['idx2word']
alphabet.word2idx = checkpoint['alphabet']['word2idx']
return checkpoint['model_config'], region_map, alphabet, params, forward
def main(text):
restore_template = jinja2.Template("""<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Roboto Mono', monospace;
font-weight: 400;
}
.container {
overflow-x: scroll;
scroll-behavior: smooth;
}
table {
table-layout: fixed;
font-size: 16px;
padding: 0;
white-space: nowrap;
}
table tr:first-child {
font-weight: bold;
}
table td {
border-bottom: 1px solid #ccc;
padding: 3px 0;
}
table td.header {
font-family: Roboto, Helvetica, sans-serif;
text-align: right;
position: -webkit-sticky;
position: sticky;
background-color: white;
}
.header-1 {
background-color: white;
width: 120px;
min-width: 120px;
max-width: 120px;
left: 0;
}
.header-2 {
left: 120px;
width: 50px;
max-width: 50px;
min-width: 50px;
padding-right: 5px;
}
table td:not(.header) {
border-left: 1px solid black;
padding-left: 5px;
}
.header-2col {
width: 170px;
min-width: 170px;
max-width: 170px;
left: 0;
padding-right: 5px;
}
.pred {
background: #ddd;
}
</style>
</head>
<body>
<h3> Restoration </h3>
<div class="container">
<table cellspacing="0">
<tr>
<td colspan="2" class="header header-2col">Input text:</td>
<td>
{% for char in restoration_results.input_text -%}
{%- if loop.index0 in prediction_idx -%}
<span class="pred">{{char}}</span>
{%- else -%}
{{char}}
{%- endif -%}
{%- endfor %}
</td>
</tr>
<!-- Predictions: -->
{% for pred in restoration_results.predictions[:3] %}
<tr>
<td class="header header-1">Hypothesis {{ loop.index }}:</td>
<td class="header header-2">{{ "%.1f%%"|format(100 * pred.score) }}</td>
<td>
{% for char in pred.text -%}
{%- if loop.index0 in prediction_idx -%}
<span class="pred">{{char}}</span>
{%- else -%}
{{char}}
{%- endif -%}
{%- endfor %}
</td>
</tr>
{% endfor %}
</table>
</div>
<script>
document.querySelector('#btn').addEventListener('click', () => {
const pred = document.querySelector(".pred");
pred.scrollIntoViewIfNeeded();
});
</script>
</body>
</html>
""")
if not 50 <= len(text) <= 750:
raise app.UsageError(
f'Text should be between 50 and 750 chars long, but the input was '
f'{len(input_text)} characters')
# Load the checkpoint pickle and extract from it the pieces needed for calling
# the attribute() and restore() functions:
(model_config, region_map, alphabet, params,
forward) = load_checkpoint('checkpoint.pkl')
vocab_char_size = model_config['vocab_char_size']
vocab_word_size = model_config['vocab_word_size']
attribution = inference.attribute(
text,
forward=forward,
params=params,
alphabet=alphabet,
region_map=region_map,
vocab_char_size=vocab_char_size,
vocab_word_size=vocab_word_size)
restoration = inference.restore(
text,
forward=forward,
params=params,
alphabet=alphabet,
vocab_char_size=vocab_char_size,
vocab_word_size=vocab_word_size)
prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
return restore_template.render(
restoration_results=restoration,
prediction_idx=prediction_idx), attrib_dict, create_time_plot(attribution)
with open('example_input.txt', encoding='utf8') as f:
examples = [line for line in f]
gradio.Interface(
main,
inputs=gradio.inputs.Textbox(lines=3),
outputs=['html', gradio.outputs.Label(label='Geographical Attribution'), 'html'],
examples=examples,
title='Spaces Demo for Ithaca',
description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True)
|