-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlocal.py
103 lines (87 loc) · 3.16 KB
/
local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from speechbrain.dataio.batch import (default_convert, mod_default_collate, recursive_to, recursive_pin_memory)
from speechbrain.lobes.models.transformer.Transformer import (
TransformerInterface,
)
class NoSchedule:
def __init__(self, lr, **kwargs):
self.lr = lr
def __call__(self, *args, **kwargs):
return self.lr, self.lr
class Batch:
def __init__(
self,
examples,
device_prep_keys=None,
apply_default_convert=True,
):
self.__length = len(examples)
self.__keys = list(examples[0].keys())
self.__device_prep_keys = []
for key in self.__keys:
values = [example[key] for example in examples]
# Default convert usually does the right thing (numpy2torch etc.)
if apply_default_convert:
values = default_convert(values)
values = mod_default_collate(values)
setattr(self, key, values)
if (device_prep_keys is not None and key in device_prep_keys) or (
device_prep_keys is None and isinstance(values[0], torch.Tensor)
):
self.__device_prep_keys.append(key)
def __len__(self):
return self.__length
def __getitem__(self, key):
if key in self.__keys:
return getattr(self, key)
else:
raise KeyError(f"Batch doesn't have key: {key}")
def __iter__(self):
"""Iterates over the different elements of the batch.
Example
-------
>>> batch = PaddedBatch([
... {"id": "ex1", "val": torch.Tensor([1.])},
... {"id": "ex2", "val": torch.Tensor([2., 1.])}])
>>> ids, vals = batch
>>> ids
['ex1', 'ex2']
"""
return iter((getattr(self, key) for key in self.__keys))
def pin_memory(self):
"""In-place, moves relevant elements to pinned memory."""
for key in self.__device_prep_keys:
value = getattr(self, key)
pinned = recursive_pin_memory(value)
setattr(self, key, pinned)
return self
def to(self, *args, **kwargs):
"""In-place move/cast relevant elements.
Passes all arguments to torch.Tensor.to, see its documentation.
"""
for key in self.__device_prep_keys:
value = getattr(self, key)
moved = recursive_to(value, *args, **kwargs)
setattr(self, key, moved)
return self
def at_position(self, pos):
"""Fetch an item by its position in the batch."""
key = self.__keys[pos]
return getattr(self, key)
@property
def batchsize(self):
return self.__length
class TransformerAM(TransformerInterface):
def __init__(self,*args, **kwargs):
super().__init__(*args, num_decoder_layers=0, **kwargs)
def forward(self, x, src_key_padding_mask=None):
if self.causal:
attn_mask = get_lookahead_mask(x)
else:
attn_mask = None
encoder_output, _ = self.encoder(
src=x,
src_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask,
)
return encoder_output