Spaces:
Sleeping
Sleeping
File size: 3,693 Bytes
7901fac 0ae684c f50b49c 5d75ec7 be730b6 7901fac be730b6 0ae684c be730b6 0ae684c be730b6 7901fac 0ae684c 7901fac 0ae684c fff6204 5781f7f 0ae684c 5781f7f fff6204 0ae684c 7901fac fff6204 7901fac 0ae684c fff6204 be730b6 0ae684c be730b6 0ae684c f50b49c be730b6 f50b49c 0ae684c be730b6 0ae684c be730b6 0ae684c be730b6 0ae684c f50b49c 0ae684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
import torch
from PIL import Image
import gc
import tempfile
import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from byaldi import RAGMultiModalModel
# Function to load Byaldi model
@st.cache_resource
def load_byaldi_model():
model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device="cpu")
return model
# Function to load Qwen2-VL model
@st.cache_resource
def load_qwen_model():
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
return model, processor
# Function to clear GPU memory
def clear_memory():
gc.collect()
torch.cuda.empty_cache()
# Streamlit Interface
st.title("OCR and Visual Language Model Demo")
st.write("Upload an image for OCR extraction and then ask a question about the image.")
# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image:
img = Image.open(image)
st.image(img, caption="Uploaded Image", use_column_width=True)
# OCR Extraction with Byaldi
st.write("Extracting text from image...")
byaldi_model = load_byaldi_model()
# Save the image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
img.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
# Create a temporary index for the uploaded image
with st.spinner("Processing image..."):
byaldi_model.index(temp_file_path, index_name="temp_index", overwrite=True)
# Perform a dummy search to get the OCR results
ocr_results = byaldi_model.search("Extract all text from the image", k=1)
# Extract the OCR text from the results
if ocr_results:
extracted_text = ocr_results[0].metadata.get("ocr_text", "No text extracted")
else:
extracted_text = "No text extracted"
st.write("Extracted Text:")
st.write(extracted_text)
# Clear Byaldi model from memory
del byaldi_model
clear_memory()
# Remove the temporary file
os.unlink(temp_file_path)
# Text input field for question
question = st.text_input("Ask a question about the image and extracted text")
if question:
st.write("Processing with Qwen2-VL...")
qwen_model, qwen_processor = load_qwen_model()
# Prepare inputs for Qwen2-VL
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": f"Extracted text: {extracted_text}\n\nQuestion: {question}"},
],
}
]
# Prepare for inference
text_input = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = qwen_processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
# Move tensors to CPU
inputs = inputs.to("cpu")
# Run the model and generate output
with torch.no_grad():
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)
# Decode the output text
generated_text = qwen_processor.batch_decode(generated_ids, skip_special_tokens=True)
# Display the response
st.write("Model's response:", generated_text)
# Clear Qwen model from memory
del qwen_model, qwen_processor
clear_memory() |