I am new to using intrinsics, but I wanted to write a function that takes a vector of 4 paired numbers a > 1e-5 ? std::sqrt(a) : 0.0. My first instinct was this:
#include <immintrin.h>
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
return is_valid * sqrt_val;
}
which according to gcc.godbolt.com compiles into the following
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
but I worry what will happen if it sqrt_valcontains nan. I do not think that 0.0 * nanwill work. what are the best practices to do here?
Edit
After reading the comment from @ChrisCooper (and @njuffa), I was linked to another stack overflow answer, and so I will check for equality itself, and then andthis with my result.
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
__m256d result = is_valid * sqrt_val;
__m256d cmpeq = result == result;
return _mm256_and_pd(cmpeq, result);
}
which compiles to the next
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
vcmpeqpd ymm1, ymm0, ymm0
vandpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269