Source code for captum.attr._core.layer.layer_feature_permutation
#!/usr/bin/env python3
-from typing import Any, Callable, List, Tuple, Union
+from typing import Any, Callable, cast, List, Tuple, Union
import torch
from captum._utils.common import (
@@ -233,7 +233,11 @@ Source code for captum.attr._core.layer.layer_feature_permutation
finally:
if hook is not None:
hook.remove()
- return eval
+
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ return cast(Tensor, eval)
with torch.no_grad():
inputs = _format_tensor_into_tuples(inputs)
diff --git a/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html b/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html
index 4b31b821c..bae1a63b0 100644
--- a/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html
+++ b/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html
@@ -31,7 +31,7 @@
Source code for captum.attr._core.layer.layer_feature_permutation
#!/usr/bin/env python3
-from typing import Any, Callable, List, Tuple, Union
+from typing import Any, Callable, cast, List, Tuple, Union
import torch
from captum._utils.common import (
@@ -233,7 +233,11 @@ Source code for captum.attr._core.layer.layer_feature_permutation
finally:
if hook is not None:
hook.remove()
- return eval
+
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ return cast(Tensor, eval)
with torch.no_grad():
inputs = _format_tensor_into_tuples(inputs)
diff --git a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html
index 20589d6c8..bdea47362 100644
--- a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html
+++ b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html
@@ -33,7 +33,7 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
#!/usr/bin/env python3
import functools
import warnings
-from typing import Any, Callable, List, overload, Tuple, Union
+from typing import Any, Callable, cast, List, overload, Tuple, Union
import torch
from captum._utils.common import (
@@ -136,7 +136,8 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
"Multiple layers provided. Please ensure that each layer is"
"**not** solely dependent on the outputs of"
"another layer. Please refer to the documentation for more"
- "detail."
+ "detail.",
+ stacklevel=2,
)
@overload
@@ -503,13 +504,17 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
# the inputs is an empty tuple
# coz it is prepended into additional_forward_args
output = _run_forward(
- self.forward_func, tuple(), target_ind, additional_forward_args
+ self.forward_func, (), target_ind, additional_forward_args
)
finally:
for hook in hooks:
if hook is not None:
hook.remove()
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ output = cast(Tensor, output)
assert output[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
diff --git a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html
index 20589d6c8..bdea47362 100644
--- a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html
+++ b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html
@@ -33,7 +33,7 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
#!/usr/bin/env python3
import functools
import warnings
-from typing import Any, Callable, List, overload, Tuple, Union
+from typing import Any, Callable, cast, List, overload, Tuple, Union
import torch
from captum._utils.common import (
@@ -136,7 +136,8 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
"Multiple layers provided. Please ensure that each layer is"
"**not** solely dependent on the outputs of"
"another layer. Please refer to the documentation for more"
- "detail."
+ "detail.",
+ stacklevel=2,
)
@overload
@@ -503,13 +504,17 @@ Source code for captum.attr._core.layer.layer_integrated_gradients
# the inputs is an empty tuple
# coz it is prepended into additional_forward_args
output = _run_forward(
- self.forward_func, tuple(), target_ind, additional_forward_args
+ self.forward_func, (), target_ind, additional_forward_args
)
finally:
for hook in hooks:
if hook is not None:
hook.remove()
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ output = cast(Tensor, output)
assert output[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
diff --git a/api/_modules/captum/attr/_core/lrp.html b/api/_modules/captum/attr/_core/lrp.html
index 700bba376..81dcf2ae9 100644
--- a/api/_modules/captum/attr/_core/lrp.html
+++ b/api/_modules/captum/attr/_core/lrp.html
@@ -401,7 +401,11 @@ Source code for captum.attr._core.lrp
# adjustments as inputs to the layers with adjusted weights. This procedure
# is important for graph generation in the 2nd forward pass.
self._register_pre_hooks()
- return output
+
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ return cast(Tensor, output)
def _remove_forward_hooks(self) -> None:
for forward_handle in self.forward_handles:
diff --git a/api/_modules/captum/attr/_core/lrp/index.html b/api/_modules/captum/attr/_core/lrp/index.html
index 700bba376..81dcf2ae9 100644
--- a/api/_modules/captum/attr/_core/lrp/index.html
+++ b/api/_modules/captum/attr/_core/lrp/index.html
@@ -401,7 +401,11 @@ Source code for captum.attr._core.lrp
# adjustments as inputs to the layers with adjusted weights. This procedure
# is important for graph generation in the 2nd forward pass.
self._register_pre_hooks()
- return output
+
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ return cast(Tensor, output)
def _remove_forward_hooks(self) -> None:
for forward_handle in self.forward_handles:
diff --git a/api/_modules/captum/attr/_core/shapley_value.html b/api/_modules/captum/attr/_core/shapley_value.html
index dad7d28e5..c9a9c96ad 100644
--- a/api/_modules/captum/attr/_core/shapley_value.html
+++ b/api/_modules/captum/attr/_core/shapley_value.html
@@ -35,7 +35,7 @@ Source code for captum.attr._core.shapley_value
<
import itertools
import math
import warnings
-from typing import Any, Callable, Iterable, Sequence, Tuple, Union
+from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union
import torch
from captum._utils.common import (
@@ -59,7 +59,7 @@ Source code for captum.attr._core.shapley_value
<
_tensorize_baseline,
)
from captum.log import log_usage
-from torch import Tensor
+from torch import dtype, Tensor
def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
@@ -588,7 +588,7 @@ Source code for captum.attr._core.shapley_value
<
# using python built-in type as torch dtype
# int -> torch.int64, float -> torch.float64
# ref: https://github.com/pytorch/pytorch/pull/21215
- return torch.tensor([forward_output], dtype=output_type)
+ return torch.tensor([forward_output], dtype=cast(dtype, output_type))
diff --git a/api/_modules/captum/attr/_core/shapley_value/index.html b/api/_modules/captum/attr/_core/shapley_value/index.html
index dad7d28e5..c9a9c96ad 100644
--- a/api/_modules/captum/attr/_core/shapley_value/index.html
+++ b/api/_modules/captum/attr/_core/shapley_value/index.html
@@ -35,7 +35,7 @@ Source code for captum.attr._core.shapley_value
<
import itertools
import math
import warnings
-from typing import Any, Callable, Iterable, Sequence, Tuple, Union
+from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union
import torch
from captum._utils.common import (
@@ -59,7 +59,7 @@ Source code for captum.attr._core.shapley_value
<
_tensorize_baseline,
)
from captum.log import log_usage
-from torch import Tensor
+from torch import dtype, Tensor
def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
@@ -588,7 +588,7 @@ Source code for captum.attr._core.shapley_value
<
# using python built-in type as torch dtype
# int -> torch.int64, float -> torch.float64
# ref: https://github.com/pytorch/pytorch/pull/21215
- return torch.tensor([forward_output], dtype=output_type)
+ return torch.tensor([forward_output], dtype=cast(dtype, output_type))
diff --git a/api/_modules/captum/attr/_utils/attribution.html b/api/_modules/captum/attr/_utils/attribution.html
index 2c5848ea1..2f2c71702 100644
--- a/api/_modules/captum/attr/_utils/attribution.html
+++ b/api/_modules/captum/attr/_utils/attribution.html
@@ -321,17 +321,22 @@ Source code for captum.attr._utils.attribution
_validate_target(num_samples, target)
with torch.no_grad():
- start_out_sum = _sum_rows(
- _run_forward(
- self.forward_func, start_point, target, additional_forward_args
- )
+ start_out_eval = _run_forward(
+ self.forward_func, start_point, target, additional_forward_args
)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ start_out_sum = _sum_rows(cast(Tensor, start_out_eval))
- end_out_sum = _sum_rows(
- _run_forward(
- self.forward_func, end_point, target, additional_forward_args
- )
+ end_out_eval = _run_forward(
+ self.forward_func, end_point, target, additional_forward_args
)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ end_out_sum = _sum_rows(cast(Tensor, end_out_eval))
+
row_sums = [_sum_rows(attribution) for attribution in attributions]
attr_sum = torch.stack(
[cast(Tensor, sum(row_sum)) for row_sum in zip(*row_sums)]
diff --git a/api/_modules/captum/attr/_utils/attribution/index.html b/api/_modules/captum/attr/_utils/attribution/index.html
index 2c5848ea1..2f2c71702 100644
--- a/api/_modules/captum/attr/_utils/attribution/index.html
+++ b/api/_modules/captum/attr/_utils/attribution/index.html
@@ -321,17 +321,22 @@
Source code for captum.attr._utils.attribution
_validate_target(num_samples, target)
with torch.no_grad():
- start_out_sum = _sum_rows(
- _run_forward(
- self.forward_func, start_point, target, additional_forward_args
- )
+ start_out_eval = _run_forward(
+ self.forward_func, start_point, target, additional_forward_args
)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ start_out_sum = _sum_rows(cast(Tensor, start_out_eval))
- end_out_sum = _sum_rows(
- _run_forward(
- self.forward_func, end_point, target, additional_forward_args
- )
+ end_out_eval = _run_forward(
+ self.forward_func, end_point, target, additional_forward_args
)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ end_out_sum = _sum_rows(cast(Tensor, end_out_eval))
+
row_sums = [_sum_rows(attribution) for attribution in attributions]
attr_sum = torch.stack(
[cast(Tensor, sum(row_sum)) for row_sum in zip(*row_sums)]
diff --git a/api/_modules/captum/metrics/_core/infidelity.html b/api/_modules/captum/metrics/_core/infidelity.html
index fee298954..b7b19d44f 100644
--- a/api/_modules/captum/metrics/_core/infidelity.html
+++ b/api/_modules/captum/metrics/_core/infidelity.html
@@ -530,6 +530,10 @@
Source code for captum.metrics._core.infidelity
<
additional_forward_args_expanded,
)
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ inputs_fwd = cast(Tensor, inputs_fwd)
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
diff --git a/api/_modules/captum/metrics/_core/infidelity/index.html b/api/_modules/captum/metrics/_core/infidelity/index.html
index fee298954..b7b19d44f 100644
--- a/api/_modules/captum/metrics/_core/infidelity/index.html
+++ b/api/_modules/captum/metrics/_core/infidelity/index.html
@@ -530,6 +530,10 @@ Source code for captum.metrics._core.infidelity
<
additional_forward_args_expanded,
)
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
+ # _run_forward may return future of Tensor,
+ # but we don't support it here now
+ # And it will fail before here.
+ inputs_fwd = cast(Tensor, inputs_fwd)
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
diff --git a/tutorials/CIFAR_TorchVision_Captum_Insights.html b/tutorials/CIFAR_TorchVision_Captum_Insights.html
index 4c943ecf5..8cca8fdc2 100644
--- a/tutorials/CIFAR_TorchVision_Captum_Insights.html
+++ b/tutorials/CIFAR_TorchVision_Captum_Insights.html
@@ -234,10 +234,10 @@
-
+