rphrp1985 commited on
Commit
fdd9d54
1 Parent(s): e18bb87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -70,6 +70,7 @@ tokenizer = AutoTokenizer.from_pretrained(
70
  , token= token,)
71
 
72
 
 
73
 
74
  model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
75
  # torch_dtype= torch.uint8,
@@ -78,12 +79,14 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
78
  attn_implementation="flash_attention_2",
79
  low_cpu_mem_usage=True,
80
 
81
- device_map='cuda',
 
82
 
83
  )
84
 
85
 
86
  #
 
87
 
88
 
89
  # device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
@@ -104,7 +107,7 @@ def respond(
104
  top_p,
105
  ):
106
  messages = [{"role": "user", "content": "Hello, how are you?"}]
107
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
108
  ## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
109
  # with autocast():
110
  gen_tokens = model.generate(
 
70
  , token= token,)
71
 
72
 
73
+ accelerator = Accelerator()
74
 
75
  model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
76
  # torch_dtype= torch.uint8,
 
79
  attn_implementation="flash_attention_2",
80
  low_cpu_mem_usage=True,
81
 
82
+ # device_map='cuda',
83
+ device_map=accelerator.device_map,
84
 
85
  )
86
 
87
 
88
  #
89
+ model = accelerator.prepare(model)
90
 
91
 
92
  # device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
 
107
  top_p,
108
  ):
109
  messages = [{"role": "user", "content": "Hello, how are you?"}]
110
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device) #.to('cuda')
111
  ## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
112
  # with autocast():
113
  gen_tokens = model.generate(