Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] how to set correct swizzle with different smem tile to avoid static_assert failure in thr_copy.partition_S/D() #1876

Closed
CalebDu opened this issue Oct 16, 2024 · 11 comments

Comments

@CalebDu
Copy link
Contributor

CalebDu commented Oct 16, 2024

I encounter a static_assert failure Static shape_div failure in thr_copy.partition_S/D(smem_tensor) when I ajust SmemLayout about tileM, tileK, stage for sA with swizzle in my customized uint1b_t Gemm.

Customized Gemm is for LLM decoding phase, so tile_m for sA is fixed to 8, tile_n for sB is larger than tile_m like 32, 64 or more. Static_assert failure is only triggered in thr_copy.partition_S/D(sA).
To simplify the code, I create sample code only with G2S part.

using namespace cute;
__global__ void test_case(void *ptr) {
  int m = 8;
  int k = 1024;
  constexpr int tile_m = 8;
  constexpr int tile_k = 128;
  constexpr int stage = 3;
  static constexpr int G2S_SwizzleB = 3;
  static constexpr int G2S_SwizzleM = 7;
  static constexpr int G2S_SwizzleS = 3;

  auto A = make_tensor(make_gmem_ptr<uint1_t>(ptr), make_shape(m, k),
                       make_stride(k, _1{}));
  auto gA = local_tile(A, make_shape(Int<tile_m>{}, Int<tile_k>{}), (0, _));

  __shared__ int8_t smem[tile_m * tile_k * stage / 8];

  // SmemLayoutAtom [8, tile_k]
  using SmemABLayoutAtom =
      decltype(composition(Swizzle<G2S_SwizzleB, G2S_SwizzleM, G2S_SwizzleS>{},
                           make_layout(make_shape(Int<8>{}, Int<tile_k>{}),
                                       make_stride(Int<tile_k>{}, Int<1>{}))));

  // SmemLayoutAtom [tile_m, tile_k]
  using SmemALayout = decltype(tile_to_shape(
      SmemABLayoutAtom{},
      make_shape(Int<tile_m>{}, Int<tile_k>{}, Int<stage>{})));

  auto sA = make_tensor(make_smem_ptr<uint1_t>(smem), SmemALayout{});

  using G2SCopyOp =
      SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>; // load 128 uint1b_t per thread
  using G2SCopyTraits = Copy_Traits<G2SCopyOp>;
  using G2SCopyAtom = Copy_Atom<G2SCopyTraits, uint1_t>;

  // thread [8, tile_k/128]
  using G2SCopyA = decltype(make_tiled_copy(
      G2SCopyAtom{},
      make_layout(make_shape(Int<tile_m>{}, Int<tile_k / 128>{}),
                  make_stride(Int<tile_k / 128>{}, _1{})),
      make_layout(make_shape(_1{}, _128{}))));

  G2SCopyA copy;
  auto thr_copy = copy.get_slice(threadIdx.x);
  auto src = thr_copy.partition_S(gA);
  auto dst = thr_copy.partition_D(sA);
}

There is no bank conflict with swizzle<3, 7, 3> and tile_k = 128 and stage = 1 in customized Gemm.
Keep swizzle<3, 7, 3> I ajust stage with 3,5 cause static_assert failure. then modify swizzle<0, 7, 0> with no static_assert failure.
So How to set correct swizzle with different tile_m, tile_k, stage and keep banck conflict fre?
I can manully try different swizzle combination, but is there any more efficient way?
Later I want to ajust tile_k with 256, 384, 512,1024 , etc. and stage with 2,3,4,5 , etc..

@ccecka
Copy link

ccecka commented Oct 18, 2024

I suspect this is simply due to your layout not being large enough. Swizzle<3,7,3> does have a base size of elements, but it is a function that repeats only after every 64xBaseElements values, so the SMEM Atom expects to be at least that size. It also concerns me that your TiledCopy appears to only be using 8 threads.

For example,

#include <cute/tensor.hpp>
using namespace cute;

int main()
{
  // SmemLayoutAtom 64x128
  auto smem_layout = composition(Swizzle<3,7,3>{},
                                 make_layout(make_shape ( _64{},_128{}),
                                             make_stride(_128{},  _1{})));

  Tensor gA = make_tensor(make_gmem_ptr<uint1_t>((void*)0), make_shape(_64{}, _128{}), LayoutRight{});
  Tensor sA = make_tensor(make_smem_ptr<uint1_t>((void*)0), smem_layout);

  auto copy = make_tiled_copy(
      Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, uint1b_t>{},  // load 128 uint1b_t as uint128_t each
      make_layout(make_shape(_32{}, _1{})),                               // 32x1 threads
      make_layout(make_shape(_1{}, _128{})));                             // 1x128 values per thread;
  auto thr_copy = copy.get_slice(0);

  Tensor src = thr_copy.partition_S(gA);
  Tensor dst = thr_copy.partition_D(sA);

  print(src); print("\n");
  print(dst); print("\n");
}

successfully partitions and prints

gmem_subptr[1b]((nil).0) o ((_128,_1),_2,_1):((_1,_0),_4096,_0)
smem_subptr[1b]((nil).0) o ((_128,_1),_2,_1):((_1,_0),4608,_0)

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 18, 2024

I suspect this is simply due to your layout not being large enough. Swizzle<3,7,3> does have a base size of 128b, but it is a function that repeats only after every 64x128b values, so the SMEM Atom expects to be at least that size. It also concerns me that your TiledCopy appears to only be using 8 threads.

For example,

#include <cute/tensor.hpp>
using namespace cute;

int main()
{
// SmemLayoutAtom 64x128
auto smem_layout = composition(Swizzle<3,7,3>{},
make_layout(make_shape ( _64{},_128{}),
make_stride(_128{}, _1{})));

Tensor gA = make_tensor(make_gmem_ptr<uint1_t>((void*)0), make_shape(_64{}, _128{}), LayoutRight{});
Tensor sA = make_tensor(make_smem_ptr<uint1_t>((void*)0), smem_layout);

auto copy = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBALcute::uint128_t, uint1b_t>{}, // load 128 uint1b_t as uint128_t each
make_layout(make_shape(_32{}, _1{})), // 32x1 threads
make_layout(make_shape(_1{}, _128{}))); // 1x128 values per thread;
auto thr_copy = copy.get_slice(0);

Tensor src = thr_copy.partition_S(gA);
Tensor dst = thr_copy.partition_D(sA);

print(src); print("\n");
print(dst); print("\n");
}
successfully partitions and prints

gmem_subptr[1b]((nil).0) o ((_128,_1),_2,_1):((_1,_0),_4096,_0)
smem_subptr[1b]((nil).0) o ((_128,_1),_2,_1):((_1,_0),4608,_0)

How to get the 64 in 64x128b? Is it come from $2^b \times 2^s = 64$ in swizzle<b=3, m=7, s=3>? So For smem layout [ m, n*128b] with swizzle<b, 7, s>, Does $m \times n$ have to be powers of two like $2^b\times 2^s$ ?

@ccecka
Copy link

ccecka commented Oct 18, 2024

Precisely, B is the number of bits in the swizzle operation and S is the number of bits in the shift.
https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp#L43

For Swizzle Atom Layouts, m and n will almost always be powers of two, yes, otherwise the swizzle isn't actually useful to avoid bank conflicts. These Atoms can be tiled_to_shape any multiple of that atom shape though -- so a 32x32 swizzle layout can be used to construct a 96x96 layout, for example.

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 18, 2024

Precisely, B is the number of bits in the swizzle operation and S is the number of bits in the shift. https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp#L43

For Swizzle Layouts, m and n will almost always be powers of two, yes, otherwise the swizzle isn't actually useful to avoid bank conflicts.

I get it. Thank your help!

@CalebDu CalebDu closed this as completed Oct 18, 2024
@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 18, 2024

@ccecka

Sorry to bother again,I got a strange problem in the same case.

  constexpr int tile_m = 8;
  constexpr int tile_k = 128;
  using G2SCopyA = decltype(make_tiled_copy(
      G2SCopyAtom{},
      make_layout(make_shape(Int<tile_m>{}, Int<tile_k / 128>{}),
                  make_stride(Int<tile_k / 128>{}, _1{})),
      make_layout(make_shape(_1{}, _128{}))));

According above example code, I only use 8thread to copy gA to sA with [8, 128] CTA tile. But I find the remaining 24 threads outside of these 8 threads in this warp also do copy data. It causes copy [8, 128] tile for other CTA to shared memory space out of sA 8*128/8bit (sB memory) and get wrong gemm result.

Like following image, A is [9, 128] uint1b tensor with random data, B is [32, 128] uint1b tensor with all 0.
Image
And I refer predicate tensor code from #1873 to deal with OOB situation.

  copy_if(
      g2s_copy_a,
      [&](auto... coords) {
        return elem_less(g2s_tAgA_copy_pred(coords...), shape(A));
      },
      g2s_tAgA_copy(_, _, _, k_main_loop), g2s_tAsA_copy(_, _, _, 0));
  copy_if(
      g2s_copy_b,
      [&](auto... coords) {
        return elem_less(g2s_tBgB_copy_pred(coords...), shape(B));
      },
      g2s_tBgB_copy(_, _, _, k_main_loop), g2s_tBsB_copy(_, _, _, 0));

then I print sA sB data.
Image
9th line data (out of [8, 128] tile)in A is loaded to sB first line.
So how to figure out remaining threads in warp load OOB data to OOM sA buffer?
And I increase [8, 128] tile to [8, 256] ,tiled copy uses 16 threads < a warp, but the above situation disappear and get correct gemm result.

@ccecka
Copy link

ccecka commented Oct 18, 2024

I wouldn't use less than a warp in any copy, that's just going to force inefficiency.

Perhaps something like this?

  auto my_layout = make_shape(_8{}, _128{}), LayoutRight{});             // 8x128 row-major
  Tensor gA = make_tensor(make_gmem_ptr<uint1_t>(gptr), my_layout);
  Tensor sA = make_tensor(make_smem_ptr<uint1_t>(sptr), my_layout);

  auto cpy = make_tiled_copy(
      Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint32_t>, uint1b_t>{},  // load 32 uint1b_t as uint32_t each
      make_layout(make_shape(_8{},  _4{}), LayoutLeft{]),                // 8x4 threads, row-major
      make_layout(make_shape(_1{}, _32{})));                             // 1x32 values per thread;
  
  auto thr_cpy = cpy.get_slice(threadIdx.x);
  Tensor src = thr_cpy.partition_S(gA);
  Tensor dst = thr_cpy.partition_D(sA);

  print(src); print("\n");
  print(dst); print("\n");

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 19, 2024

copy_if(
    g2s_copy_a,
    [&](auto... coords) {
      return elem_less(g2s_tAgA_copy_pred(coords...), shape(A));
    },
    g2s_tAgA_copy(_, _, _, k_main_loop), g2s_tAsA_copy(_, _, _, 0));

In my case, shape(A) is [9, 128], g2s_copy_a shape is[32, 128], sA shape is [8, 128]. So this predicate tensor deals 9th line in [9, 128] as legal because [9, 128] <[32, 128]. But [9, 128] is out of bound for sA tile [8, 128] acutally.

So I have a question. if tiled copy shape is greater than cta tile, as above case, tiled copy is [32, 128] (32 thread, 128 per thread) and cta tile is [8, 128], Is there any solution to auto strip OOB data when g2s copy?This situation is common in small tile gemm or kernel launch thread number > tiled_copy thread number case, like warp specialization.

As following test code, I launch kernel with grid(1), block(64), tiled copy with 32 threads and [32, 128]tile. A shape is [36, 128], sA is [32,128] allocated sA_view [64, 128] smem. Then print sA_view shows all 64 thread do copy all A [36,128] data to sA_view with data OOB [32,128]. So, How to set proper Predicate tensor to strip OOB data? or Do I manally use if branch to avoid thread copy OOB data?

// grid(1,1,1) block(64, 1, 1)
__global__ void test(void *ptr) {
  using G2SCopyOp = UniversalCopy<cute::uint128_t>;
  using G2SCopyTraits = Copy_Traits<G2SCopyOp>;
  using G2SCopyAtom = Copy_Atom<G2SCopyTraits, cute::uint1_t>;

  using G2SCopyA = decltype(make_tiled_copy(
      G2SCopyAtom{},
      make_layout(make_shape(Int<32>{}, Int<1>{}), make_stride(Int<1>{}, _1{})),
      make_layout(make_shape(_1{}, _128{}))));

  Tensor A = make_tensor(make_gmem_ptr<cute::uint1_t>(ptr),
                         make_layout(make_shape(36, 128), make_stride(128, _1{})));
  Tensor gA = local_tile(A, make_tile(_32{}, _128{}), make_coord(0, 0));

  __shared__ uint8_t smem[64*128/8];
  Tensor sA = make_tensor(
      make_smem_ptr<cute::uint1_t>(smem),
      make_layout(make_shape(_32{}, _128{}), make_stride(_128{}, _1{})));
      
  Tensor sA_view = make_tensor(
      make_smem_ptr<cute::uint1_t>(smem),
      make_layout(make_shape(_64{}, _128{}), make_stride(_128{}, _1{})));
  G2SCopyA copy;

  auto thr_copy = copy.get_slice(threadIdx.x);
  auto tAgA = thr_copy.partition_S(gA);
  auto tAsA = thr_copy.partition_D(sA);

  cute::copy(copy, tAgA, tAsA);
  __syncthreads();
  if (cutlass::thread0()) {
    print_tensor(sA_view);
  }
}

@ccecka
Copy link

ccecka commented Oct 19, 2024

Every thread has to do something, which means you need to tell every thread what to do. Then predicate everything that's out-of-bounds. If you have 64 threads in your kernel, your tiled_copy should have 64 threads. This is why it's better to use all of the threads and a smaller vectorization than simply have threads doing nothing, as I suggested in the previous reply.

Here some more code for predication

// grid(1,1,1) block(64, 1, 1)
__global__ void test(void *ptr)
{
  Tensor A = make_tensor(make_gmem_ptr<cute::uint1_t>(ptr), make_shape(36, 128), LayoutRight{});
  Tensor C = make_identity_tensor(shape(A));  // (m,n) -> (m,n)

  Tensor gA = local_tile(A, make_tile(_32{}, _128{}), make_coord(0, 0));
  Tensor cA = local_tile(C, make_tile(_32{}, _128{}), make_coord(0, 0));
  __shared__ uint8_t smem[64*128/8];
  Tensor sA = make_tensor(make_smem_ptr<cute::uint1_t>(smem), shape(gA), LayoutRight{});

  auto cpy = make_tiled_copy(
      Copy_Atom<UniversalCopy<cute::uint128_t>, cute::uint1_t>{},
      make_layout(make_shape(Int<64>{}, Int<  1>{})),
      make_layout(make_shape(Int< 1>{}, Int<128>{})));

  auto thr_cpy = cpy.get_slice(threadIdx.x);
  Tensor tAgA = thr_cpy.partition_S(gA);   // (CPY, CPY_M, CPY_N)
  Tensor tAcA = thr_cpy.partition_S(cA);   // (CPY, CPY_M, CPY_N) -> (m,n)
  Tensor tAsA = thr_cpy.partition_D(sA);   // (CPY, CPY_M, CPY_N)

  // Slice out the CPY mode with the assumption that the predication is constant across the vector
  Tensor tAcA_ = tAcA(0,_,_);
  // Compare the generated coordinates with the original shape to determine predication
  auto pred = [&](auto... coord){ return elem_less(tAcA_(coord...), shape(A)); };
  // Copy
  cute::copy_if(cpy, pred, tAgA, tAsA);

  __syncthreads();

  Tensor sA_view = make_tensor(
      make_smem_ptr<cute::uint1_t>(smem),
      make_layout(make_shape(_64{}, _128{}), make_stride(_128{}, _1{})));

  if (cutlass::thread0()) {
    print_tensor(sA_view);
  }
}

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 19, 2024

Every thread has to do something, which means you need to tell every thread what to do. Then predicate everything that's out-of-bounds. If you have 64 threads in your kernel, your tiled_copy should have 64 threads. This is why it's better to use all of the threads and a smaller vectorization than simply have threads doing nothing, as I suggested in the previous reply.

Here some more code for predication

// grid(1,1,1) block(64, 1, 1)
global void test(void *ptr)
{
Tensor A = make_tensor(make_gmem_ptrcute::uint1_t(ptr), make_shape(36, 128), LayoutRight{});
Tensor C = make_identity_tensor(shape(A)); // (m,n) -> (m,n)

Tensor gA = local_tile(A, make_tile(_32{}, _128{}), make_coord(0, 0));
Tensor cA = local_tile(C, make_tile(_32{}, _128{}), make_coord(0, 0));
shared uint8_t smem[64*128/8];
Tensor sA = make_tensor(make_smem_ptrcute::uint1_t(smem), shape(gA), LayoutRight{});

auto cpy = make_tiled_copy(
Copy_Atom<UniversalCopycute::uint128_t, cute::uint1_t>{},
make_layout(make_shape(Int<64>{}, Int< 1>{})),
make_layout(make_shape(Int< 1>{}, Int<128>{})));

auto thr_cpy = cpy.get_slice(threadIdx.x);
Tensor tAgA = thr_cpy.partition_S(gA); // (CPY, CPY_M, CPY_N)
Tensor tAcA = thr_cpy.partition_S(cA); // (CPY, CPY_M, CPY_N) -> (m,n)
Tensor tAsA = thr_cpy.partition_D(sA); // (CPY, CPY_M, CPY_N)

// Slice out the CPY mode with the assumption that the predication is constant across the vector
Tensor tAcA_ = tAcA(0,,);
// Compare the generated coordinates with the original shape to determine predication
auto pred = [&](auto... coord){ return elem_less(tAcA_(coord...), shape(A)); };
// Copy
cute::copy_if(cpy, pred, tAgA, tAsA);

__syncthreads();

Tensor sA_view = make_tensor(
make_smem_ptrcute::uint1_t(smem),
make_layout(make_shape(_64{}, _128{}), make_stride(_128{}, _1{})));

if (cutlass::thread0()) {
print_tensor(sA_view);
}
}

I find root cause and solution to deal with it. In common case, tile copy shape is equal to cta tile shape or multi-loop to make up cta tile shape. So predicate tensor only needs to check shape(A) bound.

 auto pred = [&](auto... coord){ return elem_less(tAcA_(coord...), shape(A)); };

As following figure, copy tile shape is greater than cta tile shape. Above predicate tensor only check yellow bound but out of blue bound. So predicate tensor should double check bound.

auto blue_bound = make_tuple( (bid_m +1)* 8, (k_loop+1)*128); 
auto pred = [&](auto... coord){ return elem_less(tAcA_(coord...), shape(A)) 
                                                            && elem_less(tAcA_(coord...),  blue_bound); };

Image

In addition this, I have a curious question. As your comment mentioned, tiled copy have all thread in kernel launch, and

"This is why it's better to use all of the threads and a smaller vectorization than simply have threads doing nothing".

But Cuda prefers to new programming model (warp speicalization) that sperate all warp into consumer(compute) and producer(copy) in new arch like hopper. How to implement only thread of producer warp copy data by tiled_copyin Cute? Does it also use predicate tensor to mask consumer thread?

@ccecka
Copy link

ccecka commented Oct 20, 2024

If you have one warp specialized to perform the load, then your tiled_copybetter contain all threads of that one warp. It's better to use all of the threads available in your context (and design the tiled_copy to copy exactly the tile of data you're interested in, avoiding your extra predication step) and use a smaller vectorization than simply have threads doing nothing.

@CalebDu
Copy link
Contributor Author

CalebDu commented Oct 20, 2024

I do not get your idea about

avoiding your extra predication step

For example, launch kenrel with 2warp (64 thread), warp0(thread0-31) is producer to copy g2s tile [64, 4] int32_t element.

Tensor A = make_tensor(make_gmem_ptr<int32_t>(ptr),
                       make_layout(make_shape(65, 4), make_stride(4, _1{})));
Tensor gA = local_tile(A, make_tile(_64{}, _4{}), make_coord(0, 0));
using G2SCopyOp = UniversalCopy<cute::uint128_t>;
using G2SCopyTraits = Copy_Traits<G2SCopyOp>;
using G2SCopyAtom = Copy_Atom<G2SCopyTraits, int32_t>;
using G2SCopyA = decltype(make_tiled_copy(
    G2SCopyAtom{},
    make_layout(make_shape(Int<32>{}, Int<1>{}), make_stride(Int<1>{}, _1{})),
    make_layout(make_shape(_1{}, _4{}))));
G2SCopyA copy;
auto thr_copy = copy.get_slice(0);
auto tAgA = thr_copy.partition_S(gA);print(tAgA); //  ((_4,_1),2,1):((_1,_0),128,_4)

For each thread copy ((_4,_1),2,1):((_1,_0),128,_4) g2s. So warp0 can copy all [64, 4]data, but I still need predicate to avoid other 32 thread to copy OOB data. Then consumer warp1(thread32-63) only need to S2R copy entire [64,4] data and warp0 does not copy S2R. How to implement it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants