Re: BUG #19340: Wrong result from CORR() function - Mailing list pgsql-bugs
| From | Tom Lane |
|---|---|
| Subject | Re: BUG #19340: Wrong result from CORR() function |
| Date | |
| Msg-id | 434484.1764707203@sss.pgh.pa.us Whole thread Raw |
| In response to | Re: BUG #19340: Wrong result from CORR() function (Dean Rasheed <dean.a.rasheed@gmail.com>) |
| Responses |
Re: BUG #19340: Wrong result from CORR() function
|
| List | pgsql-bugs |
Dean Rasheed <dean.a.rasheed@gmail.com> writes:
> On Tue, 2 Dec 2025 at 17:22, Tom Lane <tgl@sss.pgh.pa.us> wrote:
>> I wonder whether it'd be worth carrying additional state to
>> check that explicitly (instead of assuming that "if (Sxx == 0 ||
>> Syy == 0)" will catch it).
> I wondered the same thing. It's not nice to have to do that, but
> clearly the existing test for constant inputs is no good. The question
> is, do we really want to spend extra cycles on every query just to
> catch this odd corner case?
I experimented with the attached patch, which is very incomplete;
I just carried it far enough to be able to run performance checks on
the modified code, and so all the binary statistics aggregates except
corr() are broken. I observe about 2% slowdown on this test case:
SELECT corr( 0.09 , 0.09000001 ) FROM generate_series(1,100000000);
I think that any real-world usage is going to expend more effort
obtaining the input data than this test does, so 2% should be a
conservative upper bound on the cost. Seems to me that getting
NULL-or-not right is probably worth a percent or so.
If anyone feels differently, another idea could be to use a
separate state transition function for corr() that skips the
accumulation steps that corr() doesn't use. But I agree with
the pre-existing decision to use just one transition function
for all the binary aggregates.
If this seems like a reasonable approach, I'll see about finishing
out the patch.
regards, tom lane
diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c
index 7b97d2be6ca..5d8954a34d0 100644
--- a/src/backend/utils/adt/float.c
+++ b/src/backend/utils/adt/float.c
@@ -3319,9 +3319,15 @@ float8_stddev_samp(PG_FUNCTION_ARGS)
* As with the preceding aggregates, we use the Youngs-Cramer algorithm to
* reduce rounding errors in the aggregate final functions.
*
- * The transition datatype for all these aggregates is a 6-element array of
+ * The transition datatype for all these aggregates is a 9-element array of
* float8, holding the values N, Sx=sum(X), Sxx=sum((X-Sx/N)^2), Sy=sum(Y),
- * Syy=sum((Y-Sy/N)^2), Sxy=sum((X-Sx/N)*(Y-Sy/N)) in that order.
+ * Syy=sum((Y-Sy/N)^2), Sxy=sum((X-Sx/N)*(Y-Sy/N)), firstX, firstY, DIFF,
+ * in that order.
+ *
+ * DIFF, like N, is treated as an integer. Bit 0 is set if we saw distinct
+ * X inputs, and bit 1 is set if we saw distinct Y inputs. This allows us
+ * to detect constant inputs exactly, which is important for deciding whether
+ * some outputs should be NULL.
*
* Note that Y is the first argument to all these aggregates!
*
@@ -3345,17 +3351,23 @@ float8_regr_accum(PG_FUNCTION_ARGS)
Sy,
Syy,
Sxy,
+ firstX,
+ firstY,
tmpX,
tmpY,
scale;
+ int diff;
- transvalues = check_float8_array(transarray, "float8_regr_accum", 6);
+ transvalues = check_float8_array(transarray, "float8_regr_accum", 9);
N = transvalues[0];
Sx = transvalues[1];
Sxx = transvalues[2];
Sy = transvalues[3];
Syy = transvalues[4];
Sxy = transvalues[5];
+ firstX = transvalues[6];
+ firstY = transvalues[7];
+ diff = transvalues[8];
/*
* Use the Youngs-Cramer algorithm to incorporate the new values into the
@@ -3373,6 +3385,19 @@ float8_regr_accum(PG_FUNCTION_ARGS)
Syy += tmpY * tmpY * scale;
Sxy += tmpX * tmpY * scale;
+ /*
+ * Check to see if we have seen distinct inputs. In normal use, diff
+ * will reach 3 very soon and then we can stop checking.
+ */
+ if (diff != 3)
+ {
+ /* Need SQL-style comparison of NaNs here */
+ if (float8_ne(newvalX, firstX))
+ diff |= 1;
+ if (float8_ne(newvalY, firstY))
+ diff |= 2;
+ }
+
/*
* Overflow check. We only report an overflow error when finite
* inputs lead to infinite results. Note also that Sxx, Syy and Sxy
@@ -3410,6 +3435,8 @@ float8_regr_accum(PG_FUNCTION_ARGS)
Sxx = Sxy = get_float8_nan();
if (isnan(newvalY) || isinf(newvalY))
Syy = Sxy = get_float8_nan();
+ firstX = newvalX;
+ firstY = newvalY;
}
/*
@@ -3425,12 +3452,15 @@ float8_regr_accum(PG_FUNCTION_ARGS)
transvalues[3] = Sy;
transvalues[4] = Syy;
transvalues[5] = Sxy;
+ transvalues[6] = firstX;
+ transvalues[7] = firstY;
+ transvalues[8] = diff;
PG_RETURN_ARRAYTYPE_P(transarray);
}
else
{
- Datum transdatums[6];
+ Datum transdatums[9];
ArrayType *result;
transdatums[0] = Float8GetDatumFast(N);
@@ -3439,8 +3469,11 @@ float8_regr_accum(PG_FUNCTION_ARGS)
transdatums[3] = Float8GetDatumFast(Sy);
transdatums[4] = Float8GetDatumFast(Syy);
transdatums[5] = Float8GetDatumFast(Sxy);
+ transdatums[6] = Float8GetDatumFast(firstX);
+ transdatums[7] = Float8GetDatumFast(firstY);
+ transdatums[8] = Float8GetDatum(diff);
- result = construct_array_builtin(transdatums, 6, FLOAT8OID);
+ result = construct_array_builtin(transdatums, 9, FLOAT8OID);
PG_RETURN_ARRAYTYPE_P(result);
}
@@ -3730,27 +3763,25 @@ float8_corr(PG_FUNCTION_ARGS)
{
ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
float8 *transvalues;
- float8 N,
- Sxx,
+ float8 Sxx,
Syy,
Sxy;
+ int diff;
- transvalues = check_float8_array(transarray, "float8_corr", 6);
- N = transvalues[0];
+ transvalues = check_float8_array(transarray, "float8_corr", 9);
Sxx = transvalues[2];
Syy = transvalues[4];
Sxy = transvalues[5];
+ diff = transvalues[8];
- /* if N is 0 we should return NULL */
- if (N < 1.0)
+ /*
+ * Per spec, we must return NULL if N is zero, all X inputs are equal, or
+ * all Y inputs are equal. Checking the diff mask covers all three cases.
+ */
+ if (diff != 3)
PG_RETURN_NULL();
/* Note that Sxx and Syy are guaranteed to be non-negative */
-
- /* per spec, return NULL for horizontal and vertical lines */
- if (Sxx == 0 || Syy == 0)
- PG_RETURN_NULL();
-
PG_RETURN_FLOAT8(Sxy / sqrt(Sxx * Syy));
}
diff --git a/src/include/catalog/pg_aggregate.dat b/src/include/catalog/pg_aggregate.dat
index 870769e8f14..68dc1329ea0 100644
--- a/src/include/catalog/pg_aggregate.dat
+++ b/src/include/catalog/pg_aggregate.dat
@@ -505,7 +505,7 @@
aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0}' },
{ aggfnoid => 'corr', aggtransfn => 'float8_regr_accum',
aggfinalfn => 'float8_corr', aggcombinefn => 'float8_regr_combine',
- aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0}' },
+ aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0,0,0,0}' },
# boolean-and and boolean-or
{ aggfnoid => 'bool_and', aggtransfn => 'booland_statefunc',
pgsql-bugs by date: