diff --git a/recognition/scoring.py b/recognition/scoring.py index 036d8702..5cdde177 100644 --- a/recognition/scoring.py +++ b/recognition/scoring.py @@ -202,12 +202,27 @@ def calc_wer(self): class Hub5ScoreJob(Job): - def __init__(self, ref, glm, hyp): + __sis_hash_exclude__ = {"sctk_binary_path": None} + + def __init__( + self, + ref: tk.Path, + glm: tk.Path, + hyp: tk.Path, + sctk_binary_path: Optional[tk.Path] = None, + ): + """ + :param ref: reference stm text file + :param glm: text file containing mapping rules for scoring + :param hyp: hypothesis ctm text file + :param sctk_binary_path: set an explicit binary path. + """ self.set_vis_name("HubScore") self.glm = glm self.hyp = hyp self.ref = ref + self.sctk_binary_path = sctk_binary_path self.out_report_dir = self.output_path("reports", True) @@ -235,29 +250,33 @@ def tasks(self): yield Task("run", mini_task=True) def run(self, move_files=True): - hubscr_path = ( - os.path.join(gs.SCTK_PATH, "hubscr.pl") - if hasattr(gs, "SCTK_PATH") - else "hubscr.pl" - ) - sctk_opt = ["-p", gs.SCTK_PATH] if hasattr(gs, "SCTK_PATH") else [] + sctk_path = "" + if self.sctk_binary_path is not None: + sctk_path = self.sctk_binary_path.get_path() + elif hasattr(gs, "SCTK_PATH"): + sctk_path = gs.SCTK_PATH + hubscr_path = os.path.join( + sctk_path, "hubscr.pl" + ) # evaluates to just "hubscr.pl" if sctk_path is empty + + sctk_opt = ["-p", sctk_path] if sctk_path else [] ref = self.ref try: - ref = shutil.copy(tk.uncached_path(ref), ".") + ref = shutil.copy(ref.get_path(), ".") except shutil.SameFileError: pass hyp = self.hyp try: - hyp = shutil.copy(tk.uncached_path(hyp), ".") + hyp = shutil.copy(hyp.get_path(), ".") except shutil.SameFileError: pass sp.check_call( [hubscr_path, "-V", "-l", "english", "-h", "hub5"] + sctk_opt - + ["-g", tk.uncached_path(self.glm), "-r", ref, hyp] + + ["-g", self.glm.get_path(), "-r", ref, hyp] ) if move_files: # run as real job