lewtun's picture
lewtun HF staff
Update README.md
5734eec verified
metadata
license: apache-2.0
base_model: HuggingFaceH4/mistral-7b-cai
tags:
  - alignment-handbook
  - generated_from_trainer
datasets:
  - HuggingFaceH4/ultrafeedback_binarized_fixed
  - HuggingFaceH4/cai-conversation-harmless
model-index:
  - name: mistral-7b-dpo-v21.0cai.0.2
    results: []

Mistral 7B Constitutional AI

This model is a DPO-aligned version of Mistral 7B on the HuggingFaceH4/ultrafeedback_binarized_fixed and the HuggingFaceH4/cai-conversation-harmless datasets.

It achieves the following results on the evaluation set:

  • Loss: 0.6327
  • Rewards/chosen: -9.8716
  • Rewards/rejected: -14.5465
  • Rewards/accuracies: 0.6725
  • Rewards/margins: 4.6749
  • Logps/rejected: -329.8578
  • Logps/chosen: -294.6768
  • Logits/rejected: -2.1023
  • Logits/chosen: -2.1648

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-07
  • train_batch_size: 2
  • eval_batch_size: 8
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 8
  • total_train_batch_size: 16
  • total_eval_batch_size: 64
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_ratio: 0.1
  • num_epochs: 3

Training results

Training Loss Epoch Step Validation Loss Rewards/chosen Rewards/rejected Rewards/accuracies Rewards/margins Logps/rejected Logps/chosen Logits/rejected Logits/chosen
0.6817 0.02 100 0.6873 0.0149 0.0002 0.5150 0.0147 -184.3912 -195.8124 -3.1605 -3.1560
0.6767 0.05 200 0.6614 0.0825 0.0169 0.5575 0.0656 -184.2246 -195.1362 -3.1654 -3.1605
0.6328 0.07 300 0.6246 -0.0374 -0.2112 0.5875 0.1738 -186.5047 -196.3349 -3.1579 -3.1529
0.5919 0.1 400 0.5978 0.2812 -0.0666 0.6125 0.3478 -185.0590 -193.1489 -3.1292 -3.1243
0.5545 0.12 500 0.5800 0.1742 -0.2810 0.6275 0.4552 -187.2035 -194.2191 -3.0819 -3.0788
0.5926 0.14 600 0.5599 0.2410 -0.3076 0.6425 0.5487 -187.4693 -193.5507 -3.0601 -3.0597
0.5326 0.17 700 0.5385 -0.2501 -0.9698 0.6400 0.7197 -194.0914 -198.4624 -2.9076 -2.9090
0.5126 0.19 800 0.5238 -0.3616 -1.1783 0.6525 0.8167 -196.1764 -199.5769 -2.9965 -2.9963
0.5283 0.22 900 0.5289 -0.4142 -1.2542 0.6775 0.8400 -196.9348 -200.1031 -3.0133 -3.0134
0.5303 0.24 1000 0.5214 -0.5949 -1.5888 0.6600 0.9939 -200.2815 -201.9101 -2.9663 -2.9669
0.5969 0.26 1100 0.5235 -0.5924 -1.5222 0.6600 0.9298 -199.6154 -201.8848 -2.9402 -2.9468
0.581 0.29 1200 0.5887 -0.7548 -1.7075 0.6400 0.9527 -201.4678 -203.5091 -2.7065 -2.7227
0.817 0.31 1300 0.6620 -1.5060 -2.4221 0.6500 0.9160 -208.6137 -211.0213 -2.7717 -2.7800
0.6039 0.34 1400 0.5321 -1.6820 -2.8439 0.6425 1.1619 -212.8325 -212.7814 -2.6828 -2.6917
0.6666 0.36 1500 0.5303 -1.3875 -2.6384 0.6475 1.2509 -210.7773 -209.8365 -2.8557 -2.8594
0.6907 0.39 1600 0.5409 -2.0657 -3.2214 0.6650 1.1556 -216.6068 -216.6184 -2.8227 -2.8288
0.5772 0.41 1700 0.5309 -1.9849 -3.2833 0.6875 1.2985 -217.2264 -215.8097 -2.6498 -2.6635
0.5601 0.43 1800 0.5281 -1.7365 -3.0643 0.6575 1.3278 -215.0359 -213.3255 -2.8890 -2.8918
0.576 0.46 1900 0.5266 -1.4822 -2.9294 0.6725 1.4472 -213.6872 -210.7831 -2.7369 -2.7427
1.2064 0.48 2000 0.5538 -2.5493 -3.7625 0.6675 1.2132 -222.0182 -221.4542 -2.6773 -2.6957
0.5751 0.51 2100 0.5465 -1.9246 -3.1480 0.6425 1.2234 -215.8728 -215.2067 -2.6490 -2.6657
0.4757 0.53 2200 0.5297 -1.8443 -3.1553 0.6325 1.3110 -215.9462 -214.4039 -2.6882 -2.7115
0.4771 0.55 2300 0.5386 -2.3340 -3.7443 0.6500 1.4103 -221.8360 -219.3013 -2.6415 -2.6623
0.481 0.58 2400 0.5355 -1.6085 -3.0800 0.6550 1.4715 -215.1930 -212.0460 -2.6073 -2.6293
0.523 0.6 2500 0.5131 -2.6139 -4.2353 0.6625 1.6214 -226.7459 -222.0998 -2.6134 -2.6394
0.6263 0.63 2600 0.5287 -2.6614 -4.0538 0.6450 1.3924 -224.9310 -222.5747 -2.6189 -2.6361
0.5973 0.65 2700 0.5132 -2.7089 -4.1248 0.625 1.4159 -225.6406 -223.0499 -2.6167 -2.6317
0.8209 0.67 2800 0.5165 -2.7085 -4.1871 0.625 1.4786 -226.2637 -223.0462 -2.5605 -2.5803
0.5625 0.7 2900 0.5117 -3.4747 -5.0369 0.6325 1.5622 -234.7624 -230.7079 -2.5891 -2.6163
0.5913 0.72 3000 0.5164 -2.5844 -4.3822 0.6675 1.7978 -228.2149 -221.8051 -2.6421 -2.6632
0.7441 0.75 3100 0.5175 -2.4900 -4.2883 0.6725 1.7983 -227.2762 -220.8608 -2.6254 -2.6465
0.6169 0.77 3200 0.5163 -2.2489 -3.8666 0.6600 1.6176 -223.0589 -218.4503 -2.6517 -2.6775
0.5347 0.79 3300 0.5222 -2.6699 -4.3844 0.6375 1.7145 -228.2368 -222.6600 -2.6712 -2.6909
0.5369 0.82 3400 0.5244 -2.7710 -4.6352 0.6600 1.8642 -230.7449 -223.6711 -2.5304 -2.5595
0.5613 0.84 3500 0.5431 -3.7645 -5.6773 0.6475 1.9128 -241.1664 -233.6063 -2.5348 -2.5604
0.6395 0.87 3600 0.5332 -3.8666 -5.6894 0.6525 1.8227 -241.2867 -234.6274 -2.5479 -2.5778
0.6552 0.89 3700 0.5149 -2.9168 -4.7306 0.6525 1.8138 -231.6990 -225.1294 -2.4580 -2.4901
0.6381 0.91 3800 0.5081 -2.6182 -4.3003 0.6625 1.6821 -227.3964 -222.1432 -2.4730 -2.4991
0.5355 0.94 3900 0.5100 -2.5302 -4.2476 0.6475 1.7173 -226.8689 -221.2634 -2.5875 -2.6065
0.5488 0.96 4000 0.5164 -3.1540 -4.8339 0.6550 1.6798 -232.7318 -227.5013 -2.7017 -2.7215
0.6802 0.99 4100 0.5134 -2.6060 -4.2916 0.6625 1.6856 -227.3087 -222.0207 -2.6010 -2.6250
0.0976 1.01 4200 0.5031 -3.0885 -5.0494 0.6625 1.9609 -234.8874 -226.8463 -2.4721 -2.5028
0.0839 1.03 4300 0.5027 -3.3469 -5.4366 0.6625 2.0897 -238.7592 -229.4302 -2.3886 -2.4238
0.0788 1.06 4400 0.5398 -4.4307 -6.8568 0.6775 2.4261 -252.9614 -240.2679 -2.1805 -2.2275
0.0701 1.08 4500 0.5432 -4.3739 -7.0979 0.6975 2.7240 -255.3717 -239.7001 -2.1935 -2.2437
0.0959 1.11 4600 0.5362 -3.9784 -6.3235 0.6900 2.3451 -247.6284 -235.7450 -2.2860 -2.3272
0.1177 1.13 4700 0.5411 -4.1933 -6.8436 0.6800 2.6504 -252.8295 -237.8937 -2.3259 -2.3682
0.1651 1.16 4800 0.5737 -4.8158 -6.7229 0.6700 1.9071 -251.6221 -244.1190 -2.2753 -2.3139
0.1298 1.18 4900 0.5528 -4.6526 -6.8433 0.6825 2.1907 -252.8262 -242.4874 -2.4856 -2.5188
0.1143 1.2 5000 0.5512 -4.6212 -7.0807 0.6800 2.4595 -255.2000 -242.1734 -2.5190 -2.5542
0.1145 1.23 5100 0.5496 -4.0598 -6.6147 0.6775 2.5548 -250.5396 -236.5594 -2.5737 -2.6008
0.2324 1.25 5200 0.5524 -4.9650 -7.6613 0.6725 2.6962 -261.0058 -245.6115 -2.4382 -2.4737
0.0867 1.28 5300 0.5449 -4.9568 -7.6771 0.6625 2.7203 -261.1645 -245.5292 -2.4367 -2.4702
0.0503 1.3 5400 0.5351 -4.5684 -7.1860 0.6625 2.6176 -256.2527 -241.6449 -2.4235 -2.4557
0.0977 1.32 5500 0.5431 -4.5599 -7.1317 0.6550 2.5718 -255.7096 -241.5597 -2.5311 -2.5614
0.1564 1.35 5600 0.5512 -5.1430 -8.0510 0.6750 2.9080 -264.9027 -247.3911 -2.3498 -2.3976
0.0967 1.37 5700 0.5520 -4.5072 -7.4506 0.6750 2.9433 -258.8989 -241.0335 -2.2110 -2.2631
0.2046 1.4 5800 0.5588 -5.5328 -8.5314 0.6800 2.9986 -269.7068 -251.2888 -2.2155 -2.2677
0.0985 1.42 5900 0.5429 -5.1915 -7.9421 0.6675 2.7505 -263.8138 -247.8765 -2.2606 -2.3077
0.1398 1.44 6000 0.5350 -4.9761 -7.9378 0.6800 2.9616 -263.7706 -245.7224 -2.2291 -2.2809
0.099 1.47 6100 0.5440 -4.6202 -7.4996 0.6650 2.8794 -259.3892 -242.1633 -2.3362 -2.3859
0.1279 1.49 6200 0.5389 -4.9461 -7.7908 0.6625 2.8448 -262.3015 -245.4217 -2.2276 -2.2734
0.0778 1.52 6300 0.5451 -4.9550 -7.8964 0.6625 2.9414 -263.3570 -245.5110 -2.4781 -2.5193
0.0911 1.54 6400 0.5412 -5.4552 -8.3139 0.6675 2.8588 -267.5324 -250.5128 -2.3604 -2.4048
0.2149 1.56 6500 0.5241 -4.4512 -7.3194 0.6725 2.8682 -257.5873 -240.4732 -2.4011 -2.4461
0.1739 1.59 6600 0.5329 -5.0143 -7.7507 0.6825 2.7364 -261.8999 -246.1036 -2.4143 -2.4577
0.0842 1.61 6700 0.5395 -5.1195 -8.0856 0.6800 2.9661 -265.2489 -247.1560 -2.3877 -2.4376
0.105 1.64 6800 0.5423 -4.9379 -7.7557 0.6775 2.8178 -261.9503 -245.3403 -2.3798 -2.4323
0.086 1.66 6900 0.5351 -4.3598 -7.1156 0.6775 2.7559 -255.5494 -239.5588 -2.3870 -2.4383
0.0622 1.68 7000 0.5394 -4.6830 -7.6578 0.6825 2.9747 -260.9710 -242.7915 -2.4276 -2.4779
0.0973 1.71 7100 0.5319 -4.7475 -7.6567 0.6750 2.9091 -260.9596 -243.4364 -2.3010 -2.3564
0.1052 1.73 7200 0.5284 -4.5972 -7.5385 0.6750 2.9413 -259.7779 -241.9329 -2.3696 -2.4201
0.0645 1.76 7300 0.5339 -4.9822 -8.0212 0.6775 3.0390 -264.6048 -245.7831 -2.2857 -2.3440
0.0923 1.78 7400 0.5385 -4.6369 -7.6632 0.6650 3.0263 -261.0246 -242.3295 -2.2563 -2.3150
0.0842 1.81 7500 0.5394 -4.8705 -7.6765 0.6600 2.8060 -261.1580 -244.6661 -2.2808 -2.3287
0.1178 1.83 7600 0.5253 -4.7985 -7.5635 0.6675 2.7650 -260.0276 -243.9457 -2.4022 -2.4463
0.1255 1.85 7700 0.5355 -4.7007 -7.4363 0.6675 2.7355 -258.7556 -242.9684 -2.5073 -2.5501
0.1541 1.88 7800 0.5440 -4.9294 -7.6465 0.6500 2.7172 -260.8584 -245.2547 -2.3551 -2.4036
0.0893 1.9 7900 0.5397 -5.2135 -8.3241 0.6575 3.1106 -267.6339 -248.0959 -2.3214 -2.3784
0.1203 1.93 8000 0.5296 -4.8644 -7.8598 0.6550 2.9954 -262.9913 -244.6054 -2.4509 -2.4969
0.1018 1.95 8100 0.5381 -5.3471 -8.4918 0.6625 3.1447 -269.3113 -249.4323 -2.4193 -2.4671
0.0767 1.97 8200 0.5386 -5.2151 -8.3734 0.6675 3.1582 -268.1267 -248.1124 -2.4873 -2.5329
0.0801 2.0 8300 0.5429 -5.8103 -9.0391 0.6575 3.2288 -274.7842 -254.0639 -2.4348 -2.4867
0.034 2.02 8400 0.5566 -5.7907 -9.2424 0.6625 3.4518 -276.8175 -253.8677 -2.3679 -2.4272
0.0246 2.05 8500 0.5758 -5.6317 -9.1533 0.6625 3.5216 -275.9264 -252.2783 -2.3335 -2.3958
0.0187 2.07 8600 0.5770 -5.5795 -9.2568 0.6725 3.6773 -276.9613 -251.7559 -2.3614 -2.4166
0.0606 2.09 8700 0.6115 -7.1190 -11.2853 0.6750 4.1663 -297.2460 -267.1512 -2.2737 -2.3365
0.0402 2.12 8800 0.6164 -7.0531 -11.1316 0.6600 4.0785 -295.7089 -266.4919 -2.2005 -2.2654
0.0263 2.14 8900 0.6209 -8.1609 -12.3710 0.6650 4.2102 -308.1034 -277.5696 -2.0958 -2.1661
0.0242 2.17 9000 0.6042 -6.7201 -10.7618 0.6725 4.0416 -292.0106 -263.1622 -2.1651 -2.2304
0.0383 2.19 9100 0.6080 -7.7898 -11.9356 0.6750 4.1458 -303.7489 -273.8587 -2.1006 -2.1662
0.0371 2.21 9200 0.6149 -7.5635 -11.7050 0.6675 4.1415 -301.4433 -271.5960 -2.1556 -2.2155
0.0279 2.24 9300 0.6155 -8.1686 -12.4447 0.6775 4.2760 -308.8397 -277.6473 -2.1778 -2.2399
0.021 2.26 9400 0.6137 -7.8294 -12.0416 0.6700 4.2122 -304.8092 -274.2550 -2.2403 -2.2958
0.0374 2.29 9500 0.6238 -7.9227 -12.2842 0.6750 4.3614 -307.2347 -275.1884 -2.2926 -2.3496
0.0412 2.31 9600 0.6126 -7.7094 -11.9775 0.6700 4.2681 -304.1685 -273.0553 -2.2377 -2.2961
0.0413 2.33 9700 0.6130 -7.6030 -11.8721 0.6675 4.2691 -303.1140 -271.9912 -2.2505 -2.3100
0.0361 2.36 9800 0.6248 -8.1273 -12.6010 0.6750 4.4737 -310.4034 -277.2341 -2.2249 -2.2866
0.0289 2.38 9900 0.6192 -7.9924 -12.3825 0.6675 4.3901 -308.2185 -275.8853 -2.2473 -2.3067
0.038 2.41 10000 0.6250 -8.4114 -12.8701 0.6675 4.4586 -313.0937 -280.0753 -2.2312 -2.2938
0.0334 2.43 10100 0.6261 -9.1807 -13.7488 0.6825 4.5681 -321.8813 -287.7679 -2.2303 -2.2947
0.0359 2.45 10200 0.6374 -9.8214 -14.2774 0.6650 4.4560 -327.1667 -294.1750 -2.1817 -2.2452
0.0266 2.48 10300 0.6298 -8.3278 -12.5691 0.6650 4.2413 -310.0836 -279.2391 -2.2947 -2.3521
0.0423 2.5 10400 0.6267 -8.7527 -13.2552 0.6675 4.5025 -316.9453 -283.4879 -2.3034 -2.3620
0.0329 2.53 10500 0.6386 -8.9354 -13.5549 0.6700 4.6195 -319.9424 -285.3152 -2.2819 -2.3423
0.039 2.55 10600 0.6330 -8.3549 -12.8863 0.6775 4.5314 -313.2566 -279.5103 -2.2924 -2.3528
0.0278 2.58 10700 0.6336 -8.6754 -13.1733 0.6675 4.4979 -316.1258 -282.7150 -2.2319 -2.2929
0.0606 2.6 10800 0.6299 -8.7158 -13.0817 0.6700 4.3658 -315.2101 -283.1195 -2.2116 -2.2731
0.0293 2.62 10900 0.6259 -8.9092 -13.2926 0.6725 4.3834 -317.3194 -285.0532 -2.1572 -2.2209
0.0196 2.65 11000 0.6219 -9.1783 -13.5617 0.6700 4.3835 -320.0104 -287.7436 -2.1533 -2.2163
0.0405 2.67 11100 0.6209 -8.9912 -13.3040 0.6700 4.3128 -317.4330 -285.8734 -2.1378 -2.2017
0.0278 2.7 11200 0.6300 -9.8318 -14.2684 0.6700 4.4366 -327.0771 -294.2787 -2.1220 -2.1862
0.0307 2.72 11300 0.6356 -9.7027 -14.1764 0.6700 4.4737 -326.1576 -292.9880 -2.1316 -2.1945
0.0242 2.74 11400 0.6327 -9.8085 -14.2574 0.6625 4.4489 -326.9674 -294.0465 -2.1072 -2.1680
0.0242 2.77 11500 0.6308 -9.3697 -13.8420 0.6650 4.4723 -322.8135 -289.6585 -2.1273 -2.1882
0.0337 2.79 11600 0.6350 -9.2810 -13.7917 0.6700 4.5107 -322.3100 -288.7711 -2.1600 -2.2215
0.0302 2.82 11700 0.6450 -10.2754 -14.9521 0.6675 4.6767 -333.9139 -298.7146 -2.1339 -2.1965
0.0354 2.84 11800 0.6451 -10.3736 -15.0743 0.6725 4.7008 -335.1366 -299.6965 -2.1047 -2.1674
0.0153 2.86 11900 0.6420 -10.2126 -14.9126 0.6700 4.7000 -333.5196 -298.0872 -2.1102 -2.1728
0.0388 2.89 12000 0.6407 -10.2075 -14.9081 0.6725 4.7006 -333.4741 -298.0356 -2.1059 -2.1687
0.0253 2.91 12100 0.6353 -10.0842 -14.7598 0.6650 4.6756 -331.9908 -296.8029 -2.0968 -2.1594
0.0317 2.94 12200 0.6352 -9.9956 -14.6819 0.6750 4.6863 -331.2123 -295.9169 -2.1042 -2.1665
0.0431 2.96 12300 0.6337 -9.8807 -14.5540 0.6675 4.6733 -329.9332 -294.7676 -2.1034 -2.1660
0.0233 2.98 12400 0.6326 -9.8796 -14.5449 0.6675 4.6653 -329.8422 -294.7567 -2.1032 -2.1657

Framework versions

  • Transformers 4.36.2
  • Pytorch 2.1.2+cu121
  • Datasets 2.16.1
  • Tokenizers 0.15.0