diff --git a/src/torch_image_lerp/linear_interpolation_2d.py b/src/torch_image_lerp/linear_interpolation_2d.py index 12960b2..d755de5 100644 --- a/src/torch_image_lerp/linear_interpolation_2d.py +++ b/src/torch_image_lerp/linear_interpolation_2d.py @@ -114,7 +114,9 @@ def insert_into_image_2d( coordinates = coordinates.float() # only keep data and coordinates inside the image - in_image_idx = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1) + in_image_idx = (coordinates >= 0) & ( + coordinates <= torch.tensor(image.shape, device=image.device) - 1 + ) in_image_idx = torch.all(in_image_idx, dim=-1) data, coordinates = data[in_image_idx], coordinates[in_image_idx] diff --git a/src/torch_image_lerp/linear_interpolation_3d.py b/src/torch_image_lerp/linear_interpolation_3d.py index eaef114..ef2dda4 100644 --- a/src/torch_image_lerp/linear_interpolation_3d.py +++ b/src/torch_image_lerp/linear_interpolation_3d.py @@ -114,7 +114,9 @@ def insert_into_image_3d( coordinates = coordinates.float() # only keep data and coordinates inside the volume - inside = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1) + inside = (coordinates >= 0) & ( + coordinates <= torch.tensor(image.shape, device=image.device) - 1 + ) inside = torch.all(inside, dim=-1) data, coordinates = data[inside], coordinates[inside]