Skip to content

Commit

Permalink
Fix input_dim usage in element dim check
Browse files Browse the repository at this point in the history
Also test that error is correctly raised
  • Loading branch information
waltsims committed Oct 19, 2023
1 parent 9252d4e commit 97605c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
4 changes: 2 additions & 2 deletions kwave/utils/kwave_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def add_custom_element(self, integration_points, measure, element_dim, label):

assert isinstance(integration_points, (np.ndarray)), "'integration_points' must be a numpy array"
assert isinstance(measure, (int, float)), "'measure' must be an integer or float"
assert isinstance(element_dim, (int)) and element_dim in [2, 1], "'element_dim' must be an integer and either 2 or 3"
assert isinstance(element_dim, (int)) and element_dim in [1, 2, 3], "'element_dim' must be an integer and either 1, 2 or 3"
assert isinstance(label, (str)), "'label' must be a string"

# check the dimensionality of the integration points
Expand All @@ -244,7 +244,7 @@ def add_custom_element(self, integration_points, measure, element_dim, label):
# check that the element is being added to an array with the
# correct dimensions
if self.dim != input_dim:
raise ValueError(f"{element_dim}D custom element cannot be added to an array with {self.dim}D elements.")
raise ValueError(f"{input_dim}D custom element cannot be added to an array with {self.dim}D elements.")

self.number_elements += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
import pytest
from kwave.kgrid import kWaveGrid

from kwave.utils.kwave_array import kWaveArray
Expand Down Expand Up @@ -45,10 +46,19 @@ def test_kwave_array():
reader.increment()

kwave_array.add_custom_element(
np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3], [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float32),
9, 2, label='custom_3d'
integration_points=np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3], [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float32),
measure=9,
element_dim=2, label='custom_3d'
)
check_kwave_array_equality(kwave_array, reader.expected_value_of('kwave_array'))

with pytest.raises(ValueError):
kwave_array.add_custom_element(
integration_points=np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3]], dtype=np.float32),
measure=9,
element_dim=2, label='custom_3d'
)

reader.increment()

kwave_array.add_rect_element([12, -8, 0.3], 3, 4, [2, 4, 5])
Expand Down

0 comments on commit 97605c9

Please sign in to comment.