From 70ff30533f36d04fa7e52909eecf78f98c2ff9d0 Mon Sep 17 00:00:00 2001 From: Vaibhav Hiwase Date: Wed, 27 Dec 2023 17:46:47 +0530 Subject: [PATCH 1/2] Update inference.py During the inference phase when loading a trained model, there's a focus on exclusively loading the model's weights while disregarding or not loading the optimizer's state. This adjustment can be particularly beneficial when utilizing a pre-trained model for tasks such as inference or transfer learning, where the optimizer's state information might not be necessary. --- src/inference.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/inference.py b/src/inference.py index 564dd76..2a46056 100644 --- a/src/inference.py +++ b/src/inference.py @@ -697,8 +697,16 @@ def __init__(self, det_device=None, str_device=None, print("Detection model initialized.") if not det_model_path is None: - self.det_model.load_state_dict(torch.load(det_model_path, - map_location=torch.device(det_device))) + loaded_state_dict = torch.load(det_model_path, + map_location=torch.device(det_device)) + model_state_dict = self.det_model.state_dict() + pretrained_dict = { + k: v + for k, v in loaded_state_dict.items() + if k in model_state_dict and model_state_dict[k].shape == v.shape + } + model_state_dict.update(pretrained_dict) + self.det_model.load_state_dict(model_state_dict, strict=True) self.det_model.to(det_device) self.det_model.eval() print("Detection model weights loaded.") @@ -714,8 +722,16 @@ def __init__(self, det_device=None, str_device=None, print("Structure model initialized.") if not str_model_path is None: - self.str_model.load_state_dict(torch.load(str_model_path, - map_location=torch.device(str_device))) + loaded_state_dict = torch.load(str_model_path, + map_location=torch.device(str_device)) + model_state_dict = self.str_model.state_dict() + pretrained_dict = { + k: v + for k, v in loaded_state_dict.items() + if k in model_state_dict and model_state_dict[k].shape == v.shape + } + model_state_dict.update(pretrained_dict) + self.str_model.load_state_dict(model_state_dict, strict=True) self.str_model.to(str_device) self.str_model.eval() print("Structure model weights loaded.") @@ -936,4 +952,4 @@ def main(): img_file.replace('.jpg', '_{}.jpg'.format(table_idx))) if __name__ == "__main__": - main() \ No newline at end of file + main() From 50be367f76618099a0ac0ae5b522b498250335c2 Mon Sep 17 00:00:00 2001 From: Vaibhav Hiwase Date: Wed, 3 Jan 2024 13:14:03 +0530 Subject: [PATCH 2/2] adding visualization for args.mode == 'recognize' --- src/inference.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/src/inference.py b/src/inference.py index 2a46056..95dbad1 100644 --- a/src/inference.py +++ b/src/inference.py @@ -671,6 +671,87 @@ def visualize_cells(img, cells, out_path): return +def visualize_recognized_tables(img, rec_tables, out_path): + plt.imshow(img, interpolation="lanczos") + plt.gcf().set_size_inches(20, 20) + ax = plt.gca() + + for rec_table in rec_tables: + bbox = rec_table['bbox'] + + if rec_table['label'] == 'table': + facecolor = (1, 0, 0.45) + edgecolor = (1, 0, 0.45) + alpha = 0.3 + linewidth = 2 + hatch='//////' + elif rec_table['label'] == 'table column': + facecolor = (0.95, 0.6, 0.1) + edgecolor = (0.95, 0.6, 0.1) + alpha = 0.3 + linewidth = 2 + hatch='//////' + elif rec_table['label'] == 'table row': + facecolor = (0.3, 0.74, 0.8) + edgecolor = (0.3, 0.7, 0.6) + alpha = 0.3 + linewidth = 2 + hatch='//////' + elif rec_table['label'] == 'table column header': + facecolor = (0.7, 0.3, 0.5) + edgecolor = (0.7, 0.3, 0.5) + alpha = 0.3 + linewidth = 2 + hatch='//////' + elif rec_table['label'] == 'projected row header': + facecolor = (0.2, 0.2, 0.7) + edgecolor = (0.2, 0.2, 0.7) + alpha = 0.3 + linewidth = 2 + hatch='//////' + elif rec_table['label'] == 'table spanning cell': + facecolor = (0.5, 0.83, 0.4) + edgecolor = (0.5, 0.83, 0.4) + alpha = 0.3 + linewidth = 2 + hatch='//////' + else: + continue + + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='none',facecolor=facecolor, alpha=0.1) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, + edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) + ax.add_patch(rect) + + plt.xticks([], []) + plt.yticks([], []) + + legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), + label='Table', hatch='//////', alpha=0.3), + Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), + label='Table Column', hatch='//////', alpha=0.3), + Patch(facecolor=(0.3, 0.74, 0.8), edgecolor=(0.3, 0.7, 0.6), + label='Table Row', hatch='//////', alpha=0.3), + Patch(facecolor=(0.7, 0.3, 0.5), edgecolor=(0.7, 0.3, 0.5), + label='Table Column Header', hatch='//////', alpha=0.3), + Patch(facecolor=(0.2, 0.2, 0.7), edgecolor=(0.2, 0.2, 0.7), + label='Projected Row Header', hatch='//////', alpha=0.3), + Patch(facecolor=(0.5, 0.83, 0.4), edgecolor=(0.5, 0.83, 0.4), + label='Table Spanning Cell', hatch='//////', alpha=0.3)] + plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, + fontsize=10, ncol=2) + plt.gcf().set_size_inches(10, 10) + plt.axis('off') + plt.savefig(out_path, bbox_inches='tight', dpi=150) + plt.close() + + return + class TableExtractionPipeline(object): def __init__(self, det_device=None, str_device=None, det_model=None, str_model=None, @@ -846,7 +927,11 @@ def output_result(key, val, args, img, img_file): if args.visualize: out_file = img_file.replace(".jpg", "_fig_tables.jpg") out_path = os.path.join(args.out_dir, out_file) - visualize_detected_tables(img, val, out_path) + if args.mode == 'detect': + visualize_detected_tables(img, val, out_path) + else: + visualize_recognized_tables(img, val, out_path) + elif not key == 'image' and not key == 'tokens': for idx, elem in enumerate(val): if key == 'crops': @@ -868,6 +953,8 @@ def output_result(key, val, args, img, img_file): out_path = os.path.join(args.out_dir, out_file) visualize_cells(img, elem, out_path) else: + if elem is None: + continue out_file = img_file.replace(".jpg", "_{}.{}".format(idx, key)) with open(os.path.join(args.out_dir, out_file), 'w') as f: f.write(elem) @@ -926,7 +1013,7 @@ def main(): tokens = [] if args.mode == 'recognize': - extracted_table = pipe.recognize(img, tokens, out_objects=args.objects, out_cells=args.csv, + extracted_table = pipe.recognize(img, tokens, out_objects=args.objects, out_cells=args.cells, out_html=args.html, out_csv=args.csv) print("Table(s) recognized.") @@ -941,7 +1028,7 @@ def main(): output_result(key, val, args, img, img_file) if args.mode == 'extract': - extracted_tables = pipe.extract(img, tokens, out_objects=args.objects, out_cells=args.csv, + extracted_tables = pipe.extract(img, tokens, out_objects=args.objects, out_cells=args.cells, out_html=args.html, out_csv=args.csv, crop_padding=args.crop_padding) print("Table(s) extracted.")