Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +4 -2
gradio_app.py
CHANGED
@@ -43,6 +43,7 @@ if DEVICE != "cpu" and not torch.cuda.is_available():
|
|
43 |
logger.info(f"DEVICE {DEVICE}")
|
44 |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
|
45 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
|
|
46 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
47 |
HEADER_INFO = """
|
48 |
# BERTIN GPT-J-6B
|
@@ -140,10 +141,11 @@ class TextGeneration:
|
|
140 |
def load(self):
|
141 |
logger.info("Loading model...")
|
142 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
143 |
-
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
144 |
)
|
145 |
self.model = AutoModelForCausalLM.from_pretrained(
|
146 |
-
self.model_name_or_path,
|
|
|
147 |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
|
148 |
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
|
149 |
).to(device=DEVICE, non_blocking=False)
|
|
|
43 |
logger.info(f"DEVICE {DEVICE}")
|
44 |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
|
45 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
46 |
+
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
|
47 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
48 |
HEADER_INFO = """
|
49 |
# BERTIN GPT-J-6B
|
|
|
141 |
def load(self):
|
142 |
logger.info("Loading model...")
|
143 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
144 |
+
self.model_name_or_path, revision=MODEL_REVISION, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
145 |
)
|
146 |
self.model = AutoModelForCausalLM.from_pretrained(
|
147 |
+
self.model_name_or_path, revision=MODEL_REVISION,
|
148 |
+
use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
149 |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
|
150 |
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
|
151 |
).to(device=DEVICE, non_blocking=False)
|