Article

Obsidian Source: SDPA Optimization Case Study - A Worklog

Summary

Pending synthesis from local Obsidian source.

Original source title: This is a clever fix:

Extracted Preview

Introduction

Scaled Dot Product Attention(abbreviated as SDPA), is a attention mechanism where the dot products between the dynamic vectors(query, key, value) are scaled down by sqrt(d_k). The attention scores are calculated as:

!Pasted image 20250122172026.png

SDPA(or self attention) was a revolutionary discovery, first introduced in the "[Attention Is All You Need](https://arxiv.org/pdf/1706.03762)" paper which formed the backbone of modern NLP applications. SDPA enhanced computation speed and allowed parallel computing of input sequence, allowing to capture meaningful relationship between tokens.

Why SDPA Optimization?

In my FlexAttention [blog](https://yash-sri.xyz/blog/flex_attention), I explained in detail how the straightforward implementation of SDPA has quadratic compute and memory complexity with respect to sequence length. It is because of these bottlenecks, using optimized version of SDPA such as Flash Attention or Flex Attention are preferred for deployment.

While I was working on the FlexAttention blog, and was beginning to understand how each approach optimized the standard SDPA(or different variants of SDPA), especially from memory constraints, I found three different directions which I should explore and experiment which is the most promising amongst them.

Explained in the next section in detail, my approach with this case study is to explore how different approaches stand against each other, and how much memory can we save in comparison to standard SDPA. This case study involves a lot of experiments and testing, and will be supported by code wherever necessary. I've presented the case study as a work log, so you can see how I went from idea X to idea Y, following the initial chain of thought. Results and further directions are discussed in the end.

Initial Approach

As mentioned in the previous section, I found three axes along which I hypothesized we can explore(either individually or grouped together) each direction and study how much memory overhead we can reduce. My initial approach is given [here](https://yash-sri.xyz/scratchpad) on my scratchpad.

Approach 1: If we think from first principles, the way to reduce memory footprint is to either - reduce the size of the model, or optimize the computation heavy step. As we know that SDPA scales quadrupedally with sequence length, one trivial axes to explore was reducing the sequence length. So, the initial 3 directions I explored can be visualized as:

!Screenshot 2025-01-17 122513.png

Integration Notes

  • Source folder: /home/yashs/Documents/Docs/Obsidian/Research-Notes
  • Local source: /home/yashs/Documents/Docs/Obsidian/Research-Notes/SDPA Optimization Case Study - A Worklog.md
  • Raw copy: raw/obsidian/research-notes/SDPA Optimization Case Study - A Worklog.md

Links Created Or Updated

Open Questions