diff --git a/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb b/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb index 7de81868..c5ca34d2 100644 --- a/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb +++ b/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -270,6 +270,7 @@ "binned = aggregate_downsample(ts, time_bin_size=model_options.bin_size * u.min)\n", "binned2 = aggregate_downsample(ts2, time_bin_size=model_options.bin_size * u.min, aggregate_func=add_quad)\n", "\n", + "binned[\"normalized_flux_error\"] = binned2[\"normalized_flux_error\"]\n", "# binned_time = BinnedTimeSeries(photometry['bjd'], time_bin_start=first_time, time_bin_end=last_time, time_bin_size=bin_size)" ] }, @@ -299,34 +300,15 @@ "# Make the model\n", "mod = TransitModelFit()\n", "\n", - "# Load data\n", - "mod.times = (np.array((binned[\"time_bin_start\"] + binned[\"time_bin_size\"]/2).value) - 2400000)\n", - "mod.data = binned[\"normalized_flux\"].value\n", - "mod.weights = 1 / (binned2[\"normalized_flux_error\"].value)\n", - "mod.airmass = np.array(binned[\"airmass\"])\n", - "mod.width = np.array(binned[\"width\"])\n", - "mod.spp = np.array(binned[\"sky_per_pix_avg\"])\n", - "\n", "# Setup the model\n", "mod.setup_model(\n", + " binned_data=binned,\n", " t0=mid.jd - 2400000, # midpoint, BJD\n", " depth=tess_info.depth_ppt, # parts per thousand\n", " duration=tess_info.duration.to(\"day\").value, # days\n", " period=tess_info.period.to(\"day\").value, # days\n", - ")\n", - "\n", - "# Setup the model more 🙄\n", - "mod.model.t0.bounds = [\n", - " mid.jd - 2400000 - (model_options.transit_time_range * u.min).to(\"day\").value / 2,\n", - " mid.jd - 2400000 + (model_options.transit_time_range * u.min).to(\"day\").value / 2,\n", - "]\n", - "mod.model.t0.fixed = model_options.keep_transit_time_fixed\n", - "mod.model.a.fixed = model_options.keep_radius_orbit_fixed\n", - "mod.model.rp.fixed = model_options.keep_radius_planet_fixed\n", - "\n", - "mod.model.spp_trend.fixed = not model_options.fit_spp\n", - "mod.model.airmass_trend.fixed = not model_options.fit_airmass\n", - "mod.model.width_trend.fixed = not model_options.fit_width\n" + " model_options=model_options,\n", + ")\n" ] }, { diff --git a/stellarphot/transit_fitting/core.py b/stellarphot/transit_fitting/core.py index 63a2c576..72d55019 100644 --- a/stellarphot/transit_fitting/core.py +++ b/stellarphot/transit_fitting/core.py @@ -1,6 +1,7 @@ import warnings import numpy as np +from astropy import units as u from astropy.modeling.fitting import LevMarLSQFitter, _validate_model from astropy.modeling.models import custom_model from pydantic import BaseModel @@ -397,6 +398,7 @@ def transit_model_with_trends( def setup_model( self, + binned_data=None, t0=0, depth=0, duration=0, @@ -405,6 +407,7 @@ def setup_model( airmass_trend=0.0, width_trend=0.0, spp_trend=0.0, + model_options=None, ): """ Configure a transit model for fitting. The ``duration`` and ``depth`` @@ -441,11 +444,28 @@ def setup_model( spp_trend : float Coefficient for a linear trend in sky per pixel. + options : TransitModelOptions, optional + Options for the transit model fit. + Returns ------- None Sets values for the model parameters. """ + if binned_data: + self.times = ( + np.array( + ( + binned_data["time_bin_start"] + binned_data["time_bin_size"] / 2 + ).value + ) + - 2400000 + ) + self.data = binned_data["normalized_flux"].value + self.weights = 1 / (binned_data["normalized_flux_error"].value) + self.airmass = np.array(binned_data["airmass"]) + self.width = np.array(binned_data["width"]) + self.spp = np.array(binned_data["sky_per_pix_avg"]) self._setup_transit_model() # rp is related to depth in a straightforward way @@ -473,6 +493,20 @@ def setup_model( except ValueError: pass + if model_options is not None: + # Setup the model more 🙄 + self.model.t0.bounds = [ + t0 - (model_options.transit_time_range * u.min).to("day").value / 2, + t0 + (model_options.transit_time_range * u.min).to("day").value / 2, + ] + self.model.t0.fixed = model_options.keep_transit_time_fixed + self.model.a.fixed = model_options.keep_radius_orbit_fixed + self.model.rp.fixed = model_options.keep_radius_planet_fixed + + self.model.spp_trend.fixed = not model_options.fit_spp + self.model.airmass_trend.fixed = not model_options.fit_airmass + self.model.width_trend.fixed = not model_options.fit_width + def fit(self): """ Perform a fit and update the model with best-fit values.