From eb8b8c19a81a52d9cf705d90a597a78cdaf2b6f6 Mon Sep 17 00:00:00 2001 From: ChangLiu Date: Sat, 12 Feb 2022 21:47:30 +0800 Subject: [PATCH] Fix doubling points on the x axis bug, add unittest (#18) --- ecc/curve.py | 21 +++++---------------- tests/test_curve.py | 8 +++++++- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/ecc/curve.py b/ecc/curve.py index a4d9f58..8558a8f 100644 --- a/ecc/curve.py +++ b/ecc/curve.py @@ -100,10 +100,10 @@ def add_point(self, P: Point, Q: Point) -> Point: elif Q.is_at_infinity(): return P - if P == Q: - return self._double_point(P) if P == -Q: return self.INF + if P == Q: + return self._double_point(P) return self._add_point(P, Q) @@ -111,14 +111,6 @@ def add_point(self, P: Point, Q: Point) -> Point: def _add_point(self, P: Point, Q: Point) -> Point: pass - def double_point(self, P: Point) -> Point: - if not self.is_on_curve(P): - raise ValueError("The point is not on the curve.") - if P.is_at_infinity(): - return self.INF - - return self._double_point(P) - @abstractmethod def _double_point(self, P: Point) -> Point: pass @@ -134,17 +126,14 @@ def mul_point(self, d: int, P: Point) -> Point: if d == 0: return self.INF - res = None + res = self.INF is_negative_scalar = d < 0 d = -d if is_negative_scalar else d tmp = P while d: if d & 0x1 == 1: - if res: - res = self.add_point(res, tmp) - else: - res = tmp - tmp = self.double_point(tmp) + res = self.add_point(res, tmp) + tmp = self.add_point(tmp, tmp) d >>= 1 if is_negative_scalar: return -res diff --git a/tests/test_curve.py b/tests/test_curve.py index a7931e4..9eb2ad4 100644 --- a/tests/test_curve.py +++ b/tests/test_curve.py @@ -1,7 +1,7 @@ import unittest from ecc.curve import ( - P256, secp256k1, Curve25519, M383, E222, E382 + P256, secp256k1, Curve25519, M383, E222, E382, Point ) CURVES = [P256, secp256k1, Curve25519, M383, E222, E382] @@ -24,3 +24,9 @@ def test_operator(self): self.assertEqual(curve.INF + curve.INF, curve.INF) self.assertEqual(0 * P, curve.INF) self.assertEqual(1000 * curve.INF, curve.INF) + + def test_double_points_y_equals_to_0(self): + P = Point(x=0, y=0, curve=Curve25519) + self.assertEqual(P + P, Curve25519.INF) + self.assertEqual(2 * P, Curve25519.INF) + self.assertEqual(-2 * P, Curve25519.INF)