versae commited on
Commit
95166f2
1 Parent(s): 2432fd3

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +4 -2
gradio_app.py CHANGED
@@ -43,6 +43,7 @@ if DEVICE != "cpu" and not torch.cuda.is_available():
43
  logger.info(f"DEVICE {DEVICE}")
44
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
45
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
 
46
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
47
  HEADER_INFO = """
48
  # BERTIN GPT-J-6B
@@ -140,10 +141,11 @@ class TextGeneration:
140
  def load(self):
141
  logger.info("Loading model...")
142
  self.tokenizer = AutoTokenizer.from_pretrained(
143
- self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
144
  )
145
  self.model = AutoModelForCausalLM.from_pretrained(
146
- self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
 
147
  pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
148
  torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
149
  ).to(device=DEVICE, non_blocking=False)
 
43
  logger.info(f"DEVICE {DEVICE}")
44
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
45
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
46
+ MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
47
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
48
  HEADER_INFO = """
49
  # BERTIN GPT-J-6B
 
141
  def load(self):
142
  logger.info("Loading model...")
143
  self.tokenizer = AutoTokenizer.from_pretrained(
144
+ self.model_name_or_path, revision=MODEL_REVISION, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
145
  )
146
  self.model = AutoModelForCausalLM.from_pretrained(
147
+ self.model_name_or_path, revision=MODEL_REVISION,
148
+ use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
149
  pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
150
  torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
151
  ).to(device=DEVICE, non_blocking=False)