Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +7 -2
gradio_app.py
CHANGED
@@ -41,7 +41,12 @@ DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
|
|
41 |
if DEVICE != "cpu" and not torch.cuda.is_available():
|
42 |
DEVICE = "cpu"
|
43 |
logger.info(f"DEVICE {DEVICE}")
|
44 |
-
DTYPE =
|
|
|
|
|
|
|
|
|
|
|
45 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
46 |
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
|
47 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
@@ -147,7 +152,7 @@ class TextGeneration:
|
|
147 |
self.model_name_or_path, revision=MODEL_REVISION,
|
148 |
use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
149 |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
|
150 |
-
torch_dtype=DTYPE, low_cpu_mem_usage=
|
151 |
).to(device=DEVICE, non_blocking=False)
|
152 |
_ = self.model.eval()
|
153 |
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
|
|
|
41 |
if DEVICE != "cpu" and not torch.cuda.is_available():
|
42 |
DEVICE = "cpu"
|
43 |
logger.info(f"DEVICE {DEVICE}")
|
44 |
+
DTYPE = getattr(
|
45 |
+
torch,
|
46 |
+
os.environ.get("DTYPE", ""),
|
47 |
+
torch.float32 if DEVICE == "cpu" else torch.float16
|
48 |
+
)
|
49 |
+
LOW_CPU_MEM = bool(os.environ.get("LOW_CPU_MEM", False if DEVICE == "cpu" else True))
|
50 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
51 |
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
|
52 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
|
|
152 |
self.model_name_or_path, revision=MODEL_REVISION,
|
153 |
use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
154 |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
|
155 |
+
torch_dtype=DTYPE, low_cpu_mem_usage=LOW_CPU_MEM,
|
156 |
).to(device=DEVICE, non_blocking=False)
|
157 |
_ = self.model.eval()
|
158 |
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
|