|
import streamlit as st |
|
from PIL import Image |
|
import inference |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from PIL import Image |
|
import requests |
|
import copy |
|
import os |
|
from unittest.mock import patch |
|
from transformers.dynamic_module_utils import get_imports |
|
import torch |
|
|
|
|
|
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: |
|
if not str(filename).endswith("modeling_florence2.py"): |
|
return get_imports(filename) |
|
imports = get_imports(filename) |
|
imports.remove("flash_attn") |
|
return imports |
|
|
|
|
|
if 'model_loaded' not in st.session_state: |
|
st.session_state.model_loaded = False |
|
|
|
|
|
def load_model(): |
|
|
|
model_id = "microsoft/Florence-2-large" |
|
|
|
st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True) |
|
|
|
|
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): |
|
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True) |
|
|
|
|
|
Qmodel = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
) |
|
del model |
|
st.session_state.model = Qmodel |
|
st.session_state.model_loaded = True |
|
st.write("model loaded complete") |
|
|
|
if not st.session_state.model_loaded: |
|
with st.spinner('Loading model...'): |
|
load_model() |
|
|
|
|
|
|
|
if 'has_run' not in st.session_state: |
|
st.session_state.has_run = False |
|
|
|
|
|
st.markdown('<h3><center><b>VQA</b></center></h3>', unsafe_allow_html=True) |
|
|
|
uploaded_image = st.sidebar.file_uploader("Upload your image here", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image is not None: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
task_prompt = st.sidebar.text_input("Task Prompt", value="Describe the image in detail:") |
|
|
|
text_input = st.sidebar.text_area("Input Questions", height=20) |
|
|
|
if st.sidebar.button("Generate Caption", key="Generate") and not st.session_state.has_run: |
|
|
|
st.session_state.has_run = True |
|
st.write(task_prompt,"\n\n",text_input) |
|
inference.demo() |
|
|
|
|
|
|