Spaces:
Runtime error
Runtime error
jiwan-chung
commited on
Commit
·
5a61cb9
1
Parent(s):
61a945a
running on cpu
Browse files- arguments.py +0 -7
- load.py +1 -1
- run.py +2 -3
arguments.py
CHANGED
@@ -37,8 +37,6 @@ def get_args():
|
|
37 |
'--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
|
38 |
parser.add_argument(
|
39 |
'--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
|
40 |
-
parser.add_argument(
|
41 |
-
'--num-gpus', type=int, default=None, help='number of gpus. use all available if none')
|
42 |
parser.add_argument(
|
43 |
'--port', type=int, default=None, help="port for the demo server")
|
44 |
|
@@ -47,11 +45,6 @@ def get_args():
|
|
47 |
|
48 |
if args.use_label_prefix:
|
49 |
log.info(f'using label prefix')
|
50 |
-
num_gpus = torch.cuda.device_count()
|
51 |
-
if args.num_gpus is None:
|
52 |
-
args.num_gpus = num_gpus
|
53 |
-
else:
|
54 |
-
args.num_gpus = min(num_gpus, args.num_gpus)
|
55 |
|
56 |
if args.checkpoint is not None:
|
57 |
args.checkpoint = str(Path(args.checkpoint).resolve())
|
|
|
37 |
'--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
|
38 |
parser.add_argument(
|
39 |
'--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
|
|
|
|
|
40 |
parser.add_argument(
|
41 |
'--port', type=int, default=None, help="port for the demo server")
|
42 |
|
|
|
45 |
|
46 |
if args.use_label_prefix:
|
47 |
log.info(f'using label prefix')
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if args.checkpoint is not None:
|
50 |
args.checkpoint = str(Path(args.checkpoint).resolve())
|
load.py
CHANGED
@@ -42,7 +42,7 @@ def load_model(args, device, finetune=False):
|
|
42 |
use_transformer_mapper=args.use_transformer_mapper,
|
43 |
model_weight='None', use_label_prefix=args.use_label_prefix)
|
44 |
ckpt = args.checkpoint + '.ckpt'
|
45 |
-
state = torch.load(ckpt)
|
46 |
policy_key = 'policy_model'
|
47 |
if policy_key in state:
|
48 |
policy.model.load_state_dict(state[policy_key])
|
|
|
42 |
use_transformer_mapper=args.use_transformer_mapper,
|
43 |
model_weight='None', use_label_prefix=args.use_label_prefix)
|
44 |
ckpt = args.checkpoint + '.ckpt'
|
45 |
+
state = torch.load(ckpt, map_location=torch.device('cpu'))
|
46 |
policy_key = 'policy_model'
|
47 |
if policy_key in state:
|
48 |
policy.model.load_state_dict(state[policy_key])
|
run.py
CHANGED
@@ -22,16 +22,15 @@ log = logging.getLogger(__name__)
|
|
22 |
|
23 |
|
24 |
def prepare(args):
|
25 |
-
num_gpus = torch.cuda.device_count()
|
26 |
-
log.info(f'Detect {num_gpus} GPUS')
|
27 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
28 |
args = load_model_args(args)
|
29 |
|
30 |
def load_style(args, checkpoint):
|
31 |
model = AutoModelForCausalLM.from_pretrained(args.init_model)
|
32 |
if checkpoint is not None and Path(checkpoint).is_file():
|
33 |
log.info("joint model: loading pretrained style generator")
|
34 |
-
state = torch.load(checkpoint)
|
35 |
if 'global_step' in state:
|
36 |
step = state['global_step']
|
37 |
log.info(f'trained for {step} steps')
|
|
|
22 |
|
23 |
|
24 |
def prepare(args):
|
|
|
|
|
25 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
26 |
+
log.info(f'Device: {device}')
|
27 |
args = load_model_args(args)
|
28 |
|
29 |
def load_style(args, checkpoint):
|
30 |
model = AutoModelForCausalLM.from_pretrained(args.init_model)
|
31 |
if checkpoint is not None and Path(checkpoint).is_file():
|
32 |
log.info("joint model: loading pretrained style generator")
|
33 |
+
state = torch.load(checkpoint, map_location=torch.device('cpu'))
|
34 |
if 'global_step' in state:
|
35 |
step = state['global_step']
|
36 |
log.info(f'trained for {step} steps')
|