Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update in inference.py -- omitting the optimizer's state #165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 111 additions & 8 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,87 @@ def visualize_cells(img, cells, out_path):

return

def visualize_recognized_tables(img, rec_tables, out_path):
plt.imshow(img, interpolation="lanczos")
plt.gcf().set_size_inches(20, 20)
ax = plt.gca()

for rec_table in rec_tables:
bbox = rec_table['bbox']

if rec_table['label'] == 'table':
facecolor = (1, 0, 0.45)
edgecolor = (1, 0, 0.45)
alpha = 0.3
linewidth = 2
hatch='//////'
elif rec_table['label'] == 'table column':
facecolor = (0.95, 0.6, 0.1)
edgecolor = (0.95, 0.6, 0.1)
alpha = 0.3
linewidth = 2
hatch='//////'
elif rec_table['label'] == 'table row':
facecolor = (0.3, 0.74, 0.8)
edgecolor = (0.3, 0.7, 0.6)
alpha = 0.3
linewidth = 2
hatch='//////'
elif rec_table['label'] == 'table column header':
facecolor = (0.7, 0.3, 0.5)
edgecolor = (0.7, 0.3, 0.5)
alpha = 0.3
linewidth = 2
hatch='//////'
elif rec_table['label'] == 'projected row header':
facecolor = (0.2, 0.2, 0.7)
edgecolor = (0.2, 0.2, 0.7)
alpha = 0.3
linewidth = 2
hatch='//////'
elif rec_table['label'] == 'table spanning cell':
facecolor = (0.5, 0.83, 0.4)
edgecolor = (0.5, 0.83, 0.4)
alpha = 0.3
linewidth = 2
hatch='//////'
else:
continue

rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=facecolor, alpha=0.1)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
ax.add_patch(rect)

plt.xticks([], [])
plt.yticks([], [])

legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
label='Table', hatch='//////', alpha=0.3),
Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
label='Table Column', hatch='//////', alpha=0.3),
Patch(facecolor=(0.3, 0.74, 0.8), edgecolor=(0.3, 0.7, 0.6),
label='Table Row', hatch='//////', alpha=0.3),
Patch(facecolor=(0.7, 0.3, 0.5), edgecolor=(0.7, 0.3, 0.5),
label='Table Column Header', hatch='//////', alpha=0.3),
Patch(facecolor=(0.2, 0.2, 0.7), edgecolor=(0.2, 0.2, 0.7),
label='Projected Row Header', hatch='//////', alpha=0.3),
Patch(facecolor=(0.5, 0.83, 0.4), edgecolor=(0.5, 0.83, 0.4),
label='Table Spanning Cell', hatch='//////', alpha=0.3)]
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
fontsize=10, ncol=2)
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
plt.savefig(out_path, bbox_inches='tight', dpi=150)
plt.close()

return

class TableExtractionPipeline(object):
def __init__(self, det_device=None, str_device=None,
det_model=None, str_model=None,
Expand All @@ -697,8 +778,16 @@ def __init__(self, det_device=None, str_device=None,
print("Detection model initialized.")

if not det_model_path is None:
self.det_model.load_state_dict(torch.load(det_model_path,
map_location=torch.device(det_device)))
loaded_state_dict = torch.load(det_model_path,
map_location=torch.device(det_device))
model_state_dict = self.det_model.state_dict()
pretrained_dict = {
k: v
for k, v in loaded_state_dict.items()
if k in model_state_dict and model_state_dict[k].shape == v.shape
}
model_state_dict.update(pretrained_dict)
self.det_model.load_state_dict(model_state_dict, strict=True)
self.det_model.to(det_device)
self.det_model.eval()
print("Detection model weights loaded.")
Expand All @@ -714,8 +803,16 @@ def __init__(self, det_device=None, str_device=None,
print("Structure model initialized.")

if not str_model_path is None:
self.str_model.load_state_dict(torch.load(str_model_path,
map_location=torch.device(str_device)))
loaded_state_dict = torch.load(str_model_path,
map_location=torch.device(str_device))
model_state_dict = self.str_model.state_dict()
pretrained_dict = {
k: v
for k, v in loaded_state_dict.items()
if k in model_state_dict and model_state_dict[k].shape == v.shape
}
model_state_dict.update(pretrained_dict)
self.str_model.load_state_dict(model_state_dict, strict=True)
self.str_model.to(str_device)
self.str_model.eval()
print("Structure model weights loaded.")
Expand Down Expand Up @@ -830,7 +927,11 @@ def output_result(key, val, args, img, img_file):
if args.visualize:
out_file = img_file.replace(".jpg", "_fig_tables.jpg")
out_path = os.path.join(args.out_dir, out_file)
visualize_detected_tables(img, val, out_path)
if args.mode == 'detect':
visualize_detected_tables(img, val, out_path)
else:
visualize_recognized_tables(img, val, out_path)

elif not key == 'image' and not key == 'tokens':
for idx, elem in enumerate(val):
if key == 'crops':
Expand All @@ -852,6 +953,8 @@ def output_result(key, val, args, img, img_file):
out_path = os.path.join(args.out_dir, out_file)
visualize_cells(img, elem, out_path)
else:
if elem is None:
continue
out_file = img_file.replace(".jpg", "_{}.{}".format(idx, key))
with open(os.path.join(args.out_dir, out_file), 'w') as f:
f.write(elem)
Expand Down Expand Up @@ -910,7 +1013,7 @@ def main():
tokens = []

if args.mode == 'recognize':
extracted_table = pipe.recognize(img, tokens, out_objects=args.objects, out_cells=args.csv,
extracted_table = pipe.recognize(img, tokens, out_objects=args.objects, out_cells=args.cells,
out_html=args.html, out_csv=args.csv)
print("Table(s) recognized.")

Expand All @@ -925,7 +1028,7 @@ def main():
output_result(key, val, args, img, img_file)

if args.mode == 'extract':
extracted_tables = pipe.extract(img, tokens, out_objects=args.objects, out_cells=args.csv,
extracted_tables = pipe.extract(img, tokens, out_objects=args.objects, out_cells=args.cells,
out_html=args.html, out_csv=args.csv,
crop_padding=args.crop_padding)
print("Table(s) extracted.")
Expand All @@ -936,4 +1039,4 @@ def main():
img_file.replace('.jpg', '_{}.jpg'.format(table_idx)))

if __name__ == "__main__":
main()
main()