Skip to content

Commit

Permalink
added whole gird interpolation and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 12, 2024
1 parent 2d7bff4 commit ee8a65d
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 68 deletions.
135 changes: 67 additions & 68 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import itertools
import torch
import math
import time
from typing import Tuple


def print_grid_lay_out(
Expand Down Expand Up @@ -38,7 +40,7 @@ def print_grid_lay_out(
)
ax.scatter(interp_point[0], interp_point[1], label="interpolation point", c="black")
plt.legend()
plt.savefig("overview.png")
plt.savefig("layout_interpolation_grid.png")
plt.close()


Expand Down Expand Up @@ -133,17 +135,9 @@ def calculate_stencils(
return vecvals, coarse_grid_points_index


import numpy as np
import torch
import itertools
import math

class interp:
def __init__(
self,
num_points: int = 6,
max_degree: int = 3,
align_grids_with_lower_dim_values: bool = False,
self, num_points=6, max_degree=3, align_grids_with_lower_dim_values=False
):
"""
Initialize the interp class.
Expand Down Expand Up @@ -180,7 +174,7 @@ def __init__(
self.grid_points_index_array = torch.tensor(self.grid_points_index_array)
self.vecvals_array = torch.tensor(self.vecvals_array)

def __call__(self, tensor: torch.tensor) -> torch.tensor:
def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform the interpolation on the given tensor.
Expand All @@ -190,6 +184,7 @@ def __call__(self, tensor: torch.tensor) -> torch.tensor:
Returns:
torch.tensor: The interpolated tensor.
"""

# Calculate the number of ghost points based on num_points
ghosts = int(math.ceil(self.num_points / 2))
shape = tensor.shape
Expand All @@ -198,23 +193,23 @@ def __call__(self, tensor: torch.tensor) -> torch.tensor:
interpolation = torch.zeros(
shape[0],
shape[1],
(shape[2] - 2 * ghosts) * 2,
(shape[3] - 2 * ghosts) * 2,
(shape[4] - 2 * ghosts) * 2,
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
)

# Initialize a tensor to store positions
position = torch.zeros(
(shape[2] - 2 * ghosts) * 2,
(shape[3] - 2 * ghosts) * 2,
(shape[4] - 2 * ghosts) * 2,
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
3,
)

# Perform interpolation
for i in range(ghosts, shape[2] - ghosts):
for j in range(ghosts, shape[3] - ghosts):
for k in range(ghosts, shape[4] - ghosts):
for i in range(ghosts - 1, shape[2] - ghosts):
for j in range(ghosts - 1, shape[3] - ghosts):
for k in range(ghosts - 1, shape[4] - ghosts):
index_for_input_array = torch.tensor([i, j, k])
for (
displacements,
Expand All @@ -227,71 +222,75 @@ def __call__(self, tensor: torch.tensor) -> torch.tensor:
self.relative_index_for_interpolated_array,
self.relative_positions,
):
test = 0
result = 0
for displacement, weight in zip(displacements, weights):
# Ensure indices are scalar values
index = index_for_input_array + displacement
result += (
weight * tensor[:, :, index[0], index[1], index[2]]
)
test += weight

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - ghosts) + (relative_index)
ind = 2 * (index_for_input_array - (ghosts - 1)) + (
relative_index
)
interpolation[:, :, ind[0], ind[1], ind[2]] = result
# This array gives the position of the interpolated point in the interpolated array relative to the input array
position[ind[0], ind[1], ind[2]] = (
index_for_input_array + relative_position
)

return interpolation, position

def sinusoidal_function(self, x, y, z):
"""
A sinusoidal function of three variables x, y, and z.
"""
return np.sin(x) * np.sin(y) * np.sin(z)

def plot_grid_position(self):
length = 10
dx = 0.01
x = torch.rand(2, 25, length, length, length)
input_positions = torch.zeros(length, length, length, 3)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
for k in range(x.shape[4]):
input_positions[i, j, k] = torch.tensor([i, j, k])
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = self.sinusoidal_function(*pos)
time1 = time.time()
interpolated, positions = self(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
ghosts = int(math.ceil(6 / 2))
# Scatter plot
plt.scatter(
input_positions[:, :, 4, 0],
input_positions[:, :, 4, 1],
label="Input",
color="blue",
marker="o",
)

def sinusoidal_function(x, y, z):
"""
A sinusoidal function of three variables x, y, and z.
"""
return np.sin(x) * np.sin(y) * np.sin(z)
def sinusoidal_function(x, y, z):
"""
A sinusoidal function of three variables x, y, and z.
"""
return (x)
plt.scatter(
positions[:, :, 4, 0],
positions[:, :, 4, 1],
label="Interpolated",
color="red",
marker="x",
)

plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.10), ncol=2)
plt.xticks(input_positions[:, 0, 4, 0])
plt.yticks(input_positions[0, :, 4, 1])
plt.xlabel("X Position")
plt.ylabel("Y Position")
plt.grid(True)
plt.savefig(f"interpolation_grid.png")
plt.close()

import time

if __name__ == "__main__":
print_grid_lay_out()
interpolation = interp(6, 3)
length = 10
dx = 0.01
x = torch.rand(
2, 25, length, length, length
) # .random(10,25,length, length, length)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
for k in range(x.shape[4]):
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = sinusoidal_function(*pos)
time1 = time.time()
interpolated, positions = interpolation(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
ghosts = int(math.ceil(6 / 2))

print(interpolated.shape)
plt.subplot(1, 2, 1)
plt.title("interpolated")
plt.imshow(interpolated[0, 0, :, :, 4].numpy())
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("original")
plt.imshow(x[0, 0, ghosts:-ghosts, ghosts:-ghosts, 2].numpy())
plt.colorbar()
plt.show()
print(positions.shape)
plt.scatter(positions[:,4, 4,0].numpy(),interpolated[0, 0, :, 4, 4].numpy(), label="interpolated")
plt.scatter(np.arange(ghosts,length-ghosts),x[0, 0, ghosts:-ghosts, 4, 4].numpy(), label="original")

plt.legend()
plt.show()

plt.scatter(positions[:, :, 4, 0].numpy(),positions[:, :, 4, 1].numpy())
plt.show()
interpolation.plot_grid_position()
64 changes: 64 additions & 0 deletions tests/test_interpolations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,69 @@ def test_interpolation_stencils():
assert error < tol, f"Interpolation error {error} exceeds tolerance {tol}"


def test_interpolation_on_grid():
"""
Test the interpolation on a 3D grid.
This function initializes an interpolation object, creates a 3D grid of random values,
and applies sinusoidal function to populate the grid. It then interpolates these values
using the `interp` class and compares the interpolated values with the ground truth
obtained by directly applying the sinusoidal function to the interpolated positions.
An assertion is used to check if the interpolation error is within the specified tolerance.
Attributes:
tol (float): Tolerance level for the difference between interpolated and ground truth values.
length (int): Length of each dimension in the grid.
dx (float): Differential step to scale the grid positions.
"""
tol = 1e-10
interpolation = interp(6, 3)
length = 10
dx = 0.01

# Initializing a tensor of random values to represent the grid
x = torch.rand(2, 25, length, length, length)

# Preparing input positions for the sinusoidal function
input_positions = torch.zeros(length, length, length, 3)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
for k in range(x.shape[4]):
input_positions[i, j, k] = torch.tensor([i, j, k])
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = sinusoidal_function(*pos)

# Perform interpolation and measure time taken
time1 = time.time()
interpolated, positions = interpolation(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")

# Preparing ground truth for comparison
ghosts = int(math.ceil(6 / 2))
shape = x.shape
ground_truth = torch.zeros(
shape[0],
shape[1],
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
)

# Applying sinusoidal function to the interpolated positions

shape = ground_truth.shape
# Perform interpolation
for i in range(shape[2]):
for j in range(shape[3]):
for k in range(shape[4]):
pos = dx * (positions[i, j, k])
ground_truth[:, :, i, j, k] = sinusoidal_function(*pos)

assert (
torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...]))
) < tol


if __name__ == "__main__":
test_interpolation_stencils()
test_interpolation_on_grid()

0 comments on commit ee8a65d

Please sign in to comment.