Skip to content

Commit

Permalink
Add additional gradient-based attribution methods to LLM Attribution (#…
Browse files Browse the repository at this point in the history
…1337)

Summary:
Add `LayerGradientXActivation` and `LayerGradientShap` to the supported gradient-based LLM attribution methods.

Pull Request resolved: #1337

Test Plan: `pytest tests/attr -k TestLLMGradAttr` with new test cases via parameterized library

Reviewed By: cyrjano

Differential Revision: D62221000

Pulled By: craymichael

fbshipit-source-id: fb5f170e13a62355357d46d3ef7a2464e8eb80ab
  • Loading branch information
craymichael authored and facebook-github-bot committed Sep 9, 2024
1 parent d89243b commit d8ceaa8
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 56 deletions.
5 changes: 3 additions & 2 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
Default: 1e-10
"""
GradientAttribution.__init__(self, model)
self.model = model
self.model: nn.Module = model
self.eps = eps
self.forward_handles: List[RemovableHandle] = []
self.backward_handles: List[RemovableHandle] = []
Expand Down Expand Up @@ -324,7 +324,8 @@ def attribute( # type: ignore
warnings.warn(
"""Setting forward, backward hooks and attributes on non-linear
activations. The hooks and attributes will be removed
after the attribution is finished"""
after the attribution is finished""",
stacklevel=2,
)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
Expand Down
6 changes: 3 additions & 3 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
grad_kwargs=grad_kwargs,
)

attr_inputs = tuple(map(lambda attr: attr[0], attrs))
attr_baselines = tuple(map(lambda attr: attr[1], attrs))
gradients = tuple(map(lambda grad: grad[0], gradients))
attr_inputs = tuple(attr[0] for attr in attrs)
attr_baselines = tuple(attr[1] for attr in attrs)
gradients = tuple(grad[0] for grad in gradients)

if custom_attribution_func is None:
if self.multiplies_by_inputs:
Expand Down
93 changes: 56 additions & 37 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
Expand Down Expand Up @@ -452,7 +454,11 @@ class LLMGradientAttribution(Attribution):
and returns LLMAttributionResult
"""

SUPPORTED_METHODS = (LayerIntegratedGradients,)
SUPPORTED_METHODS = (
LayerGradientShap,
LayerGradientXActivation,
LayerIntegratedGradients,
)
SUPPORTED_INPUTS = (TextTokenInput,)

def __init__(
Expand All @@ -473,53 +479,21 @@ class created with the llm model that follows huggingface style

super().__init__(attr_method.forward_func)

# shallow copy is enough to avoid modifying original instance
self.attr_method: GradientAttribution = copy(attr_method)
self.attr_method.forward_func = self._forward_func

# alias, we really need a model and don't support wrapper functions
# coz we need call model.forward, model.generate, etc.
self.model: nn.Module = cast(nn.Module, self.forward_func)

# shallow copy is enough to avoid modifying original instance
self.attr_method: GradientAttribution = copy(attr_method)
self.attr_method.forward_func = GradientForwardFunc(self)

self.tokenizer: TokenizerLike = tokenizer
self.device: torch.device = (
cast(torch.device, self.model.device)
if hasattr(self.model, "device")
else next(self.model.parameters()).device
)

def _forward_func(
self,
perturbed_tensor: Tensor,
inp: InterpretableInput,
target_tokens: Tensor, # 1D tensor of target token ids
cur_target_idx: int, # current target index
) -> Tensor:
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))

if cur_target_idx:
# the input batch size can be expanded by attr method
output_token_tensor = (
target_tokens[:cur_target_idx]
.unsqueeze(0)
.expand(perturbed_input.size(0), -1)
.to(self.device)
)
new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1)
else:
new_input_tensor = perturbed_input

output_logits = self.model(new_input_tensor)

new_token_logits = output_logits.logits[:, -1]
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

target_token = target_tokens[cur_target_idx]
token_log_probs = log_probs[..., target_token]

# the attribution target is limited to the log probability
return token_log_probs

def _format_model_input(self, model_input: Tensor) -> Tensor:
"""
Convert str to tokenized tensor
Expand Down Expand Up @@ -643,3 +617,48 @@ def attribute_future(self) -> Callable:
raise NotImplementedError(
"attribute_future is not implemented for LLMGradientAttribution"
)


class GradientForwardFunc(nn.Module):
"""
A wrapper class for the forward function of a model in LLMGradientAttribution
"""

def __init__(self, attr: LLMGradientAttribution) -> None:
super().__init__()
self.attr = attr
self.model: nn.Module = attr.model

def forward(
self,
perturbed_tensor: Tensor,
inp: InterpretableInput,
target_tokens: Tensor, # 1D tensor of target token ids
cur_target_idx: int, # current target index
) -> Tensor:
perturbed_input = self.attr._format_model_input(
inp.to_model_input(perturbed_tensor)
)

if cur_target_idx:
# the input batch size can be expanded by attr method
output_token_tensor = (
target_tokens[:cur_target_idx]
.unsqueeze(0)
.expand(perturbed_input.size(0), -1)
.to(self.attr.device)
)
new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1)
else:
new_input_tensor = perturbed_input

output_logits = self.model(new_input_tensor)

new_token_logits = output_logits.logits[:, -1]
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

target_token = target_tokens[cur_target_idx]
token_log_probs = log_probs[..., target_token]

# the attribution target is limited to the log probability
return token_log_probs
73 changes: 59 additions & 14 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
# pyre-strict

import copy
from typing import Any, cast, Dict, List, NamedTuple, Optional, Type, Union
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Type, Union

import torch
from captum._utils.models.linear_model import ( # @manual=//pytorch/captum/captum/_utils/models/linear_model:linear_model # noqa: E501
SkLearnLasso,
)
from captum._utils.models.linear_model import SkLearnLasso
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
Expand Down Expand Up @@ -379,15 +379,30 @@ def test_futures_not_implemented(self) -> None:
class TestLLMGradAttr(BaseTest):
device: str

def test_llm_attr(self) -> None:
@parameterized.expand(
[
(LayerIntegratedGradients, None),
(LayerGradientXActivation, None),
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
]
)
def test_llm_attr(
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
llm_attr = LLMGradientAttribution(attr, tokenizer)

attr_kws: Dict[str, Any] = {}
if baselines is not None:
attr_kws["baselines"] = tuple(
baseline.to(self.device) for baseline in baselines
)

inp = TextTokenInput("a b c", tokenizer)
res = llm_attr.attribute(inp, "m n o p q")
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
Expand All @@ -402,15 +417,30 @@ def test_llm_attr(self) -> None:
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore

def test_llm_attr_without_target(self) -> None:
@parameterized.expand(
[
(LayerIntegratedGradients, None),
(LayerGradientXActivation, None),
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
]
)
def test_llm_attr_without_target(
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
llm_attr = LLMGradientAttribution(attr, tokenizer)

attr_kws: Dict[str, Any] = {}
if baselines is not None:
attr_kws["baselines"] = tuple(
baseline.to(self.device) for baseline in baselines
)

inp = TextTokenInput("a b c", tokenizer)
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"})
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws)

self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
Expand All @@ -424,15 +454,30 @@ def test_llm_attr_without_target(self) -> None:
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore

def test_llm_attr_with_skip_tokens(self) -> None:
@parameterized.expand(
[
(LayerIntegratedGradients, None),
(LayerGradientXActivation, None),
(LayerGradientShap, (torch.tensor([[1, 0, 1]]),)),
]
)
def test_llm_attr_with_skip_tokens(
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
llm_attr = LLMGradientAttribution(attr, tokenizer)

attr_kws: Dict[str, Any] = {}
if baselines is not None:
attr_kws["baselines"] = tuple(
baseline.to(self.device) for baseline in baselines
)

inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
res = llm_attr.attribute(inp, "m n o p q")
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (3,))
Expand Down

0 comments on commit d8ceaa8

Please sign in to comment.