Skip to content

Commit

Permalink
pre-define conv layers at init
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 25, 2024
1 parent 913b7b0 commit 44ebca2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
72 changes: 47 additions & 25 deletions GeneralRelativity/Interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import time
from typing import Tuple
import torch.nn.functional as F
import torch.nn as nn


def print_grid_lay_out(
Expand Down Expand Up @@ -138,14 +138,21 @@ def calculate_stencils(

class interp:
def __init__(
self, num_points=6, max_degree=3, align_grids_with_lower_dim_values=False
self,
num_points: int = 6,
max_degree: int = 3,
num_channels: int = 1,
learnable: bool = False,
align_grids_with_lower_dim_values: bool = False,
):
"""
Initialize the interp class.
Initialize the Interp class.
Parameters:
num_points (int): Number of points to use in interpolation.
max_degree (int): The maximum degree of the polynomial used in interpolation.
num_channels (int): Number of channels in the input tensor.
learnable (bool): If True, the interpolation parameters are learnable.
align_grids_with_lower_dim_values (bool): If True, aligns grid points with lower-dimensional values.
"""
self.num_points = num_points
Expand Down Expand Up @@ -184,6 +191,39 @@ def __init__(
self.grid_points_index_array = torch.tensor(self.grid_points_index_array)
self.vecvals_array = torch.tensor(self.vecvals_array)

## Initialize the kernel for convolution

self.conv_layers = []

# Iterate over the displacement, weight, and relative position information
for (
displacements,
weights,
) in zip(self.grid_points_index_array, self.vecvals_array):
kernel_size = self.num_points # Size of the convolutional kernel

# Create a convolutional kernel with zeros
kernel = torch.zeros(
(num_channels, 1, kernel_size, kernel_size, kernel_size)
)

# Find the minimum index for displacements to adjust kernel indexing
min_index = torch.min(displacements)

# Populate the kernel with weights according to displacements
for displacement, weight in zip(displacements, weights):
index = (
displacement - min_index
) # Adjust index based on minimum displacement
kernel[:, :, index[0], index[1], index[2]] = weight

conv_layer = nn.Conv3d(num_channels, num_channels, kernel_size)
conv_layer.weight = nn.Parameter(kernel)
conv_layer.bias = nn.Parameter(torch.zeros(num_channels))
conv_layer.weight.requires_grad = learnable
conv_layer.bias.requires_grad = learnable
self.conv_layers.append(conv_layer)

def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform the interpolation on the given tensor.
Expand Down Expand Up @@ -219,38 +259,20 @@ def __call__(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
3, # 3D coordinates
)

# Iterate over the displacement, weight, and relative position information
for (
displacements,
weights,
relative_index,
relative_position,
conv_layer,
) in zip(
self.grid_points_index_array,
self.vecvals_array,
self.relative_index_for_interpolated_array,
self.relative_positions,
self.conv_layers,
):
num_channels = tensor.shape[1] # Number of channels in the input tensor
kernel_size = self.num_points # Size of the convolutional kernel

# Create a convolutional kernel with zeros
kernel = torch.zeros(
(num_channels, 1, kernel_size, kernel_size, kernel_size)
)

# Find the minimum index for displacements to adjust kernel indexing
min_index = torch.min(displacements)

# Populate the kernel with weights according to displacements
for displacement, weight in zip(displacements, weights):
index = (
displacement - min_index
) # Adjust index based on minimum displacement
kernel[:, :, index[0], index[1], index[2]] = weight

# Perform convolution using the created kernel
convoluted_tensor = F.conv3d(tensor, kernel, padding=0, groups=num_channels)
convoluted_tensor = conv_layer(tensor)

# Update the interpolation tensor with the convolution results
# This is done selectively based on the relative index
Expand Down Expand Up @@ -395,7 +417,7 @@ def sinusoidal_function(self, x, y, z):
def plot_grid_position(self):
length = 10
dx = 0.01
x = torch.rand(2, 25, length, length, length)
x = torch.rand(2, 1, length, length, length)
input_positions = torch.zeros(length, length, length, 3)
for i in range(x.shape[2]):
for j in range(x.shape[3]):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_interpolations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ def test_interpolation_on_grid():
"""
for centering in [True, False]:
tol = 1e-10
interpolation = interp(6, 3, centering)
channels = 25
interpolation = interp(6, 3, centering,channels)
length = 10
dx = 0.01

# Initializing a tensor of random values to represent the grid
x = torch.rand(2, 25, length, length, length)
x = torch.rand(2, channels, length, length, length)

# Preparing input positions for the sinusoidal function
input_positions = torch.zeros(length, length, length, 3)
Expand Down

0 comments on commit 44ebca2

Please sign in to comment.