diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 0531bff10df..71375fd13a2 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1070,17 +1070,20 @@ def _step( class CountingEnvWithString(CountingEnv): def __init__(self, *args, **kwargs): + self.max_size = kwargs.pop("max_size", 30) + self.min_size = kwargs.pop("min_size", 4) super().__init__(*args, **kwargs) self.observation_spec.set( "string", NonTensor( shape=self.batch_size, device=self.device, + example_data=self.get_random_string(), ), ) def get_random_string(self): - size = random.randint(4, 30) + size = random.randint(self.min_size, self.max_size) return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/test/test_specs.py b/test/test_specs.py index a75ff0352c7..07762c7ad30 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2): assert spec2.zero().shape == spec2.shape def test_non_tensor(self): - spec = NonTensor((3, 4), device="cpu") + spec = NonTensor((3, 4), device="cpu", example_data="example_data") assert ( spec.expand(2, 3, 4) == spec.expand((2, 3, 4)) - == NonTensor((2, 3, 4), device="cpu") + == NonTensor((2, 3, 4), device="cpu", example_data="example_data") ) + assert spec.expand(2, 3, 4).example_data == "example_data" @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) @@ -1607,9 +1608,10 @@ def test_multionehot( assert spec is not spec.clone() def test_non_tensor(self): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.clone() == spec assert spec.clone() is not spec + assert spec.clone().example_data == "example_data" @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_onehot( @@ -1840,9 +1842,10 @@ def test_multionehot( spec.unbind(-1) def test_non_tensor(self): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.unbind(1)[0] == spec[:, 0] assert spec.unbind(1)[0] is not spec[:, 0] + assert spec.unbind(1)[0].example_data == "example_data" @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot( @@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device): assert spec.to(device).device == device def test_non_tensor(self, device): - spec = NonTensor(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data") assert spec.to(device).device == device + assert spec.to(device).example_data == "example_data" @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot(self, shape1, device): @@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim): assert r.shape == c.shape def test_stack_non_tensor(self, shape, stack_dim): - spec0 = NonTensor(shape=shape, device="cpu") - spec1 = NonTensor(shape=shape, device="cpu") + spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data") + spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data") new_spec = torch.stack([spec0, spec1], stack_dim) shape_insert = list(shape) shape_insert.insert(stack_dim, 2) assert new_spec.shape == torch.Size(shape_insert) assert new_spec.device == torch.device("cpu") + assert new_spec.example_data == "example_data" def test_stack_onehot(self, shape, stack_dim): n = 5 @@ -3642,10 +3647,18 @@ def test_expand(self): class TestNonTensorSpec: def test_sample(self): - nts = NonTensor(shape=(3, 4)) + nts = NonTensor(shape=(3, 4), example_data="example_data") assert nts.one((2,)).shape == (2, 3, 4) assert nts.rand((2,)).shape == (2, 3, 4) assert nts.zero((2,)).shape == (2, 3, 4) + assert nts.one((2,)).data == "example_data" + assert nts.rand((2,)).data == "example_data" + assert nts.zero((2,)).data == "example_data" + + def test_example_data_ineq(self): + nts0 = NonTensor(shape=(3, 4), example_data="example_data") + nts1 = NonTensor(shape=(3, 4), example_data="example_data 2") + assert nts0 != nts1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device") diff --git a/test/test_transforms.py b/test/test_transforms.py index ec413b2b34c..0689edc435c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -147,6 +147,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, TransformedEnv, @@ -2420,7 +2421,223 @@ def test_transform_rb(self, rbclass): assert ("next", "observation") in td.keys(True) def test_transform_inverse(self): - raise pytest.skip("No inverse for Hash") + env = CountingEnv() + env = env.append_transform( + Hash( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=["action_hash"], + ) + ) + assert "action_hash" in env.action_keys + r = env.rollout(3) + env.check_env_specs() + assert "action_hash" in r + assert isinstance(r[0]["action_hash"], torch.Tensor) + + +class TestTokenizer(TransformBase): + @pytest.mark.parametrize("datatype", ["str", "NonTensorStack"]) + def test_transform_no_env(self, datatype): + if datatype == "str": + obs = "abcdefg" + elif datatype == "NonTensorStack": + obs = torch.stack( + [ + NonTensorData(data="abcde"), + NonTensorData(data="fghij"), + NonTensorData(data="klmno"), + ] + ) + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + + t = Tokenizer(in_keys=["observation"], out_keys=["tokens"]) + td_tokenized = t(td) + t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"]) + td_recon = t_inv.inv(td_tokenized.clone().exclude("observation")) + assert td_tokenized.get("observation") is td.get("observation") + assert td_recon["observation"] == td["observation"] + + @pytest.mark.parametrize("datatype", ["str"]) + def test_single_trans_env_check(self, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + env = TransformedEnv(base_env, t) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_serial_trans_env_check(self, datatype): + def make_env(): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + + return TransformedEnv(base_env, t) + + env = SerialEnv(2, make_env) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): + def make_env(): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + return TransformedEnv(base_env, t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env, return_contiguous=False) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["str"]) + def test_trans_serial_env_check(self, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) + + env = TransformedEnv(SerialEnv(2, base_env), t) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) + try: + check_env_specs(env, return_contiguous=False) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["str"]) + def test_transform_compose(self, datatype): + if datatype == "str": + obs = "abcdefg" + + td = TensorDict( + { + "observation": obs, + } + ) + t = Tokenizer( + in_keys=["observation"], + out_keys=["tokens"], + max_length=5, + ) + t = Compose(t) + td_tokenized = t(td) + + assert td_tokenized["observation"] is td["observation"] + assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt") + + # TODO + def test_transform_model(self): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hashing"), ("hashing",)], + hash_fn=hash, + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td_out = model(td) + assert ("next", "hashing") in td_out.keys(True) + assert ("hashing",) in td_out.keys(True) + assert td_out["next", "hashing"] == hash(td["next", "observation"]) + assert td_out["hashing"] == hash(td["observation"]) + + @pytest.mark.skipif(not _has_gym, reason="Gym not found") + def test_transform_env(self): + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) + assert env.observation_spec["hashing"] + assert "observation" in env.observation_spec + assert "observation" in env.base_env.observation_spec + check_env_specs(env) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hashing"), ("hashing",)], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": torch.randn(3, 4), + "next": TensorDict( + {"observation": torch.randn(3, 4)}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "hashing" in td.keys() + assert "observation" in td.keys() + assert ("next", "observation") in td.keys(True) + + def test_transform_inverse(self): + env = CountingEnv() + env = env.append_transform( + Hash( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=["action_hash"], + ) + ) + assert "action_hash" in env.action_keys + r = env.rollout(3) + env.check_env_specs() + assert "action_hash" in r + assert isinstance(r[0]["action_hash"], torch.Tensor) class TestStack(TransformBase): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index bb8bebf2db8..3d4198ae234 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2452,6 +2452,8 @@ class NonTensor(TensorSpec): (same will go for :meth:`.zero` and :meth:`.one`). """ + example_data: Any = None + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -2470,6 +2472,11 @@ def __init__( ) self.example_data = example_data + def __eq__(self, other): + eq = super().__eq__(other) + eq = eq & (self.example_data == getattr(other, "example_data", None)) + return eq + def cardinality(self) -> Any: raise RuntimeError("Cannot enumerate a NonTensorSpec.") @@ -2555,6 +2562,16 @@ def expand(self, *shape): shape=shape, device=self.device, dtype=None, example_data=self.example_data ) + def unsqueeze(self, dim: int) -> NonTensor: + unsq = super().unsqueeze(dim=dim) + unsq.example_data = self.example_data + return unsq + + def squeeze(self, dim: int | None = None) -> NonTensor: + sq = super().squeeze(dim=dim) + sq.example_data = self.example_data + return sq + def _reshape(self, shape): return self.__class__( shape=shape, diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 3a4cde38aa2..fed73755502 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -94,6 +94,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index a25c676e378..7ee142fe811 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -55,6 +55,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 50d2762ad0b..2e9b1e7fa69 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -359,7 +359,6 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(out_key, item) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") - return tensordict @dispatch(source="in_keys_inv", dest="out_keys_inv") @@ -4426,12 +4425,15 @@ class UnaryTransform(Transform): Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. - in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call. - out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call. Keyword Args: - fn (Callable): the function to use as the unary operation. If it accepts + fn (Callable[[Any], Tensor | TensorDictBase]): the function to use as the unary operation. If it accepts a non-tensor input, it must also accept ``None``. + inv_fn (Callable[[Any], Any], optional): the function to use as the unary operation during inverse calls. + If it accepts a non-tensor input, it must also accept ``None``. + Can be ommitted, in which case :attr:`fn` will be used for inverse maps. use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4505,7 +4507,8 @@ def __init__( in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, - fn: Callable, + fn: Callable[[Any], Tensor | TensorDictBase], + inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False, ): super().__init__( @@ -4515,6 +4518,7 @@ def __init__( out_keys_inv=out_keys_inv, ) self._fn = fn + self._inv_fn = inv_fn self._use_raw_nontensor = use_raw_nontensor def _apply_transform(self, value): @@ -4537,6 +4541,8 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: state = state.tolist() elif isinstance(state, NonTensorStack): state = state.tolist() + if self._inv_fn is not None: + return self._inv_fn(state) return self._fn(state) def _reset( @@ -4555,7 +4561,17 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: test_input = zero_input_["full_action_spec"].update( zero_input_["full_state_spec"] ) - test_output = self.inv(test_input) + # We use forward and not inv because the spec comes from the base env and + # we are trying to infer what the spec looks like from the outside. + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): + data = test_input.get(in_key, None) + if data is not None: + data = self._apply_transform(data) + test_input.set(out_key, data) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {test_input}") + test_output = test_input + # test_output = self.inv(test_input) test_input_spec = make_composite_from_td( test_output, unsqueeze_null_shapes=False ) @@ -4569,7 +4585,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: input_spec["full_state_spec"], test_input_spec, ) - print(input_spec) return input_spec def transform_output_spec(self, output_spec: Composite) -> Composite: @@ -4605,14 +4620,19 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: return output_spec def _transform_spec( - self, spec: TensorSpec, test_output_spec: TensorSpec + self, spec: TensorSpec, test_output_spec: TensorSpec, inverse: bool = False ) -> TensorSpec: if not isinstance(spec, Composite): raise TypeError(f"{self}: Only specs of type Composite can be transformed") spec_keys = set(spec.keys(include_nested=True)) - for in_key, out_key in zip(self.in_keys, self.out_keys): + iterator = ( + zip(self.in_keys, self.out_keys) + if not inverse + else zip(self.in_keys_inv, self.out_keys_inv) + ) + for in_key, out_key in iterator: if in_key in spec_keys: spec.set(out_key, test_output_spec[out_key]) return spec @@ -4635,12 +4655,12 @@ def transform_done_spec( def transform_action_spec( self, action_spec: TensorSpec, test_input_spec: TensorSpec ) -> TensorSpec: - return self._transform_spec(action_spec, test_input_spec) + return self._transform_spec(action_spec, test_input_spec, inverse=True) def transform_state_spec( self, state_spec: TensorSpec, test_input_spec: TensorSpec ) -> TensorSpec: - return self._transform_spec(state_spec, test_input_spec) + return self._transform_spec(state_spec, test_input_spec, inverse=True) class Hash(UnaryTransform): @@ -4649,8 +4669,15 @@ class Hash(UnaryTransform): Args: in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. - in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call. - out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call. + in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call. + + .. note:: If an inverse map is required, a repertoire ``Dict[Tuple[int], Any]`` of hash to value should be + passed alongside the list of keys to let the ``Hash`` transform know how to recover a value from a + given hash. This repertoire isn't copied, so it can be modified in the same workspace after the + transform instantiation and these modifications will be reflected in the map. Missing hashes will be + mapped to ``None``. + + out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call. Keyword Args: hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, @@ -4738,6 +4765,8 @@ class Hash(UnaryTransform): [torchrl][INFO] check_env_specs succeeded! """ + _repertoire: Dict[Tuple[int], Any] + def __init__( self, in_keys: Sequence[NestedKey], @@ -4748,6 +4777,7 @@ def __init__( hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, + repertoire: Dict[Tuple[int], Any] | None = None, ): if hash_fn is None: hash_fn = Hash.reproducible_hash @@ -4762,6 +4792,37 @@ def __init__( fn=self.call_hash_fn, use_raw_nontensor=use_raw_nontensor, ) + if in_keys_inv is not None: + self._repertoire = repertoire if repertoire is not None else {} + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + inputs = tensordict.select(*self.in_keys_inv).detach().cpu() + tensordict = super()._inv_call(tensordict) + + def register_outcome(td): + # We need to treat each hash independently + if td.ndim: + if td.ndim > 1: + td_r = td.reshape(-1) + elif td.ndim == 1: + td_r = td + result = torch.stack([register_outcome(_td) for _td in td_r.unbind(0)]) + if td_r is not td: + return result.reshape(td.shape) + return result + for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + inp = inputs.get(in_key) + inp = tuple(inp.tolist()) + outp = self._repertoire.get(inp) + td[out_key] = outp + return td + + return register_outcome(tensordict) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if self.in_keys_inv is not None: + return {"_repertoire": self._repertoire} + return {} def call_hash_fn(self, value): if self._seed is None: @@ -4801,6 +4862,116 @@ def reproducible_hash(cls, string, seed=None): return torch.frombuffer(hash_bytes, dtype=torch.uint8) +class Tokenizer(UnaryTransform): + r"""Applies a tokenization operation on the specified inputs. + + Args: + in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation. + out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call. + + Keyword Args: + tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``, + "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a + pre-trained tokenizer. + use_raw_nontensor (bool, optional): if ``False``, data is extracted from + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization + function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` + inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``. + additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + *, + tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821 + use_raw_nontensor: bool = False, + additional_tokens: List[str] | None = None, + skip_special_tokens: bool = True, + add_special_tokens: bool = False, + padding: bool = True, + max_length: int | None = None, + ): + if tokenizer is None: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + elif isinstance(tokenizer, str): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + self.tokenizer = tokenizer + self.add_special_tokens = add_special_tokens + self.skip_special_tokens = skip_special_tokens + self.padding = padding + self.max_length = max_length + if additional_tokens: + self.tokenizer.add_tokens(additional_tokens) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + fn=self.call_tokenizer_fn, + inv_fn=self.call_tokenizer_inv_fn, + use_raw_nontensor=use_raw_nontensor, + ) + + @property + def device(self): + if "_device" in self.__dict__: + return self._device + parent = self.parent + if parent is None: + return None + device = parent.device + self._device = device + return device + + def call_tokenizer_fn(self, value: str | List[str]): + device = self.device + kwargs = {"add_special_tokens": self.add_special_tokens} + if self.max_length is not None: + kwargs["padding"] = "max_length" + kwargs["max_length"] = self.max_length + if isinstance(value, str): + out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0] + # TODO: incorporate attention mask + attention_mask = torch.ones_like(out, dtype=torch.bool) + else: + kwargs["padding"] = ( + self.padding if self.max_length is None else "max_length" + ) + # kwargs["return_attention_mask"] = False + # kwargs["return_token_type_ids"] = False + out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs) + attention_mask = out["attention_mask"] + out = out["input_ids"] + + if device is not None and out.device != device: + out = out.to(device) + return out + + def call_tokenizer_inv_fn(self, value: Tensor): + if value.ndim == 1: + out = self.tokenizer.decode( + value, skip_special_tokens=self.skip_special_tokens + ) + else: + out = self.tokenizer.batch_decode( + value, skip_special_tokens=self.skip_special_tokens + ) + if isinstance(out, list): + return NonTensorStack(*out) + return NonTensorData(out) + + class Stack(Transform): """Stacks tensors and tensordicts.