ChenyangSi commited on
Commit
dd97a63
·
1 Parent(s): 97dc735

Update free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +25 -2
free_lunch_utils.py CHANGED
@@ -93,13 +93,36 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
93
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
  #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
 
 
 
 
 
 
 
 
 
 
 
96
  # --------------- FreeU code -----------------------
97
  # Only operate on the first two stages
98
  if hidden_states.shape[1] == 1280:
99
- hidden_states[:,:640] = hidden_states[:,:640] * self.b1
 
 
 
 
 
 
 
100
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
  if hidden_states.shape[1] == 640:
102
- hidden_states[:,:320] = hidden_states[:,:320] * self.b2
 
 
 
 
 
 
103
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
  # ---------------------------------------------------------
105
 
 
93
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
  #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
 
96
+ # # --------------- FreeU code -----------------------
97
+ # # Only operate on the first two stages
98
+ # if hidden_states.shape[1] == 1280:
99
+ # hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
+ # if hidden_states.shape[1] == 640:
102
+ # hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
+ # # ---------------------------------------------------------
105
+
106
  # --------------- FreeU code -----------------------
107
  # Only operate on the first two stages
108
  if hidden_states.shape[1] == 1280:
109
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
110
+ B = hidden_mean.shape[0]
111
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
112
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
113
+
114
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
115
+
116
+ hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
117
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
118
  if hidden_states.shape[1] == 640:
119
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
120
+ B = hidden_mean.shape[0]
121
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
122
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
123
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
124
+
125
+ hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
126
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
127
  # ---------------------------------------------------------
128