reichenbach's picture
Example changes
1dd4fe3
raw
history blame
2.34 kB
import os
os.system('pip install tensorflow')
import json
import numpy as np
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from huggingface_hub.keras_mixin import from_pretrained_keras
class CustomNonPaddingTokenLoss(keras.losses.Loss):
def __init__(self, name="custom_ner_loss"):
super().__init__(name=name)
def call(self, y_true, y_pred):
loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fn(y_true, y_pred)
mask = tf.cast((y_true > 0), dtype=tf.float32)
loss = loss * mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
def lowercase_and_convert_to_ids(tokens):
tokens = tf.strings.lower(tokens)
return lookup_layer(tokens)
def tokenize_and_convert_to_ids(text):
tokens = text.split()
return lowercase_and_convert_to_ids(tokens)
def ner_tagging(text_1):
with open('mapping.json','r') as f:
mapping = json.load(f)
ner_model = from_pretrained_keras("keras-io/ner-with-transformers",
custom_objects={'CustomNonPaddingTokenLoss':CustomNonPaddingTokenLoss},
compile=False)
sample_input = tokenize_and_convert_to_ids(text_1)
sample_input = tf.reshape(sample_input, shape=[1, -1])
output = ner_model.predict(sample_input)
prediction = np.argmax(output, axis=-1)[0]
prediction = [mapping[str(i)] for i in prediction]
return prediction
text_1 = gr.inputs.Textbox(lines=5)
ner_tag = gr.outputs.Textbox()
with open("vocab.json",'r') as f:
vocab = json.load(f)
lookup_layer = keras.layers.StringLookup(vocabulary=vocab['tokens'])
iface = gr.Interface(ner_tagging,
inputs=text_1,outputs=ner_tag, examples=[['EU rejects German call to boycott British lamb .'],
["He said further scientific study was required and if it was found that action was needed it should be taken by the European Union ."]], title="Named Entity Recognition with Transformers",
description = "Named Entity Recognition with Transformers on CoNLL2003 Dataset",
article = "Author: <a href=\"https://huggingface.co/reichenbach\">Rishav Chandra Varma</a>")
iface.launch()