Skip to content

Commit

Permalink
Refactor quantized bytes representation (#2627)
Browse files Browse the repository at this point in the history
* Add pack and unpack tests

* Refactor quantized bytes repr + switch to little endian packing representation

* Missing import

* Another one

* Uncomment qparams check

* Add comment for context

* Reinstate correct equality check
  • Loading branch information
laggui authored Dec 19, 2024
1 parent c33ca14 commit 60f70cf
Show file tree
Hide file tree
Showing 14 changed files with 504 additions and 302 deletions.
7 changes: 2 additions & 5 deletions crates/burn-import/src/pytorch/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,8 @@ where
.map(ElementConversion::elem)
.collect();

let TensorData {
bytes,
shape,
dtype,
} = TensorData::new(data, shape);
let data = TensorData::new(data, shape.clone());
let (dtype, bytes) = (data.dtype, data.into_bytes());

// Manually serialize the tensor instead of using the `ParamSerde` struct, such as:
// ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer)
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-jit/src/kernel/quantization/dequantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
pub(crate) fn extract_i8s(value: u32) -> Line<i32> {
let mut line = Line::empty(4);
// Extract each 8-bit segment
line[0] = extract_i8(value, 24);
line[1] = extract_i8(value, 16);
line[2] = extract_i8(value, 8);
line[3] = extract_i8(value, 0);
line[0] = extract_i8(value, 0);
line[1] = extract_i8(value, 8);
line[2] = extract_i8(value, 16);
line[3] = extract_i8(value, 24);

line
}
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-jit/src/kernel/quantization/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel(
range_max,
);
// Shift and combine into u32
v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1));
v_packed |= (v[0] & 0xFF) << (8 * i);
}
output[ABSOLUTE_POS] = v_packed;
}
Expand Down Expand Up @@ -105,7 +105,7 @@ pub(crate) fn pack_i8s_to_u32s(value: Line<u32>) -> u32 {
#[unroll]
for i in 0..line_size {
// Shift and combine into u32
v_packed |= (value[i] & 0xFF) << (8 * (line_size - i - 1));
v_packed |= (value[i] & 0xFF) << (8 * i);
}
v_packed
}
Expand Down Expand Up @@ -150,7 +150,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel(
range_max,
);
// Shift and combine into u32
v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1));
v_packed |= (v[0] & 0xFF) << (8 * i);
}
output[ABSOLUTE_POS] = v_packed;
}
Expand Down
8 changes: 2 additions & 6 deletions crates/burn-jit/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Range;
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
Bytes, DType, Device, Shape, TensorData,
DType, Device, Shape, TensorData,
};

use crate::{
Expand Down Expand Up @@ -82,11 +82,7 @@ where
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;

TensorData {
bytes: Bytes::from_bytes_vec(bytes),
shape: tensor.shape.into(),
dtype: tensor.dtype,
}
TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)
}

fn q_swap_dims(
Expand Down
26 changes: 10 additions & 16 deletions crates/burn-jit/src/ops/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_tensor::{
ops::{TransactionOps, TransactionPrimitiveResult},
Bytes, DType, TensorData,
DType, TensorData,
};

use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
Expand Down Expand Up @@ -73,27 +73,21 @@ where
match kind {
Kind::Float(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_floats.push(TensorData {
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
result
.read_floats
.push(TensorData::from_bytes(bytes, shape, dtype));
}
Kind::Int(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_ints.push(TensorData {
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
result
.read_ints
.push(TensorData::from_bytes(bytes, shape, dtype));
}
Kind::Bool(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_bools.push(TensorData {
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
result
.read_bools
.push(TensorData::from_bytes(bytes, shape, dtype));
}
}
}
Expand Down
43 changes: 36 additions & 7 deletions crates/burn-ndarray/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{
AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType, SymmetricQuantization,
QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization,
},
DType, Shape, TensorData, TensorMetadata,
DType, ElementConversion, Shape, TensorData, TensorMetadata,
};

use crate::{
Expand Down Expand Up @@ -36,12 +36,31 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<S
fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(scheme) => {
let qparams = data.get_q_params().unwrap();
let data = data.convert::<Q>();
NdArrayQTensor {
qtensor: NdArrayTensor::<Q>::from_data(data),
let shape = data.shape.clone();
let num_elements = data.num_elements();
let q_bytes = QuantizedBytes {
bytes: data.into_bytes(),
scheme,
qparams,
num_elements,
};

match scheme {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
| QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
let (values, qparams) = q_bytes.into_vec_i8();

let data = TensorData::new(values, shape).convert::<Q>();
let qparams = QParams {
scale: qparams.scale,
offset: qparams.offset.map(|x| x.elem::<Q>()),
};

NdArrayQTensor {
qtensor: NdArrayTensor::<Q>::from_data(data),
scheme,
qparams,
}
}
}
}
_ => panic!(
Expand Down Expand Up @@ -92,7 +111,17 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<S
},
};

let shape = tensor.shape();
let data = into_data_f(tensor).with_quantization(strategy);
let num_elements = data.num_elements();
let q_bytes = QuantizedBytes {
bytes: data.into_bytes(),
scheme: *scheme,
num_elements,
};
let (values, _) = q_bytes.into_vec_i8();
let data = TensorData::new(values, shape).convert::<Q>();

NdArrayQTensor {
qtensor: NdArrayTensor::<Q>::from_data(data),
scheme: *scheme,
Expand Down
14 changes: 10 additions & 4 deletions crates/burn-tch/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{
QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType,
QuantizedBytes,
},
DType, Shape, TensorData, TensorMetadata,
};
Expand Down Expand Up @@ -46,10 +47,15 @@ impl<E: TchElement, Q: QuantElement> QTensorOps<Self> for LibTorch<E, Q> {
// methods take the values provided when quantizing.
match data.dtype {
DType::QFloat(scheme) => {
let qparams = data.get_q_params::<E, Q>().unwrap();
let dequantized = data.dequantize().unwrap();
let values = dequantized.as_slice::<E>().unwrap();
let tensor = tch::Tensor::from_slice(values).to(device);
let num_elements = data.num_elements();
let q_bytes = QuantizedBytes {
bytes: data.into_bytes(),
scheme,
num_elements,
};

let (values, qparams) = q_bytes.dequantize();
let tensor = tch::Tensor::from_slice(&values).to(device);
let tensor = quantize(tensor.reshape(shape_tch.dims), &scheme, &qparams);

TchQTensor {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ burn-common = { path = "../burn-common", version = "0.16.0", default-features =
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true }
cubecl = { workspace = true, optional = true, default-features = true }

bytemuck = { workspace = true }
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
colored = { workspace = true, optional = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
Expand Down
27 changes: 20 additions & 7 deletions crates/burn-tensor/src/tensor/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<'de> serde::Deserialize<'de> for Bytes {
impl Clone for Bytes {
fn clone(&self) -> Self {
// unwrap here: the layout is valid as it has the alignment & size of self
Self::try_from_data(MAX_ALIGN, self.deref()).unwrap()
Self::try_from_data(self.align(), self.deref()).unwrap()
}
}

Expand Down Expand Up @@ -380,7 +380,7 @@ impl Bytes {
}
}

fn reserve(&mut self, additional: usize) {
fn reserve(&mut self, additional: usize, align: usize) {
let needs_to_grow = additional > self.capacity().wrapping_sub(self.len());
if !needs_to_grow {
return;
Expand All @@ -390,17 +390,20 @@ impl Bytes {
};
// guarantee exponential growth for amortization
let new_cap = required_cap.max(self.capacity() * 2);
let new_cap = new_cap.max(MAX_ALIGN); // Small allocations would be pointless
let Ok(new_layout) = Layout::from_size_align(new_cap, MAX_ALIGN) else {
let new_cap = new_cap.max(align); // Small allocations would be pointless
let Ok(new_layout) = Layout::from_size_align(new_cap, align) else {
alloc_overflow()
};
self.alloc.grow(new_layout);
}

/// Extend the byte buffer from a slice of bytes
pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) {
/// Extend the byte buffer from a slice of bytes.
///
/// This is used internally to preserve the alignment of the memory layout when matching elements
/// are extended. Prefer [`Self::extend_from_byte_slice`] otherwise.
pub(crate) fn extend_from_byte_slice_aligned(&mut self, bytes: &[u8], align: usize) {
let additional = bytes.len();
self.reserve(additional);
self.reserve(additional, align);
let len = self.len();
let new_cap = len.wrapping_add(additional); // Can not overflow, as we've just successfully reserved sufficient space for it
let uninit_spare = &mut self.alloc.memory_mut()[len..new_cap];
Expand All @@ -412,11 +415,21 @@ impl Bytes {
self.len = new_cap;
}

/// Extend the byte buffer from a slice of bytes
pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) {
self.extend_from_byte_slice_aligned(bytes, MAX_ALIGN)
}

/// Get the total capacity, in bytes, of the wrapped allocation.
pub fn capacity(&self) -> usize {
self.alloc.layout.size()
}

/// Get the alignment of the wrapped allocation.
pub(crate) fn align(&self) -> usize {
self.alloc.layout.align()
}

/// Convert the bytes back into a vector. This requires that the type has the same alignment as the element
/// type this [Bytes] was initialized with.
/// This only returns with Ok(_) if the conversion can be done without a memcopy
Expand Down
Loading

0 comments on commit 60f70cf

Please sign in to comment.