import torch from PIL import Image from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2VLForConditionalGeneration device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" min_pixels = 1 * 28 * 28 max_pixels = 256 * 28 * 28 # 2560 * 28 * 28 processor = AutoProcessor.from_pretrained( "MrLight/dse-qwen2-2b-mrl-v1", min_pixels=min_pixels, max_pixels=max_pixels ) model = ( Qwen2VLForConditionalGeneration.from_pretrained( "MrLight/dse-qwen2-2b-mrl-v1", # attn_implementation="eager", attn_implementation="flash_attention_2" if device == "cuda" else "eager", # flash_attn is required but is a pain to install on spaces torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, ) .to(device) .eval() ) processor.tokenizer.padding_side = "left" model.padding_side = "left" def get_embedding(last_hidden_state: torch.Tensor, dimension: int): reps = last_hidden_state[:, -1] reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1) return reps.to(torch.float32).cpu().numpy() def encode_queries(queries: list): if isinstance(queries, str): queries = [queries] query_messages = [] for query in queries: message = [ { "role": "user", "content": [ { "type": "image", "image": Image.new("RGB", (28, 28)), "resized_height": 1, "resized_width": 1, }, # need a dummy image here for an easier process. {"type": "text", "text": f"Query: {query}"}, ], } ] query_messages.append(message) query_texts = [ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + "<|endoftext|>" for msg in query_messages ] query_image_inputs, query_video_inputs = process_vision_info(query_messages) query_inputs = processor( text=query_texts, images=query_image_inputs, videos=query_video_inputs, padding="longest", return_tensors="pt", ).to(device) query_inputs = model.prepare_inputs_for_generation(**query_inputs, use_cache=False) with torch.no_grad(): output = model(**query_inputs, return_dict=True, output_hidden_states=True) query_embeddings = get_embedding( output.hidden_states[-1], 1536 ) # adjust dimensionality for efficiency trade-off, e.g. 512 return query_embeddings def encode_images(images: list): if isinstance(images, Image.Image): images = [images] doc_messages = [] for image in images: message = [ { "role": "user", "content": [ { "type": "image", "image": image, }, #'resized_height':680 , 'resized_width':680} # adjust the image size for efficiency trade-off {"type": "text", "text": "What is shown in this image?"}, ], } ] doc_messages.append(message) doc_texts = [ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + "<|endoftext|>" for msg in doc_messages ] doc_image_inputs, doc_video_inputs = process_vision_info(doc_messages) doc_inputs = processor( text=doc_texts, images=doc_image_inputs, videos=doc_video_inputs, padding="longest", return_tensors="pt", ).to(device) doc_inputs = model.prepare_inputs_for_generation(**doc_inputs, use_cache=False) output = model(**doc_inputs, return_dict=True, output_hidden_states=True) with torch.no_grad(): output = model(**doc_inputs, return_dict=True, output_hidden_states=True) doc_embeddings = get_embedding( output.hidden_states[-1], 1536 ) # adjust dimensionality for efficiency trade-off e.g. 512 return doc_embeddings