import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from torchao.quantization import ( quantize_, ) from torchao.quantization.quant_api import _is_linear import requests from torchao.quantization.quant_api import to_affine_quantized_intx, MappingType, _get_linear_subclass_inserter from moe_lm import GroupedGEMM torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.fx_graph_cache = True model_id_or_path = "./out/aria-torchao-in8wo" tokenizer_id_or_path = "./" def int8_weight_only(group_size=None): """ Applies int8 weight-only symmetric per-channel quantization to linear layers. """ def apply_int8wo_quant(weight, group_size=None): weight = weight.reshape(-1, weight.shape[-1]) mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 if group_size is None: group_size = weight.shape[1] block_size = (1, weight.shape[1]) return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size) model = AutoModelForCausalLM.from_pretrained( model_id_or_path, device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True, do_sample=True, temperature=0.7, ) model = torch.compile(model, mode="max-autotune", fullgraph=True) def filter_fn(m, *args): if "experts.fc1" in args[0] or "experts.fc2" in args[0]: return True return _is_linear(m, *args) # quantize_(model, int8_weight_only(group_size=128), filter_fn=filter_fn) print(model) model.to("cuda") messages = [ { "role": "user", "content": [ # {"text": None, "type": "image"}, {"text": "what's in the image?", "type": "text"}, ], } ] image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" image = Image.open(requests.get(image_path, stream=True).raw) image = None processor = AutoProcessor.from_pretrained(tokenizer_id_or_path, trust_remote_code=True) text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt") # inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) inputs = {k: v.to(model.device) for k, v in inputs.items()} out = model.generate(**inputs, max_new_tokens=50, tokenizer=processor.tokenizer, stop_strings=["<|im_end|>"]) output_ids = out[0][inputs["input_ids"].shape[1] :] result = processor.decode(output_ids, skip_special_tokens=True) print(result) # model.save_pretrained("out/aria-torchao-in8wo", safe_serialization=False)