|
| 1 | +// Copyright 2025-present the zvec project |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <algorithm> |
| 18 | +#include <cmath> |
| 19 | +#include <cstddef> |
| 20 | +#include <cstring> |
| 21 | +#include <vector> |
| 22 | +#include "product_quantizer.h" |
| 23 | + |
| 24 | +namespace zvec { |
| 25 | +namespace ailego { |
| 26 | + |
| 27 | +/*! Optimized Product Quantization (OPQ) |
| 28 | + * |
| 29 | + * Learns an orthogonal rotation matrix R that minimizes quantization error |
| 30 | + * when applied before PQ encoding. Uses the Orthogonal Procrustes method |
| 31 | + * (SVD-based) to iteratively refine the rotation. |
| 32 | + * |
| 33 | + * Reference: |
| 34 | + * T. Ge, K. He, Q. Ke, J. Sun. "Optimized Product Quantization." |
| 35 | + * IEEE TPAMI, 2014. |
| 36 | + */ |
| 37 | +class OptimizedProductQuantizer { |
| 38 | + public: |
| 39 | + //! Constructor |
| 40 | + //! @param m Number of sub-quantizers. |
| 41 | + //! @param k Number of centroids per sub-quantizer. |
| 42 | + //! @param n_iter Number of outer OPQ iterations. |
| 43 | + //! @param pq_iter Number of inner k-means iterations per PQ training. |
| 44 | + OptimizedProductQuantizer(size_t m, size_t k, size_t n_iter = 20, |
| 45 | + size_t pq_iter = 10) |
| 46 | + : m_(m), k_(k), n_iter_(n_iter), pq_iter_(pq_iter) {} |
| 47 | + |
| 48 | + //! Retrieve the number of sub-quantizers |
| 49 | + size_t m() const { return m_; } |
| 50 | + |
| 51 | + //! Retrieve the number of centroids per sub-quantizer |
| 52 | + size_t k() const { return k_; } |
| 53 | + |
| 54 | + //! Retrieve the vector dimension |
| 55 | + size_t dim() const { return dim_; } |
| 56 | + |
| 57 | + //! Check if the quantizer is trained |
| 58 | + bool is_trained() const { return is_trained_; } |
| 59 | + |
| 60 | + //! Retrieve the learned rotation matrix (dim x dim, row-major) |
| 61 | + const std::vector<float> &rotation_matrix() const { |
| 62 | + return rotation_; |
| 63 | + } |
| 64 | + |
| 65 | + //! Retrieve the underlying PQ quantizer |
| 66 | + const ProductQuantizer &pq() const { return *pq_; } |
| 67 | + |
| 68 | + //! Train the OPQ on a set of vectors. |
| 69 | + //! Alternates between learning the rotation R and retraining PQ codebooks. |
| 70 | + //! @param data Training vectors (n x dim), row-major. |
| 71 | + //! @param n Number of training vectors. |
| 72 | + //! @param dim Vector dimension (must be divisible by m). |
| 73 | + void train(const float *data, size_t n, size_t dim) { |
| 74 | + ailego_check_with(dim % m_ == 0, "Dimension must be divisible by m"); |
| 75 | + dim_ = dim; |
| 76 | + |
| 77 | + // Initialize rotation as identity |
| 78 | + rotation_.assign(dim * dim, 0.0f); |
| 79 | + for (size_t i = 0; i < dim; ++i) { |
| 80 | + rotation_[i * dim + i] = 1.0f; |
| 81 | + } |
| 82 | + |
| 83 | + std::vector<float> rotated(n * dim); |
| 84 | + |
| 85 | + for (size_t iter = 0; iter < n_iter_; ++iter) { |
| 86 | + // Step 1: Rotate vectors: rotated = data * R^T |
| 87 | + MatMul(data, rotation_.data(), rotated.data(), n, dim, dim, false, true); |
| 88 | + |
| 89 | + // Step 2: Train PQ on rotated vectors |
| 90 | + pq_.reset(new ProductQuantizer(m_, k_, pq_iter_)); |
| 91 | + pq_->train(rotated.data(), n, dim); |
| 92 | + |
| 93 | + // Step 3: Decode to get reconstruction |
| 94 | + std::vector<uint8_t> codes(n * m_); |
| 95 | + pq_->encode(rotated.data(), n, codes.data()); |
| 96 | + std::vector<float> decoded(n * dim); |
| 97 | + pq_->decode(codes.data(), n, decoded.data()); |
| 98 | + |
| 99 | + // Step 4: Solve Orthogonal Procrustes for optimal rotation |
| 100 | + // minimize ||X * R^T - Y_hat|| where X = data, Y_hat = decoded |
| 101 | + // M = X^T * Y_hat, SVD(M) = U * S * V^T, then R = V * U^T |
| 102 | + LearnRotation(data, decoded.data(), n, dim); |
| 103 | + } |
| 104 | + |
| 105 | + is_trained_ = true; |
| 106 | + } |
| 107 | + |
| 108 | + //! Rotate vectors using the learned rotation. |
| 109 | + //! @param data Input vectors (n x dim). |
| 110 | + //! @param n Number of vectors. |
| 111 | + //! @param out Output rotated vectors (n x dim). |
| 112 | + void rotate(const float *data, size_t n, float *out) const { |
| 113 | + ailego_check_with(is_trained_, "OPQ not trained"); |
| 114 | + MatMul(data, rotation_.data(), out, n, dim_, dim_, false, true); |
| 115 | + } |
| 116 | + |
| 117 | + //! Inverse-rotate vectors (multiply by R, since R is orthogonal). |
| 118 | + //! @param data Input rotated vectors (n x dim). |
| 119 | + //! @param n Number of vectors. |
| 120 | + //! @param out Output original-space vectors (n x dim). |
| 121 | + void inverse_rotate(const float *data, size_t n, float *out) const { |
| 122 | + ailego_check_with(is_trained_, "OPQ not trained"); |
| 123 | + MatMul(data, rotation_.data(), out, n, dim_, dim_, false, false); |
| 124 | + } |
| 125 | + |
| 126 | + //! Encode vectors using OPQ (rotate then PQ encode). |
| 127 | + //! @param data Input vectors (n x dim). |
| 128 | + //! @param n Number of vectors. |
| 129 | + //! @param codes Output PQ codes (n x m). |
| 130 | + void encode(const float *data, size_t n, uint8_t *codes) const { |
| 131 | + ailego_check_with(is_trained_, "OPQ not trained"); |
| 132 | + std::vector<float> rotated(n * dim_); |
| 133 | + rotate(data, n, rotated.data()); |
| 134 | + pq_->encode(rotated.data(), n, codes); |
| 135 | + } |
| 136 | + |
| 137 | + //! Decode PQ codes back to vectors (PQ decode then inverse rotate). |
| 138 | + //! @param codes Input PQ codes (n x m). |
| 139 | + //! @param n Number of vectors. |
| 140 | + //! @param out Output reconstructed vectors (n x dim). |
| 141 | + void decode(const uint8_t *codes, size_t n, float *out) const { |
| 142 | + ailego_check_with(is_trained_, "OPQ not trained"); |
| 143 | + std::vector<float> decoded(n * dim_); |
| 144 | + pq_->decode(codes, n, decoded.data()); |
| 145 | + inverse_rotate(decoded.data(), n, out); |
| 146 | + } |
| 147 | + |
| 148 | + //! Compute quantization distortion (mean squared error). |
| 149 | + float distortion(const float *data, size_t n) const { |
| 150 | + std::vector<uint8_t> codes(n * m_); |
| 151 | + encode(data, n, codes.data()); |
| 152 | + std::vector<float> decoded(n * dim_); |
| 153 | + decode(codes.data(), n, decoded.data()); |
| 154 | + |
| 155 | + double total = 0.0; |
| 156 | + for (size_t i = 0; i < n * dim_; ++i) { |
| 157 | + double diff = data[i] - decoded[i]; |
| 158 | + total += diff * diff; |
| 159 | + } |
| 160 | + return static_cast<float>(total / n); |
| 161 | + } |
| 162 | + |
| 163 | + private: |
| 164 | + OptimizedProductQuantizer(const OptimizedProductQuantizer &) = delete; |
| 165 | + OptimizedProductQuantizer &operator=(const OptimizedProductQuantizer &) = |
| 166 | + delete; |
| 167 | + |
| 168 | + //! Matrix multiplication: C = A * B (or A * B^T) |
| 169 | + //! A is (M x K), B is (K x N) or (N x K) if transpose_b. |
| 170 | + static void MatMul(const float *A, const float *B, float *C, |
| 171 | + size_t M, size_t K, size_t N, |
| 172 | + bool transpose_a, bool transpose_b) { |
| 173 | + for (size_t i = 0; i < M; ++i) { |
| 174 | + for (size_t j = 0; j < N; ++j) { |
| 175 | + float sum = 0.0f; |
| 176 | + for (size_t p = 0; p < K; ++p) { |
| 177 | + float a = transpose_a ? A[p * M + i] : A[i * K + p]; |
| 178 | + float b = transpose_b ? B[j * K + p] : B[p * N + j]; |
| 179 | + sum += a * b; |
| 180 | + } |
| 181 | + C[i * N + j] = sum; |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + //! Solve Orthogonal Procrustes problem via SVD. |
| 187 | + //! Given original data X and decoded Y_hat, find R such that |
| 188 | + //! ||X * R^T - Y_hat||_F is minimized. |
| 189 | + //! Solution: M = X^T * Y_hat, SVD(M) = U * S * V^T, R = V * U^T. |
| 190 | + void LearnRotation(const float *X, const float *Y, |
| 191 | + size_t n, size_t dim) { |
| 192 | + // Compute M = X^T * Y (dim x dim) |
| 193 | + std::vector<float> M(dim * dim, 0.0f); |
| 194 | + MatMul(X, Y, M.data(), dim, n, dim, true, false); |
| 195 | + |
| 196 | + // SVD via Jacobi one-sided method |
| 197 | + std::vector<float> U(dim * dim); |
| 198 | + std::vector<float> S(dim); |
| 199 | + std::vector<float> Vt(dim * dim); |
| 200 | + JacobiSVD(M.data(), U.data(), S.data(), Vt.data(), dim); |
| 201 | + |
| 202 | + // R = V * U^T (V = Vt^T, so R = Vt^T * U^T = (U * Vt)^T) |
| 203 | + // Compute U * Vt first, then transpose |
| 204 | + std::vector<float> UVt(dim * dim); |
| 205 | + MatMul(U.data(), Vt.data(), UVt.data(), dim, dim, dim, false, false); |
| 206 | + |
| 207 | + // R = (U * Vt)^T |
| 208 | + for (size_t i = 0; i < dim; ++i) { |
| 209 | + for (size_t j = 0; j < dim; ++j) { |
| 210 | + rotation_[i * dim + j] = UVt[j * dim + i]; |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + //! Jacobi one-sided SVD for a square matrix. |
| 216 | + //! Computes A = U * diag(S) * V^T where A, U, V^T are dim x dim. |
| 217 | + //! Uses cyclic Jacobi rotations for convergence. |
| 218 | + static void JacobiSVD(const float *A, float *U, float *S, float *Vt, |
| 219 | + size_t dim) { |
| 220 | + size_t n = dim; |
| 221 | + |
| 222 | + // Copy A into working matrix (will become U * S) |
| 223 | + std::vector<float> W(n * n); |
| 224 | + std::memcpy(W.data(), A, n * n * sizeof(float)); |
| 225 | + |
| 226 | + // V starts as identity (we build V, then Vt = V^T) |
| 227 | + std::vector<float> V(n * n, 0.0f); |
| 228 | + for (size_t i = 0; i < n; ++i) V[i * n + i] = 1.0f; |
| 229 | + |
| 230 | + // Jacobi iterations |
| 231 | + const size_t max_sweeps = 100; |
| 232 | + const float eps = 1e-10f; |
| 233 | + |
| 234 | + for (size_t sweep = 0; sweep < max_sweeps; ++sweep) { |
| 235 | + float off_norm = 0.0f; |
| 236 | + |
| 237 | + for (size_t p = 0; p < n; ++p) { |
| 238 | + for (size_t q = p + 1; q < n; ++q) { |
| 239 | + // Compute 2x2 Gram matrix entries for columns p and q |
| 240 | + float app = 0.0f, aqq = 0.0f, apq = 0.0f; |
| 241 | + for (size_t i = 0; i < n; ++i) { |
| 242 | + app += W[i * n + p] * W[i * n + p]; |
| 243 | + aqq += W[i * n + q] * W[i * n + q]; |
| 244 | + apq += W[i * n + p] * W[i * n + q]; |
| 245 | + } |
| 246 | + |
| 247 | + off_norm += apq * apq; |
| 248 | + |
| 249 | + if (std::abs(apq) < eps * std::sqrt(app * aqq)) continue; |
| 250 | + |
| 251 | + // Compute Jacobi rotation angle |
| 252 | + float tau = (aqq - app) / (2.0f * apq); |
| 253 | + float t; |
| 254 | + if (tau >= 0.0f) { |
| 255 | + t = 1.0f / (tau + std::sqrt(1.0f + tau * tau)); |
| 256 | + } else { |
| 257 | + t = -1.0f / (-tau + std::sqrt(1.0f + tau * tau)); |
| 258 | + } |
| 259 | + float c = 1.0f / std::sqrt(1.0f + t * t); |
| 260 | + float s = t * c; |
| 261 | + |
| 262 | + // Apply rotation to W columns p and q |
| 263 | + for (size_t i = 0; i < n; ++i) { |
| 264 | + float wp = W[i * n + p]; |
| 265 | + float wq = W[i * n + q]; |
| 266 | + W[i * n + p] = c * wp - s * wq; |
| 267 | + W[i * n + q] = s * wp + c * wq; |
| 268 | + } |
| 269 | + |
| 270 | + // Apply rotation to V columns p and q |
| 271 | + for (size_t i = 0; i < n; ++i) { |
| 272 | + float vp = V[i * n + p]; |
| 273 | + float vq = V[i * n + q]; |
| 274 | + V[i * n + p] = c * vp - s * vq; |
| 275 | + V[i * n + q] = s * vp + c * vq; |
| 276 | + } |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + if (off_norm < eps * eps) break; |
| 281 | + } |
| 282 | + |
| 283 | + // Extract singular values and normalize columns of W to get U |
| 284 | + for (size_t j = 0; j < n; ++j) { |
| 285 | + float norm = 0.0f; |
| 286 | + for (size_t i = 0; i < n; ++i) { |
| 287 | + norm += W[i * n + j] * W[i * n + j]; |
| 288 | + } |
| 289 | + S[j] = std::sqrt(norm); |
| 290 | + |
| 291 | + float inv_norm = (S[j] > 0.0f) ? (1.0f / S[j]) : 0.0f; |
| 292 | + for (size_t i = 0; i < n; ++i) { |
| 293 | + U[i * n + j] = W[i * n + j] * inv_norm; |
| 294 | + } |
| 295 | + } |
| 296 | + |
| 297 | + // Vt = V^T |
| 298 | + for (size_t i = 0; i < n; ++i) { |
| 299 | + for (size_t j = 0; j < n; ++j) { |
| 300 | + Vt[i * n + j] = V[j * n + i]; |
| 301 | + } |
| 302 | + } |
| 303 | + } |
| 304 | + |
| 305 | + size_t m_; |
| 306 | + size_t k_; |
| 307 | + size_t n_iter_; |
| 308 | + size_t pq_iter_; |
| 309 | + size_t dim_{0}; |
| 310 | + bool is_trained_{false}; |
| 311 | + std::vector<float> rotation_; |
| 312 | + std::unique_ptr<ProductQuantizer> pq_; |
| 313 | +}; |
| 314 | + |
| 315 | +} // namespace ailego |
| 316 | +} // namespace zvec |
0 commit comments