Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions VocabLib/VocabTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
/* VocabTree.cpp */
/* Build a vocabulary tree from a set of vectors */

// Enabling SSE gives around a 2x speedup in the vec_diff_normsq function.
#define USE_SSE

#include <assert.h>
#include <float.h>
#include <limits.h>
Expand All @@ -50,12 +53,115 @@
#include "qsort.h"
#include "util.h"

#ifdef USE_SSE
#include <emmintrin.h>
#endif

/* Useful utility function for computing the squared distance between
* two vectors a and b of length dim */
static unsigned long vec_diff_normsq(int dim,
unsigned char *a,
unsigned char *b)
{
#ifdef USE_SSE

// When using SSE, values are processed 16 at a time, so we must enforce
// that the dimension is a multiple of 16.
assert(dim % 16 == 0);

// Set 16, 8-bit values to zero.
__m128i zero8 = _mm_set1_epi8(0);

// Set 8, 16-bit values to zero.
__m128i zero16 = _mm_set1_epi16(0);

// Set 4, 32-bit values to zero.
__m128i sum32 = _mm_set1_epi32(0);

__m128i a8;
__m128i b8;
__m128i a16;
__m128i b16;
__m128i sub16;
__m128i mul16;
__m128i mul32;

// We will process 16 values at a time.
const int dim_16 = dim / 16;
for (int i = 0; i < dim_16; ++i)
{
// Load 16, 8-bit values from each descriptor without assuming anything
// about the alignment of the underlying pointers.
a8 = _mm_loadu_si128((__m128i *)(a + 16 * i));
b8 = _mm_loadu_si128((__m128i *)(b + 16 * i));

/////////////////////////////////////////////////
// Process the lower 8 values of the descriptors.

// Interleave the lower 8, 8-bit values of the descriptors with zeros,
// effectively converting the 8 values to 16-bit values.
a16 = _mm_unpacklo_epi8(a8, zero8);
b16 = _mm_unpacklo_epi8(b8, zero8);

// Compute the difference between the 8, 16-bit values.
sub16 = _mm_sub_epi16(a16, b16);

// Compute the squared difference between the 8, 16-bit values.
mul16 = _mm_mullo_epi16(sub16, sub16);

// Interleave the lower 4, 16-bit values with zeros to effectively
// convert to 4, 32-bit values.
mul32 = _mm_unpacklo_epi16(mul16, zero16);

// Add the current 4 values to the running sum that stores 4, 32-bit values.
sum32 = _mm_add_epi32(sum32, mul32);

// Interleave the upper 4, 16-bit values with zeros to effectively
// convert to 4, 32-bit values.
mul32 = _mm_unpackhi_epi16(mul16, zero16);

// Add the current 4 values to the running sum that stores 4, 32-bit values.
sum32 = _mm_add_epi32(sum32, mul32);

/////////////////////////////////////////////////
// Process the upper 8 values of the descriptors.

// Interleave the upper 8, 8-bit values of the descriptors with zeros,
// effectively converting the 8 values to 16-bit values.
a16 = _mm_unpackhi_epi8(a8, zero8);
b16 = _mm_unpackhi_epi8(b8, zero8);

// Compute the difference between the 8, 16-bit values.
sub16 = _mm_sub_epi16(a16, b16);

// Compute the squared difference between the 8, 16-bit values.
mul16 = _mm_mullo_epi16(sub16, sub16);

// Interleave the lower 4, 16-bit values with zeros to effectively
// convert to 4, 32-bit values.
mul32 = _mm_unpacklo_epi16(mul16, zero16);

// Add the current 4 values to the running sum that stores 4, 32-bit values.
sum32 = _mm_add_epi32(sum32, mul32);

// Interleave the upper 4, 16-bit values with zeros to effectively
// convert to 4, 32-bit values.
mul32 = _mm_unpackhi_epi16(mul16, zero16);

// Add the current 4 values to the running sum that stores 4, 32-bit values.
sum32 = _mm_add_epi32(sum32, mul32);
}

// Copy the running sum of the squared differences to 4, 32-bit values without
// assuming anything about the alignment of the pointer.
unsigned int sum[4];
_mm_storeu_si128((__m128i *)sum, sum32);

// Manually sum the 4, 32-bit values, and return the result.
return sum[0] + sum[1] + sum[2] + sum[3];

#else // USE_SSE

int i;
unsigned long normsq = 0;

Expand All @@ -65,6 +171,8 @@ static unsigned long vec_diff_normsq(int dim,
}

return normsq;

#endif // USE_SSE
}

void VocabTreeInteriorNode::Clear(int bf)
Expand Down