Skip to content

Commit

Permalink
Merge pull request #766 from dirac-institute/small_fixes
Browse files Browse the repository at this point in the history
Small numpy changes
  • Loading branch information
jeremykubica authored Dec 20, 2024
2 parents dee0077 + 85528dd commit 61c538a
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/kbmod/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def plot_time_series(values, times=None, indices=None, ax=None, figure=None, tit
title : `str` or None, optional
Title of the plot. `None` by default.
"""
y_values = np.array(values)
y_values = np.asarray(values)

# If no axes were given, create a new figure.
if ax is None:
Expand All @@ -486,15 +486,15 @@ def plot_time_series(values, times=None, indices=None, ax=None, figure=None, tit

# If no valid indices are given, use them all.
if indices is None:
indices = np.array([True] * len(values), dtype=bool)
indices = np.full(len(values), True, dtype=bool)
else:
indices = np.array(indices, dtype=bool)
indices = np.asarray(indices, dtype=bool)

# If the times are not given, then use linear spacing.
if times is None:
x_values = np.linspace(0, len(values) - 1, len(values), dtype=int)
else:
x_values = np.array(times)
x_values = np.asarray(times)

# Plot the data with the curve in blue, the valid points as blue dots,
# and the invalid indices as smaller red dots.
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def _get_first_psf_at_time(work_unit, time):
If the time is not found in list of observation times in the work_unit,
raise an error.
"""
obstimes = np.array(work_unit.get_all_obstimes())
obstimes = np.asarray(work_unit.get_all_obstimes())

# if the time isn't in the list of times, raise an error.
if time not in obstimes:
Expand Down
6 changes: 3 additions & 3 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def filter_rows(self, rows, label=""):
return

# Check if we are dealing with a mask of a list of indices.
rows = np.array(rows)
rows = np.asarray(rows)
if rows.dtype == bool:
if len(rows) != len(self.table):
raise ValueError(
Expand Down Expand Up @@ -627,7 +627,7 @@ def revert_filter(self, label=None, add_column=None):

# If we don't have the tracking column yet, add it.
if add_column is not None and add_column not in self.table.colnames:
self.table[add_column] = np.array([""] * len(self.table), dtype=str)
self.table[add_column] = np.full(len(self.table), "", dtype=str)

# Make a list of tables to merge.
table_list = [self.table]
Expand Down Expand Up @@ -748,7 +748,7 @@ def write_column(self, colname, filename):
raise KeyError(f"Column {colname} missing from data.")

# Save the column.
data = np.array(self.table[colname])
data = np.asarray(self.table[colname])
np.save(filename, data, allow_pickle=False)

def load_column(self, filename, colname):
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=Non
"cluster_type": config["cluster_type"],
"cluster_eps": config["cluster_eps"],
"cluster_v_scale": config["cluster_v_scale"],
"times": np.array(mjds),
"times": np.asarray(mjds),
}
apply_clustering(keep, cluster_params)
cluster_timer.stop()
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/trajectory_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def evaluate_linear_trajectory(self, x, y, vx, vy):
result = Results.from_trajectories([trj])

# Get the psi and phi curves and do the sigma_g filtering.
psi_curve = np.array([self.search.get_psi_curves(trj)])
phi_curve = np.array([self.search.get_phi_curves(trj)])
psi_curve = np.asarray([self.search.get_psi_curves(trj)])
phi_curve = np.asarray([self.search.get_phi_curves(trj)])
obs_valid = np.full(psi_curve.shape, True)
result.add_psi_phi_data(psi_curve, phi_curve, obs_valid)

Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def trajectory_predict_skypos(trj, wcs, times):
result : `astropy.coordinates.SkyCoord`
A SkyCoord with the transformed locations.
"""
dt = np.array(times)
dt = np.asarray(times)
dt -= dt[0]

# Predict locations in pixel space.
Expand Down

0 comments on commit 61c538a

Please sign in to comment.