`FlashAttention is not installed` error on Windows 11

#87
by ocean11 - opened

I am using windows 11 and successfully installed flash-attn show in the following pic. But still get this RuntimeError: FlashAttention is not installed error. So it does not support Windows if I want to use flash-attention?

image.png

Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that the RuntimeError: FlashAttention is not installed exception is raised if you failed to run from flash_attn.ops.triton.rotary import apply_rotary.
This line requires both flash attention and triton.
So, I guess you should also install the triton by running pip install triton
스크린샷 2024-10-20 오후 5.34.59.png

Error is as below

"name": "RuntimeError",
    "message": "FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 2
      1 print(len(chunks))
----> 2 chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
      3 # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
      4 top_k = min(3, len(chunks))

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:623, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
    620 features.update(extra_features)
    622 with torch.no_grad():
--> 623     out_features = self.forward(features, **kwargs)
    624     if self.device.type == \"hpu\":
    625         out_features = copy.deepcopy(out_features)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:690, in SentenceTransformer.forward(self, input, **kwargs)
    688     module_kwarg_keys = self.module_kwargs.get(module_name, [])
    689     module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 690     input = module(input, **module_kwargs)
    691 return input

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\jina-embeddings-v3\\30996fea06f69ecd8382ee4f11e29acaf6b5405e\\custom_st.py:143, in Transformer.forward(self, features, task)
    139 lora_arguments = (
    140     {\"adapter_mask\": adapter_mask} if adapter_mask is not None else {}
    141 )
    142 features.pop('prompt_length', None)
--> 143 output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
    144 output_tokens = output_states[0]
    145 features.update({\"token_embeddings\": output_tokens, \"attention_mask\": features[\"attention_mask\"]})

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_lora.py:370, in XLMRobertaLoRA.forward(self, *args, **kwargs)
    369 def forward(self, *args, **kwargs):
--> 370     return self.roberta(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:709, in XLMRobertaModel.forward(self, input_ids, position_ids, token_type_ids, attention_mask, masked_tokens_mask, return_dict, **kwargs)
    706 else:
    707     subset_mask = None
--> 709 sequence_output = self.encoder(
    710     hidden_states,
    711     key_padding_mask=attention_mask,
    712     subset_mask=subset_mask,
    713     adapter_mask=adapter_mask,
    714 )
    716 if masked_tokens_mask is None:
    717     pooled_output = (
    718         self.pooler(sequence_output, adapter_mask=adapter_mask)
    719         if self.pooler is not None
    720         else None
    721     )

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:241, in XLMRobertaEncoder.forward(self, hidden_states, key_padding_mask, subset_mask, adapter_mask)
    234             hidden_states = torch.utils.checkpoint.checkpoint(
    235                 layer,
    236                 hidden_states,
    237                 use_reentrant=self.use_reentrant,
    238                 mixer_kwargs=mixer_kwargs,
    239             )
    240         else:
--> 241             hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
    242     hidden_states = pad_input(hidden_states, indices, batch, seqlen)
    243 else:

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\block.py:201, in Block.forward(self, hidden_states, residual, mixer_subset, mixer_kwargs)
    199 else:
    200     assert residual is None
--> 201     mixer_out = self.mixer(
    202         hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
    203     )
    204     if self.return_residual:  # mixer out is actually a pair here
    205         mixer_out, hidden_states = mixer_out

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\mha.py:732, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, adapter_mask, **kwargs)
    725 if (
    726     inference_params is None
    727     or inference_params.seqlen_offset == 0
    728     or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
    729     or not self.use_flash_attn
    730 ):
    731     if self.rotary_emb_dim > 0:
--> 732         qkv = self.rotary_emb(
    733             qkv,
    734             seqlen_offset=seqlen_offset,
    735             cu_seqlens=cu_seqlens,
    736             max_seqlen=rotary_max_seqlen,
    737         )
    738     if inference_params is None:
    739         if not self.checkpointing:

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:604, in RotaryEmbedding.forward(self, qkv, kv, seqlen_offset, cu_seqlens, max_seqlen)
    602 if kv is None:
    603     if self.scale is None:
--> 604         return apply_rotary_emb_qkv_(
    605             qkv,
    606             self._cos_cached,
    607             self._sin_cached,
    608             interleaved=self.interleaved,
    609             seqlen_offsets=seqlen_offset,
    610             cu_seqlens=cu_seqlens,
    611             max_seqlen=max_seqlen,
    612             use_flash_attn=self.use_flash_attn,
    613         )
    614     else:
    615         return apply_rotary_emb_qkv_(
    616             qkv,
    617             self._cos_cached,
   (...)
    625             use_flash_attn=self.use_flash_attn,
    626         )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:327, in apply_rotary_emb_qkv_(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
    297 def apply_rotary_emb_qkv_(
    298     qkv,
    299     cos,
   (...)
    307     use_flash_attn=True,
    308 ):
    309     \"\"\"
    310     Arguments:
    311         qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
   (...)
    325     Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
    326     \"\"\"
--> 327     return ApplyRotaryEmbQKV_.apply(
    328         qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
    329     )

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\\autograd\\function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         \"In order to use an autograd.Function with functorch transforms \"
    580         \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
    581         \"staticmethod. For more details, please see \"
    582         \"https://pytorch.org/docs/main/notes/extending.func.html\"
    583     )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:186, in ApplyRotaryEmbQKV_.forward(ctx, qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
    184     qk = rearrange(qkv[..., :2, :, :], \"... t h d -> ... (t h) d\")
    185     # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
--> 186     apply_rotary(
    187         qk,
    188         cos,
    189         sin,
    190         seqlen_offsets=seqlen_offsets,
    191         interleaved=interleaved,
    192         inplace=True,
    193         cu_seqlens=cu_seqlens,
    194         max_seqlen=max_seqlen,
    195     )
    196 else:
    197     q_rot = apply_rotary_emb_torch(
    198         qkv[:, :, 0],
    199         cos,
    200         sin,
    201         interleaved=interleaved,
    202     )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:18, in apply_rotary(*args, **kwargs)
     17 def apply_rotary(*args, **kwargs):
---> 18     raise RuntimeError(
     19         \"FlashAttention is not installed. To proceed with training, please install FlashAttention. \"
     20         \"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.\"
     21     )

RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.

</div> ```
Jina AI org

Hi @ocean11 , afaik flash-attention is not fully supported on windows, and that's why you're having this issue. You can disable it by setting use_flash_attn=False when loading the model.

How to set in sentence transformers? It seems there is no place to put use_flash_attn=False.
Previously if I do not install flash attention, the model will just give out several lines of warning not using it.

import torch
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)

chunks= [...]
queries = [...]
# Use "convert_to_tensor=True" to keep the tensors on GPU (if available)
chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity
top_k = min(3, len(chunks))
for query in queries:
    query_embedding = embedder.encode(query, convert_to_tensor=True)
    # We use cosine-similarity and torch.topk to find the highest 5 scores
    similarity_scores = embedder.similarity(query_embedding, chunks_embeddings)[0]
    scores, indices = torch.topk(similarity_scores, k=top_k)
    print("\n\033[91m--- Query:", query, "---\033[0m")
    print("Top 3 most similar sentences in corpus:")
...
Jina AI org

For SentenceTransformers you can set it like this:
embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, model_kwargs={'use_flash_attn': False})

For SentenceTransformers you can set it like this:
embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, model_kwargs={'use_flash_attn': False})

Problem solved as suggested. Thanks a lot!

jupyterjazz changed discussion status to closed

Sign up or log in to comment