Skip to content

Commit

Permalink
Add support for encoding and sparse data in RasrAlignmentDumpHDFJob (#…
Browse files Browse the repository at this point in the history
…434)

* add handling of encoding

* Add support for sparse alignments

* Add filter_list_keep
  • Loading branch information
michelwi authored Jul 31, 2023
1 parent df08050 commit 43dfdef
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
38 changes: 20 additions & 18 deletions lib/rasr_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import mmap
import numpy
import os
import sys
import typing
import zlib
from struct import pack, unpack
Expand Down Expand Up @@ -51,7 +50,8 @@ class FileArchive:
start_recovery_tag = 0xAA55AA55
end_recovery_tag = 0x55AA55AA

def __init__(self, filename, must_exists=False):
def __init__(self, filename, must_exists=False, encoding="ascii"):
self.encoding = encoding

self.ft = {} # type: typing.Dict[str,FileInfo]
if os.path.exists(filename):
Expand Down Expand Up @@ -182,12 +182,12 @@ def read_v(self, typ, size):
return res

# write routines
def write_str(self, s):
def write_str(self, s, enc="ascii"):
"""
:param str s:
:rtype: int
"""
return self.f.write(pack("%ds" % len(s), s.encode("ascii")))
return self.f.write(pack("%ds" % len(s.encode(enc)), s.encode(enc)))

def write_char(self, i):
"""
Expand Down Expand Up @@ -256,7 +256,7 @@ def readFileInfoTable(self):
return
for i in range(count):
str_len = self.read_u32()
name = self.read_str(str_len)
name = self.read_str(str_len, self.encoding)
pos = self.read_u64()
size = self.read_u32()
comp = self.read_u32()
Expand All @@ -271,8 +271,8 @@ def writeFileInfoTable(self):
self.write_u32(len(self.ft))

for fi in self.ft.values():
self.write_u32(len(fi.name))
self.write_str(fi.name)
self.write_u32(len(fi.name.encode(self.encoding)))
self.write_str(fi.name, self.encoding)
self.write_u64(fi.pos)
self.write_u32(fi.size)
self.write_u32(fi.compressed)
Expand All @@ -293,7 +293,7 @@ def scanArchive(self):
continue

fn_len = self.read_u32()
name = self.read_str(fn_len)
name = self.read_str(fn_len, self.encoding)
pos = self.f.tell()
size = self.read_u32()
comp = self.read_u32()
Expand Down Expand Up @@ -322,7 +322,7 @@ def _raw_read(self, size, typ):
"""

if typ == "str":
return self.read_str(size)
return self.read_str(size, self.encoding)

elif typ == "feat":
type_len = self.read_U32()
Expand Down Expand Up @@ -496,8 +496,8 @@ def addFeatureCache(self, filename, features, times):
:param times:
"""
self.write_U32(self.start_recovery_tag)
self.write_u32(len(filename))
self.write_str(filename)
self.write_u32(len(filename.encode(self.encoding)))
self.write_str(filename, self.encoding)
pos = self.f.tell()
if len(features) > 0:
dim = len(features[0])
Expand Down Expand Up @@ -542,8 +542,8 @@ def addAttributes(self, filename, dim, duration):
) % (dim, duration)
self.write_U32(self.start_recovery_tag)
filename = "%s.attribs" % filename
self.write_u32(len(filename))
self.write_str(filename)
self.write_u32(len(filename.encode(self.encoding)))
self.write_str(filename, self.encoding)
pos = self.f.tell()
size = len(data)
self.write_u32(size)
Expand All @@ -559,17 +559,18 @@ class FileArchiveBundle:
File archive bundle.
"""

def __init__(self, filename):
def __init__(self, filename, encoding="ascii"):
"""
:param str filename: .bundle file
:param str encoding: encoding used in the files
"""
# filename -> FileArchive
self.archives = {} # type: typing.Dict[str,FileArchive]
# archive content file -> FileArchive
self.files = {} # type: typing.Dict[str,FileArchive]
self._short_seg_names = {}
for line in open(filename).read().splitlines():
self.archives[line] = a = FileArchive(line, must_exists=True)
self.archives[line] = a = FileArchive(line, must_exists=True, encoding=encoding)
for f in a.ft.keys():
self.files[f] = a
# noinspection PyProtectedMember
Expand Down Expand Up @@ -616,17 +617,18 @@ def setAllophones(self, filename):
a.setAllophones(filename)


def open_file_archive(archive_filename, must_exists=True):
def open_file_archive(archive_filename, must_exists=True, encoding="ascii"):
"""
:param str archive_filename:
:param bool must_exists:
:param str encoding:
:rtype: FileArchiveBundle|FileArchive
"""
if archive_filename.endswith(".bundle"):
assert must_exists
return FileArchiveBundle(archive_filename)
return FileArchiveBundle(archive_filename, encoding=encoding)
else:
return FileArchive(archive_filename, must_exists=must_exists)
return FileArchive(archive_filename, must_exists=must_exists, encoding=encoding)


def is_rasr_cache_file(filename):
Expand Down
40 changes: 33 additions & 7 deletions returnn/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,28 +311,41 @@ class RasrAlignmentDumpHDFJob(Job):
This Job reads Rasr alignment caches and dump them in hdf files.
"""

__sis_hash_exclude__ = {"encoding": "ascii", "filter_list_keep": None, "sparse": False}

def __init__(
self,
alignment_caches: List[tk.Path],
allophone_file: tk.Path,
state_tying_file: tk.Path,
data_type: type = np.uint16,
returnn_root: Optional[tk.Path] = None,
encoding: str = "ascii",
filter_list_keep: Optional[tk.Path] = None,
sparse: bool = False,
):
"""
:param alignment_caches: e.g. output of an AlignmentJob
:param allophone_file: e.g. output of a StoreAllophonesJob
:param state_tying_file: e.g. output of a DumpStateTyingJob
:param data_type: type that is used to store the data
:param returnn_root: file path to the RETURNN repository root folder
:param encoding: encoding of the segment names in the cache
:param filter_list_keep: list of segment names to dump
:param sparse: writes the data to hdf in sparse format
"""
self.alignment_caches = alignment_caches
self.allophone_file = allophone_file
self.state_tying_file = state_tying_file
self.data_type = data_type
self.returnn_root = returnn_root
self.encoding = encoding
self.filter_list_keep = filter_list_keep
self.sparse = sparse

self.out_hdf_files = [self.output_path(f"data.hdf.{d}") for d in range(len(alignment_caches))]
self.out_excluded_segments = self.output_path(f"excluded.segments")
self.returnn_root = returnn_root
self.data_type = data_type

self.rqmt = {"cpu": 1, "mem": 8, "time": 0.5}

def tasks(self):
Expand All @@ -354,22 +367,35 @@ def run(self, task_id):
state_tying = dict(
(k, int(v)) for l in open(self.state_tying_file.get_path()) for k, v in [l.strip().split()[0:2]]
)
num_classes = max(state_tying.values()) + 1

alignment_cache = FileArchive(self.alignment_caches[task_id - 1].get_path())
alignment_cache = FileArchive(self.alignment_caches[task_id - 1].get_path(), encoding=self.encoding)
alignment_cache.setAllophones(self.allophone_file.get_path())
if self.filter_list_keep is not None:
keep_segments = set(open(self.filter_list_keep.get_path()).read().splitlines())
else:
keep_segments = None

returnn_root = None if self.returnn_root is None else self.returnn_root.get_path()
SimpleHDFWriter = get_returnn_simple_hdf_writer(returnn_root)
out_hdf = SimpleHDFWriter(filename=self.out_hdf_files[task_id - 1], dim=1)
out_hdf = SimpleHDFWriter(
filename=self.out_hdf_files[task_id - 1],
dim=num_classes if self.sparse else 1,
ndim=1 if self.sparse else 2,
)

excluded_segments = []

for file in alignment_cache.ft:
info = alignment_cache.ft[file]
if info.name.endswith(".attribs"):
continue
seq_name = info.name

if seq_name.endswith(".attribs"):
continue
if keep_segments is not None and seq_name not in keep_segments:
excluded_segments.append(seq_name)
continue

# alignment
targets = []
alignment = alignment_cache.read(file, "align")
Expand All @@ -382,7 +408,7 @@ def run(self, task_id):

data = np.array(targets).astype(np.dtype(self.data_type))
out_hdf.insert_batch(
inputs=data.reshape(1, -1, 1),
inputs=data.reshape(1, -1) if self.sparse else data.reshape(1, -1, 1),
seq_len=[data.shape[0]],
seq_tag=[seq_name],
)
Expand Down

0 comments on commit 43dfdef

Please sign in to comment.