Rick7799 commited on
Commit
491e665
1 Parent(s): 3dad239

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -27
app.py CHANGED
@@ -1,35 +1,17 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
- from PIL import Image
5
 
6
- # Load the ColPali model and tokenizer from Hugging Face
7
- model_name = "vidore/colpali-v1.2" # Use the correct model identifier
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  def extract_and_search(image, keyword):
12
- try:
13
- # Convert image to RGB if it's not already in that format
14
- if image.mode != 'RGB':
15
- image = image.convert('RGB')
16
 
17
- # Preprocess image: convert to tensor format required by the model
18
- inputs = tokenizer(images=image, return_tensors="pt") # Adjust as necessary for your input requirements
19
-
20
- # Extract text from image using ColPali model
21
- with torch.no_grad(): # Disable gradient calculation for inference
22
- outputs = model.generate(**inputs)
23
-
24
- # Decode outputs to text
25
- extracted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
-
27
- # Perform keyword search
28
- matching_lines = [line for line in extracted_text.splitlines() if keyword.lower() in line.lower()]
29
-
30
- return extracted_text, matching_lines
31
- except Exception as e:
32
- return f"Error during extraction: {str(e)}", []
33
 
34
  # Create Gradio interface
35
  interface = gr.Interface(
 
1
  import gradio as gr
2
+ from byaldi import RAGMultiModalModel # Importing the ColPali model
 
 
3
 
4
+ # Initialize the ColPali model
5
+ model = RAGMultiModalModel.from_pretrained("vidore/colpali")
 
 
6
 
7
  def extract_and_search(image, keyword):
8
+ # Use the model to extract text from the image
9
+ extracted_text = model.predict(image) # Replace with actual prediction method
 
 
10
 
11
+ # Perform keyword search
12
+ matching_lines = [line for line in extracted_text.splitlines() if keyword.lower() in line.lower()]
13
+
14
+ return extracted_text, matching_lines
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Create Gradio interface
17
  interface = gr.Interface(