Skip to content

Commit

Permalink
Schedule Compaction Minimizes Par Threads (#1774)
Browse files Browse the repository at this point in the history
* modify pass ordering

* more compact compaction

* another small change

* dumb mistake

* another silly mistake

* documentation

* rewrite tests

* rewrite test

* code cleaning
  • Loading branch information
calebmkim authored Nov 11, 2023
1 parent 2f18cb0 commit a43513a
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 150 deletions.
112 changes: 82 additions & 30 deletions calyx-opt/src/passes/schedule_compaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,106 @@ impl Visitor for ScheduleCompaction {

if let Ok(order) = algo::toposort(&total_order, None) {
let mut total_time: u64 = 0;
let mut stmts: Vec<ir::StaticControl> = Vec::new();

// First we build the schedule.

for i in order {
let mut start: u64 = 0;
for node in dependency.get(&i).unwrap() {
let allow_start = schedule[node] + latency_map[node];
if allow_start > start {
start = allow_start;
}
}
// Start time is when the latest dependency finishes
let start = dependency
.get(&i)
.unwrap()
.iter()
.map(|node| schedule[node] + latency_map[node])
.max()
.unwrap_or(0);
schedule.insert(i, start);
total_time = std::cmp::max(start + latency_map[&i], total_time);
}

// We sort the schedule by start time.
let mut sorted_schedule: Vec<(NodeIndex, u64)> =
schedule.into_iter().collect();
sorted_schedule
.sort_by(|(k1, v1), (k2, v2)| (v1, k1).cmp(&(v2, k2)));
// Threads for the static par, where each entry is (thread, thread_latency)
let mut par_threads: Vec<(Vec<ir::StaticControl>, u64)> =
Vec::new();

// We encode the schedule attempting to minimize the number of
// par threads.
'outer: for (i, start) in sorted_schedule {
let control = total_order[i].take().unwrap();
let mut st_seq_stmts: Vec<ir::StaticControl> = Vec::new();
for (thread, thread_latency) in par_threads.iter_mut() {
if *thread_latency <= start {
if *thread_latency < start {
// Might need a no-op group so the schedule starts correctly
let no_op = builder.add_static_group(
"no-op",
start - *thread_latency,
);
thread.push(ir::StaticControl::Enable(
ir::StaticEnable {
group: no_op,
attributes: ir::Attributes::default(),
},
));
*thread_latency = start;
}
thread.push(control);
*thread_latency += latency_map[&i];
continue 'outer;
}
}
// We must create a new par thread.
if start > 0 {
// If start > 0, then we must add a delay to the start of the
// group.
let no_op = builder.add_static_group("no-op", start);

st_seq_stmts.push(ir::StaticControl::Enable(
ir::StaticEnable {
let no_op_enable =
ir::StaticControl::Enable(ir::StaticEnable {
group: no_op,
attributes: ir::Attributes::default(),
},
});
par_threads.push((
vec![no_op_enable, control],
start + latency_map[&i],
));
} else {
par_threads.push((vec![control], latency_map[&i]));
}
if start + latency_map[&i] > total_time {
total_time = start + latency_map[&i];
}
}

st_seq_stmts.push(control);
stmts.push(ir::StaticControl::Seq(ir::StaticSeq {
stmts: st_seq_stmts,
// Turn Vec<ir::StaticControl> -> StaticSeq
let mut par_control_threads: Vec<ir::StaticControl> = Vec::new();
for (thread, thread_latency) in par_threads {
par_control_threads.push(ir::StaticControl::Seq(
ir::StaticSeq {
stmts: thread,
attributes: ir::Attributes::default(),
latency: thread_latency,
},
));
}
// Double checking that we have built the static par correctly.
let max = par_control_threads.iter().map(|c| c.get_latency()).max();
assert!(max.unwrap() == total_time, "The schedule expects latency {}. The static par that was built has latency {}", total_time, max.unwrap());

if par_control_threads.len() == 1 {
let c = Vec::pop(&mut par_control_threads).unwrap();
Ok(Action::static_change(c))
} else {
let s_par = ir::StaticControl::Par(ir::StaticPar {
stmts: par_control_threads,
attributes: ir::Attributes::default(),
latency: start + latency_map[&i],
}));
latency: total_time,
});
Ok(Action::static_change(s_par))
}

let s_par = ir::StaticControl::Par(ir::StaticPar {
stmts,
attributes: ir::Attributes::default(),
latency: total_time,
});
return Ok(Action::static_change(s_par));
} else {
println!(
panic!(
"Error when producing topo sort. Dependency graph has a cycle."
);
}
Ok(Action::Continue)
}

fn finish_static_repeat(
Expand Down
78 changes: 39 additions & 39 deletions examples/futil/dot-product.expect
Original file line number Diff line number Diff line change
Expand Up @@ -28,92 +28,92 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
@generated invoke0_done = std_wire(1);
@generated early_reset_cond00_go = std_wire(1);
@generated early_reset_cond00_done = std_wire(1);
@generated early_reset_static_par_go = std_wire(1);
@generated early_reset_static_par_done = std_wire(1);
@generated early_reset_static_seq_go = std_wire(1);
@generated early_reset_static_seq_done = std_wire(1);
@generated wrapper_early_reset_cond00_go = std_wire(1);
@generated wrapper_early_reset_cond00_done = std_wire(1);
@generated while_wrapper_early_reset_static_par_go = std_wire(1);
@generated while_wrapper_early_reset_static_par_done = std_wire(1);
@generated while_wrapper_early_reset_static_seq_go = std_wire(1);
@generated while_wrapper_early_reset_static_seq_done = std_wire(1);
@generated tdcc_go = std_wire(1);
@generated tdcc_done = std_wire(1);
}
wires {
i0.write_en = invoke0_go.out | fsm.out == 4'd1 & early_reset_static_par_go.out ? 1'd1;
i0.write_en = invoke0_go.out | fsm.out == 4'd1 & early_reset_static_seq_go.out ? 1'd1;
i0.clk = clk;
i0.reset = reset;
i0.in = fsm.out == 4'd1 & early_reset_static_par_go.out ? add1.out;
i0.in = fsm.out == 4'd1 & early_reset_static_seq_go.out ? add1.out;
i0.in = invoke0_go.out ? const0.out;
early_reset_cond00_go.in = wrapper_early_reset_cond00_go.out ? 1'd1;
add1.left = fsm.out == 4'd1 & early_reset_static_par_go.out ? i0.out;
add1.right = fsm.out == 4'd1 & early_reset_static_par_go.out ? const3.out;
add1.left = fsm.out == 4'd1 & early_reset_static_seq_go.out ? i0.out;
add1.right = fsm.out == 4'd1 & early_reset_static_seq_go.out ? const3.out;
done = tdcc_done.out ? 1'd1;
fsm.write_en = early_reset_cond00_go.out | early_reset_static_par_go.out ? 1'd1;
fsm.write_en = early_reset_cond00_go.out | early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 4'd0 & early_reset_cond00_go.out ? adder.out;
fsm.in = fsm.out == 4'd0 & early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? 4'd0;
fsm.in = fsm.out != 4'd7 & early_reset_static_par_go.out ? adder0.out;
fsm.in = fsm.out == 4'd0 & early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? 4'd0;
fsm.in = fsm.out != 4'd7 & early_reset_static_seq_go.out ? adder0.out;
adder.left = early_reset_cond00_go.out ? fsm.out;
adder.right = early_reset_cond00_go.out ? 4'd1;
add0.left = fsm.out == 4'd6 & early_reset_static_par_go.out ? v0.read_data;
add0.right = fsm.out == 4'd6 & early_reset_static_par_go.out ? B_read0_0.out;
v0.write_en = fsm.out == 4'd6 & early_reset_static_par_go.out ? 1'd1;
add0.left = fsm.out == 4'd6 & early_reset_static_seq_go.out ? v0.read_data;
add0.right = fsm.out == 4'd6 & early_reset_static_seq_go.out ? B_read0_0.out;
v0.write_en = fsm.out == 4'd6 & early_reset_static_seq_go.out ? 1'd1;
v0.clk = clk;
v0.addr0 = fsm.out == 4'd6 & early_reset_static_par_go.out ? const2.out;
v0.addr0 = fsm.out == 4'd6 & early_reset_static_seq_go.out ? const2.out;
v0.reset = reset;
v0.write_data = fsm.out == 4'd6 & early_reset_static_par_go.out ? add0.out;
comb_reg.write_en = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? 1'd1;
v0.write_data = fsm.out == 4'd6 & early_reset_static_seq_go.out ? add0.out;
comb_reg.write_en = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? 1'd1;
comb_reg.clk = clk;
comb_reg.reset = reset;
comb_reg.in = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? le0.out;
comb_reg.in = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? le0.out;
early_reset_cond00_done.in = ud.out;
while_wrapper_early_reset_static_par_go.in = !while_wrapper_early_reset_static_par_done.out & fsm0.out == 2'd2 & tdcc_go.out ? 1'd1;
while_wrapper_early_reset_static_seq_go.in = !while_wrapper_early_reset_static_seq_done.out & fsm0.out == 2'd2 & tdcc_go.out ? 1'd1;
invoke0_go.in = !invoke0_done.out & fsm0.out == 2'd0 & tdcc_go.out ? 1'd1;
while_wrapper_early_reset_static_par_done.in = !comb_reg.out & fsm.out == 4'd0 ? 1'd1;
tdcc_go.in = go;
A0.clk = clk;
A0.addr0 = fsm.out == 4'd0 & early_reset_static_par_go.out ? i0.out;
A0.addr0 = fsm.out == 4'd0 & early_reset_static_seq_go.out ? i0.out;
A0.reset = reset;
fsm0.write_en = fsm0.out == 2'd3 | fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out | fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out | fsm0.out == 2'd2 & while_wrapper_early_reset_static_par_done.out & tdcc_go.out ? 1'd1;
fsm0.write_en = fsm0.out == 2'd3 | fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out | fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out | fsm0.out == 2'd2 & while_wrapper_early_reset_static_seq_done.out & tdcc_go.out ? 1'd1;
fsm0.clk = clk;
fsm0.reset = reset;
fsm0.in = fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out ? 2'd1;
fsm0.in = fsm0.out == 2'd3 ? 2'd0;
fsm0.in = fsm0.out == 2'd2 & while_wrapper_early_reset_static_par_done.out & tdcc_go.out ? 2'd3;
fsm0.in = fsm0.out == 2'd2 & while_wrapper_early_reset_static_seq_done.out & tdcc_go.out ? 2'd3;
fsm0.in = fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out ? 2'd2;
mult_pipe0.clk = clk;
mult_pipe0.left = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? A_read0_0.out;
mult_pipe0.go = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? 1'd1;
mult_pipe0.left = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? A_read0_0.out;
mult_pipe0.go = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? 1'd1;
mult_pipe0.reset = reset;
mult_pipe0.right = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? B_read0_0.out;
adder0.left = early_reset_static_par_go.out ? fsm.out;
adder0.right = early_reset_static_par_go.out ? 4'd1;
mult_pipe0.right = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? B_read0_0.out;
adder0.left = early_reset_static_seq_go.out ? fsm.out;
adder0.right = early_reset_static_seq_go.out ? 4'd1;
invoke0_done.in = i0.done;
early_reset_static_par_done.in = ud0.out;
le0.left = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? i0.out;
le0.right = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? const1.out;
early_reset_static_seq_go.in = while_wrapper_early_reset_static_seq_go.out ? 1'd1;
le0.left = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? i0.out;
le0.right = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? const1.out;
signal_reg.write_en = fsm.out == 4'd0 & signal_reg.out | fsm.out == 4'd0 & !signal_reg.out & wrapper_early_reset_cond00_go.out ? 1'd1;
signal_reg.clk = clk;
signal_reg.reset = reset;
signal_reg.in = fsm.out == 4'd0 & !signal_reg.out & wrapper_early_reset_cond00_go.out ? 1'd1;
signal_reg.in = fsm.out == 4'd0 & signal_reg.out ? 1'd0;
B0.clk = clk;
B0.addr0 = fsm.out == 4'd0 & early_reset_static_par_go.out ? i0.out;
B0.addr0 = fsm.out == 4'd0 & early_reset_static_seq_go.out ? i0.out;
B0.reset = reset;
B_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd5 & fsm.out < 4'd6) & early_reset_static_par_go.out ? 1'd1;
B_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd5 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
B_read0_0.clk = clk;
B_read0_0.reset = reset;
B_read0_0.in = fsm.out == 4'd0 & early_reset_static_par_go.out ? B0.read_data;
B_read0_0.in = fsm.out == 4'd5 & early_reset_static_par_go.out ? A_read0_0.out;
B_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? B0.read_data;
B_read0_0.in = fsm.out == 4'd5 & early_reset_static_seq_go.out ? A_read0_0.out;
wrapper_early_reset_cond00_go.in = !wrapper_early_reset_cond00_done.out & fsm0.out == 2'd1 & tdcc_go.out ? 1'd1;
wrapper_early_reset_cond00_done.in = fsm.out == 4'd0 & signal_reg.out ? 1'd1;
early_reset_static_seq_done.in = ud0.out;
tdcc_done.in = fsm0.out == 2'd3 ? 1'd1;
early_reset_static_par_go.in = while_wrapper_early_reset_static_par_go.out ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd4 & fsm.out >= 4'd1 & fsm.out < 4'd5 & fsm.out < 4'd5) & early_reset_static_par_go.out ? 1'd1;
while_wrapper_early_reset_static_seq_done.in = !comb_reg.out & fsm.out == 4'd0 ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd4 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
A_read0_0.clk = clk;
A_read0_0.reset = reset;
A_read0_0.in = fsm.out == 4'd0 & early_reset_static_par_go.out ? A0.read_data;
A_read0_0.in = fsm.out == 4'd4 & early_reset_static_par_go.out ? mult_pipe0.out;
A_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? A0.read_data;
A_read0_0.in = fsm.out == 4'd4 & early_reset_static_seq_go.out ? mult_pipe0.out;
}
control {}
}
32 changes: 16 additions & 16 deletions examples/futil/simple.expect
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ static<5> component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done
@generated ud = undef(1);
@generated adder = std_add(3);
@generated signal_reg = std_reg(1);
@generated early_reset_static_par_go = std_wire(1);
@generated early_reset_static_par_done = std_wire(1);
@generated wrapper_early_reset_static_par_go = std_wire(1);
@generated wrapper_early_reset_static_par_done = std_wire(1);
@generated early_reset_static_seq_go = std_wire(1);
@generated early_reset_static_seq_done = std_wire(1);
@generated wrapper_early_reset_static_seq_go = std_wire(1);
@generated wrapper_early_reset_static_seq_done = std_wire(1);
}
wires {
done = wrapper_early_reset_static_par_done.out ? 1'd1;
fsm.write_en = early_reset_static_par_go.out ? 1'd1;
done = wrapper_early_reset_static_seq_done.out ? 1'd1;
fsm.write_en = early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 3'd4 & early_reset_static_par_go.out ? adder.out;
fsm.in = fsm.out == 3'd4 & early_reset_static_par_go.out ? 3'd0;
adder.left = early_reset_static_par_go.out ? fsm.out;
adder.right = early_reset_static_par_go.out ? 3'd1;
wrapper_early_reset_static_par_go.in = go;
wrapper_early_reset_static_par_done.in = fsm.out == 3'd0 & signal_reg.out ? 1'd1;
early_reset_static_par_done.in = ud.out;
signal_reg.write_en = fsm.out == 3'd0 & signal_reg.out | fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_par_go.out ? 1'd1;
fsm.in = fsm.out != 3'd4 & early_reset_static_seq_go.out ? adder.out;
fsm.in = fsm.out == 3'd4 & early_reset_static_seq_go.out ? 3'd0;
adder.left = early_reset_static_seq_go.out ? fsm.out;
adder.right = early_reset_static_seq_go.out ? 3'd1;
wrapper_early_reset_static_seq_done.in = fsm.out == 3'd0 & signal_reg.out ? 1'd1;
early_reset_static_seq_go.in = wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.write_en = fsm.out == 3'd0 & signal_reg.out | fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.clk = clk;
signal_reg.reset = reset;
signal_reg.in = fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_par_go.out ? 1'd1;
signal_reg.in = fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.in = fsm.out == 3'd0 & signal_reg.out ? 1'd0;
early_reset_static_par_go.in = wrapper_early_reset_static_par_go.out ? 1'd1;
early_reset_static_seq_done.in = ud.out;
wrapper_early_reset_static_seq_go.in = go;
}
control {}
}
Loading

0 comments on commit a43513a

Please sign in to comment.