Skip to content

Commit

Permalink
Mover much more model setup into code and out of notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mwcraig committed Nov 25, 2024
1 parent 0d18fda commit bf086ee
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
28 changes: 5 additions & 23 deletions stellarphot/notebooks/photometry/06-transit-fit-template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -270,6 +270,7 @@
"binned = aggregate_downsample(ts, time_bin_size=model_options.bin_size * u.min)\n",
"binned2 = aggregate_downsample(ts2, time_bin_size=model_options.bin_size * u.min, aggregate_func=add_quad)\n",
"\n",
"binned[\"normalized_flux_error\"] = binned2[\"normalized_flux_error\"]\n",
"# binned_time = BinnedTimeSeries(photometry['bjd'], time_bin_start=first_time, time_bin_end=last_time, time_bin_size=bin_size)"
]
},
Expand Down Expand Up @@ -299,34 +300,15 @@
"# Make the model\n",
"mod = TransitModelFit()\n",
"\n",
"# Load data\n",
"mod.times = (np.array((binned[\"time_bin_start\"] + binned[\"time_bin_size\"]/2).value) - 2400000)\n",
"mod.data = binned[\"normalized_flux\"].value\n",
"mod.weights = 1 / (binned2[\"normalized_flux_error\"].value)\n",
"mod.airmass = np.array(binned[\"airmass\"])\n",
"mod.width = np.array(binned[\"width\"])\n",
"mod.spp = np.array(binned[\"sky_per_pix_avg\"])\n",
"\n",
"# Setup the model\n",
"mod.setup_model(\n",
" binned_data=binned,\n",
" t0=mid.jd - 2400000, # midpoint, BJD\n",
" depth=tess_info.depth_ppt, # parts per thousand\n",
" duration=tess_info.duration.to(\"day\").value, # days\n",
" period=tess_info.period.to(\"day\").value, # days\n",
")\n",
"\n",
"# Setup the model more 🙄\n",
"mod.model.t0.bounds = [\n",
" mid.jd - 2400000 - (model_options.transit_time_range * u.min).to(\"day\").value / 2,\n",
" mid.jd - 2400000 + (model_options.transit_time_range * u.min).to(\"day\").value / 2,\n",
"]\n",
"mod.model.t0.fixed = model_options.keep_transit_time_fixed\n",
"mod.model.a.fixed = model_options.keep_radius_orbit_fixed\n",
"mod.model.rp.fixed = model_options.keep_radius_planet_fixed\n",
"\n",
"mod.model.spp_trend.fixed = not model_options.fit_spp\n",
"mod.model.airmass_trend.fixed = not model_options.fit_airmass\n",
"mod.model.width_trend.fixed = not model_options.fit_width\n"
" model_options=model_options,\n",
")\n"
]
},
{
Expand Down
34 changes: 34 additions & 0 deletions stellarphot/transit_fitting/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

import numpy as np
from astropy import units as u
from astropy.modeling.fitting import LevMarLSQFitter, _validate_model
from astropy.modeling.models import custom_model
from pydantic import BaseModel
Expand Down Expand Up @@ -397,6 +398,7 @@ def transit_model_with_trends(

def setup_model(
self,
binned_data=None,
t0=0,
depth=0,
duration=0,
Expand All @@ -405,6 +407,7 @@ def setup_model(
airmass_trend=0.0,
width_trend=0.0,
spp_trend=0.0,
model_options=None,
):
"""
Configure a transit model for fitting. The ``duration`` and ``depth``
Expand Down Expand Up @@ -441,11 +444,28 @@ def setup_model(
spp_trend : float
Coefficient for a linear trend in sky per pixel.
options : TransitModelOptions, optional
Options for the transit model fit.
Returns
-------
None
Sets values for the model parameters.
"""
if binned_data:
self.times = (
np.array(
(
binned_data["time_bin_start"] + binned_data["time_bin_size"] / 2
).value
)
- 2400000
)
self.data = binned_data["normalized_flux"].value
self.weights = 1 / (binned_data["normalized_flux_error"].value)
self.airmass = np.array(binned_data["airmass"])
self.width = np.array(binned_data["width"])
self.spp = np.array(binned_data["sky_per_pix_avg"])
self._setup_transit_model()

# rp is related to depth in a straightforward way
Expand Down Expand Up @@ -473,6 +493,20 @@ def setup_model(
except ValueError:
pass

if model_options is not None:
# Setup the model more 🙄
self.model.t0.bounds = [
t0 - (model_options.transit_time_range * u.min).to("day").value / 2,
t0 + (model_options.transit_time_range * u.min).to("day").value / 2,
]
self.model.t0.fixed = model_options.keep_transit_time_fixed
self.model.a.fixed = model_options.keep_radius_orbit_fixed
self.model.rp.fixed = model_options.keep_radius_planet_fixed

self.model.spp_trend.fixed = not model_options.fit_spp
self.model.airmass_trend.fixed = not model_options.fit_airmass
self.model.width_trend.fixed = not model_options.fit_width

def fit(self):
"""
Perform a fit and update the model with best-fit values.
Expand Down

0 comments on commit bf086ee

Please sign in to comment.