rphrp1985 commited on
Commit
4ad0753
1 Parent(s): bd8e143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -8,6 +8,9 @@ import os
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
 
 
 
 
11
 
12
  subprocess.run(
13
  "pip install flash-attn --no-build-isolation",
@@ -35,7 +38,8 @@ tokenizer = AutoTokenizer.from_pretrained(
35
  model_id
36
  , token= token,)
37
 
38
- model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
 
39
  # torch_dtype= torch.uint8,
40
  torch_dtype=torch.float16,
41
  # torch_dtype=torch.fl,
@@ -50,6 +54,11 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
50
  #
51
 
52
 
 
 
 
 
 
53
 
54
  @spaces.GPU(duration=60)
55
  def respond(
 
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
 
11
+ from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
12
+
13
+
14
 
15
  subprocess.run(
16
  "pip install flash-attn --no-build-isolation",
 
38
  model_id
39
  , token= token,)
40
 
41
+ with init_empty_weights():
42
+ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
43
  # torch_dtype= torch.uint8,
44
  torch_dtype=torch.float16,
45
  # torch_dtype=torch.fl,
 
54
  #
55
 
56
 
57
+ device_map = infer_auto_device_map(model, max_memory={0: "80GB", 1: "80GB", "cpu": "65GB"})
58
+
59
+ # Load the model with the inferred device map
60
+ model = load_checkpoint_and_dispatch(model, "path_to_checkpoint", device_map=device_map, no_split_module_classes=["GPTJBlock"])
61
+
62
 
63
  @spaces.GPU(duration=60)
64
  def respond(