summed_grads = tuple(
(
torch.mean(
-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
-
# `Tuple[Tensor, ...]`.
layer_grad,
-
# pyre-fixme[16]: `tuple` has no attribute `shape`.
dim=tuple(x for x in range(2, len(layer_grad.shape))),
keepdim=True,
)
@@ -264,29 +260,17 @@
Source code for captum.attr._core.layer.grad_cam
if attr_dim_summation:
scaled_acts = tuple(
- # pyre-fixme[58]: `*` is not supported for operand types
- # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
- # `Tuple[Tensor, ...]`.
- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
- # `Tuple[Tensor, ...]`.
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)
else:
scaled_acts = tuple(
- # pyre-fixme[58]: `*` is not supported for operand types
- # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
- # `Tuple[Tensor, ...]`.
summed_grad * layer_eval
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)
if relu_attributions:
- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
- # `Union[tuple[Tensor], Tensor]`.
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
- # `Tuple[Union[tuple[Tensor], Tensor], ...]`.
return _format_output(len(scaled_acts) > 1, scaled_acts)
diff --git a/api/_modules/captum/attr/_core/layer/grad_cam/index.html b/api/_modules/captum/attr/_core/layer/grad_cam/index.html
index 850fb67f06..61914e1279 100644
--- a/api/_modules/captum/attr/_core/layer/grad_cam/index.html
+++ b/api/_modules/captum/attr/_core/layer/grad_cam/index.html
@@ -33,7 +33,7 @@