Skip to content

Commit

Permalink
performance improvements
Browse files Browse the repository at this point in the history
alleviates malb#75 somewhat
  • Loading branch information
grhkm21 committed Jul 3, 2023
1 parent 2048e1a commit 51d13e3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
16 changes: 13 additions & 3 deletions estimator/lwe_primal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from .conf import red_cost_model as red_cost_model_default
from .conf import red_shape_model as red_shape_model_default
from .conf import red_simulator as red_simulator_default
from fpylll.util import gaussian_heuristic


class PrimalUSVP:
Expand Down Expand Up @@ -84,7 +83,6 @@ def cost_gsa(
red_cost_model=red_cost_model_default,
log_level=None,
):

delta = deltaf(beta)
xi = PrimalUSVP._xi_factor(params.Xs, params.Xe)
m = min(2 * ceil(sqrt(params.n * log(params.q) / log(delta))), m)
Expand Down Expand Up @@ -262,9 +260,21 @@ def svp_dimension(cls, r, D):
:param r: squared Gram-Schmidt norms
"""
from math import lgamma, log, exp, pi

def ball_log_vol(n):
return (n / 2.0) * log(pi) - lgamma(n / 2.0 + 1)

def gaussian_heuristic_log_input(r):
n = len(list(r))
log_vol = sum(r)
log_gh = 1.0 / n * (log_vol - 2 * ball_log_vol(n))
return exp(log_gh)

d = len(r)
r = [log(x) for x in r]
for i, _ in enumerate(r):
if gaussian_heuristic(r[i:]) < D.stddev**2 * (d - i):
if gaussian_heuristic_log_input(r[i:]) < D.stddev**2 * (d - i):
return ZZ(d - (i - 1))
return ZZ(2)

Expand Down
8 changes: 4 additions & 4 deletions estimator/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def babai(r, norm):
Babai probability following [EPRINT:Wun16]_.
"""
R = [RR(sqrt(t) / (2 * norm)) for t in r]
denom = float(2 * norm) ** 2
T = RealDistribution("beta", ((len(r) - 1) / 2, 1.0 / 2))
probs = [1 - T.cum_distribution_function(1 - s ** 2) for s in R]
probs = [1 - T.cum_distribution_function(1 - r_ / denom) for r_ in r]
return prod(probs)


Expand Down Expand Up @@ -103,7 +103,7 @@ def amplify(target_success_probability, success_probability, majority=False):
try:
if majority:
eps = success_probability / 2
return ceil(2 * log(2 - 2 * target_success_probability) / log(1 - 4 * eps ** 2))
return ceil(2 * log(2 - 2 * target_success_probability) / log(1 - 4 * eps**2))
else:
# target_success_probability = 1 - (1-success_probability)^trials
return ceil(log(1 - target_success_probability) / log(1 - success_probability))
Expand All @@ -121,7 +121,7 @@ def amplify_sigma(target_advantage, sigma, q):
"""
try:
sigma = sum(sigma_ ** 2 for sigma_ in sigma).sqrt()
sigma = sum(sigma_**2 for sigma_ in sigma).sqrt()
except TypeError:
pass
advantage = float(exp(-float(pi) * (float(sigma / q) ** 2)))
Expand Down

0 comments on commit 51d13e3

Please sign in to comment.