EigenRand  0.4.0-alpha
arch/NEON/MorePacketMath.h
Go to the documentation of this file.
1 
12 #ifndef EIGENRAND_MORE_PACKET_MATH_NEON_H
13 #define EIGENRAND_MORE_PACKET_MATH_NEON_H
14 
15 #include <arm_neon.h>
16 
17 // device func of casting for Eigen ~3.3.9
18 #ifdef EIGENRAND_EIGEN_33_MODE
19 namespace Eigen
20 {
21  namespace internal
22  {
23  template<>
24  EIGEN_DEVICE_FUNC inline Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a)
25  {
26  return vcvtq_f32_s32(a);
27  }
28 
29  template<>
30  EIGEN_DEVICE_FUNC inline Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a)
31  {
32  return vcvtq_s32_f32(a);
33  }
34 
35  }
36 }
37 #endif
38 
39 namespace Eigen
40 {
41  namespace internal
42  {
43  template<>
44  struct IsIntPacket<Packet4i> : std::true_type {};
45 
46  template<>
47  struct IsFloatPacket<Packet4f> : std::true_type {};
48 
49  template<>
50  struct HalfPacket<Packet4i>
51  {
52  using type = uint64_t;
53  };
54 
55  template<>
56  struct reinterpreter<Packet4i>
57  {
58  EIGEN_STRONG_INLINE Packet4f to_float(const Packet4i& x)
59  {
60  return (Packet4f)vreinterpretq_f32_s32(x);
61  }
62 
63  EIGEN_STRONG_INLINE Packet4i to_int(const Packet4i& x)
64  {
65  return x;
66  }
67  };
68 
69  template<>
70  struct reinterpreter<Packet4f>
71  {
72  EIGEN_STRONG_INLINE Packet4f to_float(const Packet4f& x)
73  {
74  return x;
75  }
76 
77  EIGEN_STRONG_INLINE Packet4i to_int(const Packet4f& x)
78  {
79  return (Packet4i)vreinterpretq_s32_f32(x);
80  }
81  };
82 
83  template<>
84  EIGEN_STRONG_INLINE Packet4i pcmpeq<Packet4i>(const Packet4i& a, const Packet4i& b)
85  {
86  return vreinterpretq_s32_u32(vceqq_s32(a, b));
87  }
88 
89  template<>
90  EIGEN_STRONG_INLINE Packet4f pcmpeq<Packet4f>(const Packet4f& a, const Packet4f& b)
91  {
92  return vreinterpretq_f32_u32(vceqq_f32(a, b));
93  }
94 
95  template<>
96  EIGEN_STRONG_INLINE Packet4i pbitnot<Packet4i>(const Packet4i& a)
97  {
98  return vmvnq_s32(a);
99  }
100 
101  template<>
102  EIGEN_STRONG_INLINE Packet4f pbitnot<Packet4f>(const Packet4f& a)
103  {
104  return (Packet4f)vreinterpretq_f32_s32(pbitnot((Packet4i)vreinterpretq_s32_f32(a)));
105  }
106 
107  template<>
108  struct BitShifter<Packet4i>
109  {
110  template<int b>
111  EIGEN_STRONG_INLINE Packet4i sll(const Packet4i& a)
112  {
113  return vreinterpretq_s32_u32(vshlq_n_u32(vreinterpretq_u32_s32(a), b));
114  }
115 
116  template<int b>
117  EIGEN_STRONG_INLINE Packet4i srl(const Packet4i& a, int _b = b)
118  {
119  if (b > 0)
120  {
121  return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), b > 0 ? b : 1));
122  }
123  else
124  {
125  switch (_b)
126  {
127  case 0: return a;
128  case 1: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 1));
129  case 2: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 2));
130  case 3: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 3));
131  case 4: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 4));
132  case 5: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 5));
133  case 6: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 6));
134  case 7: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 7));
135  case 8: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 8));
136  case 9: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 9));
137  case 10: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 10));
138  case 11: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 11));
139  case 12: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 12));
140  case 13: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 13));
141  case 14: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 14));
142  case 15: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 15));
143  case 16: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 16));
144  case 17: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 17));
145  case 18: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 18));
146  case 19: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 19));
147  case 20: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 20));
148  case 21: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 21));
149  case 22: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 22));
150  case 23: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 23));
151  case 24: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 24));
152  case 25: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 25));
153  case 26: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 26));
154  case 27: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 27));
155  case 28: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 28));
156  case 29: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 29));
157  case 30: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 30));
158  case 31: return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a), 31));
159  }
160  return vdupq_n_s32(0);
161  }
162  }
163 
164  template<int b>
165  EIGEN_STRONG_INLINE Packet4i sll64(const Packet4i& a)
166  {
167  return vreinterpretq_s32_u64(vshlq_n_u64(vreinterpretq_u64_s32(a), b));
168  }
169 
170  template<int b>
171  EIGEN_STRONG_INLINE Packet4i srl64(const Packet4i& a)
172  {
173  return vreinterpretq_s32_u64(vshrq_n_u64(vreinterpretq_u64_s32(a), b));
174  }
175  };
176 
177  template<>
178  EIGEN_STRONG_INLINE Packet4i pcmplt<Packet4i>(const Packet4i& a, const Packet4i& b)
179  {
180  return vreinterpretq_s32_u32(vcltq_s32(a, b));
181  }
182 
183  template<>
184  EIGEN_STRONG_INLINE Packet4f pcmplt<Packet4f>(const Packet4f& a, const Packet4f& b)
185  {
186  return vreinterpretq_f32_u32(vcltq_f32(a, b));
187  }
188 
189  template<>
190  EIGEN_STRONG_INLINE Packet4f pcmple<Packet4f>(const Packet4f& a, const Packet4f& b)
191  {
192  return vreinterpretq_f32_u32(vcleq_f32(a, b));
193  }
194 
195  template<>
196  EIGEN_STRONG_INLINE Packet4f pblendv(const Packet4f& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket)
197  {
198  return vbslq_f32(vreinterpretq_u32_f32(ifPacket), thenPacket, elsePacket);
199  }
200 
201  template<>
202  EIGEN_STRONG_INLINE Packet4f pblendv(const Packet4i& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket)
203  {
204  return vbslq_f32(vreinterpretq_u32_s32(ifPacket), thenPacket, elsePacket);
205  }
206 
207  template<>
208  EIGEN_STRONG_INLINE Packet4i pblendv(const Packet4i& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket)
209  {
210  return vbslq_s32(vreinterpretq_u32_s32(ifPacket), thenPacket, elsePacket);
211  }
212 
213  template<>
214  EIGEN_STRONG_INLINE Packet4i pgather<Packet4i>(const int* addr, const Packet4i& index)
215  {
216  int32_t u[4];
217  vst1q_s32(u, index);
218  int32_t t[4];
219  t[0] = addr[u[0]];
220  t[1] = addr[u[1]];
221  t[2] = addr[u[2]];
222  t[3] = addr[u[3]];
223  return vld1q_s32(t);
224  }
225 
226  template<>
227  EIGEN_STRONG_INLINE Packet4f pgather<Packet4i>(const float* addr, const Packet4i& index)
228  {
229  int32_t u[4];
230  vst1q_s32(u, index);
231  float t[4];
232  t[0] = addr[u[0]];
233  t[1] = addr[u[1]];
234  t[2] = addr[u[2]];
235  t[3] = addr[u[3]];
236  return vld1q_f32(t);
237  }
238 
239  template<>
240  EIGEN_STRONG_INLINE int pmovemask<Packet4f>(const Packet4f& a)
241  {
242  int32_t bits[4] = { 1, 2, 4, 8 };
243  auto r = vbslq_s32(vreinterpretq_u32_f32(a), vld1q_s32(bits), vdupq_n_s32(0));
244  auto s = vadd_s32(vget_low_s32(r), vget_high_s32(r));
245  return vget_lane_s32(vpadd_s32(s, s), 0);
246  }
247 
248  template<>
249  EIGEN_STRONG_INLINE int pmovemask<Packet4i>(const Packet4i& a)
250  {
251  return pmovemask((Packet4f)vreinterpretq_f32_s32(a));
252  }
253 
254  template<>
255  EIGEN_STRONG_INLINE Packet4f ptruncate<Packet4f>(const Packet4f& a)
256  {
257  return vrndq_f32(a);
258  }
259 
260  template<>
261  EIGEN_STRONG_INLINE Packet4i pseti64<Packet4i>(uint64_t a)
262  {
263  return vreinterpretq_s32_u64(vdupq_n_u64(a));
264  }
265 
266  template<>
267  EIGEN_STRONG_INLINE Packet4i pcmpeq64<Packet4i>(const Packet4i& a, const Packet4i& b)
268  {
269  return vreinterpretq_s32_u64(vceqq_s64(vreinterpretq_s64_s32(a), vreinterpretq_s64_s32(b)));
270  }
271 
272  template<>
273  EIGEN_STRONG_INLINE Packet4i pmuluadd64<Packet4i>(const Packet4i& a, uint64_t b, uint64_t c)
274  {
275  uint64_t u[2];
276  vst1q_u64(u, vreinterpretq_u64_s32(a));
277  u[0] = u[0] * b + c;
278  u[1] = u[1] * b + c;
279  return vreinterpretq_s32_u64(vld1q_u64(u));
280  }
281 
282  #ifdef EIGENRAND_EIGEN_33_MODE
283  template<>
284  EIGEN_STRONG_INLINE Packet4f plog<Packet4f>(const Packet4f& _x)
285  {
286  Packet4f x = _x;
287  _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
288  _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
289  _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
290 
291  const Packet4f p4f_inv_mant_mask = (Packet4f)vreinterpretq_f32_s32(pset1<Packet4i>(~0x7f800000));
292 
293  /* the smallest non denormalized float number */
294  const Packet4f p4f_min_norm_pos = (Packet4f)vreinterpretq_f32_s32(pset1<Packet4i>(0x00800000));
295  const Packet4f p4f_minus_inf = (Packet4f)vreinterpretq_f32_s32(pset1<Packet4i>(0xff800000));
296 
297  /* natural logarithm computed for 4 simultaneous float
298  return NaN for x <= 0
299  */
300  _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
301  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
302  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, -1.1514610310E-1f);
303  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
304  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, -1.2420140846E-1f);
305  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, +1.4249322787E-1f);
306  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, -1.6668057665E-1f);
307  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, +2.0000714765E-1f);
308  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, -2.4999993993E-1f);
309  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, +3.3333331174E-1f);
310  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
311  _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
312 
313 
314  Packet4i emm0;
315 
316  Packet4f invalid_mask = pbitnot(pcmple(pset1<Packet4f>(0), x)); // not greater equal is true if x is NaN
317  Packet4f iszero_mask = pcmpeq(x, pset1<Packet4f>(0));
318 
319  x = pmax(x, p4f_min_norm_pos); /* cut off denormalized stuff */
320  emm0 = BitShifter<Packet4i>{}.template srl<23>((Packet4i)vreinterpretq_s32_f32(x));
321 
322  /* keep only the fractional part */
323  x = pand(x, p4f_inv_mant_mask);
324  x = por(x, p4f_half);
325 
326  emm0 = psub(emm0, p4i_0x7f);
327  Packet4f e = padd(Packet4f(vcvtq_f32_s32(emm0)), p4f_1);
328 
329  /* part2:
330  if( x < SQRTHF ) {
331  e -= 1;
332  x = x + x - 1.0;
333  } else { x = x - 1.0; }
334  */
335  Packet4f mask = pcmplt(x, p4f_cephes_SQRTHF);
336  Packet4f tmp = pand(x, mask);
337  x = psub(x, p4f_1);
338  e = psub(e, pand(p4f_1, mask));
339  x = padd(x, tmp);
340 
341  Packet4f x2 = pmul(x, x);
342  Packet4f x3 = pmul(x2, x);
343 
344  Packet4f y, y1, y2;
345  y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
346  y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
347  y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
348  y = pmadd(y, x, p4f_cephes_log_p2);
349  y1 = pmadd(y1, x, p4f_cephes_log_p5);
350  y2 = pmadd(y2, x, p4f_cephes_log_p8);
351  y = pmadd(y, x3, y1);
352  y = pmadd(y, x3, y2);
353  y = pmul(y, x3);
354 
355  y1 = pmul(e, p4f_cephes_log_q1);
356  tmp = pmul(x2, p4f_half);
357  y = padd(y, y1);
358  x = psub(x, tmp);
359  y2 = pmul(e, p4f_cephes_log_q2);
360  x = padd(x, y);
361  x = padd(x, y2);
362  // negative arg will be NAN, 0 will be -INF
363  return pblendv(iszero_mask, p4f_minus_inf, por(x, invalid_mask));
364  }
365 
366  template<>
367  EIGEN_STRONG_INLINE Packet4f psqrt<Packet4f>(const Packet4f& x)
368  {
369  return vsqrtq_f32(x);
370  }
371 
372  template<>
373  EIGEN_STRONG_INLINE Packet4f psin<Packet4f>(const Packet4f& _x)
374  {
375  Packet4f x = _x;
376  _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
377  _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
378 
379  _EIGEN_DECLARE_CONST_Packet4i(1, 1);
380  _EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
381  _EIGEN_DECLARE_CONST_Packet4i(2, 2);
382  _EIGEN_DECLARE_CONST_Packet4i(4, 4);
383 
384  const Packet4f p4f_sign_mask = (Packet4f)vreinterpretq_f32_s32(pset1<Packet4i>(0x80000000));
385 
386  _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1, -0.78515625f);
387  _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
388  _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
389  _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
390  _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
391  _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
392  _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
393  _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
394  _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
395  _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
396 
397  Packet4f xmm1, xmm2, xmm3, sign_bit, y;
398 
399  Packet4i emm0, emm2;
400  sign_bit = x;
401  /* take the absolute value */
402  x = pabs(x);
403 
404  /* take the modulo */
405 
406  /* extract the sign bit (upper one) */
407  sign_bit = pand(sign_bit, p4f_sign_mask);
408 
409  /* scale by 4/Pi */
410  y = pmul(x, p4f_cephes_FOPI);
411 
412  /* store the integer part of y in mm0 */
413  emm2 = vcvtq_s32_f32(y);
414  /* j=(j+1) & (~1) (see the cephes sources) */
415  emm2 = padd(emm2, p4i_1);
416  emm2 = pand(emm2, p4i_not1);
417  y = vcvtq_f32_s32(emm2);
418  /* get the swap sign flag */
419  emm0 = pand(emm2, p4i_4);
420  emm0 = BitShifter<Packet4i>{}.template sll<29>(emm0);
421  /* get the polynom selection mask
422  there is one polynom for 0 <= x <= Pi/4
423  and another one for Pi/4<x<=Pi/2
424 
425  Both branches will be computed.
426  */
427  emm2 = pand(emm2, p4i_2);
428  emm2 = pcmpeq(emm2, pset1<Packet4i>(0));
429 
430  Packet4f swap_sign_bit = (Packet4f)vreinterpretq_f32_s32(emm0);
431  Packet4f poly_mask = (Packet4f)vreinterpretq_f32_s32(emm2);
432  sign_bit = pxor(sign_bit, swap_sign_bit);
433 
434  /* The magic pass: "Extended precision modular arithmetic"
435  x = ((x - y * DP1) - y * DP2) - y * DP3; */
436  xmm1 = pmul(y, p4f_minus_cephes_DP1);
437  xmm2 = pmul(y, p4f_minus_cephes_DP2);
438  xmm3 = pmul(y, p4f_minus_cephes_DP3);
439  x = padd(x, xmm1);
440  x = padd(x, xmm2);
441  x = padd(x, xmm3);
442 
443  /* Evaluate the first polynom (0 <= x <= Pi/4) */
444  y = p4f_coscof_p0;
445  Packet4f z = pmul(x, x);
446 
447  y = pmadd(y, z, p4f_coscof_p1);
448  y = pmadd(y, z, p4f_coscof_p2);
449  y = pmul(y, z);
450  y = pmul(y, z);
451  Packet4f tmp = pmul(z, p4f_half);
452  y = psub(y, tmp);
453  y = padd(y, p4f_1);
454 
455  /* Evaluate the second polynom (Pi/4 <= x <= 0) */
456 
457  Packet4f y2 = p4f_sincof_p0;
458  y2 = pmadd(y2, z, p4f_sincof_p1);
459  y2 = pmadd(y2, z, p4f_sincof_p2);
460  y2 = pmul(y2, z);
461  y2 = pmul(y2, x);
462  y2 = padd(y2, x);
463 
464  /* select the correct result from the two polynoms */
465  y = pblendv(poly_mask, y2, y);
466  /* update the sign */
467  return pxor(y, sign_bit);
468  }
469  #endif
470  }
471 }
472 
473 #endif