Re: Some improvements to numeric sqrt() and ln() - Mailing list pgsql-hackers

From Tom Lane
Subject Re: Some improvements to numeric sqrt() and ln()
Date
Msg-id 4495.1584915409@sss.pgh.pa.us
Whole thread Raw
In response to Re: Some improvements to numeric sqrt() and ln()  (Tels <nospam-pg-abuse@bloodgate.com>)
Responses Re: Some improvements to numeric sqrt() and ln()
List pgsql-hackers
Tels <nospam-pg-abuse@bloodgate.com> writes:
> This can be reformulated as:
> +     *        If r < 0 Then
> +     *            Let r = r + s
> +     *            Let s = s - 1
> +     *            Let r = r + s

Here's a v3 that

* incorporates Tels' idea;

* improves some of the comments (IMO anyway, though some are clear typos);

* adds some XXX comments about things that could be further improved
and/or need better explanations.

I also ran it through pgindent, just cause I'm like that.

With resolutions of the XXX items, I think this'd be committable.

            regards, tom lane

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 10229eb..afbc2b0 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif

 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
 static int    select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
                     NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+                        NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);

@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
              NumericVar *result, int rscale, bool round)
 {
     int            div_ndigits;
+    int            load_ndigits;
     int            res_sign;
     int            res_weight;
     int           *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
     div_ndigits += DIV_GUARD_DIGITS;
     if (div_ndigits < DIV_GUARD_DIGITS)
         div_ndigits = DIV_GUARD_DIGITS;
-    /* Must be at least var1ndigits, too, to simplify data-loading loop */
-    if (div_ndigits < var1ndigits)
-        div_ndigits = var1ndigits;

     /*
      * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
      * (approximate) quotient digit and stores it into div[], removing one
      * position of dividend space.  A final pass of carry propagation takes
      * care of any mistaken quotient digits.
+     *
+     * Note that div[] doesn't necessarily contain all of the digits from the
+     * dividend --- the desired precision plus guard digits might be less than
+     * the dividend's precision.  This happens, for example, in the square
+     * root algorithm, where we typically divide a 2N-digit number by an
+     * N-digit number, and only require a result with N digits of precision.
      */
     div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-    for (i = 0; i < var1ndigits; i++)
+    load_ndigits = Min(div_ndigits, var1ndigits);
+    for (i = 0; i < load_ndigits; i++)
         div[i + 1] = var1digits[i];

     /*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
             maxdiv += Abs(qdigit);
             if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
             {
-                /* Yes, do it */
+                /*
+                 * Yes, do it.  Note that if var2ndigits is much smaller than
+                 * div_ndigits, we can save a significant amount of effort
+                 * here by noting that we only need to normalise those div[]
+                 * entries touched where prior iterations subtracted multiples
+                 * of the divisor.
+                 */
                 carry = 0;
-                for (i = div_ndigits; i > qi; i--)
+                for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
                 {
                     newdig = div[i] + carry;
                     if (newdig < 0)
@@ -8095,6 +8098,76 @@ mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)


 /*
+ * div_mod_var() -
+ *
+ *    Calculate the truncated integer quotient and numeric remainder of two
+ *    numeric variables.  The remainder is precise to var2's dscale.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+            NumericVar *quot, NumericVar *rem)
+{
+    NumericVar    q;
+    NumericVar    r;
+
+    init_var(&q);
+    init_var(&r);
+
+    /*
+     * Use div_var_fast() to get an initial estimate for the integer quotient.
+     * This might be inaccurate (per the warning in div_var_fast's comments),
+     * but we can correct it below.
+     */
+    div_var_fast(var1, var2, &q, 0, false);
+
+    /* Compute initial estimate of remainder using the quotient estimate. */
+    mul_var(var2, &q, &r, var2->dscale);
+    sub_var(var1, &r, &r);
+
+    /*
+     * Adjust the results if necessary --- the remainder should have the same
+     * sign as var1, and its absolute value should be less than the absolute
+     * value of var2.
+     */
+    while (r.ndigits != 0 && r.sign != var1->sign)
+    {
+        /* The absolute value of the quotient is too large */
+        if (var1->sign == var2->sign)
+        {
+            sub_var(&q, &const_one, &q);
+            add_var(&r, var2, &r);
+        }
+        else
+        {
+            add_var(&q, &const_one, &q);
+            sub_var(&r, var2, &r);
+        }
+    }
+
+    while (cmp_abs(&r, var2) >= 0)
+    {
+        /* The absolute value of the quotient is too small */
+        if (var1->sign == var2->sign)
+        {
+            add_var(&q, &const_one, &q);
+            sub_var(&r, var2, &r);
+        }
+        else
+        {
+            sub_var(&q, &const_one, &q);
+            add_var(&r, var2, &r);
+        }
+    }
+
+    set_var_from_var(&q, quot);
+    set_var_from_var(&r, rem);
+
+    free_var(&q);
+    free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *    Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8286,30 @@ gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
 /*
  * sqrt_var() -
  *
- *    Compute the square root of x using Newton's algorithm
+ *    Compute the square root of x using the Karatsuba Square Root algorithm.
+ *    NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *    point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-    NumericVar    tmp_arg;
-    NumericVar    tmp_val;
-    NumericVar    last_val;
-    int            local_rscale;
     int            stat;
-
-    local_rscale = rscale + 8;
+    int            res_weight;
+    int            res_ndigits;
+    int            src_ndigits;
+    int            step;
+    int            ndigits[32];
+    int            blen;
+    int64        arg_int64;
+    int            src_idx;
+    int64        s_int64;
+    int64        r_int64;
+    NumericVar    s_var;
+    NumericVar    r_var;
+    NumericVar    a0_var;
+    NumericVar    a1_var;
+    NumericVar    q_var;
+    NumericVar    u_var;

     stat = cmp_var(arg, &const_zero);
     if (stat == 0)
@@ -8243,43 +8328,412 @@ sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
                 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
                  errmsg("cannot take square root of a negative number")));

-    init_var(&tmp_arg);
-    init_var(&tmp_val);
-    init_var(&last_val);
+    init_var(&s_var);
+    init_var(&r_var);
+    init_var(&a0_var);
+    init_var(&a1_var);
+    init_var(&q_var);
+    init_var(&u_var);

-    /* Copy arg in case it is the same var as result */
-    set_var_from_var(arg, &tmp_arg);
+    /*
+     * The result weight is half the input weight, rounded towards minus
+     * infinity.
+     *
+     * XXX do we really need floor(double) for that, rather than plain integer
+     * math?
+     */
+    res_weight = (int) floor((double) arg->weight / 2);

     /*
-     * Initialize the result to the first guess
+     * Number of NBASE digits to compute.  To ensure correct rounding, compute
+     * at least 1 extra decimal digit.  We explicitly allow rscale to be
+     * negative here, but must always compute at least 1 NBASE digit.
+     *
+     * XXX likewise seems like ceil(double) is unnecessary expense.
      */
-    alloc_var(result, 1);
-    result->digits[0] = tmp_arg.digits[0] / 2;
-    if (result->digits[0] == 0)
-        result->digits[0] = 1;
-    result->weight = tmp_arg.weight / 2;
-    result->sign = NUMERIC_POS;
+    res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+    res_ndigits = Max(res_ndigits, 1);

-    set_var_from_var(result, &last_val);
+    /*
+     * Number of source NBASE digits logically required to produce a result
+     * with this precision --- every digit before the decimal point, plus 2
+     * for each result digit after the decimal point (or minus 2 for each
+     * result digit we round before the decimal point).
+     */
+    src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+    src_ndigits = Max(src_ndigits, 1);

-    for (;;)
+    /* ----------
+     * From this point on, we treat the input and the result as integers and
+     * compute the integer square root and remainder using the Karatsuba
+     * Square Root algorithm, which may be written recursively as follows:
+     *
+     *    SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+     *        [ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+     *          0 <= a0,a1,a2 < b and a3 >= b/4 ]
+     *        Let (s,r) = SqrtRem(a3*b + a2)
+     *        Let (q,u) = DivRem(r*b + a1, 2*s)
+     *        Let s = s*b + q
+     *        Let r = u*b + a0 - q^2
+     *        If r < 0 Then
+     *            Let r = r + s
+     *            Let s = s - 1
+     *            Let r = r + s
+     *        Return (s,r)
+     *
+     * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+     * RR-3805, November 1999.  At the time of writing this was available
+     * on the net at <https://hal.inria.fr/inria-00072854>.
+     *
+     * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
+     * "choose a base b such that n requires at least four base-b digits to
+     * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
+     * than b".  For optimal performance, b should have approximately a
+     * quarter the number of digits in the input, so that the outer square
+     * root computes roughly twice as many digits as the inner one.  For
+     * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
+     *
+     * We implement the algorithm iteratively rather than recursively, to
+     * allow the working variables to be reused.  With this approach, each
+     * digit of the input is read precisely once --- src_idx tracks the number
+     * of input digits used so far.
+     *
+     * The array ndigits[] holds the number of NBASE digits of the input that
+     * will have been used at the end of each iteration, which roughly doubles
+     * each time.  Note that the array elements are stored in reverse order,
+     * so if the final iteration requires src_ndigits = 37 input digits, the
+     * array will contain [37,19,11,7,5,3], and we would start by computing
+     * the square root of the 3 most significant NBASE digits.
+     *
+     * XXX I don't understand how this works.  Why is it correct to consider
+     * arg->digits[0] at every step?  Can we prove rigorously that the ndigits
+     * array won't be overrun?  (I can see that src_ndigits is roughly halved
+     * by each iteration, but only roughly, so it's not entirely clear that
+     * the worst-case situation couldn't involve more than 31 steps.)
+     * ----------
+     */
+    step = 0;
+    while ((ndigits[step] = src_ndigits) > 4)
     {
-        div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+        /* Choose b so that a3 >= b/4 */
+        blen = src_ndigits / 4;
+        if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+            blen--;

-        add_var(result, &tmp_val, result);
-        mul_var(result, &const_zero_point_five, result, local_rscale);
+        /* Number of digits in the next step (inner square root) */
+        src_ndigits -= 2 * blen;
+        step++;
+    }

-        if (cmp_var(&last_val, result) == 0)
-            break;
-        set_var_from_var(result, &last_val);
+    /*
+     * First iteration (innermost square root and remainder):
+     *
+     * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+     * has at most 9 decimal digits, so estimate it using double precision
+     * arithmetic, which will in fact almost certainly return the correct
+     * result with no further correction required.
+     */
+    arg_int64 = arg->digits[0];
+    for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+    {
+        arg_int64 *= NBASE;
+        if (src_idx < arg->ndigits)
+            arg_int64 += arg->digits[src_idx];
     }

-    free_var(&last_val);
-    free_var(&tmp_val);
-    free_var(&tmp_arg);
+    s_int64 = (int64) sqrt((double) arg_int64);
+    r_int64 = arg_int64 - s_int64 * s_int64;
+
+    /* Use Newton's method to correct the result, if necessary */
+    /* XXX is this guaranteed to converge?  integer division truncates... */
+    while (r_int64 < 0 || r_int64 > 2 * s_int64)
+    {
+        s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+        r_int64 = arg_int64 - s_int64 * s_int64;
+    }
+
+    /*
+     * Iterations with src_ndigits <= 8:
+     *
+     * The next 1 or 2 iterations compute larger (outer) square roots with
+     * src_ndigits <= 8, so the result still fits in an int64 (even though the
+     * input no longer does) and we can continue to compute using int64
+     * variables to avoid more expensive numeric computations.
+     *
+     * It is fairly easy to see that there is no risk of the intermediate
+     * values below overflowing 64-bit integers.  In the worst case, the
+     * previous iteration will have computed a 3-digit square root (of a
+     * 6-digit input less than NBASE^6 / 4), so at the start of this
+     * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+     * less than 10^12.  In this case, blen will be 1, so numer will be less
+     * than 10^17, and denom will be less than 10^12 (and hence u will also be
+     * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+     * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+     * in 64-bit integers.
+     */
+    step--;
+    while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+    {
+        int            b;
+        int            a0;
+        int            a1;
+        int            i;
+        int64        numer;
+        int64        denom;
+        int64        q;
+        int64        u;
+
+        blen = (src_ndigits - src_idx) / 2;
+
+        /* Extract a1 and a0, and compute b */
+        a0 = 0;
+        a1 = 0;
+        b = 1;
+
+        for (i = 0; i < blen; i++, src_idx++)
+        {
+            b *= NBASE;
+            a1 *= NBASE;
+            if (src_idx < arg->ndigits)
+                a1 += arg->digits[src_idx];
+        }
+
+        for (i = 0; i < blen; i++, src_idx++)
+        {
+            a0 *= NBASE;
+            if (src_idx < arg->ndigits)
+                a0 += arg->digits[src_idx];
+        }

-    /* Round to requested precision */
+        /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+        numer = r_int64 * b + a1;
+        denom = 2 * s_int64;
+        q = numer / denom;
+        u = numer - q * denom;
+
+        /* Compute s = s*b + q and r = u*b + a0 - q^2 */
+        s_int64 = s_int64 * b + q;
+        r_int64 = u * b + a0 - q * q;
+
+        if (r_int64 < 0)
+        {
+            /* s is too large by 1; set r += s, s--, r += s */
+            r_int64 += s_int64;
+            s_int64--;
+            r_int64 += s_int64;
+        }
+
+        Assert(src_idx == src_ndigits); /* All input digits consumed */
+        step--;
+    }
+
+    /*
+     * On platforms with 128-bit integer support, we can further delay the
+     * need to use numeric variables.
+     */
+#ifdef HAVE_INT128
+    if (step >= 0)
+    {
+        int128        s_int128;
+        int128        r_int128;
+
+        s_int128 = s_int64;
+        r_int128 = r_int64;
+
+        /*
+         * Iterations with src_ndigits <= 16:
+         *
+         * The result fits in an int128 (even though the input doesn't) so we
+         * use int128 variables to avoid more expensive numeric computations.
+         */
+        while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+        {
+            int64        b;
+            int64        a0;
+            int64        a1;
+            int64        i;
+            int128        numer;
+            int128        denom;
+            int128        q;
+            int128        u;
+
+            blen = (src_ndigits - src_idx) / 2;
+
+            /* Extract a1 and a0, and compute b */
+            a0 = 0;
+            a1 = 0;
+            b = 1;
+
+            for (i = 0; i < blen; i++, src_idx++)
+            {
+                b *= NBASE;
+                a1 *= NBASE;
+                if (src_idx < arg->ndigits)
+                    a1 += arg->digits[src_idx];
+            }
+
+            for (i = 0; i < blen; i++, src_idx++)
+            {
+                a0 *= NBASE;
+                if (src_idx < arg->ndigits)
+                    a0 += arg->digits[src_idx];
+            }
+
+            /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+            numer = r_int128 * b + a1;
+            denom = 2 * s_int128;
+            q = numer / denom;
+            u = numer - q * denom;
+
+            /* Compute s = s*b + q and r = u*b + a0 - q^2 */
+            s_int128 = s_int128 * b + q;
+            r_int128 = u * b + a0 - q * q;
+
+            if (r_int128 < 0)
+            {
+                /* s is too large by 1; set r += s, s--, r += s */
+                r_int128 += s_int128;
+                s_int128--;
+                r_int128 += s_int128;
+            }
+
+            Assert(src_idx == src_ndigits); /* All input digits consumed */
+            step--;
+        }
+
+        /*
+         * All remaining iterations require numeric variables.  Convert the
+         * integer values to NumericVar and continue.  Note that in the final
+         * iteration we don't need the remainder, so we can save a few cycles
+         * there by not fully computing it.
+         */
+        int128_to_numericvar(s_int128, &s_var);
+        if (step >= 0)
+            int128_to_numericvar(r_int128, &r_var);
+    }
+    else
+    {
+        int64_to_numericvar(s_int64, &s_var);
+        /* step < 0, so we certainly don't need r */
+    }
+#else                            /* !HAVE_INT128 */
+    int64_to_numericvar(s_int64, &s_var);
+    if (step >= 0)
+        int64_to_numericvar(r_int64, &r_var);
+#endif                            /* HAVE_INT128 */
+
+    /*
+     * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+     * use numeric variables.
+     */
+    while (step >= 0)
+    {
+        int            tmp_len;
+
+        src_ndigits = ndigits[step];
+        blen = (src_ndigits - src_idx) / 2;
+
+        /* Extract a1 and a0 */
+        if (src_idx < arg->ndigits)
+        {
+            tmp_len = Min(blen, arg->ndigits - src_idx);
+            alloc_var(&a1_var, tmp_len);
+            memcpy(a1_var.digits, arg->digits + src_idx,
+                   tmp_len * sizeof(NumericDigit));
+            a1_var.weight = blen - 1;
+            a1_var.sign = NUMERIC_POS;
+            a1_var.dscale = 0;
+            strip_var(&a1_var);
+        }
+        else
+        {
+            zero_var(&a1_var);
+            a1_var.dscale = 0;
+        }
+        src_idx += blen;
+
+        if (src_idx < arg->ndigits)
+        {
+            tmp_len = Min(blen, arg->ndigits - src_idx);
+            alloc_var(&a0_var, tmp_len);
+            memcpy(a0_var.digits, arg->digits + src_idx,
+                   tmp_len * sizeof(NumericDigit));
+            a0_var.weight = blen - 1;
+            a0_var.sign = NUMERIC_POS;
+            a0_var.dscale = 0;
+            strip_var(&a0_var);
+        }
+        else
+        {
+            zero_var(&a0_var);
+            a0_var.dscale = 0;
+        }
+        src_idx += blen;
+
+        /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+        set_var_from_var(&r_var, &q_var);
+        q_var.weight += blen;
+        add_var(&q_var, &a1_var, &q_var);
+        add_var(&s_var, &s_var, &u_var);
+        div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+        /* Compute s = s*b + q */
+        s_var.weight += blen;
+        add_var(&s_var, &q_var, &s_var);
+
+        /*
+         * Compute r = u*b + a0 - q^2.
+         *
+         * In the final iteration, we don't actually need r; we just need to
+         * know whether it is negative, so that we know whether to adjust s.
+         * So instead of the final subtraction we can just compare.
+         */
+        u_var.weight += blen;
+        add_var(&u_var, &a0_var, &u_var);
+        mul_var(&q_var, &q_var, &q_var, 0);
+
+        if (step > 0)
+        {
+            /* Need r for later iterations */
+            sub_var(&u_var, &q_var, &r_var);
+            if (r_var.sign == NUMERIC_NEG)
+            {
+                /* s is too large by 1; set r += s, s--, r += s */
+                add_var(&r_var, &s_var, &r_var);
+                sub_var(&s_var, &const_one, &s_var);
+                add_var(&r_var, &s_var, &r_var);
+            }
+        }
+        else
+        {
+            /* Don't need r anymore, except to test if s is too large by 1 */
+            if (cmp_var(&u_var, &q_var) < 0)
+                sub_var(&s_var, &const_one, &s_var);
+        }
+
+        Assert(src_idx == src_ndigits); /* All input digits consumed */
+        step--;
+    }
+
+    /*
+     * Construct the final result, rounding it to the requested precision.
+     */
+    set_var_from_var(&s_var, result);
+    result->weight = res_weight;
+    result->sign = NUMERIC_POS;
+
+    /* Round to target rscale (and set result->dscale) */
     round_var(result, rscale);
+
+    /* Strip leading and trailing zeroes */
+    strip_var(result);
+
+    free_var(&s_var);
+    free_var(&r_var);
+    free_var(&a0_var);
+    free_var(&a1_var);
+    free_var(&q_var);
+    free_var(&u_var);
 }


@@ -8530,12 +8984,18 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
      * Each sqrt() will roughly halve the weight of x, so adjust the local
      * rscale as we work so that we keep this many significant digits at each
      * step (plus a few more for good measure).
+     *
+     * Note that we allow local_rscale < 0 during this input reduction
+     * process, which implies rounding before the decimal point.  sqrt_var()
+     * explicitly supports this, and it significantly reduces the work
+     * required to reduce very large inputs to the required range.  Once the
+     * input reduction is complete, x.weight will be 0 and its display scale
+     * will be non-negative again.
      */
     nsqrt = 0;
     while (cmp_var(&x, &const_zero_point_nine) <= 0)
     {
         local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-        local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
         sqrt_var(&x, &x, local_rscale);
         mul_var(&fact, &const_two, &fact, 0);
         nsqrt++;
@@ -8543,7 +9003,6 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
     while (cmp_var(&x, &const_one_point_one) >= 0)
     {
         local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-        local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
         sqrt_var(&x, &x, local_rscale);
         mul_var(&fact, &const_two, &fact, 0);
         nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
index 23a4c6d..c7fe63d 100644
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 (1 row)

 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
index c5c8d76..41475a9 100644
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;

 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --


pgsql-hackers by date:

Previous
From: Andreas Karlsson
Date:
Subject: Re: [PATCH] Incremental sort (was: PoC: Partial sort)
Next
From: Thomas Munro
Date:
Subject: Re: Index Skip Scan