diff options
Diffstat (limited to 'src/crypto/internal/bigmod')
-rw-r--r-- | src/crypto/internal/bigmod/_asm/go.mod | 12 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/_asm/go.sum | 32 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/_asm/nat_amd64_asm.go | 113 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat.go | 770 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_386.s | 47 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_amd64.s | 1230 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_arm.s | 47 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_arm64.s | 69 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_asm.go | 28 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_noasm.go | 21 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_ppc64x.s | 51 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_s390x.s | 85 | ||||
-rw-r--r-- | src/crypto/internal/bigmod/nat_test.go | 480 |
13 files changed, 2985 insertions, 0 deletions
diff --git a/src/crypto/internal/bigmod/_asm/go.mod b/src/crypto/internal/bigmod/_asm/go.mod new file mode 100644 index 0000000..7600a4a --- /dev/null +++ b/src/crypto/internal/bigmod/_asm/go.mod @@ -0,0 +1,12 @@ +module std/crypto/internal/bigmod/_asm + +go 1.19 + +require github.com/mmcloughlin/avo v0.4.0 + +require ( + golang.org/x/mod v0.4.2 // indirect + golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 // indirect + golang.org/x/tools v0.1.7 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect +) diff --git a/src/crypto/internal/bigmod/_asm/go.sum b/src/crypto/internal/bigmod/_asm/go.sum new file mode 100644 index 0000000..b4b5914 --- /dev/null +++ b/src/crypto/internal/bigmod/_asm/go.sum @@ -0,0 +1,32 @@ +github.com/mmcloughlin/avo v0.4.0 h1:jeHDRktVD+578ULxWpQHkilor6pkdLF7u7EiTzDbfcU= +github.com/mmcloughlin/avo v0.4.0/go.mod h1:RW9BfYA3TgO9uCdNrKU2h6J8cPD8ZLznvfgHAeszb1s= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 h1:giLT+HuUP/gXYrG2Plg9WTjj4qhfgaW424ZIFog3rlk= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= +golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go new file mode 100644 index 0000000..bf64565 --- /dev/null +++ b/src/crypto/internal/bigmod/_asm/nat_amd64_asm.go @@ -0,0 +1,113 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "strconv" + + . "github.com/mmcloughlin/avo/build" + . "github.com/mmcloughlin/avo/operand" + . "github.com/mmcloughlin/avo/reg" +) + +//go:generate go run . -out ../nat_amd64.s -pkg bigmod + +func main() { + Package("crypto/internal/bigmod") + ConstraintExpr("!purego") + + addMulVVW(1024) + addMulVVW(1536) + addMulVVW(2048) + + Generate() +} + +func addMulVVW(bits int) { + if bits%64 != 0 { + panic("bit size unsupported") + } + + Implement("addMulVVW" + strconv.Itoa(bits)) + + CMPB(Mem{Symbol: Symbol{Name: "·supportADX"}, Base: StaticBase}, Imm(1)) + JEQ(LabelRef("adx")) + + z := Mem{Base: Load(Param("z"), GP64())} + x := Mem{Base: Load(Param("x"), GP64())} + y := Load(Param("y"), GP64()) + + carry := GP64() + XORQ(carry, carry) // zero out carry + + for i := 0; i < bits/64; i++ { + Comment("Iteration " + strconv.Itoa(i)) + hi, lo := RDX, RAX // implicit MULQ inputs and outputs + MOVQ(x.Offset(i*8), lo) + MULQ(y) + ADDQ(z.Offset(i*8), lo) + ADCQ(Imm(0), hi) + ADDQ(carry, lo) + ADCQ(Imm(0), hi) + MOVQ(hi, carry) + MOVQ(lo, z.Offset(i*8)) + } + + Store(carry, ReturnIndex(0)) + RET() + + Label("adx") + + // The ADX strategy implements the following function, where c1 and c2 are + // the overflow and the carry flag respectively. + // + // func addMulVVW(z, x []uint, y uint) (carry uint) { + // var c1, c2 uint + // for i := range z { + // hi, lo := bits.Mul(x[i], y) + // lo, c1 = bits.Add(lo, z[i], c1) + // z[i], c2 = bits.Add(lo, carry, c2) + // carry = hi + // } + // return carry + c1 + c2 + // } + // + // The loop is fully unrolled and the hi / carry registers are alternated + // instead of introducing a MOV. + + z = Mem{Base: Load(Param("z"), GP64())} + x = Mem{Base: Load(Param("x"), GP64())} + Load(Param("y"), RDX) // implicit source of MULXQ + + carry = GP64() + XORQ(carry, carry) // zero out carry + z0 := GP64() + XORQ(z0, z0) // unset flags and zero out z0 + + for i := 0; i < bits/64; i++ { + hi, lo := GP64(), GP64() + + Comment("Iteration " + strconv.Itoa(i)) + MULXQ(x.Offset(i*8), lo, hi) + ADCXQ(carry, lo) + ADOXQ(z.Offset(i*8), lo) + MOVQ(lo, z.Offset(i*8)) + + i++ + + Comment("Iteration " + strconv.Itoa(i)) + MULXQ(x.Offset(i*8), lo, carry) + ADCXQ(hi, lo) + ADOXQ(z.Offset(i*8), lo) + MOVQ(lo, z.Offset(i*8)) + } + + Comment("Add back carry flags and return") + ADCXQ(z0, carry) + ADOXQ(z0, carry) + + Store(carry, ReturnIndex(0)) + RET() +} diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go new file mode 100644 index 0000000..5605e9f --- /dev/null +++ b/src/crypto/internal/bigmod/nat.go @@ -0,0 +1,770 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bigmod + +import ( + "encoding/binary" + "errors" + "math/big" + "math/bits" +) + +const ( + // _W is the size in bits of our limbs. + _W = bits.UintSize + // _S is the size in bytes of our limbs. + _S = _W / 8 +) + +// choice represents a constant-time boolean. The value of choice is always +// either 1 or 0. We use an int instead of bool in order to make decisions in +// constant time by turning it into a mask. +type choice uint + +func not(c choice) choice { return 1 ^ c } + +const yes = choice(1) +const no = choice(0) + +// ctMask is all 1s if on is yes, and all 0s otherwise. +func ctMask(on choice) uint { return -uint(on) } + +// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this +// function does not depend on its inputs. +func ctEq(x, y uint) choice { + // If x != y, then either x - y or y - x will generate a carry. + _, c1 := bits.Sub(x, y, 0) + _, c2 := bits.Sub(y, x, 0) + return not(choice(c1 | c2)) +} + +// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this +// function does not depend on its inputs. +func ctGeq(x, y uint) choice { + // If x < y, then x - y generates a carry. + _, carry := bits.Sub(x, y, 0) + return not(choice(carry)) +} + +// Nat represents an arbitrary natural number +// +// Each Nat has an announced length, which is the number of limbs it has stored. +// Operations on this number are allowed to leak this length, but will not leak +// any information about the values contained in those limbs. +type Nat struct { + // limbs is little-endian in base 2^W with W = bits.UintSize. + limbs []uint +} + +// preallocTarget is the size in bits of the numbers used to implement the most +// common and most performant RSA key size. It's also enough to cover some of +// the operations of key sizes up to 4096. +const preallocTarget = 2048 +const preallocLimbs = (preallocTarget + _W - 1) / _W + +// NewNat returns a new nat with a size of zero, just like new(Nat), but with +// the preallocated capacity to hold a number of up to preallocTarget bits. +// NewNat inlines, so the allocation can live on the stack. +func NewNat() *Nat { + limbs := make([]uint, 0, preallocLimbs) + return &Nat{limbs} +} + +// expand expands x to n limbs, leaving its value unchanged. +func (x *Nat) expand(n int) *Nat { + if len(x.limbs) > n { + panic("bigmod: internal error: shrinking nat") + } + if cap(x.limbs) < n { + newLimbs := make([]uint, n) + copy(newLimbs, x.limbs) + x.limbs = newLimbs + return x + } + extraLimbs := x.limbs[len(x.limbs):n] + for i := range extraLimbs { + extraLimbs[i] = 0 + } + x.limbs = x.limbs[:n] + return x +} + +// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). +func (x *Nat) reset(n int) *Nat { + if cap(x.limbs) < n { + x.limbs = make([]uint, n) + return x + } + for i := range x.limbs { + x.limbs[i] = 0 + } + x.limbs = x.limbs[:n] + return x +} + +// set assigns x = y, optionally resizing x to the appropriate size. +func (x *Nat) set(y *Nat) *Nat { + x.reset(len(y.limbs)) + copy(x.limbs, y.limbs) + return x +} + +// setBig assigns x = n, optionally resizing n to the appropriate size. +// +// The announced length of x is set based on the actual bit size of the input, +// ignoring leading zeroes. +func (x *Nat) setBig(n *big.Int) *Nat { + limbs := n.Bits() + x.reset(len(limbs)) + for i := range limbs { + x.limbs[i] = uint(limbs[i]) + } + return x +} + +// Bytes returns x as a zero-extended big-endian byte slice. The size of the +// slice will match the size of m. +// +// x must have the same size as m and it must be reduced modulo m. +func (x *Nat) Bytes(m *Modulus) []byte { + i := m.Size() + bytes := make([]byte, i) + for _, limb := range x.limbs { + for j := 0; j < _S; j++ { + i-- + if i < 0 { + if limb == 0 { + break + } + panic("bigmod: modulus is smaller than nat") + } + bytes[i] = byte(limb) + limb >>= 8 + } + } + return bytes +} + +// SetBytes assigns x = b, where b is a slice of big-endian bytes. +// SetBytes returns an error if b >= m. +// +// The output will be resized to the size of m and overwritten. +func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { + if err := x.setBytes(b, m); err != nil { + return nil, err + } + if x.cmpGeq(m.nat) == yes { + return nil, errors.New("input overflows the modulus") + } + return x, nil +} + +// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. +// SetOverflowingBytes returns an error if b has a longer bit length than m, but +// reduces overflowing values up to 2^⌈log2(m)⌉ - 1. +// +// The output will be resized to the size of m and overwritten. +func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { + if err := x.setBytes(b, m); err != nil { + return nil, err + } + leading := _W - bitLen(x.limbs[len(x.limbs)-1]) + if leading < m.leading { + return nil, errors.New("input overflows the modulus size") + } + x.maybeSubtractModulus(no, m) + return x, nil +} + +// bigEndianUint returns the contents of buf interpreted as a +// big-endian encoded uint value. +func bigEndianUint(buf []byte) uint { + if _W == 64 { + return uint(binary.BigEndian.Uint64(buf)) + } + return uint(binary.BigEndian.Uint32(buf)) +} + +func (x *Nat) setBytes(b []byte, m *Modulus) error { + x.resetFor(m) + i, k := len(b), 0 + for k < len(x.limbs) && i >= _S { + x.limbs[k] = bigEndianUint(b[i-_S : i]) + i -= _S + k++ + } + for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { + x.limbs[k] |= uint(b[i-1]) << s + i-- + } + if i > 0 { + return errors.New("input overflows the modulus size") + } + return nil +} + +// Equal returns 1 if x == y, and 0 otherwise. +// +// Both operands must have the same announced length. +func (x *Nat) Equal(y *Nat) choice { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + equal := yes + for i := 0; i < size; i++ { + equal &= ctEq(xLimbs[i], yLimbs[i]) + } + return equal +} + +// IsZero returns 1 if x == 0, and 0 otherwise. +func (x *Nat) IsZero() choice { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + + zero := yes + for i := 0; i < size; i++ { + zero &= ctEq(xLimbs[i], 0) + } + return zero +} + +// cmpGeq returns 1 if x >= y, and 0 otherwise. +// +// Both operands must have the same announced length. +func (x *Nat) cmpGeq(y *Nat) choice { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + var c uint + for i := 0; i < size; i++ { + _, c = bits.Sub(xLimbs[i], yLimbs[i], c) + } + // If there was a carry, then subtracting y underflowed, so + // x is not greater than or equal to y. + return not(choice(c)) +} + +// assign sets x <- y if on == 1, and does nothing otherwise. +// +// Both operands must have the same announced length. +func (x *Nat) assign(on choice, y *Nat) *Nat { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + mask := ctMask(on) + for i := 0; i < size; i++ { + xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) + } + return x +} + +// add computes x += y and returns the carry. +// +// Both operands must have the same announced length. +func (x *Nat) add(y *Nat) (c uint) { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + for i := 0; i < size; i++ { + xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) + } + return +} + +// sub computes x -= y. It returns the borrow of the subtraction. +// +// Both operands must have the same announced length. +func (x *Nat) sub(y *Nat) (c uint) { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + for i := 0; i < size; i++ { + xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) + } + return +} + +// Modulus is used for modular arithmetic, precomputing relevant constants. +// +// Moduli are assumed to be odd numbers. Moduli can also leak the exact +// number of bits needed to store their value, and are stored without padding. +// +// Their actual value is still kept secret. +type Modulus struct { + // The underlying natural number for this modulus. + // + // This will be stored without any padding, and shouldn't alias with any + // other natural number being used. + nat *Nat + leading int // number of leading zeros in the modulus + m0inv uint // -nat.limbs[0]⁻¹ mod _W + rr *Nat // R*R for montgomeryRepresentation +} + +// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). +func rr(m *Modulus) *Nat { + rr := NewNat().ExpandFor(m) + // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the + // most significant limb to 1. We then get to R*R by shifting left by _W + // n + 1 times. + n := len(rr.limbs) + rr.limbs[n-1] = 1 + for i := n - 1; i < 2*n; i++ { + rr.shiftIn(0, m) // x = x * 2^_W mod m + } + return rr +} + +// minusInverseModW computes -x⁻¹ mod _W with x odd. +// +// This operation is used to precompute a constant involved in Montgomery +// multiplication. +func minusInverseModW(x uint) uint { + // Every iteration of this loop doubles the least-significant bits of + // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, + // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough + // for 64 bits (and wastes only one iteration for 32 bits). + // + // See https://crypto.stackexchange.com/a/47496. + y := x + for i := 0; i < 5; i++ { + y = y * (2 - x*y) + } + return -y +} + +// NewModulusFromBig creates a new Modulus from a [big.Int]. +// +// The Int must be odd. The number of significant bits (and nothing else) is +// leaked through timing side-channels. +func NewModulusFromBig(n *big.Int) (*Modulus, error) { + if b := n.Bits(); len(b) == 0 { + return nil, errors.New("modulus must be >= 0") + } else if b[0]&1 != 1 { + return nil, errors.New("modulus must be odd") + } + m := &Modulus{} + m.nat = NewNat().setBig(n) + m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) + m.m0inv = minusInverseModW(m.nat.limbs[0]) + m.rr = rr(m) + return m, nil +} + +// bitLen is a version of bits.Len that only leaks the bit length of n, but not +// its value. bits.Len and bits.LeadingZeros use a lookup table for the +// low-order bits on some architectures. +func bitLen(n uint) int { + var len int + // We assume, here and elsewhere, that comparison to zero is constant time + // with respect to different non-zero values. + for n != 0 { + len++ + n >>= 1 + } + return len +} + +// Size returns the size of m in bytes. +func (m *Modulus) Size() int { + return (m.BitLen() + 7) / 8 +} + +// BitLen returns the size of m in bits. +func (m *Modulus) BitLen() int { + return len(m.nat.limbs)*_W - int(m.leading) +} + +// Nat returns m as a Nat. The return value must not be written to. +func (m *Modulus) Nat() *Nat { + return m.nat +} + +// shiftIn calculates x = x << _W + y mod m. +// +// This assumes that x is already reduced mod m. +func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { + d := NewNat().resetFor(m) + + // Eliminate bounds checks in the loop. + size := len(m.nat.limbs) + xLimbs := x.limbs[:size] + dLimbs := d.limbs[:size] + mLimbs := m.nat.limbs[:size] + + // Each iteration of this loop computes x = 2x + b mod m, where b is a bit + // from y. Effectively, it left-shifts x and adds y one bit at a time, + // reducing it every time. + // + // To do the reduction, each iteration computes both 2x + b and 2x + b - m. + // The next iteration (and finally the return line) will use either result + // based on whether 2x + b overflows m. + needSubtraction := no + for i := _W - 1; i >= 0; i-- { + carry := (y >> i) & 1 + var borrow uint + mask := ctMask(needSubtraction) + for i := 0; i < size; i++ { + l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) + xLimbs[i], carry = bits.Add(l, l, carry) + dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) + } + // Like in maybeSubtractModulus, we need the subtraction if either it + // didn't underflow (meaning 2x + b > m) or if computing 2x + b + // overflowed (meaning 2x + b > 2^_W*n > m). + needSubtraction = not(choice(borrow)) | choice(carry) + } + return x.assign(needSubtraction, d) +} + +// Mod calculates out = x mod m. +// +// This works regardless how large the value of x is. +// +// The output will be resized to the size of m and overwritten. +func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { + out.resetFor(m) + // Working our way from the most significant to the least significant limb, + // we can insert each limb at the least significant position, shifting all + // previous limbs left by _W. This way each limb will get shifted by the + // correct number of bits. We can insert at least N - 1 limbs without + // overflowing m. After that, we need to reduce every time we shift. + i := len(x.limbs) - 1 + // For the first N - 1 limbs we can skip the actual shifting and position + // them at the shifted position, which starts at min(N - 2, i). + start := len(m.nat.limbs) - 2 + if i < start { + start = i + } + for j := start; j >= 0; j-- { + out.limbs[j] = x.limbs[i] + i-- + } + // We shift in the remaining limbs, reducing modulo m each time. + for i >= 0 { + out.shiftIn(x.limbs[i], m) + i-- + } + return out +} + +// ExpandFor ensures x has the right size to work with operations modulo m. +// +// The announced size of x must be smaller than or equal to that of m. +func (x *Nat) ExpandFor(m *Modulus) *Nat { + return x.expand(len(m.nat.limbs)) +} + +// resetFor ensures out has the right size to work with operations modulo m. +// +// out is zeroed and may start at any size. +func (out *Nat) resetFor(m *Modulus) *Nat { + return out.reset(len(m.nat.limbs)) +} + +// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. +// +// It can be used to reduce modulo m a value up to 2m - 1, which is a common +// range for results computed by higher level operations. +// +// always is usually a carry that indicates that the operation that produced x +// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. +// +// x and m operands must have the same announced length. +func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { + t := NewNat().set(x) + underflow := t.sub(m.nat) + // We keep the result if x - m didn't underflow (meaning x >= m) + // or if always was set. + keep := not(choice(underflow)) | choice(always) + x.assign(keep, t) +} + +// Sub computes x = x - y mod m. +// +// The length of both operands must be the same as the modulus. Both operands +// must already be reduced modulo m. +func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { + underflow := x.sub(y) + // If the subtraction underflowed, add m. + t := NewNat().set(x) + t.add(m.nat) + x.assign(choice(underflow), t) + return x +} + +// Add computes x = x + y mod m. +// +// The length of both operands must be the same as the modulus. Both operands +// must already be reduced modulo m. +func (x *Nat) Add(y *Nat, m *Modulus) *Nat { + overflow := x.add(y) + x.maybeSubtractModulus(choice(overflow), m) + return x +} + +// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs). +// +// Faster Montgomery multiplication replaces standard modular multiplication for +// numbers in this representation. +// +// This assumes that x is already reduced mod m. +func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { + // A Montgomery multiplication (which computes a * b / R) by R * R works out + // to a multiplication by R, which takes the value out of the Montgomery domain. + return x.montgomeryMul(x, m.rr, m) +} + +// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs). +// +// This assumes that x is already reduced mod m. +func (x *Nat) montgomeryReduction(m *Modulus) *Nat { + // By Montgomery multiplying with 1 not in Montgomery representation, we + // convert out back from Montgomery representation, because it works out to + // dividing by R. + one := NewNat().ExpandFor(m) + one.limbs[0] = 1 + return x.montgomeryMul(x, one, m) +} + +// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs), also known as a Montgomery multiplication. +// +// All inputs should be the same length and already reduced modulo m. +// x will be resized to the size of m and overwritten. +func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { + n := len(m.nat.limbs) + mLimbs := m.nat.limbs[:n] + aLimbs := a.limbs[:n] + bLimbs := b.limbs[:n] + + switch n { + default: + // Attempt to use a stack-allocated backing array. + T := make([]uint, 0, preallocLimbs*2) + if cap(T) < n*2 { + T = make([]uint, 0, n*2) + } + T = T[:n*2] + + // This loop implements Word-by-Word Montgomery Multiplication, as + // described in Algorithm 4 (Fig. 3) of "Efficient Software + // Implementations of Modular Exponentiation" by Shay Gueron + // [https://eprint.iacr.org/2011/239.pdf]. + var c uint + for i := 0; i < n; i++ { + _ = T[n+i] // bounds check elimination hint + + // Step 1 (T = a × b) is computed as a large pen-and-paper column + // multiplication of two numbers with n base-2^_W digits. If we just + // wanted to produce 2n-wide T, we would do + // + // for i := 0; i < n; i++ { + // d := bLimbs[i] + // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) + // } + // + // where d is a digit of the multiplier, T[i:n+i] is the shifted + // position of the product of that digit, and T[n+i] is the final carry. + // Note that T[i] isn't modified after processing the i-th digit. + // + // Instead of running two loops, one for Step 1 and one for Steps 2–6, + // the result of Step 1 is computed during the next loop. This is + // possible because each iteration only uses T[i] in Step 2 and then + // discards it in Step 6. + d := bLimbs[i] + c1 := addMulVVW(T[i:n+i], aLimbs, d) + + // Step 6 is replaced by shifting the virtual window we operate + // over: T of the algorithm is T[i:] for us. That means that T1 in + // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. + Y := T[i] * m.m0inv + + // Step 4 and 5 add Y × m to T, which as mentioned above is stored + // at T[i:]. The two carries (from a × d and Y × m) are added up in + // the next word T[n+i], and the carry bit from that addition is + // brought forward to the next iteration. + c2 := addMulVVW(T[i:n+i], mLimbs, Y) + T[n+i], c = bits.Add(c1, c2, c) + } + + // Finally for Step 7 we copy the final T window into x, and subtract m + // if necessary (which as explained in maybeSubtractModulus can be the + // case both if x >= m, or if x overflowed). + // + // The paper suggests in Section 4 that we can do an "Almost Montgomery + // Multiplication" by subtracting only in the overflow case, but the + // cost is very similar since the constant time subtraction tells us if + // x >= m as a side effect, and taking care of the broken invariant is + // highly undesirable (see https://go.dev/issue/13907). + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + // The following specialized cases follow the exact same algorithm, but + // optimized for the sizes most used in RSA. addMulVVW is implemented in + // assembly with loop unrolling depending on the architecture and bounds + // checks are removed by the compiler thanks to the constant size. + case 1024 / _W: + const n = 1024 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW1024(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + case 1536 / _W: + const n = 1536 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW1536(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + + case 2048 / _W: + const n = 2048 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW2048(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) + } + + return x +} + +// addMulVVW multiplies the multi-word value x by the single-word value y, +// adding the result to the multi-word value z and returning the final carry. +// It can be thought of as one row of a pen-and-paper column multiplication. +func addMulVVW(z, x []uint, y uint) (carry uint) { + _ = x[len(z)-1] // bounds check elimination hint + for i := range z { + hi, lo := bits.Mul(x[i], y) + lo, c := bits.Add(lo, z[i], 0) + // We use bits.Add with zero to get an add-with-carry instruction that + // absorbs the carry from the previous bits.Add. + hi, _ = bits.Add(hi, 0, c) + lo, c = bits.Add(lo, carry, 0) + hi, _ = bits.Add(hi, 0, c) + carry = hi + z[i] = lo + } + return carry +} + +// Mul calculates x = x * y mod m. +// +// The length of both operands must be the same as the modulus. Both operands +// must already be reduced modulo m. +func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { + // A Montgomery multiplication by a value out of the Montgomery domain + // takes the result out of Montgomery representation. + xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m + return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m +} + +// Exp calculates out = x^e mod m. +// +// The exponent e is represented in big-endian order. The output will be resized +// to the size of m and overwritten. x must already be reduced modulo m. +func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { + // We use a 4 bit window. For our RSA workload, 4 bit windows are faster + // than 2 bit windows, but use an extra 12 nats worth of scratch space. + // Using bit sizes that don't divide 8 are more complex to implement, but + // are likely to be more efficient if necessary. + + table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) + // newNat calls are unrolled so they are allocated on the stack. + NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), + NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), + NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), + } + table[0].set(x).montgomeryRepresentation(m) + for i := 1; i < len(table); i++ { + table[i].montgomeryMul(table[i-1], table[0], m) + } + + out.resetFor(m) + out.limbs[0] = 1 + out.montgomeryRepresentation(m) + tmp := NewNat().ExpandFor(m) + for _, b := range e { + for _, j := range []int{4, 0} { + // Square four times. Optimization note: this can be implemented + // more efficiently than with generic Montgomery multiplication. + out.montgomeryMul(out, out, m) + out.montgomeryMul(out, out, m) + out.montgomeryMul(out, out, m) + out.montgomeryMul(out, out, m) + + // Select x^k in constant time from the table. + k := uint((b >> j) & 0b1111) + for i := range table { + tmp.assign(ctEq(k, uint(i+1)), table[i]) + } + + // Multiply by x^k, discarding the result if k = 0. + tmp.montgomeryMul(out, tmp, m) + out.assign(not(ctEq(k, 0)), tmp) + } + } + + return out.montgomeryReduction(m) +} + +// ExpShort calculates out = x^e mod m. +// +// The output will be resized to the size of m and overwritten. x must already +// be reduced modulo m. This leaks the exact bit size of the exponent. +func (out *Nat) ExpShort(x *Nat, e uint, m *Modulus) *Nat { + xR := NewNat().set(x).montgomeryRepresentation(m) + + out.resetFor(m) + out.limbs[0] = 1 + out.montgomeryRepresentation(m) + + // For short exponents, precomputing a table and using a window like in Exp + // doesn't pay off. Instead, we do a simple constant-time conditional + // square-and-multiply chain, skipping the initial run of zeroes. + tmp := NewNat().ExpandFor(m) + for i := bits.UintSize - bitLen(e); i < bits.UintSize; i++ { + out.montgomeryMul(out, out, m) + k := (e >> (bits.UintSize - i - 1)) & 1 + tmp.montgomeryMul(out, xR, m) + out.assign(ctEq(k, 1), tmp) + } + return out.montgomeryReduction(m) +} diff --git a/src/crypto/internal/bigmod/nat_386.s b/src/crypto/internal/bigmod/nat_386.s new file mode 100644 index 0000000..0637d27 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_386.s @@ -0,0 +1,47 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-16 + MOVL $32, BX + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-16 + MOVL $48, BX + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-16 + MOVL $64, BX + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVL z+0(FP), DI + MOVL x+4(FP), SI + MOVL y+8(FP), BP + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + MOVL $0, CX // c = 0 + JMP E6 + +L6: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + ADDL AX, (DI)(BX*4) + ADCL $0, DX + MOVL DX, CX + ADDL $1, BX // i++ + +E6: CMPL BX, $0 // i < 0 + JL L6 + + MOVL CX, c+12(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_amd64.s b/src/crypto/internal/bigmod/nat_amd64.s new file mode 100644 index 0000000..ab94344 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_amd64.s @@ -0,0 +1,1230 @@ +// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -pkg bigmod. DO NOT EDIT. + +//go:build !purego + +// func addMulVVW1024(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW1024(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET + +// func addMulVVW1536(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW1536(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + + // Iteration 16 + MOVQ 128(BX), AX + MULQ SI + ADDQ 128(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 128(CX) + + // Iteration 17 + MOVQ 136(BX), AX + MULQ SI + ADDQ 136(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 136(CX) + + // Iteration 18 + MOVQ 144(BX), AX + MULQ SI + ADDQ 144(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 144(CX) + + // Iteration 19 + MOVQ 152(BX), AX + MULQ SI + ADDQ 152(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 152(CX) + + // Iteration 20 + MOVQ 160(BX), AX + MULQ SI + ADDQ 160(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 160(CX) + + // Iteration 21 + MOVQ 168(BX), AX + MULQ SI + ADDQ 168(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 168(CX) + + // Iteration 22 + MOVQ 176(BX), AX + MULQ SI + ADDQ 176(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 176(CX) + + // Iteration 23 + MOVQ 184(BX), AX + MULQ SI + ADDQ 184(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 184(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Iteration 16 + MULXQ 128(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 128(AX), R8 + MOVQ R8, 128(AX) + + // Iteration 17 + MULXQ 136(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 136(AX), R8 + MOVQ R8, 136(AX) + + // Iteration 18 + MULXQ 144(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 144(AX), R8 + MOVQ R8, 144(AX) + + // Iteration 19 + MULXQ 152(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 152(AX), R8 + MOVQ R8, 152(AX) + + // Iteration 20 + MULXQ 160(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 160(AX), R8 + MOVQ R8, 160(AX) + + // Iteration 21 + MULXQ 168(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 168(AX), R8 + MOVQ R8, 168(AX) + + // Iteration 22 + MULXQ 176(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 176(AX), R8 + MOVQ R8, 176(AX) + + // Iteration 23 + MULXQ 184(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 184(AX), R8 + MOVQ R8, 184(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET + +// func addMulVVW2048(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW2048(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + + // Iteration 4 + MOVQ 32(BX), AX + MULQ SI + ADDQ 32(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 32(CX) + + // Iteration 5 + MOVQ 40(BX), AX + MULQ SI + ADDQ 40(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 40(CX) + + // Iteration 6 + MOVQ 48(BX), AX + MULQ SI + ADDQ 48(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 48(CX) + + // Iteration 7 + MOVQ 56(BX), AX + MULQ SI + ADDQ 56(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 56(CX) + + // Iteration 8 + MOVQ 64(BX), AX + MULQ SI + ADDQ 64(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 64(CX) + + // Iteration 9 + MOVQ 72(BX), AX + MULQ SI + ADDQ 72(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 72(CX) + + // Iteration 10 + MOVQ 80(BX), AX + MULQ SI + ADDQ 80(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 80(CX) + + // Iteration 11 + MOVQ 88(BX), AX + MULQ SI + ADDQ 88(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 88(CX) + + // Iteration 12 + MOVQ 96(BX), AX + MULQ SI + ADDQ 96(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 96(CX) + + // Iteration 13 + MOVQ 104(BX), AX + MULQ SI + ADDQ 104(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 104(CX) + + // Iteration 14 + MOVQ 112(BX), AX + MULQ SI + ADDQ 112(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 112(CX) + + // Iteration 15 + MOVQ 120(BX), AX + MULQ SI + ADDQ 120(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 120(CX) + + // Iteration 16 + MOVQ 128(BX), AX + MULQ SI + ADDQ 128(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 128(CX) + + // Iteration 17 + MOVQ 136(BX), AX + MULQ SI + ADDQ 136(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 136(CX) + + // Iteration 18 + MOVQ 144(BX), AX + MULQ SI + ADDQ 144(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 144(CX) + + // Iteration 19 + MOVQ 152(BX), AX + MULQ SI + ADDQ 152(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 152(CX) + + // Iteration 20 + MOVQ 160(BX), AX + MULQ SI + ADDQ 160(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 160(CX) + + // Iteration 21 + MOVQ 168(BX), AX + MULQ SI + ADDQ 168(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 168(CX) + + // Iteration 22 + MOVQ 176(BX), AX + MULQ SI + ADDQ 176(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 176(CX) + + // Iteration 23 + MOVQ 184(BX), AX + MULQ SI + ADDQ 184(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 184(CX) + + // Iteration 24 + MOVQ 192(BX), AX + MULQ SI + ADDQ 192(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 192(CX) + + // Iteration 25 + MOVQ 200(BX), AX + MULQ SI + ADDQ 200(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 200(CX) + + // Iteration 26 + MOVQ 208(BX), AX + MULQ SI + ADDQ 208(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 208(CX) + + // Iteration 27 + MOVQ 216(BX), AX + MULQ SI + ADDQ 216(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 216(CX) + + // Iteration 28 + MOVQ 224(BX), AX + MULQ SI + ADDQ 224(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 224(CX) + + // Iteration 29 + MOVQ 232(BX), AX + MULQ SI + ADDQ 232(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 232(CX) + + // Iteration 30 + MOVQ 240(BX), AX + MULQ SI + ADDQ 240(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 240(CX) + + // Iteration 31 + MOVQ 248(BX), AX + MULQ SI + ADDQ 248(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 248(CX) + MOVQ DI, c+24(FP) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Iteration 4 + MULXQ 32(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 32(AX), R8 + MOVQ R8, 32(AX) + + // Iteration 5 + MULXQ 40(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 40(AX), R8 + MOVQ R8, 40(AX) + + // Iteration 6 + MULXQ 48(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 48(AX), R8 + MOVQ R8, 48(AX) + + // Iteration 7 + MULXQ 56(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 56(AX), R8 + MOVQ R8, 56(AX) + + // Iteration 8 + MULXQ 64(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 64(AX), R8 + MOVQ R8, 64(AX) + + // Iteration 9 + MULXQ 72(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 72(AX), R8 + MOVQ R8, 72(AX) + + // Iteration 10 + MULXQ 80(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 80(AX), R8 + MOVQ R8, 80(AX) + + // Iteration 11 + MULXQ 88(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 88(AX), R8 + MOVQ R8, 88(AX) + + // Iteration 12 + MULXQ 96(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 96(AX), R8 + MOVQ R8, 96(AX) + + // Iteration 13 + MULXQ 104(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 104(AX), R8 + MOVQ R8, 104(AX) + + // Iteration 14 + MULXQ 112(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 112(AX), R8 + MOVQ R8, 112(AX) + + // Iteration 15 + MULXQ 120(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 120(AX), R8 + MOVQ R8, 120(AX) + + // Iteration 16 + MULXQ 128(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 128(AX), R8 + MOVQ R8, 128(AX) + + // Iteration 17 + MULXQ 136(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 136(AX), R8 + MOVQ R8, 136(AX) + + // Iteration 18 + MULXQ 144(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 144(AX), R8 + MOVQ R8, 144(AX) + + // Iteration 19 + MULXQ 152(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 152(AX), R8 + MOVQ R8, 152(AX) + + // Iteration 20 + MULXQ 160(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 160(AX), R8 + MOVQ R8, 160(AX) + + // Iteration 21 + MULXQ 168(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 168(AX), R8 + MOVQ R8, 168(AX) + + // Iteration 22 + MULXQ 176(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 176(AX), R8 + MOVQ R8, 176(AX) + + // Iteration 23 + MULXQ 184(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 184(AX), R8 + MOVQ R8, 184(AX) + + // Iteration 24 + MULXQ 192(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 192(AX), R8 + MOVQ R8, 192(AX) + + // Iteration 25 + MULXQ 200(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 200(AX), R8 + MOVQ R8, 200(AX) + + // Iteration 26 + MULXQ 208(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 208(AX), R8 + MOVQ R8, 208(AX) + + // Iteration 27 + MULXQ 216(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 216(AX), R8 + MOVQ R8, 216(AX) + + // Iteration 28 + MULXQ 224(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 224(AX), R8 + MOVQ R8, 224(AX) + + // Iteration 29 + MULXQ 232(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 232(AX), R8 + MOVQ R8, 232(AX) + + // Iteration 30 + MULXQ 240(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 240(AX), R8 + MOVQ R8, 240(AX) + + // Iteration 31 + MULXQ 248(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 248(AX), R8 + MOVQ R8, 248(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_arm.s b/src/crypto/internal/bigmod/nat_arm.s new file mode 100644 index 0000000..c7397b8 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_arm.s @@ -0,0 +1,47 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-16 + MOVW $32, R5 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-16 + MOVW $48, R5 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-16 + MOVW $64, R5 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW x+4(FP), R2 + MOVW y+8(FP), R3 + ADD R5<<2, R1, R5 + MOVW $0, R4 + B E9 + +L9: MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW 0(R1), R4 + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 + +E9: TEQ R1, R5 + BNE L9 + + MOVW R4, c+12(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_arm64.s b/src/crypto/internal/bigmod/nat_arm64.s new file mode 100644 index 0000000..ba1e611 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_arm64.s @@ -0,0 +1,69 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R0 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R0 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R0 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R1 + MOVD x+8(FP), R2 + MOVD y+16(FP), R3 + MOVD $0, R4 + +// The main loop of this code operates on a block of 4 words every iteration +// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9] +// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next +// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z. +loop: + CBZ R0, done + + LDP.P 16(R2), (R5, R6) + LDP.P 16(R2), (R7, R8) + + LDP (R1), (R9, R10) + ADDS R4, R9 + MUL R6, R3, R14 + ADCS R14, R10 + MUL R7, R3, R15 + LDP 16(R1), (R11, R12) + ADCS R15, R11 + MUL R8, R3, R16 + ADCS R16, R12 + UMULH R8, R3, R20 + ADC $0, R20 + + MUL R5, R3, R13 + ADDS R13, R9 + UMULH R5, R3, R17 + ADCS R17, R10 + UMULH R6, R3, R21 + STP.P (R9, R10), 16(R1) + ADCS R21, R11 + UMULH R7, R3, R19 + ADCS R19, R12 + STP.P (R11, R12), 16(R1) + ADC $0, R20, R4 + + SUB $4, R0 + B loop + +done: + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_asm.go b/src/crypto/internal/bigmod/nat_asm.go new file mode 100644 index 0000000..5eb91e1 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_asm.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x) + +package bigmod + +import "internal/cpu" + +// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry +// chains in the flags in parallel across the whole operation, and aggressively +// unrolls loops. arm64 processes four words at a time. +// +// It's unclear why the assembly for all other architectures, as well as for +// amd64 without ADX, perform better than the compiler output. +// TODO(filippo): file cmd/compile performance issue. + +var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2 + +//go:noescape +func addMulVVW1024(z, x *uint, y uint) (c uint) + +//go:noescape +func addMulVVW1536(z, x *uint, y uint) (c uint) + +//go:noescape +func addMulVVW2048(z, x *uint, y uint) (c uint) diff --git a/src/crypto/internal/bigmod/nat_noasm.go b/src/crypto/internal/bigmod/nat_noasm.go new file mode 100644 index 0000000..eff1253 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_noasm.go @@ -0,0 +1,21 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x) + +package bigmod + +import "unsafe" + +func addMulVVW1024(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y) +} + +func addMulVVW1536(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y) +} + +func addMulVVW2048(z, x *uint, y uint) (c uint) { + return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y) +} diff --git a/src/crypto/internal/bigmod/nat_ppc64x.s b/src/crypto/internal/bigmod/nat_ppc64x.s new file mode 100644 index 0000000..974f4f9 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_ppc64x.s @@ -0,0 +1,51 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego && (ppc64 || ppc64le) + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R22 // R22 = z_len + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R22 // R22 = z_len + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R22 // R22 = z_len + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+8(FP), R8 // R8 = x[] + MOVD y+16(FP), R9 // R9 = y + + MOVD R0, R3 // R3 will be the index register + CMP R0, R22 + MOVD R0, R4 // R4 = c = 0 + MOVD R22, CTR // Initialize loop counter + BEQ done + PCALIGN $16 + +loop: + MOVD (R8)(R3), R20 // Load x[i] + MOVD (R10)(R3), R21 // Load z[i] + MULLD R9, R20, R6 // R6 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = High-order(x[i]*y) + ADDC R21, R6 // R6 = z0 + ADDZE R7 // R7 = z1 + ADDC R4, R6 // R6 = z0 + c + 0 + ADDZE R7, R4 // c += z1 + MOVD R6, (R10)(R3) // Store z[i] + ADD $8, R3 + BC 16, 0, loop // bdnz + +done: + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_s390x.s b/src/crypto/internal/bigmod/nat_s390x.s new file mode 100644 index 0000000..0c07a0c --- /dev/null +++ b/src/crypto/internal/bigmod/nat_s390x.s @@ -0,0 +1,85 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !purego + +#include "textflag.h" + +// func addMulVVW1024(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1024(SB), $0-32 + MOVD $16, R5 + JMP addMulVVWx(SB) + +// func addMulVVW1536(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW1536(SB), $0-32 + MOVD $24, R5 + JMP addMulVVWx(SB) + +// func addMulVVW2048(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW2048(SB), $0-32 + MOVD $32, R5 + JMP addMulVVWx(SB) + +TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0 + MOVD z+0(FP), R2 + MOVD x+8(FP), R8 + MOVD y+16(FP), R9 + + MOVD $0, R1 // i*8 = 0 + MOVD $0, R7 // i = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R4 // c = 0 + + MOVD R5, R12 + AND $-2, R12 + CMPBGE R5, $2, A6 + BR E6 + +A6: + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + MOVD (8)(R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (8)(R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (8)(R2)(R1*1) + + ADD $16, R1 // i*8 + 8 + ADD $2, R7 // i++ + + CMPBLT R7, R12, A6 + BR E6 + +L6: + // TODO: drop unused single-step loop. + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E6: + CMPBLT R7, R5, L6 // i < n + + MOVD R4, c+24(FP) + RET diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go new file mode 100644 index 0000000..76e5570 --- /dev/null +++ b/src/crypto/internal/bigmod/nat_test.go @@ -0,0 +1,480 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bigmod + +import ( + "fmt" + "math/big" + "math/bits" + "math/rand" + "reflect" + "strings" + "testing" + "testing/quick" +) + +func (n *Nat) String() string { + var limbs []string + for i := range n.limbs { + limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i])) + } + return "{" + strings.Join(limbs, " ") + "}" +} + +// Generate generates an even nat. It's used by testing/quick to produce random +// *nat values for quick.Check invocations. +func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { + limbs := make([]uint, size) + for i := 0; i < size; i++ { + limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) + } + return reflect.ValueOf(&Nat{limbs}) +} + +func testModAddCommutative(a *Nat, b *Nat) bool { + m := maxModulus(uint(len(a.limbs))) + aPlusB := new(Nat).set(a) + aPlusB.Add(b, m) + bPlusA := new(Nat).set(b) + bPlusA.Add(a, m) + return aPlusB.Equal(bPlusA) == 1 +} + +func TestModAddCommutative(t *testing.T) { + err := quick.Check(testModAddCommutative, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func testModSubThenAddIdentity(a *Nat, b *Nat) bool { + m := maxModulus(uint(len(a.limbs))) + original := new(Nat).set(a) + a.Sub(b, m) + a.Add(b, m) + return a.Equal(original) == 1 +} + +func TestModSubThenAddIdentity(t *testing.T) { + err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func TestMontgomeryRoundtrip(t *testing.T) { + err := quick.Check(func(a *Nat) bool { + one := &Nat{make([]uint, len(a.limbs))} + one.limbs[0] = 1 + aPlusOne := new(big.Int).SetBytes(natBytes(a)) + aPlusOne.Add(aPlusOne, big.NewInt(1)) + m, _ := NewModulusFromBig(aPlusOne) + monty := new(Nat).set(a) + monty.montgomeryRepresentation(m) + aAgain := new(Nat).set(monty) + aAgain.montgomeryMul(monty, one, m) + if a.Equal(aAgain) != 1 { + t.Errorf("%v != %v", a, aAgain) + return false + } + return true + }, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func TestShiftIn(t *testing.T) { + if bits.UintSize != 64 { + t.Skip("examples are only valid in 64 bit") + } + examples := []struct { + m, x, expected []byte + y uint64 + }{{ + m: []byte{13}, + x: []byte{0}, + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{2}, + }, { + m: []byte{13}, + x: []byte{7}, + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{10}, + }, { + m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, + x: make([]byte, 9), + y: 0xFFFF_FFFF_FFFF_FFFF, + expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, { + m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, + x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + y: 0, + expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06}, + }} + + for i, tt := range examples { + m := modulusFromBytes(tt.m) + got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) + if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 { + t.Errorf("%d: got %v, expected %v", i, got, exp) + } + } +} + +func TestModulusAndNatSizes(t *testing.T) { + // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as + // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two + // limbs, if they are not, they fit in three. This can be a problem because + // modulus strips leading zeroes and nat does not. + m := modulusFromBytes([]byte{ + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe} + natFromBytes(xb).ExpandFor(m) // must not panic for shrinking + NewNat().SetBytes(xb, m) +} + +func TestSetBytes(t *testing.T) { + tests := []struct { + m, b []byte + fail bool + }{{ + m: []byte{0xff, 0xff}, + b: []byte{0x00, 0x01}, + }, { + m: []byte{0xff, 0xff}, + b: []byte{0xff, 0xff}, + fail: true, + }, { + m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0x00, 0x01}, + }, { + m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + }, { + m: []byte{0xff, 0xff}, + b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + fail: true, + }, { + m: []byte{0xff, 0xff}, + b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + fail: true, + }, { + m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + }, { + m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + fail: true, + }, { + m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + fail: true, + }, { + m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + fail: true, + }, { + m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd}, + b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + fail: true, + }} + + for i, tt := range tests { + m := modulusFromBytes(tt.m) + got, err := NewNat().SetBytes(tt.b, m) + if err != nil { + if !tt.fail { + t.Errorf("%d: unexpected error: %v", i, err) + } + continue + } + if tt.fail { + t.Errorf("%d: unexpected success", i) + continue + } + if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { + t.Errorf("%d: got %v, expected %v", i, got, expected) + } + } + + f := func(xBytes []byte) bool { + m := maxModulus(uint(len(xBytes)*8/_W + 1)) + got, err := NewNat().SetBytes(xBytes, m) + if err != nil { + return false + } + return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes + } + + err := quick.Check(f, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func TestExpand(t *testing.T) { + sliced := []uint{1, 2, 3, 4} + examples := []struct { + in []uint + n int + out []uint + }{{ + []uint{1, 2}, + 4, + []uint{1, 2, 0, 0}, + }, { + sliced[:2], + 4, + []uint{1, 2, 0, 0}, + }, { + []uint{1, 2}, + 2, + []uint{1, 2}, + }} + + for i, tt := range examples { + got := (&Nat{tt.in}).expand(tt.n) + if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { + t.Errorf("%d: got %v, expected %v", i, got, tt.out) + } + } +} + +func TestMod(t *testing.T) { + m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}) + x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) + out := new(Nat) + out.Mod(x, m) + expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) + if out.Equal(expected) != 1 { + t.Errorf("%+v != %+v", out, expected) + } +} + +func TestModSub(t *testing.T) { + m := modulusFromBytes([]byte{13}) + x := &Nat{[]uint{6}} + y := &Nat{[]uint{7}} + x.Sub(y, m) + expected := &Nat{[]uint{12}} + if x.Equal(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } + x.Sub(y, m) + expected = &Nat{[]uint{5}} + if x.Equal(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } +} + +func TestModAdd(t *testing.T) { + m := modulusFromBytes([]byte{13}) + x := &Nat{[]uint{6}} + y := &Nat{[]uint{7}} + x.Add(y, m) + expected := &Nat{[]uint{0}} + if x.Equal(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } + x.Add(y, m) + expected = &Nat{[]uint{7}} + if x.Equal(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } +} + +func TestExp(t *testing.T) { + m := modulusFromBytes([]byte{13}) + x := &Nat{[]uint{3}} + out := &Nat{[]uint{0}} + out.Exp(x, []byte{12}, m) + expected := &Nat{[]uint{1}} + if out.Equal(expected) != 1 { + t.Errorf("%+v != %+v", out, expected) + } +} + +func TestExpShort(t *testing.T) { + m := modulusFromBytes([]byte{13}) + x := &Nat{[]uint{3}} + out := &Nat{[]uint{0}} + out.ExpShort(x, 12, m) + expected := &Nat{[]uint{1}} + if out.Equal(expected) != 1 { + t.Errorf("%+v != %+v", out, expected) + } +} + +// TestMulReductions tests that Mul reduces results equal or slightly greater +// than the modulus. Some Montgomery algorithms don't and need extra care to +// return correct results. See https://go.dev/issue/13907. +func TestMulReductions(t *testing.T) { + // Two short but multi-limb primes. + a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10) + b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) + n := new(big.Int).Mul(a, b) + + N, _ := NewModulusFromBig(n) + A := NewNat().setBig(a).ExpandFor(N) + B := NewNat().setBig(b).ExpandFor(N) + + if A.Mul(B, N).IsZero() != 1 { + t.Error("a * b mod (a * b) != 0") + } + + i := new(big.Int).ModInverse(a, b) + N, _ = NewModulusFromBig(b) + A = NewNat().setBig(a).ExpandFor(N) + I := NewNat().setBig(i).ExpandFor(N) + one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) + + if A.Mul(I, N).Equal(one) != 1 { + t.Error("a * inv(a) mod b != 1") + } +} + +func natBytes(n *Nat) []byte { + return n.Bytes(maxModulus(uint(len(n.limbs)))) +} + +func natFromBytes(b []byte) *Nat { + // Must not use Nat.SetBytes as it's used in TestSetBytes. + bb := new(big.Int).SetBytes(b) + return NewNat().setBig(bb) +} + +func modulusFromBytes(b []byte) *Modulus { + bb := new(big.Int).SetBytes(b) + m, _ := NewModulusFromBig(bb) + return m +} + +// maxModulus returns the biggest modulus that can fit in n limbs. +func maxModulus(n uint) *Modulus { + b := big.NewInt(1) + b.Lsh(b, n*_W) + b.Sub(b, big.NewInt(1)) + m, _ := NewModulusFromBig(b) + return m +} + +func makeBenchmarkModulus() *Modulus { + return maxModulus(32) +} + +func makeBenchmarkValue() *Nat { + x := make([]uint, 32) + for i := 0; i < 32; i++ { + x[i]-- + } + return &Nat{limbs: x} +} + +func makeBenchmarkExponent() []byte { + e := make([]byte, 256) + for i := 0; i < 32; i++ { + e[i] = 0xFF + } + return e +} + +func BenchmarkModAdd(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Add(y, m) + } +} + +func BenchmarkModSub(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Sub(y, m) + } +} + +func BenchmarkMontgomeryRepr(b *testing.B) { + x := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.montgomeryRepresentation(m) + } +} + +func BenchmarkMontgomeryMul(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + out := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.montgomeryMul(x, y, m) + } +} + +func BenchmarkModMul(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mul(y, m) + } +} + +func BenchmarkExpBig(b *testing.B) { + out := new(big.Int) + exponentBytes := makeBenchmarkExponent() + x := new(big.Int).SetBytes(exponentBytes) + e := new(big.Int).SetBytes(exponentBytes) + n := new(big.Int).SetBytes(exponentBytes) + one := new(big.Int).SetUint64(1) + n.Add(n, one) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.Exp(x, e, n) + } +} + +func BenchmarkExp(b *testing.B) { + x := makeBenchmarkValue() + e := makeBenchmarkExponent() + out := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.Exp(x, e, m) + } +} + +func TestNewModFromBigZero(t *testing.T) { + expected := "modulus must be >= 0" + _, err := NewModulusFromBig(big.NewInt(0)) + if err == nil || err.Error() != expected { + t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) + } + + expected = "modulus must be odd" + _, err = NewModulusFromBig(big.NewInt(2)) + if err == nil || err.Error() != expected { + t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) + } +} |