diff --git a/src/sqllogictest/src/parser.rs b/src/sqllogictest/src/parser.rs index 0164795854861..b75bbcfab5c55 100644 --- a/src/sqllogictest/src/parser.rs +++ b/src/sqllogictest/src/parser.rs @@ -259,6 +259,10 @@ impl<'a> Parser<'a> { &DOUBLE_LINE_REGEX })? .trim_start(); + // We don't want to advance the expected output past the column names so rewriting works, + // but need to be able to parse past them, so remember the position before possible column + // names. + let query_output_str = output_str; let column_names = if check_column_names { Some( split_at(&mut output_str, &LINE_REGEX)? @@ -321,7 +325,7 @@ impl<'a> Parser<'a> { column_names, mode: self.mode, output, - output_str, + output_str: query_output_str, }), location, }) diff --git a/src/sqllogictest/src/runner.rs b/src/sqllogictest/src/runner.rs index b237b84b67c2d..de1f8b6298c4c 100644 --- a/src/sqllogictest/src/runner.rs +++ b/src/sqllogictest/src/runner.rs @@ -122,6 +122,7 @@ pub enum Outcome<'a> { WrongColumnNames { expected_column_names: &'a Vec, actual_column_names: Vec, + actual_output: Output, location: Location, }, OutputFailure { @@ -246,6 +247,7 @@ impl fmt::Display for Outcome<'_> { WrongColumnNames { expected_column_names, actual_column_names, + actual_output: _, location, } => write!( f, @@ -1391,25 +1393,6 @@ impl<'a> RunnerInner<'a> { Ok(query_output) => query_output, }; - // Various checks as long as there are returned rows. - if let Some(row) = rows.get(0) { - // check column names - if let Some(expected_column_names) = expected_column_names { - let actual_column_names = row - .columns() - .iter() - .map(|t| ColumnName::from(t.name())) - .collect::>(); - if expected_column_names != &actual_column_names { - return Ok(Outcome::WrongColumnNames { - expected_column_names, - actual_column_names, - location, - }); - } - } - } - // format output let mut formatted_rows = vec![]; for row in &rows { @@ -1433,6 +1416,26 @@ impl<'a> RunnerInner<'a> { values.sort(); } + // Various checks as long as there are returned rows. + if let Some(row) = rows.get(0) { + // check column names + if let Some(expected_column_names) = expected_column_names { + let actual_column_names = row + .columns() + .iter() + .map(|t| ColumnName::from(t.name())) + .collect::>(); + if expected_column_names != &actual_column_names { + return Ok(Outcome::WrongColumnNames { + expected_column_names, + actual_column_names, + actual_output: Output::Values(values), + location, + }); + } + } + } + // check output match expected_output { Output::Values(expected_values) => { @@ -1901,6 +1904,41 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() writeln!(runner.config.stdout, "==> {}", filename.display()); let mut in_transaction = false; + fn append_values_output( + buf: &mut RewriteBuffer, + input: &String, + expected_output: &str, + mode: &Mode, + types: &Vec, + column_names: Option<&Vec>, + actual_output: &Vec, + ) { + buf.append_header(input, expected_output, column_names); + + for (i, row) in actual_output.chunks(types.len()).enumerate() { + match mode { + // In Cockroach mode, output each row on its own line, with + // two spaces between each column. + Mode::Cockroach => { + if i != 0 { + buf.append("\n"); + } + buf.append(&row.join(" ")); + } + // In standard mode, output each value on its own line, + // and ignore row boundaries. + Mode::Standard => { + for (j, col) in row.iter().enumerate() { + if i != 0 || j != 0 { + buf.append("\n"); + } + buf.append(col); + } + } + } + } + } + for record in parser.parse_records()? { let outcome = runner.run_record(&record, &mut in_transaction).await?; @@ -1915,6 +1953,7 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() output: Output::Values(_), output_str: expected_output, types, + column_names, .. }), .. @@ -1924,32 +1963,43 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() .. }, ) => { - { - buf.append_header(&input, expected_output); - - for (i, row) in actual_output.chunks(types.len()).enumerate() { - match mode { - // In Cockroach mode, output each row on its own line, with - // two spaces between each column. - Mode::Cockroach => { - if i != 0 { - buf.append("\n"); - } - buf.append(&row.join(" ")); - } - // In standard mode, output each value on its own line, - // and ignore row boundaries. - Mode::Standard => { - for (j, col) in row.iter().enumerate() { - if i != 0 || j != 0 { - buf.append("\n"); - } - buf.append(col); - } - } - } - } - } + append_values_output( + &mut buf, + &input, + expected_output, + mode, + types, + column_names.as_ref(), + actual_output, + ); + } + ( + Record::Query { + output: + Ok(QueryOutput { + mode, + output: Output::Values(_), + output_str: expected_output, + types, + .. + }), + .. + }, + Outcome::WrongColumnNames { + actual_column_names, + actual_output: Output::Values(actual_output), + .. + }, + ) => { + append_values_output( + &mut buf, + &input, + expected_output, + mode, + types, + Some(actual_column_names), + actual_output, + ); } ( Record::Query { @@ -1957,6 +2007,7 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() Ok(QueryOutput { output: Output::Hashed { .. }, output_str: expected_output, + column_names, .. }), .. @@ -1966,7 +2017,7 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() .. }, ) => { - buf.append_header(&input, expected_output); + buf.append_header(&input, expected_output, column_names.as_ref()); buf.append(format!("{} values hashing to {}\n", num_values, md5).as_str()) } @@ -1980,7 +2031,7 @@ pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<() .. }, ) => { - buf.append_header(&input, expected_output); + buf.append_header(&input, expected_output, None); for (i, row) in actual_output.iter().enumerate() { if i != 0 { @@ -2059,7 +2110,12 @@ impl<'a> RewriteBuffer<'a> { self.output.push_str(s); } - fn append_header(&mut self, input: &String, expected_output: &str) { + fn append_header( + &mut self, + input: &String, + expected_output: &str, + column_names: Option<&Vec>, + ) { // Output everything before this record. // TODO(benesch): is it possible to rewrite this to avoid `as`? #[allow(clippy::as_conversions)] @@ -2074,6 +2130,18 @@ impl<'a> RewriteBuffer<'a> { } else if self.peek_last(6) != "\n----\n" { self.append("\n----\n"); } + + let Some(names) = column_names else { + return; + }; + self.append( + &names + .iter() + .map(|name| name.as_str().replace('␠', " ")) + .collect::>() + .join(" "), + ); + self.append("\n"); } fn rewrite_expected_error(