Skip to content

Commit

Permalink
[DSLX] Fix ability to use xN as bit slice bounds.
Browse files Browse the repository at this point in the history
  • Loading branch information
cdleary committed Jan 3, 2025
1 parent be6d8f4 commit 3c2b6d2
Show file tree
Hide file tree
Showing 18 changed files with 193 additions and 120 deletions.
17 changes: 11 additions & 6 deletions docs_src/tutorials/intro_to_parametrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

This tutorial demonstrates how types and functions can be parameterized to
enable them to work on data of different formats and layouts, e.g., for a
function `foo` to work on both u16 and u32 data types, and anywhere in between.
function `foo` to work on both `u16` and `u32` data types, and anywhere in
between.

It's recommended that you're familiar with the concepts in the previous
tutorial,
Expand All @@ -11,7 +12,7 @@ before following this tutorial.

## Simple parametrics

Consider the simple example of the `umax` function
Consider the simple example of a `umax` function -- similar to the `max` function
[in the DSLX standard library](https://github.com/google/xls/tree/main/xls/dslx/stdlib/std.x):

```dslx
Expand Down Expand Up @@ -40,10 +41,12 @@ infer them:
Explicit specification:

```dslx
import std;
fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
if x > y { x } else { y }
}
fn foo(a: u32, b: u16) -> u64 {
std::umax<u32:64>(a as u64, b as u64)
umax<u32:64>(a as u64, b as u64)
}
```

Expand All @@ -53,10 +56,12 @@ are.
Parametric inference:

```dslx
import std;
fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
if x > y { x } else { y }
}
fn foo(a: u32, b: u16) -> u64 {
std::umax(a as u64, b as u64)
umax(a as u64, b as u64)
}
```

Expand Down
11 changes: 7 additions & 4 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,15 @@ absl::Status FunctionConverter::HandleBuiltinCheckedCast(
int64_t old_bit_count,
std::get<InterpValue>(input_bit_count_ctd.value()).GetBitValueViaSign());

if (dynamic_cast<ArrayType*>(output_type.get()) != nullptr ||
dynamic_cast<ArrayType*>(input_type.get()) != nullptr) {
std::optional<BitsLikeProperties> output_bits_like =
GetBitsLike(*output_type);
std::optional<BitsLikeProperties> input_bits_like = GetBitsLike(*input_type);

if (!output_bits_like.has_value() || !input_bits_like.has_value()) {
return IrConversionErrorStatus(
node->span(),
absl::StrFormat("CheckedCast to and from array "
"is not currently supported for IR conversion; "
absl::StrFormat("CheckedCast is only supported for bits-like types in "
"IR conversion; "
"attempted checked cast from: %s to: %s",
input_type->ToString(), output_type->ToString()),
file_table());
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/apfloat.x
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ fn test_fp_lt_2() {
fn to_signed_or_unsigned_int<RESULT_SZ: u32, RESULT_SIGNED: bool, EXP_SZ: u32, FRACTION_SZ: u32>
(x: APFloat<EXP_SZ, FRACTION_SZ>) -> xN[RESULT_SIGNED][RESULT_SZ] {
const WIDE_FRACTION: u32 = FRACTION_SZ + u32:1;
const MAX_FRACTION_SZ: u32 = std::umax(RESULT_SZ, WIDE_FRACTION);
const MAX_FRACTION_SZ: u32 = std::max(RESULT_SZ, WIDE_FRACTION);

const INT_MIN = if RESULT_SIGNED {
(uN[MAX_FRACTION_SZ]:1 << (RESULT_SZ - u32:1)) // or rather, its negative.
Expand Down
135 changes: 65 additions & 70 deletions xls/dslx/stdlib/std.x
Original file line number Diff line number Diff line change
Expand Up @@ -90,74 +90,68 @@ fn unsigned_max_value_test() {
assert_eq(u32:0xffffffff, unsigned_max_value<u32:32>());
}

// Returns the maximum of two signed integers.
pub fn smax<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x > y { x } else { y } }
// Returns the maximum of two (signed or unsigned) integers.
pub fn max<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x > y { x } else { y } }

#[test]
fn smax_test() {
assert_eq(s2:0, smax(s2:0, s2:0));
assert_eq(s2:1, smax(s2:-1, s2:1));
assert_eq(s7:-3, smax(s7:-3, s7:-6));
fn max_test_signed() {
assert_eq(s2:0, max(s2:0, s2:0));
assert_eq(s2:1, max(s2:-1, s2:1));
assert_eq(s7:-3, max(s7:-3, s7:-6));
}

// Returns the maximum of two unsigned integers.
pub fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } }

#[test]
fn umax_test() {
assert_eq(u1:1, umax(u1:1, u1:0));
assert_eq(u1:1, umax(u1:1, u1:1));
assert_eq(u2:3, umax(u2:3, u2:2));
fn max_test_unsigned() {
assert_eq(u1:1, max(u1:1, u1:0));
assert_eq(u1:1, max(u1:1, u1:1));
assert_eq(u2:3, max(u2:3, u2:2));
}

// Returns the maximum of two signed integers.
pub fn smin<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x < y { x } else { y } }
// Returns the minimum of two (signed or unsigned) integers.
pub fn min<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x < y { x } else { y } }

#[test]
fn smin_test() {
assert_eq(s1:0, smin(s1:0, s1:0));
assert_eq(s1:-1, smin(s1:0, s1:1));
assert_eq(s1:-1, smin(s1:1, s1:0));
assert_eq(s1:-1, smin(s1:1, s1:1));

assert_eq(s2:-2, smin(s2:0, s2:-2));
assert_eq(s2:-1, smin(s2:0, s2:-1));
assert_eq(s2:0, smin(s2:0, s2:0));
assert_eq(s2:0, smin(s2:0, s2:1));
fn min_test_unsigned() {
assert_eq(u1:0, min(u1:1, u1:0));
assert_eq(u1:1, min(u1:1, u1:1));
assert_eq(u2:2, min(u2:3, u2:2));
}

assert_eq(s2:-2, smin(s2:1, s2:-2));
assert_eq(s2:-1, smin(s2:1, s2:-1));
assert_eq(s2:0, smin(s2:1, s2:0));
assert_eq(s2:1, smin(s2:1, s2:1));
#[test]
fn min_test_signed() {
assert_eq(s1:0, min(s1:0, s1:0));
assert_eq(s1:-1, min(s1:0, s1:1));
assert_eq(s1:-1, min(s1:1, s1:0));
assert_eq(s1:-1, min(s1:1, s1:1));

assert_eq(s2:-2, smin(s2:-2, s2:-2));
assert_eq(s2:-2, smin(s2:-2, s2:-1));
assert_eq(s2:-2, smin(s2:-2, s2:0));
assert_eq(s2:-2, smin(s2:-2, s2:1));
assert_eq(s2:-2, min(s2:0, s2:-2));
assert_eq(s2:-1, min(s2:0, s2:-1));
assert_eq(s2:0, min(s2:0, s2:0));
assert_eq(s2:0, min(s2:0, s2:1));

assert_eq(s2:-2, smin(s2:-1, s2:-2));
assert_eq(s2:-1, smin(s2:-1, s2:-1));
assert_eq(s2:-1, smin(s2:-1, s2:0));
assert_eq(s2:-1, smin(s2:-1, s2:1));
}
assert_eq(s2:-2, min(s2:1, s2:-2));
assert_eq(s2:-1, min(s2:1, s2:-1));
assert_eq(s2:0, min(s2:1, s2:0));
assert_eq(s2:1, min(s2:1, s2:1));

// Returns the minimum of two unsigned integers.
pub fn umin<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x < y { x } else { y } }
assert_eq(s2:-2, min(s2:-2, s2:-2));
assert_eq(s2:-2, min(s2:-2, s2:-1));
assert_eq(s2:-2, min(s2:-2, s2:0));
assert_eq(s2:-2, min(s2:-2, s2:1));

#[test]
fn umin_test() {
assert_eq(u1:0, umin(u1:1, u1:0));
assert_eq(u1:1, umin(u1:1, u1:1));
assert_eq(u2:2, umin(u2:3, u2:2));
assert_eq(s2:-2, min(s2:-1, s2:-2));
assert_eq(s2:-1, min(s2:-1, s2:-1));
assert_eq(s2:-1, min(s2:-1, s2:0));
assert_eq(s2:-1, min(s2:-1, s2:1));
}

// Returns unsigned add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
pub fn uadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
pub fn uadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
(x as uN[R]) + (y as uN[R])
}

// Returns signed add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
pub fn sadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
pub fn sadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
(x as sN[R]) + (y as sN[R])
}

Expand Down Expand Up @@ -773,7 +767,7 @@ fn test_to_unsigned() {
// let result : (bool, u16) = uadd_with_overflow<u32:16>(x, y);
//
pub fn uadd_with_overflow
<V: u32, N: u32, M: u32, MAX_N_M: u32 = {umax(N, M)}, MAX_N_M_V: u32 = {umax(MAX_N_M, V)}>
<V: u32, N: u32, M: u32, MAX_N_M: u32 = {max(N, M)}, MAX_N_M_V: u32 = {max(MAX_N_M, V)}>
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {

let x_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(x);
Expand Down Expand Up @@ -801,47 +795,48 @@ fn test_uadd_with_overflow() {
}

// Extract bits given a fixed-point integer with a constant offset.
// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + fixed_shift, to_exclusive)];
// (x_extended << fixed_shift)[from_inclusive:to_exclusive]
// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + FIXED_SHIFT, TO_EXCLUSIVE)];
// (x_extended << FIXED_SHIFT)[FROM_INCLUSIVE:TO_EXCLUSIVE]
//
// This function behaves as-if x has reasonably infinite precision so that
// the result is zero-padded if from_inclusive or to_exclusive are out of
// the result is zero-padded if FROM_INCLUSIVE or TO_EXCLUSIVE are out of
// range of the original x's bitwidth.
//
// If to_exclusive <= from_exclusive, the result will be a zero-bit uN[0].
// If TO_EXCLUSIVE <= FROM_INCLUSIVE, the result will be a zero-bit uN[0].
pub fn extract_bits
<from_inclusive: u32, to_exclusive: u32, fixed_shift: u32, N: u32,
extract_width: u32 = {smax(s32:0, to_exclusive as s32 - from_inclusive as s32) as u32}>
(x: uN[N]) -> uN[extract_width] {
if to_exclusive <= from_inclusive {
uN[extract_width]:0
<FROM_INCLUSIVE: u32, TO_EXCLUSIVE: u32, FIXED_SHIFT: u32, N: u32,
EXTRACT_WIDTH: u32 = {max(s32:0, TO_EXCLUSIVE as s32 - FROM_INCLUSIVE as s32) as u32}>
(x: uN[N]) -> uN[EXTRACT_WIDTH] {
if TO_EXCLUSIVE <= FROM_INCLUSIVE {
uN[EXTRACT_WIDTH]:0
} else {
// With a non-zero fixed width, all lower bits of index < fixed_shift are
// are zero.
let lower_bits =
uN[checked_cast<u32>(smax(s32:0, fixed_shift as s32 - from_inclusive as s32))]:0;
uN[checked_cast<u32>(max(s32:0, FIXED_SHIFT as s32 - FROM_INCLUSIVE as s32))]:0;

// Based on the input of N bits and a fixed shift, there are an effective
// count of N + fixed_shift known bits. All bits of index >
// N + fixed_shift - 1 are zero's.
const UPPER_BIT_COUNT = checked_cast<u32>(
smax(s32:0, N as s32 + fixed_shift as s32 - to_exclusive as s32 - s32:1));
let upper_bits = uN[UPPER_BIT_COUNT]:0;
max(s32:0, N as s32 + FIXED_SHIFT as s32 - TO_EXCLUSIVE as s32 - s32:1));
const UPPER_BITS = uN[UPPER_BIT_COUNT]:0;

if fixed_shift < from_inclusive {
if FIXED_SHIFT < FROM_INCLUSIVE {
// The bits extracted start within or after the middle span.
// upper_bits ++ middle_bits
let middle_bits = upper_bits ++
x[smin(from_inclusive as s32 - fixed_shift as s32, N as s32)
:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
(upper_bits ++ middle_bits) as uN[extract_width]
} else if fixed_shift <= to_exclusive {
const FROM: s32 = min(FROM_INCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
let middle_bits = UPPER_BITS ++ x[FROM:TO];
(UPPER_BITS ++ middle_bits) as uN[EXTRACT_WIDTH]
} else if FIXED_SHIFT <= TO_EXCLUSIVE {
// The bits extracted start within the fixed_shift span.
let middle_bits = x[0:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
let middle_bits = x[0:TO];

(upper_bits ++ middle_bits ++ lower_bits) as uN[extract_width]
(UPPER_BITS ++ middle_bits ++ lower_bits) as uN[EXTRACT_WIDTH]
} else {
uN[extract_width]:0
uN[EXTRACT_WIDTH]:0
}
}
}
Expand Down Expand Up @@ -928,7 +923,7 @@ pub fn umul_with_overflow
<V: u32, N: u32, M: u32, N_lower_bits: u32 = {N >> u32:1},
N_upper_bits: u32 = {N - N_lower_bits}, M_lower_bits: u32 = {M >> u32:1},
M_upper_bits: u32 = {M - M_lower_bits},
Min_N_M_lower_bits: u32 = {umin(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
Min_N_M_lower_bits: u32 = {min(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {
// Break x and y into two halves.
// x = x1 ++ x0,
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ dslx_lang_test(name = "xn_type_equivalence")

dslx_lang_test(name = "xn_signedness_properties")

dslx_lang_test(name = "xn_slice_bounds")

dslx_lang_test(
name = "parametric_shift",
# TODO(leary): 2023-08-14 Runs into "cannot translate zero length bitvector
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/tests/errors/error_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,8 @@ def test_equals_rhs_undefined_nameref(self):

def test_umin_type_mismatch(self):
stderr = self._run('xls/dslx/tests/errors/umin_type_mismatch.x')
self.assertIn('umin_type_mismatch.x:21:12-21:27', stderr)
self.assertIn('XlsTypeError: uN[N] vs uN[8]', stderr)
self.assertIn('umin_type_mismatch.x:21:13-21:28', stderr)
self.assertIn('saw: 42; then: 8', stderr)

def test_diag_block_with_trailing_semi(self):
stderr = self._run(
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/tests/errors/spawn_wrong_argc.x
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub proc foo {
config () { () }

next(state: ()) {
std::umin(u32:1, u32:2);
std::min(u32:1, u32:2);
()
}
}
Expand All @@ -37,7 +37,7 @@ proc test_case {
}

next(state: ()) {
std::umin(u32:1, u32:2);
std::min(u32:1, u32:2);
let tok = send(join(), terminator, true);
()
}
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/tests/errors/umin_type_mismatch.x
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ const MY_U32 = u42:42;
const MY_U8 = u8:42;

fn f() -> u32 {
std::umin(MY_U32, MY_U8)
std::min(MY_U32, MY_U8)
}
33 changes: 33 additions & 0 deletions xls/dslx/tests/xn_slice_bounds.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2025 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

const S = true;
const N = u32:32;

type MyS32 = xN[S][N];

fn from_to(x: u32) -> u8 { x[MyS32:0:MyS32:8] }

fn to(x: u32) -> u8 { x[:MyS32:8] }

fn from(x: u32) -> u8 { x[MyS32:-8:] }

fn main(x: u32) -> u8[3] { [from_to(x), to(x), from(x)] }

#[test]
fn test_main() {
assert_eq(from_to(u32:0x12345678), u8:0x78);
assert_eq(to(u32:0x12345678), u8:0x78);
assert_eq(from(u32:0x12345678), u8:0x12);
}
Loading

0 comments on commit 3c2b6d2

Please sign in to comment.