jedyang97 commited on
Commit
7cb0f3c
·
1 Parent(s): 15bf65b

fix zero gpu

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. model.py +2 -3
app.py CHANGED
@@ -44,9 +44,10 @@ tokenizer, model, data_loader = load_model_and_dataloader(
44
  load_4bit=load_4bit,
45
  load_bf16=load_bf16,
46
  scene_to_obj_mapping=scene_to_obj_mapping,
47
- )
 
 
48
 
49
- @spaces.GPU
50
  def get_chatbot_response(user_chat_input, scene_id):
51
  # Get the response from the model
52
  prompt, response = get_model_response(
 
44
  load_4bit=load_4bit,
45
  load_bf16=load_bf16,
46
  scene_to_obj_mapping=scene_to_obj_mapping,
47
+ device_map='cpu',
48
+ ) # Huggingface Zero-GPU has to use .to(device) to set the device, otherwise it will fail
49
+ model.to("cuda") # Huggingface Zero-GPU requires explicit device placement
50
 
 
51
  def get_chatbot_response(user_chat_input, scene_id):
52
  # Get the response from the model
53
  prompt, response = get_model_response(
model.py CHANGED
@@ -12,8 +12,7 @@ from llava.mm_utils import get_model_name_from_path
12
  from llava.model.builder import load_pretrained_model
13
  from llava.utils import disable_torch_init
14
 
15
- @spaces.GPU
16
- def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_context_feature_type="text", load_8bit=False, load_4bit=False, load_bf16=False):
17
 
18
  model_name = get_model_name_from_path(model_path)
19
 
@@ -24,6 +23,7 @@ def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_
24
  load_8bit=load_8bit,
25
  load_4bit=load_4bit,
26
  load_bf16=load_bf16,
 
27
  )
28
 
29
  dataset = ObjIdentifierDataset(
@@ -41,7 +41,6 @@ def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_
41
  return tokenizer, model, data_loader
42
 
43
 
44
- @spaces.GPU
45
  def get_model_response(model, tokenizer, data_loader, scene_id, user_input, max_new_tokens=50, temperature=0.2, top_p=0.9):
46
  input_data = [
47
  {
 
12
  from llava.model.builder import load_pretrained_model
13
  from llava.utils import disable_torch_init
14
 
15
+ def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_context_feature_type="text", load_8bit=False, load_4bit=False, load_bf16=False, device_map='auto'):
 
16
 
17
  model_name = get_model_name_from_path(model_path)
18
 
 
23
  load_8bit=load_8bit,
24
  load_4bit=load_4bit,
25
  load_bf16=load_bf16,
26
+ device_map=device_map,
27
  )
28
 
29
  dataset = ObjIdentifierDataset(
 
41
  return tokenizer, model, data_loader
42
 
43
 
 
44
  def get_model_response(model, tokenizer, data_loader, scene_id, user_input, max_new_tokens=50, temperature=0.2, top_p=0.9):
45
  input_data = [
46
  {