Skip to content

Commit

Permalink
add options for extra loops and the cmap value
Browse files Browse the repository at this point in the history
  • Loading branch information
ieee8023 committed Jan 3, 2023
1 parent bd87691 commit 7c16b42
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion captum/attr/_core/latent_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def generate_video(
temp_path: str = "/tmp/gifsplanation",
show: bool = True,
verbose: bool = True,
extra_loops: int = 0,
cmap: str = None,
):
"""Generate a video from the generated images.
Expand All @@ -260,6 +262,10 @@ def generate_video(
temp_path: A temp path to write images.
show: To try and show the video in a jupyter notebook.
verbose: True to print debug text
extra_loops: The video does one loop by default. This will repeat
those loops to make it easier to watch.
cmap: The cmap value passed to matplotlib. e.g. 'gray' for a
grayscale image.
Returns:
The filename of the video if show=False, otherwise it will
Expand All @@ -277,12 +283,16 @@ def generate_video(
# Add reversed so we have an animation cycle
towrite = list(reversed(imgs)) + list(imgs)
ys = list(reversed(params["preds"])) + list(params["preds"])

for n in range(extra_loops):
towrite += towrite
ys += ys

for idx, img in enumerate(towrite):

px = 1 / plt.rcParams["figure.dpi"]
full_frame(img[0].shape[0] * px, img[0].shape[1] * px)
plt.imshow(img[0], interpolation="none")
plt.imshow(img[0], interpolation="none", cmap=cmap)

if watermark:
# Write prob output in upper left
Expand Down

0 comments on commit 7c16b42

Please sign in to comment.