Skip to content

Commit

Permalink
Optimised interpolation by 50x
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 25, 2024
1 parent 681b21c commit 5a7afbc
Showing 1 changed file with 144 additions and 6 deletions.
150 changes: 144 additions & 6 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import time
from typing import Tuple
import torch.nn.functional as F


def print_grid_lay_out(
Expand Down Expand Up @@ -161,7 +162,9 @@ def __init__(

# Calculate vector values and grid points indices
for interp_point in self.relative_positions:
vecvals, grid_points_index = calculate_stencils(interp_point, num_points, max_degree)
vecvals, grid_points_index = calculate_stencils(
interp_point, num_points, max_degree
)
vecvals[np.abs(vecvals) < 1e-10] = 0
self.vecvals_array.append(vecvals.tolist())
self.grid_points_index_array.append(grid_points_index.tolist())
Expand All @@ -185,6 +188,140 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform the interpolation on the given tensor.
Parameters:
tensor (torch.Tensor): The input tensor to interpolate.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the interpolated tensor and the position tensor.
"""

# Calculate the number of ghost points based on num_points
# Ghost points are used for padding or handling edges during interpolation
ghosts = int(math.ceil(self.num_points / 2))
shape = tensor.shape

# Initialize the tensor for storing interpolation results
# The output tensor will have modified spatial dimensions based on the number of ghost points
interpolation = torch.zeros(
shape[0], # batch size
shape[1], # number of channels
(shape[2] - 2 * ghosts) * 2 + 2, # modified x dimension
(shape[3] - 2 * ghosts) * 2 + 2, # modified y dimension
(shape[4] - 2 * ghosts) * 2 + 2, # modified z dimension
)

# Initialize a tensor to store positions
# This tensor keeps track of the positions in the interpolated space
position = torch.zeros(
(shape[2] - 2 * ghosts) * 2 + 2, # x dimension
(shape[3] - 2 * ghosts) * 2 + 2, # y dimension
(shape[4] - 2 * ghosts) * 2 + 2, # z dimension
3, # 3D coordinates
)

# Iterate over the displacement, weight, and relative position information
for (
displacements,
weights,
relative_index,
relative_position,
) in zip(
self.grid_points_index_array,
self.vecvals_array,
self.relative_index_for_interpolated_array,
self.relative_positions,
):
num_channels = tensor.shape[1] # Number of channels in the input tensor
kernel_size = self.num_points # Size of the convolutional kernel

# Create a convolutional kernel with zeros
kernel = torch.zeros(
(num_channels, 1, kernel_size, kernel_size, kernel_size)
)

# Find the minimum index for displacements to adjust kernel indexing
min_index = torch.min(displacements)

# Populate the kernel with weights according to displacements
for displacement, weight in zip(displacements, weights):
index = (
displacement - min_index
) # Adjust index based on minimum displacement
kernel[:, :, index[0], index[1], index[2]] = weight

# Perform convolution using the created kernel
convoluted_tensor = F.conv3d(tensor, kernel, padding=0, groups=num_channels)

# Update the interpolation tensor with the convolution results
# This is done selectively based on the relative index
interpolation[
:,
:,
relative_index[0] :: 2,
relative_index[1] :: 2,
relative_index[2] :: 2,
] = convoluted_tensor

return interpolation

def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the position of the interpolated points.
Parameters:
tensor (torch.tensor): The input tensor to interpolate.
Returns:
torch.tensor: The interpolated tensor.
"""

# Calculate the number of ghost points based on num_points
ghosts = int(math.ceil(self.num_points / 2))
shape = tensor.shape

# Initialize a tensor to store positions
position = torch.zeros(
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
3,
)

# Perform interpolation
for i in range(ghosts - 2, shape[2] - ghosts - 1):
for j in range(ghosts - 2, shape[3] - ghosts - 1):
for k in range(ghosts - 2, shape[4] - ghosts - 1):
index_for_input_array = torch.tensor([i, j, k])
for (
displacements,
weights,
relative_index,
relative_position,
) in zip(
self.grid_points_index_array,
self.vecvals_array,
self.relative_index_for_interpolated_array,
self.relative_positions,
):
result = 0

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - (ghosts - 2)) + (
relative_index
)
# This array gives the position of the interpolated point in the interpolated array relative to the input array
position[ind[0], ind[1], ind[2]] = (
index_for_input_array + relative_position
)

return position

def non_vector_implementation(
self, tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform the interpolation on the given tensor.
Parameters:
tensor (torch.tensor): The input tensor to interpolate.
Expand Down Expand Up @@ -214,9 +351,9 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
)

# Perform interpolation
for i in range(ghosts - 1, shape[2] - ghosts-1):
for j in range(ghosts - 1, shape[3] - ghosts-1):
for k in range(ghosts - 1, shape[4] - ghosts-1):
for i in range(ghosts - 2, shape[2] - ghosts - 1):
for j in range(ghosts - 2, shape[3] - ghosts - 1):
for k in range(ghosts - 2, shape[4] - ghosts - 1):
index_for_input_array = torch.tensor([i, j, k])
for (
displacements,
Expand All @@ -238,7 +375,7 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
)

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - (ghosts - 1)) + (
ind = 2 * (index_for_input_array - (ghosts - 2)) + (
relative_index
)
interpolation[:, :, ind[0], ind[1], ind[2]] = result
Expand Down Expand Up @@ -267,8 +404,9 @@ def plot_grid_position(self):
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = self.sinusoidal_function(*pos)
time1 = time.time()
interpolated, positions = self(x)
interpolated = self(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
positions = self.get_postion(x)
ghosts = int(math.ceil(6 / 2))
# Scatter plot
plt.scatter(
Expand Down

0 comments on commit 5a7afbc

Please sign in to comment.