Re: Optimize numeric.c mul_var() using the Karatsuba algorithm - Mailing list pgsql-hackers

From Joel Jacobson
Subject Re: Optimize numeric.c mul_var() using the Karatsuba algorithm
Date
Msg-id f47879b7-6cbc-48e3-aa5b-6386b6293e84@app.fastmail.com
Whole thread Raw
In response to Re: Optimize numeric.c mul_var() using the Karatsuba algorithm  (Dean Rasheed <dean.a.rasheed@gmail.com>)
Responses Re: Optimize numeric.c mul_var() using the Karatsuba algorithm  (Alvaro Herrera <alvherre@alvh.no-ip.org>)
List pgsql-hackers
On Sat, Jun 29, 2024, at 14:22, Dean Rasheed wrote:
> On Sun, Jun 23, 2024 at 09:00:29AM +0200, Joel Jacobson wrote:
>> Attached, rebased version of the patch that implements the Karatsuba algorithm in numeric.c's mul_var().
>>
>
> Something to watch out for is that not all callers of mul_var() want
> an exact result. Several internal callers request an approximate
> result by passing it an rscale value less than the sum of the input
> dscales. The schoolbook algorithm handles that by computing up to
> rscale plus MUL_GUARD_DIGITS extra digits and then rounding, whereas
> the new Karatsuba code always computes the full result and then
> rounds. That impacts the performance of various functions, for
> example:
>
> select sum(exp(x)) from generate_series(5999.9, 5950.0, -0.1) x;
>
> Time: 1790.825 ms (00:01.791)  [HEAD]
> Time: 2161.244 ms (00:02.161)  [with patch]

Ops. Thanks for spotting this and clarifying.

I read Tom's reply and note that we should only do Karatsuba
an exact (full-precision) result is requested.

> Looking at mul_var_karatsuba_half(), I don't really like the approach
> it takes. The whole correctness proof using the Karatsuba formula
> seems to somewhat miss the point that this function isn't actually
> implementing the Karatsuba algorithm, it is implementing the
> schoolbook algorithm in two steps, by splitting the longer input into
> two pieces.

The surprising realization here is that there are actually (var1ndigits, var2ndigits)
combinations where *only* doing mul_var_karatsuba_half() recursively
all the way down to schoolbook *is* a performance win,
even though we don't do any mul_var_karatsuba_full().

mul_var_karatsuba_half() *is* actually implementing the exact
Karatsuba formula, it's just taking a shortcut exploiting the pre-known
that splitting `var1` at `m2` would result in `high1` being zero.
This allows the provably correct substitutions to be made,
which avoids the meaningless computations.

> But why just split it into two pieces? That will just lead
> to a lot of unnecessary recursion for very unbalanced inputs. Instead,
> why not split the longer input into N roughly equal sized pieces, each
> around the same length as the shorter input, multiplying and adding
> them at the appropriate offsets?

The approach you're describing is implemented by e.g. CPython
and is called "lopsided" in their code base. It has some different
performance characteristics, compared to the recursive Half-Karatsuba
approach.

What I didn't like about lopsided is the degenerate case where the
last chunk is much shorter than the var1, for example, if we pretend
we would be doing Karatsuba all the way down to ndigits 2,
and think about the example var1ndigits = 3 and var2ndigits = 10,
then lopsided would do
var1ndigits=3 var2ndigits=3
var1ndigits=3 var2ndigits=3
var1ndigits=3 var2ndigits=3
var1ndigits=3 var2ndigits=1

whereas Half-Karatsuba would do
var1ndigits=3 var2ndigits=5
var1ndigits=3 var2ndigits=5

You can find contrary examples too of course where lopsided
is better than Half-Karatsuba, none of them seem substantially better
than the other.

My measurements indicated that overall, Half-Karatsuba seemed like
the overall marginal winner, on the architectures I tested, but they were all
very similar, i.e. almost the same number of "wins" and "losses",
for different (var1ndigits, var2ndigits) combinations.

Note that even with lopsided, there will still be recursion, since Karatsuba
is a recursive algorithm, so to satisfy Tom Lane's request about
proving the recursion is limited, we will still need to prove the same
thing for lopsided+Karatsuba.

Here is some old code from my experiments, if we want to evaluate lopsided:

```
static void slice_var(const NumericVar *var, int start, int length,
                      NumericVar *slice);

static void mul_var_lopsided(const NumericVar *var1, const NumericVar *var2,
                             NumericVar *result);

/*
 * slice_var() -
 *
 * Extract a slice of a NumericVar starting at a specified position
 * and with a specified length.
 */
static void
slice_var(const NumericVar *var, int start, int length,
          NumericVar *slice)
{
    Assert(start >= 0);
    Assert(start + length <= var->ndigits);

    init_var(slice);

    slice->ndigits = length;
    slice->digits = var->digits + start;
    slice->buf = NULL;
    slice->weight = var->weight - var->ndigits + length;
    slice->sign = var->sign;
    slice->dscale = (var->ndigits - var->weight - 1) * DEC_DIGITS;
}

/*
 * mul_var_lopsided() -
 *
 * Lopsided Multiplication for unequal-length factors.
 *
 * This function handles the case where var1 has significantly fewer digits
 * than var2. In such a scenario, splitting var1 for a balanced multiplication
 * algorithm would be inefficient, as the high part would be zero.
 *
 * To overcome this inefficiency, the function divides factor2 into a series of
 * slices, each containing the same number of digits as var1, and multiplies
 * var1 with each slice one at a time. As a result, the recursive call to
 * mul_var() will have balanced inputs, which improves the performance of
 * divide-and-conquer algorithm, such as the Karatsuba.
 */
static void
mul_var_lopsided(const NumericVar *var1, const NumericVar *var2,
                 NumericVar *result)
{
    int            var1ndigits = var1->ndigits;
    int            var2ndigits = var2->ndigits;
    int            processed = 0;
    int            remaining = var2ndigits;
    int            length;
    NumericVar    slice;
    NumericVar    product;
    NumericVar    sum;

    Assert(var1ndigits <= var2ndigits);
    Assert(var1ndigits > MUL_SMALL);
    Assert(var1ndigits * 2 <= var2ndigits);

    init_var(&slice);
    init_var(&product);
    init_var(&sum);

    while (remaining > 0)
    {
        length = Min(remaining, var1ndigits);
        slice_var(var2, var2ndigits - processed - length, length, &slice);
        mul_var(var1, &slice, &product, var1->dscale + slice.dscale);
        product.weight += processed;
        add_var(&sum, &product, &sum);
        remaining -= length;
        processed += length;
    }

    set_var_from_var(&sum, result);

    free_var(&slice);
    free_var(&product);
    free_var(&sum);
}
```

> As an example, given inputs with
> var1ndigits = 1000 and var2ndigits = 10000, mul_var() will invoke
> mul_var_karatsuba_half(), which will then recursively invoke mul_var()
> twice with var1ndigits = 1000 and var2ndigits = 5000, which no longer
> satisfies KARATSUBA_CONDITION(), so it will just invoke the schoolbook
> algorithm on each half, which stands no chance of being any faster. On
> the other hand, if it divided var2 into 10 chunks of length 1000, it
> would invoke the Karatsuba algorithm on each chunk, which would at
> least stand a chance of being faster.

Interesting example!

Indeed only mul_var_karatsuba_half() will be called with the inputs:
var1ndigits=1000 var2ndigits=10000
var1ndigits=1000 var2ndigits=5000
It will never call mul_var_karatsuba_full().

Surprisingly, this still gives a 13% speed-up on a Intel Core i9-14900K.

This performance gain comes from the splitting of the larger factor.

Here is how I benchmarked using pg-timeit [1] and the
mul_var-karatsuba-benchmark.patch [2] from my original post:

To test the patch, you have to edit pg_proc.dat for
numeric_mul_karatsuba and give it a new unique oid.

```
SELECT
   timeit.pretty_time(total_time_a / 1e6 / executions,3) AS execution_time_a,
   timeit.pretty_time(total_time_b / 1e6 / executions,3) AS execution_time_b,
   total_time_a::numeric/total_time_b - 1 AS execution_time_difference
FROM timeit.cmp(
   'numeric_mul',
   'numeric_mul_karatsuba',
   input_values := ARRAY[
      random_ndigits(1000)::TEXT,
      random_ndigits(10000)::TEXT
   ],
   min_time := 1000000,
   timeout := '10 s'
);
-[ RECORD 1 ]-------------+-------------------
execution_time_a          | 976 µs
execution_time_b          | 864 µs
execution_time_difference | 0.1294294200936600
```

The KARATSUBA_CONDITION tries to capture the interesting region of performance gains, as observed from measurements
[3],in an expression that is not too complex. 

In the image [3], the purple to black area is performance regressions,
and the rainbow colors are performance gains.

The black colored line segment is the KARATSUBA_CONDITION,
which tries to capture the three performance gain regions,
as defined by:
KARATSUBA_LOW_RANGE_CONDITION
KARATSUBA_MIDDLE_RANGE_CONDITION
KARATSUBA_HIGH_RANGE_CONDITION

> Related to that, KARATSUBA_HIGH_RANGE_CONDITION() doesn't appear to
> make a lot of sense. For inputs with var1ndigits between 128 and 2000,
> and var2ndigits > 9000, this condition will pass and it will
> recursively break up the longer input into smaller and smaller pieces
> until eventually that condition no longer passes, but none of the
> other conditions in KARATSUBA_CONDITION() will pass either, so it'll
> just invoke the schoolbook algorithm on each piece, which is bound to
> be slower once all the overheads are taken into account. For example,
> given var1ndigits = 200 and var2ndigits = 30000, KARATSUBA_CONDITION()
> will pass due to KARATSUBA_HIGH_RANGE_CONDITION(), and it will recurse
> with var1ndigits = 200 and var2ndigits = 15000, and then again with
> var1ndigits = 200 and var2ndigits = 7500, at which point
> KARATSUBA_CONDITION() no longer passes. With mul_var_karatsuba_half()
> implemented as it is, that is bound to happen, because each half will
> end up having var2ndigits between 4500 and 9000, which fails
> KARATSUBA_CONDITION() if var1ndigits < 2000. If
> mul_var_karatsuba_half() was replaced by something that recursed with
> more balanced chunks, then it might make more sense,though allowing
> values of var1ndigits down to 128 doesn't make sense, since the
> Karatsuba algorithm will never be invoked for inputs shorter than 384.

Like explained above, the mul_var_karatsuba_half() is not meaningless,
even for cases where we never reach mul_var_karatsuba_full().

Regarding the last comment on 128 and 384, I think you're reading the
conditions wrong, note that 128 is for var1ndigits while 384 is for var2ndigits:

+#define KARATSUBA_BASE_LIMIT 384
+#define KARATSUBA_VAR1_MIN1 128
...
+    ((var2ndigits) >= KARATSUBA_BASE_LIMIT && \
...
+     (var1ndigits) > KARATSUBA_VAR1_MIN1)

> Looking at KARATSUBA_MIDDLE_RANGE_CONDITION(), the test that
> var2ndigits > 2500 seems to be redundant. If var1ndigits > 2000 and
> var2ndigits < 2500, then KARATSUBA_LOW_RANGE_CONDITION() is satisfied,
> so these tests could be simplified, eliminating some of those magic
> constants.

Yes, I realized that myself too, but chose to keep the start and end
boundaries for each of the three ranges. Since it's just boolean logic,
with constants, I think the compiler should be smart enough to optimize
away the redundancy, but maybe better to keep the redundant condition
as a comment instead of actual code, I have no strong opinion what's best.

> However, I really don't like having these magic constants at all,
> because in practice the threshold above which the Karatsuba algorithm
> is a win can vary depending on a number of factors, such as whether
> it's running on 32-bit or 64-bit, whether or not SIMD instructions are
> available, the relative timings of CPU instructions, the compiler
> options used, and probably a bunch of other things. The last time I
> looked at the Java source code, for example, they had separate
> thresholds for 32-bit and 64-bit platforms, and even that's probably
> too crude. Some numeric libraries tune the thresholds for a large
> number of different platforms, but that takes a lot of effort. I think
> a better approach would be to have a configurable threshold. Ideally,
> this would be just one number, with all other numbers being derived
> from it, possibly using some simple heuristic to reduce the effective
> threshold for more balanced inputs, for which the Karatsuba algorithm
> is more efficient.
>
> Having a configurable threshold would allow people to tune for best
> performance on their own platforms, and also it would make it easier
> to write tests that hit the new code. As it stands, it's not obvious
> how much of the new code is being hit by the existing tests.
>
> Doing a quick test on my machine, using random equal-length inputs of
> various sizes, I got the following performance results:
>
>  digits | rate (HEAD)   | rate (patch)  | change
> --------+---------------+---------------+--------
>      10 | 6.060014e+06  | 6.0189365e+06 | -0.7%
>     100 | 2.7038752e+06 | 2.7287925e+06 | +0.9%
>    1000 | 88640.37      | 90504.82      | +2.1%
>    1500 | 39885.23      | 41041.504     | +2.9%
>    1600 | 36355.24      | 33368.28      | -8.2%
>    2000 | 23308.582     | 23105.932     | -0.9%
>    3000 | 10765.185     | 11360.11      | +5.5%
>    4000 | 6118.2554     | 6645.4116     | +8.6%
>    5000 | 3928.4985     | 4639.914      | +18.1%
>   10000 | 1003.80164    | 1431.9335     | +42.7%
>   20000 | 255.46135     | 456.23462     | +78.6%
>   30000 | 110.69313     | 226.53398     | +104.7%
>   40000 | 62.29333      | 148.12916     | +137.8%
>   50000 | 39.867493     | 95.16788      | +138.7%
>   60000 | 27.7672       | 74.01282      | +166.5%
>
> The Karatsuba algorithm kicks in at 384*4 = 1536 decimal digits, so
> presumably the variations below that are just noise, but this does
> seem to suggest that KARATSUBA_BASE_LIMIT = 384 is too low for me, and
> I'd probably want it to be something like 500-700.

I've tried hard to reduce the magic part of these constants,
by benchmarking on numerous architectures, and picking them
manually by making a balanced judgement about what complexity
could possibly be acceptable for the threshold function,
and what performance gains that are important to try to capture.

I think this approach is actually less magical than the hard-coded
single value constants I've seen in many other numeric libraries,
where it's not clear at all what the full two dimensional performance
image looks like.

I considered if my initial post on this should propose a patch with a
simple threshold function, that just checks if var1ndigits is larger
than some constant, like many other numeric libraries do.
However, I decided I should at least try to do something smarter,
since it seemed possible.

> There's another complication though (if the threshold is made
> configurable): the various numeric functions that use mul_var() are
> immutable, which means that the results from the Karatsuba algorithm
> must match those from the schoolbook algorithm exactly, for all
> inputs. That's not currently the case when computing approximate
> results with a reduced rscale. That's fixable, but it's not obvious
> whether or not the Karatsuba algorithm can actually be made beneficial
> when computing such approximate results.

I read Tom's reply on this part, and understand we can only do Karatsuba
if full rscale is desired.

> There's a wider question as to how many people use such big numeric
> values -- i.e., how many people are actually going to benefit from
> this? I don't have a good feel for that.

Personally, I started working on this because I wanted a to use numeric
to implement the ECDSA verify algorithm in PL/pgSQL, to avoid
dependency on a C extension that I discovered could segfault.

Unfortunately, Karatsuba didn't help much for this particular case,
since the ECDSA factors are not big enough.
I ended up implementing ECDSA verify as a pgrx extension instead.

However, I can imagine other crypto algorithms might require larger factors,
as well as other scientific research use cases, such as astronomy and physics,
that could desire storage of numeric values of very high precision.

Toom-3 is probably overkill, since we "only" support up to 32768 base digits,
but I think we should at least consider optimizing for the range of numeric values
that are supported by numeric.

Regards,
Joel

[1] https://github.com/joelonsql/pg-timeit
[2] https://www.postgresql.org/message-id/attachment/159528/mul_var-karatsuba-benchmark.patch
[3]
https://gist.githubusercontent.com/joelonsql/e9d06cdbcdf56cd8ffa673f499880b0d/raw/69df06e95bc254090f8397765079e1a8145eb5ac/derive_threshold_function_using_dynamic_programming.png



pgsql-hackers by date:

Previous
From: Heikki Linnakangas
Date:
Subject: Re: Question about maxTapes & selectnewtape & dumptuples
Next
From: "Joel Jacobson"
Date:
Subject: Re: [PATCH] Add ACL (Access Control List) acronym