From d069af4389b25558bfc9dc19a03d4d92d58d7ac2 Mon Sep 17 00:00:00 2001 From: thelfer1 Date: Thu, 4 Jan 2024 18:16:52 -0500 Subject: [PATCH] made sure that everything is GPU compatible --- GeneralRelativity/TensorAlgebra.py | 8 ++++---- GeneralRelativity/Utils.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/GeneralRelativity/TensorAlgebra.py b/GeneralRelativity/TensorAlgebra.py index 3c80d1f..d9d7662 100644 --- a/GeneralRelativity/TensorAlgebra.py +++ b/GeneralRelativity/TensorAlgebra.py @@ -22,8 +22,8 @@ def compute_christoffel(d1_metric: torch.tensor, h_UU: torch.tensor) -> torch.te """ chris = { - "LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype), - "ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype), + "LLL": torch.zeros_like(d1_metric), + "ULL": torch.zeros_like(d1_metric), } # Compute Christoffel symbols of the first kind (LLL) @@ -61,8 +61,8 @@ def compute_christoffel_fast( # Initialize the output tensors # batch, x, y, z, i, j, dx = d1_metric.shape chris = { - "LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype), - "ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype), + "LLL": torch.zeros_like(d1_metric), + "ULL": torch.zeros_like(d1_metric), } # Compute Christoffel symbols of the first kind (LLL) diff --git a/GeneralRelativity/Utils.py b/GeneralRelativity/Utils.py index 263e0fd..996c054 100644 --- a/GeneralRelativity/Utils.py +++ b/GeneralRelativity/Utils.py @@ -37,6 +37,7 @@ def __getitem__(self, key: list) -> torch.tensor: + (3, 3) + self.tensor.shape[(self.num_index + 1) :], dtype=self.tensor.dtype, + device=self.device, ) for i in range(3): for j in range(i, 3):