12 #ifndef EIGENRAND_MVDISTS_MVNORMAL_H
13 #define EIGENRAND_MVDISTS_MVNORMAL_H
21 template<
typename _Scalar, Index Dim,
typename _Mat>
22 Matrix<_Scalar, Dim, Dim> get_lt(
const MatrixBase<_Mat>& mat)
24 LLT<Matrix<_Scalar, Dim, Dim>> llt(mat);
25 if (llt.info() == Eigen::Success)
31 SelfAdjointEigenSolver<Matrix<_Scalar, Dim, Dim>> solver(mat);
32 if (solver.info() != Eigen::Success)
34 throw std::runtime_error{
"The matrix cannot be solved!" };
36 return solver.eigenvectors() * solver.eigenvalues().cwiseMax(0).cwiseSqrt().asDiagonal();
41 class LowerTriangular {};
42 class InvLowerTriangular {};
45 constexpr detail::FullMatrix full_matrix;
46 constexpr detail::LowerTriangular lower_triangular;
47 constexpr detail::InvLowerTriangular inv_lower_triangular;
56 template<
typename _Scalar, Index Dim = -1>
59 Matrix<_Scalar, Dim, 1> mean;
60 Matrix<_Scalar, Dim, Dim> lt;
72 template<
typename MeanTy,
typename LTTy>
73 MvNormalGen(
const MatrixBase<MeanTy>& _mean,
const MatrixBase<LTTy>& _lt, detail::LowerTriangular)
74 : mean{ _mean }, lt{ _lt }
76 eigen_assert(_mean.cols() == 1 && _mean.rows() == _lt.rows() && _lt.rows() == _lt.cols());
87 template<
typename MeanTy,
typename CovTy>
88 MvNormalGen(
const MatrixBase<MeanTy>& _mean,
const MatrixBase<CovTy>& _cov, detail::FullMatrix = {})
89 :
MvNormalGen{ _mean, detail::template get_lt<_Scalar, Dim>(_cov), lower_triangular }
96 Index dims()
const {
return mean.rows(); }
98 template<
typename Urng>
99 inline auto generate(Urng&& urng, Index samples)
100 -> decltype((lt * stdnorm.template generate<Matrix<_Scalar, Dim, -1>>(mean.rows(), samples, std::forward<Urng>(urng))).colwise() + mean)
102 return (lt * stdnorm.template generate<Matrix<_Scalar, Dim, -1>>(mean.rows(), samples, std::forward<Urng>(urng))).colwise() + mean;
105 template<
typename Urng>
106 inline auto generate(Urng&& urng)
107 -> decltype((lt * stdnorm.template generate<Matrix<_Scalar, Dim, 1>>(mean.rows(), 1, std::forward<Urng>(urng))).colwise() + mean)
109 return (lt * stdnorm.template generate<Matrix<_Scalar, Dim, 1>>(mean.rows(), 1, std::forward<Urng>(urng))).colwise() + mean;
121 template<
typename MeanTy,
typename CovTy>
122 inline auto makeMvNormGen(
const MatrixBase<MeanTy>& mean,
const MatrixBase<CovTy>& cov)
126 std::is_same<
typename MatrixBase<MeanTy>::Scalar,
typename MatrixBase<CovTy>::Scalar>::value,
127 "Derived::Scalar must be the same with `mean` and `cov`'s Scalar."
130 MatrixBase<MeanTy>::RowsAtCompileTime == MatrixBase<CovTy>::RowsAtCompileTime &&
131 MatrixBase<CovTy>::RowsAtCompileTime == MatrixBase<CovTy>::ColsAtCompileTime,
132 "assert: mean.RowsAtCompileTime == cov.RowsAtCompileTime && cov.RowsAtCompileTime == cov.ColsAtCompileTime"
134 return { mean, cov };
145 template<
typename MeanTy,
typename LTTy>
150 std::is_same<
typename MatrixBase<MeanTy>::Scalar,
typename MatrixBase<LTTy>::Scalar>::value,
151 "Derived::Scalar must be the same with `mean` and `lt`'s Scalar."
154 MatrixBase<MeanTy>::RowsAtCompileTime == MatrixBase<LTTy>::RowsAtCompileTime &&
155 MatrixBase<LTTy>::RowsAtCompileTime == MatrixBase<LTTy>::ColsAtCompileTime,
156 "assert: mean.RowsAtCompileTime == lt.RowsAtCompileTime && lt.RowsAtCompileTime == lt.ColsAtCompileTime"
158 return { mean, lt, lower_triangular };
167 template<
typename _Scalar, Index Dim>
171 Matrix<_Scalar, Dim, Dim> chol;
173 std::vector<ChiSquaredGen<_Scalar>> chisqs;
182 template<
typename ScaleTy>
183 WishartGen(Index _df,
const MatrixBase<ScaleTy>& _lt, detail::LowerTriangular)
184 : df{ _df }, chol{ _lt }
186 eigen_assert(df > chol.rows() - 1);
187 eigen_assert(chol.rows() == chol.cols());
189 for (Index i = 0; i < chol.rows(); ++i)
191 chisqs.emplace_back(df - i);
202 template<
typename ScaleTy>
203 WishartGen(Index _df,
const MatrixBase<ScaleTy>& _scale, detail::FullMatrix = {})
204 :
WishartGen{ _df, detail::template get_lt<_Scalar, Dim>(_scale), lower_triangular }
206 eigen_assert(_scale.rows() == _scale.cols());
212 Index dims()
const {
return chol.rows(); }
214 template<
typename Urng>
215 inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
217 const Index dim = chol.rows();
218 const Index normSamples = samples * dim * (dim - 1) / 2;
219 using ArrayXs = Array<_Scalar, -1, 1>;
220 Matrix<_Scalar, Dim, -1> rand_mat(dim, dim * samples), tmp(dim, dim * samples);
222 _Scalar* ptr = tmp.data();
223 Map<ArrayXs>{ ptr, normSamples } = stdnorm.template generate<ArrayXs>(normSamples, 1, urng);
224 for (Index j = 0; j < samples; ++j)
226 for (Index i = 0; i < dim - 1; ++i)
228 rand_mat.col(i + j * dim).tail(dim - 1 - i) = Map<ArrayXs>{ ptr, dim - 1 - i };
233 for (Index i = 0; i < dim; ++i)
235 _Scalar* ptr = tmp.data();
236 Map<ArrayXs>{ ptr, samples } = chisqs[i].template generate<ArrayXs>(samples, 1, urng).sqrt();
237 for (Index j = 0; j < samples; ++j)
239 rand_mat(i, i + j * dim) = *ptr++;
243 for (Index j = 0; j < samples; ++j)
245 rand_mat.middleCols(j * dim, dim).template triangularView<StrictlyUpper>().setZero();
247 tmp.noalias() = chol * rand_mat;
249 for (Index j = 0; j < samples; ++j)
251 auto t = tmp.middleCols(j * dim, dim);
252 rand_mat.middleCols(j * dim, dim).noalias() = t * t.transpose();
257 template<
typename Urng>
258 inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng)
260 const Index dim = chol.rows();
261 const Index normSamples = dim * (dim - 1) / 2;
262 using ArrayXs = Array<_Scalar, -1, 1>;
263 Matrix<_Scalar, Dim, Dim> rand_mat(dim, dim);
264 Map<ArrayXs>{ rand_mat.data(), normSamples } = stdnorm.template generate<ArrayXs>(normSamples, 1, urng);
266 for (Index i = 0; i < dim / 2; ++i)
268 rand_mat.col(dim - 2 - i).tail(i + 1) = rand_mat.col(i).head(i + 1);
271 for (Index i = 0; i < dim; ++i)
273 rand_mat(i, i) = chisqs[i].template generate<Array<_Scalar, 1, 1>>(1, 1, urng).sqrt()(0);
275 rand_mat.template triangularView<StrictlyUpper>().setZero();
277 auto t = (chol * rand_mat).eval();
278 return (t * t.transpose()).eval();
289 template<
typename ScaleTy>
294 MatrixBase<ScaleTy>::RowsAtCompileTime == MatrixBase<ScaleTy>::ColsAtCompileTime,
295 "assert: scale.RowsAtCompileTime == scale.ColsAtCompileTime"
297 return { df, scale };
307 template<
typename LTTy>
312 MatrixBase<LTTy>::RowsAtCompileTime == MatrixBase<LTTy>::ColsAtCompileTime,
313 "assert: lt.RowsAtCompileTime == lt.ColsAtCompileTime"
315 return { df, lt, lower_triangular };
324 template<
typename _Scalar, Index Dim>
328 Matrix<_Scalar, Dim, Dim> chol;
330 std::vector<ChiSquaredGen<_Scalar>> chisqs;
339 template<
typename ScaleTy>
340 InvWishartGen(Index _df,
const MatrixBase<ScaleTy>& _ilt, detail::InvLowerTriangular)
341 : df{ _df }, chol{ _ilt }
343 eigen_assert(df > chol.rows() - 1);
344 eigen_assert(chol.rows() == chol.cols());
346 for (Index i = 0; i < chol.rows(); ++i)
348 chisqs.emplace_back(df - i);
359 template<
typename ScaleTy>
360 InvWishartGen(Index _df,
const MatrixBase<ScaleTy>& _scale, detail::FullMatrix = {})
361 :
InvWishartGen{ _df, detail::template get_lt<_Scalar, Dim>(_scale.inverse()), inv_lower_triangular }
363 eigen_assert(_scale.rows() == _scale.cols());
369 Index dims()
const {
return chol.rows(); }
371 template<
typename Urng>
372 inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
374 const Index dim = chol.rows();
375 const Index normSamples = samples * dim * (dim - 1) / 2;
376 using ArrayXs = Array<_Scalar, -1, 1>;
377 Matrix<_Scalar, Dim, -1> rand_mat(dim, dim * samples), tmp(dim, dim * samples);
379 _Scalar* ptr = tmp.data();
380 Map<ArrayXs>{ ptr, normSamples } = stdnorm.template generate<ArrayXs>(normSamples, 1, urng);
381 for (Index j = 0; j < samples; ++j)
383 for (Index i = 0; i < dim - 1; ++i)
385 rand_mat.col(i + j * dim).tail(dim - 1 - i) = Map<ArrayXs>{ ptr, dim - 1 - i };
390 for (Index i = 0; i < dim; ++i)
392 _Scalar* ptr = tmp.data();
393 Map<ArrayXs>{ ptr, samples } = chisqs[i].template generate<ArrayXs>(samples, 1, urng).sqrt();
394 for (Index j = 0; j < samples; ++j)
396 rand_mat(i, i + j * dim) = *ptr++;
400 for (Index j = 0; j < samples; ++j)
402 rand_mat.middleCols(j * dim, dim).template triangularView<StrictlyUpper>().setZero();
404 tmp.noalias() = chol * rand_mat;
406 auto id = Eigen::Matrix<_Scalar, Dim, Dim>::Identity(dim, dim);
407 for (Index j = 0; j < samples; ++j)
409 auto t = tmp.middleCols(j * dim, dim);
410 auto u = rand_mat.middleCols(j * dim, dim);
411 u.noalias() = t.template triangularView<Lower>().solve(
id);
412 t.noalias() = u.transpose() * u;
417 template<
typename Urng>
418 inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng)
420 const Index dim = chol.rows();
421 const Index normSamples = dim * (dim - 1) / 2;
422 using ArrayXs = Array<_Scalar, -1, 1>;
423 Matrix<_Scalar, Dim, Dim> rand_mat(dim, dim);
424 Map<ArrayXs>{ rand_mat.data(), normSamples } = stdnorm.template generate<ArrayXs>(normSamples, 1, urng);
426 for (Index i = 0; i < dim / 2; ++i)
428 rand_mat.col(dim - 2 - i).tail(i + 1) = rand_mat.col(i).head(i + 1);
431 for (Index i = 0; i < dim; ++i)
433 rand_mat(i, i) = chisqs[i].template generate<Array<_Scalar, 1, 1>>(1, 1, urng).sqrt()(0);
435 rand_mat.template triangularView<StrictlyUpper>().setZero();
437 auto t = (chol * rand_mat).eval();
438 auto id = Eigen::Matrix<_Scalar, Dim, Dim>::Identity(dim, dim);
439 rand_mat.noalias() = t.template triangularView<Lower>().solve(
id);
441 return (rand_mat.transpose() * rand_mat).eval();
452 template<
typename ScaleTy>
457 MatrixBase<ScaleTy>::RowsAtCompileTime == MatrixBase<ScaleTy>::ColsAtCompileTime,
458 "assert: scale.RowsAtCompileTime == scale.ColsAtCompileTime"
460 return { df, scale };
470 template<
typename ILTTy>
475 MatrixBase<ILTTy>::RowsAtCompileTime == MatrixBase<ILTTy>::ColsAtCompileTime,
476 "assert: ilt.RowsAtCompileTime == ilt.ColsAtCompileTime"
478 return { df, ilt, inv_lower_triangular };