87flowers ~ $

faster pshufb inverse

simd

  1. what is a permutation
  2. permutation bitmatrices
  3. transposing a 16×16 bitmatrix on the cpu
  4. the solution
  5. further work

CPUs have in-vector gather instructions (pshufb, vpermb and family). Unfortunately, unlike GPUs, they do not have in-vector scatter instructions. This is a very annoying gap. Most previous solutions use key-value sorting, but we can do better.

I have wanted to do this performantly for a while. Inspiration suddenly struck me today, and had to document it.

what is a permutation

Let us start with an four element example. Let's say we have the permutation vector:

   idx = [2 1 3 0]

Permuation is the operation:

foreach i:
    result[i] = source[idx[i]]

In otherwords, this vector represents the mapping:

   0 -> 2
   1 -> 1
   2 -> 3
   3 -> 0

We want to find the inverse permuation [3 1 0 2].

This is the permuation that would inverse the mapping like so:

   0 <- 2
   1 <- 1
   2 <- 3
   3 <- 0

And when we sort it by the key:

   3 <- 0
   1 <- 1
   0 <- 2
   2 <- 3

permutation bitmatrices

One way we can do this efficiently is by switching representations and considering the permuation matrix.

The permuation [2 1 3 0] can be thought of as the permutation matrix:

   [ 0 0 0 1 ]
   [ 0 1 0 0 ]
   [ 1 0 0 0 ]
   [ 0 0 1 0 ]

To do the opposite of a permuation, we invert the matrix.

Since permuation matrixes are orthogonal, the inverse is the transpose. We can do this!

transposing a 16×16 bitmatrix on the cpu

Fortunately we have a way of computing the transpose of a 8×8 bitmatrix efficiently on the CPU. We can use this to do a 16×16 transpose.

Introducing, yet again, our venerable gf2p8affineqb instruction! Here is pseudocode showing what the instruction does:

// y = A · x + b (under GF(2))
v128 gf2p8affineqb(v128 x, v128 A, u8 b) {
  v128 result;
  for i = 0..16 {
    result.byte[i] = matrix_multiply(A.qword[i / 8], x.byte[i]) ^ b;
  }
  return result;
}

u8 matrix_multiply(u64 A, u8 x) {
   u8 result = 0;
   for i = 0..8 {
      result.bit[i] = bit_parity(A.byte[7 - i] & x);
   }
   return result;
}

Notice how each bit i in the result depends only on the byte i of the input A!

There is also an unwanted bitreverse operation. We can undo this unwanted operation with a — you guessed it — second gf2p8affineqb instruction.
Hiding in here is a result.byte[i].bit[j] = A.byte[j].bit[i] operation, which is exactly a 8×8 matrix transposition.

Now that we have a 8×8 bitmatrix transpose, we can break the 16×16 bitmatrix into four smaller 8×8 bitmatrixes that can be transposed. We can then recompose these to get our result.

the solution

  1. Convert permuation vector to a 16×16 bitmatrix
  2. Shuffle 16×16 bitmatrix into four 8×8 bitmatrices
  3. Transpose 8×8 bitmatrices
  4. Shuffle four 8×8 bitmatrices back into a 16×16 bitmatrix
  5. Convert from 16×16 bitmatrix to a permutation vector
// An example permutation vector we're using as example input.
__m128i idx = _mm_set_epi8(1, 6, 14, 15, 2, 5, 7, 10, 3, 4, 11, 13, 8, 9, 12, 0);

// convert from permutation vector to a 16×16 bitmatrix
__m256i bm = _mm256_sllv_epi16(_mm256_set1_epi16(1), _mm256_cvtepu8_epi16(idx));

// shuffle to four 8×8 matrices
__m256i bm2 = _mm256_permutexvar_epi8(
  _mm256_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
                   1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31),
  bm);

// transpose the four 8×8 matrices
__m256i trans = _mm256_gf2p8affine_epi64_epi8(
  _mm256_gf2p8affine_epi64_epi8(
    _mm256_set1_epi64x(0x8040201008040201),
    bm2,
    0),
  _mm256_set1_epi64x(0x8040201008040201),
  0);

// shuffle back to 16×16 bitmatrix
__m256i trans2 = _mm256_permutexvar_epi8(
  _mm256_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15,
                   16, 24, 17, 25, 18, 26, 19, 27, 20, 28, 21, 29, 22, 30, 23, 31),
    trans);

// Convert from 16×16 bitmatrix to permutation vector
// We don't have a 16-bit ctz instruction in AVX-512 so we do this instead
__m128i inv = _mm256_cvtepi16_epi8(
  _mm256_popcnt_epi16(_mm256_sub_epi16(trans2, _mm256_set1_epi16(1))));

Example intermediate values for your reference:

idx:    00 0c 09 08 0d 0b 04 03 0a 07 05 02 0f 0e 06 01
bm:     0001 1000 0200 0100 2000 0800 0010 0008 0400 0080 0020 0004 8000 4000 0040 0002
bm2:    01 00 00 00 00 00 10 08 | 00 80 20 04 00 00 40 02 | 00 10 02 01 20 08 00 00 | 04 00 00 00 80 40 00 00
trans:  01 00 00 80 40 00 00 00 | 00 80 08 00 00 04 40 02 | 08 04 00 20 02 10 00 00 | 00 00 01 00 00 00 20 10
trans2: 0001 8000 0800 0080 0040 0400 4000 0200 0008 0004 0100 0020 0002 0010 2000 1000
inv:    00 0f 0b 07 06 0a 0e 09 03 02 08 05 01 04 0d 0c

further work

Extending this to 32-element and 64-element permutes is straightforward and is left as an exercise to the reader.

© 2024. All rights reserved. 87flowers