import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import spaces # Example images and texts EXAMPLES = [ ["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], ["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"], ["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], ["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], ] MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" # CPU/GPU device zero = torch.Tensor([0]).cuda() # Tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # Model model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) @spaces.GPU def process(text: str) -> str: """Take the text, the tokenizer and the causal model and generate the correction.""" prompt = prepare_instruction(text) input_ids = tokenizer( prompt, add_special_tokens=True, return_tensors="pt" ).input_ids output = model.generate( input_ids.to(zero.device), # GPU do_sample=False, max_new_tokens=512, ) return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() def prepare_instruction(text: str) -> str: """Prepare instruction prompt for fine-tuning and inference. Args: text (str): List of ingredients Returns: str: Instruction. """ instruction = ( "###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n" ) return instruction ########################## # GRADIO SETUP ########################## # Creating the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Ingredients Spellcheck") gr.Markdown("") with gr.Row(): with gr.Column(): image = gr.Image(type="pil", label="image_input") ingredients = gr.Textbox(label="List of ingredients") spellcheck_button = gr.Button(value='Spellcheck') with gr.Column(): correction = gr.Textbox(label="Correction", interactive=False) with gr.Row(): gr.Examples( fn=process, examples=EXAMPLES, inputs=[ image, ingredients, ], outputs=[correction], run_on_click=False, ) spellcheck_button.click( fn=process, inputs=[ingredients], outputs=[correction] ) if __name__ == "__main__": # Launch the demo demo.launch()