diff options
Diffstat (limited to 'src/crypto/internal/edwards25519/field')
-rw-r--r-- | src/crypto/internal/edwards25519/field/_asm/fe_amd64_asm.go | 294 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/_asm/go.mod | 12 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/_asm/go.sum | 32 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe.go | 420 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_alias_test.go | 140 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_amd64.go | 15 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_amd64.s | 378 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_amd64_noasm.go | 11 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_arm64.go | 15 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_arm64.s | 42 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_arm64_noasm.go | 11 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_bench_test.go | 49 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_generic.go | 266 | ||||
-rw-r--r-- | src/crypto/internal/edwards25519/field/fe_test.go | 560 |
14 files changed, 2245 insertions, 0 deletions
diff --git a/src/crypto/internal/edwards25519/field/_asm/fe_amd64_asm.go b/src/crypto/internal/edwards25519/field/_asm/fe_amd64_asm.go new file mode 100644 index 0000000..411399c --- /dev/null +++ b/src/crypto/internal/edwards25519/field/_asm/fe_amd64_asm.go @@ -0,0 +1,294 @@ +// Copyright (c) 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + + . "github.com/mmcloughlin/avo/build" + . "github.com/mmcloughlin/avo/gotypes" + . "github.com/mmcloughlin/avo/operand" + . "github.com/mmcloughlin/avo/reg" +) + +//go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field + +func main() { + Package("crypto/internal/edwards25519/field") + ConstraintExpr("amd64,gc,!purego") + feMul() + feSquare() + Generate() +} + +type namedComponent struct { + Component + name string +} + +func (c namedComponent) String() string { return c.name } + +type uint128 struct { + name string + hi, lo GPVirtual +} + +func (c uint128) String() string { return c.name } + +func feSquare() { + TEXT("feSquare", NOSPLIT, "func(out, a *Element)") + Doc("feSquare sets out = a * a. It works like feSquareGeneric.") + Pragma("noescape") + + a := Dereference(Param("a")) + l0 := namedComponent{a.Field("l0"), "l0"} + l1 := namedComponent{a.Field("l1"), "l1"} + l2 := namedComponent{a.Field("l2"), "l2"} + l3 := namedComponent{a.Field("l3"), "l3"} + l4 := namedComponent{a.Field("l4"), "l4"} + + // r0 = l0×l0 + 19×2×(l1×l4 + l2×l3) + r0 := uint128{"r0", GP64(), GP64()} + mul64(r0, 1, l0, l0) + addMul64(r0, 38, l1, l4) + addMul64(r0, 38, l2, l3) + + // r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3 + r1 := uint128{"r1", GP64(), GP64()} + mul64(r1, 2, l0, l1) + addMul64(r1, 38, l2, l4) + addMul64(r1, 19, l3, l3) + + // r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4 + r2 := uint128{"r2", GP64(), GP64()} + mul64(r2, 2, l0, l2) + addMul64(r2, 1, l1, l1) + addMul64(r2, 38, l3, l4) + + // r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4 + r3 := uint128{"r3", GP64(), GP64()} + mul64(r3, 2, l0, l3) + addMul64(r3, 2, l1, l2) + addMul64(r3, 19, l4, l4) + + // r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2 + r4 := uint128{"r4", GP64(), GP64()} + mul64(r4, 2, l0, l4) + addMul64(r4, 2, l1, l3) + addMul64(r4, 1, l2, l2) + + Comment("First reduction chain") + maskLow51Bits := GP64() + MOVQ(Imm((1<<51)-1), maskLow51Bits) + c0, r0lo := shiftRightBy51(&r0) + c1, r1lo := shiftRightBy51(&r1) + c2, r2lo := shiftRightBy51(&r2) + c3, r3lo := shiftRightBy51(&r3) + c4, r4lo := shiftRightBy51(&r4) + maskAndAdd(r0lo, maskLow51Bits, c4, 19) + maskAndAdd(r1lo, maskLow51Bits, c0, 1) + maskAndAdd(r2lo, maskLow51Bits, c1, 1) + maskAndAdd(r3lo, maskLow51Bits, c2, 1) + maskAndAdd(r4lo, maskLow51Bits, c3, 1) + + Comment("Second reduction chain (carryPropagate)") + // c0 = r0 >> 51 + MOVQ(r0lo, c0) + SHRQ(Imm(51), c0) + // c1 = r1 >> 51 + MOVQ(r1lo, c1) + SHRQ(Imm(51), c1) + // c2 = r2 >> 51 + MOVQ(r2lo, c2) + SHRQ(Imm(51), c2) + // c3 = r3 >> 51 + MOVQ(r3lo, c3) + SHRQ(Imm(51), c3) + // c4 = r4 >> 51 + MOVQ(r4lo, c4) + SHRQ(Imm(51), c4) + maskAndAdd(r0lo, maskLow51Bits, c4, 19) + maskAndAdd(r1lo, maskLow51Bits, c0, 1) + maskAndAdd(r2lo, maskLow51Bits, c1, 1) + maskAndAdd(r3lo, maskLow51Bits, c2, 1) + maskAndAdd(r4lo, maskLow51Bits, c3, 1) + + Comment("Store output") + out := Dereference(Param("out")) + Store(r0lo, out.Field("l0")) + Store(r1lo, out.Field("l1")) + Store(r2lo, out.Field("l2")) + Store(r3lo, out.Field("l3")) + Store(r4lo, out.Field("l4")) + + RET() +} + +func feMul() { + TEXT("feMul", NOSPLIT, "func(out, a, b *Element)") + Doc("feMul sets out = a * b. It works like feMulGeneric.") + Pragma("noescape") + + a := Dereference(Param("a")) + a0 := namedComponent{a.Field("l0"), "a0"} + a1 := namedComponent{a.Field("l1"), "a1"} + a2 := namedComponent{a.Field("l2"), "a2"} + a3 := namedComponent{a.Field("l3"), "a3"} + a4 := namedComponent{a.Field("l4"), "a4"} + + b := Dereference(Param("b")) + b0 := namedComponent{b.Field("l0"), "b0"} + b1 := namedComponent{b.Field("l1"), "b1"} + b2 := namedComponent{b.Field("l2"), "b2"} + b3 := namedComponent{b.Field("l3"), "b3"} + b4 := namedComponent{b.Field("l4"), "b4"} + + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + r0 := uint128{"r0", GP64(), GP64()} + mul64(r0, 1, a0, b0) + addMul64(r0, 19, a1, b4) + addMul64(r0, 19, a2, b3) + addMul64(r0, 19, a3, b2) + addMul64(r0, 19, a4, b1) + + // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2) + r1 := uint128{"r1", GP64(), GP64()} + mul64(r1, 1, a0, b1) + addMul64(r1, 1, a1, b0) + addMul64(r1, 19, a2, b4) + addMul64(r1, 19, a3, b3) + addMul64(r1, 19, a4, b2) + + // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3) + r2 := uint128{"r2", GP64(), GP64()} + mul64(r2, 1, a0, b2) + addMul64(r2, 1, a1, b1) + addMul64(r2, 1, a2, b0) + addMul64(r2, 19, a3, b4) + addMul64(r2, 19, a4, b3) + + // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4 + r3 := uint128{"r3", GP64(), GP64()} + mul64(r3, 1, a0, b3) + addMul64(r3, 1, a1, b2) + addMul64(r3, 1, a2, b1) + addMul64(r3, 1, a3, b0) + addMul64(r3, 19, a4, b4) + + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + r4 := uint128{"r4", GP64(), GP64()} + mul64(r4, 1, a0, b4) + addMul64(r4, 1, a1, b3) + addMul64(r4, 1, a2, b2) + addMul64(r4, 1, a3, b1) + addMul64(r4, 1, a4, b0) + + Comment("First reduction chain") + maskLow51Bits := GP64() + MOVQ(Imm((1<<51)-1), maskLow51Bits) + c0, r0lo := shiftRightBy51(&r0) + c1, r1lo := shiftRightBy51(&r1) + c2, r2lo := shiftRightBy51(&r2) + c3, r3lo := shiftRightBy51(&r3) + c4, r4lo := shiftRightBy51(&r4) + maskAndAdd(r0lo, maskLow51Bits, c4, 19) + maskAndAdd(r1lo, maskLow51Bits, c0, 1) + maskAndAdd(r2lo, maskLow51Bits, c1, 1) + maskAndAdd(r3lo, maskLow51Bits, c2, 1) + maskAndAdd(r4lo, maskLow51Bits, c3, 1) + + Comment("Second reduction chain (carryPropagate)") + // c0 = r0 >> 51 + MOVQ(r0lo, c0) + SHRQ(Imm(51), c0) + // c1 = r1 >> 51 + MOVQ(r1lo, c1) + SHRQ(Imm(51), c1) + // c2 = r2 >> 51 + MOVQ(r2lo, c2) + SHRQ(Imm(51), c2) + // c3 = r3 >> 51 + MOVQ(r3lo, c3) + SHRQ(Imm(51), c3) + // c4 = r4 >> 51 + MOVQ(r4lo, c4) + SHRQ(Imm(51), c4) + maskAndAdd(r0lo, maskLow51Bits, c4, 19) + maskAndAdd(r1lo, maskLow51Bits, c0, 1) + maskAndAdd(r2lo, maskLow51Bits, c1, 1) + maskAndAdd(r3lo, maskLow51Bits, c2, 1) + maskAndAdd(r4lo, maskLow51Bits, c3, 1) + + Comment("Store output") + out := Dereference(Param("out")) + Store(r0lo, out.Field("l0")) + Store(r1lo, out.Field("l1")) + Store(r2lo, out.Field("l2")) + Store(r3lo, out.Field("l3")) + Store(r4lo, out.Field("l4")) + + RET() +} + +// mul64 sets r to i * aX * bX. +func mul64(r uint128, i int, aX, bX namedComponent) { + switch i { + case 1: + Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX)) + Load(aX, RAX) + case 2: + Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX)) + Load(aX, RAX) + SHLQ(Imm(1), RAX) + default: + panic("unsupported i value") + } + MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX + MOVQ(RAX, r.lo) + MOVQ(RDX, r.hi) +} + +// addMul64 sets r to r + i * aX * bX. +func addMul64(r uint128, i uint64, aX, bX namedComponent) { + switch i { + case 1: + Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX)) + Load(aX, RAX) + default: + Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX)) + IMUL3Q(Imm(i), Load(aX, GP64()), RAX) + } + MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX + ADDQ(RAX, r.lo) + ADCQ(RDX, r.hi) +} + +// shiftRightBy51 returns r >> 51 and r.lo. +// +// After this function is called, the uint128 may not be used anymore. +func shiftRightBy51(r *uint128) (out, lo GPVirtual) { + out = r.hi + lo = r.lo + SHLQ(Imm(64-51), r.lo, r.hi) + r.lo, r.hi = nil, nil // make sure the uint128 is unusable + return +} + +// maskAndAdd sets r = r&mask + c*i. +func maskAndAdd(r, mask, c GPVirtual, i uint64) { + ANDQ(mask, r) + if i != 1 { + IMUL3Q(Imm(i), c, c) + } + ADDQ(c, r) +} + +func mustAddr(c Component) Op { + b, err := c.Resolve() + if err != nil { + panic(err) + } + return b.Addr +} diff --git a/src/crypto/internal/edwards25519/field/_asm/go.mod b/src/crypto/internal/edwards25519/field/_asm/go.mod new file mode 100644 index 0000000..1ce2b5e --- /dev/null +++ b/src/crypto/internal/edwards25519/field/_asm/go.mod @@ -0,0 +1,12 @@ +module asm + +go 1.19 + +require github.com/mmcloughlin/avo v0.4.0 + +require ( + golang.org/x/mod v0.4.2 // indirect + golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 // indirect + golang.org/x/tools v0.1.7 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect +) diff --git a/src/crypto/internal/edwards25519/field/_asm/go.sum b/src/crypto/internal/edwards25519/field/_asm/go.sum new file mode 100644 index 0000000..b4b5914 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/_asm/go.sum @@ -0,0 +1,32 @@ +github.com/mmcloughlin/avo v0.4.0 h1:jeHDRktVD+578ULxWpQHkilor6pkdLF7u7EiTzDbfcU= +github.com/mmcloughlin/avo v0.4.0/go.mod h1:RW9BfYA3TgO9uCdNrKU2h6J8cPD8ZLznvfgHAeszb1s= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 h1:giLT+HuUP/gXYrG2Plg9WTjj4qhfgaW424ZIFog3rlk= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= +golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/src/crypto/internal/edwards25519/field/fe.go b/src/crypto/internal/edwards25519/field/fe.go new file mode 100644 index 0000000..5518ef2 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe.go @@ -0,0 +1,420 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package field implements fast arithmetic modulo 2^255-19. +package field + +import ( + "crypto/subtle" + "encoding/binary" + "errors" + "math/bits" +) + +// Element represents an element of the field GF(2^255-19). Note that this +// is not a cryptographically secure group, and should only be used to interact +// with edwards25519.Point coordinates. +// +// This type works similarly to math/big.Int, and all arguments and receivers +// are allowed to alias. +// +// The zero value is a valid zero element. +type Element struct { + // An element t represents the integer + // t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204 + // + // Between operations, all limbs are expected to be lower than 2^52. + l0 uint64 + l1 uint64 + l2 uint64 + l3 uint64 + l4 uint64 +} + +const maskLow51Bits uint64 = (1 << 51) - 1 + +var feZero = &Element{0, 0, 0, 0, 0} + +// Zero sets v = 0, and returns v. +func (v *Element) Zero() *Element { + *v = *feZero + return v +} + +var feOne = &Element{1, 0, 0, 0, 0} + +// One sets v = 1, and returns v. +func (v *Element) One() *Element { + *v = *feOne + return v +} + +// reduce reduces v modulo 2^255 - 19 and returns it. +func (v *Element) reduce() *Element { + v.carryPropagate() + + // After the light reduction we now have a field element representation + // v < 2^255 + 2^13 * 19, but need v < 2^255 - 19. + + // If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1, + // generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise. + c := (v.l0 + 19) >> 51 + c = (v.l1 + c) >> 51 + c = (v.l2 + c) >> 51 + c = (v.l3 + c) >> 51 + c = (v.l4 + c) >> 51 + + // If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's + // effectively applying the reduction identity to the carry. + v.l0 += 19 * c + + v.l1 += v.l0 >> 51 + v.l0 = v.l0 & maskLow51Bits + v.l2 += v.l1 >> 51 + v.l1 = v.l1 & maskLow51Bits + v.l3 += v.l2 >> 51 + v.l2 = v.l2 & maskLow51Bits + v.l4 += v.l3 >> 51 + v.l3 = v.l3 & maskLow51Bits + // no additional carry + v.l4 = v.l4 & maskLow51Bits + + return v +} + +// Add sets v = a + b, and returns v. +func (v *Element) Add(a, b *Element) *Element { + v.l0 = a.l0 + b.l0 + v.l1 = a.l1 + b.l1 + v.l2 = a.l2 + b.l2 + v.l3 = a.l3 + b.l3 + v.l4 = a.l4 + b.l4 + // Using the generic implementation here is actually faster than the + // assembly. Probably because the body of this function is so simple that + // the compiler can figure out better optimizations by inlining the carry + // propagation. + return v.carryPropagateGeneric() +} + +// Subtract sets v = a - b, and returns v. +func (v *Element) Subtract(a, b *Element) *Element { + // We first add 2 * p, to guarantee the subtraction won't underflow, and + // then subtract b (which can be up to 2^255 + 2^13 * 19). + v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0 + v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1 + v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2 + v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3 + v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4 + return v.carryPropagate() +} + +// Negate sets v = -a, and returns v. +func (v *Element) Negate(a *Element) *Element { + return v.Subtract(feZero, a) +} + +// Invert sets v = 1/z mod p, and returns v. +// +// If z == 0, Invert returns v = 0. +func (v *Element) Invert(z *Element) *Element { + // Inversion is implemented as exponentiation with exponent p − 2. It uses the + // same sequence of 255 squarings and 11 multiplications as [Curve25519]. + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element + + z2.Square(z) // 2 + t.Square(&z2) // 4 + t.Square(&t) // 8 + z9.Multiply(&t, z) // 9 + z11.Multiply(&z9, &z2) // 11 + t.Square(&z11) // 22 + z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0 + + t.Square(&z2_5_0) // 2^6 - 2^1 + for i := 0; i < 4; i++ { + t.Square(&t) // 2^10 - 2^5 + } + z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0 + + t.Square(&z2_10_0) // 2^11 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^20 - 2^10 + } + z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0 + + t.Square(&z2_20_0) // 2^21 - 2^1 + for i := 0; i < 19; i++ { + t.Square(&t) // 2^40 - 2^20 + } + t.Multiply(&t, &z2_20_0) // 2^40 - 2^0 + + t.Square(&t) // 2^41 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^50 - 2^10 + } + z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0 + + t.Square(&z2_50_0) // 2^51 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^100 - 2^50 + } + z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0 + + t.Square(&z2_100_0) // 2^101 - 2^1 + for i := 0; i < 99; i++ { + t.Square(&t) // 2^200 - 2^100 + } + t.Multiply(&t, &z2_100_0) // 2^200 - 2^0 + + t.Square(&t) // 2^201 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^250 - 2^50 + } + t.Multiply(&t, &z2_50_0) // 2^250 - 2^0 + + t.Square(&t) // 2^251 - 2^1 + t.Square(&t) // 2^252 - 2^2 + t.Square(&t) // 2^253 - 2^3 + t.Square(&t) // 2^254 - 2^4 + t.Square(&t) // 2^255 - 2^5 + + return v.Multiply(&t, &z11) // 2^255 - 21 +} + +// Set sets v = a, and returns v. +func (v *Element) Set(a *Element) *Element { + *v = *a + return v +} + +// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is +// not of the right length, SetBytes returns nil and an error, and the +// receiver is unchanged. +// +// Consistent with RFC 7748, the most significant bit (the high bit of the +// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1) +// are accepted. Note that this is laxer than specified by RFC 8032, but +// consistent with most Ed25519 implementations. +func (v *Element) SetBytes(x []byte) (*Element, error) { + if len(x) != 32 { + return nil, errors.New("edwards25519: invalid field element input size") + } + + // Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51). + v.l0 = binary.LittleEndian.Uint64(x[0:8]) + v.l0 &= maskLow51Bits + // Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51). + v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3 + v.l1 &= maskLow51Bits + // Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51). + v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6 + v.l2 &= maskLow51Bits + // Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51). + v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1 + v.l3 &= maskLow51Bits + // Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51). + // Note: not bytes 25:33, shift 4, to avoid overread. + v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12 + v.l4 &= maskLow51Bits + + return v, nil +} + +// Bytes returns the canonical 32-byte little-endian encoding of v. +func (v *Element) Bytes() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [32]byte + return v.bytes(&out) +} + +func (v *Element) bytes(out *[32]byte) []byte { + t := *v + t.reduce() + + var buf [8]byte + for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} { + bitsOffset := i * 51 + binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8)) + for i, bb := range buf { + off := bitsOffset/8 + i + if off >= len(out) { + break + } + out[off] |= bb + } + } + + return out[:] +} + +// Equal returns 1 if v and u are equal, and 0 otherwise. +func (v *Element) Equal(u *Element) int { + sa, sv := u.Bytes(), v.Bytes() + return subtle.ConstantTimeCompare(sa, sv) +} + +// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise. +func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) } + +// Select sets v to a if cond == 1, and to b if cond == 0. +func (v *Element) Select(a, b *Element, cond int) *Element { + m := mask64Bits(cond) + v.l0 = (m & a.l0) | (^m & b.l0) + v.l1 = (m & a.l1) | (^m & b.l1) + v.l2 = (m & a.l2) | (^m & b.l2) + v.l3 = (m & a.l3) | (^m & b.l3) + v.l4 = (m & a.l4) | (^m & b.l4) + return v +} + +// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v. +func (v *Element) Swap(u *Element, cond int) { + m := mask64Bits(cond) + t := m & (v.l0 ^ u.l0) + v.l0 ^= t + u.l0 ^= t + t = m & (v.l1 ^ u.l1) + v.l1 ^= t + u.l1 ^= t + t = m & (v.l2 ^ u.l2) + v.l2 ^= t + u.l2 ^= t + t = m & (v.l3 ^ u.l3) + v.l3 ^= t + u.l3 ^= t + t = m & (v.l4 ^ u.l4) + v.l4 ^= t + u.l4 ^= t +} + +// IsNegative returns 1 if v is negative, and 0 otherwise. +func (v *Element) IsNegative() int { + return int(v.Bytes()[0] & 1) +} + +// Absolute sets v to |u|, and returns v. +func (v *Element) Absolute(u *Element) *Element { + return v.Select(new(Element).Negate(u), u, u.IsNegative()) +} + +// Multiply sets v = x * y, and returns v. +func (v *Element) Multiply(x, y *Element) *Element { + feMul(v, x, y) + return v +} + +// Square sets v = x * x, and returns v. +func (v *Element) Square(x *Element) *Element { + feSquare(v, x) + return v +} + +// Mult32 sets v = x * y, and returns v. +func (v *Element) Mult32(x *Element, y uint32) *Element { + x0lo, x0hi := mul51(x.l0, y) + x1lo, x1hi := mul51(x.l1, y) + x2lo, x2hi := mul51(x.l2, y) + x3lo, x3hi := mul51(x.l3, y) + x4lo, x4hi := mul51(x.l4, y) + v.l0 = x0lo + 19*x4hi // carried over per the reduction identity + v.l1 = x1lo + x0hi + v.l2 = x2lo + x1hi + v.l3 = x3lo + x2hi + v.l4 = x4lo + x3hi + // The hi portions are going to be only 32 bits, plus any previous excess, + // so we can skip the carry propagation. + return v +} + +// mul51 returns lo + hi * 2⁵¹ = a * b. +func mul51(a uint64, b uint32) (lo uint64, hi uint64) { + mh, ml := bits.Mul64(a, uint64(b)) + lo = ml & maskLow51Bits + hi = (mh << 13) | (ml >> 51) + return +} + +// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3. +func (v *Element) Pow22523(x *Element) *Element { + var t0, t1, t2 Element + + t0.Square(x) // x^2 + t1.Square(&t0) // x^4 + t1.Square(&t1) // x^8 + t1.Multiply(x, &t1) // x^9 + t0.Multiply(&t0, &t1) // x^11 + t0.Square(&t0) // x^22 + t0.Multiply(&t1, &t0) // x^31 + t1.Square(&t0) // x^62 + for i := 1; i < 5; i++ { // x^992 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1 + t1.Square(&t0) // 2^11 - 2 + for i := 1; i < 10; i++ { // 2^20 - 2^10 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^20 - 1 + t2.Square(&t1) // 2^21 - 2 + for i := 1; i < 20; i++ { // 2^40 - 2^20 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^40 - 1 + t1.Square(&t1) // 2^41 - 2 + for i := 1; i < 10; i++ { // 2^50 - 2^10 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^50 - 1 + t1.Square(&t0) // 2^51 - 2 + for i := 1; i < 50; i++ { // 2^100 - 2^50 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^100 - 1 + t2.Square(&t1) // 2^101 - 2 + for i := 1; i < 100; i++ { // 2^200 - 2^100 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^200 - 1 + t1.Square(&t1) // 2^201 - 2 + for i := 1; i < 50; i++ { // 2^250 - 2^50 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^250 - 1 + t0.Square(&t0) // 2^251 - 2 + t0.Square(&t0) // 2^252 - 4 + return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3) +} + +// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion. +var sqrtM1 = &Element{1718705420411056, 234908883556509, + 2233514472574048, 2117202627021982, 765476049583133} + +// SqrtRatio sets r to the non-negative square root of the ratio of u and v. +// +// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio +// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00, +// and returns r and 0. +func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) { + t0 := new(Element) + + // r = (u * v3) * (u * v7)^((p-5)/8) + v2 := new(Element).Square(v) + uv3 := new(Element).Multiply(u, t0.Multiply(v2, v)) + uv7 := new(Element).Multiply(uv3, t0.Square(v2)) + rr := new(Element).Multiply(uv3, t0.Pow22523(uv7)) + + check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2 + + uNeg := new(Element).Negate(u) + correctSignSqrt := check.Equal(u) + flippedSignSqrt := check.Equal(uNeg) + flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1)) + + rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r + // r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r) + rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI) + + r.Absolute(rr) // Choose the nonnegative square root. + return r, correctSignSqrt | flippedSignSqrt +} diff --git a/src/crypto/internal/edwards25519/field/fe_alias_test.go b/src/crypto/internal/edwards25519/field/fe_alias_test.go new file mode 100644 index 0000000..bf1efdc --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_alias_test.go @@ -0,0 +1,140 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "testing" + "testing/quick" +) + +func checkAliasingOneArg(f func(v, x *Element) *Element) func(v, x Element) bool { + return func(v, x Element) bool { + x1, v1 := x, x + + // Calculate a reference f(x) without aliasing. + if out := f(&v, &x); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the argument and the receiver. + if out := f(&v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments was not modified. + return x == x1 + } +} + +func checkAliasingTwoArgs(f func(v, x, y *Element) *Element) func(v, x, y Element) bool { + return func(v, x, y Element) bool { + x1, y1, v1 := x, y, Element{} + + // Calculate a reference f(x, y) without aliasing. + if out := f(&v, &x, &y); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &y); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = y + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + + // Calculate a reference f(x, x) without aliasing. + if out := f(&v, &x, &x); out != &v { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &x); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = x + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + // Test aliasing both arguments and the receiver. + v1 = x + if out := f(&v1, &v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments were not modified. + return x == x1 && y == y1 + } +} + +// TestAliasing checks that receivers and arguments can alias each other without +// leading to incorrect results. That is, it ensures that it's safe to write +// +// v.Invert(v) +// +// or +// +// v.Add(v, v) +// +// without any of the inputs getting clobbered by the output being written. +func TestAliasing(t *testing.T) { + type target struct { + name string + oneArgF func(v, x *Element) *Element + twoArgsF func(v, x, y *Element) *Element + } + for _, tt := range []target{ + {name: "Absolute", oneArgF: (*Element).Absolute}, + {name: "Invert", oneArgF: (*Element).Invert}, + {name: "Negate", oneArgF: (*Element).Negate}, + {name: "Set", oneArgF: (*Element).Set}, + {name: "Square", oneArgF: (*Element).Square}, + {name: "Pow22523", oneArgF: (*Element).Pow22523}, + { + name: "Mult32", + oneArgF: func(v, x *Element) *Element { + return v.Mult32(x, 0xffffffff) + }, + }, + {name: "Multiply", twoArgsF: (*Element).Multiply}, + {name: "Add", twoArgsF: (*Element).Add}, + {name: "Subtract", twoArgsF: (*Element).Subtract}, + { + name: "SqrtRatio", + twoArgsF: func(v, x, y *Element) *Element { + r, _ := v.SqrtRatio(x, y) + return r + }, + }, + { + name: "Select0", + twoArgsF: func(v, x, y *Element) *Element { + return v.Select(x, y, 0) + }, + }, + { + name: "Select1", + twoArgsF: func(v, x, y *Element) *Element { + return v.Select(x, y, 1) + }, + }, + } { + var err error + switch { + case tt.oneArgF != nil: + err = quick.Check(checkAliasingOneArg(tt.oneArgF), &quick.Config{MaxCountScale: 1 << 8}) + case tt.twoArgsF != nil: + err = quick.Check(checkAliasingTwoArgs(tt.twoArgsF), &quick.Config{MaxCountScale: 1 << 8}) + } + if err != nil { + t.Errorf("%v: %v", tt.name, err) + } + } +} diff --git a/src/crypto/internal/edwards25519/field/fe_amd64.go b/src/crypto/internal/edwards25519/field/fe_amd64.go new file mode 100644 index 0000000..70c5416 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_amd64.go @@ -0,0 +1,15 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +//go:build amd64 && gc && !purego + +package field + +// feMul sets out = a * b. It works like feMulGeneric. +// +//go:noescape +func feMul(out *Element, a *Element, b *Element) + +// feSquare sets out = a * a. It works like feSquareGeneric. +// +//go:noescape +func feSquare(out *Element, a *Element) diff --git a/src/crypto/internal/edwards25519/field/fe_amd64.s b/src/crypto/internal/edwards25519/field/fe_amd64.s new file mode 100644 index 0000000..60817ac --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_amd64.s @@ -0,0 +1,378 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +//go:build amd64 && gc && !purego + +#include "textflag.h" + +// func feMul(out *Element, a *Element, b *Element) +TEXT ·feMul(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), CX + MOVQ b+16(FP), BX + + // r0 = a0×b0 + MOVQ (CX), AX + MULQ (BX) + MOVQ AX, DI + MOVQ DX, SI + + // r0 += 19×a1×b4 + MOVQ 8(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a2×b3 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a3×b2 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a4×b1 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 8(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r1 = a0×b1 + MOVQ (CX), AX + MULQ 8(BX) + MOVQ AX, R9 + MOVQ DX, R8 + + // r1 += a1×b0 + MOVQ 8(CX), AX + MULQ (BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a2×b4 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a3×b3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a4×b2 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r2 = a0×b2 + MOVQ (CX), AX + MULQ 16(BX) + MOVQ AX, R11 + MOVQ DX, R10 + + // r2 += a1×b1 + MOVQ 8(CX), AX + MULQ 8(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += a2×b0 + MOVQ 16(CX), AX + MULQ (BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a3×b4 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a4×b3 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r3 = a0×b3 + MOVQ (CX), AX + MULQ 24(BX) + MOVQ AX, R13 + MOVQ DX, R12 + + // r3 += a1×b2 + MOVQ 8(CX), AX + MULQ 16(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a2×b1 + MOVQ 16(CX), AX + MULQ 8(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a3×b0 + MOVQ 24(CX), AX + MULQ (BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += 19×a4×b4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r4 = a0×b4 + MOVQ (CX), AX + MULQ 32(BX) + MOVQ AX, R15 + MOVQ DX, R14 + + // r4 += a1×b3 + MOVQ 8(CX), AX + MULQ 24(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a2×b2 + MOVQ 16(CX), AX + MULQ 16(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a3×b1 + MOVQ 24(CX), AX + MULQ 8(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a4×b0 + MOVQ 32(CX), AX + MULQ (BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, DI, SI + SHLQ $0x0d, R9, R8 + SHLQ $0x0d, R11, R10 + SHLQ $0x0d, R13, R12 + SHLQ $0x0d, R15, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Second reduction chain (carryPropagate) + MOVQ DI, SI + SHRQ $0x33, SI + MOVQ R9, R8 + SHRQ $0x33, R8 + MOVQ R11, R10 + SHRQ $0x33, R10 + MOVQ R13, R12 + SHRQ $0x33, R12 + MOVQ R15, R14 + SHRQ $0x33, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Store output + MOVQ out+0(FP), AX + MOVQ DI, (AX) + MOVQ R9, 8(AX) + MOVQ R11, 16(AX) + MOVQ R13, 24(AX) + MOVQ R15, 32(AX) + RET + +// func feSquare(out *Element, a *Element) +TEXT ·feSquare(SB), NOSPLIT, $0-16 + MOVQ a+8(FP), CX + + // r0 = l0×l0 + MOVQ (CX), AX + MULQ (CX) + MOVQ AX, SI + MOVQ DX, BX + + // r0 += 38×l1×l4 + MOVQ 8(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r0 += 38×l2×l3 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 24(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r1 = 2×l0×l1 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 8(CX) + MOVQ AX, R8 + MOVQ DX, DI + + // r1 += 38×l2×l4 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r1 += 19×l3×l3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r2 = 2×l0×l2 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 16(CX) + MOVQ AX, R10 + MOVQ DX, R9 + + // r2 += l1×l1 + MOVQ 8(CX), AX + MULQ 8(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r2 += 38×l3×l4 + MOVQ 24(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r3 = 2×l0×l3 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 24(CX) + MOVQ AX, R12 + MOVQ DX, R11 + + // r3 += 2×l1×l2 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 16(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r3 += 19×l4×l4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r4 = 2×l0×l4 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 32(CX) + MOVQ AX, R14 + MOVQ DX, R13 + + // r4 += 2×l1×l3 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 24(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // r4 += l2×l2 + MOVQ 16(CX), AX + MULQ 16(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, SI, BX + SHLQ $0x0d, R8, DI + SHLQ $0x0d, R10, R9 + SHLQ $0x0d, R12, R11 + SHLQ $0x0d, R14, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Second reduction chain (carryPropagate) + MOVQ SI, BX + SHRQ $0x33, BX + MOVQ R8, DI + SHRQ $0x33, DI + MOVQ R10, R9 + SHRQ $0x33, R9 + MOVQ R12, R11 + SHRQ $0x33, R11 + MOVQ R14, R13 + SHRQ $0x33, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Store output + MOVQ out+0(FP), AX + MOVQ SI, (AX) + MOVQ R8, 8(AX) + MOVQ R10, 16(AX) + MOVQ R12, 24(AX) + MOVQ R14, 32(AX) + RET diff --git a/src/crypto/internal/edwards25519/field/fe_amd64_noasm.go b/src/crypto/internal/edwards25519/field/fe_amd64_noasm.go new file mode 100644 index 0000000..9da280d --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_amd64_noasm.go @@ -0,0 +1,11 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !amd64 || !gc || purego + +package field + +func feMul(v, x, y *Element) { feMulGeneric(v, x, y) } + +func feSquare(v, x *Element) { feSquareGeneric(v, x) } diff --git a/src/crypto/internal/edwards25519/field/fe_arm64.go b/src/crypto/internal/edwards25519/field/fe_arm64.go new file mode 100644 index 0000000..075fe9b --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_arm64.go @@ -0,0 +1,15 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build arm64 && gc && !purego + +package field + +//go:noescape +func carryPropagate(v *Element) + +func (v *Element) carryPropagate() *Element { + carryPropagate(v) + return v +} diff --git a/src/crypto/internal/edwards25519/field/fe_arm64.s b/src/crypto/internal/edwards25519/field/fe_arm64.s new file mode 100644 index 0000000..751ab2a --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_arm64.s @@ -0,0 +1,42 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm64,gc,!purego + +#include "textflag.h" + +// carryPropagate works exactly like carryPropagateGeneric and uses the +// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but +// avoids loading R0-R4 twice and uses LDP and STP. +// +// See https://golang.org/issues/43145 for the main compiler issue. +// +// func carryPropagate(v *Element) +TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8 + MOVD v+0(FP), R20 + + LDP 0(R20), (R0, R1) + LDP 16(R20), (R2, R3) + MOVD 32(R20), R4 + + AND $0x7ffffffffffff, R0, R10 + AND $0x7ffffffffffff, R1, R11 + AND $0x7ffffffffffff, R2, R12 + AND $0x7ffffffffffff, R3, R13 + AND $0x7ffffffffffff, R4, R14 + + ADD R0>>51, R11, R11 + ADD R1>>51, R12, R12 + ADD R2>>51, R13, R13 + ADD R3>>51, R14, R14 + // R4>>51 * 19 + R10 -> R10 + LSR $51, R4, R21 + MOVD $19, R22 + MADD R22, R10, R21, R10 + + STP (R10, R11), 0(R20) + STP (R12, R13), 16(R20) + MOVD R14, 32(R20) + + RET diff --git a/src/crypto/internal/edwards25519/field/fe_arm64_noasm.go b/src/crypto/internal/edwards25519/field/fe_arm64_noasm.go new file mode 100644 index 0000000..fc029ac --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_arm64_noasm.go @@ -0,0 +1,11 @@ +// Copyright (c) 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !arm64 || !gc || purego + +package field + +func (v *Element) carryPropagate() *Element { + return v.carryPropagateGeneric() +} diff --git a/src/crypto/internal/edwards25519/field/fe_bench_test.go b/src/crypto/internal/edwards25519/field/fe_bench_test.go new file mode 100644 index 0000000..84fdf05 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_bench_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "testing" + +func BenchmarkAdd(b *testing.B) { + x := new(Element).One() + y := new(Element).Add(x, x) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Add(x, y) + } +} + +func BenchmarkMultiply(b *testing.B) { + x := new(Element).One() + y := new(Element).Add(x, x) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Multiply(x, y) + } +} + +func BenchmarkSquare(b *testing.B) { + x := new(Element).Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Square(x) + } +} + +func BenchmarkInvert(b *testing.B) { + x := new(Element).Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Invert(x) + } +} + +func BenchmarkMult32(b *testing.B) { + x := new(Element).One() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mult32(x, 0xaa42aa42) + } +} diff --git a/src/crypto/internal/edwards25519/field/fe_generic.go b/src/crypto/internal/edwards25519/field/fe_generic.go new file mode 100644 index 0000000..d6667b2 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_generic.go @@ -0,0 +1,266 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "math/bits" + +// uint128 holds a 128-bit number as two 64-bit limbs, for use with the +// bits.Mul64 and bits.Add64 intrinsics. +type uint128 struct { + lo, hi uint64 +} + +// mul64 returns a * b. +func mul64(a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + return uint128{lo, hi} +} + +// addMul64 returns v + a * b. +func addMul64(v uint128, a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + lo, c := bits.Add64(lo, v.lo, 0) + hi, _ = bits.Add64(hi, v.hi, c) + return uint128{lo, hi} +} + +// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits. +func shiftRightBy51(a uint128) uint64 { + return (a.hi << (64 - 51)) | (a.lo >> 51) +} + +func feMulGeneric(v, a, b *Element) { + a0 := a.l0 + a1 := a.l1 + a2 := a.l2 + a3 := a.l3 + a4 := a.l4 + + b0 := b.l0 + b1 := b.l1 + b2 := b.l2 + b3 := b.l3 + b4 := b.l4 + + // Limb multiplication works like pen-and-paper columnar multiplication, but + // with 51-bit limbs instead of digits. + // + // a4 a3 a2 a1 a0 x + // b4 b3 b2 b1 b0 = + // ------------------------ + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a4b1 a3b1 a2b1 a1b1 a0b1 + + // a4b2 a3b2 a2b2 a1b2 a0b2 + + // a4b3 a3b3 a2b3 a1b3 a0b3 + + // a4b4 a3b4 a2b4 a1b4 a0b4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to + // reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5, + // r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc. + // + // Reduction can be carried out simultaneously to multiplication. For + // example, we do not compute r5: whenever the result of a multiplication + // belongs to r5, like a1b4, we multiply it by 19 and add the result to r0. + // + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a3b1 a2b1 a1b1 a0b1 19×a4b1 + + // a2b2 a1b2 a0b2 19×a4b2 19×a3b2 + + // a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 + + // a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // Finally we add up the columns into wide, overlapping limbs. + + a1_19 := a1 * 19 + a2_19 := a2 * 19 + a3_19 := a3 * 19 + a4_19 := a4 * 19 + + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + r0 := mul64(a0, b0) + r0 = addMul64(r0, a1_19, b4) + r0 = addMul64(r0, a2_19, b3) + r0 = addMul64(r0, a3_19, b2) + r0 = addMul64(r0, a4_19, b1) + + // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2) + r1 := mul64(a0, b1) + r1 = addMul64(r1, a1, b0) + r1 = addMul64(r1, a2_19, b4) + r1 = addMul64(r1, a3_19, b3) + r1 = addMul64(r1, a4_19, b2) + + // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3) + r2 := mul64(a0, b2) + r2 = addMul64(r2, a1, b1) + r2 = addMul64(r2, a2, b0) + r2 = addMul64(r2, a3_19, b4) + r2 = addMul64(r2, a4_19, b3) + + // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4 + r3 := mul64(a0, b3) + r3 = addMul64(r3, a1, b2) + r3 = addMul64(r3, a2, b1) + r3 = addMul64(r3, a3, b0) + r3 = addMul64(r3, a4_19, b4) + + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + r4 := mul64(a0, b4) + r4 = addMul64(r4, a1, b3) + r4 = addMul64(r4, a2, b2) + r4 = addMul64(r4, a3, b1) + r4 = addMul64(r4, a4, b0) + + // After the multiplication, we need to reduce (carry) the five coefficients + // to obtain a result with limbs that are at most slightly larger than 2⁵¹, + // to respect the Element invariant. + // + // Overall, the reduction works the same as carryPropagate, except with + // wider inputs: we take the carry for each coefficient by shifting it right + // by 51, and add it to the limb above it. The top carry is multiplied by 19 + // according to the reduction identity and added to the lowest limb. + // + // The largest coefficient (r0) will be at most 111 bits, which guarantees + // that all carries are at most 111 - 51 = 60 bits, which fits in a uint64. + // + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + // r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²) + // r0 < (1 + 19 × 4) × 2⁵² × 2⁵² + // r0 < 2⁷ × 2⁵² × 2⁵² + // r0 < 2¹¹¹ + // + // Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most + // 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and + // allows us to easily apply the reduction identity. + // + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + // r4 < 5 × 2⁵² × 2⁵² + // r4 < 2¹⁰⁷ + // + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + // Now all coefficients fit into 64-bit registers but are still too large to + // be passed around as a Element. We therefore do one last carry chain, + // where the carries will be small enough to fit in the wiggle room above 2⁵¹. + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +func feSquareGeneric(v, a *Element) { + l0 := a.l0 + l1 := a.l1 + l2 := a.l2 + l3 := a.l3 + l4 := a.l4 + + // Squaring works precisely like multiplication above, but thanks to its + // symmetry we get to group a few terms together. + // + // l4 l3 l2 l1 l0 x + // l4 l3 l2 l1 l0 = + // ------------------------ + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l4l1 l3l1 l2l1 l1l1 l0l1 + + // l4l2 l3l2 l2l2 l1l2 l0l2 + + // l4l3 l3l3 l2l3 l1l3 l0l3 + + // l4l4 l3l4 l2l4 l1l4 l0l4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l3l1 l2l1 l1l1 l0l1 19×l4l1 + + // l2l2 l1l2 l0l2 19×l4l2 19×l3l2 + + // l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 + + // l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with + // only three Mul64 and four Add64, instead of five and eight. + + l0_2 := l0 * 2 + l1_2 := l1 * 2 + + l1_38 := l1 * 38 + l2_38 := l2 * 38 + l3_38 := l3 * 38 + + l3_19 := l3 * 19 + l4_19 := l4 * 19 + + // r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3) + r0 := mul64(l0, l0) + r0 = addMul64(r0, l1_38, l4) + r0 = addMul64(r0, l2_38, l3) + + // r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3 + r1 := mul64(l0_2, l1) + r1 = addMul64(r1, l2_38, l4) + r1 = addMul64(r1, l3_19, l3) + + // r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4 + r2 := mul64(l0_2, l2) + r2 = addMul64(r2, l1, l1) + r2 = addMul64(r2, l3_38, l4) + + // r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4 + r3 := mul64(l0_2, l3) + r3 = addMul64(r3, l1_2, l2) + r3 = addMul64(r3, l4_19, l4) + + // r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2 + r4 := mul64(l0_2, l4) + r4 = addMul64(r4, l1_2, l3) + r4 = addMul64(r4, l2, l2) + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +// carryPropagate brings the limbs below 52 bits by applying the reduction +// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. +func (v *Element) carryPropagateGeneric() *Element { + c0 := v.l0 >> 51 + c1 := v.l1 >> 51 + c2 := v.l2 >> 51 + c3 := v.l3 >> 51 + c4 := v.l4 >> 51 + + // c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and + // the final l0 will be at most 52 bits. Similarly for the rest. + v.l0 = v.l0&maskLow51Bits + c4*19 + v.l1 = v.l1&maskLow51Bits + c0 + v.l2 = v.l2&maskLow51Bits + c1 + v.l3 = v.l3&maskLow51Bits + c2 + v.l4 = v.l4&maskLow51Bits + c3 + + return v +} diff --git a/src/crypto/internal/edwards25519/field/fe_test.go b/src/crypto/internal/edwards25519/field/fe_test.go new file mode 100644 index 0000000..945a024 --- /dev/null +++ b/src/crypto/internal/edwards25519/field/fe_test.go @@ -0,0 +1,560 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "io" + "math/big" + "math/bits" + mathrand "math/rand" + "reflect" + "testing" + "testing/quick" +) + +func (v Element) String() string { + return hex.EncodeToString(v.Bytes()) +} + +// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks) +// times. The default value of -quickchecks is 100. +var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10} + +func generateFieldElement(rand *mathrand.Rand) Element { + const maskLow52Bits = (1 << 52) - 1 + return Element{ + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + } +} + +// weirdLimbs can be combined to generate a range of edge-case field elements. +// 0 and -1 are intentionally more weighted, as they combine well. +var ( + weirdLimbs51 = []uint64{ + 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + } + weirdLimbs52 = []uint64{ + 0, 0, 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + 1 << 51, + (1 << 51) + 1, + (1 << 52) - 19, + (1 << 52) - 1, + } +) + +func generateWeirdFieldElement(rand *mathrand.Rand) Element { + return Element{ + weirdLimbs52[rand.Intn(len(weirdLimbs52))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + } +} + +func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value { + if rand.Intn(2) == 0 { + return reflect.ValueOf(generateWeirdFieldElement(rand)) + } + return reflect.ValueOf(generateFieldElement(rand)) +} + +// isInBounds returns whether the element is within the expected bit size bounds +// after a light reduction. +func isInBounds(x *Element) bool { + return bits.Len64(x.l0) <= 52 && + bits.Len64(x.l1) <= 52 && + bits.Len64(x.l2) <= 52 && + bits.Len64(x.l3) <= 52 && + bits.Len64(x.l4) <= 52 +} + +func TestMultiplyDistributesOverAdd(t *testing.T) { + multiplyDistributesOverAdd := func(x, y, z Element) bool { + // Compute t1 = (x+y)*z + t1 := new(Element) + t1.Add(&x, &y) + t1.Multiply(t1, &z) + + // Compute t2 = x*z + y*z + t2 := new(Element) + t3 := new(Element) + t2.Multiply(&x, &z) + t3.Multiply(&y, &z) + t2.Add(t2, t3) + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestMul64to128(t *testing.T) { + a := uint64(5) + b := uint64(5) + r := mul64(a, b) + if r.lo != 0x19 || r.hi != 0 { + t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(18014398509481983) // 2^54 - 1 + b = uint64(18014398509481983) // 2^54 - 1 + r = mul64(a, b) + if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff { + t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(1125899906842661) + b = uint64(2097155) + r = mul64(a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + if r.lo != 16888498990613035 || r.hi != 640 { + t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi) + } +} + +func TestSetBytesRoundTrip(t *testing.T) { + f1 := func(in [32]byte, fe Element) bool { + fe.SetBytes(in[:]) + + // Mask the most significant bit as it's ignored by SetBytes. (Now + // instead of earlier so we check the masking in SetBytes is working.) + in[len(in)-1] &= (1 << 7) - 1 + + return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe) + } + if err := quick.Check(f1, nil); err != nil { + t.Errorf("failed bytes->FE->bytes round-trip: %v", err) + } + + f2 := func(fe, r Element) bool { + r.SetBytes(fe.Bytes()) + + // Intentionally not using Equal not to go through Bytes again. + // Calling reduce because both Generate and SetBytes can produce + // non-canonical representations. + fe.reduce() + r.reduce() + return fe == r + } + if err := quick.Check(f2, nil); err != nil { + t.Errorf("failed FE->bytes->FE round-trip: %v", err) + } + + // Check some fixed vectors from dalek + type feRTTest struct { + fe Element + b []byte + } + var tests = []feRTTest{ + { + fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}, + b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31}, + }, + { + fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}, + b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122}, + }, + } + + for _, tt := range tests { + b := tt.fe.Bytes() + fe, _ := new(Element).SetBytes(tt.b) + if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 { + t.Errorf("Failed fixed roundtrip: %v", tt) + } + } +} + +func swapEndianness(buf []byte) []byte { + for i := 0; i < len(buf)/2; i++ { + buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i] + } + return buf +} + +func TestBytesBigEquivalence(t *testing.T) { + f1 := func(in [32]byte, fe, fe1 Element) bool { + fe.SetBytes(in[:]) + + in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit + b := new(big.Int).SetBytes(swapEndianness(in[:])) + fe1.fromBig(b) + + if fe != fe1 { + return false + } + + buf := make([]byte, 32) + buf = swapEndianness(fe1.toBig().FillBytes(buf)) + + return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1) + } + if err := quick.Check(f1, nil); err != nil { + t.Error(err) + } +} + +// fromBig sets v = n, and returns v. The bit length of n must not exceed 256. +func (v *Element) fromBig(n *big.Int) *Element { + if n.BitLen() > 32*8 { + panic("edwards25519: invalid field element input size") + } + + buf := make([]byte, 0, 32) + for _, word := range n.Bits() { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) >= cap(buf) { + break + } + buf = append(buf, byte(word)) + word >>= 8 + } + } + + v.SetBytes(buf[:32]) + return v +} + +func (v *Element) fromDecimal(s string) *Element { + n, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("not a valid decimal: " + s) + } + return v.fromBig(n) +} + +// toBig returns v as a big.Int. +func (v *Element) toBig() *big.Int { + buf := v.Bytes() + + words := make([]big.Word, 32*8/bits.UintSize) + for n := range words { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) == 0 { + break + } + words[n] |= big.Word(buf[0]) << big.Word(i) + buf = buf[1:] + } + } + + return new(big.Int).SetBits(words) +} + +func TestDecimalConstants(t *testing.T) { + sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752" + if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 { + t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp) + } + // d is in the parent package, and we don't want to expose d or fromDecimal. + // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555" + // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 { + // t.Errorf("d is %v, expected %v", d, exp) + // } +} + +func TestSetBytesRoundTripEdgeCases(t *testing.T) { + // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1, + // and between 2^255 and 2^256-1. Test both the documented SetBytes + // behavior, and that Bytes reduces them. +} + +// Tests self-consistency between Multiply and Square. +func TestConsistency(t *testing.T) { + var x Element + var x2, x2sq Element + + x = Element{1, 1, 1, 1, 1} + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } +} + +func TestEqual(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + y := Element{5, 4, 3, 2, 1} + + eq := x.Equal(&x) + if eq != 1 { + t.Errorf("wrong about equality") + } + + eq = x.Equal(&y) + if eq != 0 { + t.Errorf("wrong about inequality") + } +} + +func TestInvert(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + one := Element{1, 0, 0, 0, 0} + var xinv, r Element + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("inversion identity failed, got: %x", r) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) + } + + zero := Element{} + x.Set(&zero) + if xx := xinv.Invert(&x); xx != &xinv { + t.Errorf("inverting zero did not return the receiver") + } else if xinv.Equal(&zero) != 1 { + t.Errorf("inverting zero did not return zero") + } +} + +func TestSelectSwap(t *testing.T) { + a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676} + b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972} + + var c, d Element + + c.Select(&a, &b, 1) + d.Select(&a, &b, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Select failed") + } + + c.Swap(&d, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Swap failed") + } + + c.Swap(&d, 1) + + if c.Equal(&b) != 1 || d.Equal(&a) != 1 { + t.Errorf("Swap failed") + } +} + +func TestMult32(t *testing.T) { + mult32EquivalentToMul := func(x Element, y uint32) bool { + t1 := new(Element) + for i := 0; i < 100; i++ { + t1.Mult32(&x, y) + } + + ty := new(Element) + ty.l0 = uint64(y) + + t2 := new(Element) + for i := 0; i < 100; i++ { + t2.Multiply(&x, ty) + } + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestSqrtRatio(t *testing.T) { + // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4. + type test struct { + u, v string + wasSquare int + r string + } + var tests = []test{ + // If u is 0, the function is defined to return (0, TRUE), even if v + // is zero. Note that where used in this package, the denominator v + // is never zero. + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 0/1 == 0² + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // If u is non-zero and v is zero, defined to return (0, FALSE). + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 0, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 2/1 is not square in this field. + { + "0200000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54", + }, + // 4/1 == 2² + { + "0400000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0200000000000000000000000000000000000000000000000000000000000000", + }, + // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0400000000000000000000000000000000000000000000000000000000000000", + 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f", + }, + } + + for i, tt := range tests { + u, _ := new(Element).SetBytes(decodeHex(tt.u)) + v, _ := new(Element).SetBytes(decodeHex(tt.v)) + want, _ := new(Element).SetBytes(decodeHex(tt.r)) + got, wasSquare := new(Element).SqrtRatio(u, v) + if got.Equal(want) == 0 || wasSquare != tt.wasSquare { + t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare) + } + } +} + +func TestCarryPropagate(t *testing.T) { + asmLikeGeneric := func(a [5]uint64) bool { + t1 := &Element{a[0], a[1], a[2], a[3], a[4]} + t2 := &Element{a[0], a[1], a[2], a[3], a[4]} + + t1.carryPropagate() + t2.carryPropagateGeneric() + + if *t1 != *t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return *t1 == *t2 && isInBounds(t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } + + if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) { + t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}") + } +} + +func TestFeSquare(t *testing.T) { + asmLikeGeneric := func(a Element) bool { + t1 := a + t2 := a + + feSquareGeneric(&t1, &t1) + feSquare(&t2, &t2) + + if t1 != t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return t1 == t2 && isInBounds(&t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestFeMul(t *testing.T) { + asmLikeGeneric := func(a, b Element) bool { + a1 := a + a2 := a + b1 := b + b2 := b + + feMulGeneric(&a1, &a1, &b1) + feMul(&a2, &a2, &b2) + + if a1 != a2 || b1 != b2 { + t.Logf("got: %#v,\nexpected: %#v", a1, a2) + t.Logf("got: %#v,\nexpected: %#v", b1, b2) + } + + return a1 == a2 && isInBounds(&a2) && + b1 == b2 && isInBounds(&b2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func decodeHex(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} |