diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 5510a203b0..c40b062c9a 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -459,6 +459,30 @@ static const NumericVar const_ninf = static const int round_powers[4] = {0, 1000, 100, 10}; #endif +#define KARATSUBA_BASE_LIMIT 384 +#define KARATSUBA_VAR1_MIN1 128 +#define KARATSUBA_VAR1_MIN2 2000 +#define KARATSUBA_VAR2_MIN1 2500 +#define KARATSUBA_VAR2_MIN2 9000 +#define KARATSUBA_SLOPE 0.764 +#define KARATSUBA_INTERCEPT 90.737 + +#define KARATSUBA_LOW_RANGE_CONDITION(var1ndigits, var2ndigits) \ + ((var1ndigits) > (KARATSUBA_SLOPE) * (var2ndigits) + KARATSUBA_INTERCEPT) + +#define KARATSUBA_MIDDLE_RANGE_CONDITION(var1ndigits, var2ndigits) \ + ((var2ndigits) > KARATSUBA_VAR2_MIN1 && \ + (var1ndigits) > KARATSUBA_VAR1_MIN2) + +#define KARATSUBA_HIGH_RANGE_CONDITION(var1ndigits, var2ndigits) \ + ((var2ndigits) > KARATSUBA_VAR2_MIN2 && \ + (var1ndigits) > KARATSUBA_VAR1_MIN1) + +#define KARATSUBA_CONDITION(var1ndigits, var2ndigits) \ + ((var2ndigits) >= KARATSUBA_BASE_LIMIT && \ + (KARATSUBA_LOW_RANGE_CONDITION(var1ndigits, var2ndigits) || \ + KARATSUBA_MIDDLE_RANGE_CONDITION(var1ndigits, var2ndigits) || \ + KARATSUBA_HIGH_RANGE_CONDITION(var1ndigits, var2ndigits))) /* ---------- * Local functions @@ -551,6 +575,17 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2, static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_karatsuba(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, + int rscale); +inline static void split_var_at(const NumericVar *var, int split_point, + NumericVar *low, NumericVar *high); +static void mul_var_karatsuba_full(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, + int rscale); +static void mul_var_karatsuba_half(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, + int rscale); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -3115,6 +3150,130 @@ numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error) } +/* + * numeric_mul_karatsuba() - + * + * This function multiplies two numeric values using the Karatsuba algorithm, + * designed for efficient handling of large numbers. It's introduced to allow + * direct benchmark comparisons with the standard numeric_mul() function. + */ +Datum +numeric_mul_karatsuba(PG_FUNCTION_ARGS) +{ + Numeric num1 = PG_GETARG_NUMERIC(0); + Numeric num2 = PG_GETARG_NUMERIC(1); + Numeric res; + + res = numeric_mul_karatsuba_opt_error(num1, num2, NULL); + + PG_RETURN_NUMERIC(res); +} + + +/* + * numeric_mul_karatsuba_opt_error() - + * + * Internal version of numeric_mul_karatsuba(). + * If "*have_error" flag is provided, on error it's set to true, NULL returned. + * This is helpful when caller need to handle errors by itself. + */ +Numeric +numeric_mul_karatsuba_opt_error(Numeric num1, Numeric num2, bool *have_error) +{ + NumericVar arg1; + NumericVar arg2; + NumericVar result; + Numeric res; + + /* + * Handle NaN and infinities + */ + if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2)) + { + if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2)) + return make_result(&const_nan); + if (NUMERIC_IS_PINF(num1)) + { + switch (numeric_sign_internal(num2)) + { + case 0: + return make_result(&const_nan); /* Inf * 0 */ + case 1: + return make_result(&const_pinf); + case -1: + return make_result(&const_ninf); + } + Assert(false); + } + if (NUMERIC_IS_NINF(num1)) + { + switch (numeric_sign_internal(num2)) + { + case 0: + return make_result(&const_nan); /* -Inf * 0 */ + case 1: + return make_result(&const_ninf); + case -1: + return make_result(&const_pinf); + } + Assert(false); + } + /* by here, num1 must be finite, so num2 is not */ + if (NUMERIC_IS_PINF(num2)) + { + switch (numeric_sign_internal(num1)) + { + case 0: + return make_result(&const_nan); /* 0 * Inf */ + case 1: + return make_result(&const_pinf); + case -1: + return make_result(&const_ninf); + } + Assert(false); + } + Assert(NUMERIC_IS_NINF(num2)); + switch (numeric_sign_internal(num1)) + { + case 0: + return make_result(&const_nan); /* 0 * -Inf */ + case 1: + return make_result(&const_ninf); + case -1: + return make_result(&const_pinf); + } + Assert(false); + } + + /* + * Unpack the values, let mul_var() compute the result and return it. + * Unlike add_var() and sub_var(), mul_var() will round its result. In the + * case of numeric_mul(), which is invoked for the * operator on numerics, + * we request exact representation for the product (rscale = sum(dscale of + * arg1, dscale of arg2)). If the exact result has more digits after the + * decimal point than can be stored in a numeric, we round it. Rounding + * after computing the exact result ensures that the final result is + * correctly rounded (rounding in mul_var() using a truncated product + * would not guarantee this). + */ + init_var_from_num(num1, &arg1); + init_var_from_num(num2, &arg2); + + init_var(&result); + + mul_var_karatsuba(&arg1, &arg2, &result, arg1.dscale + arg2.dscale); + + if (result.dscale > NUMERIC_DSCALE_MAX) + round_var(&result, NUMERIC_DSCALE_MAX); + + res = make_result_opt_error(&result, have_error); + + free_var(&result); + + return res; +} + + /* * numeric_div() - * @@ -8659,6 +8818,37 @@ sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result) } +/* + * split_var_at() - + * + * Split a NumericVar into two parts at a specified position. + */ +inline static void +split_var_at(const NumericVar *var, int split_point, + NumericVar *low, NumericVar *high) +{ + int high_ndigits = var->ndigits - split_point; + int low_ndigits = split_point; + + init_var(high); + init_var(low); + + high->ndigits = high_ndigits; + high->digits = var->digits; + high->buf = NULL; + high->weight = var->weight - low_ndigits; + high->sign = var->sign; + high->dscale = (var->ndigits - var->weight - 1) * DEC_DIGITS; + + low->ndigits = low_ndigits; + low->digits = var->digits + high_ndigits; + low->buf = NULL; + low->weight = var->weight - high_ndigits; + low->sign = var->sign; + low->dscale = var->dscale; +} + + /* * mul_var() - * @@ -8865,6 +9055,411 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, } +/* + * mul_var_karatsuba_full() - + * + * Multiplication using the Karatsuba algorithm. + * + * The algorithm normally starts by checking if any of the inputs + * are smaller than the NBASE, the base case for the recursion, + * and if so, fall back to traditional multiplication. + * + * That part is handled by the caller in our code, so when this function + * is called, we know that var1 and var2 are large enough for Karatsuba + * to be used. We also know that var1 is shorter or of equal length as var2, + * which has been arranged by the caller by swapping them if necessary. + * + * The algorithm then proceeds by splitting var1 and var2 into + * two high and low parts, at half the length of the longer input: + * + * m = max(size_NBASE(var1), size_NBASE(var2)) + * m2 = floor(m / 2) + * + * high1, low1 = split_var_at(var1, m2) + * high2, low2 = split_var_at(var2, m2) + * + * z0 = (low1 * low2) + * z1 = ((low1 + high1) * (low2 + high2)) + * z2 = (high1 * high2) + * + * return (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0 + */ +static void +mul_var_karatsuba_full(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + NumericVar high1, low1; + NumericVar high2, low2; + NumericVar z0, z1, z2; + NumericVar temp1, temp2; + int m2 = var2->ndigits / 2; + + init_var(&low1); + init_var(&low2); + init_var(&high1); + init_var(&high2); + init_var(&z0); + init_var(&z1); + init_var(&z2); + init_var(&temp1); + init_var(&temp2); + + split_var_at(var1, m2, &low1, &high1); + split_var_at(var2, m2, &low2, &high2); + + mul_var_karatsuba(&low1, &low2, &z0, low1.dscale + low2.dscale); + + add_var(&low1, &high1, &temp1); + add_var(&low2, &high2, &temp2); + mul_var_karatsuba(&temp1, &temp2, &z1, temp1.dscale + temp2.dscale); + + mul_var_karatsuba(&high1, &high2, &z2, high1.dscale + high2.dscale); + + set_var_from_var(&z2, &temp1); + temp1.weight += m2 * 2; + + sub_var(&z1, &z2, &z1); + sub_var(&z1, &z0, &temp2); + temp2.weight += m2; + + add_var(&temp1, &temp2, &temp2); + add_var(&temp2, &z0, result); + + free_var(&low1); + free_var(&low2); + free_var(&high1); + free_var(&high2); + free_var(&z0); + free_var(&z1); + free_var(&z2); + free_var(&temp1); + free_var(&temp2); + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); + + return; +} + + +/* + * mul_var_karatsuba_half() - + * + * Karatsuba Multiplication for factors with significant length disparity. + * + * The Half-Karatsuba Multiplication Algorithm is a specialized case of + * the normal Karatsuba multiplication algorithm, designed for the scenario + * where var2 has at least twice as many base digits as var1. + * + * In this case var2 (the longer input) is split into high2 and low1, + * at m2 (half the length of var2) and var1 (the shorter input), + * is used directly without splitting. + * + * The algorithm then proceeds as follows: + * + * 1. Compute the product z0 = var1 * low2. + * 2. Compute the product temp2 = var1 * high2. + * 3. Adjust the weight of temp2 by adding m2 (* NBASE ^ m2) + * 4. Add temp2 and z0 to obtain the final result. + * + * Proof: + * + * The algorithm can be derived from the original Karatsuba algorithm by + * simplifying the formula when the shorter factor var1 is not split into + * high and low parts, as shown below. + * + * Original Karatsuba formula: + * + * result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0 + * + * Substitutions: + * + * low1 = var1 + * high1 = 0 + * + * Applying substitutions: + * + * z0 = (low1 * low2) + * = (var1 * low2) + * + * z1 = ((low1 + high1) * (low2 + high2)) + * = ((var1 + 0) * (low2 + high2)) + * = (var1 * low2) + (var1 * high2) + * + * z2 = (high1 * high2) + * = (0 * high2) + * = 0 + * + * Simplified using the above substitutions: + * + * result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0 + * = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0 + * = ((z1 - z0) * NBASE ^ m2) + z0 + * = ((z1 - z0) * NBASE ^ m2) + z0 + * = (var1 * high2) * NBASE ^ m2 + z0 + */ +static void +mul_var_karatsuba_half(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + NumericVar high2, low2; + NumericVar z0; + NumericVar temp2; + int m2 = var2->ndigits / 2; + + init_var(&high2); + init_var(&low2); + init_var(&z0); + init_var(&temp2); + + split_var_at(var2, m2, &low2, &high2); + + mul_var_karatsuba(var1, &low2, &z0, var1->dscale + low2.dscale); + mul_var_karatsuba(var1, &high2, &temp2, var1->dscale + high2.dscale); + temp2.weight += m2; + add_var(&temp2, &z0, result); + + free_var(&high2); + free_var(&low2); + free_var(&z0); + free_var(&temp2); + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); + + return; +} + + +/* + * mul_var_karatsuba() - + * + * Implements Karatsuba multiplication for large numbers, introduced + * alongside the unchanged original mul_var(). This function is part of + * an optimization effort, allowing direct benchmark comparisons with + * mul_var(). It selects full or half Karatsuba based on input size. + * This is a temporary measure before considering its replacement of + * mul_var() based on benchmark outcomes. + */ +static void +mul_var_karatsuba(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + int res_ndigits; + int res_sign; + int res_weight; + int maxdigits; + int *dig; + int carry; + int maxdig; + int newdig; + int var1ndigits; + int var2ndigits; + NumericDigit *var1digits; + NumericDigit *var2digits; + NumericDigit *res_digits; + int i, + i1, + i2; + + /* + * Arrange for var1 to be the shorter of the two numbers. This improves + * performance because the inner multiplication loop is much simpler than + * the outer loop, so it's better to have a smaller number of iterations + * of the outer loop. This also reduces the number of times that the + * accumulator array needs to be normalized. + */ + if (var1->ndigits > var2->ndigits) + { + const NumericVar *tmp = var1; + + var1 = var2; + var2 = tmp; + } + + /* copy these values into local vars for speed in inner loop */ + var1ndigits = var1->ndigits; + var2ndigits = var2->ndigits; + var1digits = var1->digits; + var2digits = var2->digits; + + if (var1ndigits == 0 || var2ndigits == 0) + { + /* one or both inputs is zero; so is result */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* Determine result sign and (maximum possible) weight */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + res_weight = var1->weight + var2->weight + 2; + + /* + * Determine the number of result digits to compute. If the exact result + * would have more than rscale fractional digits, truncate the computation + * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that + * would only contribute to the right of that. (This will give the exact + * rounded-to-rscale answer unless carries out of the ignored positions + * would have propagated through more than MUL_GUARD_DIGITS digits.) + * + * Note: an exact computation could not produce more than var1ndigits + + * var2ndigits digits, but we allocate one extra output digit in case + * rscale-driven rounding produces a carry out of the highest exact digit. + */ + res_ndigits = var1ndigits + var2ndigits + 1; + maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS + + MUL_GUARD_DIGITS; + res_ndigits = Min(res_ndigits, maxdigits); + + if (res_ndigits < 3) + { + /* All input digits will be ignored; so result is zero */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* + * Use the Karatsuba algorithm for sufficiently large factors. + */ + if (KARATSUBA_CONDITION(var1ndigits, var2ndigits)) + { + if (var1ndigits * 2 > var2ndigits) + mul_var_karatsuba_full(var1, var2, result, rscale); + else + mul_var_karatsuba_half(var1, var2, result, rscale); + return; + } + + /* + * We do the arithmetic in an array "dig[]" of signed int's. Since + * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom + * to avoid normalizing carries immediately. + * + * maxdig tracks the maximum possible value of any dig[] entry; when this + * threatens to exceed INT_MAX, we take the time to propagate carries. + * Furthermore, we need to ensure that overflow doesn't occur during the + * carry propagation passes either. The carry values could be as much as + * INT_MAX/NBASE, so really we must normalize when digits threaten to + * exceed INT_MAX - INT_MAX/NBASE. + * + * To avoid overflow in maxdig itself, it actually represents the max + * possible value divided by NBASE-1, ie, at the top of the loop it is + * known that no dig[] entry exceeds maxdig * (NBASE-1). + */ + dig = (int *) palloc0(res_ndigits * sizeof(int)); + maxdig = 0; + + /* + * The least significant digits of var1 should be ignored if they don't + * contribute directly to the first res_ndigits digits of the result that + * we are computing. + * + * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit + * i1+i2+2 of the accumulator array, so we need only consider digits of + * var1 for which i1 <= res_ndigits - 3. + */ + for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--) + { + NumericDigit var1digit = var1digits[i1]; + + if (var1digit == 0) + continue; + + /* Time to normalize? */ + maxdig += var1digit; + if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1)) + { + /* Yes, do it */ + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= NBASE) + { + carry = newdig / NBASE; + newdig -= carry * NBASE; + } + else + carry = 0; + dig[i] = newdig; + } + Assert(carry == 0); + /* Reset maxdig to indicate new worst-case */ + maxdig = 1 + var1digit; + } + + /* + * Add the appropriate multiple of var2 into the accumulator. + * + * As above, digits of var2 can be ignored if they don't contribute, + * so we only include digits for which i1+i2+2 < res_ndigits. + * + * This inner loop is the performance bottleneck for multiplication, + * so we want to keep it simple enough so that it can be + * auto-vectorized. Accordingly, process the digits left-to-right + * even though schoolbook multiplication would suggest right-to-left. + * Since we aren't propagating carries in this loop, the order does + * not matter. + */ + { + int i2limit = Min(var2ndigits, res_ndigits - i1 - 2); + int *dig_i1_2 = &dig[i1 + 2]; + + for (i2 = 0; i2 < i2limit; i2++) + dig_i1_2[i2] += var1digit * var2digits[i2]; + } + } + + /* + * Now we do a final carry propagation pass to normalize the result, which + * we combine with storing the result digits into the output. Note that + * this is still done at full precision w/guard digits. + */ + alloc_var(result, res_ndigits); + res_digits = result->digits; + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= NBASE) + { + carry = newdig / NBASE; + newdig -= carry * NBASE; + } + else + carry = 0; + res_digits[i] = newdig; + } + Assert(carry == 0); + + pfree(dig); + + /* + * Finally, round the result to the requested precision. + */ + result->weight = res_weight; + result->sign = res_sign; + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); + +} + + /* * div_var() - * diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat index 153d816a05..cab6fb8238 100644 --- a/src/include/catalog/pg_proc.dat +++ b/src/include/catalog/pg_proc.dat @@ -4465,6 +4465,9 @@ { oid => '1726', proname => 'numeric_mul', prorettype => 'numeric', proargtypes => 'numeric numeric', prosrc => 'numeric_mul' }, +{ oid => '6312', + proname => 'numeric_mul_karatsuba', prorettype => 'numeric', + proargtypes => 'numeric numeric', prosrc => 'numeric_mul_karatsuba' }, { oid => '1727', proname => 'numeric_div', prorettype => 'numeric', proargtypes => 'numeric numeric', prosrc => 'numeric_div' }, diff --git a/src/include/utils/numeric.h b/src/include/utils/numeric.h index 43c75c436f..2b214a7700 100644 --- a/src/include/utils/numeric.h +++ b/src/include/utils/numeric.h @@ -97,6 +97,8 @@ extern Numeric numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error); extern Numeric numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error); +extern Numeric numeric_mul_karatsuba_opt_error(Numeric num1, Numeric num2, + bool *have_error); extern Numeric numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error); extern Numeric numeric_mod_opt_error(Numeric num1, Numeric num2,