diff --git a/stellarphot/gui_tools/photometry_widget_functions.py b/stellarphot/gui_tools/photometry_widget_functions.py
index e89f6376..7391fd13 100644
--- a/stellarphot/gui_tools/photometry_widget_functions.py
+++ b/stellarphot/gui_tools/photometry_widget_functions.py
@@ -2,17 +2,137 @@
import ipywidgets as ipw
from ccdproc import ImageFileCollection
+from ipyautoui.custom import FileChooser
+from stellarphot import PhotometryData
from stellarphot.settings import (
PhotometryApertures,
PhotometryFileSettings,
ui_generator,
)
+from stellarphot.settings.custom_widgets import Spinner
-__all__ = ["PhotometrySettings"]
+__all__ = ["TessAnalysisInputControls", "PhotometrySettingsOLDBAD"]
-class PhotometrySettings:
+class TessAnalysisInputControls(ipw.VBox):
+ """
+ A class to hold the widgets for choosing TESS input
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ hidden = ipw.Layout(display="none")
+
+ self.phot_chooser = FileChooser(filter_pattern=["*.csv", "*.fits", "*.ecsv"])
+ self._fits_openr = ipw.VBox(
+ children=[
+ ipw.HTML(value="
Select your photometry/flux file
"),
+ self.phot_chooser,
+ ]
+ )
+ self.tic_file_chooser = FileChooser(filter_pattern=["*.json"])
+ fits_openr2 = ipw.VBox(
+ children=[
+ ipw.HTML(value="Select your TESS info file
"),
+ self.tic_file_chooser,
+ ],
+ layout=hidden,
+ )
+ self._passband = ipw.Dropdown(
+ description="Ccoose Filter",
+ options=["gp", "ip"],
+ disabled=True,
+ layout=hidden,
+ )
+
+ spinner = Spinner(message="Loading photometry...
")
+
+ self.phot_data = None
+
+ def update_filter_list(_):
+ spinner.start()
+ self.phot_data = PhotometryData.read(self.phot_chooser.value)
+ passband_data = self.phot_data["passband"]
+ fits_openr2.layout.display = "flex"
+ self._passband.layout.display = "flex"
+ self._passband.options = sorted(set(passband_data))
+ self._passband.disabled = False
+ self._passband.value = self._passband.options[0]
+ spinner.stop()
+
+ self.phot_chooser.observe(update_filter_list, names="_value")
+ self.children = [self._fits_openr, spinner, fits_openr2, self._passband]
+
+ @property
+ def tic_info_file(self):
+ p = Path(self.tic_file_chooser.value)
+ selected_file = p.name
+ if not selected_file:
+ raise ValueError("No TIC info json file selected")
+ return p
+
+ @property
+ def photometry_data_file(self):
+ p = Path(self.phot_chooser.value)
+ selected_file = p.name
+ if not selected_file:
+ raise ValueError("No photometry data file selected")
+ return p
+
+ @property
+ def passband(self):
+ return self._passband.value
+
+
+def filter_by_dates(
+ phot_times=None,
+ use_no_data_before=None,
+ use_no_data_between=None,
+ use_no_data_after=None,
+):
+ n_dropped = 0
+
+ bad_data = phot_times < use_no_data_before
+
+ n_dropped = bad_data.sum()
+
+ if n_dropped > 0:
+ print(
+ f"👉👉👉👉 Dropping {n_dropped} data points before "
+ f"BJD {use_no_data_before}"
+ )
+
+ bad_data = bad_data | (
+ (use_no_data_between[0][0] < phot_times)
+ & (phot_times < use_no_data_between[0][1])
+ )
+
+ new_dropped = bad_data.sum() - n_dropped
+
+ if new_dropped:
+ print(
+ f"👉👉👉👉 Dropping {new_dropped} data points between "
+ f"BJD {use_no_data_between[0][0]} and {use_no_data_between[0][1]}"
+ )
+
+ n_dropped += new_dropped
+
+ bad_data = bad_data | (phot_times > use_no_data_after)
+
+ new_dropped = bad_data.sum() - n_dropped
+
+ if new_dropped:
+ print(
+ f"👉👉👉👉 Dropping {new_dropped} data points after "
+ f"BJD {use_no_data_after}"
+ )
+
+ n_dropped += new_dropped
+ return bad_data
+
+
+class PhotometrySettingsOLDBAD:
"""
A class to hold the widgets for photometry settings.
diff --git a/stellarphot/io/tess.py b/stellarphot/io/tess.py
index 09215d66..bf1f5651 100644
--- a/stellarphot/io/tess.py
+++ b/stellarphot/io/tess.py
@@ -1,4 +1,5 @@
import re
+import warnings
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile
@@ -356,6 +357,41 @@ def from_tic_id(cls, tic_id):
tess_mag_error=toi_table["TESS Mag err"],
)
+ def transit_time_for_observation(self, obs_times):
+ """
+ Calculate the transit time for a set of observation times.
+
+ Parameters
+ ----------
+
+ obs_times : `astropy.time.Time`
+ The times of the observations.
+
+ Returns
+ -------
+
+ `astropy.time.Time`
+ The transit times for the observations.
+ """
+ first_obs = obs_times[0]
+ # Three possible cases here. Either the first time is close to, but before, a
+ # transit, or it is close to, but just after a transit, or it is nowhere close
+ # to a transit.
+ # Assume that the first time is just before a transit
+ cycle_number = int((first_obs - self.epoch) / self.period + 1)
+ that_transit = cycle_number * self.period + self.epoch
+
+ # Check -- is the first time closer to this transit or the one before it?
+ previous_transit = that_transit - self.period
+ if abs(first_obs - previous_transit) < abs(first_obs - that_transit):
+ that_transit = previous_transit
+
+ # Check -- are we way, way, way off from a transit?
+ if abs(first_obs - that_transit) > 3 * self.duration:
+ warnings.warn("Observation times are far from a transit.", stacklevel=2)
+
+ return that_transit
+
def tess_photometry_setup(tic_id=None, TOI_object=None, overwrite=False):
"""
diff --git a/stellarphot/io/tests/test_tess.py b/stellarphot/io/tests/test_tess.py
index cbffd51c..ae7df114 100644
--- a/stellarphot/io/tests/test_tess.py
+++ b/stellarphot/io/tests/test_tess.py
@@ -3,8 +3,10 @@
import warnings
from pathlib import Path
+import numpy as np
import pytest
from astropy.coordinates import SkyCoord
+from astropy.time import Time
from requests import ConnectionError, ReadTimeout
from stellarphot.io.tess import (
@@ -187,6 +189,25 @@ def test_from_tic_id(self, tess_tic_expected_values):
assert toi_info.coord.separation(new_toi.coord).arcsecond < 0.01
+ @pytest.mark.parametrize("start_before_midpoint", [True, False])
+ def test_transit_time_for_observation(self, sample_toi, start_before_midpoint):
+ # For this test we are checking that the correct transit time is identified
+ # for a given observation time.
+ # For the sake of the test, we are going to use the 124th transit of the TOI
+ # as the reference transit.
+ reference_midpoint = sample_toi.epoch + 124 * sample_toi.period
+ if start_before_midpoint:
+ # Start the observation before the midpoint (and before the transit)
+ test_time_start = Time(reference_midpoint - 0.8 * sample_toi.duration)
+ else:
+ # Start the observation after the midpoint
+ test_time_start = Time(reference_midpoint + 0.1 * sample_toi.duration)
+ obs_times = test_time_start + np.linspace(0, 2) * sample_toi.duration
+
+ assert sample_toi.transit_time_for_observation(obs_times).jd == pytest.approx(
+ reference_midpoint.jd
+ )
+
class TestTessPhotometrySetup:
# This auto-used fixture changes the working directory to the temporary directory
diff --git a/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb b/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb
index 166951f7..6340166d 100644
--- a/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb
+++ b/stellarphot/notebooks/photometry/06-transit-fit-template.ipynb
@@ -9,59 +9,40 @@
"outputs": [],
"source": [
"from itertools import product\n",
- "import pickle\n",
"\n",
- "import ipywidgets as ipw\n",
"import numpy as np\n",
"\n",
+ "from astropy.timeseries import TimeSeries, aggregate_downsample\n",
+ "from astropy.time import Time\n",
+ "from astropy.table import Table, Column\n",
+ "from astropy import units as u\n",
"from matplotlib import pyplot as plt\n",
- "from astropy.table import Table\n",
"\n",
- "from stellarphot.transit_fitting import TransitModelFit\n",
- "from stellarphot.transit_fitting.gui import *\n",
+ "from stellarphot.transit_fitting import TransitModelFit, TransitModelOptions\n",
"from stellarphot.io import TOI\n",
- "from stellarphot.settings.fits_opener import FitsOpener\n",
- "from stellarphot.plotting import plot_many_factors\n",
- "from stellarphot import PhotometryData\n",
- "from astropy.timeseries import BinnedTimeSeries, TimeSeries, aggregate_downsample\n",
- "from astropy.time import Time\n",
- "from astropy.table import Table, Column\n",
- "from astropy import units as u"
+ "from stellarphot.plotting import plot_transit_lightcurve\n",
+ "from stellarphot.gui_tools.photometry_widget_functions import TessAnalysisInputControls, filter_by_dates\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 0. Get some data"
+ "### 0. Get some data\n",
+ "\n",
+ "+ Select photometry file with relative flux\n",
+ "+ Select passband\n",
+ "+ Select TESS info file"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
- "fits_openr = FitsOpener(\n",
- " title=\"Select your photometry/flux file\",\n",
- " filter_pattern=[\"*.csv\", \"*.fits\", \"*.ecsv\"],\n",
- ")\n",
- "fits_openr2 = FitsOpener(title=\"Select your TESS info file\", filter_pattern=[\"*.json\"])\n",
- "passband = ipw.Dropdown(description=\"Filter\", options=[\"gp\", \"ip\"], disabled=True)\n",
- "box = ipw.VBox()\n",
- "\n",
- "def update_filter_list(change):\n",
- " tab = Table.read(fits_openr.path)[\"passband\"]\n",
- " passband.options = sorted(set(tab))\n",
- " passband.disabled = False\n",
- " passband.value = passband.options[0]\n",
- "\n",
- "\n",
- "fits_openr.file_chooser.observe(update_filter_list, names=\"_value\")\n",
- "box.children = [fits_openr.file_chooser, fits_openr2.file_chooser, passband]\n",
- "box"
+ "taic = TessAnalysisInputControls()\n",
+ "taic"
]
},
{
@@ -73,17 +54,19 @@
"outputs": [],
"source": [
"# 👉 File with photometry, including flux\n",
- "photometry_file = fits_openr.path\n",
+ "photometry_file = taic.photometry_data_file\n",
+ "inp_photometry = taic.phot_data\n",
"\n",
"# 👉 File with exoplanet info in\n",
- "tess_info_output_file = fits_openr2.path"
+ "tess_info_output_file = taic.tic_info_file\n",
+ "tess_info = TOI.model_validate_json(tess_info_output_file.read_text())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 👇👇👇 use this to exclude some data (only if needed!) 👇👇👇"
+ "### Get just the target star and some information about it"
]
},
{
@@ -92,70 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
- "use_no_data_before = Time(2400000, format=\"jd\", scale=\"tdb\")\n",
- "\n",
- "use_no_data_between = [\n",
- " [Time(2400000, format=\"jd\", scale=\"tdb\"), Time(2400000, format=\"jd\", scale=\"tdb\")]\n",
- "]\n",
- "\n",
- "use_no_data_after = Time(2499999, format=\"jd\", scale=\"tdb\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "photometry = PhotometryData.read(photometry_file)\n",
- "\n",
- "tess_info = TOI.model_validate_json(tess_info_output_file.read_text())\n",
- "# with open(tess_info_output_file, \"rb\") as f:\n",
- "# tess_info = pickle.load(f)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "phot_times = Time(photometry[\"bjd\"], format=\"jd\", scale=\"tdb\")\n",
- "\n",
- "n_dropped = 0\n",
- "\n",
- "bad_data = phot_times < use_no_data_before\n",
- "\n",
- "n_dropped = bad_data.sum()\n",
- "\n",
- "if n_dropped > 0:\n",
- " print(f\"👉👉👉👉 Dropping {n_dropped} data points before BJD {use_no_data_before}\")\n",
- "\n",
- "bad_data = bad_data | (\n",
- " (use_no_data_between[0][0] < phot_times) & (phot_times < use_no_data_between[0][1])\n",
- ")\n",
- "\n",
- "new_dropped = bad_data.sum() - n_dropped\n",
- "\n",
- "if new_dropped:\n",
- " print(\n",
- " f\"👉👉👉👉 Dropping {new_dropped} data points between BJD {use_no_data_between[0][0]} and {use_no_data_between[0][1]}\"\n",
- " )\n",
- "\n",
- "n_dropped += new_dropped\n",
- "\n",
- "bad_data = bad_data | (phot_times > use_no_data_after)\n",
- "\n",
- "new_dropped = bad_data.sum() - n_dropped\n",
- "\n",
- "if new_dropped:\n",
- " print(f\"👉👉👉👉 Dropping {new_dropped} data points after BJD {use_no_data_after}\")\n",
- "\n",
- "n_dropped += new_dropped\n",
- "\n",
- "photometry = photometry[~bad_data]"
+ "photometry = inp_photometry.lightcurve_for(1, flux_column=\"relative_flux\", passband=taic.passband).remove_nans()"
]
},
{
@@ -166,89 +86,15 @@
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# These affect the fitting that is done\n",
- "\n",
- "# bin size in minutes\n",
- "bin_size = 5 * u.min\n",
- "\n",
- "# Keep the time of transit fixed?\n",
- "keep_fixed_transit_time = True\n",
- "transit_time_range = 60 * u.min\n",
- "\n",
- "# Keep radius of planet fixed?\n",
- "\n",
- "keep_fixed_radius_planet = False\n",
- "\n",
- "# Keep radius of orbit fixed?\n",
- "\n",
- "keep_fixed_radius_orbit = False\n",
- "\n",
- "# Remove effects of airmas?\n",
- "fit_airmass = False\n",
- "\n",
- "# Remove effects of sky background?\n",
- "fit_spp = False\n",
- "\n",
- "# Remove effects of change in focus?\n",
- "fit_width = False"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "# Enter your object's period here\n",
- "period = tess_info.period\n",
- "\n",
- "# Enter the epoch here\n",
- "epoch = tess_info.epoch # Time(2458761.602894, scale='tdb', format='jd')\n",
- "\n",
- "# Enter the duration below\n",
- "duration = tess_info.duration\n",
+ "### Fit settings\n",
"\n",
- "# Enter the transit depth here -- get the \"ppm\" value from ExoFOP-TESS\n",
- "depth = tess_info.depth_ppt\n",
- "\n",
- "# Enter object name\n",
- "obj = f\"TIC {tess_info.tic_id}\"\n",
- "\n",
- "# Enter filter\n",
- "phot_filter = \"rp\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# These affect spacing of lines on final plot\n",
- "high = 1.06\n",
- "low = 0.82\n",
- "scale = 0.15 * (high - low)\n",
- "shift = -0.72 * (high - low)"
+ "+ Do any detrending by a covariate?\n",
+ "+ Which parameters are fixed?"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "code",
"execution_count": null,
@@ -257,57 +103,16 @@
},
"outputs": [],
"source": [
- "target_star = photometry[\"star_id\"] == 1\n",
- "\n",
- "# No changes to the line below, it is grabbing the first time in the data series\n",
- "then = Time(photometry[\"bjd\"][target_star][0], scale=\"tdb\", format=\"jd\")\n",
+ "# These affect the fitting that is done\n",
"\n",
- "date_obs = photometry[\"date-obs\"][target_star][0]\n",
- "exposure_time = photometry[\"exposure\"][target_star][0]"
+ "model_options = TransitModelOptions()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Looks like we need to normalize the data first....."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "band_filter = photometry[\"passband\"] == phot_filter\n",
- "\n",
- "target_and_filter = target_star & band_filter"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "photometry = photometry[target_and_filter]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "cycle_number = int((then - epoch) / period + 1)\n",
- "that_transit = cycle_number * period + epoch\n",
- "that_transit"
+ "### Find the OOT region and use it to get normalization factor"
]
},
{
@@ -318,16 +123,19 @@
},
"outputs": [],
"source": [
- "start = that_transit - duration / 2\n",
+ "that_transit = tess_info.transit_time_for_observation(photometry.time)\n",
+ "start = that_transit - tess_info.duration / 2\n",
"mid = that_transit\n",
- "end = that_transit + duration / 2\n",
+ "end = that_transit + tess_info.duration / 2\n",
"\n",
"after_transit = (photometry[\"bjd\"] - 2400000 * u.day) > end\n",
"\n",
"outside_transit = (photometry[\"bjd\"] < start) | (photometry[\"bjd\"] > end)\n",
"\n",
"normalization_factor = np.nanmean(1 / photometry[\"relative_flux\"][outside_transit])\n",
- "normalized_flux = Column(photometry[\"relative_flux\"] * normalization_factor, name=\"relative_flux\")"
+ "normalized_flux = Column(photometry[\"relative_flux\"] * normalization_factor, name=\"normalized_flux\")\n",
+ "norm_flux_error = Column(normalization_factor * photometry[\"relative_flux_error\"].value, name=\"normalized_flux_error\")\n",
+ "photometry.add_columns([normalized_flux, norm_flux_error])\n"
]
},
{
@@ -340,16 +148,11 @@
"* data table\n",
"* start\n",
"* end\n",
- "* bin_size"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "len(normalization_factor * photometry[\"relative_flux_error\"].value)"
+ "* bin_size\n",
+ "\n",
+ "Data is binned twice because one finds means and the other errors\n",
+ "\n",
+ "**Also make times smaller**"
]
},
{
@@ -363,7 +166,7 @@
"t_ob = Time(photometry[\"bjd\"], scale=\"tdb\", format=\"jd\")\n",
"ts = TimeSeries(\n",
" [\n",
- " normalized_flux,\n",
+ " photometry[\"normalized_flux\"],\n",
" photometry[\"airmass\"],\n",
" photometry[\"xcenter\"],\n",
" photometry[\"sky_per_pix_avg\"],\n",
@@ -373,8 +176,8 @@
")\n",
"ts2 = TimeSeries(\n",
" [Column(\n",
- " data=normalization_factor * photometry[\"relative_flux_error\"].value,\n",
- " name=\"relative_flux_error\"\n",
+ " data=photometry[\"normalized_flux_error\"],\n",
+ " name=\"normalized_flux_error\"\n",
" )],\n",
" time=t_ob\n",
")\n",
@@ -391,9 +194,10 @@
" return np.sqrt(np.nansum(x**2)) / n\n",
"\n",
"\n",
- "binned = aggregate_downsample(ts, time_bin_size=bin_size)\n",
- "binned2 = aggregate_downsample(ts2, time_bin_size=bin_size, aggregate_func=add_quad)\n",
+ "binned = aggregate_downsample(ts, time_bin_size=model_options.bin_size * u.min)\n",
+ "binned2 = aggregate_downsample(ts2, time_bin_size=model_options.bin_size * u.min, aggregate_func=add_quad)\n",
"\n",
+ "binned[\"normalized_flux_error\"] = binned2[\"normalized_flux_error\"]\n",
"# binned_time = BinnedTimeSeries(photometry['bjd'], time_bin_start=first_time, time_bin_end=last_time, time_bin_size=bin_size)"
]
},
@@ -401,31 +205,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 1. Create the model object"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "mod = TransitModelFit()"
+ "## Model, fit, plot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 2. Load some data\n",
- "\n",
- "Here we will just load times and normalized flux. You can also set width, spp (sky per pixel) and airmass. The only two that must be set are times and flux.\n",
- "\n",
- "If you have set `mod.spp`, `mod.width` or `mod.airmass` then those things will be included in the fit. Otherwise, they are ignored.\n",
- "\n",
- "THE WEIGHTS ARE IMPORTANT TO INCLUDE"
+ "### Create the model "
]
},
{
@@ -436,132 +223,25 @@
},
"outputs": [],
"source": [
- "not_empty = ~np.isnan(binned[\"relative_flux\"])\n",
- "\n",
- "mod.times = (np.array(binned[\"time_bin_start\"].value) - 2400000)[not_empty]\n",
- "mod.data = binned[\"relative_flux\"].value[not_empty]\n",
- "mod.weights = 1 / (binned2[\"relative_flux_error\"].value)[not_empty]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3. Set up the model\n",
+ "# Make the model\n",
+ "mod = TransitModelFit()\n",
"\n",
- "You should be able to get the parameters for this from TTF. There are more parameters you can set; `shift-Tab` in the arguments to pull up the docstring, which lists and explains them all."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
+ "# Setup the model\n",
"mod.setup_model(\n",
+ " binned_data=binned,\n",
" t0=mid.jd - 2400000, # midpoint, BJD\n",
- " depth=depth, # parts per thousand\n",
- " duration=duration.to(\"day\").value, # days\n",
- " period=period.to(\"day\").value, # days\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3.25 Set up airmass, etc"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "mod.airmass = np.array(binned[\"airmass\"])[not_empty]\n",
- "mod.width = np.array(binned[\"width\"])[not_empty]\n",
- "mod.spp = np.array(binned[\"sky_per_pix_avg\"])[not_empty]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3.5 Constrain the fits if you want\n",
- "\n",
- "#### Exoplanet parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "mod.model.t0.bounds = [\n",
- " mid.jd - 2400000 - transit_time_range.to(\"day\").value / 2,\n",
- " mid.jd - 2400000 + transit_time_range.to(\"day\").value / 2,\n",
- "]\n",
- "mod.model.t0.fixed = keep_fixed_transit_time\n",
- "mod.model.a.fixed = keep_fixed_radius_orbit\n",
- "mod.model.rp.fixed = keep_fixed_radius_planet"
+ " depth=tess_info.depth_ppt, # parts per thousand\n",
+ " duration=tess_info.duration.to(\"day\").value, # days\n",
+ " period=tess_info.period.to(\"day\").value, # days\n",
+ " model_options=model_options,\n",
+ ")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Detrending parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "mod.model.spp_trend.fixed = not fit_spp\n",
- "mod.model.airmass_trend.fixed = not fit_airmass\n",
- "mod.model.width_trend.fixed = not fit_width"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "detrended_by = []\n",
- "if fit_airmass:\n",
- " detrended_by.append(\"Airmass\")\n",
- "\n",
- "if fit_spp:\n",
- " detrended_by.append(\"SPP\")\n",
- "\n",
- "if fit_width:\n",
- " detrended_by.append(\"Wdith\")\n",
- "\n",
- "detrended_by = (\n",
- " (\"Detrended by: \" + \",\".join(detrended_by)) if detrended_by else \"No detrending\"\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4. Run the fit"
+ "### Run the fit"
]
},
{
@@ -579,7 +259,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 5. Let's try a plot...."
+ "### Look at the results"
]
},
{
@@ -599,156 +279,31 @@
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "mod.model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# mod._fitter.fit_info"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "flux_full_detrend = mod.data_light_curve(detrend_by=\"all\")\n",
- "flux_full_detrend_model = mod.model_light_curve(detrend_by=\"all\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "rel_detrended_flux = flux_full_detrend / np.mean(flux_full_detrend)\n",
- "\n",
- "rel_detrended_flux_rms = np.std(rel_detrended_flux)\n",
- "rel_model_rms = np.std(flux_full_detrend_model - rel_detrended_flux)\n",
- "\n",
- "rel_flux_rms = np.std(mod.data)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "grid_y_ticks = np.arange(low, high, 0.02)"
+ "### Exclude data by date *if needed*\n"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
- "# (RMS={rel_flux_rms:.5f})\n",
- "\n",
- "plt.figure(figsize=(8, 11))\n",
- "fig, ax = plt.subplots(1, 1, figsize=(8, 11))\n",
- "\n",
- "plt.plot(\n",
- " (photometry[\"bjd\"] - 2400000 * u.day).jd,\n",
- " normalized_flux,\n",
- " \"b.\",\n",
- " label=f\"rel_flux_T1 (RMS={rel_flux_rms:.5f})\",\n",
- " ms=4,\n",
- ")\n",
- "\n",
- "plt.plot(\n",
- " mod.times,\n",
- " flux_full_detrend - 0.04,\n",
- " \".\",\n",
- " c=\"r\",\n",
- " ms=4,\n",
- " label=f\"rel_flux_T1 ({detrended_by})(RMS={rel_detrended_flux_rms:.5f}), (bin size={bin_size} min)\",\n",
- ")\n",
- "\n",
- "plt.plot(\n",
- " mod.times,\n",
- " flux_full_detrend - 0.08,\n",
- " \".\",\n",
- " c=\"g\",\n",
- " ms=4,\n",
- " label=f\"rel_flux_T1 ({detrended_by} with transit fit)(RMS={rel_model_rms:.5f}), (bin size={bin_size})\",\n",
- ")\n",
- "plt.plot(\n",
- " mod.times,\n",
- " flux_full_detrend_model - 0.08,\n",
- " c=\"g\",\n",
- " ms=4,\n",
- " label=f\"rel_flux_T1 Transit Model ([P={mod.model.period.value:.4f}], \"\n",
- " f\"(Rp/R*)^2={(mod.model.rp.value)**2:.4f}, \\na/R*={mod.model.a.value:.4f}, \"\n",
- " f\"[Tc={mod.model.t0.value + 2400000:.4f}], \"\n",
- " f\"[u1={mod.model.limb_u1.value:.1f}, u2={mod.model.limb_u2.value:.1f})\",\n",
- ")\n",
- "\n",
- "plot_many_factors(photometry, shift, scale)\n",
- "\n",
- "plt.vlines(start.jd - 2400000, low, 1.025, colors=\"r\", linestyle=\"--\", alpha=0.5)\n",
- "plt.vlines(end.jd - 2400000, low, 1.025, colors=\"r\", linestyle=\"--\", alpha=0.5)\n",
- "plt.text(\n",
- " start.jd - 2400000,\n",
- " low + 0.0005,\n",
- " f\"Predicted\\nIngress\\n{start.jd-2400000-int(start.jd - 2400000):.3f}\",\n",
- " horizontalalignment=\"center\",\n",
- " c=\"r\",\n",
- ")\n",
- "plt.text(\n",
- " end.jd - 2400000,\n",
- " low + 0.0005,\n",
- " f\"Predicted\\nEgress\\n{end.jd-2400000-int(end.jd - 2400000):.3f}\",\n",
- " horizontalalignment=\"center\",\n",
- " c=\"r\",\n",
- ")\n",
- "\n",
- "# plt.vlines(start + 0.005, low, 1, colors='darkgray', linestyle='--', alpha=0.5)\n",
- "# plt.text(start + 0.005, low+0.001, f'Left\\n{start-int(start)+0.005:.3f}', horizontalalignment='center',c='darkgray')\n",
- "# plt.vlines(end - 0.005, low, 1, colors='darkgray', linestyle='--', alpha=0.5)\n",
- "# plt.text(end - 0.005, low+0.001, f'Rght\\n{end-int(end)-0.005:.3f}', horizontalalignment='center',c='darkgray')\n",
- "\n",
- "\n",
- "plt.ylim(low, high)\n",
- "plt.xlabel(\"Barycentric Julian Date (TDB)\", fontname=\"Arial\")\n",
- "plt.ylabel(\"Relative Flux (normalized)\", fontname=\"Arial\")\n",
- "plt.title(\n",
- " f\"{obj}.01 UT{date_obs}\\nPaul P. Feder Observatory 0.4m ({phot_filter} filter, {exposure_time} exp, fap 10-25-40)\",\n",
- " fontsize=14,\n",
- " fontname=\"Arial\",\n",
+ "bad_time = filter_by_dates(\n",
+ " phot_times=photometry[\"bjd\"],\n",
+ " use_no_data_before=Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " use_no_data_between=[\n",
+ " [\n",
+ " Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " ]\n",
+ " ],\n",
+ " use_no_data_after=Time(2499999, format=\"jd\", scale=\"tdb\"),\n",
")\n",
- "plt.legend(loc=\"upper center\", frameon=False, fontsize=8, bbox_to_anchor=(0.6, 1.0))\n",
- "ax.set_yticks(grid_y_ticks)\n",
- "plt.grid()\n",
"\n",
- "plt.savefig(\n",
- " f\"TIC{tess_info.tic_id}-01_20200701_Paul-P-Feder-0.4m_gp_lightcurve.png\",\n",
- " facecolor=\"w\",\n",
- ")"
+ "photometry = photometry[~bad_time]"
]
},
{
@@ -759,18 +314,14 @@
},
"outputs": [],
"source": [
- "mod.n_fit_parameters"
+ "mod.model"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "mod._all_detrend_params"
+ "### Attempt to calculate BIC, but...this seems to have side effects on the rest of notebook "
]
},
{
@@ -824,7 +375,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "stelldev-pyd2",
"language": "python",
"name": "python3"
},
diff --git a/stellarphot/notebooks/photometry/06b-transit-fit-template-fancy-plot.ipynb b/stellarphot/notebooks/photometry/06b-transit-fit-template-fancy-plot.ipynb
new file mode 100644
index 00000000..527778fb
--- /dev/null
+++ b/stellarphot/notebooks/photometry/06b-transit-fit-template-fancy-plot.ipynb
@@ -0,0 +1,420 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from itertools import product\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "from astropy.timeseries import TimeSeries, aggregate_downsample\n",
+ "from astropy.time import Time\n",
+ "from astropy.table import Table, Column\n",
+ "from astropy import units as u\n",
+ "from matplotlib import pyplot as plt\n",
+ "\n",
+ "from stellarphot.transit_fitting import TransitModelFit, TransitModelOptions\n",
+ "from stellarphot.io import TOI\n",
+ "from stellarphot.plotting import plot_transit_lightcurve\n",
+ "from stellarphot.gui_tools.photometry_widget_functions import TessAnalysisInputControls, filter_by_dates\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 0. Get some data\n",
+ "\n",
+ "+ Select photometry file with relative flux\n",
+ "+ Select passband\n",
+ "+ Select TESS info file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "taic = TessAnalysisInputControls()\n",
+ "taic"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 👉 File with photometry, including flux\n",
+ "photometry_file = taic.photometry_data_file\n",
+ "inp_photometry = taic.phot_data\n",
+ "\n",
+ "# 👉 File with exoplanet info in\n",
+ "tess_info_output_file = taic.tic_info_file\n",
+ "tess_info = TOI.model_validate_json(tess_info_output_file.read_text())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Get just the target star and some information about it"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "photometry = inp_photometry.lightcurve_for(1, flux_column=\"relative_flux\", passband=taic.passband).remove_nans()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### You may need to alter some of the settings here"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Fit settings\n",
+ "\n",
+ "+ Do any detrending by a covariate?\n",
+ "+ Which parameters are fixed?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# These affect the fitting that is done\n",
+ "\n",
+ "model_options = TransitModelOptions()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Find the OOT region and use it to get normalization factor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "that_transit = tess_info.transit_time_for_observation(photometry.time)\n",
+ "start = that_transit - tess_info.duration / 2\n",
+ "mid = that_transit\n",
+ "end = that_transit + tess_info.duration / 2\n",
+ "\n",
+ "after_transit = (photometry[\"bjd\"] - 2400000 * u.day) > end\n",
+ "\n",
+ "outside_transit = (photometry[\"bjd\"] < start) | (photometry[\"bjd\"] > end)\n",
+ "\n",
+ "normalization_factor = np.nanmean(1 / photometry[\"relative_flux\"][outside_transit])\n",
+ "normalized_flux = Column(photometry[\"relative_flux\"] * normalization_factor, name=\"normalized_flux\")\n",
+ "norm_flux_error = Column(normalization_factor * photometry[\"relative_flux_error\"].value, name=\"normalized_flux_error\")\n",
+ "photometry.add_columns([normalized_flux, norm_flux_error])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Bin Data\n",
+ "\n",
+ "Need\n",
+ "* data table\n",
+ "* start\n",
+ "* end\n",
+ "* bin_size\n",
+ "\n",
+ "Data is binned twice because one finds means and the other errors\n",
+ "\n",
+ "**Also make times smaller**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "t_ob = Time(photometry[\"bjd\"], scale=\"tdb\", format=\"jd\")\n",
+ "ts = TimeSeries(\n",
+ " [\n",
+ " photometry[\"normalized_flux\"],\n",
+ " photometry[\"airmass\"],\n",
+ " photometry[\"xcenter\"],\n",
+ " photometry[\"sky_per_pix_avg\"],\n",
+ " photometry[\"width\"],\n",
+ " ],\n",
+ " time=t_ob,\n",
+ ")\n",
+ "ts2 = TimeSeries(\n",
+ " [Column(\n",
+ " data=photometry[\"normalized_flux_error\"],\n",
+ " name=\"normalized_flux_error\"\n",
+ " )],\n",
+ " time=t_ob\n",
+ ")\n",
+ "\n",
+ "first_time = photometry[\"bjd\"][0] - 2400000\n",
+ "last_time = photometry[\"bjd\"][-1] - 2400000\n",
+ "\n",
+ "\n",
+ "def add_quad(x):\n",
+ " try:\n",
+ " n = len(x)\n",
+ " except TypeError:\n",
+ " n = 1\n",
+ " return np.sqrt(np.nansum(x**2)) / n\n",
+ "\n",
+ "\n",
+ "binned = aggregate_downsample(ts, time_bin_size=model_options.bin_size * u.min)\n",
+ "binned2 = aggregate_downsample(ts2, time_bin_size=model_options.bin_size * u.min, aggregate_func=add_quad)\n",
+ "\n",
+ "binned[\"normalized_flux_error\"] = binned2[\"normalized_flux_error\"]\n",
+ "# binned_time = BinnedTimeSeries(photometry['bjd'], time_bin_start=first_time, time_bin_end=last_time, time_bin_size=bin_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model, fit, plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Create the model "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Make the model\n",
+ "mod = TransitModelFit()\n",
+ "\n",
+ "# Setup the model\n",
+ "mod.setup_model(\n",
+ " binned_data=binned,\n",
+ " t0=mid.jd - 2400000, # midpoint, BJD\n",
+ " depth=tess_info.depth_ppt, # parts per thousand\n",
+ " duration=tess_info.duration.to(\"day\").value, # days\n",
+ " period=tess_info.period.to(\"day\").value, # days\n",
+ " model_options=model_options,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Run the fit"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "mod.fit()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Look at the results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "plt.plot(mod.times, mod.data, \".\")\n",
+ "plt.plot(mod.times, mod.model_light_curve())\n",
+ "plt.vlines(start.jd - 2400000, 0.98, 1.02, colors=\"r\", linestyle=\"--\", alpha=0.5)\n",
+ "plt.vlines(end.jd - 2400000, 0.98, 1.02, colors=\"r\", linestyle=\"--\", alpha=0.5)\n",
+ "plt.title(\"Data and fit\")\n",
+ "plt.grid()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Exclude data by date *if needed*\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bad_time = filter_by_dates(\n",
+ " phot_times=photometry[\"bjd\"],\n",
+ " use_no_data_before=Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " use_no_data_between=[\n",
+ " [\n",
+ " Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " Time(2400000, format=\"jd\", scale=\"tdb\"),\n",
+ " ]\n",
+ " ],\n",
+ " use_no_data_after=Time(2499999, format=\"jd\", scale=\"tdb\"),\n",
+ ")\n",
+ "\n",
+ "photometry = photometry[~bad_time]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "mod.model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Make the big plot"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "plot_transit_lightcurve(\n",
+ " photometry,\n",
+ " mod,\n",
+ " tess_info,\n",
+ " model_options.bin_size * u.min\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Attempt to calculate BIC, but...this seems to have side effects on the rest of notebook "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def evaluate_fits(mod):\n",
+ " BICs = []\n",
+ " settings = []\n",
+ " all_trendable = mod._all_detrend_params\n",
+ " tf_sequence = product([True, False], repeat=len(all_trendable))\n",
+ " for fixed in tf_sequence:\n",
+ " this_summary = []\n",
+ " for param, fix in zip(all_trendable, fixed):\n",
+ " trend_mod = getattr(mod.model, f\"{param}_trend\")\n",
+ " if fix:\n",
+ " setattr(mod.model, f\"{param}_trend\", 0.0)\n",
+ " trend_mod.fixed = fix\n",
+ " this_summary.append(f\"{param}: {not fix}\")\n",
+ "\n",
+ " settings.append(\", \".join(this_summary))\n",
+ " mod.fit()\n",
+ " BICs.append(mod.BIC)\n",
+ " return Table(data=[settings, BICs], names=[\"Fit this param?\", \"BIC\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "bic_table = evaluate_fits(mod)\n",
+ "bic_table.sort(\"BIC\")\n",
+ "bic_table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "stelldev-pyd2",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/stellarphot/plotting/transit_plots.py b/stellarphot/plotting/transit_plots.py
index 3fed492d..34851b7e 100644
--- a/stellarphot/plotting/transit_plots.py
+++ b/stellarphot/plotting/transit_plots.py
@@ -2,7 +2,12 @@
from astropy import units as u
from matplotlib import pyplot as plt
-__all__ = ["plot_many_factors", "bin_data", "scale_and_shift"]
+__all__ = [
+ "plot_many_factors",
+ "bin_data",
+ "scale_and_shift",
+ "plot_transit_lightcurve",
+]
def plot_many_factors(photometry, shift, scale, ax=None):
@@ -50,7 +55,6 @@ def plot_many_factors(photometry, shift, scale, ax=None):
if ax is None:
ax = plt.gca()
- print(f"{scale_airmass.min()} {scale_airmass.max()}")
ax.plot(
x_times,
scale_counts,
@@ -170,3 +174,123 @@ def scale_and_shift(data_set, scale, shift, pos=True):
data_set += shift
return data_set
+
+
+def plot_transit_lightcurve(
+ photometry,
+ mod,
+ tess_info,
+ bin_size,
+ low=0.82,
+ high=1.06,
+):
+ # These affect spacing of lines on final plot
+ scale = 0.15 * (high - low)
+ shift = -0.72 * (high - low)
+ # (RMS={rel_flux_rms:.5f})
+ grid_y_ticks = np.arange(low, high, 0.02)
+
+ date_obs = photometry["date-obs"][0]
+ phot_filter = photometry["passband"][0]
+ exposure_time = photometry["exposure"][0]
+
+ midpoint = tess_info.transit_time_for_observation(photometry["bjd"])
+ start = midpoint - 0.5 * tess_info.duration
+ end = midpoint + 0.5 * tess_info.duration
+
+ detrended_by = []
+ if not mod.model.airmass_trend.fixed:
+ detrended_by.append("Airmass")
+ if not mod.model.spp_trend.fixed:
+ detrended_by.append("SPP")
+ if not mod.model.width_trend.fixed:
+ detrended_by.append("Width")
+
+ flux_full_detrend = mod.data_light_curve(detrend_by="all")
+ flux_full_detrend_model = mod.model_light_curve(detrend_by="all")
+ rel_detrended_flux = flux_full_detrend / np.mean(flux_full_detrend)
+
+ rel_detrended_flux_rms = np.std(rel_detrended_flux)
+ rel_model_rms = np.std(flux_full_detrend_model - rel_detrended_flux)
+
+ rel_flux_rms = np.std(mod.data)
+
+ fig, ax = plt.subplots(1, 1, figsize=(8, 11))
+
+ plt.plot(
+ (photometry["bjd"] - 2400000 * u.day).jd,
+ photometry["normalized_flux"],
+ "b.",
+ label=f"rel_flux_T1 (RMS={rel_flux_rms:.5f})",
+ ms=4,
+ )
+
+ plt.plot(
+ mod.times,
+ flux_full_detrend - 0.04,
+ ".",
+ c="r",
+ ms=4,
+ label=f"rel_flux_T1 ({detrended_by})(RMS={rel_detrended_flux_rms:.5f}), "
+ f"(bin size={bin_size})",
+ )
+
+ plt.plot(
+ mod.times,
+ flux_full_detrend - 0.08,
+ ".",
+ c="g",
+ ms=4,
+ label=f"rel_flux_T1 ({detrended_by} with transit fit)(RMS={rel_model_rms:.5f}),"
+ f" (bin size={bin_size})",
+ )
+ plt.plot(
+ mod.times,
+ flux_full_detrend_model - 0.08,
+ c="g",
+ ms=4,
+ label=f"rel_flux_T1 Transit Model ([P={mod.model.period.value:.4f}], "
+ f"(Rp/R*)^2={(mod.model.rp.value)**2:.4f}, \na/R*={mod.model.a.value:.4f}, "
+ f"[Tc={mod.model.t0.value + 2400000:.4f}], "
+ f"[u1={mod.model.limb_u1.value:.1f}, u2={mod.model.limb_u2.value:.1f})",
+ )
+
+ plot_many_factors(photometry, shift, scale)
+
+ plt.vlines(start.jd - 2400000, low, 1.025, colors="r", linestyle="--", alpha=0.5)
+ plt.vlines(end.jd - 2400000, low, 1.025, colors="r", linestyle="--", alpha=0.5)
+ plt.text(
+ start.jd - 2400000,
+ low + 0.0005,
+ f"Predicted\nIngress\n{start.jd-2400000-int(start.jd - 2400000):.3f}",
+ horizontalalignment="center",
+ c="r",
+ )
+ plt.text(
+ end.jd - 2400000,
+ low + 0.0005,
+ f"Predicted\nEgress\n{end.jd-2400000-int(end.jd - 2400000):.3f}",
+ horizontalalignment="center",
+ c="r",
+ )
+
+ plt.ylim(low, high)
+ plt.xlabel("Barycentric Julian Date (TDB)", fontname="Arial")
+ plt.ylabel("Relative Flux (normalized)", fontname="Arial")
+ plt.title(
+ f"TIC {tess_info.tic_id}.01 UT{date_obs}\nPaul P. Feder Observatory 0.4m "
+ f"({phot_filter} filter, {exposure_time} exp, "
+ f"fap {photometry['aperture'][0].value:.0f}"
+ f"-{photometry['annulus_inner'][0].value:.0f}"
+ f"-{photometry['annulus_outer'][0].value:.0f})\n",
+ fontsize=14,
+ fontname="Arial",
+ )
+ plt.legend(loc="upper center", frameon=False, fontsize=8, bbox_to_anchor=(0.6, 1.0))
+ ax.set_yticks(grid_y_ticks)
+ plt.grid()
+
+ plt.savefig(
+ f"TIC{tess_info.tic_id}-01_{date_obs}_Paul-P-Feder-0.4m_{phot_filter}_lightcurve.png",
+ facecolor="w",
+ )
diff --git a/stellarphot/transit_fitting/core.py b/stellarphot/transit_fitting/core.py
index 5817db2b..72d55019 100644
--- a/stellarphot/transit_fitting/core.py
+++ b/stellarphot/transit_fitting/core.py
@@ -1,8 +1,10 @@
import warnings
import numpy as np
+from astropy import units as u
from astropy.modeling.fitting import LevMarLSQFitter, _validate_model
from astropy.modeling.models import custom_model
+from pydantic import BaseModel
# Functions below changed from private to public in astropy 5
try:
@@ -30,7 +32,7 @@
"pip install batman-package"
)
-__all__ = ["VariableArgsFitter", "TransitModelFit"]
+__all__ = ["VariableArgsFitter", "TransitModelOptions", "TransitModelFit"]
class VariableArgsFitter(LevMarLSQFitter):
@@ -120,6 +122,17 @@ def __call__(
return model_copy
+class TransitModelOptions(BaseModel):
+ bin_size: float = 5.0
+ keep_transit_time_fixed: bool = True
+ transit_time_range: float = 60.0
+ keep_radius_planet_fixed: bool = False
+ keep_radius_orbit_fixed: bool = False
+ fit_airmass: bool = False
+ fit_width: bool = False
+ fit_spp: bool = False
+
+
class TransitModelFit:
"""
Transit model fits to observed light curves.
@@ -385,6 +398,7 @@ def transit_model_with_trends(
def setup_model(
self,
+ binned_data=None,
t0=0,
depth=0,
duration=0,
@@ -393,6 +407,7 @@ def setup_model(
airmass_trend=0.0,
width_trend=0.0,
spp_trend=0.0,
+ model_options=None,
):
"""
Configure a transit model for fitting. The ``duration`` and ``depth``
@@ -429,11 +444,28 @@ def setup_model(
spp_trend : float
Coefficient for a linear trend in sky per pixel.
+ options : TransitModelOptions, optional
+ Options for the transit model fit.
+
Returns
-------
None
Sets values for the model parameters.
"""
+ if binned_data:
+ self.times = (
+ np.array(
+ (
+ binned_data["time_bin_start"] + binned_data["time_bin_size"] / 2
+ ).value
+ )
+ - 2400000
+ )
+ self.data = binned_data["normalized_flux"].value
+ self.weights = 1 / (binned_data["normalized_flux_error"].value)
+ self.airmass = np.array(binned_data["airmass"])
+ self.width = np.array(binned_data["width"])
+ self.spp = np.array(binned_data["sky_per_pix_avg"])
self._setup_transit_model()
# rp is related to depth in a straightforward way
@@ -461,6 +493,20 @@ def setup_model(
except ValueError:
pass
+ if model_options is not None:
+ # Setup the model more 🙄
+ self.model.t0.bounds = [
+ t0 - (model_options.transit_time_range * u.min).to("day").value / 2,
+ t0 + (model_options.transit_time_range * u.min).to("day").value / 2,
+ ]
+ self.model.t0.fixed = model_options.keep_transit_time_fixed
+ self.model.a.fixed = model_options.keep_radius_orbit_fixed
+ self.model.rp.fixed = model_options.keep_radius_planet_fixed
+
+ self.model.spp_trend.fixed = not model_options.fit_spp
+ self.model.airmass_trend.fixed = not model_options.fit_airmass
+ self.model.width_trend.fixed = not model_options.fit_width
+
def fit(self):
"""
Perform a fit and update the model with best-fit values.