mmE5-mllama-11b-instruct / custom_st.py
intfloat's picture
Update custom_st.py
7f44eb9 verified
raw
history blame
3.98 kB
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
def _load_model(
self,
model_name_or_path: str,
config,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
self.auto_model = MllamaForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
)
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
# Process inputs through the model
outputs = self.auto_model(
**features,
return_dict=True,
output_hidden_states=True,
**kwargs
)
# Apply last pooling and normalization
last_hidden_state = outputs.hidden_states[-1]
attention_mask = features["attention_mask"]
sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
features.update({"sentence_embedding": sentence_embedding})
return features
def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Apply last token pooling and L2 normalization"""
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
return torch.nn.functional.normalize(reps, p=2, dim=-1)
def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
def process_text_item(item):
if isinstance(item, str):
return item, []
text, images = "", []
for sub_item in item:
if sub_item["type"] == "text":
text += sub_item["content"]
elif sub_item["type"] in ["image_bytes", "image_path"]:
text += "<|image|>"
if sub_item["type"] == "image_bytes":
img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
else:
img = Image.open(sub_item["content"]).convert("RGB")
images.append(img)
else:
raise ValueError(f"Unknown data type {sub_item['type']}")
return text, images
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.extend(images)
# Process inputs through the processor
if all_images:
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs