Skip to content

Commit

Permalink
allow setting return_type in predict_page
Browse files Browse the repository at this point in the history
  • Loading branch information
lolipopshock committed Jul 6, 2022
1 parent f6cd94b commit 5509e3b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/vila/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def predict_page(
visual_group_detector: Optional[Any] = None,
page_size=None,
batch_size=None,
) -> lp.Layout:
return_type: Optional[str] = "layout",
) -> Union[lp.Layout, List]:
"""The predict_page function is used for running the model on a single page
in the vila page_token objects.
Expand All @@ -187,8 +188,8 @@ def predict_page(
visual_group_detector:
The visual group model to use for detecting the required visual groups.
page_size (Tuple):
A tuple of (width, height) for this page. By default it will use the
page_size from the page_tokens directly unless the page_size is explicitly
A tuple of (width, height) for this page. By default it will use the
page_size from the page_tokens directly unless the page_size is explicitly
specified.
batch_size (Optional[int]):
Specifying the maximum number of batches for each model run.
Expand All @@ -198,7 +199,7 @@ def predict_page(
required_agg_level = self.preprocessor.config.agg_level
required_group = AGG_LEVEL_TO_GROUP_NAME[required_agg_level]

if not getattr(page_tokens, required_group + "s"): # either none or empty
if not getattr(page_tokens, required_group + "s"): # either none or empty
if page_image is not None and visual_group_detector is not None:
warnings.warn(
f"The required_group {required_group} is missing in page_tokens."
Expand All @@ -216,7 +217,7 @@ def predict_page(
page_data=pdf_data,
page_size=page_tokens.page_size if page_size is None else page_size,
batch_size=batch_size,
return_type="layout",
return_type=return_type,
)

return predicted_tokens
Expand Down

0 comments on commit 5509e3b

Please sign in to comment.