phoebeklett commited on
Commit
2bd703d
1 Parent(s): dd6bd5b

Delete blocks.py

Browse files
Files changed (1) hide show
  1. blocks.py +0 -120
blocks.py DELETED
@@ -1,120 +0,0 @@
1
- # Adapted from https://github.com/mosaicml/llm-foundry
2
- # Classes changed: MPTBlock
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """GPT Blocks used for the GPT Model."""
6
-
7
- from typing import Dict, Optional, Tuple
8
- import torch
9
- import torch.nn as nn
10
- from .attention import ATTN_CLASS_REGISTRY
11
- from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
12
-
13
- class MPTMLP(nn.Module):
14
-
15
- def __init__(self,
16
- d_model: int,
17
- expansion_ratio: int,
18
- device: Optional[str] = None):
19
- super().__init__()
20
- self.up_proj = nn.Linear(d_model,
21
- expansion_ratio * d_model,
22
- device=device)
23
- self.act = nn.GELU(approximate='none')
24
- self.down_proj = nn.Linear(expansion_ratio * d_model,
25
- d_model,
26
- device=device)
27
- self.down_proj._is_residual = True # type: ignore
28
-
29
- def forward(self, x):
30
- return self.down_proj(self.act(self.up_proj(x)))
31
-
32
- class MPTBlock(nn.Module):
33
- def __init__(
34
- self,
35
- d_model: int,
36
- n_heads: int,
37
- expansion_ratio: int,
38
- attn_config: Dict = {
39
- 'attn_type': 'multihead_attention',
40
- 'attn_pdrop': 0.0,
41
- 'attn_impl': 'triton',
42
- 'qk_ln': False,
43
- 'clip_qkv': None,
44
- 'softmax_scale': None,
45
- 'prefix_lm': False,
46
- 'attn_uses_sequence_id': False,
47
- 'alibi': False,
48
- 'alibi_bias_max': 8,
49
- },
50
- resid_pdrop: float = 0.0,
51
- norm_type: str = 'low_precision_layernorm',
52
- verbose: int = 0,
53
- device: Optional[str] = None,
54
- **kwargs):
55
- del kwargs # unused, just to capture any extra args from the config
56
- super().__init__()
57
-
58
- norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
59
- attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
60
-
61
- self.norm_1 = norm_class(d_model, device=device)
62
- self.attn = attn_class(
63
- attn_impl=attn_config['attn_impl'],
64
- clip_qkv=attn_config['clip_qkv'],
65
- qk_ln=attn_config['qk_ln'],
66
- softmax_scale=attn_config['softmax_scale'],
67
- attn_pdrop=attn_config['attn_pdrop'],
68
- d_model=d_model,
69
- n_heads=n_heads,
70
- verbose=verbose,
71
- device=device,
72
- )
73
- self.norm_2 = norm_class(d_model, device=device)
74
- self.ffn = MPTMLP(
75
- d_model=d_model,
76
- expansion_ratio=expansion_ratio,
77
- device=device,
78
- )
79
- self.resid_attn_dropout = nn.Dropout(resid_pdrop)
80
- self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
81
-
82
- def forward(
83
- self,
84
- x: torch.Tensor,
85
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
86
- long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None,
87
- attn_bias: Optional[torch.Tensor] = None,
88
- attn_bias_ae: Optional[torch.Tensor] = None,
89
- attention_mask: Optional[torch.ByteTensor] = None,
90
- is_causal: bool = True,
91
- topk:int=None,
92
- needs_weights:bool=None,
93
- faiss_indexes:Tuple=None,
94
- n_layers:int=None,
95
- current_layer:int=None,
96
- mask_by_sim:bool=False,
97
- sim_threshold:float=None
98
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
99
- a = self.norm_1(x)
100
- b, attn_weights, past_key_value, reshaped_idx = self.attn(
101
- a,
102
- past_key_value=past_key_value,
103
- long_range_past_key_value=long_range_past_key_value,
104
- attn_bias=attn_bias,
105
- attn_bias_ae=attn_bias_ae,
106
- attention_mask=attention_mask,
107
- is_causal=is_causal,
108
- topk=topk,
109
- needs_weights=needs_weights,
110
- faiss_indexes=faiss_indexes,
111
- n_layers=n_layers,
112
- current_layer=current_layer,
113
- mask_by_sim=mask_by_sim,
114
- sim_threshold=sim_threshold
115
- )
116
- x = x + self.resid_attn_dropout(b)
117
- m = self.norm_2(x)
118
- n = self.ffn(m)
119
- x = x + self.resid_ffn_dropout(n)
120
- return x, attn_weights, past_key_value, reshaped_idx