Skip to content

Commit

Permalink
676 computation of the reproduction number in the seir model (#685)
Browse files Browse the repository at this point in the history
Co-authored-by: Johannssen <[email protected]>
Co-authored-by: HenrZu <[email protected]>
Co-authored-by: Henrik Zunker <[email protected]>
Co-authored-by: Martin Kühn <[email protected]>
  • Loading branch information
5 people authored Aug 29, 2023
1 parent e3076c2 commit 124313a
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 2 deletions.
75 changes: 75 additions & 0 deletions cpp/models/ode_seir/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
#define SEIR_MODEL_H

#include "memilio/compartments/compartmentalmodel.h"
#include "memilio/config.h"
#include "memilio/epidemiology/populations.h"
#include "memilio/epidemiology/contact_matrix.h"
#include "memilio/io/io.h"
#include "memilio/math/interpolation.h"
#include "memilio/utils/time_series.h"
#include "ode_seir/infection_state.h"
#include "ode_seir/parameters.h"
#include <algorithm>
#include <iterator>

namespace mio
{
Expand Down Expand Up @@ -63,6 +69,75 @@ class Model : public CompartmentalModel<InfectionState, Populations<InfectionSta
dydt[(size_t)InfectionState::Recovered] =
(1.0 / params.get<TimeInfected>()) * y[(size_t)InfectionState::Infected];
}

/**
*@brief Computes the reproduction number at a given index time of the Model output obtained by the Simulation.
*@param t_idx The index time at which the reproduction number is computed.
*@param y The TimeSeries obtained from the Model Simulation.
*@returns The computed reproduction number at the provided index time.
*/
IOResult<ScalarType> get_reproduction_number(size_t t_idx, const mio::TimeSeries<ScalarType>& y)
{
if (!(t_idx < static_cast<size_t>(y.get_num_time_points()))) {
return mio::failure(mio::StatusCode::OutOfRange, "t_idx is not a valid index for the TimeSeries");
}

ScalarType TimeInfected = this->parameters.get<mio::oseir::TimeInfected>();

ScalarType coeffStoE = this->parameters.get<mio::oseir::ContactPatterns>().get_matrix_at(
y.get_time(static_cast<Eigen::Index>(t_idx)))(0, 0) *
this->parameters.get<mio::oseir::TransmissionProbabilityOnContact>() /
this->populations.get_total();

ScalarType result =
y.get_value(static_cast<Eigen::Index>(t_idx))[(Eigen::Index)mio::oseir::InfectionState::Susceptible] *
TimeInfected * coeffStoE;

return mio::success(result);
}

/**
*@brief Computes the reproduction number for all time points of the Model output obtained by the Simulation.
*@param y The TimeSeries obtained from the Model Simulation.
*@returns vector containing all reproduction numbers
*/
Eigen::VectorXd get_reproduction_numbers(const mio::TimeSeries<ScalarType>& y)
{
auto num_time_points = y.get_num_time_points();
Eigen::VectorXd temp(num_time_points);
for (size_t i = 0; i < static_cast<size_t>(num_time_points); i++) {
temp[i] = get_reproduction_number(i, y).value();
}
return temp;
}

/**
*@brief Computes the reproduction number at a given time point of the Model output obtained by the Simulation. If the particular time point is not inside the output, a linearly interpolated value is returned.
*@param t_value The time point at which the reproduction number is computed.
*@param y The TimeSeries obtained from the Model Simulation.
*@returns The computed reproduction number at the provided time point, potentially using linear interpolation.
*/
IOResult<ScalarType> get_reproduction_number(ScalarType t_value, const mio::TimeSeries<ScalarType>& y)
{
if (t_value < y.get_time(0) || t_value > y.get_last_time()) {
return mio::failure(mio::StatusCode::OutOfRange,
"Cannot interpolate reproduction number outside computed horizon of the TimeSeries");
}

if (t_value == y.get_time(0)) {
return mio::success(get_reproduction_number((size_t)0, y).value());
}

auto times = std::vector<ScalarType>(y.get_times().begin(), y.get_times().end());

auto time_late = std::distance(times.begin(), std::lower_bound(times.begin(), times.end(), t_value));

ScalarType y1 = get_reproduction_number(static_cast<size_t>(time_late - 1), y).value();
ScalarType y2 = get_reproduction_number(static_cast<size_t>(time_late), y).value();

auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
return mio::success(static_cast<ScalarType>(result));
}
};

} // namespace oseir
Expand Down
157 changes: 155 additions & 2 deletions cpp/tests/test_odeseir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
* limitations under the License.
*/
#include "load_test_data.h"
#include "memilio/config.h"
#include "memilio/utils/time_series.h"
#include "ode_seir/model.h"
#include "ode_seir/infection_state.h"
#include "ode_seir/parameters.h"
#include "memilio/math/euler.h"
#include "memilio/compartments/simulation.h"
#include <gtest/gtest.h>
#include <iomanip>
#include <vector>

TEST(TestSeir, simulateDefault)
{
Expand Down Expand Up @@ -150,7 +154,6 @@ TEST(TestSeir, check_constraints_parameters)
model.parameters.set<mio::oseir::TimeInfected>(6);
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(10.);
ASSERT_EQ(model.parameters.check_constraints(), 1);

mio::set_log_level(mio::LogLevel::warn);
}

Expand All @@ -176,6 +179,156 @@ TEST(TestSeir, apply_constraints_parameters)
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(10.);
EXPECT_EQ(model.parameters.apply_constraints(), 1);
EXPECT_NEAR(model.parameters.get<mio::oseir::TransmissionProbabilityOnContact>(), 0.0, 1e-14);

mio::set_log_level(mio::LogLevel::warn);
}

TEST(TestSeir, get_reproduction_numbers)
{
mio::oseir::Model model;

double total_population = 10000;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Susceptible)}] =
total_population -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}];

model.parameters.set<mio::oseir::TimeInfected>(6);
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(0.04);
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 10;

model.apply_constraints();

Eigen::VectorXd checkReproductionNumbers(7);
checkReproductionNumbers << 2.3280000000000002913, 2.3279906878991880603, 2.3279487809434575851,
2.3277601483151548756, 2.3269102025388899158, 2.3230580052413736247, 2.3185400624683065729;

Eigen::VectorXd checkReproductionNumbers2(7);
checkReproductionNumbers2 << 2.0952000000000001734, 2.0951916191092689878, 2.0951539028491117378,
2.0949841334836394324, 2.0942191822850007021, 2.0907522047172362178, 2.086686056221475738;

Eigen::VectorXd checkReproductionNumbers3(7);
checkReproductionNumbers3 << 1.8623999999999998334, 1.8623925503193501374, 1.8623590247547658905,
1.8622081186521235452, 1.8615281620311117106, 1.8584464041930985889, 1.854832049974644903;

mio::TimeSeries<ScalarType> result((int)mio::oseir::InfectionState::Count);
mio::TimeSeries<ScalarType>::Vector result_0(4);
mio::TimeSeries<ScalarType>::Vector result_1(4);
mio::TimeSeries<ScalarType>::Vector result_2(4);
mio::TimeSeries<ScalarType>::Vector result_3(4);
mio::TimeSeries<ScalarType>::Vector result_4(4);
mio::TimeSeries<ScalarType>::Vector result_5(4);
mio::TimeSeries<ScalarType>::Vector result_6(4);

result_0[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9700;
result_1[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.9611995799496071;
result_2[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.7865872644051706;
result_3[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.0006179798110679;
result_4[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9695.4591772453732119;
result_5[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9679.4083551723888377;
result_6[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9660.5835936179428245;

result.add_time_point(0, result_0);
result.add_time_point(0.0010000000000000000208, result_1);
result.add_time_point(0.0055000000000000005482, result_2);
result.add_time_point(0.025750000000000005523, result_3);
result.add_time_point(0.11687500000000002054, result_4);
result.add_time_point(0.52693750000000005862, result_5);
result.add_time_point(1, result_6);

auto reproduction_numbers = model.get_reproduction_numbers(result);

for (int i = 0; i < reproduction_numbers.size(); i++) {
EXPECT_NEAR(reproduction_numbers[i], checkReproductionNumbers[i], 1e-12);
}

model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 9;

auto reproduction_numbers2 = model.get_reproduction_numbers(result);

for (int i = 0; i < reproduction_numbers2.size(); i++) {
EXPECT_NEAR(reproduction_numbers2[i], checkReproductionNumbers2[i], 1e-12);
}

model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 8;

auto reproduction_numbers3 = model.get_reproduction_numbers(result);

for (int i = 0; i < reproduction_numbers2.size(); i++) {
EXPECT_NEAR(reproduction_numbers3[i], checkReproductionNumbers3[i], 1e-12);
}

EXPECT_FALSE(model.get_reproduction_number(static_cast<double>(static_cast<size_t>(result.get_num_time_points())),
result)); //Test for an index that is out of range
}

TEST(TestSeir, get_reproduction_number)
{
mio::oseir::Model model;

double total_population = 10000; //Initialize compartments to get total population of 10000
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}] = 100;
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Susceptible)}] =
total_population -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Exposed)}] -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Infected)}] -
model.populations[{mio::Index<mio::oseir::InfectionState>(mio::oseir::InfectionState::Recovered)}];

model.parameters.set<mio::oseir::TimeInfected>(6);
model.parameters.set<mio::oseir::TransmissionProbabilityOnContact>(0.04);
model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 10;

model.apply_constraints();

mio::TimeSeries<ScalarType> result((int)mio::oseir::InfectionState::Count);
mio::TimeSeries<ScalarType>::Vector result_0(4);
mio::TimeSeries<ScalarType>::Vector result_1(4);
mio::TimeSeries<ScalarType>::Vector result_2(4);
mio::TimeSeries<ScalarType>::Vector result_3(4);
mio::TimeSeries<ScalarType>::Vector result_4(4);
mio::TimeSeries<ScalarType>::Vector result_5(4);
mio::TimeSeries<ScalarType>::Vector result_6(4);
mio::TimeSeries<ScalarType>::Vector result_7(4);

result_0[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9700;
result_1[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.9709149074315;
result_2[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.8404009584538;
result_3[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9699.260556488618;
result_4[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9696.800490904101;
result_5[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9687.9435082620021;
result_6[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9679.5436372291661;
result_7[(Eigen::Index)mio::oseir::InfectionState::Susceptible] = 9678.5949381732935;

result.add_time_point(0, result_0);
result.add_time_point(0.001, result_1);
result.add_time_point(0.0055, result_2);
result.add_time_point(0.02575, result_3);
result.add_time_point(0.116875, result_4);
result.add_time_point(0.526938, result_5);
result.add_time_point(0.952226, result_6);
result.add_time_point(1, result_7);

EXPECT_FALSE(model.get_reproduction_number(result.get_time(0) - 0.5, result)); //Test for indices out of range
EXPECT_FALSE(model.get_reproduction_number(result.get_last_time() + 0.5, result));
EXPECT_FALSE(model.get_reproduction_number((size_t)result.get_num_time_points(), result));

EXPECT_EQ(model.get_reproduction_number((size_t)0, result).value(),
model.get_reproduction_number(0.0, result).value());

EXPECT_NEAR(model.get_reproduction_number(0.3, result).value(), 2.3262828383474389859, 1e-12);
EXPECT_NEAR(model.get_reproduction_number(0.7, result).value(), 2.3242860858116172196, 1e-12);
EXPECT_NEAR(model.get_reproduction_number(0.0, result).value(), 2.3280000000000002913, 1e-12);

model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 9;
EXPECT_NEAR(model.get_reproduction_number(0.1, result).value(), 2.0946073086586665113, 1e-12);
EXPECT_NEAR(model.get_reproduction_number(0.3, result).value(), 2.0936545545126947765, 1e-12);

model.parameters.get<mio::oseir::ContactPatterns>().get_baseline()(0, 0) = 8;
EXPECT_NEAR(model.get_reproduction_number(0.2, result).value(), 1.8614409729718137676, 1e-12);
EXPECT_NEAR(model.get_reproduction_number(0.9, result).value(), 1.858670429549998504, 1e-12);
}

0 comments on commit 124313a

Please sign in to comment.