summaryrefslogtreecommitdiffstats
path: root/src/crypto/internal/bigmod/nat_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/internal/bigmod/nat_test.go')
-rw-r--r--src/crypto/internal/bigmod/nat_test.go412
1 files changed, 412 insertions, 0 deletions
diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go
new file mode 100644
index 0000000..6431d25
--- /dev/null
+++ b/src/crypto/internal/bigmod/nat_test.go
@@ -0,0 +1,412 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bigmod
+
+import (
+ "math/big"
+ "math/bits"
+ "math/rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+)
+
+// Generate generates an even nat. It's used by testing/quick to produce random
+// *nat values for quick.Check invocations.
+func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
+ limbs := make([]uint, size)
+ for i := 0; i < size; i++ {
+ limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
+ }
+ return reflect.ValueOf(&Nat{limbs})
+}
+
+func testModAddCommutative(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ aPlusB := new(Nat).set(a)
+ aPlusB.Add(b, m)
+ bPlusA := new(Nat).set(b)
+ bPlusA.Add(a, m)
+ return aPlusB.Equal(bPlusA) == 1
+}
+
+func TestModAddCommutative(t *testing.T) {
+ err := quick.Check(testModAddCommutative, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ original := new(Nat).set(a)
+ a.Sub(b, m)
+ a.Add(b, m)
+ return a.Equal(original) == 1
+}
+
+func TestModSubThenAddIdentity(t *testing.T) {
+ err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testMontgomeryRoundtrip(a *Nat) bool {
+ one := &Nat{make([]uint, len(a.limbs))}
+ one.limbs[0] = 1
+ aPlusOne := new(big.Int).SetBytes(natBytes(a))
+ aPlusOne.Add(aPlusOne, big.NewInt(1))
+ m := NewModulusFromBig(aPlusOne)
+ monty := new(Nat).set(a)
+ monty.montgomeryRepresentation(m)
+ aAgain := new(Nat).set(monty)
+ aAgain.montgomeryMul(monty, one, m)
+ return a.Equal(aAgain) == 1
+}
+
+func TestMontgomeryRoundtrip(t *testing.T) {
+ err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestShiftIn(t *testing.T) {
+ if bits.UintSize != 64 {
+ t.Skip("examples are only valid in 64 bit")
+ }
+ examples := []struct {
+ m, x, expected []byte
+ y uint64
+ }{{
+ m: []byte{13},
+ x: []byte{0},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{7},
+ }, {
+ m: []byte{13},
+ x: []byte{7},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{11},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: make([]byte, 9),
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ y: 0,
+ expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+ }}
+
+ for i, tt := range examples {
+ m := modulusFromBytes(tt.m)
+ got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
+ if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+ }
+ }
+}
+
+func TestModulusAndNatSizes(t *testing.T) {
+ // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
+ // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
+ // limbs, if they are not, they fit in three. This can be a problem because
+ // modulus strips leading zeroes and nat does not.
+ m := modulusFromBytes([]byte{
+ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+ xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
+ natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
+ NewNat().SetBytes(xb, m)
+}
+
+func TestSetBytes(t *testing.T) {
+ tests := []struct {
+ m, b []byte
+ fail bool
+ }{{
+ m: []byte{0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }}
+
+ for i, tt := range tests {
+ m := modulusFromBytes(tt.m)
+ got, err := NewNat().SetBytes(tt.b, m)
+ if err != nil {
+ if !tt.fail {
+ t.Errorf("%d: unexpected error: %v", i, err)
+ }
+ continue
+ }
+ if err == nil && tt.fail {
+ t.Errorf("%d: unexpected success", i)
+ continue
+ }
+ if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
+ t.Errorf("%d: got %x, expected %x", i, got, expected)
+ }
+ }
+
+ f := func(xBytes []byte) bool {
+ m := maxModulus(uint(len(xBytes)*8/_W + 1))
+ got, err := NewNat().SetBytes(xBytes, m)
+ if err != nil {
+ return false
+ }
+ return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
+ }
+
+ err := quick.Check(f, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestExpand(t *testing.T) {
+ sliced := []uint{1, 2, 3, 4}
+ examples := []struct {
+ in []uint
+ n int
+ out []uint
+ }{{
+ []uint{1, 2},
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ sliced[:2],
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ []uint{1, 2},
+ 2,
+ []uint{1, 2},
+ }}
+
+ for i, tt := range examples {
+ got := (&Nat{tt.in}).expand(tt.n)
+ if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+ }
+ }
+}
+
+func TestMod(t *testing.T) {
+ m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
+ x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+ out := new(Nat)
+ out.Mod(x, m)
+ expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func TestModSub(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Sub(y, m)
+ expected := &Nat{[]uint{12}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Sub(y, m)
+ expected = &Nat{[]uint{5}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestModAdd(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Add(y, m)
+ expected := &Nat{[]uint{0}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Add(y, m)
+ expected = &Nat{[]uint{7}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestExp(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{3}}
+ out := &Nat{[]uint{0}}
+ out.Exp(x, []byte{12}, m)
+ expected := &Nat{[]uint{1}}
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func natBytes(n *Nat) []byte {
+ return n.Bytes(maxModulus(uint(len(n.limbs))))
+}
+
+func natFromBytes(b []byte) *Nat {
+ bb := new(big.Int).SetBytes(b)
+ return NewNat().setBig(bb)
+}
+
+func modulusFromBytes(b []byte) *Modulus {
+ bb := new(big.Int).SetBytes(b)
+ return NewModulusFromBig(bb)
+}
+
+// maxModulus returns the biggest modulus that can fit in n limbs.
+func maxModulus(n uint) *Modulus {
+ m := big.NewInt(1)
+ m.Lsh(m, n*_W)
+ m.Sub(m, big.NewInt(1))
+ return NewModulusFromBig(m)
+}
+
+func makeBenchmarkModulus() *Modulus {
+ return maxModulus(32)
+}
+
+func makeBenchmarkValue() *Nat {
+ x := make([]uint, 32)
+ for i := 0; i < 32; i++ {
+ x[i] = _MASK - 1
+ }
+ return &Nat{limbs: x}
+}
+
+func makeBenchmarkExponent() []byte {
+ e := make([]byte, 256)
+ for i := 0; i < 32; i++ {
+ e[i] = 0xFF
+ }
+ return e
+}
+
+func BenchmarkModAdd(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Add(y, m)
+ }
+}
+
+func BenchmarkModSub(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Sub(y, m)
+ }
+}
+
+func BenchmarkMontgomeryRepr(b *testing.B) {
+ x := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.montgomeryRepresentation(m)
+ }
+}
+
+func BenchmarkMontgomeryMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.montgomeryMul(x, y, m)
+ }
+}
+
+func BenchmarkModMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Mul(y, m)
+ }
+}
+
+func BenchmarkExpBig(b *testing.B) {
+ out := new(big.Int)
+ exponentBytes := makeBenchmarkExponent()
+ x := new(big.Int).SetBytes(exponentBytes)
+ e := new(big.Int).SetBytes(exponentBytes)
+ n := new(big.Int).SetBytes(exponentBytes)
+ one := new(big.Int).SetUint64(1)
+ n.Add(n, one)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, n)
+ }
+}
+
+func BenchmarkExp(b *testing.B) {
+ x := makeBenchmarkValue()
+ e := makeBenchmarkExponent()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, m)
+ }
+}