diff --git a/GeneralRelativity/Interpolation.py b/GeneralRelativity/Interpolation.py index 0789088..a3be31f 100644 --- a/GeneralRelativity/Interpolation.py +++ b/GeneralRelativity/Interpolation.py @@ -79,8 +79,13 @@ def calculate_stencils( """ dx = 1 + + # Not sure how to handle odd number of points, so assert that it is even + assert num_points % 2 == 0 + # Shift index to center around zero - half = int(np.floor(float(max_degree) / 2.0)) + + half = int(np.floor(float(num_points) / 2.0)) - 1 # Generate 3D meshgrid for coarse grid points coarse_grid_x, coarse_grid_y, coarse_grid_z = np.meshgrid( @@ -88,6 +93,7 @@ def calculate_stencils( np.arange(0 - half, num_points - half), np.arange(0 - half, num_points - half), ) + coarse_grid_points_index = np.vstack( [coarse_grid_x.ravel(), coarse_grid_y.ravel(), coarse_grid_z.ravel()] ).T @@ -160,6 +166,7 @@ def __init__( self.vecvals_array = [] # Vector values for interpolation self.grid_points_index_array = [] # Grid points indices for interpolation self.num_channels = num_channels + tol = 1e-10 # Cutting weight values below this tolerance, assumtion is that they are just numerical noise # Define fixed values for grid alignment if align_grids_with_lower_dim_values: @@ -173,7 +180,7 @@ def __init__( vecvals, grid_points_index = calculate_stencils( interp_point, num_points, max_degree ) - vecvals[np.abs(vecvals) < 1e-10] = 0 + vecvals[np.abs(vecvals) < tol] = 0 self.vecvals_array.append(vecvals.tolist()) self.grid_points_index_array.append(grid_points_index.tolist()) @@ -205,7 +212,7 @@ def __init__( # Create a convolutional kernel with zeros kernel = torch.zeros( - (num_channels, num_channels, kernel_size, kernel_size, kernel_size) + (num_channels, 1, kernel_size, kernel_size, kernel_size) ) # Find the minimum index for displacements to adjust kernel indexing @@ -218,11 +225,11 @@ def __init__( ) # Adjust index based on minimum displacement kernel[:, :, index[0], index[1], index[2]] = weight - conv_layer = nn.Conv3d(num_channels, num_channels, kernel_size) + conv_layer = nn.Conv3d( + num_channels, num_channels, kernel_size, groups=num_channels, bias=False + ) conv_layer.weight = nn.Parameter(kernel) - conv_layer.bias = nn.Parameter(torch.zeros(num_channels)) conv_layer.weight.requires_grad = learnable - conv_layer.bias.requires_grad = learnable self.conv_layers.append(conv_layer) def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -274,7 +281,6 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: self.conv_layers, ): convoluted_tensor = conv_layer(tensor) - # Update the interpolation tensor with the convolution results # This is done selectively based on the relative index interpolation[ @@ -311,9 +317,9 @@ def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] ) # Perform interpolation - for i in range(ghosts - 2, shape[2] - ghosts - 1): - for j in range(ghosts - 2, shape[3] - ghosts - 1): - for k in range(ghosts - 2, shape[4] - ghosts - 1): + for i in range(ghosts - 1, shape[2] - ghosts): + for j in range(ghosts - 1, shape[3] - ghosts): + for k in range(ghosts - 1, shape[4] - ghosts): index_for_input_array = torch.tensor([i, j, k]) for ( displacements, @@ -329,12 +335,12 @@ def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] result = 0 # Writing results to the interpolated array - ind = 2 * (index_for_input_array - (ghosts - 2)) + ( + ind = 2 * (index_for_input_array - (ghosts - 1)) + ( relative_index ) # This array gives the position of the interpolated point in the interpolated array relative to the input array position[ind[0], ind[1], ind[2]] = ( - index_for_input_array + relative_position + 1 + index_for_input_array + relative_position ) return position @@ -374,9 +380,9 @@ def non_vector_implementation( ) # Perform interpolation - for i in range(ghosts - 2, shape[2] - ghosts - 1): - for j in range(ghosts - 2, shape[3] - ghosts - 1): - for k in range(ghosts - 2, shape[4] - ghosts - 1): + for i in range(ghosts - 1, shape[2] - ghosts): + for j in range(ghosts - 1, shape[3] - ghosts): + for k in range(ghosts - 1, shape[4] - ghosts): index_for_input_array = torch.tensor([i, j, k]) for ( displacements, @@ -398,7 +404,7 @@ def non_vector_implementation( ) # Writing results to the interpolated array - ind = 2 * (index_for_input_array - (ghosts - 2)) + ( + ind = 2 * (index_for_input_array - (ghosts - 1)) + ( relative_index ) interpolation[:, :, ind[0], ind[1], ind[2]] = result @@ -456,15 +462,77 @@ def plot_grid_position(self): plt.savefig(f"interpolation_grid.png") plt.close() - - plt.subplot(1,2,1) - plt.imshow(interpolated[0,0,:,:,4]) - plt.subplot(1,2,2) - interpolated,_ = self.non_vector_implementation(x) - plt.imshow(interpolated[0,0,:,:,4]) + interpolated_old, _ = self.non_vector_implementation(x) + plt.plot(interpolated_old[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4]) + print(torch.mean(torch.abs(interpolated_old - interpolated))) plt.savefig(f"interpolation_results.png") + +def sinusoidal_function(x, y, z): + """ + A sinusoidal function of three variables x, y, and z. + """ + return np.sin(x) * np.sin(y) * np.sin(z) + + if __name__ == "__main__": - print_grid_lay_out() - interpolation = interp(6, 3, 25) + # print_grid_lay_out() + interpolation = interp(6, 3, 25, align_grids_with_lower_dim_values=True) interpolation.plot_grid_position() + + for centering in [True, False]: + tol = 1e-10 + channels = 25 + interpolation = interp( + num_points=6, + max_degree=3, + num_channels=channels, + learnable=False, + align_grids_with_lower_dim_values=centering, + ) + length = 10 + dx = 0.01 + + # Initializing a tensor of random values to represent the grid + x = torch.rand(2, channels, length, length, length) + + # Preparing input positions for the sinusoidal function + input_positions = torch.zeros(length, length, length, 3) + for i in range(x.shape[2]): + for j in range(x.shape[3]): + for k in range(x.shape[4]): + input_positions[i, j, k] = torch.tensor([i, j, k]) + pos = dx * np.array([i, j, k]) + x[:, :, i, j, k] = sinusoidal_function(*pos) + + # Perform interpolation and measure time taken + time1 = time.time() + interpolated, _ = interpolation.non_vector_implementation(x) + print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec") + positions = interpolation.get_postion(x) + + # Preparing ground truth for comparison + ghosts = int(math.ceil(6 / 2)) + shape = x.shape + ground_truth = torch.zeros( + shape[0], + shape[1], + (shape[2] - 2 * ghosts) * 2 + 2, + (shape[3] - 2 * ghosts) * 2 + 2, + (shape[4] - 2 * ghosts) * 2 + 2, + ) + + # Applying sinusoidal function to the interpolated positions + + shape = ground_truth.shape + # Perform interpolation + for i in range(shape[2]): + for j in range(shape[3]): + for k in range(shape[4]): + pos = dx * (positions[i, j, k]) + ground_truth[:, :, i, j, k] = sinusoidal_function(*pos) + + plt.plot(ground_truth[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4]) + plt.savefig(f"interpolation_results{centering}.png") + plt.close() + print(torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...])))