Skip to content

Commit

Permalink
Merge pull request #740 from borglab/refactor/netvlad
Browse files Browse the repository at this point in the history
Initialize NetVLAD in constructor
  • Loading branch information
travisdriver authored Nov 16, 2023
2 parents 799ada9 + fbafd3c commit 6fd70ed
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions gtsfm/frontend/global_descriptor/netvlad_global_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np
import torch
from torch import nn

from gtsfm.common.image import Image
from gtsfm.frontend.global_descriptor.global_descriptor_base import GlobalDescriptorBase
from thirdparty.hloc.netvlad import NetVLAD
Expand All @@ -22,7 +22,7 @@ class NetVLADGlobalDescriptor(GlobalDescriptorBase):

def __init__(self) -> None:
""" """
pass
self._model = NetVLAD().eval()

def describe(self, image: Image) -> np.ndarray:
"""Compute the NetVLAD global descriptor for a single image query.
Expand All @@ -33,16 +33,13 @@ def describe(self, image: Image) -> np.ndarray:
Returns:
img_desc: Array of shape (D,) representing global image descriptor.
"""
# Load model.
# Note: Initializing in the constructor leads to OOM.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: nn.Module = NetVLAD().to(device)
model.eval()
self._model.to(device)

img_tensor = (
torch.from_numpy(image.value_array).to(device).permute(2, 0, 1).unsqueeze(0).type(torch.float32) / 255
)
with torch.no_grad():
img_desc = model({"image": img_tensor})
img_desc = self._model({"image": img_tensor})

return img_desc["global_descriptor"].detach().squeeze().cpu().numpy()

0 comments on commit 6fd70ed

Please sign in to comment.