Skip to content

Commit

Permalink
sctk_binary_path variable in Hub5ScoreJob (#292)
Browse files Browse the repository at this point in the history
* Add sctk_binary_path variable to Hub5ScoreJob and add type annotations.

* Reformat.

* Update type annotations

Co-authored-by: michelwi <[email protected]>

* Update docstring

Co-authored-by: michelwi <[email protected]>

* Update imports.

Co-authored-by: michelwi <[email protected]>

* Remove tk.uncached_path.

Co-authored-by: Daniel Mann <[email protected]>
Co-authored-by: michelwi <[email protected]>
  • Loading branch information
3 people authored Jul 25, 2022
1 parent 092b141 commit cb0e2a4
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions recognition/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,27 @@ def calc_wer(self):


class Hub5ScoreJob(Job):
def __init__(self, ref, glm, hyp):
__sis_hash_exclude__ = {"sctk_binary_path": None}

def __init__(
self,
ref: tk.Path,
glm: tk.Path,
hyp: tk.Path,
sctk_binary_path: Optional[tk.Path] = None,
):
"""
:param ref: reference stm text file
:param glm: text file containing mapping rules for scoring
:param hyp: hypothesis ctm text file
:param sctk_binary_path: set an explicit binary path.
"""
self.set_vis_name("HubScore")

self.glm = glm
self.hyp = hyp
self.ref = ref
self.sctk_binary_path = sctk_binary_path

self.out_report_dir = self.output_path("reports", True)

Expand Down Expand Up @@ -235,29 +250,33 @@ def tasks(self):
yield Task("run", mini_task=True)

def run(self, move_files=True):
hubscr_path = (
os.path.join(gs.SCTK_PATH, "hubscr.pl")
if hasattr(gs, "SCTK_PATH")
else "hubscr.pl"
)
sctk_opt = ["-p", gs.SCTK_PATH] if hasattr(gs, "SCTK_PATH") else []
sctk_path = ""
if self.sctk_binary_path is not None:
sctk_path = self.sctk_binary_path.get_path()
elif hasattr(gs, "SCTK_PATH"):
sctk_path = gs.SCTK_PATH
hubscr_path = os.path.join(
sctk_path, "hubscr.pl"
) # evaluates to just "hubscr.pl" if sctk_path is empty

sctk_opt = ["-p", sctk_path] if sctk_path else []

ref = self.ref
try:
ref = shutil.copy(tk.uncached_path(ref), ".")
ref = shutil.copy(ref.get_path(), ".")
except shutil.SameFileError:
pass

hyp = self.hyp
try:
hyp = shutil.copy(tk.uncached_path(hyp), ".")
hyp = shutil.copy(hyp.get_path(), ".")
except shutil.SameFileError:
pass

sp.check_call(
[hubscr_path, "-V", "-l", "english", "-h", "hub5"]
+ sctk_opt
+ ["-g", tk.uncached_path(self.glm), "-r", ref, hyp]
+ ["-g", self.glm.get_path(), "-r", ref, hyp]
)

if move_files: # run as real job
Expand Down

0 comments on commit cb0e2a4

Please sign in to comment.