From 44a1b638f6e059ffb2c19c54b77f8001f12242a3 Mon Sep 17 00:00:00 2001 From: Jared Heinly Date: Wed, 18 Feb 2015 19:20:05 -0500 Subject: [PATCH] Add SSE to vec_diff_normsq function --- VocabLib/VocabTree.cpp | 108 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/VocabLib/VocabTree.cpp b/VocabLib/VocabTree.cpp index 3a50e27..c57b331 100644 --- a/VocabLib/VocabTree.cpp +++ b/VocabLib/VocabTree.cpp @@ -37,6 +37,9 @@ /* VocabTree.cpp */ /* Build a vocabulary tree from a set of vectors */ +// Enabling SSE gives around a 2x speedup in the vec_diff_normsq function. +#define USE_SSE + #include #include #include @@ -50,12 +53,115 @@ #include "qsort.h" #include "util.h" +#ifdef USE_SSE + #include +#endif + /* Useful utility function for computing the squared distance between * two vectors a and b of length dim */ static unsigned long vec_diff_normsq(int dim, unsigned char *a, unsigned char *b) { +#ifdef USE_SSE + + // When using SSE, values are processed 16 at a time, so we must enforce + // that the dimension is a multiple of 16. + assert(dim % 16 == 0); + + // Set 16, 8-bit values to zero. + __m128i zero8 = _mm_set1_epi8(0); + + // Set 8, 16-bit values to zero. + __m128i zero16 = _mm_set1_epi16(0); + + // Set 4, 32-bit values to zero. + __m128i sum32 = _mm_set1_epi32(0); + + __m128i a8; + __m128i b8; + __m128i a16; + __m128i b16; + __m128i sub16; + __m128i mul16; + __m128i mul32; + + // We will process 16 values at a time. + const int dim_16 = dim / 16; + for (int i = 0; i < dim_16; ++i) + { + // Load 16, 8-bit values from each descriptor without assuming anything + // about the alignment of the underlying pointers. + a8 = _mm_loadu_si128((__m128i *)(a + 16 * i)); + b8 = _mm_loadu_si128((__m128i *)(b + 16 * i)); + + ///////////////////////////////////////////////// + // Process the lower 8 values of the descriptors. + + // Interleave the lower 8, 8-bit values of the descriptors with zeros, + // effectively converting the 8 values to 16-bit values. + a16 = _mm_unpacklo_epi8(a8, zero8); + b16 = _mm_unpacklo_epi8(b8, zero8); + + // Compute the difference between the 8, 16-bit values. + sub16 = _mm_sub_epi16(a16, b16); + + // Compute the squared difference between the 8, 16-bit values. + mul16 = _mm_mullo_epi16(sub16, sub16); + + // Interleave the lower 4, 16-bit values with zeros to effectively + // convert to 4, 32-bit values. + mul32 = _mm_unpacklo_epi16(mul16, zero16); + + // Add the current 4 values to the running sum that stores 4, 32-bit values. + sum32 = _mm_add_epi32(sum32, mul32); + + // Interleave the upper 4, 16-bit values with zeros to effectively + // convert to 4, 32-bit values. + mul32 = _mm_unpackhi_epi16(mul16, zero16); + + // Add the current 4 values to the running sum that stores 4, 32-bit values. + sum32 = _mm_add_epi32(sum32, mul32); + + ///////////////////////////////////////////////// + // Process the upper 8 values of the descriptors. + + // Interleave the upper 8, 8-bit values of the descriptors with zeros, + // effectively converting the 8 values to 16-bit values. + a16 = _mm_unpackhi_epi8(a8, zero8); + b16 = _mm_unpackhi_epi8(b8, zero8); + + // Compute the difference between the 8, 16-bit values. + sub16 = _mm_sub_epi16(a16, b16); + + // Compute the squared difference between the 8, 16-bit values. + mul16 = _mm_mullo_epi16(sub16, sub16); + + // Interleave the lower 4, 16-bit values with zeros to effectively + // convert to 4, 32-bit values. + mul32 = _mm_unpacklo_epi16(mul16, zero16); + + // Add the current 4 values to the running sum that stores 4, 32-bit values. + sum32 = _mm_add_epi32(sum32, mul32); + + // Interleave the upper 4, 16-bit values with zeros to effectively + // convert to 4, 32-bit values. + mul32 = _mm_unpackhi_epi16(mul16, zero16); + + // Add the current 4 values to the running sum that stores 4, 32-bit values. + sum32 = _mm_add_epi32(sum32, mul32); + } + + // Copy the running sum of the squared differences to 4, 32-bit values without + // assuming anything about the alignment of the pointer. + unsigned int sum[4]; + _mm_storeu_si128((__m128i *)sum, sum32); + + // Manually sum the 4, 32-bit values, and return the result. + return sum[0] + sum[1] + sum[2] + sum[3]; + +#else // USE_SSE + int i; unsigned long normsq = 0; @@ -65,6 +171,8 @@ static unsigned long vec_diff_normsq(int dim, } return normsq; + +#endif // USE_SSE } void VocabTreeInteriorNode::Clear(int bf)