Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize MFAS #717

Merged
merged 15 commits into from
Nov 8, 2023
80 changes: 65 additions & 15 deletions gtsfm/averaging/translation/averaging_1dsfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
Authors: Jing Wu, Ayush Baid, Akshay Krishnan
"""
import time
import timeit
from collections import defaultdict
from enum import Enum
from typing import DefaultDict, Dict, List, Optional, Set, Tuple
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Any

import dask
import gtsam
import numpy as np
from distributed.worker import get_client
from gtsam import (
BinaryMeasurementsPoint3,
BinaryMeasurementPoint3,
Expand Down Expand Up @@ -54,15 +57,21 @@

# Minimum number of measurements required for a track to be used for averaging.
MIN_TRACK_MEASUREMENTS_FOR_AVERAGING = 3

# Number of track measurements to be added for each camera. Can be reduced to 8 for speed at the cost of some accuracy.
TRACKS_MEASUREMENTS_PER_CAMERA = 12

# Heuristically set to limit the number of delayed tasks, as recommended by Dask:
# https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-too-many-tasks
MAX_DELAYED_CALLS = 16

logger = logger_utils.get_logger()

C = symbol_shorthand.A # for camera translation variables
L = symbol_shorthand.B # for track (landmark) translation variables

RelativeDirectionsDict = Dict[Tuple[int, int], Unit3]
DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this.


class TranslationAveraging1DSFM(TranslationAveragingBase):
Expand All @@ -86,13 +95,16 @@ def __init__(
use_tracks_for_averaging: bool = True,
reject_outliers: bool = True,
projection_sampling_method: ProjectionSamplingMethod = ProjectionSamplingMethod.SAMPLE_WITH_UNIFORM_DENSITY,
max_delayed_calls: int = MAX_DELAYED_CALLS,
) -> None:
"""Initializes the 1DSFM averaging instance.

Args:
robust_measurement_noise: Whether to use a robust noise model for the measurements, defaults to true.
use_tracks_for_averaging:
reject_outliers: whether to perform outlier rejection with MFAS algorithm (default True).
projection_sampling_method: ProjectionSamplingMethod to be used for directions to run 1DSfM.
max_delayed_calls: Maximum number of concurrent delayed tasks to create.
"""
super().__init__(robust_measurement_noise)

Expand All @@ -101,6 +113,7 @@ def __init__(
self._reject_outliers = reject_outliers
self._projection_sampling_method = projection_sampling_method
self._use_tracks_for_averaging = use_tracks_for_averaging
self._max_delayed_calls = max_delayed_calls

def __sample_projection_directions(
self,
Expand Down Expand Up @@ -132,8 +145,8 @@ def __sample_projection_directions(

return projections

@staticmethod
def _binary_measurements_from_dict(
self,
w_i2Ui1_dict: RelativeDirectionsDict,
w_iUj_dict_tracks: RelativeDirectionsDict,
noise_model: gtsam.noiseModel,
Expand Down Expand Up @@ -191,6 +204,24 @@ def get_prior_in_world_frame(i2, i2Ti1_prior):
)
return w_i1ti2_prior_measurements

@staticmethod
def run_mfas(
w_i2Ui1_dict: RelativeDirectionsDict,
w_iUj_dict_tracks: RelativeDirectionsDict,
directions: List[Unit3],
) -> Dict[Tuple[int, int], float]:
"""Runs MFAS on a batch of directions."""
w_i1Ui2_measurements = TranslationAveraging1DSFM._binary_measurements_from_dict(
w_i2Ui1_dict, w_iUj_dict_tracks, DUMMY_NOISE_MODEL
)
results = []
for dir in directions:
# Note: Have to convert output of MFAS::computeOutlierWeights to Dict, as Dask has no instructions to pickle
# KeyPairDoubleMap objects.
results.append(dict(MFAS(w_i1Ui2_measurements, dir).computeOutlierWeights()))

return results

def compute_inliers(
self,
w_i2Ui1_dict: RelativeDirectionsDict,
Expand All @@ -214,24 +245,43 @@ def compute_inliers(
projection_directions = self.__sample_projection_directions(combined_measurements)

# Convert to measurements: map indexes to symbols.
dummy_noise_model = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this.
w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_iUj_dict_tracks, dummy_noise_model)
w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_iUj_dict_tracks, DUMMY_NOISE_MODEL)

# Scatter data to all workers if client available.
try:
client = get_client()
future_w_i2Ui1_dict = client.scatter(w_i2Ui1_dict, broadcast=True)
future_w_iUj_dict_tracks = client.scatter(w_iUj_dict_tracks, broadcast=True)
except ValueError: # this should only happen for unit tests.
logger.info("No Dask client found... Running without scattering.")
future_w_i2Ui1_dict = w_i2Ui1_dict
future_w_iUj_dict_tracks = w_iUj_dict_tracks

# Loop through tracks and and generate delayed MFAS tasks.
batch_size = int(np.ceil(len(projection_directions) / self._max_delayed_calls))
batched_outlier_weights: List[Any] = []
for j in range(0, len(projection_directions), batch_size):
batched_outlier_weights.append(
dask.delayed(self.run_mfas)(
future_w_i2Ui1_dict,
future_w_iUj_dict_tracks,
projection_directions[j : j + batch_size],
)
)

# Compute outlier weights using MFAS.
# TODO(ayush): parallelize this step.
outlier_weights: List[Dict[Tuple[int, int], float]] = []
for direction in projection_directions:
mfas_instance = MFAS(w_i1Ui2_measurements, direction)
outlier_weights.append(mfas_instance.computeOutlierWeights())
logger.debug("Computed outlier weights using MFAS.")
# Compute outlier weights in parallel.
_t2 = timeit.default_timer()
batched_outlier_weights = dask.compute(*batched_outlier_weights)
johnwlambert marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Computed outlier weights using MFAS in %.2f seconds." % (timeit.default_timer() - _t2))

# Compute average outlier weight.
outlier_weights_sum: DefaultDict[Tuple[int, int], float] = defaultdict(float)
inliers = set()
for outlier_weight_dict in outlier_weights:
for w_i1Ui2 in w_i1Ui2_measurements:
i1, i2 = w_i1Ui2.key1(), w_i1Ui2.key2()
outlier_weights_sum[(i1, i2)] += outlier_weight_dict[(i1, i2)]
for batch_outlier_weights in batched_outlier_weights:
for outlier_weight_dict in batch_outlier_weights:
for w_i1Ui2 in w_i1Ui2_measurements:
i1, i2 = w_i1Ui2.key1(), w_i1Ui2.key2()
outlier_weights_sum[(i1, i2)] += outlier_weight_dict[(i1, i2)]
for (i1, i2), weight_sum in outlier_weights_sum.items():
if weight_sum / len(projection_directions) < OUTLIER_WEIGHT_THRESHOLD:
inliers.add((i1, i2))
Expand Down
Loading