import os
import base64
import fitz
from io import BytesIO
from PIL import Image
import requests
from llama_index.llms.nvidia import NVIDIA
from llama_index.vector_stores.milvus import MilvusVectorStore
from dotenv import load_dotenv
load_dotenv()
def set_environment_variables():
"""Set necessary environment variables."""
os.environ["NVIDIA_API_KEY"] = os.getenv("NVIDIA_API_KEY") #set API key
def get_b64_image_from_content(image_content):
"""Convert image content to base64 encoded string."""
img = Image.open(BytesIO(image_content))
if img.mode != 'RGB':
img = img.convert('RGB')
buffered = BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def is_graph(image_content):
"""Determine if an image is a graph, plot, chart, or table."""
res = describe_image(image_content)
return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"])
def process_graph(image_content):
"""Process a graph image and generate a description."""
deplot_description = process_graph_deplot(image_content)
mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct")
response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description)
return response.text
def describe_image(image_content):
"""Generate a description of an image using NVIDIA API."""
image_b64 = get_b64_image_from_content(image_content)
invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b"
api_key = os.getenv("NVIDIA_API_KEY")
if not api_key:
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
payload = {
"messages": [
{
"role": "user",
"content": f"""
Describe what you see in this image:
Also include:
1. Visible text extraction discovering names and description of products(can use ocr).
2. Inferred location or scene type in the image.
4. Date/time information and its location.
"""
}
],
"max_tokens": 1024,
"temperature": 0.20,
"top_p": 0.70,
"seed": 0,
"stream": False
}
response = requests.post(invoke_url, headers=headers, json=payload)
return response.json()["choices"][0]['message']['content']
def process_graph_deplot(image_content):
"""Process a graph image using NVIDIA's Deplot API."""
invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot"
image_b64 = get_b64_image_from_content(image_content)
api_key = os.getenv("NVIDIA_API_KEY")
if not api_key:
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
payload = {
"messages": [
{
"role": "user",
"content": f'Generate underlying data table of the figure below: '
}
],
"max_tokens": 1024,
"temperature": 0.20,
"top_p": 0.20,
"stream": False
}
response = requests.post(invoke_url, headers=headers, json=payload)
return response.json()["choices"][0]['message']['content']
def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1):
"""Extract text above and below a given bounding box on a page."""
before_text, after_text = "", ""
vertical_threshold_distance = page_height * threshold_percentage
horizontal_threshold_distance = bbox.width * threshold_percentage
for block in text_blocks:
block_bbox = fitz.Rect(block[:4])
vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1))
horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0))
if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance:
if block_bbox.y1 < bbox.y0 and not before_text:
before_text = block[4]
elif block_bbox.y0 > bbox.y1 and not after_text:
after_text = block[4]
break
return before_text, after_text
def process_text_blocks(text_blocks, char_count_threshold=500):
"""Group text blocks based on a character count threshold."""
current_group = []
grouped_blocks = []
current_char_count = 0
for block in text_blocks:
if block[-1] == 0: # Check if the block is of text type
block_text = block[4]
block_char_count = len(block_text)
if current_char_count + block_char_count <= char_count_threshold:
current_group.append(block)
current_char_count += block_char_count
else:
if current_group:
grouped_content = "\n".join([b[4] for b in current_group])
grouped_blocks.append((current_group[0], grouped_content))
current_group = [block]
current_char_count = block_char_count
# Append the last group
if current_group:
grouped_content = "\n".join([b[4] for b in current_group])
grouped_blocks.append((current_group[0], grouped_content))
return grouped_blocks
def save_uploaded_file(uploaded_file):
"""Save an uploaded file to a temporary directory."""
temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as temp_file:
temp_file.write(uploaded_file.read())
return temp_file_path