Skip to content

Commit

Permalink
Merge pull request #18497 from gnecula:shape_poly_refactor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581901696
  • Loading branch information
jax authors committed Nov 13, 2023
2 parents 6403291 + f9474b2 commit 7498d30
Showing 1 changed file with 94 additions and 91 deletions.
185 changes: 94 additions & 91 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def both_enable_and_disable_xla(self) -> tuple["PolyHarness", "PolyHarness"]:
self.name = f"{self.name}_enable_xla_True"
return (self, other)

def run_test(self, tst: tf_test_util.JaxToTfTestCase):
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]:
def log_message(extra: str):
return f"[{tst._testMethodName}]: {extra}"

Expand Down Expand Up @@ -609,7 +609,7 @@ def log_message(extra: str):
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)

if expect_error_type is not None:
return
return None

if self.expected_output_signature:
# Strangely, output_shapes can be a single shape for a function with a
Expand Down Expand Up @@ -649,6 +649,11 @@ def log_message(extra: str):
f"to {custom_assert_lims[0]}"))
custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore
tol=tol, err_msg=None)
return res_tf
else:
return None
else:
return None


def check_shape_poly(tst, f_jax: Callable, *,
Expand All @@ -657,7 +662,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
polymorphic_shapes: Sequence[Optional[str]] = (),
input_signature: Optional[Sequence[tf.TensorSpec]] = None,
expected_output_signature: Optional[tf.TensorSpec] = None,
expect_error=(None, None)):
expect_error=(None, None)) -> Optional[jax.Array]:
# Makes and tests a harness. See PolyHarness documentation.
h = PolyHarness("", "", f_jax,
arg_descriptors=arg_descriptors,
Expand All @@ -666,7 +671,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
input_signature=input_signature,
expected_output_signature=expected_output_signature,
expect_error=expect_error)
h.run_test(tst)
return h.run_test(tst)


class ShapePolyTest(tf_test_util.JaxToTfTestCase):
Expand Down Expand Up @@ -730,43 +735,6 @@ def f_jax(x, y):
polymorphic_shapes=["h, h", "h, h"],
expected_output_signature=tf.TensorSpec([None, None]))

@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
# b == 6
kwargs=[
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), # Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2", make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2", make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), # Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
])
def test_arange(self, make_args):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((6,), dtype=np.int32)
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
f_jax(x))

@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
kwargs=[
Expand All @@ -792,14 +760,9 @@ def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
with self.assertRaisesRegex(expect_error, expect_msg):
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
check_shape_poly(self, f_jax, arg_descriptors=[x],
polymorphic_shapes=["b"])

def test_argmax(self):
def f_jax(x): # x: f32[b, 4, 5]
return lax.argmax(x, axis=1, index_dtype=np.int32)
x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
f_jax(x))

@jtu.parameterized_filterable(
kwargs=[
Expand Down Expand Up @@ -996,11 +959,13 @@ def shaped_array(shape_spec: str, actual_shape: core.Shape):
expected_shapeenv=dict(a=2, b=3, c=4))

def test_arg_avals_errors(self):
"""Test error reporting for shape polymorpish."""
"""Test error reporting for shape polymorphism."""
def conv_and_run(*, arg_shape: core.Shape,
polymorphic_shape: str):
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
check_shape_poly(self, lambda x: x,
arg_descriptors=[arg],
polymorphic_shapes=[polymorphic_shape])

with self.assertRaisesRegex(ValueError,
re.escape("polymorphic shape spec should be")):
Expand Down Expand Up @@ -1094,7 +1059,9 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
_ = jax2tf.convert(f_jax, polymorphic_shapes=[poly_spec])(x)
_ = check_shape_poly(self, f_jax,
arg_descriptors=[x],
polymorphic_shapes=[poly_spec])

def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
Expand Down Expand Up @@ -1372,7 +1339,8 @@ def f(x, y):
res_jax = f(x, y)
self.assertAllClose(
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
check_shape_poly(self, f, arg_descriptors=[x, y],
polymorphic_shapes=["(b, h)", "h"]))

def test_while(self):
def f(x):
Expand All @@ -1382,7 +1350,8 @@ def f(x):
(x, 0))

x = np.ones((3,), dtype=np.float32)
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
res_tf = check_shape_poly(self, f, arg_descriptors=[x],
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)

@jtu.parameterized_filterable(
Expand Down Expand Up @@ -1671,32 +1640,37 @@ def f(x):
return jnp.sum(x, axis=0) * x.shape[0]

x = np.arange(3.)
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(
9.,
jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(9.,
check_shape_poly(self, f,
arg_descriptors=[x],
polymorphic_shapes=["(b,)"]))
self.assertAllClose(
9.,
tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
check_shape_poly(self, jax.jit(f),
arg_descriptors=[x], polymorphic_shapes=["(b,)"]))

res_primal, res_tangent = jax2tf.convert(
res_primal, res_tangent = check_shape_poly(self,
lambda x, xt: jax.jvp(f, (x,), (xt,)),
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
arg_descriptors=[x, np.array([0.1, 0.2, 0.3])],
polymorphic_shapes=["b", "b"])
self.assertAllClose((9., 1.8), (res_primal, res_tangent))

self.assertAllClose(
np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
check_shape_poly(self, jax.grad(f),
arg_descriptors=[x],
polymorphic_shapes=["b"]))

xv = np.arange(24.).reshape((2, 3, 4))
res_vmap = jax.vmap(f, in_axes=1)(xv)
# Implement by iteration
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
self.assertAllClose(res_iter, res_vmap)

res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
polymorphic_shapes=["b1, b2, ..."])(xv)
self.assertAllClose(res_iter, res_vmap_tf.numpy())
res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1),
arg_descriptors=[xv],
polymorphic_shapes=["b1, b2, ..."])
self.assertAllClose(res_iter, res_vmap_tf)

def test_with_hash_collision_vmap(self):
# Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have
Expand Down Expand Up @@ -1948,33 +1922,6 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w)
self.assertAllClose(f2(* f1(x, y)), res)

def test_gather_1d(self):
operand = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], np.float32)
rand_idxs = np.random.randint(0, high=max(operand.shape), size=(3, 1), dtype=np.int32)
slice_x = np.zeros((10,), dtype=jnp.float32)
dnums = lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)
)

@jax.jit
def f_jax(operand, start_indices, x):
return lax.gather(
operand,
start_indices,
dimension_numbers=dnums,
slice_sizes=x.shape,
mode="promise_in_bounds",
)

res = f_jax(operand, rand_idxs, slice_x)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"],
)
res_tf = f_tf(operand, rand_idxs, slice_x)
self.assertAllClose(res, res_tf)


# List containing either harnesses, or lists of harnesses
_POLY_SHAPE_TEST_HARNESSES = [
Expand All @@ -1986,6 +1933,45 @@ def f_jax(operand, start_indices, x):
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]),
[
# make_args invoked with op.shape[0] and produces the arange args:
# start, stop, step, dtype
PolyHarness("arange", kwargs["testcase_name"], # type: ignore
lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore
arg_descriptors=[RandArg((6,), np.float32)],
polymorphic_shapes=["b"])
for kwargs in [
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)),
# Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2",
make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2",
make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)),
# Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
]
],
# Reduce the poly dimension
PolyHarness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
Expand Down Expand Up @@ -2328,6 +2314,23 @@ def f_jax(operand, start_indices, x):
lambda x: lax.full((x.shape[0], 2), 3.) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("gather", "1d",
lambda operand, start_indices, x: lax.gather(
operand,
start_indices,
dimension_numbers=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(),
start_index_map=(0,)),
slice_sizes=x.shape,
mode="promise_in_bounds"),
arg_descriptors=[
RandArg((10,), np.float32),
np.random.randint(0, high=10, size=(3, 1),
dtype=np.int32),
np.zeros((10,), dtype=jnp.int32),
],
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]),
# operand is non-poly, index is poly
PolyHarness("getitem", "op=static_idx=poly",
lambda a, i: a[i],
Expand Down

0 comments on commit 7498d30

Please sign in to comment.