ithaca / app.py
badongtakla's picture
added examples to file
ab1439f
raw
history blame
6.82 kB
import subprocess
import jinja2
import gradio
subprocess.run(
["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"])
# Copyright 2021 the Ithaca Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for running inference. See also colab."""
import functools
import pickle
from ithaca.eval import inference
from ithaca.models.model import Model
from ithaca.util.alphabet import GreekAlphabet
import jax
def get_subregion_name(id, region_map):
return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]]
def load_checkpoint(path):
"""Loads a checkpoint pickle.
Args:
path: path to checkpoint pickle
Returns:
a model config dictionary (arguments to the model's constructor), a dict of
dicts containing region mapping information, a GreekAlphabet instance with
indices and words populated from the checkpoint, a dict of Jax arrays
`params`, and a `forward` function.
"""
# Pickled checkpoint dict containing params and various config:
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
# We reconstruct the model using the same arguments as during training, which
# are saved as a dict in the "model_config" key, and construct a `forward`
# function of the form required by attribute() and restore().
params = jax.device_put(checkpoint['params'])
model = Model(**checkpoint['model_config'])
forward = functools.partial(model.apply, params)
# Contains the mapping between region IDs and names:
region_map = checkpoint['region_map']
# Use vocabulary mapping from the checkpoint, the rest of the values in the
# class are fixed and constant e.g. the padding symbol
alphabet = GreekAlphabet()
alphabet.idx2word = checkpoint['alphabet']['idx2word']
alphabet.word2idx = checkpoint['alphabet']['word2idx']
return checkpoint['model_config'], region_map, alphabet, params, forward
def main(text):
restore_template = jinja2.Template("""<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Roboto Mono', monospace;
font-weight: 400;
}
.container {
overflow-x: scroll;
scroll-behavior: smooth;
}
table {
table-layout: fixed;
font-size: 16px;
padding: 0;
white-space: nowrap;
}
table tr:first-child {
font-weight: bold;
}
table td {
border-bottom: 1px solid #ccc;
padding: 3px 0;
}
table td.header {
font-family: Roboto, Helvetica, sans-serif;
text-align: right;
position: -webkit-sticky;
position: sticky;
background-color: white;
}
.header-1 {
background-color: white;
width: 120px;
min-width: 120px;
max-width: 120px;
left: 0;
}
.header-2 {
left: 120px;
width: 50px;
max-width: 50px;
min-width: 50px;
padding-right: 5px;
}
table td:not(.header) {
border-left: 1px solid black;
padding-left: 5px;
}
.header-2col {
width: 170px;
min-width: 170px;
max-width: 170px;
left: 0;
padding-right: 5px;
}
.pred {
background: #ddd;
}
</style>
</head>
<body>
<div class="container">
<table cellspacing="0">
<tr>
<td colspan="2" class="header header-2col">Input text:</td>
<td>
{% for char in restoration_results.input_text -%}
{%- if loop.index0 in prediction_idx -%}
<span class="pred">{{char}}</span>
{%- else -%}
{{char}}
{%- endif -%}
{%- endfor %}
</td>
</tr>
<!-- Predictions: -->
{% for pred in restoration_results.predictions[:3] %}
<tr>
<td class="header header-1">Hypothesis {{ loop.index }}:</td>
<td class="header header-2">{{ "%.1f%%"|format(100 * pred.score) }}</td>
<td>
{% for char in pred.text -%}
{%- if loop.index0 in prediction_idx -%}
<span class="pred">{{char}}</span>
{%- else -%}
{{char}}
{%- endif -%}
{%- endfor %}
</td>
</tr>
{% endfor %}
</table>
</div>
<script>
document.querySelector('#btn').addEventListener('click', () => {
const pred = document.querySelector(".pred");
pred.scrollIntoViewIfNeeded();
});
</script>
</body>
</html>
""")
locations = []
if not 50 <= len(text) <= 750:
raise app.UsageError(
f'Text should be between 50 and 750 chars long, but the input was '
f'{len(input_text)} characters')
# Load the checkpoint pickle and extract from it the pieces needed for calling
# the attribute() and restore() functions:
(model_config, region_map, alphabet, params,
forward) = load_checkpoint('checkpoint.pkl')
vocab_char_size = model_config['vocab_char_size']
vocab_word_size = model_config['vocab_word_size']
attribution = inference.attribute(
text,
forward=forward,
params=params,
alphabet=alphabet,
region_map=region_map,
vocab_char_size=vocab_char_size,
vocab_word_size=vocab_word_size)
restoration = inference.restore(
text,
forward=forward,
params=params,
alphabet=alphabet,
vocab_char_size=vocab_char_size,
vocab_word_size=vocab_word_size)
prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?')
attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]}
return restore_template.render(
restoration_results=restoration,
prediction_idx=prediction_idx), attrib_dict
with open('example_input.txt', encoding='utf8') as f:
examples = [line for line in f]
gradio.Interface(
main,
inputs="text",
outputs=["html", gradio.outputs.Label(label="Geographical Attribution")],
examples=examples,
description='spaces demo for Ithaca').launch(enable_queue=True)