cutechicken commited on
Commit
a188372
1 Parent(s): 4c60e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -14,18 +14,24 @@ import spaces
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  MODEL_ID = "CohereForAI/c4ai-command-r-plus-08-2024"
16
 
 
 
 
 
17
  class ModelManager:
18
  def __init__(self):
 
19
  self.model = None
20
  self.tokenizer = None
21
  self.setup_model()
22
-
23
  def setup_model(self):
24
  try:
25
  self.tokenizer = AutoTokenizer.from_pretrained(
26
  MODEL_ID,
27
  token=HF_TOKEN,
28
- trust_remote_code=True
 
29
  )
30
  self.model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
@@ -33,22 +39,11 @@ class ModelManager:
33
  torch_dtype=torch.float16,
34
  device_map="auto",
35
  trust_remote_code=True,
36
- low_cpu_mem_usage=True
 
37
  )
38
  except Exception as e:
39
- print(f"Error loading model: {e}")
40
- # Fallback to basic loading without device_map
41
- try:
42
- self.model = AutoModelForCausalLM.from_pretrained(
43
- MODEL_ID,
44
- token=HF_TOKEN,
45
- torch_dtype=torch.float16,
46
- trust_remote_code=True
47
- )
48
- except Exception as e:
49
- raise Exception(f"Model loading failed completely: {e}")
50
-
51
-
52
 
53
  class ChatHistory:
54
  def __init__(self):
 
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  MODEL_ID = "CohereForAI/c4ai-command-r-plus-08-2024"
16
 
17
+ os.environ["TRANSFORMERS_CACHE"] = "/persistent/transformers_cache"
18
+ os.environ["TORCH_HOME"] = "/persistent/torch_cache"
19
+ os.environ["HF_HOME"] = "/persistent/huggingface"
20
+
21
  class ModelManager:
22
  def __init__(self):
23
+ self.cache_dir = "/persistent/model_cache"
24
  self.model = None
25
  self.tokenizer = None
26
  self.setup_model()
27
+
28
  def setup_model(self):
29
  try:
30
  self.tokenizer = AutoTokenizer.from_pretrained(
31
  MODEL_ID,
32
  token=HF_TOKEN,
33
+ trust_remote_code=True,
34
+ cache_dir=self.cache_dir
35
  )
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_ID,
 
39
  torch_dtype=torch.float16,
40
  device_map="auto",
41
  trust_remote_code=True,
42
+ low_cpu_mem_usage=True,
43
+ cache_dir=self.cache_dir
44
  )
45
  except Exception as e:
46
+ raise Exception(f"Model loading failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class ChatHistory:
49
  def __init__(self):