Skip to content

Commit

Permalink
check basenames only, split unit tests into separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
emolter committed Jan 14, 2025
1 parent 71db069 commit 5c74ff3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 25 deletions.
8 changes: 6 additions & 2 deletions jwst/skymatch/skymatch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from stdatamodels.jwst.datamodels.dqflags import pixel

from jwst.datamodels import ModelLibrary
from jwst.lib.suffix import remove_suffix
import os.path as op

from ..stpipe import Step

Expand Down Expand Up @@ -237,21 +239,23 @@ def _user_sky(self, library):
log.info("Setting sky background of input images to user-provided values "
f"from `skylist` ({self.skylist}).")

# read the comma separated file
# read the comma separated file and get just the stem of the filename
skylist = np.genfromtxt(
self.skylist,
dtype=[("fname", "<S128"), ("sky", "f")],
)
skyfnames, skyvals = skylist['fname'], skylist['sky']
skyfnames = skyfnames.astype(str)
skyfnames = [remove_suffix(op.splitext(fname)[0])[0] for fname in skyfnames]
skyfnames = np.array(skyfnames)

if len(skyvals) != len(library):
raise ValueError(f"Number of entries in skylist ({len(self.skylist)}) does not match "
f"number of input images ({len(library)}).")

with library:
for model in library:
fname = model.meta.filename
fname, _ = remove_suffix(op.splitext(model.meta.filename)[0])
sky = skyvals[np.where(skyfnames == fname)]
if len(sky) == 0:
raise ValueError(f"Image '{fname}' not found in the skylist.")
Expand Down
71 changes: 48 additions & 23 deletions jwst/skymatch/tests/test_skymatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _add_bad_pixels(im, sat_val, dont_use_val):
'skymethod, subtract, skystat, match_down, grouped',
tuple(
product(
['local', 'match', 'global', 'global+match', 'user'],
['local', 'match', 'global', 'global+match'],
[False, True],
['median', 'mean', 'midpt', 'mode'],
[False, True],
Expand All @@ -175,11 +175,8 @@ def test_skymatch(tmp_cwd, nircam_rate, skymethod, subtract, skystat, match_down
# test basic functionality and correctness of sky computations
np.random.seed(1)
im1 = nircam_rate.copy()
im1.meta.filename = "one.fits"
im2 = im1.copy()
im2.meta.filename = "two.fits"
im3 = im1.copy()
im3.meta.filename = "three.fits"

# add "bad" data
im1, dq_mask = _add_bad_pixels(im1, 1e6, 1e9)
Expand All @@ -204,14 +201,6 @@ def test_skymatch(tmp_cwd, nircam_rate, skymethod, subtract, skystat, match_down
scale=0.1,
size=im.data.shape
)

# put levels into the skylist file for when skylist='user'
fnames = [model.meta.filename for model in container]
print("fnames", fnames)
skyfile = "skylist.txt"
with open(skyfile, "w") as f:
for fname, lev in zip(fnames, levels):
f.write(f"{fname} {lev}\n")

# exclude central DO_NOT_USE and corner SATURATED pixels
result = SkyMatchStep.call(
Expand All @@ -223,7 +212,6 @@ def test_skymatch(tmp_cwd, nircam_rate, skymethod, subtract, skystat, match_down
binwidth=0.2,
nclip=0,
dqbits='~DO_NOT_USE+SATURATED',
skylist=skyfile,
)

if skymethod == 'match' and grouped:
Expand Down Expand Up @@ -255,9 +243,6 @@ def test_skymatch(tmp_cwd, nircam_rate, skymethod, subtract, skystat, match_down
elif skymethod == 'global':
ref_levels = len(levels) * [min(levels)]

elif skymethod == 'user':
ref_levels = levels

sub_levels = np.subtract(levels, ref_levels)

with result:
Expand Down Expand Up @@ -564,32 +549,72 @@ def test_skymatch_2x(tmp_cwd, nircam_rate, tmp_path, skymethod, subtract):
result2.shelve(im2)


def test_user_sky_bad_inputs(tmp_cwd, nircam_rate):
@pytest.mark.parametrize("subtract", (False, True))
def test_user_skyfile(tmp_cwd, nircam_rate, subtract):

# give them all different suffixes to ensure they get stripped properly
im1 = nircam_rate.copy()
im1.meta.filename = "one.fits"
im1.meta.filename = "one_tweakregstep.fits"
im2 = im1.copy()
im2.meta.filename = "two.fits"
im2.meta.filename = "two_unknown.fits"
im3 = im1.copy()
im3.meta.filename = "three.fits"


# give filenames in skyfile same stems but different suffix
fnames_skyfile = ["one_cal.fits", "two_unknown_cal.fits", "three_cal.fits"]

container = [im1, im2, im3]

# define some background:
# put levels into the skylist file for when skylist='user'
levels = [9.12, 8.28, 2.56]
fnames = [model.meta.filename for model in container]

for im, lev in zip(container, levels):
im.data += lev

skyfile = "skylist.txt"
with open(skyfile, "w") as f:
for fname, lev in zip(fnames_skyfile, levels):
f.write(f"{fname} {lev}\n")

#test good inputs
result = SkyMatchStep.call(
container,
subtract=subtract,
skymethod="user",
skylist=skyfile,
)

ref_levels = levels
sub_levels = np.subtract(levels, ref_levels)

with result:
for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels):
# check that meta was set correctly:
assert im.meta.background.method == "user"
assert im.meta.background.subtracted == subtract

# test computed/measured sky values:
assert abs(im.meta.background.level - rlev) < 0.01

# test
if subtract:
assert abs(np.mean(im.data) - slev) < 0.01
else:
assert abs(np.mean(im.data) - lev) < 0.01
result.shelve(im, modify=False)


# test failures
# no skylist file
with pytest.raises(ValueError):
# skylist must be provided
SkyMatchStep.call(
container,
skymethod='user',
)

# test skylist file doesn't have right number of lines
# skylist file doesn't have right number of lines
skyfile = "skylist_short.txt"
with open(skyfile, "w") as f:
for fname, lev in zip(fnames[1:], levels[1:]):
Expand All @@ -602,7 +627,7 @@ def test_user_sky_bad_inputs(tmp_cwd, nircam_rate):
skylist=skyfile
)

# test skylist file does not contain all filenames
# skylist file does not contain all filenames
skyfile = "skylist_missing.txt"
fnames_wrong = ["two.fits"] + fnames[1:]
with open(skyfile, "w") as f:
Expand Down

0 comments on commit 5c74ff3

Please sign in to comment.