Skip to content

Commit

Permalink
resolve the size of override-sized arrays in backends
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Dec 19, 2024
1 parent 53f4079 commit d0e92aa
Show file tree
Hide file tree
Showing 15 changed files with 196 additions and 130 deletions.
21 changes: 9 additions & 12 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ pub enum Error {
/// [`crate::Sampling::First`] is unsupported.
#[error("`{:?}` sampling is unsupported", crate::Sampling::First)]
FirstSamplingNotSupported,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

/// Binary operation with a different logic on the GLSL side.
Expand Down Expand Up @@ -973,13 +975,12 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, "[")?;

// Write the array size
// Writes nothing if `ArraySize::Dynamic`
match size {
crate::ArraySize::Constant(size) => {
// Writes nothing if `ResolvedSize::Runtime`
match proc::resolve_array_size(size, self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => (),
proc::ResolvedSize::Runtime => (),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -4455,13 +4456,9 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, ")")?;
}
TypeInner::Array { base, size, .. } => {
let count = match size
.to_indexable_length(self.module)
.expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => return Ok(()),
let count = match proc::resolve_array_size(size, self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => return Ok(()),
};
self.write_type(base)?;
self.write_array_size(base, size)?;
Expand Down
19 changes: 9 additions & 10 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl crate::TypeInner {
}
}

pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> Result<u32, Error> {
match *self {
Self::Matrix {
columns,
Expand All @@ -62,19 +62,18 @@ impl crate::TypeInner {
} => {
let stride = Alignment::from(rows) * scalar.width as u32;
let last_row_size = rows as u32 * scalar.width as u32;
((columns as u32 - 1) * stride) + last_row_size
Ok(((columns as u32 - 1) * stride) + last_row_size)
}
Self::Array { base, size, stride } => {
let count = match size {
crate::ArraySize::Constant(size) => size.get(),
// A dynamically-sized array has to have at least one element
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => 1,
let count = match crate::proc::resolve_array_size(size, gctx)? {
crate::proc::ResolvedSize::Constant(size) => size,
// A runtime-sized array has to have at least one element
crate::proc::ResolvedSize::Runtime => 1,
};
let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
((count - 1) * stride) + last_el_size
let last_el_size = gctx.types[base].inner.size_hlsl(gctx)?;
Ok(((count - 1) * stride) + last_el_size)
}
_ => self.size(gctx),
_ => Ok(self.size(gctx)),
}
}

Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ pub enum Error {
Custom(String),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
10 changes: 4 additions & 6 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,12 +984,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
) -> BackendResult {
write!(self.out, "[")?;

match size {
crate::ArraySize::Constant(size) => {
match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => unreachable!(),
proc::ResolvedSize::Runtime => unreachable!(),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -1034,7 +1033,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
let ty_inner = &module.types[member.ty].inner;
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;

// The indentation is only for readability
write!(self.out, "{}", back::INDENT)?;
Expand Down Expand Up @@ -2635,7 +2634,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => unreachable!(),
}
write!(self.out, ")")?;
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub enum Error {
UnsupportedRayTracing,
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
24 changes: 8 additions & 16 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, Transl
use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
proc::index,
proc::{self, NameKey, TypeResolution},
proc::{self, index, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
#[cfg(test)]
Expand Down Expand Up @@ -2409,7 +2408,6 @@ impl<W: Write> Writer<W> {
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Expand Down Expand Up @@ -2546,7 +2544,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
let accessing_wrapped_array = match *base_ty {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(_),
size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
..
} => true,
_ => false,
Expand All @@ -2572,7 +2570,6 @@ impl<W: Write> Writer<W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
Expand Down Expand Up @@ -3733,8 +3730,8 @@ impl<W: Write> Writer<W> {
first_time: false,
};

match size {
crate::ArraySize::Constant(size) => {
match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
writeln!(self.out, "struct {name} {{")?;
writeln!(
self.out,
Expand All @@ -3746,10 +3743,7 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Pending(_) => {
unreachable!()
}
crate::ArraySize::Dynamic => {
proc::ResolvedSize::Runtime => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
}
Expand Down Expand Up @@ -6147,11 +6141,9 @@ mod workgroup_mem_init {
writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
}
crate::TypeInner::Array { base, size, .. } => {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => unreachable!(),
let count = match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => unreachable!(),
};

access_stack.enter_array(|access_stack, array_depth| {
Expand Down
87 changes: 45 additions & 42 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Span, Statement, TypeInner, WithSpan,
Span, Statement, Type, TypeInner, UniqueArena, WithSpan,
};
use std::{borrow::Cow, collections::HashSet, mem};
use thiserror::Error;
Expand Down Expand Up @@ -51,6 +51,7 @@ pub fn process_overrides<'a>(
return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
}

let original_module_types = &module.types;
let mut module = module.clone();

// A map from override handles to the handles of the constants
Expand Down Expand Up @@ -196,7 +197,13 @@ pub fn process_overrides<'a>(
}
module.entry_points = entry_points;

process_pending(&mut module, &override_map, &adjusted_global_expressions)?;
process_pending(
&mut module.types,
original_module_types,
&module.constants,
&override_map,
&adjusted_global_expressions,
);

// Now that we've rewritten all the expressions, we need to
// recompute their types and other metadata. For the time being,
Expand All @@ -208,61 +215,57 @@ pub fn process_overrides<'a>(
}

fn process_pending(
module: &mut Module,
types: &mut UniqueArena<Type>,
original_module_types: &UniqueArena<Type>,
constants: &Arena<Constant>,
override_map: &HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
) -> Result<(), PipelineConstantError> {
for (handle, ty) in module.types.clone().iter() {
) {
for (handle, ty) in original_module_types.iter() {
if let TypeInner::Array {
base,
size: crate::ArraySize::Pending(size),
stride,
} = ty.inner
{
let expr = match size {
match size {
crate::PendingArraySize::Expression(size_expr) => {
adjusted_global_expressions[size_expr]
let expr = adjusted_global_expressions[size_expr];
if expr != size_expr {
types.replace(
handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
}
crate::PendingArraySize::Override(size_override) => {
module.constants[override_map[size_override]].init
let expr = constants[override_map[size_override]].init;
types.replace(
handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
};
let value = module
.to_ctx()
.eval_expr_to_u32(expr)
.map(|n| {
if n == 0 {
Err(PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(
module.global_expressions.get_span(expr),
"evaluated to zero",
),
))
} else {
Ok(std::num::NonZeroU32::new(n).unwrap())
}
})
.map_err(|_| {
PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(module.global_expressions.get_span(expr), "negative"),
)
})??;
module.types.replace(
handle,
crate::Type {
name: None,
inner: TypeInner::Array {
base,
size: crate::ArraySize::Constant(value),
stride,
},
},
);
}
}
Ok(())
}

fn process_workgroup_size_override(
Expand Down
7 changes: 3 additions & 4 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,12 @@ impl BlockContext<'_> {
block: &mut Block,
) -> Result<MaybeKnown<u32>, Error> {
let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
match sequence_ty.indexable_length(self.ir_module) {
match sequence_ty
.indexable_length(self.ir_module, crate::ArraySize::indexable_length_resolved)
{
Ok(crate::proc::IndexableLength::Known(known_length)) => {
Ok(MaybeKnown::Known(known_length))
}
Ok(crate::proc::IndexableLength::Pending) => {
unreachable!()
}
Ok(crate::proc::IndexableLength::Dynamic) => {
let length_id = self.write_runtime_array_length(sequence, block)?;
Ok(MaybeKnown::Computed(length_id))
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub enum Error {
Validation(&'static str),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
Loading

0 comments on commit d0e92aa

Please sign in to comment.