From c4f2a560110ee2b03f975fd8ec13c4a78d85e445 Mon Sep 17 00:00:00 2001 From: ThomasHelfer Date: Fri, 26 Jan 2024 11:45:19 -0500 Subject: [PATCH] cleaning up interpolation code --- GeneralRelativity/Interpolation.py | 24 ++---------------------- tests/test_interpolations.py | 6 ++++++ 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/GeneralRelativity/Interpolation.py b/GeneralRelativity/Interpolation.py index a3be31f..be41895 100644 --- a/GeneralRelativity/Interpolation.py +++ b/GeneralRelativity/Interpolation.py @@ -268,16 +268,10 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ) for ( - displacements, - weights, relative_index, - relative_position, conv_layer, ) in zip( - self.grid_points_index_array, - self.vecvals_array, self.relative_index_for_interpolated_array, - self.relative_positions, self.conv_layers, ): convoluted_tensor = conv_layer(tensor) @@ -371,14 +365,6 @@ def non_vector_implementation( (shape[4] - 2 * ghosts) * 2 + 2, ) - # Initialize a tensor to store positions - position = torch.zeros( - (shape[2] - 2 * ghosts) * 2 + 2, - (shape[3] - 2 * ghosts) * 2 + 2, - (shape[4] - 2 * ghosts) * 2 + 2, - 3, - ) - # Perform interpolation for i in range(ghosts - 1, shape[2] - ghosts): for j in range(ghosts - 1, shape[3] - ghosts): @@ -409,11 +395,8 @@ def non_vector_implementation( ) interpolation[:, :, ind[0], ind[1], ind[2]] = result # This array gives the position of the interpolated point in the interpolated array relative to the input array - position[ind[0], ind[1], ind[2]] = ( - index_for_input_array + relative_position - ) - return interpolation, position + return interpolation def sinusoidal_function(self, x, y, z): """ @@ -462,10 +445,7 @@ def plot_grid_position(self): plt.savefig(f"interpolation_grid.png") plt.close() - interpolated_old, _ = self.non_vector_implementation(x) - plt.plot(interpolated_old[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4]) - print(torch.mean(torch.abs(interpolated_old - interpolated))) - plt.savefig(f"interpolation_results.png") + def sinusoidal_function(x, y, z): diff --git a/tests/test_interpolations.py b/tests/test_interpolations.py index 22fad58..e5c5074 100644 --- a/tests/test_interpolations.py +++ b/tests/test_interpolations.py @@ -107,6 +107,7 @@ def test_interpolation_on_grid(): # Perform interpolation and measure time taken time1 = time.time() interpolated = interpolation(x) + interpolated_old = interpolation.non_vector_implementation(x) print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec") positions = interpolation.get_postion(x) @@ -131,10 +132,15 @@ def test_interpolation_on_grid(): pos = dx * (positions[i, j, k]) ground_truth[:, :, i, j, k] = sinusoidal_function(*pos) + # Comparing interpolated and ground truth values assert ( torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...])) ) < tol + # Comparing old and new interpolation + assert(torch.mean(torch.abs(interpolated_old - interpolated))< tol) + + if __name__ == "__main__": test_interpolation_stencils()