import streamlit as st from PIL import Image from transformers import AutoProcessor, AutoModelForVision2Seq import torch DEVICE = "cuda:0" st.set_page_config(page_title='idefics', page_icon = "🌀", layout = 'wide', initial_sidebar_state = 'auto') @st.cache_resource def load_my_model(): DEVICE = "cuda:0" model_name_or_path = "SalmanFaroz/idefics2-8b-DocVQA-SP" # model_name_or_path="HuggingFaceM4/idefics2-8b" processor = AutoProcessor.from_pretrained(model_name_or_path, do_image_splitting=True) model = AutoModelForVision2Seq.from_pretrained( model_name_or_path,device_map="auto" ) return processor, model processor, model= load_my_model() def generate_text(image, prompt): messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ] }, ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[image], return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} # Generate generated_ids = model.generate(**inputs, max_new_tokens=500) generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) split_texts = generated_texts[0].split('\n') split_texts=split_texts[1].split(': ')[1] return split_texts st.title("Idefics-2 Fine-tuned Model") # st.sidebar.title("Generation Parameters") # temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, step=0.1) # top_p = st.sidebar.slider("Top-p Sampling Threshold", min_value=0.1, max_value=1.0, value=0.8, step=0.1) # num_tokens = st.sidebar.number_input("Number of Tokens to Generate", min_value=10, max_value=5000, value=100, step=10) st.header("Upload Image") uploaded_image = st.file_uploader("", type=["jpg", "jpeg", "png"]) st.header("Prompt") prompt = st.text_input("") if st.button("Generate Text"): if uploaded_image is None: st.error("Please upload an image.") elif prompt == "": st.error("Please enter prompt text.") else: image = Image.open(uploaded_image) generated_text = generate_text(image, prompt) st.write("Prompt:\n\n", prompt) st.write("Generated Text:") st.write(generated_text) if st.button("Refresh"): uploaded_image = None prompt = "" torch.cuda.empty_cache() if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption="Uploaded Image" ) torch.cuda.empty_cache()