DavidGF commited on
Commit
cd357c3
1 Parent(s): 0a2cee8

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +65 -18
modeling_phi.py CHANGED
@@ -928,27 +928,74 @@ class PhiModel(PhiPreTrainedModel):
928
  return hidden_states
929
 
930
  group_definitions = [
931
- list(range(0, 4)), # Indices for the first group
932
- list(range(4, 8)), # Indices for the second group, with overlap
933
- list(range(8, 12)), # Indices for the second group, with overlap
934
- list(range(12, 16)), # Indices for the second group, with overlap
935
- list(range(16, 20)), # Indices for the second group, with overlap
936
- list(range(20, 24)), # Indices for the second group, with overlap
937
- list(range(24, 28)), # Indices for the second group, with overlap
938
- list(range(28, 32)), # Indices for the second group, with overlap
939
- # Add more groups as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
  ]
941
 
942
  repetitions = [
943
- 1, # Repetitions for the first group
944
- 2, # Repetitions for the second group
945
- 2, # Repetitions for the second group
946
- 2, # Repetitions for the second group
947
- 2, # Repetitions for the second group
948
- 2, # Repetitions for the second group
949
- 2, # Repetitions for the second group
950
- 1, # Repetitions for the second group
951
- # Add more repetition counts as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
  ]
953
 
954
  class AdvancedSharedLayerModule(nn.Module):
 
928
  return hidden_states
929
 
930
  group_definitions = [
931
+ list(range(0, 1)), # Indices for the 1st group
932
+ list(range(1, 2)),
933
+ list(range(2, 3)),
934
+ list(range(3, 4)),
935
+ list(range(4, 5)),
936
+ list(range(5, 6)),
937
+ list(range(6, 7)),
938
+ list(range(7, 8)),
939
+ list(range(8, 9)),
940
+ list(range(9, 10)),
941
+ list(range(10, 11)),
942
+ list(range(11, 12)),
943
+ list(range(12, 13)),
944
+ list(range(13, 14)),
945
+ list(range(14, 15)),
946
+ list(range(15, 16)),
947
+ list(range(16, 17)),
948
+ list(range(17, 18)),
949
+ list(range(18, 19)),
950
+ list(range(19, 20)),
951
+ list(range(20, 21)),
952
+ list(range(21, 22)),
953
+ list(range(22, 23)),
954
+ list(range(23, 24)),
955
+ list(range(24, 25)),
956
+ list(range(25, 26)),
957
+ list(range(26, 27)),
958
+ list(range(27, 28)),
959
+ list(range(28, 29)),
960
+ list(range(29, 30)),
961
+ list(range(30, 31)),
962
+ list(range(31, 32)),
963
  ]
964
 
965
  repetitions = [
966
+ 1,
967
+ 2,
968
+ 2,
969
+ 2,
970
+ 2,
971
+ 2,
972
+ 2,
973
+ 2,
974
+ 2,
975
+ 2,
976
+ 2,
977
+ 2,
978
+ 2,
979
+ 2,
980
+ 2,
981
+ 2,
982
+ 2,
983
+ 2,
984
+ 2,
985
+ 2,
986
+ 2,
987
+ 2,
988
+ 2,
989
+ 2,
990
+ 2,
991
+ 2,
992
+ 2,
993
+ 2,
994
+ 2,
995
+ 2,
996
+ 2,
997
+ 1,
998
+
999
  ]
1000
 
1001
  class AdvancedSharedLayerModule(nn.Module):