From 144d641739ccd1109055d13b5b96e4e76607305d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 9 Jan 2024 10:29:12 -0500 Subject: [PATCH] Stan 2.34: Fix parsing of unit_e output files --- cmdstanpy/stanfit/mcmc.py | 6 ++++++ cmdstanpy/utils/stancsv.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 53fb4de1..4d908033 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -441,6 +441,12 @@ def _assemble_draws(self) -> None: self._metric[chain, i, :] = [ float(x) for x in xs ] + else: # unit_e changed in 2.34 to have an extra line + pos = fd.tell() + line = fd.readline().strip() + if not line.startswith('#'): + fd.seek(pos) + # process draws for i in range(sampling_iter_start, num_draws): line = fd.readline().strip() diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index d08181d2..1481cb5d 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -289,10 +289,16 @@ def scan_hmc_params( raise ValueError( 'line {}: invalid step size: {}'.format(lineno, step_size) ) from e - if metric == 'unit_e': - return lineno + before_metric = fd.tell() line = fd.readline().strip() lineno += 1 + if metric == 'unit_e': + if line.startswith("# No free parameters"): + return lineno + else: + fd.seek(before_metric) + return lineno - 1 + if not ( ( metric == 'diag_e'