Skip to content

Commit

Permalink
fix: regr_count now returns Uint64 (apache#11731)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-J-Ward authored Jul 30, 2024
1 parent 66a8570 commit cd786e2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
8 changes: 6 additions & 2 deletions datafusion/functions-aggregate/src/regr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ impl AggregateUDFImpl for Regr {
return plan_err!("Covariance requires numeric input types");
}

Ok(DataType::Float64)
if matches!(self.regr_type, RegrType::Count) {
Ok(DataType::UInt64)
} else {
Ok(DataType::Float64)
}
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Expand Down Expand Up @@ -480,7 +484,7 @@ impl Accumulator for RegrAccumulator {
let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
}
RegrType::Count => Ok(ScalarValue::Float64(Some(self.count as f64))),
RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
RegrType::R2 => {
// Only 0/1 point or all x(or y) is the same
let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
Expand Down
32 changes: 16 additions & 16 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4742,35 +4742,35 @@ select regr_sxy(NULL, 'bar');


# regr_*() NULL results
query RRRRRRRRR
query RRIRRRRRR
select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1);
----
NULL NULL 1 NULL 1 1 0 0 0

query RRRRRRRRR
query RRIRRRRRR
select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
query RRIRRRRRR
select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
query RRIRRRRRR
select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL);
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query RRRRRRRRR
query RRIRRRRRR
select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6));
----
NULL NULL 3 NULL 1 4 0 8 0



# regr_*() basic tests
query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
Expand All @@ -4785,7 +4785,7 @@ from (values (1,2), (2,4), (3,6));
----
2 0 3 1 2 4 2 8 4

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
Expand All @@ -4803,7 +4803,7 @@ from aggregate_test_100;


# regr_*() functions ignore NULLs
query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
Expand All @@ -4818,7 +4818,7 @@ from (values (1,NULL), (2,4), (3,6));
----
2 0 2 1 2.5 5 0.5 2 1

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
Expand All @@ -4833,7 +4833,7 @@ from (values (1,NULL), (NULL,4), (3,6));
----
NULL NULL 1 NULL 3 6 0 0 0

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(column2, column1),
regr_intercept(column2, column1),
Expand All @@ -4848,7 +4848,7 @@ from (values (1,NULL), (NULL,4), (NULL,NULL));
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL

query TRRRRRRRRR rowsort
query TRRIRRRRRR rowsort
select
column3,
regr_slope(column2, column1),
Expand All @@ -4873,7 +4873,7 @@ c NULL NULL 1 NULL 1 10 0 0 0
statement ok
set datafusion.execution.batch_size = 1;

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
Expand All @@ -4891,7 +4891,7 @@ from aggregate_test_100;
statement ok
set datafusion.execution.batch_size = 2;

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
Expand All @@ -4909,7 +4909,7 @@ from aggregate_test_100;
statement ok
set datafusion.execution.batch_size = 3;

query RRRRRRRRR
query RRIRRRRRR
select
regr_slope(c12, c11),
regr_intercept(c12, c11),
Expand All @@ -4930,7 +4930,7 @@ set datafusion.execution.batch_size = 8192;


# regr_*() testing retract_batch() from RegrAccumulator's internal implementation
query RRRRRRRRR
query RRIRRRRRR
SELECT
regr_slope(column2, column1) OVER w AS slope,
regr_intercept(column2, column1) OVER w AS intercept,
Expand All @@ -4951,7 +4951,7 @@ NULL NULL 1 NULL 1 2 0 0 0
4.5 -7 3 0.964285714286 4 11 2 42 9
3 0 3 1 5 15 2 18 6

query RRRRRRRRR
query RRIRRRRRR
SELECT
regr_slope(column2, column1) OVER w AS slope,
regr_intercept(column2, column1) OVER w AS intercept,
Expand Down

0 comments on commit cd786e2

Please sign in to comment.