From beb534df19ec00f535610aeb816681d30123ac19 Mon Sep 17 00:00:00 2001 From: Seth Siegel Date: Mon, 26 Aug 2024 09:33:15 -0700 Subject: [PATCH] fix(derivative): fix copy-paste bugs in mixed partial derivatives Fix errors in the calculation of the (burst_width, dm) and (scattering_timescale, scattering_timescale) mixed partial derivative. Impacts the calculation of the analytical hessian for models with scattering. --- fitburst/routines/derivative.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fitburst/routines/derivative.py b/fitburst/routines/derivative.py index a12f7d9..b3f0565 100644 --- a/fitburst/routines/derivative.py +++ b/fitburst/routines/derivative.py @@ -1780,6 +1780,7 @@ def deriv2_model_wrt_burst_width_dm(parameters: dict, model: float, component: i sc_time = parameters["scattering_timescale"][0] # global parameter. amp_pbf = amplitude_pbf(model.freqs, parameters, component) + deriv_first = deriv_model_wrt_dm(parameters, model, component, add_all = False) deriv_mod = np.zeros(current_model.shape) # now loop over each frequency and compute mixed-derivative array per channel. @@ -1806,7 +1807,7 @@ def deriv2_model_wrt_burst_width_dm(parameters: dict, model: float, component: i ) # now define terms that contriubte to mixed-partial derivative. - term1 = burst_width * current_model[freq, :] / sc_time_freq ** 2 + term1 = burst_width * deriv_first[freq, :] / sc_time_freq ** 2 term2 = current_amplitude[freq, :] * amp_pbf[freq] * np.exp(arg_exp) * deriv2_arg_erf * 2 / np.sqrt(np.pi) term3 = current_amplitude[freq, :] * amp_pbf[freq] * np.exp(arg_exp) * deriv_arg_erf * deriv_arg_exp * 2 / np.sqrt(np.pi) deriv_mod[freq, :] = term1 + term2 + term3 @@ -2296,7 +2297,7 @@ def deriv2_model_wrt_scattering_timescale_scattering_timescale(parameters: dict, burst_width = parameters["burst_width"][current_component] current_timediff = model.timediff_per_component[:, : , current_component] current_amplitude = model.amplitude_per_component[:, :, current_component] - deriv_first = deriv_model_wrt_dm_index(parameters, model, current_component, add_all = False) + deriv_first = deriv_model_wrt_scattering_timescale(parameters, model, current_component, add_all = False) ref_freq = parameters["ref_freq"][current_component] amp_pbf = amplitude_pbf(model.freqs, parameters, current_component)