jupyterjazz commited on
Commit
dc4080e
1 Parent(s): 169b7fb

fix: override use_flash_attn in lora

Browse files
Files changed (1) hide show
  1. modeling_lora.py +1 -4
modeling_lora.py CHANGED
@@ -322,12 +322,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
322
  use_safetensors: bool = None,
323
  **kwargs,
324
  ):
325
- config = XLMRobertaFlashConfig.from_pretrained(
326
- pretrained_model_name_or_path, *model_args, **kwargs
327
- )
328
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
329
  return super().from_pretrained(
330
- pretrained_model_name_or_path, *model_args, **kwargs
331
  )
332
  else: # initializing new adapters
333
  roberta = XLMRobertaModel.from_pretrained(
 
322
  use_safetensors: bool = None,
323
  **kwargs,
324
  ):
 
 
 
325
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
326
  return super().from_pretrained(
327
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
328
  )
329
  else: # initializing new adapters
330
  roberta = XLMRobertaModel.from_pretrained(