PacketRandomEngine.h
Go to the documentation of this file.
1 
12 #ifndef EIGENRAND_PACKET_RANDOM_ENGINE_H
13 #define EIGENRAND_PACKET_RANDOM_ENGINE_H
14 
15 #include <array>
16 #include <random>
17 #include <type_traits>
19 
20 namespace Eigen
21 {
22  namespace internal
23  {
24  template<typename Ty>
25  struct IsIntPacket : std::false_type {};
26 
27  template<typename Ty>
28  struct HalfPacket;
29 
30 #ifdef EIGEN_VECTORIZE_AVX2
31  template<>
32  struct IsIntPacket<Packet8i> : std::true_type {};
33 
34  template<>
35  struct HalfPacket<Packet8i>
36  {
37  using type = Packet4i;
38  };
39 #endif
40 #ifdef EIGEN_VECTORIZE_SSE2
41  template<>
42  struct IsIntPacket<Packet4i> : std::true_type {};
43 
44  template<>
45  struct HalfPacket<Packet4i>
46  {
47  using type = uint64_t;
48  };
49 #endif
50  template<typename Packet>
51  EIGEN_STRONG_INLINE Packet pcmpeq64(const Packet& a, const Packet& b);
52 
53  template<typename Packet>
54  EIGEN_STRONG_INLINE Packet pmuluadd64(const Packet& a, uint64_t b, uint64_t c);
55 
56 #ifdef EIGEN_VECTORIZE_AVX
57  template<>
58  EIGEN_STRONG_INLINE Packet8i pcmpeq64<Packet8i>(const Packet8i& a, const Packet8i& b)
59  {
60 #ifdef EIGEN_VECTORIZE_AVX2
61  return _mm256_cmpeq_epi64(a, b);
62 #else
63  Packet4i a1, a2, b1, b2;
64  split_two(a, a1, a2);
65  split_two(b, b1, b2);
66  return combine_two(_mm_cmpeq_epi64(a1, b1), _mm_cmpeq_epi64(a2, b2));
67 #endif
68  }
69 
70  template<>
71  EIGEN_STRONG_INLINE Packet8i pmuluadd64<Packet8i>(const Packet8i& a, uint64_t b, uint64_t c)
72  {
73  uint64_t u[4];
74  _mm256_storeu_si256((__m256i*)u, a);
75  u[0] = u[0] * b + c;
76  u[1] = u[1] * b + c;
77  u[2] = u[2] * b + c;
78  u[3] = u[3] * b + c;
79  return _mm256_loadu_si256((__m256i*)u);
80  }
81 #endif
82 
83 #ifdef EIGEN_VECTORIZE_SSE2
84  template<>
85  EIGEN_STRONG_INLINE Packet4i pcmpeq64<Packet4i>(const Packet4i& a, const Packet4i& b)
86  {
87 #ifdef EIGEN_VECTORIZE_SSE4_1
88  return _mm_cmpeq_epi64(a, b);
89 #else
90  Packet4i c = _mm_cmpeq_epi32(a, b);
91  return pand(c, _mm_shuffle_epi32(c, _MM_SHUFFLE(2, 3, 0, 1)));
92 #endif
93  }
94 
95  template<>
96  EIGEN_STRONG_INLINE Packet4i pmuluadd64<Packet4i>(const Packet4i& a, uint64_t b, uint64_t c)
97  {
98  uint64_t u[2];
99  _mm_storeu_si128((__m128i*)u, a);
100  u[0] = u[0] * b + c;
101  u[1] = u[1] * b + c;
102  return _mm_loadu_si128((__m128i*)u);
103  }
104 
105 #endif
106  }
107 
108  namespace Rand
109  {
110  namespace detail
111  {
112  template<typename T>
113  auto test_integral_result_type(int) -> std::integral_constant<bool, std::is_integral<typename T::result_type>::value>;
114 
115  template<typename T>
116  auto test_integral_result_type(...) -> std::false_type;
117 
118  template<typename T>
119  auto test_intpacket_result_type(int)->std::integral_constant<bool, internal::IsIntPacket<typename T::result_type>::value>;
120 
121  template<typename T>
122  auto test_intpacket_result_type(...)->std::false_type;
123  }
124 
125  template<typename Ty>
126  struct IsScalarRandomEngine : decltype(detail::test_integral_result_type<Ty>(0))
127  {
128  };
129 
130  template<typename Ty>
131  struct IsPacketRandomEngine : decltype(detail::test_intpacket_result_type<Ty>(0))
132  {
133  };
134 
135  enum class RandomEngineType
136  {
137  none, scalar, packet
138  };
139 
140  template<typename Ty>
141  struct GetRandomEngineType : std::integral_constant <
142  RandomEngineType,
143  IsPacketRandomEngine<Ty>::value ? RandomEngineType::packet :
144  (IsScalarRandomEngine<Ty>::value ? RandomEngineType::scalar : RandomEngineType::none)
145  >
146  {
147  };
148 
149 #ifndef EIGEN_DONT_VECTORIZE
150 
170  template<typename Packet,
171  int _Nx, int _Mx,
172  int _Rx, uint64_t _Px,
173  int _Ux, uint64_t _Dx,
174  int _Sx, uint64_t _Bx,
175  int _Tx, uint64_t _Cx,
176  int _Lx, uint64_t _Fx>
178  {
179  public:
180  using result_type = Packet;
181 
182  static constexpr int word_size = 64;
183  static constexpr int state_size = _Nx;
184  static constexpr int shift_size = _Mx;
185  static constexpr int mask_bits = _Rx;
186  static constexpr uint64_t parameter_a = _Px;
187  static constexpr int output_u = _Ux;
188  static constexpr int output_s = _Sx;
189  static constexpr uint64_t output_b = _Bx;
190  static constexpr int output_t = _Tx;
191  static constexpr uint64_t output_c = _Cx;
192  static constexpr int output_l = _Lx;
193 
194  static constexpr uint64_t default_seed = 5489U;
195 
204  MersenneTwister(uint64_t x0 = default_seed)
205  {
206  using namespace Eigen::internal;
207  std::array<uint64_t, unpacket_traits<Packet>::size / 2> seeds;
208  for (uint64_t i = 0; i < seeds.size(); ++i)
209  {
210  seeds[i] = x0 + i;
211  }
212  seed(ploadu<Packet>((int*)seeds.data()));
213  }
214 
220  MersenneTwister(Packet x0)
221  {
222  seed(x0);
223  }
224 
230  void seed(Packet x0)
231  {
232  using namespace Eigen::internal;
233  Packet prev = state[0] = x0;
234  for (int i = 1; i < _Nx; ++i)
235  {
236  prev = state[i] = pmuluadd64(pxor(prev, psrl64(prev, word_size - 2)), _Fx, i);
237  }
238  stateIdx = _Nx;
239  }
240 
246  uint64_t min() const
247  {
248  return 0;
249  }
250 
256  uint64_t max() const
257  {
258  return _wMask;
259  }
260 
269  result_type operator()()
270  {
271  if (stateIdx == _Nx)
272  refill_upper();
273  else if (2 * _Nx <= stateIdx)
274  refill_lower();
275 
276  using namespace Eigen::internal;
277 
278  Packet res = state[stateIdx++];
279  res = pxor(res, pand(psrl64(res, _Ux), pseti64<Packet>(_Dx)));
280  res = pxor(res, pand(psll64(res, _Sx), pseti64<Packet>(_Bx)));
281  res = pxor(res, pand(psll64(res, _Tx), pseti64<Packet>(_Cx)));
282  res = pxor(res, psrl64(res, _Lx));
283  return res;
284  }
285 
291  void discard(unsigned long long num)
292  {
293  for (; 0 < num; --num)
294  {
295  operator()();
296  }
297  }
298 
299  typename internal::HalfPacket<Packet>::type half()
300  {
301  if (valid)
302  {
303  valid = false;
304  return cache;
305  }
306  typename internal::HalfPacket<Packet>::type a;
307  internal::split_two(operator()(), a, cache);
308  valid = true;
309  return a;
310  }
311 
312  protected:
313 
314  void refill_lower()
315  {
316  using namespace Eigen::internal;
317 
318  auto hmask = pseti64<Packet>(_hMask),
319  lmask = pseti64<Packet>(_lMask),
320  px = pseti64<Packet>(_Px),
321  one = pseti64<Packet>(1);
322 
323  int i;
324  for (i = 0; i < _Nx - _Mx; ++i)
325  {
326  Packet tmp = por(pand(state[i + _Nx], hmask),
327  pand(state[i + _Nx + 1], lmask));
328 
329  state[i] = pxor(pxor(
330  psrl64(tmp, 1),
331  pand(pcmpeq64(pand(tmp, one), one), px)),
332  state[i + _Nx + _Mx]
333  );
334  }
335 
336  for (; i < _Nx - 1; ++i)
337  {
338  Packet tmp = por(pand(state[i + _Nx], hmask),
339  pand(state[i + _Nx + 1], lmask));
340 
341  state[i] = pxor(pxor(
342  psrl64(tmp, 1),
343  pand(pcmpeq64(pand(tmp, one), one), px)),
344  state[i - _Nx + _Mx]
345  );
346  }
347 
348  Packet tmp = por(pand(state[i + _Nx], hmask),
349  pand(state[0], lmask));
350  state[i] = pxor(pxor(
351  psrl64(tmp, 1),
352  pand(pcmpeq64(pand(tmp, one), one), px)),
353  state[_Mx - 1]
354  );
355  stateIdx = 0;
356  }
357 
358  void refill_upper()
359  {
360  using namespace Eigen::internal;
361 
362  auto hmask = pseti64<Packet>(_hMask),
363  lmask = pseti64<Packet>(_lMask),
364  px = pseti64<Packet>(_Px),
365  one = pseti64<Packet>(1);
366 
367  for (int i = _Nx; i < 2 * _Nx; ++i)
368  {
369  Packet tmp = por(pand(state[i - _Nx], hmask),
370  pand(state[i - _Nx + 1], lmask));
371 
372  state[i] = pxor(pxor(
373  psrl64(tmp, 1),
374  pand(pcmpeq64(pand(tmp, one), one), px)),
375  state[i - _Nx + _Mx]
376  );
377  }
378  }
379 
380  std::array<Packet, _Nx * 2> state;
381  size_t stateIdx = 0;
382  typename internal::HalfPacket<Packet>::type cache;
383  bool valid = false;
384 
385  static constexpr uint64_t _wMask = (uint64_t)-1;
386  static constexpr uint64_t _hMask = (_wMask << _Rx) & _wMask;
387  static constexpr uint64_t _lMask = ~_hMask & _wMask;
388  };
389 
395  template<typename Packet>
396  using Pmt19937_64 = MersenneTwister<Packet, 312, 156, 31,
397  0xb5026f5aa96619e9, 29,
398  0x5555555555555555, 17,
399  0x71d67fffeda60000, 37,
400  0xfff7eee000000000, 43, 6364136223846793005>;
401 #endif
402 
409  template<typename UIntType, typename BaseRng>
411  {
412  static_assert(IsPacketRandomEngine<BaseRng>::value, "BaseRNG must be a kind of PacketRandomEngine.");
413  public:
414  using result_type = UIntType;
415 
416  PacketRandomEngineAdaptor(const BaseRng& _rng)
417  : rng{ _rng }
418  {
419  }
420 
421  PacketRandomEngineAdaptor(BaseRng&& _rng)
422  : rng{ _rng }
423  {
424  }
425 
428 
429  static constexpr result_type min()
430  {
431  return std::numeric_limits<result_type>::min();
432  }
433 
434  static constexpr result_type max()
435  {
436  return std::numeric_limits<result_type>::max();
437  }
438 
439  result_type operator()()
440  {
441  if (cnt >= buf_size)
442  {
443  refill_buffer();
444  }
445  return buf[cnt++];
446  }
447 
448  private:
449  static constexpr size_t buf_size = 64 / sizeof(result_type);
450 
451  void refill_buffer()
452  {
453  cnt = 0;
454  const size_t stride = sizeof(typename BaseRng::result_type) / sizeof(result_type);
455  for (size_t i = 0; i < buf_size; i += stride)
456  {
457  *(typename BaseRng::result_type*)&buf[i] = rng();
458  }
459  }
460 
461  BaseRng rng;
462  std::array<result_type, buf_size> buf;
463  size_t cnt = buf_size;
464 
465  };
466 
475  template<typename UIntType, typename Rng>
476  auto makeScalarRng(Rng&& rng) -> typename std::enable_if<
477  IsPacketRandomEngine<typename std::remove_reference<Rng>::type>::value,
479  >::type
480  {
481  return { std::forward<Rng>(rng) };
482  }
483 
484  template<typename UIntType, typename Rng>
485  auto makeScalarRng(Rng&& rng) -> typename std::enable_if<
486  IsScalarRandomEngine<typename std::remove_reference<Rng>::type>::value,
487  typename std::remove_reference<Rng>::type
488  >::type
489  {
490  return std::forward<Rng>(rng);
491  }
492 
493 #ifdef EIGEN_VECTORIZE_AVX2
494  using Vmt19937_64 = Pmt19937_64<internal::Packet8i>;
495 #elif defined(EIGEN_VECTORIZE_AVX) || defined(EIGEN_VECTORIZE_SSE2)
496  using Vmt19937_64 = Pmt19937_64<internal::Packet4i>;
497 #else
498 
504  using Vmt19937_64 = std::mt19937_64;
505 #endif
506 
507  }
508 }
509 
510 #endif
Eigen::Rand::MersenneTwister::min
uint64_t min() const
minimum value of the result
Definition: PacketRandomEngine.h:246
Eigen::Rand::MersenneTwister
A vectorized version of Mersenne Twister Engine.
Definition: PacketRandomEngine.h:178
MorePacketMath.h
Eigen::Rand::MersenneTwister::MersenneTwister
MersenneTwister(Packet x0)
Construct a new Mersenne Twister engine with a packet seed.
Definition: PacketRandomEngine.h:220
Eigen::Rand::MersenneTwister::discard
void discard(unsigned long long num)
Discards num items being generated.
Definition: PacketRandomEngine.h:291
Eigen::Rand::Vmt19937_64
std::mt19937_64 Vmt19937_64
same as std::mt19937_64 when EIGEN_DONT_VECTORIZE, Pmt19937_64<internal::Packet4i> when SSE2 enabled ...
Definition: PacketRandomEngine.h:504
Eigen::Rand::MersenneTwister::operator()
result_type operator()()
Generates one random packet and advance the internal state.
Definition: PacketRandomEngine.h:269
Eigen::Rand::MersenneTwister::max
uint64_t max() const
maximum value of the result
Definition: PacketRandomEngine.h:256
Eigen::Rand::makeScalarRng
auto makeScalarRng(Rng &&rng) -> typename std::enable_if< IsPacketRandomEngine< typename std::remove_reference< Rng >::type >::value, PacketRandomEngineAdaptor< UIntType, typename std::remove_reference< Rng >::type > >::type
Helper function for making a PacketRandomEngineAdaptor.
Definition: PacketRandomEngine.h:476
Eigen::Rand::PacketRandomEngineAdaptor
Scalar adaptor for random engines which generates packet.
Definition: PacketRandomEngine.h:411
Eigen::Rand::MersenneTwister::seed
void seed(Packet x0)
initialize the engine with a given seed
Definition: PacketRandomEngine.h:230
Eigen::Rand::MersenneTwister::MersenneTwister
MersenneTwister(uint64_t x0=default_seed)
Construct a new Mersenne Twister engine with a scalar seed.
Definition: PacketRandomEngine.h:204