Skip to content

Commit

Permalink
Add Essential Matrix Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
travisdriver committed Oct 15, 2024
1 parent 53aa1b2 commit 853cfdf
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 13 deletions.
14 changes: 14 additions & 0 deletions 2024_04_16_15_23_29_545171_aggregated_results_table.tsv

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions 2024_04_18_18_10_46_790577_aggregated_results_table.tsv

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions 2024_04_20_15_47_13_477451_aggregated_results_table.tsv

Large diffs are not rendered by default.

104 changes: 98 additions & 6 deletions gtsfm/view_graph_estimator/cycle_consistent_rotation_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import matplotlib.pyplot as plt
import numpy as np
from gtsam import Cal3Bundler, Rot3, Unit3

import gtsam
import gtsfm.utils.geometry_comparisons as comp_utils
import gtsfm.utils.graph as graph_utils
import gtsfm.utils.logger as logger_utils
import matplotlib.pyplot as plt
import numpy as np
from gtsam import Cal3Bundler, Rot3, Unit3
from gtsfm.common.keypoints import Keypoints
from gtsfm.two_view_estimator import TwoViewEstimationReport
from gtsfm.view_graph_estimator.view_graph_estimator_base import ViewGraphEstimatorBase
from gtsfm.view_graph_estimator.view_graph_estimator_base import \
ViewGraphEstimatorBase

logger = logger_utils.get_logger()

Expand All @@ -28,6 +29,8 @@
# threshold for evaluation w.r.t. GT
MAX_INLIER_MEASUREMENT_ERROR_DEG = 5.0

E = gtsam.symbol_shorthand.E # essential matrix


class EdgeErrorAggregationCriterion(str, Enum):
"""Aggregate cycle errors over each edge by choosing one of the following summary statistics:
Expand Down Expand Up @@ -108,6 +111,10 @@ def run(

logger.info("Number of triplets: %d" % len(triplets))

i2Ri1_dict, i2Ui1_dict = self.optimize_essential_matrices(
triplets, i2Ri1_dict, i2Ui1_dict, calibrations, corr_idxs_i1i2, keypoints
)

per_edge_errors = defaultdict(list)
cycle_errors: List[float] = []
max_gt_error_in_cycle = []
Expand Down Expand Up @@ -149,7 +156,7 @@ def run(
duration_sec,
)

return valid_edges
return valid_edges, i2Ri1_dict, i2Ui1_dict

def __save_plots(
self,
Expand Down Expand Up @@ -232,3 +239,88 @@ def __aggregate_errors_for_edge(self, edge_errors: List[float]) -> float:
return np.amin(edge_errors)
elif self._edge_error_aggregation_criterion == EdgeErrorAggregationCriterion.MEDIAN_EDGE_ERROR:
return np.median(edge_errors)


def optimize_essential_matrices(
self,
triplets: List[Tuple[int, int, int]],
i2Ri1_dict: Dict[Tuple[int, int], Rot3],
i2Ui1_dict: Dict[Tuple[int, int], Unit3],
calibrations: List[Cal3Bundler],
corr_idxs_i1i2: Dict[Tuple[int, int], np.ndarray],
keypoints: List[Keypoints],
):
# Create a factor graph container.
graph = gtsam.NonlinearFactorGraph()

# Add essential matrix factors.
noise_model = gtsam.noiseModel.Isotropic.Sigma(1, 1.0)
pair_to_key = {}
for i0, i1, i2 in triplets:

# i0 -> i1
pair = (i0, i1)
corr01 = corr_idxs_i1i2[pair]
mkps0, mkps1 = keypoints[pair[0]].coordinates[corr01[:, 0]], keypoints[pair[1]].coordinates[corr01[:, 1]]
if pair in pair_to_key:
key = pair_to_key[pair]
else:
key = E(len(pair_to_key))
pair_to_key[pair] = key
for kp0, kp1 in zip(mkps0, mkps1):
graph.add(
gtsam.EssentialMatrixFactor(
key, calibrations[pair[1]].calibrate(kp1), calibrations[pair[0]].calibrate(kp0), noise_model
)
)

# i0 -> i2
pair = (i0, i2)
corr01 = corr_idxs_i1i2[pair]
mkps0, mkps1 = keypoints[pair[0]].coordinates[corr01[:, 0]], keypoints[pair[1]].coordinates[corr01[:, 1]]
if pair in pair_to_key:
key = pair_to_key[pair]
else:
key = E(len(pair_to_key))
pair_to_key[pair] = key
for kp0, kp1 in zip(mkps0, mkps1):
graph.add(
gtsam.EssentialMatrixFactor(
key, calibrations[pair[1]].calibrate(kp1), calibrations[pair[0]].calibrate(kp0), noise_model
)
)

# i1 -> i2
pair = (i1, i2)
corr01 = corr_idxs_i1i2[pair]
mkps0, mkps1 = keypoints[pair[0]].coordinates[corr01[:, 0]], keypoints[pair[1]].coordinates[corr01[:, 1]]
if pair in pair_to_key:
key = pair_to_key[pair]
else:
key = E(len(pair_to_key))
pair_to_key[pair] = key
for kp0, kp1 in zip(mkps0, mkps1):
graph.add(
gtsam.EssentialMatrixFactor(
key, calibrations[pair[1]].calibrate(kp1), calibrations[pair[0]].calibrate(kp0), noise_model
)
)

# Create initial estimate.
initial = gtsam.Values()
for pair in i2Ri1_dict.keys():
initial.insert(pair_to_key[pair], gtsam.EssentialMatrix(i2Ri1_dict[pair], i2Ui1_dict[pair]))

# Optimize!
params = gtsam.LevenbergMarquardtParams()
params.setVerbosity("ERROR")
optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
result = optimizer.optimize()

# Add optimized rotations and translations to the dictionary.
for pair in i2Ri1_dict.keys():
E_opt = result.atEssentialMatrix(pair_to_key[pair])
i2Ri1_dict[pair] = E_opt.rotation()
i2Ui1_dict[pair] = E_opt.direction()

return i2Ri1_dict, i2Ui1_dict
11 changes: 5 additions & 6 deletions gtsfm/view_graph_estimator/view_graph_estimator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@
Authors: Akshay Krishnan, Ayush Baid, John Lambert
"""
import abc
import os
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import dask
import os
import numpy as np
from dask.delayed import Delayed
from gtsam import Cal3Bundler, Rot3, Unit3

import gtsfm.common.types as gtsfm_types
import gtsfm.utils.graph as graph_utils
import gtsfm.utils.logger as logger_utils
import gtsfm.utils.metrics as metrics_utils
import numpy as np
from dask.delayed import Delayed
from gtsam import Cal3Bundler, Rot3, Unit3
from gtsfm.common.keypoints import Keypoints
from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup
from gtsfm.two_view_estimator import TwoViewEstimationReport
Expand Down Expand Up @@ -302,7 +301,7 @@ def create_computation_graph(
)

# Run view graph estimation.
view_graph_edges = dask.delayed(self.run)(
view_graph_edges, i2Ri1_valid_dict, i2Ui1_valid_dict = dask.delayed(self.run, nout=3)(
i2Ri1_dict=i2Ri1_valid_dict,
i2Ui1_dict=i2Ui1_valid_dict,
calibrations=calibrations,
Expand Down
1 change: 1 addition & 0 deletions thirdparty/RoMa
Submodule RoMa added at fa66db

0 comments on commit 853cfdf

Please sign in to comment.