multimodalart HF staff commited on
Commit
8079453
1 Parent(s): a5e4f9a

Update lora.py

Browse files
Files changed (1) hide show
  1. lora.py +7 -2
lora.py CHANGED
@@ -114,7 +114,7 @@ class LoRAModule(torch.nn.Module):
114
 
115
  lx = self.lora_up(lx)
116
 
117
- return org_forwarded + lx * self.multiplier #* scale
118
 
119
 
120
  class LoRAInfModule(LoRAModule):
@@ -219,7 +219,12 @@ class LoRAInfModule(LoRAModule):
219
 
220
  def default_forward(self, x):
221
  # print("default_forward", self.lora_name, x.size())
222
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier #* self.scale
 
 
 
 
 
223
 
224
  def forward(self, x):
225
  if not self.enabled:
 
114
 
115
  lx = self.lora_up(lx)
116
 
117
+ return org_forwarded + lx * self.multiplier * scale
118
 
119
 
120
  class LoRAInfModule(LoRAModule):
 
219
 
220
  def default_forward(self, x):
221
  # print("default_forward", self.lora_name, x.size())
222
+ org_forward = self.org_forward(x)
223
+ lora_up_down = self.lora_up(self.lora_down(x))
224
+ print(org_forward)
225
+ print(lora_up_down)
226
+ print(self.multiplier)
227
+ return org_forward + lora_up_down * self.multiplier #* self.scale
228
 
229
  def forward(self, x):
230
  if not self.enabled: