diff --git a/src/vila/predictors.py b/src/vila/predictors.py index c96e46f..dbc52e5 100644 --- a/src/vila/predictors.py +++ b/src/vila/predictors.py @@ -177,7 +177,8 @@ def predict_page( visual_group_detector: Optional[Any] = None, page_size=None, batch_size=None, - ) -> lp.Layout: + return_type: Optional[str] = "layout", + ) -> Union[lp.Layout, List]: """The predict_page function is used for running the model on a single page in the vila page_token objects. @@ -187,8 +188,8 @@ def predict_page( visual_group_detector: The visual group model to use for detecting the required visual groups. page_size (Tuple): - A tuple of (width, height) for this page. By default it will use the - page_size from the page_tokens directly unless the page_size is explicitly + A tuple of (width, height) for this page. By default it will use the + page_size from the page_tokens directly unless the page_size is explicitly specified. batch_size (Optional[int]): Specifying the maximum number of batches for each model run. @@ -198,7 +199,7 @@ def predict_page( required_agg_level = self.preprocessor.config.agg_level required_group = AGG_LEVEL_TO_GROUP_NAME[required_agg_level] - if not getattr(page_tokens, required_group + "s"): # either none or empty + if not getattr(page_tokens, required_group + "s"): # either none or empty if page_image is not None and visual_group_detector is not None: warnings.warn( f"The required_group {required_group} is missing in page_tokens." @@ -216,7 +217,7 @@ def predict_page( page_data=pdf_data, page_size=page_tokens.page_size if page_size is None else page_size, batch_size=batch_size, - return_type="layout", + return_type=return_type, ) return predicted_tokens