Skip to content

Commit

Permalink
cleaning up interpolation code
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 26, 2024
1 parent dcd67ec commit c4f2a56
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
24 changes: 2 additions & 22 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,10 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
)

for (
displacements,
weights,
relative_index,
relative_position,
conv_layer,
) in zip(
self.grid_points_index_array,
self.vecvals_array,
self.relative_index_for_interpolated_array,
self.relative_positions,
self.conv_layers,
):
convoluted_tensor = conv_layer(tensor)
Expand Down Expand Up @@ -371,14 +365,6 @@ def non_vector_implementation(
(shape[4] - 2 * ghosts) * 2 + 2,
)

# Initialize a tensor to store positions
position = torch.zeros(
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
3,
)

# Perform interpolation
for i in range(ghosts - 1, shape[2] - ghosts):
for j in range(ghosts - 1, shape[3] - ghosts):
Expand Down Expand Up @@ -409,11 +395,8 @@ def non_vector_implementation(
)
interpolation[:, :, ind[0], ind[1], ind[2]] = result
# This array gives the position of the interpolated point in the interpolated array relative to the input array
position[ind[0], ind[1], ind[2]] = (
index_for_input_array + relative_position
)

return interpolation, position
return interpolation

def sinusoidal_function(self, x, y, z):
"""
Expand Down Expand Up @@ -462,10 +445,7 @@ def plot_grid_position(self):
plt.savefig(f"interpolation_grid.png")
plt.close()

interpolated_old, _ = self.non_vector_implementation(x)
plt.plot(interpolated_old[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4])
print(torch.mean(torch.abs(interpolated_old - interpolated)))
plt.savefig(f"interpolation_results.png")



def sinusoidal_function(x, y, z):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_interpolations.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_interpolation_on_grid():
# Perform interpolation and measure time taken
time1 = time.time()
interpolated = interpolation(x)
interpolated_old = interpolation.non_vector_implementation(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
positions = interpolation.get_postion(x)

Expand All @@ -131,10 +132,15 @@ def test_interpolation_on_grid():
pos = dx * (positions[i, j, k])
ground_truth[:, :, i, j, k] = sinusoidal_function(*pos)

# Comparing interpolated and ground truth values
assert (
torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...]))
) < tol

# Comparing old and new interpolation
assert(torch.mean(torch.abs(interpolated_old - interpolated))< tol)



if __name__ == "__main__":
test_interpolation_stencils()
Expand Down

0 comments on commit c4f2a56

Please sign in to comment.