Kainat98 commited on
Commit
57a3662
1 Parent(s): 3ff3d76

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +47 -0
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from optimum.onnxruntime import ORTModelForCausalLM
11
+ from transformers import AutoConfig, AutoTokenizer, GenerationConfig
12
+
13
+ device_id = 0
14
+ device = torch.device(f"cuda:{device_id}") # Change to torch.device("cpu") if running on CPU
15
+
16
+ ep = "CUDAExecutionProvider" # change to CPUExecutionProvider if running on CPU
17
+ ep_options = {"device_id": device_id}
18
+
19
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
20
+ model_path = "llama-13b-4bit-finetuned-alpaca/Olive/examples/llama2/models/qlora/qlora-conversion-transformers_optimization-bnb_quantization/gpu-cuda_model"
21
+
22
+ model_path = Path(model_path)
23
+
24
+ if not (model_path / "config.json").exists():
25
+ config = AutoConfig.from_pretrained(model_id)
26
+ config.save_pretrained(model_path)
27
+ else:
28
+ config = AutoConfig.from_pretrained(model_path)
29
+
30
+ if not (model_path / "generation_config.json").exists():
31
+ gen_config = GenerationConfig.from_pretrained(model_id)
32
+ gen_config.save_pretrained(model_path)
33
+ else:
34
+ gen_config = GenerationConfig.from_pretrained(model_path)
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
37
+
38
+ model = ORTModelForCausalLM.from_pretrained(
39
+ model_path,
40
+ config=config,
41
+ generation_config=gen_config,
42
+ use_io_binding=True,
43
+ # provider="CUDAExecutionProvider",
44
+ provider=ep,
45
+ provider_options={"device_id": device_id}
46
+ # provider_options={"device_id": str(rank)},
47
+ )