This post explains flash attention 1 2. More references are also useful to understand flash attention as well 3 4 5. Backgrounds # Attention # $$\text{Attention}(Q, K, V)=\text{softmax}(\frac{QK^T}{\sqrt{d^k}})V$$ This equation can be implemented as: class OPTAttention(nn.Module): def forward(...): # hidden states is an input tenor of Attention layer # Calculate Q, K, and V with linear projections to the input # query_states = self.q_proj(hidden_states) # key_states = self.k_proj(hidden_state...