For deepseek FP8 blockwise recipe, the scaling block size for weights is 128x128. Having 2D block scaling means sharding by blocks will make memory incontiguous. Does it mean that we have to go through the following function to "permute" the original [N, K] weight tensor. And then after the parameter AG, do we also need to "unpermute" it such that we can launch a regular FP8 GEMM?
def to_blocked_128x128(a: torch.Tensor) -> torch.Tensor:
N, K = a.shape
return a.view(N//128, 128, K//128, 128).permute(0, 2, 1, 3).contiguous()
