aria-dev commited on
Commit
feb88b1
1 Parent(s): 39d9c7b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -66
README.md CHANGED
@@ -1,67 +1,69 @@
1
- ---
2
- license: apache-2.0
3
- ---
4
-
5
- This repository offers int8 quantized weights of the [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) model utilizing the [TorchAO](https://github.com/pytorch/ao) quantization framework. It now supports inference within 30GB of GPU memory.
6
-
7
-
8
- ## Quick Start
9
- ### Installation
10
- ```
11
- pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torch==2.5.0 torchao==0.6.1 torchvision requests Pillow
12
- pip install flash-attn --no-build-isolation
13
- ```
14
-
15
- ### Inference
16
-
17
- ```python
18
- import requests
19
- import torch
20
- from PIL import Image
21
- from transformers import AutoModelForCausalLM, AutoProcessor
22
-
23
- model_id_or_path = "rhymes-ai/Aria-torchao-int8wo"
24
-
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id_or_path,
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- trust_remote_code=True,
30
- attn_implementation="flash_attention_2",
31
- )
32
-
33
- processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
34
-
35
- image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
36
-
37
- image = Image.open(requests.get(image_path, stream=True).raw)
38
-
39
- messages = [
40
- {
41
- "role": "user",
42
- "content": [
43
- {"text": None, "type": "image"},
44
- {"text": "what is the image?", "type": "text"},
45
- ],
46
- }
47
- ]
48
-
49
- text = processor.apply_chat_template(messages, add_generation_prompt=True)
50
- inputs = processor(text=text, images=image, return_tensors="pt")
51
- inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
52
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
53
-
54
- with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
55
- output = model.generate(
56
- **inputs,
57
- max_new_tokens=500,
58
- stop_strings=["<|im_end|>"],
59
- tokenizer=processor.tokenizer,
60
- do_sample=True,
61
- temperature=0.9,
62
- )
63
- output_ids = output[0][inputs["input_ids"].shape[1] :]
64
- result = processor.decode(output_ids, skip_special_tokens=True)
65
-
66
- print(result)
 
 
67
  ```
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - rhymes-ai/Aria
5
+ ---
6
+
7
+ This repository offers int8 quantized weights of the [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) model utilizing the [TorchAO](https://github.com/pytorch/ao) quantization framework. It now supports inference within 30GB of GPU memory.
8
+
9
+
10
+ ## Quick Start
11
+ ### Installation
12
+ ```
13
+ pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torch==2.5.0 torchao==0.6.1 torchvision requests Pillow
14
+ pip install flash-attn --no-build-isolation
15
+ ```
16
+
17
+ ### Inference
18
+
19
+ ```python
20
+ import requests
21
+ import torch
22
+ from PIL import Image
23
+ from transformers import AutoModelForCausalLM, AutoProcessor
24
+
25
+ model_id_or_path = "rhymes-ai/Aria-torchao-int8wo"
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id_or_path,
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True,
32
+ attn_implementation="flash_attention_2",
33
+ )
34
+
35
+ processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
36
+
37
+ image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
38
+
39
+ image = Image.open(requests.get(image_path, stream=True).raw)
40
+
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"text": None, "type": "image"},
46
+ {"text": "what is the image?", "type": "text"},
47
+ ],
48
+ }
49
+ ]
50
+
51
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
52
+ inputs = processor(text=text, images=image, return_tensors="pt")
53
+ inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
54
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
55
+
56
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
57
+ output = model.generate(
58
+ **inputs,
59
+ max_new_tokens=500,
60
+ stop_strings=["<|im_end|>"],
61
+ tokenizer=processor.tokenizer,
62
+ do_sample=True,
63
+ temperature=0.9,
64
+ )
65
+ output_ids = output[0][inputs["input_ids"].shape[1] :]
66
+ result = processor.decode(output_ids, skip_special_tokens=True)
67
+
68
+ print(result)
69
  ```