Re: BUG #19340: Wrong result from CORR() function - Mailing list pgsql-bugs

From Tom Lane
Subject Re: BUG #19340: Wrong result from CORR() function
Date
Msg-id 434484.1764707203@sss.pgh.pa.us
Whole thread Raw
In response to Re: BUG #19340: Wrong result from CORR() function  (Dean Rasheed <dean.a.rasheed@gmail.com>)
Responses Re: BUG #19340: Wrong result from CORR() function
List pgsql-bugs
Dean Rasheed <dean.a.rasheed@gmail.com> writes:
> On Tue, 2 Dec 2025 at 17:22, Tom Lane <tgl@sss.pgh.pa.us> wrote:
>> I wonder whether it'd be worth carrying additional state to
>> check that explicitly (instead of assuming that "if (Sxx == 0 ||
>> Syy == 0)" will catch it).

> I wondered the same thing. It's not nice to have to do that, but
> clearly the existing test for constant inputs is no good. The question
> is, do we really want to spend extra cycles on every query just to
> catch this odd corner case?

I experimented with the attached patch, which is very incomplete;
I just carried it far enough to be able to run performance checks on
the modified code, and so all the binary statistics aggregates except
corr() are broken.  I observe about 2% slowdown on this test case:

SELECT corr( 0.09 , 0.09000001 ) FROM generate_series(1,100000000);

I think that any real-world usage is going to expend more effort
obtaining the input data than this test does, so 2% should be a
conservative upper bound on the cost.  Seems to me that getting
NULL-or-not right is probably worth a percent or so.

If anyone feels differently, another idea could be to use a
separate state transition function for corr() that skips the
accumulation steps that corr() doesn't use.  But I agree with
the pre-existing decision to use just one transition function
for all the binary aggregates.

If this seems like a reasonable approach, I'll see about finishing
out the patch.

            regards, tom lane

diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c
index 7b97d2be6ca..5d8954a34d0 100644
--- a/src/backend/utils/adt/float.c
+++ b/src/backend/utils/adt/float.c
@@ -3319,9 +3319,15 @@ float8_stddev_samp(PG_FUNCTION_ARGS)
  * As with the preceding aggregates, we use the Youngs-Cramer algorithm to
  * reduce rounding errors in the aggregate final functions.
  *
- * The transition datatype for all these aggregates is a 6-element array of
+ * The transition datatype for all these aggregates is a 9-element array of
  * float8, holding the values N, Sx=sum(X), Sxx=sum((X-Sx/N)^2), Sy=sum(Y),
- * Syy=sum((Y-Sy/N)^2), Sxy=sum((X-Sx/N)*(Y-Sy/N)) in that order.
+ * Syy=sum((Y-Sy/N)^2), Sxy=sum((X-Sx/N)*(Y-Sy/N)), firstX, firstY, DIFF,
+ * in that order.
+ *
+ * DIFF, like N, is treated as an integer.  Bit 0 is set if we saw distinct
+ * X inputs, and bit 1 is set if we saw distinct Y inputs.  This allows us
+ * to detect constant inputs exactly, which is important for deciding whether
+ * some outputs should be NULL.
  *
  * Note that Y is the first argument to all these aggregates!
  *
@@ -3345,17 +3351,23 @@ float8_regr_accum(PG_FUNCTION_ARGS)
                 Sy,
                 Syy,
                 Sxy,
+                firstX,
+                firstY,
                 tmpX,
                 tmpY,
                 scale;
+    int            diff;

-    transvalues = check_float8_array(transarray, "float8_regr_accum", 6);
+    transvalues = check_float8_array(transarray, "float8_regr_accum", 9);
     N = transvalues[0];
     Sx = transvalues[1];
     Sxx = transvalues[2];
     Sy = transvalues[3];
     Syy = transvalues[4];
     Sxy = transvalues[5];
+    firstX = transvalues[6];
+    firstY = transvalues[7];
+    diff = transvalues[8];

     /*
      * Use the Youngs-Cramer algorithm to incorporate the new values into the
@@ -3373,6 +3385,19 @@ float8_regr_accum(PG_FUNCTION_ARGS)
         Syy += tmpY * tmpY * scale;
         Sxy += tmpX * tmpY * scale;

+        /*
+         * Check to see if we have seen distinct inputs.  In normal use, diff
+         * will reach 3 very soon and then we can stop checking.
+         */
+        if (diff != 3)
+        {
+            /* Need SQL-style comparison of NaNs here */
+            if (float8_ne(newvalX, firstX))
+                diff |= 1;
+            if (float8_ne(newvalY, firstY))
+                diff |= 2;
+        }
+
         /*
          * Overflow check.  We only report an overflow error when finite
          * inputs lead to infinite results.  Note also that Sxx, Syy and Sxy
@@ -3410,6 +3435,8 @@ float8_regr_accum(PG_FUNCTION_ARGS)
             Sxx = Sxy = get_float8_nan();
         if (isnan(newvalY) || isinf(newvalY))
             Syy = Sxy = get_float8_nan();
+        firstX = newvalX;
+        firstY = newvalY;
     }

     /*
@@ -3425,12 +3452,15 @@ float8_regr_accum(PG_FUNCTION_ARGS)
         transvalues[3] = Sy;
         transvalues[4] = Syy;
         transvalues[5] = Sxy;
+        transvalues[6] = firstX;
+        transvalues[7] = firstY;
+        transvalues[8] = diff;

         PG_RETURN_ARRAYTYPE_P(transarray);
     }
     else
     {
-        Datum        transdatums[6];
+        Datum        transdatums[9];
         ArrayType  *result;

         transdatums[0] = Float8GetDatumFast(N);
@@ -3439,8 +3469,11 @@ float8_regr_accum(PG_FUNCTION_ARGS)
         transdatums[3] = Float8GetDatumFast(Sy);
         transdatums[4] = Float8GetDatumFast(Syy);
         transdatums[5] = Float8GetDatumFast(Sxy);
+        transdatums[6] = Float8GetDatumFast(firstX);
+        transdatums[7] = Float8GetDatumFast(firstY);
+        transdatums[8] = Float8GetDatum(diff);

-        result = construct_array_builtin(transdatums, 6, FLOAT8OID);
+        result = construct_array_builtin(transdatums, 9, FLOAT8OID);

         PG_RETURN_ARRAYTYPE_P(result);
     }
@@ -3730,27 +3763,25 @@ float8_corr(PG_FUNCTION_ARGS)
 {
     ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
     float8       *transvalues;
-    float8        N,
-                Sxx,
+    float8        Sxx,
                 Syy,
                 Sxy;
+    int            diff;

-    transvalues = check_float8_array(transarray, "float8_corr", 6);
-    N = transvalues[0];
+    transvalues = check_float8_array(transarray, "float8_corr", 9);
     Sxx = transvalues[2];
     Syy = transvalues[4];
     Sxy = transvalues[5];
+    diff = transvalues[8];

-    /* if N is 0 we should return NULL */
-    if (N < 1.0)
+    /*
+     * Per spec, we must return NULL if N is zero, all X inputs are equal, or
+     * all Y inputs are equal.  Checking the diff mask covers all three cases.
+     */
+    if (diff != 3)
         PG_RETURN_NULL();

     /* Note that Sxx and Syy are guaranteed to be non-negative */
-
-    /* per spec, return NULL for horizontal and vertical lines */
-    if (Sxx == 0 || Syy == 0)
-        PG_RETURN_NULL();
-
     PG_RETURN_FLOAT8(Sxy / sqrt(Sxx * Syy));
 }

diff --git a/src/include/catalog/pg_aggregate.dat b/src/include/catalog/pg_aggregate.dat
index 870769e8f14..68dc1329ea0 100644
--- a/src/include/catalog/pg_aggregate.dat
+++ b/src/include/catalog/pg_aggregate.dat
@@ -505,7 +505,7 @@
   aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0}' },
 { aggfnoid => 'corr', aggtransfn => 'float8_regr_accum',
   aggfinalfn => 'float8_corr', aggcombinefn => 'float8_regr_combine',
-  aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0}' },
+  aggtranstype => '_float8', agginitval => '{0,0,0,0,0,0,0,0,0}' },

 # boolean-and and boolean-or
 { aggfnoid => 'bool_and', aggtransfn => 'booland_statefunc',

pgsql-bugs by date:

Previous
From: Dean Rasheed
Date:
Subject: Re: BUG #19340: Wrong result from CORR() function
Next
From: Tom Lane
Date:
Subject: Re: BUG #19341: REPLACE() fails to match final character when using nondeterministic ICU collation