diff --git a/gtsfm/averaging/translation/averaging_1dsfm.py b/gtsfm/averaging/translation/averaging_1dsfm.py index 50e4cf325..79fd3eea6 100644 --- a/gtsfm/averaging/translation/averaging_1dsfm.py +++ b/gtsfm/averaging/translation/averaging_1dsfm.py @@ -11,12 +11,15 @@ Authors: Jing Wu, Ayush Baid, Akshay Krishnan """ import time +import timeit from collections import defaultdict from enum import Enum -from typing import DefaultDict, Dict, List, Optional, Set, Tuple +from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Any +import dask import gtsam import numpy as np +from distributed.worker import get_client from gtsam import ( BinaryMeasurementsPoint3, BinaryMeasurementPoint3, @@ -54,15 +57,21 @@ # Minimum number of measurements required for a track to be used for averaging. MIN_TRACK_MEASUREMENTS_FOR_AVERAGING = 3 + # Number of track measurements to be added for each camera. Can be reduced to 8 for speed at the cost of some accuracy. TRACKS_MEASUREMENTS_PER_CAMERA = 12 +# Heuristically set to limit the number of delayed tasks, as recommended by Dask: +# https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-too-many-tasks +MAX_DELAYED_CALLS = 16 + logger = logger_utils.get_logger() C = symbol_shorthand.A # for camera translation variables L = symbol_shorthand.B # for track (landmark) translation variables RelativeDirectionsDict = Dict[Tuple[int, int], Unit3] +DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this. class TranslationAveraging1DSFM(TranslationAveragingBase): @@ -86,13 +95,16 @@ def __init__( use_tracks_for_averaging: bool = True, reject_outliers: bool = True, projection_sampling_method: ProjectionSamplingMethod = ProjectionSamplingMethod.SAMPLE_WITH_UNIFORM_DENSITY, + max_delayed_calls: int = MAX_DELAYED_CALLS, ) -> None: """Initializes the 1DSFM averaging instance. Args: robust_measurement_noise: Whether to use a robust noise model for the measurements, defaults to true. + use_tracks_for_averaging: reject_outliers: whether to perform outlier rejection with MFAS algorithm (default True). projection_sampling_method: ProjectionSamplingMethod to be used for directions to run 1DSfM. + max_delayed_calls: Maximum number of concurrent delayed tasks to create. """ super().__init__(robust_measurement_noise) @@ -101,6 +113,7 @@ def __init__( self._reject_outliers = reject_outliers self._projection_sampling_method = projection_sampling_method self._use_tracks_for_averaging = use_tracks_for_averaging + self._max_delayed_calls = max_delayed_calls def __sample_projection_directions( self, @@ -132,8 +145,8 @@ def __sample_projection_directions( return projections + @staticmethod def _binary_measurements_from_dict( - self, w_i2Ui1_dict: RelativeDirectionsDict, w_iUj_dict_tracks: RelativeDirectionsDict, noise_model: gtsam.noiseModel, @@ -191,6 +204,24 @@ def get_prior_in_world_frame(i2, i2Ti1_prior): ) return w_i1ti2_prior_measurements + @staticmethod + def run_mfas( + w_i2Ui1_dict: RelativeDirectionsDict, + w_iUj_dict_tracks: RelativeDirectionsDict, + directions: List[Unit3], + ) -> Dict[Tuple[int, int], float]: + """Runs MFAS on a batch of directions.""" + w_i1Ui2_measurements = TranslationAveraging1DSFM._binary_measurements_from_dict( + w_i2Ui1_dict, w_iUj_dict_tracks, DUMMY_NOISE_MODEL + ) + results = [] + for dir in directions: + # Note: Have to convert output of MFAS::computeOutlierWeights to Dict, as Dask has no instructions to pickle + # KeyPairDoubleMap objects. + results.append(dict(MFAS(w_i1Ui2_measurements, dir).computeOutlierWeights())) + + return results + def compute_inliers( self, w_i2Ui1_dict: RelativeDirectionsDict, @@ -214,24 +245,43 @@ def compute_inliers( projection_directions = self.__sample_projection_directions(combined_measurements) # Convert to measurements: map indexes to symbols. - dummy_noise_model = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this. - w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_iUj_dict_tracks, dummy_noise_model) + w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_iUj_dict_tracks, DUMMY_NOISE_MODEL) + + # Scatter data to all workers if client available. + try: + client = get_client() + future_w_i2Ui1_dict = client.scatter(w_i2Ui1_dict, broadcast=True) + future_w_iUj_dict_tracks = client.scatter(w_iUj_dict_tracks, broadcast=True) + except ValueError: # allows use without initializing client. + logger.info("No Dask client found... Running without scattering.") + future_w_i2Ui1_dict = w_i2Ui1_dict + future_w_iUj_dict_tracks = w_iUj_dict_tracks + + # Loop through tracks and and generate delayed MFAS tasks. + batch_size = int(np.ceil(len(projection_directions) / self._max_delayed_calls)) + batched_outlier_weights: List[Any] = [] + for j in range(0, len(projection_directions), batch_size): + batched_outlier_weights.append( + dask.delayed(self.run_mfas)( + future_w_i2Ui1_dict, + future_w_iUj_dict_tracks, + projection_directions[j : j + batch_size], + ) + ) - # Compute outlier weights using MFAS. - # TODO(ayush): parallelize this step. - outlier_weights: List[Dict[Tuple[int, int], float]] = [] - for direction in projection_directions: - mfas_instance = MFAS(w_i1Ui2_measurements, direction) - outlier_weights.append(mfas_instance.computeOutlierWeights()) - logger.debug("Computed outlier weights using MFAS.") + # Compute outlier weights in parallel. + _t2 = timeit.default_timer() + batched_outlier_weights = dask.compute(*batched_outlier_weights) + logger.info("Computed outlier weights using MFAS in %.2f seconds." % (timeit.default_timer() - _t2)) # Compute average outlier weight. outlier_weights_sum: DefaultDict[Tuple[int, int], float] = defaultdict(float) inliers = set() - for outlier_weight_dict in outlier_weights: - for w_i1Ui2 in w_i1Ui2_measurements: - i1, i2 = w_i1Ui2.key1(), w_i1Ui2.key2() - outlier_weights_sum[(i1, i2)] += outlier_weight_dict[(i1, i2)] + for batch_outlier_weights in batched_outlier_weights: + for outlier_weight_dict in batch_outlier_weights: + for w_i1Ui2 in w_i1Ui2_measurements: + i1, i2 = w_i1Ui2.key1(), w_i1Ui2.key2() + outlier_weights_sum[(i1, i2)] += outlier_weight_dict[(i1, i2)] for (i1, i2), weight_sum in outlier_weights_sum.items(): if weight_sum / len(projection_directions) < OUTLIER_WEIGHT_THRESHOLD: inliers.add((i1, i2))