Skip to content

Commit

Permalink
made sure that everything is GPU compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jan 4, 2024
1 parent 4a36eaf commit d069af4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions GeneralRelativity/TensorAlgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def compute_christoffel(d1_metric: torch.tensor, h_UU: torch.tensor) -> torch.te
"""

chris = {
"LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"LLL": torch.zeros_like(d1_metric),
"ULL": torch.zeros_like(d1_metric),
}

# Compute Christoffel symbols of the first kind (LLL)
Expand Down Expand Up @@ -61,8 +61,8 @@ def compute_christoffel_fast(
# Initialize the output tensors
# batch, x, y, z, i, j, dx = d1_metric.shape
chris = {
"LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"LLL": torch.zeros_like(d1_metric),
"ULL": torch.zeros_like(d1_metric),
}

# Compute Christoffel symbols of the first kind (LLL)
Expand Down
1 change: 1 addition & 0 deletions GeneralRelativity/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __getitem__(self, key: list) -> torch.tensor:
+ (3, 3)
+ self.tensor.shape[(self.num_index + 1) :],
dtype=self.tensor.dtype,
device=self.device,
)
for i in range(3):
for j in range(i, 3):
Expand Down

0 comments on commit d069af4

Please sign in to comment.