Website Source: blog / sdpa_optim
Summary
Pending synthesis from local website source.
Original source title: This is a clever fix:
Extracted Preview
Note: I wrote this article because I wanted to learn Pytorch internals.
SDPA Optimization - Introduction
Scaled Dot Product Attention(abbreviated as SDPA), is a attention mechanism where the dot products between the dynamic vectors(query, key, value) are scaled down by sqrt(d_k). The attention scores are calculated as:

SDPA(or self attention) was a revolutionary discovery, first introduced in the "[Attention Is All You Need](https://arxiv.org/pdf/1706.03762)" paper which formed the backbone of modern NLP applications. SDPA enhanced computation speed and allowed parallel computing of input sequence, allowing to capture meaningful relationship between tokens.
Why SDPA Optimization?
In my FlexAttention [blog](https://yash-sri.xyz/blog/flex_attention), I explained in detail how the straightforward implementation of SDPA has quadratic compute and memory complexity with respect to sequence length. It is because of these bottlenecks, using optimized version of SDPA such as Flash Attention or Flex Attention are preferred for deployment.
While I was working on the FlexAttention blog, and was beginning to understand how each approach optimized the standard SDPA(or different variants of SDPA), especially from memory constraints, I found three different directions which I should explore and experiment which is the most promising amongst them.
Explained in the next section in detail, my approach with this case study is to explore how different approaches stand against each other, and how much memory can we save in comparison to standard SDPA. This case study involves a lot of experiments and testing, and will be supported by code wherever necessary. I've presented the case study as a work log, so you can see how I went from idea X to idea Y, following the initial chain of thought. Results and further directions are discussed in the end.
Initial Approach
As mentioned in the previous section, I found three axes along which I hypothesized we can explore(either individually or grouped together) each direction and study how much memory overhead we can reduce. My initial approach is given [here](https://yash-sri.xyz/scratchpad) on my scratchpad.
Integration Notes
- Source section:
blog - Local source:
/home/yashs/Desktop/Programming/yash_blog/yash-srivastava19.github.io/blog/sdpa_optim.md - Raw copy:
raw/website/yash-srivastava19-github-io/blog/sdpa_optim.md