rishiraj commited on
Commit
7047a96
·
1 Parent(s): 9f309ad

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +2 -59
model_utils.py CHANGED
@@ -14,47 +14,12 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
17
- from typing import Dict
18
-
19
- import torch
20
- from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
21
-
22
- from accelerate import Accelerator
23
- from huggingface_hub import list_repo_files
24
- from peft import LoraConfig, PeftConfig
25
 
26
  from .configs import DataArguments, ModelArguments
27
  from .data import DEFAULT_CHAT_TEMPLATE
28
 
29
 
30
- def get_current_device() -> int:
31
- """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
32
- return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
33
-
34
-
35
- def get_kbit_device_map() -> Dict[str, int] | None:
36
- """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
37
- return {"": get_current_device()} if torch.cuda.is_available() else None
38
-
39
-
40
- def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
41
- if model_args.load_in_4bit:
42
- quantization_config = BitsAndBytesConfig(
43
- load_in_4bit=True,
44
- bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
45
- bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
46
- bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
47
- )
48
- elif model_args.load_in_8bit:
49
- quantization_config = BitsAndBytesConfig(
50
- load_in_8bit=True,
51
- )
52
- else:
53
- quantization_config = None
54
-
55
- return quantization_config
56
-
57
-
58
  def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
59
  """Get the tokenizer for the model."""
60
  tokenizer = AutoTokenizer.from_pretrained(
@@ -76,26 +41,4 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr
76
  elif tokenizer.chat_template is None:
77
  tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
78
 
79
- return tokenizer
80
-
81
-
82
- def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
83
- if model_args.use_peft is False:
84
- return None
85
-
86
- peft_config = LoraConfig(
87
- r=model_args.lora_r,
88
- lora_alpha=model_args.lora_alpha,
89
- lora_dropout=model_args.lora_dropout,
90
- bias="none",
91
- task_type="CAUSAL_LM",
92
- target_modules=model_args.lora_target_modules,
93
- modules_to_save=model_args.lora_modules_to_save,
94
- )
95
-
96
- return peft_config
97
-
98
-
99
- def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
100
- repo_files = list_repo_files(model_name_or_path, revision=revision)
101
- return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
17
+ from transformers import AutoTokenizer, PreTrainedTokenizer
 
 
 
 
 
 
 
18
 
19
  from .configs import DataArguments, ModelArguments
20
  from .data import DEFAULT_CHAT_TEMPLATE
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
24
  """Get the tokenizer for the model."""
25
  tokenizer = AutoTokenizer.from_pretrained(
 
41
  elif tokenizer.chat_template is None:
42
  tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
43
 
44
+ return tokenizer