Skip to content

Commit

Permalink
cleaning interpolation script
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Feb 4, 2024
1 parent 7b53e52 commit 57952a0
Showing 1 changed file with 0 additions and 58 deletions.
58 changes: 0 additions & 58 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def get_postion(self, tensor: torch.Tensor) -> torch.Tensor:
self.relative_positions,
):
result = 0

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - (ghosts - 1)) + (
relative_index
Expand Down Expand Up @@ -455,60 +454,3 @@ def sinusoidal_function(x, y, z):
# print_grid_lay_out()
interpolation = interp(6, 3, 25, align_grids_with_lower_dim_values=True)
interpolation.plot_grid_position()

for centering in [True, False]:
tol = 1e-10
channels = 25
interpolation = interp(
num_points=6,
max_degree=3,
num_channels=channels,
learnable=False,
align_grids_with_lower_dim_values=centering,
)
length = 10
dx = 0.01

# Initializing a tensor of random values to represent the grid
x = torch.rand(2, channels, length, length, length)

# Preparing input positions for the sinusoidal function
input_positions = torch.zeros(length, length, length, 3)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
for k in range(x.shape[4]):
input_positions[i, j, k] = torch.tensor([i, j, k])
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = sinusoidal_function(*pos)

# Perform interpolation and measure time taken
time1 = time.time()
interpolated, _ = interpolation.non_vector_implementation(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
positions = interpolation.get_postion(x)

# Preparing ground truth for comparison
ghosts = int(math.ceil(6 / 2))
shape = x.shape
ground_truth = torch.zeros(
shape[0],
shape[1],
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
)

# Applying sinusoidal function to the interpolated positions

shape = ground_truth.shape
# Perform interpolation
for i in range(shape[2]):
for j in range(shape[3]):
for k in range(shape[4]):
pos = dx * (positions[i, j, k])
ground_truth[:, :, i, j, k] = sinusoidal_function(*pos)

plt.plot(ground_truth[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4])
plt.savefig(f"interpolation_results{centering}.png")
plt.close()
print(torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...])))

0 comments on commit 57952a0

Please sign in to comment.