Skip to content

Commit

Permalink
minor changes for new pyinterpx version
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jun 13, 2024
1 parent c7fe52a commit f762049
Showing 1 changed file with 36 additions and 57 deletions.
93 changes: 36 additions & 57 deletions learn_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,66 +63,26 @@ def main():
writer = SummaryWriter(f"{folder_name}")

# Loading small testdata
filenamesX = (
"/home/thelfer1/scr4_tedwar42/thelfer1/data_gen_binary/outputXdata_level1_*"
)
filenamesX = "/home/thelfer1/scr4_tedwar42/thelfer1/data_gen_binary/outputXdata_level1_step0050.dat"
num_varsX = 100
dataX = get_box_format(filenamesX, num_varsX)
# Cutting out extra values added for validation
dataX = dataX[:, :, :, :, :25]

class SuperResolution3DNet(torch.nn.Module):
def __init__(self):
super(SuperResolution3DNet, self).__init__()
self.points = 6
self.power = 3
self.channels = 25
self.interpolation = interp(
self.points, self.power, self.channels, False, True, torch.double
)

# Encoder
# The encoder consists of two 3D convolutional layers.
# The first conv layer expands the channel size from 25 to 64.
# The second conv layer further expands the channel size from 64 to 128.
# ReLU activation functions are used for non-linearity.
self.encoder = torch.nn.Sequential(
torch.nn.Conv3d(25, 64, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv3d(64, 25, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
)

# Decoder
# The decoder uses a transposed 3D convolution (or deconvolution) to upsample the feature maps.
# The channel size is reduced from 128 back to 64.
# A final 3D convolution reduces the channel size back to the original size of 25.
self.decoder = torch.nn.Sequential(
torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv3d(64, 25, kernel_size=3, padding=1),
)

# Initialize only the weights in self.encoder and self.decoder
self.initialize_encoder_decoder_weights()

def forward(self, x):
# Reusing the input data for faster learning

x = self.interpolation(x)
tmp = x
x = x + self.encoder(tmp)

return x

class SuperResolution3DNet(torch.nn.Module):
def __init__(self):
def __init__(self, factor):
super(SuperResolution3DNet, self).__init__()
self.points = 6
self.power = 3
self.channels = 25
self.interpolation = interp(
self.points, self.power, self.channels, False, True, torch.double
num_points=self.points,
max_degree=self.power,
num_channels=self.channels,
learnable=False,
align_corners=True,
factor=factor,
dtype=torch.double,
)

# Encoder
Expand Down Expand Up @@ -176,8 +136,9 @@ def forward(self, x):

return x

factor = 2
# Instantiate the model
net = SuperResolution3DNet().to(torch.double)
net = SuperResolution3DNet(factor).to(torch.double)

# Create a random 3D low-resolution input tensor (batch size, channels, depth, height, width)
input_tensor = torch.randn(
Expand Down Expand Up @@ -273,15 +234,15 @@ def __call__(self, output: torch.tensor, dummy: torch.tensor) -> torch.tensor:

# Note: it will slow down signficantly with BFGS steps, they are 10x slower, just be aware!
ADAMsteps = (
1000 # Will perform # steps of ADAM steps and then switch over to BFGS-L
1000000 # Will perform # steps of ADAM steps and then switch over to BFGS-L
)
n_steps = 1 # Total amount of steps
n_steps = 0 # Total amount of steps

net.train()
net.to(device)
net.interpolation.to(device)

my_loss = torch.nn.L1Loss()
# my_loss = torch.nn.L1Loss()
print("training")
pbar = trange(n_steps)
for i in pbar:
Expand Down Expand Up @@ -383,13 +344,20 @@ def closure():
writer.close()

# Get comparison with classical methods

(y_batch,) = next(iter(test_loader))
y_batch = y_batch.to(device)
X_batch = y_batch[:, :, ::2, ::2, ::2].clone()
y_batch = y_batch[
:, :25, diff - 1 : -diff - 1, diff - 1 : -diff - 1, diff - 1 : -diff - 1
]
# Interpolation compared to what is used typically in codes ( we interpolate between 6 values with polynomials x^i y^k z^k containing powers up to 3)
points = 6
power = 3
channels = 25
shape = X_batch.shape
interpolation = interp(points, power, channels, False, True, torch.double)
interpolation = interp(
points, power, channels, False, True, dtype=torch.double, factor=factor
)
ghosts = int(math.ceil(points / 2))
shape_higher_order = (shape[-1] - 2 * ghosts) * 2 + 2

Expand All @@ -399,7 +367,7 @@ def closure():

box = 0
channel = 0
slice = 4
slice = 5
# Note we remove some part of the grid as the interpolation needs space
max_val = torch.max(y_batch[box, channel, :, :, slice]).cpu().numpy()
min_val = torch.min(y_batch[box, channel, :, :, slice]).cpu().numpy()
Expand Down Expand Up @@ -455,7 +423,7 @@ def closure():

box = 0
channel = 0
slice = 4
slice = 5

net.eval()
y_pred = net(X_batch.detach())
Expand Down Expand Up @@ -525,6 +493,17 @@ def closure():
)
file.write("--------------------\n")

print(
f"Reference data L2 Ham {my_loss(y_batch[:, :, :, :, :], torch.tensor([])).detach().cpu().numpy()}\n"
)
print(
f"Neural Network L2 Ham {my_loss(y_pred[:, :, :, :, :], torch.tensor([])).detach().cpu().numpy()}\n"
)
print(
f"Interpolation L2 Ham {my_loss(y_interpolated, torch.tensor([])).detach().numpy()}\n"
)
print("--------------------\n")

# Calculate L1 performance
my_loss = torch.nn.L1Loss()

Expand Down

0 comments on commit f762049

Please sign in to comment.