Skip to content

Commit

Permalink
fixed bug in stencil calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 26, 2024
1 parent 61a1887 commit dcd67ec
Showing 1 changed file with 92 additions and 24 deletions.
116 changes: 92 additions & 24 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,21 @@ def calculate_stencils(
"""

dx = 1

# Not sure how to handle odd number of points, so assert that it is even
assert num_points % 2 == 0

# Shift index to center around zero
half = int(np.floor(float(max_degree) / 2.0))

half = int(np.floor(float(num_points) / 2.0)) - 1

# Generate 3D meshgrid for coarse grid points
coarse_grid_x, coarse_grid_y, coarse_grid_z = np.meshgrid(
np.arange(0 - half, num_points - half),
np.arange(0 - half, num_points - half),
np.arange(0 - half, num_points - half),
)

coarse_grid_points_index = np.vstack(
[coarse_grid_x.ravel(), coarse_grid_y.ravel(), coarse_grid_z.ravel()]
).T
Expand Down Expand Up @@ -160,6 +166,7 @@ def __init__(
self.vecvals_array = [] # Vector values for interpolation
self.grid_points_index_array = [] # Grid points indices for interpolation
self.num_channels = num_channels
tol = 1e-10 # Cutting weight values below this tolerance, assumtion is that they are just numerical noise

# Define fixed values for grid alignment
if align_grids_with_lower_dim_values:
Expand All @@ -173,7 +180,7 @@ def __init__(
vecvals, grid_points_index = calculate_stencils(
interp_point, num_points, max_degree
)
vecvals[np.abs(vecvals) < 1e-10] = 0
vecvals[np.abs(vecvals) < tol] = 0
self.vecvals_array.append(vecvals.tolist())
self.grid_points_index_array.append(grid_points_index.tolist())

Expand Down Expand Up @@ -205,7 +212,7 @@ def __init__(

# Create a convolutional kernel with zeros
kernel = torch.zeros(
(num_channels, num_channels, kernel_size, kernel_size, kernel_size)
(num_channels, 1, kernel_size, kernel_size, kernel_size)
)

# Find the minimum index for displacements to adjust kernel indexing
Expand All @@ -218,11 +225,11 @@ def __init__(
) # Adjust index based on minimum displacement
kernel[:, :, index[0], index[1], index[2]] = weight

conv_layer = nn.Conv3d(num_channels, num_channels, kernel_size)
conv_layer = nn.Conv3d(
num_channels, num_channels, kernel_size, groups=num_channels, bias=False
)
conv_layer.weight = nn.Parameter(kernel)
conv_layer.bias = nn.Parameter(torch.zeros(num_channels))
conv_layer.weight.requires_grad = learnable
conv_layer.bias.requires_grad = learnable
self.conv_layers.append(conv_layer)

def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -274,7 +281,6 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self.conv_layers,
):
convoluted_tensor = conv_layer(tensor)

# Update the interpolation tensor with the convolution results
# This is done selectively based on the relative index
interpolation[
Expand Down Expand Up @@ -311,9 +317,9 @@ def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
)

# Perform interpolation
for i in range(ghosts - 2, shape[2] - ghosts - 1):
for j in range(ghosts - 2, shape[3] - ghosts - 1):
for k in range(ghosts - 2, shape[4] - ghosts - 1):
for i in range(ghosts - 1, shape[2] - ghosts):
for j in range(ghosts - 1, shape[3] - ghosts):
for k in range(ghosts - 1, shape[4] - ghosts):
index_for_input_array = torch.tensor([i, j, k])
for (
displacements,
Expand All @@ -329,12 +335,12 @@ def get_postion(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
result = 0

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - (ghosts - 2)) + (
ind = 2 * (index_for_input_array - (ghosts - 1)) + (
relative_index
)
# This array gives the position of the interpolated point in the interpolated array relative to the input array
position[ind[0], ind[1], ind[2]] = (
index_for_input_array + relative_position + 1
index_for_input_array + relative_position
)

return position
Expand Down Expand Up @@ -374,9 +380,9 @@ def non_vector_implementation(
)

# Perform interpolation
for i in range(ghosts - 2, shape[2] - ghosts - 1):
for j in range(ghosts - 2, shape[3] - ghosts - 1):
for k in range(ghosts - 2, shape[4] - ghosts - 1):
for i in range(ghosts - 1, shape[2] - ghosts):
for j in range(ghosts - 1, shape[3] - ghosts):
for k in range(ghosts - 1, shape[4] - ghosts):
index_for_input_array = torch.tensor([i, j, k])
for (
displacements,
Expand All @@ -398,7 +404,7 @@ def non_vector_implementation(
)

# Writing results to the interpolated array
ind = 2 * (index_for_input_array - (ghosts - 2)) + (
ind = 2 * (index_for_input_array - (ghosts - 1)) + (
relative_index
)
interpolation[:, :, ind[0], ind[1], ind[2]] = result
Expand Down Expand Up @@ -456,15 +462,77 @@ def plot_grid_position(self):
plt.savefig(f"interpolation_grid.png")
plt.close()


plt.subplot(1,2,1)
plt.imshow(interpolated[0,0,:,:,4])
plt.subplot(1,2,2)
interpolated,_ = self.non_vector_implementation(x)
plt.imshow(interpolated[0,0,:,:,4])
interpolated_old, _ = self.non_vector_implementation(x)
plt.plot(interpolated_old[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4])
print(torch.mean(torch.abs(interpolated_old - interpolated)))
plt.savefig(f"interpolation_results.png")


def sinusoidal_function(x, y, z):
"""
A sinusoidal function of three variables x, y, and z.
"""
return np.sin(x) * np.sin(y) * np.sin(z)


if __name__ == "__main__":
print_grid_lay_out()
interpolation = interp(6, 3, 25)
# print_grid_lay_out()
interpolation = interp(6, 3, 25, align_grids_with_lower_dim_values=True)
interpolation.plot_grid_position()

for centering in [True, False]:
tol = 1e-10
channels = 25
interpolation = interp(
num_points=6,
max_degree=3,
num_channels=channels,
learnable=False,
align_grids_with_lower_dim_values=centering,
)
length = 10
dx = 0.01

# Initializing a tensor of random values to represent the grid
x = torch.rand(2, channels, length, length, length)

# Preparing input positions for the sinusoidal function
input_positions = torch.zeros(length, length, length, 3)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
for k in range(x.shape[4]):
input_positions[i, j, k] = torch.tensor([i, j, k])
pos = dx * np.array([i, j, k])
x[:, :, i, j, k] = sinusoidal_function(*pos)

# Perform interpolation and measure time taken
time1 = time.time()
interpolated, _ = interpolation.non_vector_implementation(x)
print(f"Time taken for interpolation: {(time.time() - time1):.2f} sec")
positions = interpolation.get_postion(x)

# Preparing ground truth for comparison
ghosts = int(math.ceil(6 / 2))
shape = x.shape
ground_truth = torch.zeros(
shape[0],
shape[1],
(shape[2] - 2 * ghosts) * 2 + 2,
(shape[3] - 2 * ghosts) * 2 + 2,
(shape[4] - 2 * ghosts) * 2 + 2,
)

# Applying sinusoidal function to the interpolated positions

shape = ground_truth.shape
# Perform interpolation
for i in range(shape[2]):
for j in range(shape[3]):
for k in range(shape[4]):
pos = dx * (positions[i, j, k])
ground_truth[:, :, i, j, k] = sinusoidal_function(*pos)

plt.plot(ground_truth[0, 0, 4, :, 4] - interpolated[0, 0, 4, :, 4])
plt.savefig(f"interpolation_results{centering}.png")
plt.close()
print(torch.mean(torch.abs(interpolated[0, 0, ...] - ground_truth[0, 0, ...])))

0 comments on commit dcd67ec

Please sign in to comment.