diff --git a/GeneralRelativity/Interpolation.py b/GeneralRelativity/Interpolation.py index 44a11ea..8747801 100644 --- a/GeneralRelativity/Interpolation.py +++ b/GeneralRelativity/Interpolation.py @@ -5,6 +5,7 @@ import math import time from typing import Tuple +import torch.nn.functional as F def print_grid_lay_out( @@ -161,7 +162,9 @@ def __init__( # Calculate vector values and grid points indices for interp_point in self.relative_positions: - vecvals, grid_points_index = calculate_stencils(interp_point, num_points, max_degree) + vecvals, grid_points_index = calculate_stencils( + interp_point, num_points, max_degree + ) vecvals[np.abs(vecvals) < 1e-10] = 0 self.vecvals_array.append(vecvals.tolist()) self.grid_points_index_array.append(grid_points_index.tolist()) @@ -185,6 +188,140 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform the interpolation on the given tensor. + Parameters: + tensor (torch.Tensor): The input tensor to interpolate. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the interpolated tensor and the position tensor. + """ + + # Calculate the number of ghost points based on num_points + # Ghost points are used for padding or handling edges during interpolation + ghosts = int(math.ceil(self.num_points / 2)) + shape = tensor.shape + + # Initialize the tensor for storing interpolation results + # The output tensor will have modified spatial dimensions based on the number of ghost points + interpolation = torch.zeros( + shape[0], # batch size + shape[1], # number of channels + (shape[2] - 2 * ghosts) * 2 + 2, # modified x dimension + (shape[3] - 2 * ghosts) * 2 + 2, # modified y dimension + (shape[4] - 2 * ghosts) * 2 + 2, # modified z dimension + ) + + # Initialize a tensor to store positions + # This tensor keeps track of the positions in the interpolated space + position = torch.zeros( + (shape[2] - 2 * ghosts) * 2 + 2, # x dimension + (shape[3] - 2 * ghosts) * 2 + 2, # y dimension + (shape[4] - 2 * ghosts) * 2 + 2, # z dimension + 3, # 3D coordinates + ) + + # Iterate over the displacement, weight, and relative position information + for ( + displacements, + weights, + relative_index, + relative_position, + ) in zip( + self.grid_points_index_array, + self.vecvals_array, + self.relative_index_for_interpolated_array, + self.relative_positions, + ): + num_channels = tensor.shape[1] # Number of channels in the input tensor + kernel_size = self.num_points # Size of the convolutional kernel + + # Create a convolutional kernel with zeros + kernel = torch.zeros( + (num_channels, 1, kernel_size, kernel_size, kernel_size) + ) + + # Find the minimum index for displacements to adjust kernel indexing + min_index = torch.min(displacements) + + # Populate the kernel with weights according to displacements + for displacement, weight in zip(displacements, weights): + index = ( + displacement - min_index + ) # Adjust index based on minimum displacement + kernel[:, :, index[0], index[1], index[2]] = weight + + # Perform convolution using the created kernel + convoluted_tensor = F.conv3d(tensor, kernel, padding=0, groups=num_channels) + + # Update the interpolation tensor with the convolution results + # This is done selectively based on the relative index + interpolation[ + :, + :, + relative_index[0] :: 2, + relative_index[1] :: 2, + relative_index[2] :: 2, + ] = convoluted_tensor + + return interpolation + + def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the position of the interpolated points. + + Parameters: + tensor (torch.tensor): The input tensor to interpolate. + + Returns: + torch.tensor: The interpolated tensor. + """ + + # Calculate the number of ghost points based on num_points + ghosts = int(math.ceil(self.num_points / 2)) + shape = tensor.shape + + # Initialize a tensor to store positions + position = torch.zeros( + (shape[2] - 2 * ghosts) * 2 + 2, + (shape[3] - 2 * ghosts) * 2 + 2, + (shape[4] - 2 * ghosts) * 2 + 2, + 3, + ) + + # Perform interpolation + for i in range(ghosts - 2, shape[2] - ghosts - 1): + for j in range(ghosts - 2, shape[3] - ghosts - 1): + for k in range(ghosts - 2, shape[4] - ghosts - 1): + index_for_input_array = torch.tensor([i, j, k]) + for ( + displacements, + weights, + relative_index, + relative_position, + ) in zip( + self.grid_points_index_array, + self.vecvals_array, + self.relative_index_for_interpolated_array, + self.relative_positions, + ): + result = 0 + + # Writing results to the interpolated array + ind = 2 * (index_for_input_array - (ghosts - 2)) + ( + relative_index + ) + # This array gives the position of the interpolated point in the interpolated array relative to the input array + position[ind[0], ind[1], ind[2]] = ( + index_for_input_array + relative_position + ) + + return position + + def non_vector_implementation( + self, tensor: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform the interpolation on the given tensor. + Parameters: tensor (torch.tensor): The input tensor to interpolate. @@ -214,9 +351,9 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ) # Perform interpolation - for i in range(ghosts - 1, shape[2] - ghosts-1): - for j in range(ghosts - 1, shape[3] - ghosts-1): - for k in range(ghosts - 1, shape[4] - ghosts-1): + for i in range(ghosts - 2, shape[2] - ghosts - 1): + for j in range(ghosts - 2, shape[3] - ghosts - 1): + for k in range(ghosts - 2, shape[4] - ghosts - 1): index_for_input_array = torch.tensor([i, j, k]) for ( displacements, @@ -238,7 +375,7 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ) # Writing results to the interpolated array - ind = 2 * (index_for_input_array - (ghosts - 1)) + ( + ind = 2 * (index_for_input_array - (ghosts - 2)) + ( relative_index ) interpolation[:, :, ind[0], ind[1], ind[2]] = result @@ -267,8 +404,9 @@ def plot_grid_position(self): pos = dx * np.array([i, j, k]) x[:, :, i, j, k] = self.sinusoidal_function(*pos) time1 = time.time() - interpolated, positions = self(x) + interpolated = self(x) print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec") + positions = self.get_postion(x) ghosts = int(math.ceil(6 / 2)) # Scatter plot plt.scatter(