diff --git a/api/_modules/captum/attr/_core/feature_ablation.html b/api/_modules/captum/attr/_core/feature_ablation.html index 6864961bb..73d90c2b7 100644 --- a/api/_modules/captum/attr/_core/feature_ablation.html +++ b/api/_modules/captum/attr/_core/feature_ablation.html @@ -33,7 +33,7 @@

Source code for captum.attr._core.feature_ablation

#!/usr/bin/env python3 import math -from typing import Any, Callable, cast, Tuple, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -51,6 +51,7 @@

Source code for captum.attr._core.feature_ablation

from captum.attr._utils.common import _format_input_baseline from captum.log import log_usage from torch import dtype, Tensor +from torch.futures import collect_all, Future
@@ -96,6 +97,7 @@

Source code for captum.attr._core.feature_ablation

# input grow as expected. Once it turns to True, we will assume the model's # behavior stays consistent and no longer check again self._is_output_shape_valid = False + self.use_futures = False
[docs] @@ -110,7 +112,7 @@

Source code for captum.attr._core.feature_ablation

perturbations_per_eval: int = 1, show_progress: bool = False, **kwargs: Any, - ) -> TensorOrTupleOfTensorsGeneric: + ) -> Union[TensorOrTupleOfTensorsGeneric, Future[TensorOrTupleOfTensorsGeneric]]: r""" Args: @@ -322,42 +324,55 @@

Source code for captum.attr._core.feature_ablation

# Computes initial evaluation with all features, which is compared # to each ablated result. - initial_eval = self._strict_run_forward( + initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( self.forward_func, inputs, target, additional_forward_args ) if show_progress: attr_progress.update() - # number of elements in the output of forward_func - n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 - - # flatten eval outputs into 1D (n_outputs) - # add the leading dim for n_feature_perturbed - flattened_initial_eval = initial_eval.reshape(1, -1) - - # Initialize attribution totals and counts - attrib_type = cast(dtype, flattened_initial_eval.dtype) + processed_initial_eval_fut: Optional[ + Future[Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]] + ] = None + total_attrib: List[Tensor] = [] + weights: List[Tensor] = [] + flattened_initial_eval: Tensor + n_outputs: int + attrib_type: dtype + + if self.use_futures: + assert isinstance(initial_eval, torch.Future), ( + "when use_futures is True, initial_eval should have " + f"Future type rather than {type(initial_eval)}" + ) - total_attrib = [ - # attribute w.r.t each output element - torch.zeros( - (n_outputs,) + input.shape[1:], - dtype=attrib_type, - device=input.device, + processed_initial_eval_fut = initial_eval.then( + lambda x: self._process_initial_eval( + x.value(), + inputs, + ) + ) + else: + assert not isinstance(initial_eval, torch.Future), ( + "when use_futures is False, initial_eval should have " + f"non-Future type rather than {type(initial_eval)}" ) - for input in inputs - ] - # Weights are used in cases where ablations may be overlapping. - if self.use_weights: - weights = [ - torch.zeros( - (n_outputs,) + input.shape[1:], device=input.device - ).float() - for input in inputs - ] + ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) = self._process_initial_eval( + initial_eval, + inputs, + ) + # The will be the same amount futures as modified_eval down there, + # since we cannot add up the evaluation result adhoc under async mode. + all_futures: List[List[Future]] = [[] for _ in range(len(inputs))] # Iterate through each feature tensor for ablation for i in range(len(inputs)): # Skip any empty input tensors @@ -384,7 +399,7 @@

Source code for captum.attr._core.feature_ablation

# agg mode: (*initial_eval.shape) # non-agg mode: # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval = self._strict_run_forward( + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( self.forward_func, current_inputs, current_target, @@ -394,71 +409,75 @@

Source code for captum.attr._core.feature_ablation

if show_progress: attr_progress.update() - # if perturbations_per_eval > 1, the output shape must grow with - # input and not be aggregated - if perturbations_per_eval > 1 and not self._is_output_shape_valid: - current_batch_size = current_inputs[0].shape[0] - - # number of perturbation, which is not the same as - # perturbations_per_eval when not enough features to perturb - n_perturb = current_batch_size / num_examples - - current_output_shape = modified_eval.shape - - # use initial_eval as the forward of perturbations_per_eval = 1 - initial_output_shape = initial_eval.shape - + if self.use_futures: + assert isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) assert ( - # check if the output is not a scalar - current_output_shape - and initial_output_shape - # check if the output grow in same ratio, i.e., not agg - and current_output_shape[0] - == n_perturb * initial_output_shape[0] - ), ( - "When perturbations_per_eval > 1, forward_func's output " - "should be a tensor whose 1st dim grow with the input " - f"batch size: when input batch size is {num_examples}, " - f"the output shape is {initial_output_shape}; " - f"when input batch size is {current_batch_size}, " - f"the output shape is {current_output_shape}" + processed_initial_eval_fut is not None + ), "processed_initial_eval_fut should not be None" + + # Need to collect both initial eval and modified_eval + eval_futs: Future[List[Future[Tensor]]] = collect_all( + [ + processed_initial_eval_fut, + modified_eval, + ] ) - self._is_output_shape_valid = True - - # reshape the leading dim for n_feature_perturbed - # flatten each feature's eval outputs into 1D of (n_outputs) - modified_eval = modified_eval.reshape(-1, n_outputs) - # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = flattened_initial_eval - modified_eval - - # append the shape of one input example - # to make it broadcastable to mask - eval_diff = eval_diff.reshape( - eval_diff.shape + (inputs[i].dim() - 1) * (1,) - ) - eval_diff = eval_diff.to(total_attrib[i].device) - - if self.use_weights: - weights[i] += current_mask.float().sum(dim=0) + ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( + eval_futs.then( + lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._process_ablated_out( # type: ignore # noqa: E501 line too long + eval_futs.value()[1].value(), + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + # initial_eval + eval_futs.value()[0].value()[2], + # flattened_initial_eval + eval_futs.value()[0].value()[3], + inputs, + # n_outputs + eval_futs.value()[0].value()[4], + # total_attrib + eval_futs.value()[0].value()[0], + # weights + eval_futs.value()[0].value()[1], + i, + # attrib_type + eval_futs.value()[0].value()[5], + ) + ) + ) - total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( - dim=0 - ) + all_futures[i].append(ablated_out_fut) + else: + total_attrib, weights = self._process_ablated_out( + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) if show_progress: attr_progress.close() - # Divide total attributions by counts and return formatted attributions - if self.use_weights: - attrib = tuple( - single_attrib.float() / weight - for single_attrib, weight in zip(total_attrib, weights) - ) + if len(all_futures) > 0 and len(all_futures[0]) > 0: + return self._generate_async_result(all_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + else: - attrib = tuple(total_attrib) - _result = _format_output(is_inputs_tuple, attrib) - return _result
+ return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long
def _ith_input_ablation_generator( @@ -630,13 +649,12 @@

Source code for captum.attr._core.feature_ablation

for inp, mask in zip(inputs, feature_mask) ) - def _strict_run_forward(self, *args, **kwargs) -> Tensor: + def _parse_forward_out(self, forward_output) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output type assertion & conversion. Remove after the strict logic is supported by all attr classes """ - forward_output = _run_forward(*args, **kwargs) if isinstance(forward_output, Tensor): return forward_output @@ -649,7 +667,177 @@

Source code for captum.attr._core.feature_ablation

# using python built-in type as torch dtype # int -> torch.int64, float -> torch.float64 # ref: https://github.com/pytorch/pytorch/pull/21215 - return torch.tensor(forward_output, dtype=output_type)
+ return torch.tensor(forward_output, dtype=cast(dtype, output_type)) + + def _process_initial_eval( + self, + initial_eval: Tensor, + inputs: TensorOrTupleOfTensorsGeneric, + ) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]: + initial_eval = self._parse_forward_out(initial_eval) + + # number of elements in the output of forward_func + n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 + + # flatten eval outputs into 1D (n_outputs) + # add the leading dim for n_feature_perturbed + flattened_initial_eval = initial_eval.reshape(1, -1) + + # Initialize attribution totals and counts + attrib_type = flattened_initial_eval.dtype + + total_attrib = [ + # attribute w.r.t each output element + torch.zeros( + (n_outputs,) + input.shape[1:], + dtype=attrib_type, + device=input.device, + ) + for input in inputs + ] + + # Weights are used in cases where ablations may be overlapping. + weights = [] + if self.use_weights: + weights = [ + torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float() + for input in inputs + ] + + return ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) + + def _process_ablated_out( + self, + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) -> Tuple[List[Tensor], List[Tensor]]: + modified_eval = self._parse_forward_out(modified_eval) + + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + if perturbations_per_eval > 1 and not self._is_output_shape_valid: + current_batch_size = current_inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + + assert ( + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True + + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = flattened_initial_eval - modified_eval + + # append the shape of one input example + # to make it broadcastable to mask + eval_diff = eval_diff.reshape(eval_diff.shape + (inputs[i].dim() - 1) * (1,)) + eval_diff = eval_diff.to(total_attrib[i].device) + + if self.use_weights: + weights[i] += current_mask.float().sum(dim=0) + + total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights + + def _generate_async_result( + self, + futs: List[List[Future[Tuple[List[Tensor], List[Tensor]]]]], + is_inputs_tuple: bool, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + # Each element of the 2d list contains evalutaion results for a feature + # Need to add up all the results for each input + accumulate_fut_list: List[Future] = [] + total_attrib: List[Tensor] = [] + weights: List[Tensor] = [] + for i, fut_tuples in enumerate(futs): + for fut_tuple in fut_tuples: + accumulate_fut_list.append( + fut_tuple.then( + lambda x, i=i: self._accumulate_for_single_input( # type: ignore # noqa: E501 line too long + total_attrib, weights, i, x.value()[0], x.value()[1] + ) + ) + ) + + result_fut = collect_all(accumulate_fut_list).then( + lambda x: self._generate_result(total_attrib, weights, is_inputs_tuple) + ) + + return result_fut + + def _accumulate_for_single_input( + self, + total_attrib: List[Tensor], + weights: List[Tensor], + idx: int, + attrib: List[Tensor], + weight: List[Tensor], + ) -> None: + if total_attrib: + total_attrib[idx] += attrib + else: + total_attrib.extend(attrib) + if self.use_weights: + if weights: + weights[idx] += weight + else: + weights.extend(weight) + + def _generate_result( + self, + total_attrib: List[Tensor], + weights: List[Tensor], + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + # Divide total attributions by counts and return formatted attributions + if self.use_weights: + attrib = tuple( + single_attrib.float() / weight + for single_attrib, weight in zip(total_attrib, weights) + ) + else: + attrib = tuple(total_attrib) + return _format_output(is_inputs_tuple, attrib)
diff --git a/api/_modules/captum/attr/_core/feature_ablation/index.html b/api/_modules/captum/attr/_core/feature_ablation/index.html index 6864961bb..73d90c2b7 100644 --- a/api/_modules/captum/attr/_core/feature_ablation/index.html +++ b/api/_modules/captum/attr/_core/feature_ablation/index.html @@ -33,7 +33,7 @@

Source code for captum.attr._core.feature_ablation

#!/usr/bin/env python3 import math -from typing import Any, Callable, cast, Tuple, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -51,6 +51,7 @@

Source code for captum.attr._core.feature_ablation

from captum.attr._utils.common import _format_input_baseline from captum.log import log_usage from torch import dtype, Tensor +from torch.futures import collect_all, Future
@@ -96,6 +97,7 @@

Source code for captum.attr._core.feature_ablation

# input grow as expected. Once it turns to True, we will assume the model's # behavior stays consistent and no longer check again self._is_output_shape_valid = False + self.use_futures = False
[docs] @@ -110,7 +112,7 @@

Source code for captum.attr._core.feature_ablation

perturbations_per_eval: int = 1, show_progress: bool = False, **kwargs: Any, - ) -> TensorOrTupleOfTensorsGeneric: + ) -> Union[TensorOrTupleOfTensorsGeneric, Future[TensorOrTupleOfTensorsGeneric]]: r""" Args: @@ -322,42 +324,55 @@

Source code for captum.attr._core.feature_ablation

# Computes initial evaluation with all features, which is compared # to each ablated result. - initial_eval = self._strict_run_forward( + initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( self.forward_func, inputs, target, additional_forward_args ) if show_progress: attr_progress.update() - # number of elements in the output of forward_func - n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 - - # flatten eval outputs into 1D (n_outputs) - # add the leading dim for n_feature_perturbed - flattened_initial_eval = initial_eval.reshape(1, -1) - - # Initialize attribution totals and counts - attrib_type = cast(dtype, flattened_initial_eval.dtype) + processed_initial_eval_fut: Optional[ + Future[Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]] + ] = None + total_attrib: List[Tensor] = [] + weights: List[Tensor] = [] + flattened_initial_eval: Tensor + n_outputs: int + attrib_type: dtype + + if self.use_futures: + assert isinstance(initial_eval, torch.Future), ( + "when use_futures is True, initial_eval should have " + f"Future type rather than {type(initial_eval)}" + ) - total_attrib = [ - # attribute w.r.t each output element - torch.zeros( - (n_outputs,) + input.shape[1:], - dtype=attrib_type, - device=input.device, + processed_initial_eval_fut = initial_eval.then( + lambda x: self._process_initial_eval( + x.value(), + inputs, + ) + ) + else: + assert not isinstance(initial_eval, torch.Future), ( + "when use_futures is False, initial_eval should have " + f"non-Future type rather than {type(initial_eval)}" ) - for input in inputs - ] - # Weights are used in cases where ablations may be overlapping. - if self.use_weights: - weights = [ - torch.zeros( - (n_outputs,) + input.shape[1:], device=input.device - ).float() - for input in inputs - ] + ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) = self._process_initial_eval( + initial_eval, + inputs, + ) + # The will be the same amount futures as modified_eval down there, + # since we cannot add up the evaluation result adhoc under async mode. + all_futures: List[List[Future]] = [[] for _ in range(len(inputs))] # Iterate through each feature tensor for ablation for i in range(len(inputs)): # Skip any empty input tensors @@ -384,7 +399,7 @@

Source code for captum.attr._core.feature_ablation

# agg mode: (*initial_eval.shape) # non-agg mode: # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval = self._strict_run_forward( + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( self.forward_func, current_inputs, current_target, @@ -394,71 +409,75 @@

Source code for captum.attr._core.feature_ablation

if show_progress: attr_progress.update() - # if perturbations_per_eval > 1, the output shape must grow with - # input and not be aggregated - if perturbations_per_eval > 1 and not self._is_output_shape_valid: - current_batch_size = current_inputs[0].shape[0] - - # number of perturbation, which is not the same as - # perturbations_per_eval when not enough features to perturb - n_perturb = current_batch_size / num_examples - - current_output_shape = modified_eval.shape - - # use initial_eval as the forward of perturbations_per_eval = 1 - initial_output_shape = initial_eval.shape - + if self.use_futures: + assert isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) assert ( - # check if the output is not a scalar - current_output_shape - and initial_output_shape - # check if the output grow in same ratio, i.e., not agg - and current_output_shape[0] - == n_perturb * initial_output_shape[0] - ), ( - "When perturbations_per_eval > 1, forward_func's output " - "should be a tensor whose 1st dim grow with the input " - f"batch size: when input batch size is {num_examples}, " - f"the output shape is {initial_output_shape}; " - f"when input batch size is {current_batch_size}, " - f"the output shape is {current_output_shape}" + processed_initial_eval_fut is not None + ), "processed_initial_eval_fut should not be None" + + # Need to collect both initial eval and modified_eval + eval_futs: Future[List[Future[Tensor]]] = collect_all( + [ + processed_initial_eval_fut, + modified_eval, + ] ) - self._is_output_shape_valid = True - - # reshape the leading dim for n_feature_perturbed - # flatten each feature's eval outputs into 1D of (n_outputs) - modified_eval = modified_eval.reshape(-1, n_outputs) - # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = flattened_initial_eval - modified_eval - - # append the shape of one input example - # to make it broadcastable to mask - eval_diff = eval_diff.reshape( - eval_diff.shape + (inputs[i].dim() - 1) * (1,) - ) - eval_diff = eval_diff.to(total_attrib[i].device) - - if self.use_weights: - weights[i] += current_mask.float().sum(dim=0) + ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( + eval_futs.then( + lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._process_ablated_out( # type: ignore # noqa: E501 line too long + eval_futs.value()[1].value(), + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + # initial_eval + eval_futs.value()[0].value()[2], + # flattened_initial_eval + eval_futs.value()[0].value()[3], + inputs, + # n_outputs + eval_futs.value()[0].value()[4], + # total_attrib + eval_futs.value()[0].value()[0], + # weights + eval_futs.value()[0].value()[1], + i, + # attrib_type + eval_futs.value()[0].value()[5], + ) + ) + ) - total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( - dim=0 - ) + all_futures[i].append(ablated_out_fut) + else: + total_attrib, weights = self._process_ablated_out( + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) if show_progress: attr_progress.close() - # Divide total attributions by counts and return formatted attributions - if self.use_weights: - attrib = tuple( - single_attrib.float() / weight - for single_attrib, weight in zip(total_attrib, weights) - ) + if len(all_futures) > 0 and len(all_futures[0]) > 0: + return self._generate_async_result(all_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + else: - attrib = tuple(total_attrib) - _result = _format_output(is_inputs_tuple, attrib) - return _result
+ return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long
def _ith_input_ablation_generator( @@ -630,13 +649,12 @@

Source code for captum.attr._core.feature_ablation

for inp, mask in zip(inputs, feature_mask) ) - def _strict_run_forward(self, *args, **kwargs) -> Tensor: + def _parse_forward_out(self, forward_output) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output type assertion & conversion. Remove after the strict logic is supported by all attr classes """ - forward_output = _run_forward(*args, **kwargs) if isinstance(forward_output, Tensor): return forward_output @@ -649,7 +667,177 @@

Source code for captum.attr._core.feature_ablation

# using python built-in type as torch dtype # int -> torch.int64, float -> torch.float64 # ref: https://github.com/pytorch/pytorch/pull/21215 - return torch.tensor(forward_output, dtype=output_type)
+ return torch.tensor(forward_output, dtype=cast(dtype, output_type)) + + def _process_initial_eval( + self, + initial_eval: Tensor, + inputs: TensorOrTupleOfTensorsGeneric, + ) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]: + initial_eval = self._parse_forward_out(initial_eval) + + # number of elements in the output of forward_func + n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 + + # flatten eval outputs into 1D (n_outputs) + # add the leading dim for n_feature_perturbed + flattened_initial_eval = initial_eval.reshape(1, -1) + + # Initialize attribution totals and counts + attrib_type = flattened_initial_eval.dtype + + total_attrib = [ + # attribute w.r.t each output element + torch.zeros( + (n_outputs,) + input.shape[1:], + dtype=attrib_type, + device=input.device, + ) + for input in inputs + ] + + # Weights are used in cases where ablations may be overlapping. + weights = [] + if self.use_weights: + weights = [ + torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float() + for input in inputs + ] + + return ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) + + def _process_ablated_out( + self, + modified_eval, + current_inputs, + current_mask, + perturbations_per_eval, + num_examples, + initial_eval, + flattened_initial_eval, + inputs, + n_outputs, + total_attrib, + weights, + i, + attrib_type, + ) -> Tuple[List[Tensor], List[Tensor]]: + modified_eval = self._parse_forward_out(modified_eval) + + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + if perturbations_per_eval > 1 and not self._is_output_shape_valid: + current_batch_size = current_inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + + assert ( + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True + + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = flattened_initial_eval - modified_eval + + # append the shape of one input example + # to make it broadcastable to mask + eval_diff = eval_diff.reshape(eval_diff.shape + (inputs[i].dim() - 1) * (1,)) + eval_diff = eval_diff.to(total_attrib[i].device) + + if self.use_weights: + weights[i] += current_mask.float().sum(dim=0) + + total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights + + def _generate_async_result( + self, + futs: List[List[Future[Tuple[List[Tensor], List[Tensor]]]]], + is_inputs_tuple: bool, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + # Each element of the 2d list contains evalutaion results for a feature + # Need to add up all the results for each input + accumulate_fut_list: List[Future] = [] + total_attrib: List[Tensor] = [] + weights: List[Tensor] = [] + for i, fut_tuples in enumerate(futs): + for fut_tuple in fut_tuples: + accumulate_fut_list.append( + fut_tuple.then( + lambda x, i=i: self._accumulate_for_single_input( # type: ignore # noqa: E501 line too long + total_attrib, weights, i, x.value()[0], x.value()[1] + ) + ) + ) + + result_fut = collect_all(accumulate_fut_list).then( + lambda x: self._generate_result(total_attrib, weights, is_inputs_tuple) + ) + + return result_fut + + def _accumulate_for_single_input( + self, + total_attrib: List[Tensor], + weights: List[Tensor], + idx: int, + attrib: List[Tensor], + weight: List[Tensor], + ) -> None: + if total_attrib: + total_attrib[idx] += attrib + else: + total_attrib.extend(attrib) + if self.use_weights: + if weights: + weights[idx] += weight + else: + weights.extend(weight) + + def _generate_result( + self, + total_attrib: List[Tensor], + weights: List[Tensor], + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + # Divide total attributions by counts and return formatted attributions + if self.use_weights: + attrib = tuple( + single_attrib.float() / weight + for single_attrib, weight in zip(total_attrib, weights) + ) + else: + attrib = tuple(total_attrib) + return _format_output(is_inputs_tuple, attrib)
diff --git a/api/_modules/captum/attr/_core/layer/layer_feature_permutation.html b/api/_modules/captum/attr/_core/layer/layer_feature_permutation.html index 4b31b821c..bae1a63b0 100644 --- a/api/_modules/captum/attr/_core/layer/layer_feature_permutation.html +++ b/api/_modules/captum/attr/_core/layer/layer_feature_permutation.html @@ -31,7 +31,7 @@

Source code for captum.attr._core.layer.layer_feature_permutation

 #!/usr/bin/env python3
-from typing import Any, Callable, List, Tuple, Union
+from typing import Any, Callable, cast, List, Tuple, Union
 
 import torch
 from captum._utils.common import (
@@ -233,7 +233,11 @@ 

Source code for captum.attr._core.layer.layer_feature_permutation

finally: if hook is not None: hook.remove() - return eval + + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + return cast(Tensor, eval) with torch.no_grad(): inputs = _format_tensor_into_tuples(inputs) diff --git a/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html b/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html index 4b31b821c..bae1a63b0 100644 --- a/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html +++ b/api/_modules/captum/attr/_core/layer/layer_feature_permutation/index.html @@ -31,7 +31,7 @@

Source code for captum.attr._core.layer.layer_feature_permutation

 #!/usr/bin/env python3
-from typing import Any, Callable, List, Tuple, Union
+from typing import Any, Callable, cast, List, Tuple, Union
 
 import torch
 from captum._utils.common import (
@@ -233,7 +233,11 @@ 

Source code for captum.attr._core.layer.layer_feature_permutation

finally: if hook is not None: hook.remove() - return eval + + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + return cast(Tensor, eval) with torch.no_grad(): inputs = _format_tensor_into_tuples(inputs) diff --git a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html index 20589d6c8..bdea47362 100644 --- a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html +++ b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients.html @@ -33,7 +33,7 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

#!/usr/bin/env python3 import functools import warnings -from typing import Any, Callable, List, overload, Tuple, Union +from typing import Any, Callable, cast, List, overload, Tuple, Union import torch from captum._utils.common import ( @@ -136,7 +136,8 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

"Multiple layers provided. Please ensure that each layer is" "**not** solely dependent on the outputs of" "another layer. Please refer to the documentation for more" - "detail." + "detail.", + stacklevel=2, ) @overload @@ -503,13 +504,17 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

# the inputs is an empty tuple # coz it is prepended into additional_forward_args output = _run_forward( - self.forward_func, tuple(), target_ind, additional_forward_args + self.forward_func, (), target_ind, additional_forward_args ) finally: for hook in hooks: if hook is not None: hook.remove() + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + output = cast(Tensor, output) assert output[0].numel() == 1, ( "Target not provided when necessary, cannot" " take gradient with respect to multiple outputs." diff --git a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html index 20589d6c8..bdea47362 100644 --- a/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html +++ b/api/_modules/captum/attr/_core/layer/layer_integrated_gradients/index.html @@ -33,7 +33,7 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

#!/usr/bin/env python3 import functools import warnings -from typing import Any, Callable, List, overload, Tuple, Union +from typing import Any, Callable, cast, List, overload, Tuple, Union import torch from captum._utils.common import ( @@ -136,7 +136,8 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

"Multiple layers provided. Please ensure that each layer is" "**not** solely dependent on the outputs of" "another layer. Please refer to the documentation for more" - "detail." + "detail.", + stacklevel=2, ) @overload @@ -503,13 +504,17 @@

Source code for captum.attr._core.layer.layer_integrated_gradients

# the inputs is an empty tuple # coz it is prepended into additional_forward_args output = _run_forward( - self.forward_func, tuple(), target_ind, additional_forward_args + self.forward_func, (), target_ind, additional_forward_args ) finally: for hook in hooks: if hook is not None: hook.remove() + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + output = cast(Tensor, output) assert output[0].numel() == 1, ( "Target not provided when necessary, cannot" " take gradient with respect to multiple outputs." diff --git a/api/_modules/captum/attr/_core/lrp.html b/api/_modules/captum/attr/_core/lrp.html index 700bba376..81dcf2ae9 100644 --- a/api/_modules/captum/attr/_core/lrp.html +++ b/api/_modules/captum/attr/_core/lrp.html @@ -401,7 +401,11 @@

Source code for captum.attr._core.lrp

         # adjustments as inputs to the layers with adjusted weights. This procedure
         # is important for graph generation in the 2nd forward pass.
         self._register_pre_hooks()
-        return output
+
+        # _run_forward may return future of Tensor,
+        # but we don't support it here now
+        # And it will fail before here.
+        return cast(Tensor, output)
 
     def _remove_forward_hooks(self) -> None:
         for forward_handle in self.forward_handles:
diff --git a/api/_modules/captum/attr/_core/lrp/index.html b/api/_modules/captum/attr/_core/lrp/index.html
index 700bba376..81dcf2ae9 100644
--- a/api/_modules/captum/attr/_core/lrp/index.html
+++ b/api/_modules/captum/attr/_core/lrp/index.html
@@ -401,7 +401,11 @@ 

Source code for captum.attr._core.lrp

         # adjustments as inputs to the layers with adjusted weights. This procedure
         # is important for graph generation in the 2nd forward pass.
         self._register_pre_hooks()
-        return output
+
+        # _run_forward may return future of Tensor,
+        # but we don't support it here now
+        # And it will fail before here.
+        return cast(Tensor, output)
 
     def _remove_forward_hooks(self) -> None:
         for forward_handle in self.forward_handles:
diff --git a/api/_modules/captum/attr/_core/shapley_value.html b/api/_modules/captum/attr/_core/shapley_value.html
index dad7d28e5..c9a9c96ad 100644
--- a/api/_modules/captum/attr/_core/shapley_value.html
+++ b/api/_modules/captum/attr/_core/shapley_value.html
@@ -35,7 +35,7 @@ 

Source code for captum.attr._core.shapley_value

< import itertools import math import warnings -from typing import Any, Callable, Iterable, Sequence, Tuple, Union +from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -59,7 +59,7 @@

Source code for captum.attr._core.shapley_value

< _tensorize_baseline, ) from captum.log import log_usage -from torch import Tensor +from torch import dtype, Tensor def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: @@ -588,7 +588,7 @@

Source code for captum.attr._core.shapley_value

< # using python built-in type as torch dtype # int -> torch.int64, float -> torch.float64 # ref: https://github.com/pytorch/pytorch/pull/21215 - return torch.tensor([forward_output], dtype=output_type)
+ return torch.tensor([forward_output], dtype=cast(dtype, output_type))
diff --git a/api/_modules/captum/attr/_core/shapley_value/index.html b/api/_modules/captum/attr/_core/shapley_value/index.html index dad7d28e5..c9a9c96ad 100644 --- a/api/_modules/captum/attr/_core/shapley_value/index.html +++ b/api/_modules/captum/attr/_core/shapley_value/index.html @@ -35,7 +35,7 @@

Source code for captum.attr._core.shapley_value

< import itertools import math import warnings -from typing import Any, Callable, Iterable, Sequence, Tuple, Union +from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -59,7 +59,7 @@

Source code for captum.attr._core.shapley_value

< _tensorize_baseline, ) from captum.log import log_usage -from torch import Tensor +from torch import dtype, Tensor def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: @@ -588,7 +588,7 @@

Source code for captum.attr._core.shapley_value

< # using python built-in type as torch dtype # int -> torch.int64, float -> torch.float64 # ref: https://github.com/pytorch/pytorch/pull/21215 - return torch.tensor([forward_output], dtype=output_type)
+ return torch.tensor([forward_output], dtype=cast(dtype, output_type))
diff --git a/api/_modules/captum/attr/_utils/attribution.html b/api/_modules/captum/attr/_utils/attribution.html index 2c5848ea1..2f2c71702 100644 --- a/api/_modules/captum/attr/_utils/attribution.html +++ b/api/_modules/captum/attr/_utils/attribution.html @@ -321,17 +321,22 @@

Source code for captum.attr._utils.attribution

_validate_target(num_samples, target) with torch.no_grad(): - start_out_sum = _sum_rows( - _run_forward( - self.forward_func, start_point, target, additional_forward_args - ) + start_out_eval = _run_forward( + self.forward_func, start_point, target, additional_forward_args ) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + start_out_sum = _sum_rows(cast(Tensor, start_out_eval)) - end_out_sum = _sum_rows( - _run_forward( - self.forward_func, end_point, target, additional_forward_args - ) + end_out_eval = _run_forward( + self.forward_func, end_point, target, additional_forward_args ) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + end_out_sum = _sum_rows(cast(Tensor, end_out_eval)) + row_sums = [_sum_rows(attribution) for attribution in attributions] attr_sum = torch.stack( [cast(Tensor, sum(row_sum)) for row_sum in zip(*row_sums)] diff --git a/api/_modules/captum/attr/_utils/attribution/index.html b/api/_modules/captum/attr/_utils/attribution/index.html index 2c5848ea1..2f2c71702 100644 --- a/api/_modules/captum/attr/_utils/attribution/index.html +++ b/api/_modules/captum/attr/_utils/attribution/index.html @@ -321,17 +321,22 @@

Source code for captum.attr._utils.attribution

_validate_target(num_samples, target) with torch.no_grad(): - start_out_sum = _sum_rows( - _run_forward( - self.forward_func, start_point, target, additional_forward_args - ) + start_out_eval = _run_forward( + self.forward_func, start_point, target, additional_forward_args ) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + start_out_sum = _sum_rows(cast(Tensor, start_out_eval)) - end_out_sum = _sum_rows( - _run_forward( - self.forward_func, end_point, target, additional_forward_args - ) + end_out_eval = _run_forward( + self.forward_func, end_point, target, additional_forward_args ) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + end_out_sum = _sum_rows(cast(Tensor, end_out_eval)) + row_sums = [_sum_rows(attribution) for attribution in attributions] attr_sum = torch.stack( [cast(Tensor, sum(row_sum)) for row_sum in zip(*row_sums)] diff --git a/api/_modules/captum/metrics/_core/infidelity.html b/api/_modules/captum/metrics/_core/infidelity.html index fee298954..b7b19d44f 100644 --- a/api/_modules/captum/metrics/_core/infidelity.html +++ b/api/_modules/captum/metrics/_core/infidelity.html @@ -530,6 +530,10 @@

Source code for captum.metrics._core.infidelity

< additional_forward_args_expanded, ) inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + inputs_fwd = cast(Tensor, inputs_fwd) inputs_fwd = torch.repeat_interleave( inputs_fwd, current_n_perturb_samples, dim=0 ) diff --git a/api/_modules/captum/metrics/_core/infidelity/index.html b/api/_modules/captum/metrics/_core/infidelity/index.html index fee298954..b7b19d44f 100644 --- a/api/_modules/captum/metrics/_core/infidelity/index.html +++ b/api/_modules/captum/metrics/_core/infidelity/index.html @@ -530,6 +530,10 @@

Source code for captum.metrics._core.infidelity

< additional_forward_args_expanded, ) inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args) + # _run_forward may return future of Tensor, + # but we don't support it here now + # And it will fail before here. + inputs_fwd = cast(Tensor, inputs_fwd) inputs_fwd = torch.repeat_interleave( inputs_fwd, current_n_perturb_samples, dim=0 ) diff --git a/tutorials/CIFAR_TorchVision_Captum_Insights.html b/tutorials/CIFAR_TorchVision_Captum_Insights.html index 4c943ecf5..8cca8fdc2 100644 --- a/tutorials/CIFAR_TorchVision_Captum_Insights.html +++ b/tutorials/CIFAR_TorchVision_Captum_Insights.html @@ -234,10 +234,10 @@

-
+