Skip to content

Commit

Permalink
[mosaic_gpu] Fixed unbounded recursion in FragmentedArray._pointwise
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700265616
  • Loading branch information
superbobry authored and Google-ML-Automation committed Nov 26, 2024
1 parent 16a5607 commit b6566c8
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout):
reg, self.shape, new_layout, is_signed=self.is_signed
)

def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False):
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
# If our layout is a splat, then we should either dispatch to a non-splat
# layout, or broadcast ourselves to the output shape first.
if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout):
if isinstance(self.layout, WGSplatFragLayout):
output_shape = self.shape
for i, o in enumerate(other):
if not isinstance(o, FragmentedArray):
Expand All @@ -641,9 +641,10 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_
else:
output_shape = np.broadcast_shapes(output_shape, o.shape)
# If we get here then we haven't found any non-splat layout.
return self.broadcast(output_shape)._pointwise(
op, *other, output_is_signed=output_is_signed, force_no_dispatch=True,
)
if self.shape != output_shape:
return self.broadcast(output_shape)._pointwise(
op, *other, output_is_signed=output_is_signed
)

other_arrs = []
for o in other:
Expand Down

0 comments on commit b6566c8

Please sign in to comment.