Skip to content

Commit

Permalink
Fast flint_mpn_mullow_n for large n; use in preinvn functions (#1933)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrik-johansson authored Apr 27, 2024
1 parent d448165 commit c96635b
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 80 deletions.
28 changes: 28 additions & 0 deletions doc/source/mpn_extras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,34 @@ More generally, we can define `n`-limb high products of `m`-limb and
On x86-64 machines with ADX, the basecase version currently assumes
that `n \ge 8`.

.. function:: void _flint_mpn_mullow_n_mulders_recursive(mp_ptr rp, mp_srcptr u, mp_srcptr v, mp_size_t n)
mp_limb_t flint_mpn_mullow_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
mp_limb_t _flint_mpn_mullow_n_mulders(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
mp_limb_t _flint_mpn_mullow_n_mul(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
mp_limb_t _flint_mpn_mullow_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
mp_limb_t flint_mpn_mullow_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)

Compute the low `n` limbs of the product.

The `(n + 1)`-th limb is also computed and returned.
Warning: this extra limb of output may be removed in the future.

.. function:: void flint_mpn_mul_or_mullow_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)

Write the low `n + 1` limbs of the product `uv` to ``res``.
The output is assumed to have space for `2n` limbs so that the high
limbs can be used as scratch space or to write the whole product
when this is the fastest method.

Warning: the one extra limb of output may be removed in the future.

.. function:: void flint_mpn_mul_or_mulhigh_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)

Write the high `n + 1` limbs of the product `uv` to ``res + (n - 1)``
(with possible error of a few ulps as for :func:`flint_mpn_mulhigh_n`).
The low `n - 1` limbs of the output may be used as scratch space or
to write the whole product when this is the fastest method.

Divisibility
--------------------------------------------------------------------------------

Expand Down
87 changes: 49 additions & 38 deletions src/mpn_extras.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,84 +396,81 @@ flint_mpn_sqr(mp_ptr r, mp_srcptr x, mp_size_t n)
flint_mpn_mul((_z), (_y), (_yn), (_x), (_xn)); \
}

/* low multiplication ********************************************************/
/* High and low multiplication *******************************************************/

#define FLINT_HAVE_MULLOW_FUNC(n) ((n) <= FLINT_MPN_MULLOW_FUNC_TAB_WIDTH)

FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mullow_func_tab[];

mp_limb_t flint_mpn_mullow_basecase(mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);

#if FLINT_HAVE_ASSEMBLY_x86_64_adx
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 8
# define FLINT_HAVE_NATIVE_mpn_mullow_basecase 1
#else
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 0
#endif

/* TODO: Fix higher stuff */
MPN_EXTRAS_INLINE
mp_limb_t flint_mpn_mullow_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);

if (FLINT_HAVE_MULLOW_FUNC(n))
{
FLINT_ASSERT(rp != xp);
return flint_mpn_mullow_func_tab[n](rp, xp, yp);
}
else
return flint_mpn_mullow_basecase(rp, xp, yp, n);
}

/* high multiplication *******************************************************/

#define FLINT_HAVE_MULHIGH_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH)
#define FLINT_HAVE_SQRHIGH_FUNC(n) ((n) <= FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH)
#define FLINT_HAVE_MULHIGH_NORMALISED_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH)

typedef struct { mp_limb_t m1; mp_limb_t m2; } mp_limb_pair_t;
typedef mp_limb_pair_t (* flint_mpn_mulhigh_normalised_func_t)(mp_ptr, mp_srcptr, mp_srcptr);

FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mullow_func_tab[];
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[];
FLINT_DLL extern const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[];
FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[];

#if FLINT_HAVE_ASSEMBLY_x86_64_adx
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 9
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 9

# define FLINT_HAVE_NATIVE_mpn_mullow_basecase 1
/* NOTE: This function only works for n >= 6 */
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1

/* NOTE: The x86_64_adx versions of these functions only works for n >= 6 */
/* NOTE: This function only works for n >= 6 */
# define FLINT_HAVE_NATIVE_mpn_sqrhigh_basecase 1

#elif FLINT_HAVE_ASSEMBLY_armv8
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 0
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0

/* NOTE: This function only works for n > 8 */
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1

#else
/* TODO: generic hardcoded mullows */
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 0
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 16
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 2
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0

#endif

/* FIXME: this tuning is for x86_64_adx with fft_small */
/* NOTE: we assume that the same cutoff is optimal for both mulhigh and mullow */
#define FLINT_MPN_MULHIGH_MULDERS_CUTOFF 50
#define FLINT_MPN_MULHIGH_MUL_CUTOFF 2000
#define FLINT_MPN_MULHIGH_K_TAB_SIZE 2048

void _flint_mpn_mulhigh_n_mulders_recursive(mp_ptr rp, mp_srcptr np, mp_srcptr mp, mp_size_t n);
FLINT_DLL extern const signed short flint_mpn_mulhigh_k_tab[FLINT_MPN_MULHIGH_K_TAB_SIZE];

mp_limb_t _flint_mpn_mulhigh_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
mp_limb_t flint_mpn_mullow_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
void _flint_mpn_mullow_n_mulders_recursive(mp_ptr rp, mp_srcptr np, mp_srcptr mp, mp_size_t n);
mp_limb_t _flint_mpn_mullow_n_mulders(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
mp_limb_t _flint_mpn_mullow_n_mul(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
mp_limb_t _flint_mpn_mullow_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);

mp_limb_t _flint_mpn_mulhigh_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
void _flint_mpn_mulhigh_n_mulders_recursive(mp_ptr rp, mp_srcptr np, mp_srcptr mp, mp_size_t n);
mp_limb_t _flint_mpn_mulhigh_n_mulders(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
mp_limb_t _flint_mpn_mulhigh_n_mul(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);
mp_limb_t _flint_mpn_mulhigh_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n);

MPN_EXTRAS_INLINE
mp_limb_t flint_mpn_mullow_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);
FLINT_ASSERT(rp != xp);

if (FLINT_HAVE_MULLOW_FUNC(n))
return flint_mpn_mullow_func_tab[n](rp, xp, yp);
else
return _flint_mpn_mullow_n(rp, xp, yp, n);
}

MPN_EXTRAS_INLINE
mp_limb_t flint_mpn_mulhigh_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
Expand All @@ -485,9 +482,23 @@ mp_limb_t flint_mpn_mulhigh_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n
return _flint_mpn_mulhigh_n(rp, xp, yp, n);
}

/* We just want the high n limbs, but rp has low limbs available
/* We just want the high or low n limbs, but rp has 2n limbs available
which can be used for scratch space or for doing a full multiply
without temporary allocations. */
without temporary allocations. TODO: exploit this in the Mulders range
by calling Mulders directly. */
MPN_EXTRAS_INLINE
void flint_mpn_mul_or_mullow_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);

if (FLINT_HAVE_MULLOW_FUNC(n))
rp[n] = flint_mpn_mullow_func_tab[n](rp, xp, yp);
else if (n < FLINT_MPN_MULHIGH_MUL_CUTOFF)
rp[n] = _flint_mpn_mullow_n(rp, xp, yp, n);
else
flint_mpn_mul_n(rp, xp, yp, n);
}

MPN_EXTRAS_INLINE
void flint_mpn_mul_or_mulhigh_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
Expand Down
8 changes: 7 additions & 1 deletion src/mpn_extras/divrem_preinvn.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ mp_limb_t flint_mpn_divrem_preinvn(mp_ptr qp, mp_ptr rp, mp_srcptr ap, mp_size_t
flint_mpn_mul_or_mulhigh_n(t, dinv, r + n, n);
cy = mpn_add_n(q, t + n, r + n, n);

flint_mpn_mul_n(t, d, q, n);
/* note: we rely on the fact that mul_or_mullow_n actually
writes at least n + 1 limbs */
flint_mpn_mul_or_mullow_n(t, d, q, n);
cy = r[n] - t[n] - mpn_sub_n(r, a, t, n);

while (cy > 0)
Expand All @@ -65,6 +67,8 @@ mp_limb_t flint_mpn_divrem_preinvn(mp_ptr qp, mp_ptr rp, mp_srcptr ap, mp_size_t
mpn_add_1(q, q, n, 1);
}

FLINT_ASSERT(mpn_cmp(r, d, n) < 0);

m -= n;
r -= n;
a -= n;
Expand Down Expand Up @@ -99,6 +103,8 @@ mp_limb_t flint_mpn_divrem_preinvn(mp_ptr qp, mp_ptr rp, mp_srcptr ap, mp_size_t
mpn_sub_n(rp, rp, d, n);
mpn_add_1(qp, qp, size, 1);
}

FLINT_ASSERT(mpn_cmp(rp, d, n) < 0);
}

TMP_END;
Expand Down
8 changes: 7 additions & 1 deletion src/mpn_extras/mod_preinvn.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ void flint_mpn_mod_preinvn(mp_ptr rp, mp_srcptr ap, mp_size_t m,
flint_mpn_mul_or_mulhigh_n(t, dinv, r + n, n);
cy = mpn_add_n(t + 2*n, t + n, r + n, n);

flint_mpn_mul_n(t, d, t + 2*n, n);
/* note: we rely on the fact that mul_or_mullow_n actually
writes at least n + 1 limbs */
flint_mpn_mul_or_mullow_n(t, d, t + 2*n, n);
cy = r[n] - t[n] - mpn_sub_n(r, a, t, n);

while (cy > 0)
Expand All @@ -55,6 +57,8 @@ void flint_mpn_mod_preinvn(mp_ptr rp, mp_srcptr ap, mp_size_t m,
if (mpn_cmp(r, d, n) >= 0)
mpn_sub_n(r, r, d, n);

FLINT_ASSERT(mpn_cmp(r, d, n) < 0);

m -= n;
r -= n;
a -= n;
Expand Down Expand Up @@ -82,6 +86,8 @@ void flint_mpn_mod_preinvn(mp_ptr rp, mp_srcptr ap, mp_size_t m,

if (mpn_cmp(rp, d, n) >= 0)
mpn_sub_n(rp, rp, d, n);

FLINT_ASSERT(mpn_cmp(rp, d, n) < 0);
}

TMP_END;
Expand Down
2 changes: 1 addition & 1 deletion src/mpn_extras/mulhigh.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "mpn_extras.h"

/* Generated by tune-mulhigh.c */
static const signed short flint_mpn_mulhigh_k_tab[FLINT_MPN_MULHIGH_K_TAB_SIZE] =
const signed short flint_mpn_mulhigh_k_tab[FLINT_MPN_MULHIGH_K_TAB_SIZE] =
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 14, 14, 16, 15, 15, 18, 18,
18, 19, 20, 18, 22, 22, 20, 20, 26, 22, 22, 22, 24, 24, 24, 26, 25, 26, 30, 30, 28, 30, 31, 32, 32, 30, 36, 36, 36, 36,
38, 39, 39, 38, 39, 40, 40, 40, 44, 40, 44, 44, 40, 44, 44, 48, 44, 48, 44, 48, 48, 52, 52, 52, 44, 52, 52, 52, 52, 56,
Expand Down
90 changes: 90 additions & 0 deletions src/mpn_extras/mullow.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
Copyright (C) 2024 Albin Ahlbäck
Copyright (C) 2024 Fredrik Johansson
This file is part of FLINT.
Expand All @@ -9,6 +10,7 @@
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include <string.h>
#include "mpn_extras.h"

#if FLINT_HAVE_ASSEMBLY_x86_64_adx
Expand Down Expand Up @@ -55,3 +57,91 @@ flint_mpn_mullow_basecase(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
return ret;
}
#endif

void
_flint_mpn_mullow_n_mulders_recursive(mp_ptr rp, mp_srcptr np, mp_srcptr mp, mp_size_t n)
{
mp_size_t l;
slong k;

if (FLINT_HAVE_MULLOW_FUNC(n))
{
flint_mpn_mullow_func_tab[n](rp, np, mp);
return;
}

if (n < FLINT_MPN_MULHIGH_K_TAB_SIZE)
k = flint_mpn_mulhigh_k_tab[n];
else
k = 3 * (n / 4);

if (k == 0)
{
flint_mpn_mullow_basecase(rp, np, mp, n);
return;
}

if (k == n)
{
flint_mpn_mul_n(rp, np, mp, n);
return;
}

FLINT_ASSERT(k >= (n + 1) / 2);

l = n - k;

flint_mpn_mul_n(rp, np, mp, k);
_flint_mpn_mullow_n_mulders_recursive(rp + n, np, mp + k, l);
mpn_add_n(rp + k, rp + k, rp + n, l);
_flint_mpn_mullow_n_mulders_recursive(rp + n, np + k, mp, l);
mpn_add_n(rp + k, rp + k, rp + n, l);
}

mp_limb_t
_flint_mpn_mullow_n_mulders(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
{
mp_ptr tmp, tr, tu, tv;
mp_limb_t bot;
TMP_INIT;

TMP_START;
tmp = TMP_ALLOC(sizeof(mp_limb_t) * (4 * (n + 1)));
tu = tmp;
tv = tu + (n + 1);
tr = tv + (n + 1);
tu[n] = 0;
tv[n] = 0;
flint_mpn_copyi(tu, u, n);
flint_mpn_copyi(tv, v, n);
_flint_mpn_mullow_n_mulders_recursive(tr, tu, tv, n + 1);
flint_mpn_copyi(res, tr, n);
bot = tr[n];

TMP_END;
return bot;
}

mp_limb_t
_flint_mpn_mullow_n_mul(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
{
mp_ptr tmp;
mp_limb_t cy;
tmp = flint_malloc(sizeof(mp_limb_t) * (2 * n));
flint_mpn_mul_n(tmp, u, v, n);
memcpy(res, tmp, sizeof(mp_limb_t) * n);
cy = tmp[n];
flint_free(tmp);
return cy;
}

mp_limb_t
_flint_mpn_mullow_n(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
{
if (n <= FLINT_MPN_MULHIGH_MULDERS_CUTOFF)
return flint_mpn_mullow_basecase(res, u, v, n);
else if (n <= FLINT_MPN_MULHIGH_MUL_CUTOFF)
return _flint_mpn_mullow_n_mulders(res, u, v, n);
else
return _flint_mpn_mullow_n_mul(res, u, v, n);
}
6 changes: 5 additions & 1 deletion src/mpn_extras/mulmod_preinvn.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ void flint_mpn_mulmod_preinvn(mp_ptr r,
flint_mpn_mul_or_mulhigh_n(t + 3*n, t + n, dinv, n);
mpn_add_n(t + 4*n, t + 4*n, t + n, n);

flint_mpn_mul_n(t + 2*n, t + 4*n, d, n);
/* note: we rely on the fact that mul_or_mullow_n actually
writes at least n + 1 limbs */
flint_mpn_mul_or_mullow_n(t + 2*n, t + 4*n, d, n);
cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n);

while (cy > 0)
Expand All @@ -111,6 +113,8 @@ void flint_mpn_mulmod_preinvn(mp_ptr r,
if (mpn_cmp(r, d, n) >= 0)
mpn_sub_n(r, r, d, n);

FLINT_ASSERT(mpn_cmp(r, d, n) < 0);

TMP_END;
}
}
Loading

0 comments on commit c96635b

Please sign in to comment.