Skip to content

Commit

Permalink
Merge branch 'main' into refactor/jit/qtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Dec 11, 2024
2 parents 52edc88 + ebd7649 commit f90734f
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 271 deletions.
130 changes: 65 additions & 65 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ spin = { version = "0.9.8", features = [
] }
strum = "0.26.3"
strum_macros = "0.26.4"
syn = { version = "2.0.89", features = ["full", "extra-traits"] }
syn = { version = "2.0.90", features = ["full", "extra-traits"] }
tempfile = "3.14.0"
thiserror = "2.0.6"
tokio = { version = "1.41.1", features = ["rt", "macros"] }
tokio = { version = "1.42.0", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.33"
tracing-subscriber = "0.3.18"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ WGPU (WebGPU): Cross-Platform GPU Backend 🌐

Based on the most popular and well-supported Rust graphics library, [WGPU](https://wgpu.rs), this
backend automatically targets Vulkan, OpenGL, Metal, Direct X11/12, and WebGPU, by using the WebGPU
shading language [WGSL](https://www.w3.org/TR/WGSL/https://www.w3.org/TR/WGSL/), or optionally
shading language [WGSL](https://www.w3.org/TR/WGSL/), or optionally
[SPIR-V](https://www.khronos.org/spir/) when targeting Vulkan. It can also be compiled to Web
Assembly to run in the browser while leveraging the GPU, see
[this demo](https://antimora.github.io/image-classification/). For more information on the benefits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ where
) {
let k_step = SMM::K;
let range = k_range.1 - k_range.0;
#[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
#[allow(clippy::manual_div_ceil)]
let num_loops = (range + k_step - 1) / k_step;

Expand Down
213 changes: 77 additions & 136 deletions crates/burn-jit/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use core::marker::PhantomData;

use burn_tensor::{Element, ElementConversion};
use cubecl::{
ir::{Elem, FloatKind},
linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy},
tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner},
tune,
tune::{local_tuner, tune_with, LocalTuner},
Feature,
};

use crate::{
Expand All @@ -15,73 +16,45 @@ use crate::{
JitRuntime, JitTuneId,
};

use super::key::MatmulAutotuneKey;
use super::key::create_key;

/// Set of matmul implementations available for autotune
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
pub struct MatmulAutotuneOperationSet<R: JitRuntime, E: FloatElement> {
#[tune(
operations(matmul_tiling2d, matmul_accelerated, matmul_simple),
create_key = create_key::<R, E>,
should_run = should_run
)]
fn matmul_ops<R: JitRuntime, E: FloatElement>(
key: JitAutotuneKey,
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
_e: PhantomData<E>,
}
impl<R: JitRuntime, E: FloatElement> MatmulAutotuneOperationSet<R, E> {
fn new(lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>) -> Self {
Self {
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())),
lhs,
rhs,
out,
_e: PhantomData,
}
}
}

impl<R: JitRuntime, E: FloatElement> AutotuneOperationSet<JitAutotuneKey>
for MatmulAutotuneOperationSet<R, E>
{
fn key(&self) -> JitAutotuneKey {
self.key.clone()
}

fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1);
let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1);
) {
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1);
let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1);

let out = empty_device::<R, E>(
self.out.client.clone(),
self.out.device.clone(),
self.out.shape.clone(),
);
let out = empty_device::<R, E>(out.client.clone(), out.device.clone(), out.shape.clone());

vec![
Box::new(MatmulTiling2d::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MatmulAccelerated::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MatmulSimple::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
]
}
tune_with!(lhs, rhs, out)
}

fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
match fastest_index {
0 => Box::new(MatmulTiling2d::<R, E>::new(self.lhs, self.rhs, self.out)),
1 => Box::new(MatmulAccelerated::<R, E>::new(self.lhs, self.rhs, self.out)),
2 => Box::new(MatmulSimple::<R, E>::new(self.lhs, self.rhs, self.out)),
_ => panic!("Fastest index is out of bound"),
}
fn should_run<R: JitRuntime, E: FloatElement>(
op: &MatmulOps<R, E>,
_key: &JitAutotuneKey,
index: usize,
) -> bool {
match index {
// Accelerated
// TODO: Add way to query actual requirements from cubecl
1 => op.lhs.client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
}),
_ => true,
}
}

Expand All @@ -100,85 +73,53 @@ pub fn matmul_autotune<R: JitRuntime, E: FloatElement + Element>(
TUNER.execute(
&JitTuneId::new::<R>(&lhs.device),
&client,
Box::new(MatmulAutotuneOperationSet::<R, E>::new(
lhs,
rhs,
output.clone(),
)),
Box::new(MatmulOps::<R, E>::new(lhs, rhs, output.clone())),
);

output
}

macro_rules! matmul_tune_ops {
($name:ident, $func:expr) => {
#[derive(new, Debug)]
pub(crate) struct $name<R: JitRuntime, E: FloatElement> {
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
_e: PhantomData<E>,
}

impl<R: JitRuntime, E: FloatElement> AutotuneOperation for $name<R, E> {
fn execute(self: Box<Self>) {
#[allow(clippy::redundant_closure_call)]
$func(self.lhs, self.rhs, self.out);
}

fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self {
lhs: self.lhs.clone(),
rhs: self.rhs.clone(),
out: self.out.clone(),
_e: self._e,
})
}
}
};
fn matmul_accelerated<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Accelerated,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}

// Probably the fastest in the general case.
matmul_tune_ops!(
MatmulAccelerated,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Accelerated,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}
);

// Probably the fastest when tensor cores are not available.
matmul_tune_ops!(
MatmulTiling2d,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Tiling2D(Tiling2dConfig::default()),
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}
);
fn matmul_tiling2d<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Tiling2D(Tiling2dConfig::default()),
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}

// Probably the fastest for small matrices.
matmul_tune_ops!(
MatmulSimple,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Simple,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}
);
fn matmul_simple<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Simple,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)
.unwrap();
}
Loading

0 comments on commit f90734f

Please sign in to comment.