Skip to content

Commit

Permalink
Allows DimensionSizes to be used in InputInfo
Browse files Browse the repository at this point in the history
Introduces a new `IntLike` type alias that covers integers and `DimensionSize`s
and updates `InputInfo` to accept shape arguments of this new type. Note that
the `DimensionSize`s must be evaluated in order to populate the profile correctly.

This allows for using the shapes of eager mode tensors when compiling.
  • Loading branch information
pranavm-nvidia committed Dec 4, 2024
1 parent 913afcd commit b71a868
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 45 deletions.
19 changes: 15 additions & 4 deletions tripy/tests/backend/api/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,39 @@
from tests.backend.api.conftest import *

import tripy as tp
from tripy.frontend.trace.ops.storage import Storage


class TestCompile:
# TODO (#246): Verify that it's actually compiling somehow here and below.
# Need to return something programatically queriable from compile to do this.
def test_function(self):
compiled_gelu = tp.compile(tp.relu, args=[tp.InputInfo((2, 2), dtype=tp.float32)])
compiled_relu = tp.compile(tp.relu, args=[tp.InputInfo((2, 2), dtype=tp.float32)])

inp = tp.ones((2, 2), dtype=tp.float32)
out = compiled_gelu(inp)
inp = tp.iota((2, 2), dtype=tp.float32) - 1
out = compiled_relu(inp)

assert tp.equal(out, tp.relu(inp))

def test_module(self):
layernorm = tp.LayerNorm(2)
compiled_layernorm = tp.compile(layernorm, args=[tp.InputInfo((2, 2), dtype=tp.float32)])

inp = tp.ones((2, 2), dtype=tp.float32)
inp = tp.iota((2, 2), dtype=tp.float32) - 1
out = compiled_layernorm(inp)

assert tp.equal(out, layernorm(inp))

def test_can_compile_using_shape_of_tensor(self):
# Since InputInfo allows `DimensionSize`s, we should be able to use the shape of a tensor as
# the shape of the InputInfo.
inp = tp.iota((2, 2), dtype=tp.float32) - 1
shape = inp.shape

compiled_relu = tp.compile(tp.relu, args=[tp.InputInfo(shape, inp.dtype)])
out = compiled_relu(inp)
assert tp.equal(out, tp.relu(inp))

def test_compile_arg_order_irrelevant(self):
# The order of arguments we specify to `compile` should not affect the order
# of the arguments in the compiled function, which should just follow the order
Expand Down
14 changes: 10 additions & 4 deletions tripy/tests/backend/api/test_input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ class TestInput:
@pytest.mark.parametrize(
"shape, expected_min, expected_opt, expected_max",
[
# int:
# min/opt/max explicitly specified
([(1, 2, 3)], (1,), (2,), (3,)),
([(1, 2, 3)], [1], [2], [3]),
# Only one value specified
([1], (1,), (1,), (1,)),
([1], [1], [1], [1]),
# `DimensionSize`s:
# min/opt/max explicitly specified
([(tp.DimensionSize(1), tp.DimensionSize(2), tp.DimensionSize(3))], [1], [2], [3]),
# Only one value specified
([tp.DimensionSize(1)], [1], [1], [1]),
],
)
def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max):
Expand All @@ -41,14 +47,14 @@ def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max
# Not a number
(
(tp.int32, 1),
"Shape values should be either a single number or a Tuple specifying min/opt/max bounds.",
"Shape values should be either a single integer-like value or a 3-element tuple specifying min/opt/max bounds.",
),
# Too few elements in dimension
(((1, 1), 1), "Incorrect number of shape values provided"),
# Too many elements in dimension
(((1, 1, 1, 1), 1), "Incorrect number of shape values provided"),
# Tuple containing a non-number
(((tp.int32, 1, 1), 1), "Shape values must be numbers"),
(((tp.int32, 1, 1), 1), "Shape values must be integers or `DimensionSize`s."),
],
)
def test_invalid_shape(self, shape, expected_error):
Expand Down
21 changes: 12 additions & 9 deletions tripy/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,19 @@ def process_code_block_for_outputs_and_locals(
code_start, code_end = get_code_bounds(block_lines)
code = dedent("\n".join(block_lines[code_start:code_end]))

with capture_output() as outfile:
try:
try:
with capture_output() as outfile:
code_locals = exec_code(code, local_vars)
except Exception as e:
if allow_exception:
print(f"Exception occurred: {str(e)}")
code_locals = local_vars
else:
print(err_msg)
raise
except Exception as e:
print(
f"Exception occurred while executing code block: {type(e).__name__}: {e}\n"
f"Note: Code block was:\n\n{block}"
)
if allow_exception:
code_locals = local_vars
else:
print(err_msg)
raise

new_locals = {
key: value for key, value in code_locals.items() if key not in local_vars or value is not local_vars[key]
Expand Down
10 changes: 5 additions & 5 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,14 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled):
output = eager_or_compiled(conv_layer, input)
output_transpose = eager_or_compiled(conv_transpose_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
rtol = 2e-7 if tp_dtype == tp.float32 else 9e-4
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol, atol=1e-5)
assert output.shape == list(expected.shape)
assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=rtol_)
assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5)
assert output_transpose.shape == list(expected_transpose.shape)
assert tp.allclose(output, output_transpose, rtol=rtol_)
assert tp.allclose(output, output_transpose, rtol=rtol, atol=1e-5)
assert output.shape == output_transpose.shape
assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=rtol_)
assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5)
assert list(expected.shape) == list(expected_transpose.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_downscale)
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/integration/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_basic_forward_pass_accuracy(self, eager_or_compiled):
with torch.no_grad():
torch_output = torch_model(input_tensor)

rtol_ = 2e-6
assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_)
assert torch.allclose(torch.from_dlpack(tp_output), torch_output, atol=1e-5, rtol=2e-6)

def test_dict_forward_pass_accuracy(self, eager_or_compiled):
torch_model = torch.nn.Sequential(
Expand Down
31 changes: 17 additions & 14 deletions tripy/tripy/backend/api/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Sequence, Tuple, Union

from tripy import export
from tripy.common.exception import raise_error
from tripy.common.shape_bounds import ShapeBounds
from tripy.frontend.dimension_size import DimensionSize
from tripy.types import IntLike


@export.public_api(document_under="compiling_code")
Expand All @@ -26,7 +27,7 @@ class InputInfo:
Captures information about an input to a compiled function.
"""

def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tripy.dtype") -> None:
def __init__(self, shape: Sequence[Union[IntLike, Tuple[IntLike, IntLike, IntLike]]], dtype: "tripy.dtype") -> None:
"""
Args:
shape: The shape of the input.
Expand All @@ -38,9 +39,9 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr
:caption: Example
inp = tp.InputInfo((2, 4), dtype=tp.float32)
assert inp.shape_bounds.min == (2, 4)
assert inp.shape_bounds.opt == (2, 4)
assert inp.shape_bounds.max == (2, 4)
assert inp.shape_bounds.min == [2, 4]
assert inp.shape_bounds.opt == [2, 4]
assert inp.shape_bounds.max == [2, 4]
.. code-block:: python
:linenos:
Expand All @@ -49,22 +50,24 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr
# The first dimension will support values in the range [1, 3],
# optimizing for a size of 2.
inp = tp.InputInfo(((1, 2, 3), 4), dtype=tp.float32)
assert inp.shape_bounds.min == (1, 4)
assert inp.shape_bounds.opt == (2, 4)
assert inp.shape_bounds.max == (3, 4)
assert inp.shape_bounds.min == [1, 4]
assert inp.shape_bounds.opt == [2, 4]
assert inp.shape_bounds.max == [3, 4]
"""
is_int_like = lambda arg: any(isinstance(arg, typ) for typ in {int, DimensionSize})

# TODO (#252): Allow `shape` to be a shape tensor
min_shape = []
opt_shape = []
max_shape = []
for elem in shape:
if isinstance(elem, numbers.Number):
if is_int_like(elem):
elem = (elem,) * 3
elif isinstance(elem, Sequence):
if not all(isinstance(val, numbers.Number) for val in elem):
if not all(is_int_like(val) for val in elem):
raise_error(
"Shape values must be numbers.",
[f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"],
"Shape values must be integers or `DimensionSize`s.",
[f"Shape: {shape} contains an element of incorrect type: {repr(elem)}"],
)
if len(elem) != 3:
raise_error(
Expand All @@ -76,15 +79,15 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr
)
else:
raise_error(
"Shape values should be either a single number or a Tuple specifying min/opt/max bounds.",
"Shape values should be either a single integer-like value or a 3-element tuple specifying min/opt/max bounds.",
[f"Shape: {shape} contains an invalid element: {elem}"],
)

min_shape.append(elem[0])
opt_shape.append(elem[1])
max_shape.append(elem[2])

self.shape_bounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
self.shape_bounds = ShapeBounds(min_shape, opt_shape, max_shape)
self.dtype = dtype

def __str__(self) -> str:
Expand Down
8 changes: 5 additions & 3 deletions tripy/tripy/common/shape_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from dataclasses import dataclass
from typing import Sequence

from tripy.types import IntLike


@dataclass
class ShapeBounds:
min: Sequence[int]
opt: Sequence[int]
max: Sequence[int]
min: Sequence[IntLike]
opt: Sequence[IntLike]
max: Sequence[IntLike]

def is_static(self):
return self.min == self.opt == self.max
3 changes: 2 additions & 1 deletion tripy/tripy/frontend/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tripy import constraints, export
from tripy.common.exception import raise_error
from tripy.frontend import utils as frontend_utils
from tripy.types import IntLike


@export.public_api(document_under="operations/functions")
Expand All @@ -27,7 +28,7 @@
"T1": ["float32", "float16", "bfloat16", "int4", "float8", "int8", "int32", "int64", "bool"],
},
)
def repeat(input: "tripy.Tensor", repeats: Union[int, "tripy.DimensionSize"], dim: int) -> "tripy.Tensor":
def repeat(input: "tripy.Tensor", repeats: IntLike, dim: int) -> "tripy.Tensor":
"""
Repeats each element of a tensor after itself along the specified dimension.
Expand Down
17 changes: 14 additions & 3 deletions tripy/tripy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,26 @@
module=sys.modules[__name__],
symbol="TensorLike",
doc="""
Type annotation for a parameter that is either a Tripy :class:`Tensor` or a Python number that can be automatically converted into one.
A Tripy :class:`Tensor` or a Python number that can be automatically converted into one.
""",
)(Union["tripy.Tensor", numbers.Number])


IntLike = export.public_api(
document_under="types.rst",
module=sys.modules[__name__],
symbol="IntLike",
doc="""
An integer-like object.
""",
)(Union[int, "tripy.DimensionSize"])


ShapeLike = export.public_api(
document_under="types.rst",
module=sys.modules[__name__],
symbol="ShapeLike",
doc="""
Type annotation for a parameter that represents a shape.
A shape of a :class:`Tensor` .
""",
)(Sequence[Union[int, "tripy.DimensionSize"]])
)(Sequence[IntLike])

0 comments on commit b71a868

Please sign in to comment.