Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix baselines utils pyre fix me issues #1478

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
ModuleOrModuleList,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
Expand Down Expand Up @@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval(

def construct_neuron_grad_fn(
layer: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
device_ids: Union[None, List[int]] = None,
attribute_to_neuron_input: bool = False,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
7 changes: 7 additions & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
TensorLikeList5D,
]

try:
# Subscripted slice syntax is not supported in previous Python versions,
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore

# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
Expand Down
22 changes: 3 additions & 19 deletions captum/attr/_core/layer/grad_cam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -54,8 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -201,7 +200,7 @@ def attribute(
# hidden layer and hidden layer evaluated at each input.
layer_gradients, layer_evals = compute_layer_gradients_and_eval(
self.forward_func,
self.layer,
cast(Module, self.layer),
inputs,
target,
additional_forward_args,
Expand All @@ -213,10 +212,7 @@ def attribute(
summed_grads = tuple(
(
torch.mean(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Tuple[Tensor, ...]`.
layer_grad,
# pyre-fixme[16]: `tuple` has no attribute `shape`.
dim=tuple(x for x in range(2, len(layer_grad.shape))),
keepdim=True,
)
Expand All @@ -228,27 +224,15 @@ def attribute(

if attr_dim_summation:
scaled_acts = tuple(
# pyre-fixme[58]: `*` is not supported for operand types
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
# `Tuple[Tensor, ...]`.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Tuple[Tensor, ...]`.
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)
else:
scaled_acts = tuple(
# pyre-fixme[58]: `*` is not supported for operand types
# `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
# `Tuple[Tensor, ...]`.
summed_grad * layer_eval
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
)

if relu_attributions:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[tuple[Tensor], Tensor]`.
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
# `Tuple[Union[tuple[Tensor], Tensor], ...]`.
return _format_output(len(scaled_acts) > 1, scaled_acts)
12 changes: 4 additions & 8 deletions captum/attr/_core/layer/internal_influence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -41,8 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -293,7 +292,7 @@ def _attribute(
# Returns gradient of output with respect to hidden layer.
layer_gradients, _ = compute_layer_gradients_and_eval(
forward_fn=self.forward_func,
layer=self.layer,
layer=cast(Module, self.layer),
inputs=scaled_features_tpl,
target_ind=expanded_target,
additional_forward_args=input_additional_args,
Expand All @@ -304,9 +303,7 @@ def _attribute(
# flattening grads so that we can multiply it with step-size
# calling contiguous to avoid `memory whole` problems
scaled_grads = tuple(
# pyre-fixme[16]: `tuple` has no attribute `contiguous`.
layer_grad.contiguous().view(n_steps, -1)
# pyre-fixme[16]: `tuple` has no attribute `device`.
* torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device)
for layer_grad in layer_gradients
)
Expand All @@ -317,8 +314,7 @@ def _attribute(
scaled_grad,
n_steps,
inputs[0].shape[0],
# pyre-fixme[16]: `tuple` has no attribute `shape`.
layer_grad.shape[1:],
tuple(layer_grad.shape[1:]),
)
for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients)
)
Expand Down
5 changes: 1 addition & 4 deletions captum/attr/_core/layer/layer_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class LayerActivation(LayerAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Union[int, float, Tensor]],
layer: ModuleOrModuleList,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -132,8 +131,6 @@ def attribute(
)
else:
return [
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but
# got `Tensor`.
_format_output(len(single_layer_eval) > 1, single_layer_eval)
for single_layer_eval in layer_eval
]
Expand Down
16 changes: 4 additions & 12 deletions captum/attr/_core/layer/layer_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -73,8 +72,6 @@ def has_convergence_delta(self) -> bool:
return True

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `75`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -91,8 +88,6 @@ def attribute(
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `91`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -376,7 +371,7 @@ def _attribute(
layer_evals,
) = compute_layer_gradients_and_eval(
forward_fn=self.forward_func,
layer=self.layer,
layer=cast(Module, self.layer),
inputs=scaled_features_tpl,
additional_forward_args=input_additional_args,
target_ind=expanded_target,
Expand All @@ -389,8 +384,6 @@ def _attribute(
# This approximates the total input gradient of each step multiplied
# by the step size.
grad_diffs = tuple(
# pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor,
# ...]` and `Tuple[Tensor, ...]`.
layer_eval[num_examples:] - layer_eval[:-num_examples]
for layer_eval in layer_evals
)
Expand All @@ -403,8 +396,7 @@ def _attribute(
grad_diff * layer_gradient[:-num_examples],
n_steps,
num_examples,
# pyre-fixme[16]: `tuple` has no attribute `shape`.
layer_eval.shape[1:],
tuple(layer_eval.shape[1:]),
)
for layer_gradient, layer_eval, grad_diff in zip(
layer_gradients, layer_evals, grad_diffs
Expand Down
33 changes: 18 additions & 15 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ def attribute(
additional_forward_args,
)

# pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
def chunk_output_fn(
out: TensorOrTupleOfTensorsGeneric,
) -> Sequence[Union[Tensor, Sequence[Tensor]]]:
if isinstance(out, Tensor):
return out.chunk(2)
return tuple(out_sub.chunk(2) for out_sub in out)
Expand Down Expand Up @@ -434,8 +435,6 @@ def __init__(

# Ignoring mypy error for inconsistent signature with DeepLiftShap
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `453`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -450,9 +449,7 @@ def attribute(
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `439`.
@typing.overload # type: ignore
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -654,7 +651,7 @@ def attribute(
) = DeepLiftShap._expand_inputs_baselines_targets(
self, baselines, inputs, target, additional_forward_args
)
attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
attribs_layer_deeplift = LayerDeepLift.attribute.__wrapped__( # type: ignore
self,
exp_inp,
exp_base,
Expand All @@ -667,8 +664,12 @@ def attribute(
attribute_to_layer_input=attribute_to_layer_input,
custom_attribution_func=custom_attribution_func,
)
delta: Tensor
attributions: Union[Tensor, Tuple[Tensor, ...]]
if return_convergence_delta:
attributions, delta = attributions
attributions, delta = attribs_layer_deeplift
else:
attributions = attribs_layer_deeplift
if isinstance(attributions, tuple):
attributions = tuple(
DeepLiftShap._compute_mean_across_baselines(
Expand All @@ -681,15 +682,17 @@ def attribute(
self, inp_bsz, base_bsz, attributions
)
if return_convergence_delta:
# pyre-fixme[61]: `delta` is undefined, or not always defined.
return attributions, delta
else:
# pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
# typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
# but got `Union[tuple[Tensor], Tensor]`.
return attributions
return cast(
Union[
Tensor,
Tuple[Tensor, ...],
Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor],
],
attributions,
)

@property
# pyre-fixme[3]: Return type must be annotated.
def multiplies_by_inputs(self) -> bool:
return self._multiply_by_inputs
Loading
Loading