summaryrefslogtreecommitdiffstats
path: root/src/crypto/elliptic/elliptic_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/elliptic/elliptic_test.go')
-rw-r--r--src/crypto/elliptic/elliptic_test.go331
1 files changed, 331 insertions, 0 deletions
diff --git a/src/crypto/elliptic/elliptic_test.go b/src/crypto/elliptic/elliptic_test.go
new file mode 100644
index 0000000..3fe53c5
--- /dev/null
+++ b/src/crypto/elliptic/elliptic_test.go
@@ -0,0 +1,331 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package elliptic
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/hex"
+ "math/big"
+ "testing"
+)
+
+// genericParamsForCurve returns the dereferenced CurveParams for
+// the specified curve. This is used to avoid the logic for
+// upgrading a curve to it's specific implementation, forcing
+// usage of the generic implementation. This is only relevant
+// for the P224, P256, and P521 curves.
+func genericParamsForCurve(c Curve) *CurveParams {
+ d := *(c.Params())
+ return &d
+}
+
+func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
+ tests := []struct {
+ name string
+ curve Curve
+ }{
+ {"P256", P256()},
+ {"P256/Params", genericParamsForCurve(P256())},
+ {"P224", P224()},
+ {"P224/Params", genericParamsForCurve(P224())},
+ {"P384", P384()},
+ {"P384/Params", genericParamsForCurve(P384())},
+ {"P521", P521()},
+ {"P521/Params", genericParamsForCurve(P521())},
+ }
+ if testing.Short() {
+ tests = tests[:1]
+ }
+ for _, test := range tests {
+ curve := test.curve
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+ f(t, curve)
+ })
+ }
+}
+
+func TestOnCurve(t *testing.T) {
+ testAllCurves(t, func(t *testing.T, curve Curve) {
+ if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
+ t.Error("basepoint is not on the curve")
+ }
+ })
+}
+
+func TestOffCurve(t *testing.T) {
+ testAllCurves(t, func(t *testing.T, curve Curve) {
+ x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
+ if curve.IsOnCurve(x, y) {
+ t.Errorf("point off curve is claimed to be on the curve")
+ }
+ b := Marshal(curve, x, y)
+ x1, y1 := Unmarshal(curve, b)
+ if x1 != nil || y1 != nil {
+ t.Errorf("unmarshaling a point not on the curve succeeded")
+ }
+ })
+}
+
+func TestInfinity(t *testing.T) {
+ testAllCurves(t, testInfinity)
+}
+
+func testInfinity(t *testing.T, curve Curve) {
+ _, x, y, _ := GenerateKey(curve, rand.Reader)
+ x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
+ if x.Sign() != 0 || y.Sign() != 0 {
+ t.Errorf("x^q != ∞")
+ }
+
+ x, y = curve.ScalarBaseMult([]byte{0})
+ if x.Sign() != 0 || y.Sign() != 0 {
+ t.Errorf("b^0 != ∞")
+ x.SetInt64(0)
+ y.SetInt64(0)
+ }
+
+ x2, y2 := curve.Double(x, y)
+ if x2.Sign() != 0 || y2.Sign() != 0 {
+ t.Errorf("2∞ != ∞")
+ }
+
+ baseX := curve.Params().Gx
+ baseY := curve.Params().Gy
+
+ x3, y3 := curve.Add(baseX, baseY, x, y)
+ if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
+ t.Errorf("x+∞ != x")
+ }
+
+ x4, y4 := curve.Add(x, y, baseX, baseY)
+ if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
+ t.Errorf("∞+x != x")
+ }
+
+ if curve.IsOnCurve(x, y) {
+ t.Errorf("IsOnCurve(∞) == true")
+ }
+}
+
+func TestMarshal(t *testing.T) {
+ testAllCurves(t, func(t *testing.T, curve Curve) {
+ _, x, y, err := GenerateKey(curve, rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ serialized := Marshal(curve, x, y)
+ xx, yy := Unmarshal(curve, serialized)
+ if xx == nil {
+ t.Fatal("failed to unmarshal")
+ }
+ if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
+ t.Fatal("unmarshal returned different values")
+ }
+ })
+}
+
+func TestUnmarshalToLargeCoordinates(t *testing.T) {
+ // See https://golang.org/issues/20482.
+ testAllCurves(t, testUnmarshalToLargeCoordinates)
+}
+
+func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
+ p := curve.Params().P
+ byteLen := (p.BitLen() + 7) / 8
+
+ // Set x to be greater than curve's parameter P – specifically, to P+5.
+ // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
+ // curve.
+ x := new(big.Int).Add(p, big.NewInt(5))
+ y := curve.Params().polynomial(x)
+ y.ModSqrt(y, p)
+
+ invalid := make([]byte, byteLen*2+1)
+ invalid[0] = 4 // uncompressed encoding
+ x.FillBytes(invalid[1 : 1+byteLen])
+ y.FillBytes(invalid[1+byteLen:])
+
+ if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
+ t.Errorf("Unmarshal accepts invalid X coordinate")
+ }
+
+ if curve == p256 {
+ // This is a point on the curve with a small y value, small enough that
+ // we can add p and still be within 32 bytes.
+ x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
+ y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
+ y.Add(y, p)
+
+ if p.Cmp(y) > 0 || y.BitLen() != 256 {
+ t.Fatal("y not within expected range")
+ }
+
+ // marshal
+ x.FillBytes(invalid[1 : 1+byteLen])
+ y.FillBytes(invalid[1+byteLen:])
+
+ if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
+ t.Errorf("Unmarshal accepts invalid Y coordinate")
+ }
+ }
+}
+
+// TestInvalidCoordinates tests big.Int values that are not valid field elements
+// (negative or bigger than P). They are expected to return false from
+// IsOnCurve, all other behavior is undefined.
+func TestInvalidCoordinates(t *testing.T) {
+ testAllCurves(t, testInvalidCoordinates)
+}
+
+func testInvalidCoordinates(t *testing.T, curve Curve) {
+ checkIsOnCurveFalse := func(name string, x, y *big.Int) {
+ if curve.IsOnCurve(x, y) {
+ t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
+ }
+ }
+
+ p := curve.Params().P
+ _, x, y, _ := GenerateKey(curve, rand.Reader)
+ xx, yy := new(big.Int), new(big.Int)
+
+ // Check if the sign is getting dropped.
+ xx.Neg(x)
+ checkIsOnCurveFalse("-x, y", xx, y)
+ yy.Neg(y)
+ checkIsOnCurveFalse("x, -y", x, yy)
+
+ // Check if negative values are reduced modulo P.
+ xx.Sub(x, p)
+ checkIsOnCurveFalse("x-P, y", xx, y)
+ yy.Sub(y, p)
+ checkIsOnCurveFalse("x, y-P", x, yy)
+
+ // Check if positive values are reduced modulo P.
+ xx.Add(x, p)
+ checkIsOnCurveFalse("x+P, y", xx, y)
+ yy.Add(y, p)
+ checkIsOnCurveFalse("x, y+P", x, yy)
+
+ // Check if the overflow is dropped.
+ xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
+ checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
+ yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
+ checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
+
+ // Check if P is treated like zero (if possible).
+ // y^2 = x^3 - 3x + B
+ // y = mod_sqrt(x^3 - 3x + B)
+ // y = mod_sqrt(B) if x = 0
+ // If there is no modsqrt, there is no point with x = 0, can't test x = P.
+ if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
+ if !curve.IsOnCurve(big.NewInt(0), yy) {
+ t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
+ }
+ checkIsOnCurveFalse("P, y", p, yy)
+ }
+}
+
+func TestMarshalCompressed(t *testing.T) {
+ t.Run("P-256/03", func(t *testing.T) {
+ data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
+ x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
+ y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
+ testMarshalCompressed(t, P256(), x, y, data)
+ })
+ t.Run("P-256/02", func(t *testing.T) {
+ data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
+ x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
+ y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
+ testMarshalCompressed(t, P256(), x, y, data)
+ })
+
+ t.Run("Invalid", func(t *testing.T) {
+ data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
+ X, Y := UnmarshalCompressed(P256(), data)
+ if X != nil || Y != nil {
+ t.Error("expected an error for invalid encoding")
+ }
+ })
+
+ if testing.Short() {
+ t.Skip("skipping other curves on short test")
+ }
+
+ testAllCurves(t, func(t *testing.T, curve Curve) {
+ _, x, y, err := GenerateKey(curve, rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testMarshalCompressed(t, curve, x, y, nil)
+ })
+
+}
+
+func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
+ if !curve.IsOnCurve(x, y) {
+ t.Fatal("invalid test point")
+ }
+ got := MarshalCompressed(curve, x, y)
+ if want != nil && !bytes.Equal(got, want) {
+ t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
+ }
+
+ X, Y := UnmarshalCompressed(curve, got)
+ if X == nil || Y == nil {
+ t.Fatalf("UnmarshalCompressed failed unexpectedly")
+ }
+
+ if !curve.IsOnCurve(X, Y) {
+ t.Error("UnmarshalCompressed returned a point not on the curve")
+ }
+ if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
+ t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
+ }
+}
+
+func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
+ tests := []struct {
+ name string
+ curve Curve
+ }{
+ {"P256", P256()},
+ {"P224", P224()},
+ {"P384", P384()},
+ {"P521", P521()},
+ }
+ for _, test := range tests {
+ curve := test.curve
+ t.Run(test.name, func(t *testing.B) {
+ f(t, curve)
+ })
+ }
+}
+
+func BenchmarkScalarBaseMult(b *testing.B) {
+ benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
+ priv, _, _, _ := GenerateKey(curve, rand.Reader)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x, _ := curve.ScalarBaseMult(priv)
+ // Prevent the compiler from optimizing out the operation.
+ priv[0] ^= byte(x.Bits()[0])
+ }
+ })
+}
+
+func BenchmarkScalarMult(b *testing.B) {
+ benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
+ _, x, y, _ := GenerateKey(curve, rand.Reader)
+ priv, _, _, _ := GenerateKey(curve, rand.Reader)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x, y = curve.ScalarMult(x, y, priv)
+ }
+ })
+}