File size: 3,678 Bytes
b585c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
But accepts preloaded model to avoid slowness in use and CUDA forking issues
Loader that uses Pix2Struct models to image caption
"""
from typing import List, Union, Any, Tuple
from langchain.docstore.document import Document
from langchain.document_loaders import ImageCaptionLoader
from utils import get_device, clear_torch_cache
from PIL import Image
class H2OPix2StructLoader(ImageCaptionLoader):
"""Loader that extracts text from images"""
def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base",
max_new_tokens=50):
super().__init__(path_images)
self._pix2struct_model = None
self._model_type = model_type
self._max_new_tokens = max_new_tokens
def set_context(self):
if get_device() == 'cuda':
import torch
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
if n_gpus > 0:
self.context_class = torch.device
self.device = 'cuda'
else:
self.device = 'cpu'
else:
self.device = 'cpu'
def load_model(self):
try:
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
except ImportError:
raise ValueError(
"`transformers` package not found, please install with "
"`pip install transformers`."
)
if self._pix2struct_model:
self._pix2struct_model = self._pix2struct_model.to(self.device)
return self
self.set_context()
self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type)
self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device)
return self
def unload_model(self):
if hasattr(self._pix2struct_model, 'cpu'):
self._pix2struct_model.cpu()
clear_torch_cache()
def set_image_paths(self, path_images: Union[str, List[str]]):
"""
Load from a list of image files
"""
if isinstance(path_images, str):
self.image_paths = [path_images]
else:
self.image_paths = path_images
def load(self, prompt=None) -> List[Document]:
if self._pix2struct_model is None:
self.load_model()
results = []
for path_image in self.image_paths:
caption, metadata = self._get_captions_and_metadata(
processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image
)
doc = Document(page_content=caption, metadata=metadata)
results.append(doc)
return results
def _get_captions_and_metadata(
self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]:
"""
Helper function for getting the captions and metadata of an image
"""
try:
image = Image.open(path_image)
except Exception:
raise ValueError(f"Could not get image data for {path_image}")
inputs = self._pix2struct_processor(images=image, return_tensors="pt")
inputs = inputs.to(self.device)
generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens)
generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
metadata: dict = {"image_path": path_image}
return generated_text, metadata
|