File size: 1,351 Bytes
0dce0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/bin/bash

script_path=$(realpath $0)
script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)

MP_SIZE=1
# MODEL_NAME="MSAGPT-"
# MODEL_NAME="MSAGPT-dpo"


SEED=12345
MAX_GEN_LENGTH=128
MIN_GEN_LENGTH=0

# BeamSearchStrategy args
NUM_BEAMS=4
LENGTH_PENALTY=1.0
NO_REPEAT_NGRAM=0

# BaseStrategy args 
TEMP=0.8
TOPK=0
TOPP=0.9


PORT=19865

MODEL_ARGS="--bf16 \
            --skip-init \
            --mode finetune \
            --rotary-embedding-2d"

       #      --mode inference \ TODO: sat ds_config bug?

GENERATION_ARGS="--seed $SEED \
              --sampling-strategy BaseStrategy \
              --max-gen-length $MAX_GEN_LENGTH \
              --min-gen-length $MIN_GEN_LENGTH \
              --num-beams $NUM_BEAMS \
              --length-penalty $LENGTH_PENALTY \
              --no-repeat-ngram-size $NO_REPEAT_NGRAM \
              --multiline_stream \
              --temperature $TEMP \
              --top_k $TOPK \
              --top_p $TOPP 
"
# --sampling-strategy BeamSearchStrategy \
# --no-gap


OPTIONS_NCCL="NCCL_DEBUG=VERSION NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 CUDA_LAUNCH_BLOCKING=0"

ARGS="${main_dir}/cli_sat.py \
       $MODEL_ARGS \
       $GENERATION_ARGS \
       $*"

run_cmd="${OPTIONS_NCCL} torchrun --nproc_per_node $MP_SIZE --master_port=$PORT ${ARGS}"
echo  ${run_cmd}
eval ${run_cmd}
set +x