Skip to content

Commit

Permalink
Allow normal distribution to preserve spare value while changing para…
Browse files Browse the repository at this point in the history
…meters (#1371)
  • Loading branch information
sethrj authored Aug 19, 2024
1 parent e69902f commit e077e9c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 10 deletions.
93 changes: 89 additions & 4 deletions src/celeritas/random/distribution/NormalDistribution.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,43 @@ class NormalDistribution

public:
// Construct with mean and standard deviation
explicit inline CELER_FUNCTION
NormalDistribution(real_type mean = 0, real_type stddev = 1);
inline CELER_FUNCTION NormalDistribution(real_type mean, real_type stddev);

//! Construct with unit deviation
explicit CELER_FUNCTION NormalDistribution(real_type mean)
: NormalDistribution{mean, 1}
{
}

//! Construct with unit deviation and zero mean
CELER_FUNCTION NormalDistribution() : NormalDistribution{0, 1} {}

// Initialize with parameters but not spare values
inline CELER_FUNCTION NormalDistribution(NormalDistribution const& other);

// Reset spare value of other distribution
inline CELER_FUNCTION NormalDistribution(NormalDistribution&& other);

// Keep spare value but change distribution
inline CELER_FUNCTION NormalDistribution&
operator=(NormalDistribution const&);

// Possibly use spare value, change distribution
inline CELER_FUNCTION NormalDistribution& operator=(NormalDistribution&&);

// Default destructor (rule of 5)
~NormalDistribution() = default;

// Sample a random number according to the distribution
template<class Generator>
inline CELER_FUNCTION result_type operator()(Generator& rng);

private:
real_type const mean_;
real_type const stddev_;
// Distribution properties
real_type mean_;
real_type stddev_;

// Intermediate samples
real_type spare_{};
bool has_spare_{false};
};
Expand All @@ -76,6 +103,64 @@ NormalDistribution<RealType>::NormalDistribution(real_type mean,
CELER_EXPECT(stddev > 0);
}

//---------------------------------------------------------------------------//
/*!
* Initialize with parameters but not spare values.
*/
template<class RealType>
CELER_FUNCTION
NormalDistribution<RealType>::NormalDistribution(NormalDistribution const& other)
: mean_{other.mean}, stddev_{other.stddev}
{
}

//---------------------------------------------------------------------------//
/*!
* Reset spare value of other distribution.
*/
template<class RealType>
CELER_FUNCTION
NormalDistribution<RealType>::NormalDistribution(NormalDistribution&& other)
: mean_{other.mean_}
, stddev_{other.stddev_}
, spare_{other.spare_}
, has_spare_{other.has_spare_}
{
other.has_spare_ = false;
}

//---------------------------------------------------------------------------//
/*!
* Keep spare value but change distribution.
*/
template<class RealType>
CELER_FUNCTION NormalDistribution<RealType>&
NormalDistribution<RealType>::operator=(NormalDistribution const& other)
{
mean_ = other.mean_;
stddev_ = other.stddev_;
return *this;
}

//---------------------------------------------------------------------------//
/*!
* Possibly use spare value, change distribution.
*/
template<class RealType>
CELER_FUNCTION NormalDistribution<RealType>&
NormalDistribution<RealType>::operator=(NormalDistribution&& other)
{
mean_ = other.mean_;
stddev_ = other.stddev_;
if (!has_spare_ && other.has_spare_)
{
spare_ = other.spare_;
has_spare_ = other.has_spare_;
other.has_spare_ = false;
}
return *this;
}

//---------------------------------------------------------------------------//
/*!
* Sample a random number according to the distribution.
Expand Down
44 changes: 38 additions & 6 deletions test/celeritas/random/distribution/NormalDistribution.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ namespace test
{
//---------------------------------------------------------------------------//

TEST(NormalDistributionTest, bin)
TEST(NormalDistributionTest, normal)
{
DiagnosticRngEngine<std::mt19937> rng;
int num_samples = 10000;

double mean = 0.0;
double stddev = 1.0;

NormalDistribution<double> sample_normal{mean, stddev};

NormalDistribution<double> sample_normal{/* mean = */ 0.0,
/* stddev = */ 1.0};
std::vector<int> counters(6);
for ([[maybe_unused]] int i : range(num_samples))
{
Expand All @@ -53,6 +50,41 @@ TEST(NormalDistributionTest, bin)
EXPECT_EQ(2 * num_samples, rng.count());
}

TEST(NormalDistributionTest, move)
{
DiagnosticRngEngine<std::mt19937> rng;
NormalDistribution<double> sample_normal{/* mean = */ 0,
/* stddev = */ 0.5};

std::vector<double> samples;
for ([[maybe_unused]] int i : range(4))
{
samples.push_back(sample_normal(rng));
}

// Check that resetting RNG gives same results
rng = {};
for ([[maybe_unused]] int i : range(4))
{
EXPECT_DOUBLE_EQ(samples[i], sample_normal(rng));
}

// Replace after 1 sample: should be scaled original (using latent spare)
rng = {};
EXPECT_DOUBLE_EQ(samples[0], sample_normal(rng));
sample_normal = {1.0, 1.0}; // Shift right, double width
EXPECT_DOUBLE_EQ(2 * samples[1] + 1, sample_normal(rng));

// Check that we capture the "spare" value from another distribution
sample_normal = [] {
NormalDistribution<double> sample_other_normal{0, 2.0};
std::mt19937 temp_rng;
sample_other_normal(temp_rng);
return sample_other_normal;
}();
EXPECT_DOUBLE_EQ(4 * samples[1], sample_normal(rng));
}

//---------------------------------------------------------------------------//
} // namespace test
} // namespace celeritas

0 comments on commit e077e9c

Please sign in to comment.