Spaces:
Runtime error
Runtime error
Update logger
Browse files
model.py
CHANGED
@@ -80,10 +80,10 @@ formatter = logging.Formatter(
|
|
80 |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
81 |
datefmt='%Y-%m-%d %H:%M:%S')
|
82 |
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
83 |
-
stream_handler.setLevel(logging.
|
84 |
stream_handler.setFormatter(formatter)
|
85 |
logger = logging.getLogger(__name__)
|
86 |
-
logger.setLevel(logging.
|
87 |
logger.propagate = False
|
88 |
logger.addHandler(stream_handler)
|
89 |
|
@@ -254,7 +254,7 @@ class Model:
|
|
254 |
self.style = style
|
255 |
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
|
256 |
self.query_template = self.args.query_template
|
257 |
-
logger.
|
258 |
|
259 |
self.strategy.temperature = self.args.temp_all_gen
|
260 |
|
@@ -296,7 +296,7 @@ class Model:
|
|
296 |
start = time.perf_counter()
|
297 |
|
298 |
text = self.query_template.format(text)
|
299 |
-
logger.
|
300 |
seq = tokenizer.encode(text)
|
301 |
logger.info(f'{len(seq)=}')
|
302 |
if len(seq) > 110:
|
@@ -342,7 +342,7 @@ class Model:
|
|
342 |
output_list.append(coarse_samples)
|
343 |
remaining -= self.max_batch_size
|
344 |
output_tokens = torch.cat(output_list, dim=0)
|
345 |
-
logger.
|
346 |
|
347 |
elapsed = time.perf_counter() - start
|
348 |
logger.info(f'Elapsed: {elapsed}')
|
@@ -360,7 +360,7 @@ class Model:
|
|
360 |
logger.info('--- generate_images ---')
|
361 |
start = time.perf_counter()
|
362 |
|
363 |
-
logger.
|
364 |
res = []
|
365 |
if self.only_first_stage:
|
366 |
for i in range(len(tokens)):
|
@@ -414,6 +414,9 @@ class AppModel(Model):
|
|
414 |
self, text: str, translate: bool, style: str, seed: int,
|
415 |
only_first_stage: bool, num: int
|
416 |
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
|
|
|
|
|
|
|
417 |
if translate:
|
418 |
text = translated_text = self.translator(text)
|
419 |
else:
|
|
|
80 |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
81 |
datefmt='%Y-%m-%d %H:%M:%S')
|
82 |
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
83 |
+
stream_handler.setLevel(logging.INFO)
|
84 |
stream_handler.setFormatter(formatter)
|
85 |
logger = logging.getLogger(__name__)
|
86 |
+
logger.setLevel(logging.INFO)
|
87 |
logger.propagate = False
|
88 |
logger.addHandler(stream_handler)
|
89 |
|
|
|
254 |
self.style = style
|
255 |
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
|
256 |
self.query_template = self.args.query_template
|
257 |
+
logger.debug(f'{self.query_template=}')
|
258 |
|
259 |
self.strategy.temperature = self.args.temp_all_gen
|
260 |
|
|
|
296 |
start = time.perf_counter()
|
297 |
|
298 |
text = self.query_template.format(text)
|
299 |
+
logger.debug(f'{text=}')
|
300 |
seq = tokenizer.encode(text)
|
301 |
logger.info(f'{len(seq)=}')
|
302 |
if len(seq) > 110:
|
|
|
342 |
output_list.append(coarse_samples)
|
343 |
remaining -= self.max_batch_size
|
344 |
output_tokens = torch.cat(output_list, dim=0)
|
345 |
+
logger.debug(f'{output_tokens.shape=}')
|
346 |
|
347 |
elapsed = time.perf_counter() - start
|
348 |
logger.info(f'Elapsed: {elapsed}')
|
|
|
360 |
logger.info('--- generate_images ---')
|
361 |
start = time.perf_counter()
|
362 |
|
363 |
+
logger.debug(f'{self.only_first_stage=}')
|
364 |
res = []
|
365 |
if self.only_first_stage:
|
366 |
for i in range(len(tokens)):
|
|
|
414 |
self, text: str, translate: bool, style: str, seed: int,
|
415 |
only_first_stage: bool, num: int
|
416 |
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
|
417 |
+
logger.info(
|
418 |
+
f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}'
|
419 |
+
)
|
420 |
if translate:
|
421 |
text = translated_text = self.translator(text)
|
422 |
else:
|