Skip to content

Commit

Permalink
add template fieldOps, ellipticOps to simplify operations
Browse files Browse the repository at this point in the history
Slightly more type safe version of the templates included in some of
the procs previously.
  • Loading branch information
Vindaar committed Sep 24, 2024
1 parent 844bc95 commit 842255e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 77 deletions.
96 changes: 27 additions & 69 deletions constantine/math_compiler/pub_curves_jacobian.nim
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,35 @@ proc store*(dst: EcPointJac, src: EcPointJac) =
store(dst.getY(), src.getY())
store(dst.getZ(), src.getZ())

proc fromAffine_impl*(asy: Assembler_LLVM, ed: CurveDescriptor, jac: var EcPointJac, aff: EcPointAff) =
template x(ec: EcPointJac | EcPointAff): Field = ec.getX()
template y(ec: EcPointJac | EcPointAff): Field = ec.getY()
template z(ec: EcPointJac): Field = ec.getZ()

template setOne(x): untyped = asy.setOne_internal(ed.fd, x.buf)
template derefBool(x): untyped = asy.load2(asy.ctx.int1_t(), x)
template csetZero(x, c): untyped = asy.csetZero_internal(ed.fd, x.buf, derefBool c)
template ellipticOps*(asy: Assembler_LLVM, ed: CurveDescriptor): untyped =
## This template can be used to make operations on `Field` elements
## more convenient.
## XXX: extend to include all ops
# Boolean checks
template isNeutral(res, x): untyped = asy.isNeutral_internal(ed, res, x.buf)
template isNeutral(x): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isNeutral_internal(ed, res, x.buf)
res

# Conditional ops
template ccopy(x, y: EcPointJac, c): untyped = asy.ccopy_internal(ed, x.buf, y.buf, derefBool c)

# Accessors
template x(ec: EcPointJac | EcPointAff): Field = ec.getX()
template y(ec: EcPointJac | EcPointAff): Field = ec.getY()
template z(ec: EcPointJac): Field = ec.getZ()

proc fromAffine_impl*(asy: Assembler_LLVM, ed: CurveDescriptor, jac: var EcPointJac, aff: EcPointAff) =
# Inject templates for convenient access
fieldOps(asy, ed.fd)
ellipticOps(asy, ed)

jac.x.store(aff.x)
jac.y.store(aff.y)
jac.z.setOne()
jac.z.csetZero(aff.isNeutral())


proc fromAffine_internal*(asy: Assembler_LLVM, ed: CurveDescriptor, j, a: ValueRef) =
## Given an EC point in affine coordinates, converts the point to
## Jacobian coordinates as `jac`.
Expand Down Expand Up @@ -112,7 +122,6 @@ proc genEcFromAffine*(asy: Assembler_LLVM, ed: CurveDescriptor): string =

return name


proc isNeutral_internal*(asy: Assembler_LLVM, ed: CurveDescriptor, r, a: ValueRef) {.used.} =
## Generate an internal elliptic curve point isNeutral proc
## with signature
Expand Down Expand Up @@ -356,46 +365,10 @@ proc sum_internal*(asy: Assembler_LLVM, ed: CurveDescriptor, r, p, q: ValueRef)
## unless we absorb not only the `Builder` in the `Field` / `EcPointJac` objects, but also
## the full `asy`/`ed` types as refs. It is an option though.

# For finite field points
template square(res, y): untyped = asy.nsqr_internal(ed.fd, res.buf, y.buf, count = 1)
template prod(res, x, y): untyped = asy.mul_internal(ed.fd, res.buf, x.buf, y.buf)
template diff(res, x, y): untyped = asy.sub_internal(ed.fd, res.buf, x.buf, y.buf)
template add(res, x, y): untyped = asy.add_internal(ed.fd, res.buf, x.buf, y.buf)
template double(res, x): untyped = asy.double_internal(ed.fd, res.buf, x.buf)
template isZero(res, x): untyped = asy.isZero_internal(ed.fd, res, x.buf)
template isZero(x): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isZero_internal(ed.fd, res, x.buf)
res
template ccopy(x, y: Field, c): untyped = asy.ccopy_internal(ed.fd, x.buf, y.buf, c)
template div2(x): untyped = asy.div2_internal(ed.fd, x.buf)
template csub(x, y, c): untyped = asy.csub_internal(ed.fd, x.buf, y.buf, c)

template `not`(x: ValueRef): untyped = asy.br.`not`(x)

template `*=`(x, y: Field): untyped = x.prod(x, y)
template `+=`(x, y: Field): untyped = x.add(x, y)
template `-=`(x, y: Field): untyped = x.diff(x, y)

template derefBool(x): untyped = asy.load2(asy.ctx.int1_t(), x)

template `and`(x, y): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
res = asy.br.`and`(derefBool x, derefBool y)
res

# For EC points
template isNeutral(res, x): untyped = asy.isNeutral_internal(ed, res, x.buf)
template isNeutral(x): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isNeutral_internal(ed, res, x.buf)
res

template ccopy(x, y: EcPointJac, c): untyped = asy.ccopy_internal(ed, x.buf, y.buf, derefBool c)

template x(ec: EcPointJac): Field = ec.getX()
template y(ec: EcPointJac): Field = ec.getY()
template z(ec: EcPointJac): Field = ec.getZ()
# Make finite field point operations nicer
fieldOps(asy, ed.fd)
# And EC points
ellipticOps(asy, ed)

## XXX: Required to extent for coefA != 0!
when false:
Expand Down Expand Up @@ -586,25 +559,10 @@ proc double_internal*(asy: Assembler_LLVM, ed: CurveDescriptor, r, p: ValueRef)
## unless we absorb not only the `Builder` in the `Field` / `EcPointJac` objects, but also
## the full `asy`/`ed` types as refs. It is an option though.

# For finite field points
template square(res, y): untyped = asy.nsqr_internal(ed.fd, res.buf, y.buf, count = 1)
template square(x): untyped = square(x, x)
template prod(res, x, y): untyped = asy.mul_internal(ed.fd, res.buf, x.buf, y.buf)
template diff(res, x, y): untyped = asy.sub_internal(ed.fd, res.buf, x.buf, y.buf)
template double(res, x): untyped = asy.double_internal(ed.fd, res.buf, x.buf)
template double(x): untyped = double(x, x)
template add(res, x, y): untyped = asy.add_internal(ed.fd, res.buf, x.buf, y.buf)

template `*=`(x, y: Field): untyped = x.prod(x, y)
template `+=`(x, y: Field): untyped = x.add(x, y)
template `-=`(x, y: Field): untyped = x.diff(x, y)

template `*=`(x: Field, b: static int): untyped = asy.scalarMul_internal(ed.fd, x.buf, b)

# For EC points
template x(ec: EcPointJac): Field = ec.getX()
template y(ec: EcPointJac): Field = ec.getY()
template z(ec: EcPointJac): Field = ec.getZ()
# Make operations more convenient, for fields:
fieldOps(asy, ed.fd)
# and for EC points
ellipticOps(asy, ed)

var
A = asy.newField(ed.fd)
Expand Down
72 changes: 64 additions & 8 deletions constantine/math_compiler/pub_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,67 @@ import
## Section name used for `llvmInternalFnDef`
const SectionName = "ctt.pub_fields"

template fieldOps*(asy: Assembler_LLVM, fd: FieldDescriptor): untyped {.dirty.} =
## This template can be used to make operations on `Field` elements
## more convenient.

## Note: This is handled via these templates, due to the Assembler_LLVM and
## (to a lesser extent) `FieldDescriptor` dependency.
## We could partially solve that by having `_impl` procs for every operation,
## which _only_ contains the inner code for the `llvmInternalFnDef` code.
## Thay way we could call _that_ function in these templates, which would
## be independent of the `asy`.
## For the `FieldDescriptor` we could anyhow (mostly?) resuse the `BuilderRef`
## that is part of the `Array` object (which `Field` is a `distinct` to).

## XXX: extend to include all ops
# Boolean checks
template isZero(res, x: Field): untyped = asy.isZero_internal(fd, res, x.buf)
template isZero(x: Field): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isZero_internal(fd, res, x.buf)
res

# Boolean logic
template `not`(x: ValueRef): untyped = asy.br.`not`(x)
template derefBool(x: ValueRef): untyped = asy.load2(asy.ctx.int1_t(), x)
template `and`(x, y): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
res = asy.br.`and`(derefBool x, derefBool y)
res

# Mutators
template setZero(x: Field): untyped = asy.setZero_internal(fd, x.buf)
template setOne(x: Field): untyped = asy.setZero_internal(fd, x.buf)
template neg(res, y: Field): untyped = asy.neg_internal(fd, res.buf, y.buf)

# Conditional setters
template csetZero(x: Field, c): untyped = asy.csetZero_internal(fd, x.buf, derefBool c)

# Basic arithmetic
template sum(res, x, y: Field): untyped = asy.add_internal(fd, res.buf, x.buf, y.buf)
template add(res, x, y: Field): untyped = asy.add_internal(fd, res.buf, x.buf, y.buf)
template diff(res, x, y: Field): untyped = asy.sub_internal(fd, res.buf, x.buf, y.buf)
template prod(res, x, y: Field): untyped = asy.mul_internal(fd, res.buf, x.buf, y.buf)

# Conditional arithmetic
template cadd(x, y: Field, c): untyped = asy.cadd_internal(fd, x.buf, y.buf, c)
template csub(x, y: Field, c): untyped = asy.csub_internal(fd, x.buf, y.buf, c)
template ccopy(x, y: Field, c): untyped = asy.ccopy_internal(fd, x.buf, y.buf, c)

# Extended arithmetic
template square(res, y: Field): untyped = asy.nsqr_internal(fd, res.buf, y.buf, count = 1)
template square(x): untyped = square(x, x)
template double(res, x: Field): untyped = asy.double_internal(fd, res.buf, x.buf)
template double(x: Field): untyped = asy.double_internal(fd, x.buf, x.buf)
template div2(x: Field): untyped = asy.div2_internal(fd, x.buf)

# Mutating assignment ops
template `*=`(x, y: Field): untyped = x.prod(x, y)
template `+=`(x, y: Field): untyped = x.add(x, y)
template `-=`(x, y: Field): untyped = x.diff(x, y)
template `*=`(x: Field, b: static int): untyped = asy.scalarMul_internal(fd, x.buf, b)

proc setZero_internal*(asy: Assembler_LLVM, fd: FieldDescriptor, r: ValueRef) {.used.} =
## Generate an internal field setZero
## with signature
Expand Down Expand Up @@ -855,15 +916,10 @@ proc scalarMul_internal*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef,
{kHot}):
let ai = llvmParams

let a = asy.asField(fd, ai) # shadow `a` argument of proc
# Make field ops convenient:
fieldOps(asy, fd)

template neg(res, y): untyped = asy.neg_internal(fd, res.buf, y.buf)
template setZero(x): untyped = asy.setZero_internal(fd, x.buf)
template double(res, x): untyped = asy.double_internal(fd, res.buf, x.buf)
template double(x): untyped = asy.double_internal(fd, x.buf, x.buf)
template diff(res, x, y): untyped = asy.sub_internal(fd, res.buf, x.buf, y.buf)
template sum(res, x, y): untyped = asy.add_internal(fd, res.buf, x.buf, y.buf)
template `+=`(res, x): untyped = asy.add_internal(fd, res.buf, res.buf, x.buf)
let a = asy.asField(fd, ai) # shadow `a` argument of proc

const negate = b < 0
const b = if negate: -b
Expand Down

0 comments on commit 842255e

Please sign in to comment.