diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index b8c442602..47cee0dc7 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -20,16 +20,17 @@ from tests.backend.api.conftest import * import tripy as tp +from tripy.frontend.trace.ops.storage import Storage class TestCompile: # TODO (#246): Verify that it's actually compiling somehow here and below. # Need to return something programatically queriable from compile to do this. def test_function(self): - compiled_gelu = tp.compile(tp.relu, args=[tp.InputInfo((2, 2), dtype=tp.float32)]) + compiled_relu = tp.compile(tp.relu, args=[tp.InputInfo((2, 2), dtype=tp.float32)]) - inp = tp.ones((2, 2), dtype=tp.float32) - out = compiled_gelu(inp) + inp = tp.iota((2, 2), dtype=tp.float32) - 1 + out = compiled_relu(inp) assert tp.equal(out, tp.relu(inp)) @@ -37,11 +38,21 @@ def test_module(self): layernorm = tp.LayerNorm(2) compiled_layernorm = tp.compile(layernorm, args=[tp.InputInfo((2, 2), dtype=tp.float32)]) - inp = tp.ones((2, 2), dtype=tp.float32) + inp = tp.iota((2, 2), dtype=tp.float32) - 1 out = compiled_layernorm(inp) assert tp.equal(out, layernorm(inp)) + def test_can_compile_using_shape_of_tensor(self): + # Since InputInfo allows `DimensionSize`s, we should be able to use the shape of a tensor as + # the shape of the InputInfo. + inp = tp.iota((2, 2), dtype=tp.float32) - 1 + shape = inp.shape + + compiled_relu = tp.compile(tp.relu, args=[tp.InputInfo(shape, inp.dtype)]) + out = compiled_relu(inp) + assert tp.equal(out, tp.relu(inp)) + def test_compile_arg_order_irrelevant(self): # The order of arguments we specify to `compile` should not affect the order # of the arguments in the compiled function, which should just follow the order diff --git a/tripy/tests/backend/api/test_input_info.py b/tripy/tests/backend/api/test_input_info.py index ef4986b12..f69ce76e9 100644 --- a/tripy/tests/backend/api/test_input_info.py +++ b/tripy/tests/backend/api/test_input_info.py @@ -22,10 +22,16 @@ class TestInput: @pytest.mark.parametrize( "shape, expected_min, expected_opt, expected_max", [ + # int: # min/opt/max explicitly specified - ([(1, 2, 3)], (1,), (2,), (3,)), + ([(1, 2, 3)], [1], [2], [3]), # Only one value specified - ([1], (1,), (1,), (1,)), + ([1], [1], [1], [1]), + # `DimensionSize`s: + # min/opt/max explicitly specified + ([(tp.DimensionSize(1), tp.DimensionSize(2), tp.DimensionSize(3))], [1], [2], [3]), + # Only one value specified + ([tp.DimensionSize(1)], [1], [1], [1]), ], ) def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max): @@ -41,14 +47,14 @@ def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max # Not a number ( (tp.int32, 1), - "Shape values should be either a single number or a Tuple specifying min/opt/max bounds.", + "Shape values should be either a single integer-like value or a 3-element tuple specifying min/opt/max bounds.", ), # Too few elements in dimension (((1, 1), 1), "Incorrect number of shape values provided"), # Too many elements in dimension (((1, 1, 1, 1), 1), "Incorrect number of shape values provided"), # Tuple containing a non-number - (((tp.int32, 1, 1), 1), "Shape values must be numbers"), + (((tp.int32, 1, 1), 1), "Shape values must be integers or `DimensionSize`s."), ], ) def test_invalid_shape(self, shape, expected_error): diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index 858f35a8d..a799ea523 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -498,16 +498,19 @@ def process_code_block_for_outputs_and_locals( code_start, code_end = get_code_bounds(block_lines) code = dedent("\n".join(block_lines[code_start:code_end])) - with capture_output() as outfile: - try: + try: + with capture_output() as outfile: code_locals = exec_code(code, local_vars) - except Exception as e: - if allow_exception: - print(f"Exception occurred: {str(e)}") - code_locals = local_vars - else: - print(err_msg) - raise + except Exception as e: + print( + f"Exception occurred while executing code block: {type(e).__name__}: {e}\n" + f"Note: Code block was:\n\n{block}" + ) + if allow_exception: + code_locals = local_vars + else: + print(err_msg) + raise new_locals = { key: value for key, value in code_locals.items() if key not in local_vars or value is not local_vars[key] diff --git a/tripy/tests/integration/test_conv_transpose.py b/tripy/tests/integration/test_conv_transpose.py index 2245d024b..434abb86b 100644 --- a/tripy/tests/integration/test_conv_transpose.py +++ b/tripy/tests/integration/test_conv_transpose.py @@ -280,14 +280,14 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled): output = eager_or_compiled(conv_layer, input) output_transpose = eager_or_compiled(conv_transpose_layer, input) - rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4 - assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) + rtol = 2e-7 if tp_dtype == tp.float32 else 9e-4 + assert tp.allclose(output, tp.Tensor(expected), rtol=rtol, atol=1e-5) assert output.shape == list(expected.shape) - assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=rtol_) + assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5) assert output_transpose.shape == list(expected_transpose.shape) - assert tp.allclose(output, output_transpose, rtol=rtol_) + assert tp.allclose(output, output_transpose, rtol=rtol, atol=1e-5) assert output.shape == output_transpose.shape - assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=rtol_) + assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5) assert list(expected.shape) == list(expected_transpose.shape) @pytest.mark.parametrize("test_case", test_cases_transpose_downscale) diff --git a/tripy/tests/integration/test_sequential.py b/tripy/tests/integration/test_sequential.py index dd5784de0..1f8962741 100644 --- a/tripy/tests/integration/test_sequential.py +++ b/tripy/tests/integration/test_sequential.py @@ -42,8 +42,7 @@ def test_basic_forward_pass_accuracy(self, eager_or_compiled): with torch.no_grad(): torch_output = torch_model(input_tensor) - rtol_ = 2e-6 - assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_) + assert torch.allclose(torch.from_dlpack(tp_output), torch_output, atol=1e-5, rtol=2e-6) def test_dict_forward_pass_accuracy(self, eager_or_compiled): torch_model = torch.nn.Sequential( diff --git a/tripy/tripy/backend/api/input_info.py b/tripy/tripy/backend/api/input_info.py index 2a2457cf8..28eb9b207 100644 --- a/tripy/tripy/backend/api/input_info.py +++ b/tripy/tripy/backend/api/input_info.py @@ -12,12 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numbers from typing import Sequence, Tuple, Union from tripy import export from tripy.common.exception import raise_error from tripy.common.shape_bounds import ShapeBounds +from tripy.frontend.dimension_size import DimensionSize +from tripy.types import IntLike @export.public_api(document_under="compiling_code") @@ -26,7 +27,7 @@ class InputInfo: Captures information about an input to a compiled function. """ - def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tripy.dtype") -> None: + def __init__(self, shape: Sequence[Union[IntLike, Tuple[IntLike, IntLike, IntLike]]], dtype: "tripy.dtype") -> None: """ Args: shape: The shape of the input. @@ -38,9 +39,9 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr :caption: Example inp = tp.InputInfo((2, 4), dtype=tp.float32) - assert inp.shape_bounds.min == (2, 4) - assert inp.shape_bounds.opt == (2, 4) - assert inp.shape_bounds.max == (2, 4) + assert inp.shape_bounds.min == [2, 4] + assert inp.shape_bounds.opt == [2, 4] + assert inp.shape_bounds.max == [2, 4] .. code-block:: python :linenos: @@ -49,22 +50,24 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr # The first dimension will support values in the range [1, 3], # optimizing for a size of 2. inp = tp.InputInfo(((1, 2, 3), 4), dtype=tp.float32) - assert inp.shape_bounds.min == (1, 4) - assert inp.shape_bounds.opt == (2, 4) - assert inp.shape_bounds.max == (3, 4) + assert inp.shape_bounds.min == [1, 4] + assert inp.shape_bounds.opt == [2, 4] + assert inp.shape_bounds.max == [3, 4] """ + is_int_like = lambda arg: any(isinstance(arg, typ) for typ in {int, DimensionSize}) + # TODO (#252): Allow `shape` to be a shape tensor min_shape = [] opt_shape = [] max_shape = [] for elem in shape: - if isinstance(elem, numbers.Number): + if is_int_like(elem): elem = (elem,) * 3 elif isinstance(elem, Sequence): - if not all(isinstance(val, numbers.Number) for val in elem): + if not all(is_int_like(val) for val in elem): raise_error( - "Shape values must be numbers.", - [f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"], + "Shape values must be integers or `DimensionSize`s.", + [f"Shape: {shape} contains an element of incorrect type: {repr(elem)}"], ) if len(elem) != 3: raise_error( @@ -76,7 +79,7 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr ) else: raise_error( - "Shape values should be either a single number or a Tuple specifying min/opt/max bounds.", + "Shape values should be either a single integer-like value or a 3-element tuple specifying min/opt/max bounds.", [f"Shape: {shape} contains an invalid element: {elem}"], ) @@ -84,7 +87,7 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr opt_shape.append(elem[1]) max_shape.append(elem[2]) - self.shape_bounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape)) + self.shape_bounds = ShapeBounds(min_shape, opt_shape, max_shape) self.dtype = dtype def __str__(self) -> str: diff --git a/tripy/tripy/common/shape_bounds.py b/tripy/tripy/common/shape_bounds.py index 800292c23..cf3dd1342 100644 --- a/tripy/tripy/common/shape_bounds.py +++ b/tripy/tripy/common/shape_bounds.py @@ -18,12 +18,14 @@ from dataclasses import dataclass from typing import Sequence +from tripy.types import IntLike + @dataclass class ShapeBounds: - min: Sequence[int] - opt: Sequence[int] - max: Sequence[int] + min: Sequence[IntLike] + opt: Sequence[IntLike] + max: Sequence[IntLike] def is_static(self): return self.min == self.opt == self.max diff --git a/tripy/tripy/frontend/ops/repeat.py b/tripy/tripy/frontend/ops/repeat.py index a152d7d22..6a08a3f85 100644 --- a/tripy/tripy/frontend/ops/repeat.py +++ b/tripy/tripy/frontend/ops/repeat.py @@ -18,6 +18,7 @@ from tripy import constraints, export from tripy.common.exception import raise_error from tripy.frontend import utils as frontend_utils +from tripy.types import IntLike @export.public_api(document_under="operations/functions") @@ -27,7 +28,7 @@ "T1": ["float32", "float16", "bfloat16", "int4", "float8", "int8", "int32", "int64", "bool"], }, ) -def repeat(input: "tripy.Tensor", repeats: Union[int, "tripy.DimensionSize"], dim: int) -> "tripy.Tensor": +def repeat(input: "tripy.Tensor", repeats: IntLike, dim: int) -> "tripy.Tensor": """ Repeats each element of a tensor after itself along the specified dimension. diff --git a/tripy/tripy/types.py b/tripy/tripy/types.py index 8f9d3a834..cc7e24b2c 100644 --- a/tripy/tripy/types.py +++ b/tripy/tripy/types.py @@ -33,15 +33,26 @@ module=sys.modules[__name__], symbol="TensorLike", doc=""" - Type annotation for a parameter that is either a Tripy :class:`Tensor` or a Python number that can be automatically converted into one. + A Tripy :class:`Tensor` or a Python number that can be automatically converted into one. """, )(Union["tripy.Tensor", numbers.Number]) + +IntLike = export.public_api( + document_under="types.rst", + module=sys.modules[__name__], + symbol="IntLike", + doc=""" + An integer-like object. + """, +)(Union[int, "tripy.DimensionSize"]) + + ShapeLike = export.public_api( document_under="types.rst", module=sys.modules[__name__], symbol="ShapeLike", doc=""" - Type annotation for a parameter that represents a shape. + A shape of a :class:`Tensor` . """, -)(Sequence[Union[int, "tripy.DimensionSize"]]) +)(Sequence[IntLike])