summaryrefslogtreecommitdiffstats
path: root/src/crypto/internal/bigmod
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
commit43a123c1ae6613b3efeed291fa552ecd909d3acf (patch)
treefd92518b7024bc74031f78a1cf9e454b65e73665 /src/crypto/internal/bigmod
parentInitial commit. (diff)
downloadgolang-1.20-upstream.tar.xz
golang-1.20-upstream.zip
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/crypto/internal/bigmod')
-rw-r--r--src/crypto/internal/bigmod/_asm/go.mod12
-rw-r--r--src/crypto/internal/bigmod/_asm/go.sum32
-rw-r--r--src/crypto/internal/bigmod/_asm/nat_amd64_asm.go131
-rw-r--r--src/crypto/internal/bigmod/nat.go703
-rw-r--r--src/crypto/internal/bigmod/nat_amd64.go8
-rw-r--r--src/crypto/internal/bigmod/nat_amd64.s68
-rw-r--r--src/crypto/internal/bigmod/nat_noasm.go11
-rw-r--r--src/crypto/internal/bigmod/nat_test.go412
8 files changed, 1377 insertions, 0 deletions
diff --git a/src/crypto/internal/bigmod/_asm/go.mod b/src/crypto/internal/bigmod/_asm/go.mod
new file mode 100644
index 0000000..1ce2b5e
--- /dev/null
+++ b/src/crypto/internal/bigmod/_asm/go.mod
@@ -0,0 +1,12 @@
+module asm
+
+go 1.19
+
+require github.com/mmcloughlin/avo v0.4.0
+
+require (
+ golang.org/x/mod v0.4.2 // indirect
+ golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 // indirect
+ golang.org/x/tools v0.1.7 // indirect
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+)
diff --git a/src/crypto/internal/bigmod/_asm/go.sum b/src/crypto/internal/bigmod/_asm/go.sum
new file mode 100644
index 0000000..b4b5914
--- /dev/null
+++ b/src/crypto/internal/bigmod/_asm/go.sum
@@ -0,0 +1,32 @@
+github.com/mmcloughlin/avo v0.4.0 h1:jeHDRktVD+578ULxWpQHkilor6pkdLF7u7EiTzDbfcU=
+github.com/mmcloughlin/avo v0.4.0/go.mod h1:RW9BfYA3TgO9uCdNrKU2h6J8cPD8ZLznvfgHAeszb1s=
+github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
+golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo=
+golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 h1:giLT+HuUP/gXYrG2Plg9WTjj4qhfgaW424ZIFog3rlk=
+golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ=
+golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
diff --git a/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go
new file mode 100644
index 0000000..5690f04
--- /dev/null
+++ b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go
@@ -0,0 +1,131 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ . "github.com/mmcloughlin/avo/build"
+ . "github.com/mmcloughlin/avo/operand"
+ . "github.com/mmcloughlin/avo/reg"
+)
+
+//go:generate go run . -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod
+
+func main() {
+ Package("crypto/internal/bigmod")
+ ConstraintExpr("amd64,gc,!purego")
+
+ Implement("montgomeryLoop")
+ Pragma("noescape")
+
+ size := Load(Param("d").Len(), GP64())
+ d := Mem{Base: Load(Param("d").Base(), GP64())}
+ b := Mem{Base: Load(Param("b").Base(), GP64())}
+ m := Mem{Base: Load(Param("m").Base(), GP64())}
+ m0inv := Load(Param("m0inv"), GP64())
+
+ overflow := zero()
+ i := zero()
+ Label("outerLoop")
+
+ ai := Load(Param("a").Base(), GP64())
+ MOVQ(Mem{Base: ai}.Idx(i, 8), ai)
+
+ z := uint128{GP64(), GP64()}
+ mul64(z, b, ai)
+ add64(z, d)
+ f := GP64()
+ MOVQ(m0inv, f)
+ IMULQ(z.lo, f)
+ _MASK(f)
+ addMul64(z, m, f)
+ carry := shiftBy63(z)
+
+ j := zero()
+ INCQ(j)
+ JMP(LabelRef("innerLoopCondition"))
+ Label("innerLoop")
+
+ // z = d[j] + a[i] * b[j] + f * m[j] + carry
+ z = uint128{GP64(), GP64()}
+ mul64(z, b.Idx(j, 8), ai)
+ addMul64(z, m.Idx(j, 8), f)
+ add64(z, d.Idx(j, 8))
+ add64(z, carry)
+ // d[j-1] = z_lo & _MASK
+ storeMasked(z.lo, d.Idx(j, 8).Offset(-8))
+ // carry = z_hi<<1 | z_lo>>_W
+ MOVQ(shiftBy63(z), carry)
+
+ INCQ(j)
+ Label("innerLoopCondition")
+ CMPQ(size, j)
+ JGT(LabelRef("innerLoop"))
+
+ ADDQ(carry, overflow)
+ storeMasked(overflow, d.Idx(size, 8).Offset(-8))
+ SHRQ(Imm(63), overflow)
+
+ INCQ(i)
+ CMPQ(size, i)
+ JGT(LabelRef("outerLoop"))
+
+ Store(overflow, ReturnIndex(0))
+ RET()
+ Generate()
+}
+
+// zero zeroes a new register and returns it.
+func zero() Register {
+ r := GP64()
+ XORQ(r, r)
+ return r
+}
+
+// _MASK masks out the top bit of r.
+func _MASK(r Register) {
+ BTRQ(Imm(63), r)
+}
+
+type uint128 struct {
+ hi, lo GPVirtual
+}
+
+// storeMasked stores _MASK(src) in dst. It doesn't modify src.
+func storeMasked(src, dst Op) {
+ out := GP64()
+ MOVQ(src, out)
+ _MASK(out)
+ MOVQ(out, dst)
+}
+
+// shiftBy63 returns z >> 63. It reuses z.lo.
+func shiftBy63(z uint128) Register {
+ SHRQ(Imm(63), z.hi, z.lo)
+ result := z.lo
+ z.hi, z.lo = nil, nil
+ return result
+}
+
+// add64 sets r to r + a.
+func add64(r uint128, a Op) {
+ ADDQ(a, r.lo)
+ ADCQ(Imm(0), r.hi)
+}
+
+// mul64 sets r to a * b.
+func mul64(r uint128, a, b Op) {
+ MOVQ(a, RAX)
+ MULQ(b) // RDX, RAX = RAX * b
+ MOVQ(RAX, r.lo)
+ MOVQ(RDX, r.hi)
+}
+
+// addMul64 sets r to r + a * b.
+func addMul64(r uint128, a, b Op) {
+ MOVQ(a, RAX)
+ MULQ(b) // RDX, RAX = RAX * b
+ ADDQ(RAX, r.lo)
+ ADCQ(RDX, r.hi)
+}
diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go
new file mode 100644
index 0000000..804316f
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat.go
@@ -0,0 +1,703 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bigmod
+
+import (
+ "errors"
+ "math/big"
+ "math/bits"
+)
+
+const (
+ // _W is the number of bits we use for our limbs.
+ _W = bits.UintSize - 1
+ // _MASK selects _W bits from a full machine word.
+ _MASK = (1 << _W) - 1
+)
+
+// choice represents a constant-time boolean. The value of choice is always
+// either 1 or 0. We use an int instead of bool in order to make decisions in
+// constant time by turning it into a mask.
+type choice uint
+
+func not(c choice) choice { return 1 ^ c }
+
+const yes = choice(1)
+const no = choice(0)
+
+// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
+// function does not depend on its inputs. If on is any value besides 1 or 0,
+// the result is undefined.
+func ctSelect(on choice, x, y uint) uint {
+ // When on == 1, mask is 0b111..., otherwise mask is 0b000...
+ mask := -uint(on)
+ // When mask is all zeros, we just have y, otherwise, y cancels with itself.
+ return y ^ (mask & (y ^ x))
+}
+
+// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
+// function does not depend on its inputs.
+func ctEq(x, y uint) choice {
+ // If x != y, then either x - y or y - x will generate a carry.
+ _, c1 := bits.Sub(x, y, 0)
+ _, c2 := bits.Sub(y, x, 0)
+ return not(choice(c1 | c2))
+}
+
+// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
+// function does not depend on its inputs.
+func ctGeq(x, y uint) choice {
+ // If x < y, then x - y generates a carry.
+ _, carry := bits.Sub(x, y, 0)
+ return not(choice(carry))
+}
+
+// Nat represents an arbitrary natural number
+//
+// Each Nat has an announced length, which is the number of limbs it has stored.
+// Operations on this number are allowed to leak this length, but will not leak
+// any information about the values contained in those limbs.
+type Nat struct {
+ // limbs is a little-endian representation in base 2^W with
+ // W = bits.UintSize - 1. The top bit is always unset between operations.
+ //
+ // The top bit is left unset to optimize Montgomery multiplication, in the
+ // inner loop of exponentiation. Using fully saturated limbs would leave us
+ // working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
+ // and thus time.
+ limbs []uint
+}
+
+// preallocTarget is the size in bits of the numbers used to implement the most
+// common and most performant RSA key size. It's also enough to cover some of
+// the operations of key sizes up to 4096.
+const preallocTarget = 2048
+const preallocLimbs = (preallocTarget + _W - 1) / _W
+
+// NewNat returns a new nat with a size of zero, just like new(Nat), but with
+// the preallocated capacity to hold a number of up to preallocTarget bits.
+// NewNat inlines, so the allocation can live on the stack.
+func NewNat() *Nat {
+ limbs := make([]uint, 0, preallocLimbs)
+ return &Nat{limbs}
+}
+
+// expand expands x to n limbs, leaving its value unchanged.
+func (x *Nat) expand(n int) *Nat {
+ if len(x.limbs) > n {
+ panic("bigmod: internal error: shrinking nat")
+ }
+ if cap(x.limbs) < n {
+ newLimbs := make([]uint, n)
+ copy(newLimbs, x.limbs)
+ x.limbs = newLimbs
+ return x
+ }
+ extraLimbs := x.limbs[len(x.limbs):n]
+ for i := range extraLimbs {
+ extraLimbs[i] = 0
+ }
+ x.limbs = x.limbs[:n]
+ return x
+}
+
+// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
+func (x *Nat) reset(n int) *Nat {
+ if cap(x.limbs) < n {
+ x.limbs = make([]uint, n)
+ return x
+ }
+ for i := range x.limbs {
+ x.limbs[i] = 0
+ }
+ x.limbs = x.limbs[:n]
+ return x
+}
+
+// set assigns x = y, optionally resizing x to the appropriate size.
+func (x *Nat) set(y *Nat) *Nat {
+ x.reset(len(y.limbs))
+ copy(x.limbs, y.limbs)
+ return x
+}
+
+// setBig assigns x = n, optionally resizing n to the appropriate size.
+//
+// The announced length of x is set based on the actual bit size of the input,
+// ignoring leading zeroes.
+func (x *Nat) setBig(n *big.Int) *Nat {
+ requiredLimbs := (n.BitLen() + _W - 1) / _W
+ x.reset(requiredLimbs)
+
+ outI := 0
+ shift := 0
+ limbs := n.Bits()
+ for i := range limbs {
+ xi := uint(limbs[i])
+ x.limbs[outI] |= (xi << shift) & _MASK
+ outI++
+ if outI == requiredLimbs {
+ return x
+ }
+ x.limbs[outI] = xi >> (_W - shift)
+ shift++ // this assumes bits.UintSize - _W = 1
+ if shift == _W {
+ shift = 0
+ outI++
+ }
+ }
+ return x
+}
+
+// Bytes returns x as a zero-extended big-endian byte slice. The size of the
+// slice will match the size of m.
+//
+// x must have the same size as m and it must be reduced modulo m.
+func (x *Nat) Bytes(m *Modulus) []byte {
+ bytes := make([]byte, m.Size())
+ shift := 0
+ outI := len(bytes) - 1
+ for _, limb := range x.limbs {
+ remainingBits := _W
+ for remainingBits >= 8 {
+ bytes[outI] |= byte(limb) << shift
+ consumed := 8 - shift
+ limb >>= consumed
+ remainingBits -= consumed
+ shift = 0
+ outI--
+ if outI < 0 {
+ return bytes
+ }
+ }
+ bytes[outI] = byte(limb)
+ shift = remainingBits
+ }
+ return bytes
+}
+
+// SetBytes assigns x = b, where b is a slice of big-endian bytes.
+// SetBytes returns an error if b >= m.
+//
+// The output will be resized to the size of m and overwritten.
+func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
+ if err := x.setBytes(b, m); err != nil {
+ return nil, err
+ }
+ if x.cmpGeq(m.nat) == yes {
+ return nil, errors.New("input overflows the modulus")
+ }
+ return x, nil
+}
+
+// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
+// returns an error if b has a longer bit length than m, but reduces overflowing
+// values up to 2^⌈log2(m)⌉ - 1.
+//
+// The output will be resized to the size of m and overwritten.
+func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
+ if err := x.setBytes(b, m); err != nil {
+ return nil, err
+ }
+ leading := _W - bitLen(x.limbs[len(x.limbs)-1])
+ if leading < m.leading {
+ return nil, errors.New("input overflows the modulus")
+ }
+ x.sub(x.cmpGeq(m.nat), m.nat)
+ return x, nil
+}
+
+func (x *Nat) setBytes(b []byte, m *Modulus) error {
+ outI := 0
+ shift := 0
+ x.resetFor(m)
+ for i := len(b) - 1; i >= 0; i-- {
+ bi := b[i]
+ x.limbs[outI] |= uint(bi) << shift
+ shift += 8
+ if shift >= _W {
+ shift -= _W
+ x.limbs[outI] &= _MASK
+ overflow := bi >> (8 - shift)
+ outI++
+ if outI >= len(x.limbs) {
+ if overflow > 0 || i > 0 {
+ return errors.New("input overflows the modulus")
+ }
+ break
+ }
+ x.limbs[outI] = uint(overflow)
+ }
+ }
+ return nil
+}
+
+// Equal returns 1 if x == y, and 0 otherwise.
+//
+// Both operands must have the same announced length.
+func (x *Nat) Equal(y *Nat) choice {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ equal := yes
+ for i := 0; i < size; i++ {
+ equal &= ctEq(xLimbs[i], yLimbs[i])
+ }
+ return equal
+}
+
+// IsZero returns 1 if x == 0, and 0 otherwise.
+func (x *Nat) IsZero() choice {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+
+ zero := yes
+ for i := 0; i < size; i++ {
+ zero &= ctEq(xLimbs[i], 0)
+ }
+ return zero
+}
+
+// cmpGeq returns 1 if x >= y, and 0 otherwise.
+//
+// Both operands must have the same announced length.
+func (x *Nat) cmpGeq(y *Nat) choice {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ var c uint
+ for i := 0; i < size; i++ {
+ c = (xLimbs[i] - yLimbs[i] - c) >> _W
+ }
+ // If there was a carry, then subtracting y underflowed, so
+ // x is not greater than or equal to y.
+ return not(choice(c))
+}
+
+// assign sets x <- y if on == 1, and does nothing otherwise.
+//
+// Both operands must have the same announced length.
+func (x *Nat) assign(on choice, y *Nat) *Nat {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
+ }
+ return x
+}
+
+// add computes x += y if on == 1, and does nothing otherwise. It returns the
+// carry of the addition regardless of on.
+//
+// Both operands must have the same announced length.
+func (x *Nat) add(on choice, y *Nat) (c uint) {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ res := xLimbs[i] + yLimbs[i] + c
+ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
+ c = res >> _W
+ }
+ return
+}
+
+// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
+// borrow of the subtraction regardless of on.
+//
+// Both operands must have the same announced length.
+func (x *Nat) sub(on choice, y *Nat) (c uint) {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ res := xLimbs[i] - yLimbs[i] - c
+ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
+ c = res >> _W
+ }
+ return
+}
+
+// Modulus is used for modular arithmetic, precomputing relevant constants.
+//
+// Moduli are assumed to be odd numbers. Moduli can also leak the exact
+// number of bits needed to store their value, and are stored without padding.
+//
+// Their actual value is still kept secret.
+type Modulus struct {
+ // The underlying natural number for this modulus.
+ //
+ // This will be stored without any padding, and shouldn't alias with any
+ // other natural number being used.
+ nat *Nat
+ leading int // number of leading zeros in the modulus
+ m0inv uint // -nat.limbs[0]⁻¹ mod _W
+ rr *Nat // R*R for montgomeryRepresentation
+}
+
+// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
+func rr(m *Modulus) *Nat {
+ rr := NewNat().ExpandFor(m)
+ // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
+ // most significant limb to 1. We then get to R*R by shifting left by _W
+ // n + 1 times.
+ n := len(rr.limbs)
+ rr.limbs[n-1] = 1
+ for i := n - 1; i < 2*n; i++ {
+ rr.shiftIn(0, m) // x = x * 2^_W mod m
+ }
+ return rr
+}
+
+// minusInverseModW computes -x⁻¹ mod _W with x odd.
+//
+// This operation is used to precompute a constant involved in Montgomery
+// multiplication.
+func minusInverseModW(x uint) uint {
+ // Every iteration of this loop doubles the least-significant bits of
+ // correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
+ // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
+ // for 61 bits (and wastes only one iteration for 31 bits).
+ //
+ // See https://crypto.stackexchange.com/a/47496.
+ y := x
+ for i := 0; i < 5; i++ {
+ y = y * (2 - x*y)
+ }
+ return (1 << _W) - (y & _MASK)
+}
+
+// NewModulusFromBig creates a new Modulus from a [big.Int].
+//
+// The Int must be odd. The number of significant bits must be leakable.
+func NewModulusFromBig(n *big.Int) *Modulus {
+ m := &Modulus{}
+ m.nat = NewNat().setBig(n)
+ m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
+ m.m0inv = minusInverseModW(m.nat.limbs[0])
+ m.rr = rr(m)
+ return m
+}
+
+// bitLen is a version of bits.Len that only leaks the bit length of n, but not
+// its value. bits.Len and bits.LeadingZeros use a lookup table for the
+// low-order bits on some architectures.
+func bitLen(n uint) int {
+ var len int
+ // We assume, here and elsewhere, that comparison to zero is constant time
+ // with respect to different non-zero values.
+ for n != 0 {
+ len++
+ n >>= 1
+ }
+ return len
+}
+
+// Size returns the size of m in bytes.
+func (m *Modulus) Size() int {
+ return (m.BitLen() + 7) / 8
+}
+
+// BitLen returns the size of m in bits.
+func (m *Modulus) BitLen() int {
+ return len(m.nat.limbs)*_W - int(m.leading)
+}
+
+// Nat returns m as a Nat. The return value must not be written to.
+func (m *Modulus) Nat() *Nat {
+ return m.nat
+}
+
+// shiftIn calculates x = x << _W + y mod m.
+//
+// This assumes that x is already reduced mod m, and that y < 2^_W.
+func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
+ d := NewNat().resetFor(m)
+
+ // Eliminate bounds checks in the loop.
+ size := len(m.nat.limbs)
+ xLimbs := x.limbs[:size]
+ dLimbs := d.limbs[:size]
+ mLimbs := m.nat.limbs[:size]
+
+ // Each iteration of this loop computes x = 2x + b mod m, where b is a bit
+ // from y. Effectively, it left-shifts x and adds y one bit at a time,
+ // reducing it every time.
+ //
+ // To do the reduction, each iteration computes both 2x + b and 2x + b - m.
+ // The next iteration (and finally the return line) will use either result
+ // based on whether the subtraction underflowed.
+ needSubtraction := no
+ for i := _W - 1; i >= 0; i-- {
+ carry := (y >> i) & 1
+ var borrow uint
+ for i := 0; i < size; i++ {
+ l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
+
+ res := l<<1 + carry
+ xLimbs[i] = res & _MASK
+ carry = res >> _W
+
+ res = xLimbs[i] - mLimbs[i] - borrow
+ dLimbs[i] = res & _MASK
+ borrow = res >> _W
+ }
+ // See Add for how carry (aka overflow), borrow (aka underflow), and
+ // needSubtraction relate.
+ needSubtraction = ctEq(carry, borrow)
+ }
+ return x.assign(needSubtraction, d)
+}
+
+// Mod calculates out = x mod m.
+//
+// This works regardless how large the value of x is.
+//
+// The output will be resized to the size of m and overwritten.
+func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
+ out.resetFor(m)
+ // Working our way from the most significant to the least significant limb,
+ // we can insert each limb at the least significant position, shifting all
+ // previous limbs left by _W. This way each limb will get shifted by the
+ // correct number of bits. We can insert at least N - 1 limbs without
+ // overflowing m. After that, we need to reduce every time we shift.
+ i := len(x.limbs) - 1
+ // For the first N - 1 limbs we can skip the actual shifting and position
+ // them at the shifted position, which starts at min(N - 2, i).
+ start := len(m.nat.limbs) - 2
+ if i < start {
+ start = i
+ }
+ for j := start; j >= 0; j-- {
+ out.limbs[j] = x.limbs[i]
+ i--
+ }
+ // We shift in the remaining limbs, reducing modulo m each time.
+ for i >= 0 {
+ out.shiftIn(x.limbs[i], m)
+ i--
+ }
+ return out
+}
+
+// ExpandFor ensures out has the right size to work with operations modulo m.
+//
+// The announced size of out must be smaller than or equal to that of m.
+func (out *Nat) ExpandFor(m *Modulus) *Nat {
+ return out.expand(len(m.nat.limbs))
+}
+
+// resetFor ensures out has the right size to work with operations modulo m.
+//
+// out is zeroed and may start at any size.
+func (out *Nat) resetFor(m *Modulus) *Nat {
+ return out.reset(len(m.nat.limbs))
+}
+
+// Sub computes x = x - y mod m.
+//
+// The length of both operands must be the same as the modulus. Both operands
+// must already be reduced modulo m.
+func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
+ underflow := x.sub(yes, y)
+ // If the subtraction underflowed, add m.
+ x.add(choice(underflow), m.nat)
+ return x
+}
+
+// Add computes x = x + y mod m.
+//
+// The length of both operands must be the same as the modulus. Both operands
+// must already be reduced modulo m.
+func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
+ overflow := x.add(yes, y)
+ underflow := not(x.cmpGeq(m.nat)) // x < m
+
+ // Three cases are possible:
+ //
+ // - overflow = 0, underflow = 0
+ //
+ // In this case, addition fits in our limbs, but we can still subtract away
+ // m without an underflow, so we need to perform the subtraction to reduce
+ // our result.
+ //
+ // - overflow = 0, underflow = 1
+ //
+ // The addition fits in our limbs, but we can't subtract m without
+ // underflowing. The result is already reduced.
+ //
+ // - overflow = 1, underflow = 1
+ //
+ // The addition does not fit in our limbs, and the subtraction's borrow
+ // would cancel out with the addition's carry. We need to subtract m to
+ // reduce our result.
+ //
+ // The overflow = 1, underflow = 0 case is not possible, because y is at
+ // most m - 1, and if adding m - 1 overflows, then subtracting m must
+ // necessarily underflow.
+ needSubtraction := ctEq(overflow, uint(underflow))
+
+ x.sub(needSubtraction, m.nat)
+ return x
+}
+
+// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs).
+//
+// Faster Montgomery multiplication replaces standard modular multiplication for
+// numbers in this representation.
+//
+// This assumes that x is already reduced mod m.
+func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
+ // A Montgomery multiplication (which computes a * b / R) by R * R works out
+ // to a multiplication by R, which takes the value out of the Montgomery domain.
+ return x.montgomeryMul(NewNat().set(x), m.rr, m)
+}
+
+// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs).
+//
+// This assumes that x is already reduced mod m.
+func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
+ // By Montgomery multiplying with 1 not in Montgomery representation, we
+ // convert out back from Montgomery representation, because it works out to
+ // dividing by R.
+ t0 := NewNat().set(x)
+ t1 := NewNat().ExpandFor(m)
+ t1.limbs[0] = 1
+ return x.montgomeryMul(t0, t1, m)
+}
+
+// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
+//
+// All inputs should be the same length, not aliasing d, and already
+// reduced modulo m. d will be resized to the size of m and overwritten.
+func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
+ d.resetFor(m)
+ if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
+ panic("bigmod: invalid montgomeryMul input")
+ }
+
+ // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
+ // for a description of the algorithm implemented mostly in montgomeryLoop.
+ // See Add for how overflow, underflow, and needSubtraction relate.
+ overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
+ underflow := not(d.cmpGeq(m.nat)) // d < m
+ needSubtraction := ctEq(overflow, uint(underflow))
+ d.sub(needSubtraction, m.nat)
+
+ return d
+}
+
+func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
+ // Eliminate bounds checks in the loop.
+ size := len(d)
+ a = a[:size]
+ b = b[:size]
+ m = m[:size]
+
+ for _, ai := range a {
+ // This is an unrolled iteration of the loop below with j = 0.
+ hi, lo := bits.Mul(ai, b[0])
+ z_lo, c := bits.Add(d[0], lo, 0)
+ f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
+ z_hi, _ := bits.Add(0, hi, c)
+ hi, lo = bits.Mul(f, m[0])
+ z_lo, c = bits.Add(z_lo, lo, 0)
+ z_hi, _ = bits.Add(z_hi, hi, c)
+ carry := z_hi<<1 | z_lo>>_W
+
+ for j := 1; j < size; j++ {
+ // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
+ hi, lo := bits.Mul(ai, b[j])
+ z_lo, c := bits.Add(d[j], lo, 0)
+ z_hi, _ := bits.Add(0, hi, c)
+ hi, lo = bits.Mul(f, m[j])
+ z_lo, c = bits.Add(z_lo, lo, 0)
+ z_hi, _ = bits.Add(z_hi, hi, c)
+ z_lo, c = bits.Add(z_lo, carry, 0)
+ z_hi, _ = bits.Add(z_hi, 0, c)
+ d[j-1] = z_lo & _MASK
+ carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
+ }
+
+ z := overflow + carry // z <= 2^(W+1) - 1
+ d[size-1] = z & _MASK
+ overflow = z >> _W // overflow <= 1
+ }
+ return
+}
+
+// Mul calculates x *= y mod m.
+//
+// x and y must already be reduced modulo m, they must share its announced
+// length, and they may not alias.
+func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
+ // A Montgomery multiplication by a value out of the Montgomery domain
+ // takes the result out of Montgomery representation.
+ xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+ return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
+}
+
+// Exp calculates out = x^e mod m.
+//
+// The exponent e is represented in big-endian order. The output will be resized
+// to the size of m and overwritten. x must already be reduced modulo m.
+func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
+ // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
+ // than 2 bit windows, but use an extra 12 nats worth of scratch space.
+ // Using bit sizes that don't divide 8 are more complex to implement.
+
+ table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
+ // newNat calls are unrolled so they are allocated on the stack.
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+ }
+ table[0].set(x).montgomeryRepresentation(m)
+ for i := 1; i < len(table); i++ {
+ table[i].montgomeryMul(table[i-1], table[0], m)
+ }
+
+ out.resetFor(m)
+ out.limbs[0] = 1
+ out.montgomeryRepresentation(m)
+ t0 := NewNat().ExpandFor(m)
+ t1 := NewNat().ExpandFor(m)
+ for _, b := range e {
+ for _, j := range []int{4, 0} {
+ // Square four times.
+ t1.montgomeryMul(out, out, m)
+ out.montgomeryMul(t1, t1, m)
+ t1.montgomeryMul(out, out, m)
+ out.montgomeryMul(t1, t1, m)
+
+ // Select x^k in constant time from the table.
+ k := uint((b >> j) & 0b1111)
+ for i := range table {
+ t0.assign(ctEq(k, uint(i+1)), table[i])
+ }
+
+ // Multiply by x^k, discarding the result if k = 0.
+ t1.montgomeryMul(out, t0, m)
+ out.assign(not(ctEq(k, 0)), t1)
+ }
+ }
+
+ return out.montgomeryReduction(m)
+}
diff --git a/src/crypto/internal/bigmod/nat_amd64.go b/src/crypto/internal/bigmod/nat_amd64.go
new file mode 100644
index 0000000..e947782
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat_amd64.go
@@ -0,0 +1,8 @@
+// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
+
+//go:build amd64 && gc && !purego
+
+package bigmod
+
+//go:noescape
+func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
diff --git a/src/crypto/internal/bigmod/nat_amd64.s b/src/crypto/internal/bigmod/nat_amd64.s
new file mode 100644
index 0000000..12b7629
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat_amd64.s
@@ -0,0 +1,68 @@
+// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
+
+//go:build amd64 && gc && !purego
+
+// func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
+TEXT ·montgomeryLoop(SB), $8-112
+ MOVQ d_len+8(FP), CX
+ MOVQ d_base+0(FP), BX
+ MOVQ b_base+48(FP), SI
+ MOVQ m_base+72(FP), DI
+ MOVQ m0inv+96(FP), R8
+ XORQ R9, R9
+ XORQ R10, R10
+
+outerLoop:
+ MOVQ a_base+24(FP), R11
+ MOVQ (R11)(R10*8), R11
+ MOVQ (SI), AX
+ MULQ R11
+ MOVQ AX, R13
+ MOVQ DX, R12
+ ADDQ (BX), R13
+ ADCQ $0x00, R12
+ MOVQ R8, R14
+ IMULQ R13, R14
+ BTRQ $0x3f, R14
+ MOVQ (DI), AX
+ MULQ R14
+ ADDQ AX, R13
+ ADCQ DX, R12
+ SHRQ $0x3f, R12, R13
+ XORQ R12, R12
+ INCQ R12
+ JMP innerLoopCondition
+
+innerLoop:
+ MOVQ (SI)(R12*8), AX
+ MULQ R11
+ MOVQ AX, BP
+ MOVQ DX, R15
+ MOVQ (DI)(R12*8), AX
+ MULQ R14
+ ADDQ AX, BP
+ ADCQ DX, R15
+ ADDQ (BX)(R12*8), BP
+ ADCQ $0x00, R15
+ ADDQ R13, BP
+ ADCQ $0x00, R15
+ MOVQ BP, AX
+ BTRQ $0x3f, AX
+ MOVQ AX, -8(BX)(R12*8)
+ SHRQ $0x3f, R15, BP
+ MOVQ BP, R13
+ INCQ R12
+
+innerLoopCondition:
+ CMPQ CX, R12
+ JGT innerLoop
+ ADDQ R13, R9
+ MOVQ R9, AX
+ BTRQ $0x3f, AX
+ MOVQ AX, -8(BX)(CX*8)
+ SHRQ $0x3f, R9
+ INCQ R10
+ CMPQ CX, R10
+ JGT outerLoop
+ MOVQ R9, ret+104(FP)
+ RET
diff --git a/src/crypto/internal/bigmod/nat_noasm.go b/src/crypto/internal/bigmod/nat_noasm.go
new file mode 100644
index 0000000..870b445
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat_noasm.go
@@ -0,0 +1,11 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !amd64 || !gc || purego
+
+package bigmod
+
+func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint {
+ return montgomeryLoopGeneric(d, a, b, m, m0inv)
+}
diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go
new file mode 100644
index 0000000..6431d25
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat_test.go
@@ -0,0 +1,412 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bigmod
+
+import (
+ "math/big"
+ "math/bits"
+ "math/rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+)
+
+// Generate generates an even nat. It's used by testing/quick to produce random
+// *nat values for quick.Check invocations.
+func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
+ limbs := make([]uint, size)
+ for i := 0; i < size; i++ {
+ limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
+ }
+ return reflect.ValueOf(&Nat{limbs})
+}
+
+func testModAddCommutative(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ aPlusB := new(Nat).set(a)
+ aPlusB.Add(b, m)
+ bPlusA := new(Nat).set(b)
+ bPlusA.Add(a, m)
+ return aPlusB.Equal(bPlusA) == 1
+}
+
+func TestModAddCommutative(t *testing.T) {
+ err := quick.Check(testModAddCommutative, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ original := new(Nat).set(a)
+ a.Sub(b, m)
+ a.Add(b, m)
+ return a.Equal(original) == 1
+}
+
+func TestModSubThenAddIdentity(t *testing.T) {
+ err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testMontgomeryRoundtrip(a *Nat) bool {
+ one := &Nat{make([]uint, len(a.limbs))}
+ one.limbs[0] = 1
+ aPlusOne := new(big.Int).SetBytes(natBytes(a))
+ aPlusOne.Add(aPlusOne, big.NewInt(1))
+ m := NewModulusFromBig(aPlusOne)
+ monty := new(Nat).set(a)
+ monty.montgomeryRepresentation(m)
+ aAgain := new(Nat).set(monty)
+ aAgain.montgomeryMul(monty, one, m)
+ return a.Equal(aAgain) == 1
+}
+
+func TestMontgomeryRoundtrip(t *testing.T) {
+ err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestShiftIn(t *testing.T) {
+ if bits.UintSize != 64 {
+ t.Skip("examples are only valid in 64 bit")
+ }
+ examples := []struct {
+ m, x, expected []byte
+ y uint64
+ }{{
+ m: []byte{13},
+ x: []byte{0},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{7},
+ }, {
+ m: []byte{13},
+ x: []byte{7},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{11},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: make([]byte, 9),
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ y: 0,
+ expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+ }}
+
+ for i, tt := range examples {
+ m := modulusFromBytes(tt.m)
+ got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
+ if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+ }
+ }
+}
+
+func TestModulusAndNatSizes(t *testing.T) {
+ // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
+ // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
+ // limbs, if they are not, they fit in three. This can be a problem because
+ // modulus strips leading zeroes and nat does not.
+ m := modulusFromBytes([]byte{
+ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+ xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
+ natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
+ NewNat().SetBytes(xb, m)
+}
+
+func TestSetBytes(t *testing.T) {
+ tests := []struct {
+ m, b []byte
+ fail bool
+ }{{
+ m: []byte{0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }}
+
+ for i, tt := range tests {
+ m := modulusFromBytes(tt.m)
+ got, err := NewNat().SetBytes(tt.b, m)
+ if err != nil {
+ if !tt.fail {
+ t.Errorf("%d: unexpected error: %v", i, err)
+ }
+ continue
+ }
+ if err == nil && tt.fail {
+ t.Errorf("%d: unexpected success", i)
+ continue
+ }
+ if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
+ t.Errorf("%d: got %x, expected %x", i, got, expected)
+ }
+ }
+
+ f := func(xBytes []byte) bool {
+ m := maxModulus(uint(len(xBytes)*8/_W + 1))
+ got, err := NewNat().SetBytes(xBytes, m)
+ if err != nil {
+ return false
+ }
+ return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
+ }
+
+ err := quick.Check(f, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestExpand(t *testing.T) {
+ sliced := []uint{1, 2, 3, 4}
+ examples := []struct {
+ in []uint
+ n int
+ out []uint
+ }{{
+ []uint{1, 2},
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ sliced[:2],
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ []uint{1, 2},
+ 2,
+ []uint{1, 2},
+ }}
+
+ for i, tt := range examples {
+ got := (&Nat{tt.in}).expand(tt.n)
+ if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+ }
+ }
+}
+
+func TestMod(t *testing.T) {
+ m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
+ x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+ out := new(Nat)
+ out.Mod(x, m)
+ expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func TestModSub(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Sub(y, m)
+ expected := &Nat{[]uint{12}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Sub(y, m)
+ expected = &Nat{[]uint{5}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestModAdd(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Add(y, m)
+ expected := &Nat{[]uint{0}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Add(y, m)
+ expected = &Nat{[]uint{7}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestExp(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{3}}
+ out := &Nat{[]uint{0}}
+ out.Exp(x, []byte{12}, m)
+ expected := &Nat{[]uint{1}}
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func natBytes(n *Nat) []byte {
+ return n.Bytes(maxModulus(uint(len(n.limbs))))
+}
+
+func natFromBytes(b []byte) *Nat {
+ bb := new(big.Int).SetBytes(b)
+ return NewNat().setBig(bb)
+}
+
+func modulusFromBytes(b []byte) *Modulus {
+ bb := new(big.Int).SetBytes(b)
+ return NewModulusFromBig(bb)
+}
+
+// maxModulus returns the biggest modulus that can fit in n limbs.
+func maxModulus(n uint) *Modulus {
+ m := big.NewInt(1)
+ m.Lsh(m, n*_W)
+ m.Sub(m, big.NewInt(1))
+ return NewModulusFromBig(m)
+}
+
+func makeBenchmarkModulus() *Modulus {
+ return maxModulus(32)
+}
+
+func makeBenchmarkValue() *Nat {
+ x := make([]uint, 32)
+ for i := 0; i < 32; i++ {
+ x[i] = _MASK - 1
+ }
+ return &Nat{limbs: x}
+}
+
+func makeBenchmarkExponent() []byte {
+ e := make([]byte, 256)
+ for i := 0; i < 32; i++ {
+ e[i] = 0xFF
+ }
+ return e
+}
+
+func BenchmarkModAdd(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Add(y, m)
+ }
+}
+
+func BenchmarkModSub(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Sub(y, m)
+ }
+}
+
+func BenchmarkMontgomeryRepr(b *testing.B) {
+ x := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.montgomeryRepresentation(m)
+ }
+}
+
+func BenchmarkMontgomeryMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.montgomeryMul(x, y, m)
+ }
+}
+
+func BenchmarkModMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Mul(y, m)
+ }
+}
+
+func BenchmarkExpBig(b *testing.B) {
+ out := new(big.Int)
+ exponentBytes := makeBenchmarkExponent()
+ x := new(big.Int).SetBytes(exponentBytes)
+ e := new(big.Int).SetBytes(exponentBytes)
+ n := new(big.Int).SetBytes(exponentBytes)
+ one := new(big.Int).SetUint64(1)
+ n.Add(n, one)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, n)
+ }
+}
+
+func BenchmarkExp(b *testing.B) {
+ x := makeBenchmarkValue()
+ e := makeBenchmarkExponent()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, m)
+ }
+}