21 #ifndef MIO_RANDOM_NUMBER_GENERATOR_H
22 #define MIO_RANDOM_NUMBER_GENERATOR_H
34 #include "Random123/array.h"
35 #include "Random123/threefry.h"
45 #include <type_traits>
115 template <
class Derived>
116 class RandomNumberGeneratorBase
119 using result_type = uint64_t;
125 static constexpr result_type
min()
134 static constexpr result_type
max()
144 result_type operator()();
152 inline uint64_t to_uint64(r123array2x32 tf_array)
155 std::memcpy(&
i, tf_array.data(),
sizeof(uint64_t));
162 inline r123array2x32 to_r123_array(uint64_t
i)
164 threefry2x32_ctr_t c;
165 std::memcpy(c.data(), &
i,
sizeof(uint64_t));
177 using TypeSafe<T, Key<T>>::TypeSafe;
179 static_assert(
sizeof(Key<uint32_t>) ==
sizeof(uint32_t),
"Empty Base Optimization isn't working.");
187 OperatorComparison<Counter<T>>,
188 OperatorAdditionSubtraction<Counter<T>> {
190 using TypeSafe<T, Counter<T>>::TypeSafe;
192 static_assert(
sizeof(Counter<uint32_t>) ==
sizeof(uint32_t),
"Empty Base Optimization isn't working.");
194 template <
class Derived>
195 auto RandomNumberGeneratorBase<Derived>::operator()() -> result_type
199 auto self =
static_cast<Derived*
>(
this);
200 auto c =
static_cast<uint64_t
>(
self->get_counter().get());
201 auto k =
static_cast<uint64_t
>(
self->get_key().get());
202 auto r = details::to_uint64(threefry2x32(details::to_r123_array(k), details::to_r123_array(c)));
203 self->increment_counter();
213 template <
class SeedSeq>
214 Key<uint64_t> seed_rng_key(SeedSeq& seed_seq)
216 auto tf_key = threefry2x32_key_t::seed(seed_seq);
217 return Key<uint64_t>(details::to_uint64(tf_key));
236 template <
class UIntC,
class UIntN,
class CounterS>
237 Counter<UIntC> rng_totalsequence_counter(UIntN subsequence_idx, CounterS counter)
240 static const UIntC BITS_PER_BYTE = 8;
241 static const UIntC C_BITS =
sizeof(UIntC) * BITS_PER_BYTE;
242 static const UIntC S_BITS =
sizeof(CounterS) * BITS_PER_BYTE;
243 static const UIntC N_BITS = C_BITS - S_BITS;
245 static_assert(S_BITS < C_BITS,
"Subsequence counter must be smaller than total sequence counter.");
246 static_assert(N_BITS <= C_BITS,
"Subsequence index must not be bigger than total sequence counter.");
247 static_assert(N_BITS <=
sizeof(UIntN) * BITS_PER_BYTE,
"Subsequence index must be at least N bits");
249 assert(UIntC(subsequence_idx) <= (UIntC(1) << N_BITS) &&
250 "Subsequence index is too large.");
259 const auto i =
static_cast<UIntC
>(subsequence_idx);
260 const auto s =
static_cast<UIntC
>(counter.get());
261 const auto c = (
i << S_BITS) + s;
262 return Counter<UIntC>{c};
275 template <
class UIntS,
class CounterC>
276 Counter<UIntS> rng_subsequence_counter(CounterC counter)
278 using UIntC =
typename CounterC::ValueType;
279 static const UIntC C_BYTES =
sizeof(UIntC);
280 static const UIntC S_BYTES =
sizeof(UIntS);
282 static_assert(S_BYTES < C_BYTES,
"Subsequence counter must be smaller than total sequence counter.");
287 return Counter<UIntS>(
static_cast<UIntS
>(counter.get()));
295 class RandomNumberGenerator :
public RandomNumberGeneratorBase<RandomNumberGenerator>
298 RandomNumberGenerator()
301 seed(generate_seeds());
304 Key<uint64_t> get_key()
const
308 Counter<uint64_t> get_counter()
const
312 void set_counter(Counter<uint64_t> counter)
316 void increment_counter()
320 static std::vector<uint32_t> generate_seeds()
322 std::random_device rd;
323 return {rd(), rd(), rd(), rd(), rd(), rd()};
326 void seed(
const std::vector<uint32_t>& seeds)
329 std::seed_seq seed_seq(m_seeds.begin(), m_seeds.end());
330 m_key = seed_rng_key(seed_seq);
333 const std::vector<uint32_t> get_seeds()
const
343 #ifdef MEMILIO_ENABLE_MPI
348 num_seeds = int(m_seeds.size());
352 m_seeds.assign(num_seeds, 0);
354 MPI_Bcast(m_seeds.data(), num_seeds, MPI_UNSIGNED, 0,
mpi::get_world());
362 auto default_serialize()
364 return Members(
"RandomNumberGenerator").add(
"key", m_key).add(
"counter", m_counter).add(
"seeds", m_seeds);
369 Counter<uint64_t> m_counter;
370 std::vector<uint32_t> m_seeds;
383 inline void log_rng_seeds(
const RandomNumberGenerator& rng,
LogLevel level)
385 const auto& seeds = rng.get_seeds();
386 std::stringstream ss;
388 for (
auto& s : seeds) {
395 log(level,
"Using RNG with seeds: {0}.", ss.str());
401 inline void log_thread_local_rng_seeds(
LogLevel level)
414 template <
class DistT>
415 class DistributionAdapter
421 using ResultType =
typename DistT::result_type;
428 using DistType = DistributionAdapter<DistT>;
430 template <
typename... Ps>
431 requires std::is_constructible_v<
typename DistT::param_type, Ps...>
432 ParamType(Ps&&... ps)
433 : params(std::forward<Ps>(ps)...)
441 static DistributionAdapter& get_distribution_instance()
443 return DistType::get_instance();
446 typename DistT::param_type params;
451 using GeneratorFunction = std::function<ResultType(
const typename DistT::param_type& p)>;
459 DistributionAdapter() =
default;
460 DistributionAdapter(
const DistributionAdapter&) =
default;
461 DistributionAdapter& operator=(
const DistributionAdapter&) =
default;
462 DistributionAdapter(DistributionAdapter&&) =
default;
463 DistributionAdapter& operator=(DistributionAdapter&&) =
default;
473 template <
class RNG,
class... T>
474 ResultType operator()(RNG& rng, T&&... params)
478 return m_generator(
typename DistT::param_type{std::forward<T>(params)...});
481 return DistT(std::forward<T>(params)...)(rng);
488 GeneratorFunction get_generator()
const
496 void set_generator(GeneratorFunction g)
507 static DistributionAdapter& get_instance()
509 static DistributionAdapter instance;
514 GeneratorFunction m_generator;
525 class DiscreteDistributionInPlace
531 using result_type = Int;
539 using distribution = DiscreteDistributionInPlace;
541 param_type() =
default;
543 param_type(Span<ScalarType> weights)
548 Span<ScalarType> weights()
const
554 Span<ScalarType> m_weights;
561 DiscreteDistributionInPlace() =
default;
566 DiscreteDistributionInPlace(Span<ScalarType> weights)
574 DiscreteDistributionInPlace(param_type params)
598 void param(param_type p)
606 Span<ScalarType> weights()
608 return m_params.weights();
616 result_type operator()(RNG& rng)
618 return (*
this)(rng, m_params);
627 result_type operator()(RNG& rng, param_type p)
629 auto weights = p.weights();
630 if (weights.size() <= 1) {
633 auto sum = std::accumulate(weights.begin(), weights.end(), 0.0);
634 auto u = std::uniform_real_distribution<ScalarType>()(
635 rng, std::uniform_real_distribution<ScalarType>::param_type{0.0, sum});
636 auto intermediate_sum = 0.0;
637 for (
size_t i = 0;
i < weights.size(); ++
i) {
638 intermediate_sum += weights.get_ptr()[
i];
639 if (u < intermediate_sum) {
643 assert(
false &&
"this should never happen.");
644 return result_type(-1);
656 using DiscreteDistribution = DistributionAdapter<DiscreteDistributionInPlace<Int>>;
662 template <
class Real>
663 using ExponentialDistribution = DistributionAdapter<std::exponential_distribution<Real>>;
669 template <
class Real>
670 using NormalDistribution = DistributionAdapter<std::normal_distribution<Real>>;
677 using UniformIntDistribution = DistributionAdapter<std::uniform_int_distribution<Int>>;
683 template <
class Real>
684 using UniformDistribution = DistributionAdapter<std::uniform_real_distribution<Real>>;
686 template <
class IOContext,
class UniformDistributionParams,
687 class Real =
typename UniformDistributionParams::DistType::ResultType>
688 requires std::is_same_v<UniformDistributionParams, typename UniformDistribution<Real>::ParamType>
691 auto obj = io.create_object(
"UniformDistributionParams");
692 obj.add_element(
"a", p.params.a());
693 obj.add_element(
"b", p.params.b());
696 template <
class IOContext,
class UniformDistributionParams,
697 class Real =
typename UniformDistributionParams::DistType::ResultType>
698 requires std::is_same_v<UniformDistributionParams, typename UniformDistribution<Real>::ParamType>
699 IOResult<UniformDistributionParams>
deserialize_internal(IOContext& io, Tag<UniformDistributionParams>)
701 auto obj = io.expect_object(
"UniformDistributionParams");
702 auto a = obj.expect_element(
"a", Tag<Real>{});
703 auto b = obj.expect_element(
"b", Tag<Real>{});
706 [](
auto&& a_,
auto&& b_) {
707 return UniformDistributionParams{a_, b_};
717 using PoissonDistribution = DistributionAdapter<std::poisson_distribution<Int>>;
723 template <
class Real>
724 using LogNormalDistribution = DistributionAdapter<std::lognormal_distribution<Real>>;
730 template <
class Real>
731 using GammaDistribution = DistributionAdapter<std::gamma_distribution<Real>>;
737 template <
class Real>
738 using NormalDistribution = DistributionAdapter<std::normal_distribution<Real>>;
#define MEMILIO_ENABLE_EBO
Definition: compiler_diagnostics.h:75
#define MSVC_WARNING_POP()
Definition: compiler_diagnostics.h:44
#define GCC_CLANG_DIAGNOSTIC(...)
Definition: compiler_diagnostics.h:62
static min_max_return_type< ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 >, ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > >::type min(const ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > &a, const ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > &b)
Definition: ad.hpp:2599
static min_max_return_type< ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 >, ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > >::type max(const ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > &a, const ad::internal::active_type< AD_TAPE_REAL, DATA_HANDLER_1 > &b)
Definition: ad.hpp:2596
trait_value< T >::RETURN_TYPE & value(T &x)
Definition: ad.hpp:3308
Comm get_world()
Get the global MPI communicator.
Definition: miompi.cpp:32
int rank(Comm comm)
Return the rank of the calling process on the given communicator.
Definition: miompi.cpp:63
A collection of classes to simplify handling of matrix shapes in meta programming.
Definition: models/abm/analyze_result.h:30
auto i
Definition: io.h:809
details::ApplyResultT< F, T... > apply(IOContext &io, F f, const IOResult< T > &... rs)
Evaluate a function with zero or more unpacked IOResults as arguments.
Definition: io.h:481
requires(!std::is_trivial_v< T >) void BinarySerializerObject
Definition: binary_serializer.h:333
LogLevel
Definition: logging.h:40
void log(LogLevel level, spdlog::string_view_t fmt, const Args &... args)
Definition: logging.h:128
auto max(const Eigen::MatrixBase< A > &a, B &&b)
coefficient wise maximum of two matrices.
Definition: eigen_util.h:171
RandomNumberGenerator & thread_local_rng()
Definition: random_number_generator.cpp:25
IOResult< T > deserialize_internal(IOContext &io, Tag< T > tag)
Deserialization implementation for the default serialization feature.
Definition: default_serialize.h:236
void serialize_internal(IOContext &io, const T &a)
Serialization implementation for the default serialization feature.
Definition: default_serialize.h:213
MSVC_WARNING_DISABLE_PUSH(4127) GCC_CLANG_DIAGNOSTIC(ignored "-Wexpansion-to-defined") namespace mio
Definition: random_number_generator.h:31