File size: 3,693 Bytes
7901fac
0ae684c
 
 
f50b49c
5d75ec7
be730b6
7901fac
be730b6
0ae684c
be730b6
0ae684c
be730b6
 
 
7901fac
0ae684c
7901fac
0ae684c
fff6204
 
 
 
 
5781f7f
0ae684c
 
 
 
5781f7f
fff6204
0ae684c
 
7901fac
fff6204
 
7901fac
0ae684c
fff6204
 
 
be730b6
0ae684c
be730b6
0ae684c
f50b49c
 
 
 
 
be730b6
 
f50b49c
0ae684c
be730b6
 
0ae684c
be730b6
 
 
 
 
0ae684c
 
 
 
be730b6
 
0ae684c
 
f50b49c
 
 
0ae684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import torch
from PIL import Image
import gc
import tempfile
import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from byaldi import RAGMultiModalModel

# Function to load Byaldi model
@st.cache_resource
def load_byaldi_model():
    model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device="cpu")
    return model

# Function to load Qwen2-VL model
@st.cache_resource
def load_qwen_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
    return model, processor

# Function to clear GPU memory
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Streamlit Interface
st.title("OCR and Visual Language Model Demo")
st.write("Upload an image for OCR extraction and then ask a question about the image.")

# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if image:
    img = Image.open(image)
    st.image(img, caption="Uploaded Image", use_column_width=True)

    # OCR Extraction with Byaldi
    st.write("Extracting text from image...")
    byaldi_model = load_byaldi_model()
    
    # Save the image to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        img.save(temp_file, format="JPEG")
        temp_file_path = temp_file.name

    # Create a temporary index for the uploaded image
    with st.spinner("Processing image..."):
        byaldi_model.index(temp_file_path, index_name="temp_index", overwrite=True)
    
    # Perform a dummy search to get the OCR results
    ocr_results = byaldi_model.search("Extract all text from the image", k=1)
    
    # Extract the OCR text from the results
    if ocr_results:
        extracted_text = ocr_results[0].metadata.get("ocr_text", "No text extracted")
    else:
        extracted_text = "No text extracted"
    
    st.write("Extracted Text:")
    st.write(extracted_text)
    
    # Clear Byaldi model from memory
    del byaldi_model
    clear_memory()

    # Remove the temporary file
    os.unlink(temp_file_path)

    # Text input field for question
    question = st.text_input("Ask a question about the image and extracted text")

    if question:
        st.write("Processing with Qwen2-VL...")
        qwen_model, qwen_processor = load_qwen_model()

        # Prepare inputs for Qwen2-VL
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img},
                    {"type": "text", "text": f"Extracted text: {extracted_text}\n\nQuestion: {question}"},
                ],
            }
        ]

        # Prepare for inference
        text_input = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = qwen_processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")

        # Move tensors to CPU
        inputs = inputs.to("cpu")

        # Run the model and generate output
        with torch.no_grad():
            generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)

        # Decode the output text
        generated_text = qwen_processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Display the response
        st.write("Model's response:", generated_text)

        # Clear Qwen model from memory
        del qwen_model, qwen_processor
        clear_memory()