Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks#450
Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks#450lessw2020 wants to merge 9 commits intofacebookresearch:mainfrom
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #450 +/- ##
==========================================
+ Coverage 69.11% 69.24% +0.13%
==========================================
Files 170 170
Lines 11524 11580 +56
==========================================
+ Hits 7965 8019 +54
- Misses 3559 3561 +2
☔ View full report in Codecov by Sentry. |
ebsmothers
left a comment
There was a problem hiding this comment.
Looks good! Just a few minor things, mainly around testing and comments
| assert_expected(actual.shape, expected) | ||
|
|
||
|
|
||
| def test_rotary_embeddings_math(): |
There was a problem hiding this comment.
Can we put these unit tests into a class? (Similar to the other tests in this file)
| return cur_freqs.view(*shape, 2) | ||
|
|
||
| def forward( | ||
| self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, float] |
There was a problem hiding this comment.
Do you think it makes sense to have start_pos default to 0? (My assumption is that this would at least be the starting point for most users)
| Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed | ||
| ratio: int | ||
| The ratio for the geometric progression to compute the rotation angles | ||
| """ |
There was a problem hiding this comment.
It'd be nice to add more in the docstring on the exact details of these embeddings, e.g. at least the [[cos, -sin], [sin, cos]] matrix and maybe even a small example (like the simple 2D one you wrote for the unit test)
| assert_expected(qr[0, :, 1], qr2[1, :, 0]) | ||
|
|
||
| assert_expected(kr[0], kr2[0]) | ||
| assert_expected(kr[0, :, 1], kr2[1, :, 0]) |
There was a problem hiding this comment.
Can we also add a test for updating the cached frequencies? (As far as I can tell this second test is not hitting that block in L262-268, lmk if I'm misunderstanding)
There was a problem hiding this comment.
yes, that's a good idea.
| k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 | ||
|
|
||
| if isinstance(start_pos, int): | ||
| if start_pos + seq_len > self.max_seq_len_cached: |
There was a problem hiding this comment.
Some comments here about when the frequencies need to be recomputed might be helpful
There was a problem hiding this comment.
sounds good - offhand should be changing dtype, changing device, and resetting seq len > max_seq_len.
| ) | ||
| self.compute_freqs_cis(max_position_embeddings) | ||
|
|
||
| def compute_freqs_cis( |
There was a problem hiding this comment.
Random q: what does cis mean here?
There was a problem hiding this comment.
it's short form for rotation transform technically doing e^(alpha*i) = cos(alpha) + i * sin(alpha), or shortened, cos + i * sin = cis.
There was a problem hiding this comment.
should probably add that in the docstring actually, otherwise it's too cryptic.
|
@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
rohan-varma
left a comment
There was a problem hiding this comment.
high level comment, but let's maybe create a modules/layers/embeddings folder in the future as we might have multiple embedding layers.
Summary:
Adds Rotary Positional Embeddings (RoPE)
Test plan:
two unit tests - one for math, one for padding