Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import spaces | |
# Example images and texts | |
EXAMPLES = [ | |
["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], | |
["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], | |
["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"], | |
["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], | |
["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], | |
] | |
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" | |
# CPU/GPU device | |
zero = torch.Tensor([0]).cuda() | |
# Tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# Model | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16, | |
) | |
def process(text: str) -> str: | |
"""Take the text, the tokenizer and the causal model and generate the correction.""" | |
prompt = prepare_instruction(text) | |
input_ids = tokenizer( | |
prompt, | |
add_special_tokens=True, | |
return_tensors="pt" | |
).input_ids | |
output = model.generate( | |
input_ids.to(zero.device), # GPU | |
do_sample=False, | |
max_new_tokens=512, | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() | |
def prepare_instruction(text: str) -> str: | |
"""Prepare instruction prompt for fine-tuning and inference. | |
Args: | |
text (str): List of ingredients | |
Returns: | |
str: Instruction. | |
""" | |
instruction = ( | |
"###Correct the list of ingredients:\n" | |
+ text | |
+ "\n\n###Correction:\n" | |
) | |
return instruction | |
########################## | |
# GRADIO SETUP | |
########################## | |
# Creating the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ingredients Spellcheck") | |
gr.Markdown("") | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type="pil", label="image_input") | |
ingredients = gr.Textbox(label="List of ingredients") | |
spellcheck_button = gr.Button(value='Spellcheck') | |
with gr.Column(): | |
correction = gr.Textbox(label="Correction", interactive=False) | |
with gr.Row(): | |
gr.Examples( | |
fn=process, | |
examples=EXAMPLES, | |
inputs=[ | |
image, | |
ingredients, | |
], | |
outputs=[correction], | |
run_on_click=False, | |
) | |
spellcheck_button.click( | |
fn=process, | |
inputs=[ingredients], | |
outputs=[correction] | |
) | |
if __name__ == "__main__": | |
# Launch the demo | |
demo.launch() | |