summaryrefslogtreecommitdiffstats
path: root/src/encoding/asn1/marshal.go
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 13:16:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 13:16:40 +0000
commit47ab3d4a42e9ab51c465c4322d2ec233f6324e6b (patch)
treea61a0ffd83f4a3def4b36e5c8e99630c559aa723 /src/encoding/asn1/marshal.go
parentInitial commit. (diff)
downloadgolang-1.18-upstream.tar.xz
golang-1.18-upstream.zip
Adding upstream version 1.18.10.upstream/1.18.10upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/encoding/asn1/marshal.go')
-rw-r--r--src/encoding/asn1/marshal.go747
1 files changed, 747 insertions, 0 deletions
diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go
new file mode 100644
index 0000000..c243349
--- /dev/null
+++ b/src/encoding/asn1/marshal.go
@@ -0,0 +1,747 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package asn1
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "math/big"
+ "reflect"
+ "sort"
+ "time"
+ "unicode/utf8"
+)
+
+var (
+ byte00Encoder encoder = byteEncoder(0x00)
+ byteFFEncoder encoder = byteEncoder(0xff)
+)
+
+// encoder represents an ASN.1 element that is waiting to be marshaled.
+type encoder interface {
+ // Len returns the number of bytes needed to marshal this element.
+ Len() int
+ // Encode encodes this element by writing Len() bytes to dst.
+ Encode(dst []byte)
+}
+
+type byteEncoder byte
+
+func (c byteEncoder) Len() int {
+ return 1
+}
+
+func (c byteEncoder) Encode(dst []byte) {
+ dst[0] = byte(c)
+}
+
+type bytesEncoder []byte
+
+func (b bytesEncoder) Len() int {
+ return len(b)
+}
+
+func (b bytesEncoder) Encode(dst []byte) {
+ if copy(dst, b) != len(b) {
+ panic("internal error")
+ }
+}
+
+type stringEncoder string
+
+func (s stringEncoder) Len() int {
+ return len(s)
+}
+
+func (s stringEncoder) Encode(dst []byte) {
+ if copy(dst, s) != len(s) {
+ panic("internal error")
+ }
+}
+
+type multiEncoder []encoder
+
+func (m multiEncoder) Len() int {
+ var size int
+ for _, e := range m {
+ size += e.Len()
+ }
+ return size
+}
+
+func (m multiEncoder) Encode(dst []byte) {
+ var off int
+ for _, e := range m {
+ e.Encode(dst[off:])
+ off += e.Len()
+ }
+}
+
+type setEncoder []encoder
+
+func (s setEncoder) Len() int {
+ var size int
+ for _, e := range s {
+ size += e.Len()
+ }
+ return size
+}
+
+func (s setEncoder) Encode(dst []byte) {
+ // Per X690 Section 11.6: The encodings of the component values of a
+ // set-of value shall appear in ascending order, the encodings being
+ // compared as octet strings with the shorter components being padded
+ // at their trailing end with 0-octets.
+ //
+ // First we encode each element to its TLV encoding and then use
+ // octetSort to get the ordering expected by X690 DER rules before
+ // writing the sorted encodings out to dst.
+ l := make([][]byte, len(s))
+ for i, e := range s {
+ l[i] = make([]byte, e.Len())
+ e.Encode(l[i])
+ }
+
+ sort.Slice(l, func(i, j int) bool {
+ // Since we are using bytes.Compare to compare TLV encodings we
+ // don't need to right pad s[i] and s[j] to the same length as
+ // suggested in X690. If len(s[i]) < len(s[j]) the length octet of
+ // s[i], which is the first determining byte, will inherently be
+ // smaller than the length octet of s[j]. This lets us skip the
+ // padding step.
+ return bytes.Compare(l[i], l[j]) < 0
+ })
+
+ var off int
+ for _, b := range l {
+ copy(dst[off:], b)
+ off += len(b)
+ }
+}
+
+type taggedEncoder struct {
+ // scratch contains temporary space for encoding the tag and length of
+ // an element in order to avoid extra allocations.
+ scratch [8]byte
+ tag encoder
+ body encoder
+}
+
+func (t *taggedEncoder) Len() int {
+ return t.tag.Len() + t.body.Len()
+}
+
+func (t *taggedEncoder) Encode(dst []byte) {
+ t.tag.Encode(dst)
+ t.body.Encode(dst[t.tag.Len():])
+}
+
+type int64Encoder int64
+
+func (i int64Encoder) Len() int {
+ n := 1
+
+ for i > 127 {
+ n++
+ i >>= 8
+ }
+
+ for i < -128 {
+ n++
+ i >>= 8
+ }
+
+ return n
+}
+
+func (i int64Encoder) Encode(dst []byte) {
+ n := i.Len()
+
+ for j := 0; j < n; j++ {
+ dst[j] = byte(i >> uint((n-1-j)*8))
+ }
+}
+
+func base128IntLength(n int64) int {
+ if n == 0 {
+ return 1
+ }
+
+ l := 0
+ for i := n; i > 0; i >>= 7 {
+ l++
+ }
+
+ return l
+}
+
+func appendBase128Int(dst []byte, n int64) []byte {
+ l := base128IntLength(n)
+
+ for i := l - 1; i >= 0; i-- {
+ o := byte(n >> uint(i*7))
+ o &= 0x7f
+ if i != 0 {
+ o |= 0x80
+ }
+
+ dst = append(dst, o)
+ }
+
+ return dst
+}
+
+func makeBigInt(n *big.Int) (encoder, error) {
+ if n == nil {
+ return nil, StructuralError{"empty integer"}
+ }
+
+ if n.Sign() < 0 {
+ // A negative number has to be converted to two's-complement
+ // form. So we'll invert and subtract 1. If the
+ // most-significant-bit isn't set then we'll need to pad the
+ // beginning with 0xff in order to keep the number negative.
+ nMinus1 := new(big.Int).Neg(n)
+ nMinus1.Sub(nMinus1, bigOne)
+ bytes := nMinus1.Bytes()
+ for i := range bytes {
+ bytes[i] ^= 0xff
+ }
+ if len(bytes) == 0 || bytes[0]&0x80 == 0 {
+ return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
+ }
+ return bytesEncoder(bytes), nil
+ } else if n.Sign() == 0 {
+ // Zero is written as a single 0 zero rather than no bytes.
+ return byte00Encoder, nil
+ } else {
+ bytes := n.Bytes()
+ if len(bytes) > 0 && bytes[0]&0x80 != 0 {
+ // We'll have to pad this with 0x00 in order to stop it
+ // looking like a negative number.
+ return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
+ }
+ return bytesEncoder(bytes), nil
+ }
+}
+
+func appendLength(dst []byte, i int) []byte {
+ n := lengthLength(i)
+
+ for ; n > 0; n-- {
+ dst = append(dst, byte(i>>uint((n-1)*8)))
+ }
+
+ return dst
+}
+
+func lengthLength(i int) (numBytes int) {
+ numBytes = 1
+ for i > 255 {
+ numBytes++
+ i >>= 8
+ }
+ return
+}
+
+func appendTagAndLength(dst []byte, t tagAndLength) []byte {
+ b := uint8(t.class) << 6
+ if t.isCompound {
+ b |= 0x20
+ }
+ if t.tag >= 31 {
+ b |= 0x1f
+ dst = append(dst, b)
+ dst = appendBase128Int(dst, int64(t.tag))
+ } else {
+ b |= uint8(t.tag)
+ dst = append(dst, b)
+ }
+
+ if t.length >= 128 {
+ l := lengthLength(t.length)
+ dst = append(dst, 0x80|byte(l))
+ dst = appendLength(dst, t.length)
+ } else {
+ dst = append(dst, byte(t.length))
+ }
+
+ return dst
+}
+
+type bitStringEncoder BitString
+
+func (b bitStringEncoder) Len() int {
+ return len(b.Bytes) + 1
+}
+
+func (b bitStringEncoder) Encode(dst []byte) {
+ dst[0] = byte((8 - b.BitLength%8) % 8)
+ if copy(dst[1:], b.Bytes) != len(b.Bytes) {
+ panic("internal error")
+ }
+}
+
+type oidEncoder []int
+
+func (oid oidEncoder) Len() int {
+ l := base128IntLength(int64(oid[0]*40 + oid[1]))
+ for i := 2; i < len(oid); i++ {
+ l += base128IntLength(int64(oid[i]))
+ }
+ return l
+}
+
+func (oid oidEncoder) Encode(dst []byte) {
+ dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
+ for i := 2; i < len(oid); i++ {
+ dst = appendBase128Int(dst, int64(oid[i]))
+ }
+}
+
+func makeObjectIdentifier(oid []int) (e encoder, err error) {
+ if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
+ return nil, StructuralError{"invalid object identifier"}
+ }
+
+ return oidEncoder(oid), nil
+}
+
+func makePrintableString(s string) (e encoder, err error) {
+ for i := 0; i < len(s); i++ {
+ // The asterisk is often used in PrintableString, even though
+ // it is invalid. If a PrintableString was specifically
+ // requested then the asterisk is permitted by this code.
+ // Ampersand is allowed in parsing due a handful of CA
+ // certificates, however when making new certificates
+ // it is rejected.
+ if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
+ return nil, StructuralError{"PrintableString contains invalid character"}
+ }
+ }
+
+ return stringEncoder(s), nil
+}
+
+func makeIA5String(s string) (e encoder, err error) {
+ for i := 0; i < len(s); i++ {
+ if s[i] > 127 {
+ return nil, StructuralError{"IA5String contains invalid character"}
+ }
+ }
+
+ return stringEncoder(s), nil
+}
+
+func makeNumericString(s string) (e encoder, err error) {
+ for i := 0; i < len(s); i++ {
+ if !isNumeric(s[i]) {
+ return nil, StructuralError{"NumericString contains invalid character"}
+ }
+ }
+
+ return stringEncoder(s), nil
+}
+
+func makeUTF8String(s string) encoder {
+ return stringEncoder(s)
+}
+
+func appendTwoDigits(dst []byte, v int) []byte {
+ return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
+}
+
+func appendFourDigits(dst []byte, v int) []byte {
+ var bytes [4]byte
+ for i := range bytes {
+ bytes[3-i] = '0' + byte(v%10)
+ v /= 10
+ }
+ return append(dst, bytes[:]...)
+}
+
+func outsideUTCRange(t time.Time) bool {
+ year := t.Year()
+ return year < 1950 || year >= 2050
+}
+
+func makeUTCTime(t time.Time) (e encoder, err error) {
+ dst := make([]byte, 0, 18)
+
+ dst, err = appendUTCTime(dst, t)
+ if err != nil {
+ return nil, err
+ }
+
+ return bytesEncoder(dst), nil
+}
+
+func makeGeneralizedTime(t time.Time) (e encoder, err error) {
+ dst := make([]byte, 0, 20)
+
+ dst, err = appendGeneralizedTime(dst, t)
+ if err != nil {
+ return nil, err
+ }
+
+ return bytesEncoder(dst), nil
+}
+
+func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
+ year := t.Year()
+
+ switch {
+ case 1950 <= year && year < 2000:
+ dst = appendTwoDigits(dst, year-1900)
+ case 2000 <= year && year < 2050:
+ dst = appendTwoDigits(dst, year-2000)
+ default:
+ return nil, StructuralError{"cannot represent time as UTCTime"}
+ }
+
+ return appendTimeCommon(dst, t), nil
+}
+
+func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
+ year := t.Year()
+ if year < 0 || year > 9999 {
+ return nil, StructuralError{"cannot represent time as GeneralizedTime"}
+ }
+
+ dst = appendFourDigits(dst, year)
+
+ return appendTimeCommon(dst, t), nil
+}
+
+func appendTimeCommon(dst []byte, t time.Time) []byte {
+ _, month, day := t.Date()
+
+ dst = appendTwoDigits(dst, int(month))
+ dst = appendTwoDigits(dst, day)
+
+ hour, min, sec := t.Clock()
+
+ dst = appendTwoDigits(dst, hour)
+ dst = appendTwoDigits(dst, min)
+ dst = appendTwoDigits(dst, sec)
+
+ _, offset := t.Zone()
+
+ switch {
+ case offset/60 == 0:
+ return append(dst, 'Z')
+ case offset > 0:
+ dst = append(dst, '+')
+ case offset < 0:
+ dst = append(dst, '-')
+ }
+
+ offsetMinutes := offset / 60
+ if offsetMinutes < 0 {
+ offsetMinutes = -offsetMinutes
+ }
+
+ dst = appendTwoDigits(dst, offsetMinutes/60)
+ dst = appendTwoDigits(dst, offsetMinutes%60)
+
+ return dst
+}
+
+func stripTagAndLength(in []byte) []byte {
+ _, offset, err := parseTagAndLength(in, 0)
+ if err != nil {
+ return in
+ }
+ return in[offset:]
+}
+
+func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
+ switch value.Type() {
+ case flagType:
+ return bytesEncoder(nil), nil
+ case timeType:
+ t := value.Interface().(time.Time)
+ if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
+ return makeGeneralizedTime(t)
+ }
+ return makeUTCTime(t)
+ case bitStringType:
+ return bitStringEncoder(value.Interface().(BitString)), nil
+ case objectIdentifierType:
+ return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
+ case bigIntType:
+ return makeBigInt(value.Interface().(*big.Int))
+ }
+
+ switch v := value; v.Kind() {
+ case reflect.Bool:
+ if v.Bool() {
+ return byteFFEncoder, nil
+ }
+ return byte00Encoder, nil
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return int64Encoder(v.Int()), nil
+ case reflect.Struct:
+ t := v.Type()
+
+ for i := 0; i < t.NumField(); i++ {
+ if !t.Field(i).IsExported() {
+ return nil, StructuralError{"struct contains unexported fields"}
+ }
+ }
+
+ startingField := 0
+
+ n := t.NumField()
+ if n == 0 {
+ return bytesEncoder(nil), nil
+ }
+
+ // If the first element of the structure is a non-empty
+ // RawContents, then we don't bother serializing the rest.
+ if t.Field(0).Type == rawContentsType {
+ s := v.Field(0)
+ if s.Len() > 0 {
+ bytes := s.Bytes()
+ /* The RawContents will contain the tag and
+ * length fields but we'll also be writing
+ * those ourselves, so we strip them out of
+ * bytes */
+ return bytesEncoder(stripTagAndLength(bytes)), nil
+ }
+
+ startingField = 1
+ }
+
+ switch n1 := n - startingField; n1 {
+ case 0:
+ return bytesEncoder(nil), nil
+ case 1:
+ return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
+ default:
+ m := make([]encoder, n1)
+ for i := 0; i < n1; i++ {
+ m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return multiEncoder(m), nil
+ }
+ case reflect.Slice:
+ sliceType := v.Type()
+ if sliceType.Elem().Kind() == reflect.Uint8 {
+ return bytesEncoder(v.Bytes()), nil
+ }
+
+ var fp fieldParameters
+
+ switch l := v.Len(); l {
+ case 0:
+ return bytesEncoder(nil), nil
+ case 1:
+ return makeField(v.Index(0), fp)
+ default:
+ m := make([]encoder, l)
+
+ for i := 0; i < l; i++ {
+ m[i], err = makeField(v.Index(i), fp)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if params.set {
+ return setEncoder(m), nil
+ }
+ return multiEncoder(m), nil
+ }
+ case reflect.String:
+ switch params.stringType {
+ case TagIA5String:
+ return makeIA5String(v.String())
+ case TagPrintableString:
+ return makePrintableString(v.String())
+ case TagNumericString:
+ return makeNumericString(v.String())
+ default:
+ return makeUTF8String(v.String()), nil
+ }
+ }
+
+ return nil, StructuralError{"unknown Go type"}
+}
+
+func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
+ if !v.IsValid() {
+ return nil, fmt.Errorf("asn1: cannot marshal nil value")
+ }
+ // If the field is an interface{} then recurse into it.
+ if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
+ return makeField(v.Elem(), params)
+ }
+
+ if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
+ return bytesEncoder(nil), nil
+ }
+
+ if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
+ defaultValue := reflect.New(v.Type()).Elem()
+ defaultValue.SetInt(*params.defaultValue)
+
+ if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
+ return bytesEncoder(nil), nil
+ }
+ }
+
+ // If no default value is given then the zero value for the type is
+ // assumed to be the default value. This isn't obviously the correct
+ // behavior, but it's what Go has traditionally done.
+ if params.optional && params.defaultValue == nil {
+ if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
+ return bytesEncoder(nil), nil
+ }
+ }
+
+ if v.Type() == rawValueType {
+ rv := v.Interface().(RawValue)
+ if len(rv.FullBytes) != 0 {
+ return bytesEncoder(rv.FullBytes), nil
+ }
+
+ t := new(taggedEncoder)
+
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
+ t.body = bytesEncoder(rv.Bytes)
+
+ return t, nil
+ }
+
+ matchAny, tag, isCompound, ok := getUniversalType(v.Type())
+ if !ok || matchAny {
+ return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
+ }
+
+ if params.timeType != 0 && tag != TagUTCTime {
+ return nil, StructuralError{"explicit time type given to non-time member"}
+ }
+
+ if params.stringType != 0 && tag != TagPrintableString {
+ return nil, StructuralError{"explicit string type given to non-string member"}
+ }
+
+ switch tag {
+ case TagPrintableString:
+ if params.stringType == 0 {
+ // This is a string without an explicit string type. We'll use
+ // a PrintableString if the character set in the string is
+ // sufficiently limited, otherwise we'll use a UTF8String.
+ for _, r := range v.String() {
+ if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
+ if !utf8.ValidString(v.String()) {
+ return nil, errors.New("asn1: string not valid UTF-8")
+ }
+ tag = TagUTF8String
+ break
+ }
+ }
+ } else {
+ tag = params.stringType
+ }
+ case TagUTCTime:
+ if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
+ tag = TagGeneralizedTime
+ }
+ }
+
+ if params.set {
+ if tag != TagSequence {
+ return nil, StructuralError{"non sequence tagged as set"}
+ }
+ tag = TagSet
+ }
+
+ // makeField can be called for a slice that should be treated as a SET
+ // but doesn't have params.set set, for instance when using a slice
+ // with the SET type name suffix. In this case getUniversalType returns
+ // TagSet, but makeBody doesn't know about that so will treat the slice
+ // as a sequence. To work around this we set params.set.
+ if tag == TagSet && !params.set {
+ params.set = true
+ }
+
+ t := new(taggedEncoder)
+
+ t.body, err = makeBody(v, params)
+ if err != nil {
+ return nil, err
+ }
+
+ bodyLen := t.body.Len()
+
+ class := ClassUniversal
+ if params.tag != nil {
+ if params.application {
+ class = ClassApplication
+ } else if params.private {
+ class = ClassPrivate
+ } else {
+ class = ClassContextSpecific
+ }
+
+ if params.explicit {
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
+
+ tt := new(taggedEncoder)
+
+ tt.body = t
+
+ tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
+ class: class,
+ tag: *params.tag,
+ length: bodyLen + t.tag.Len(),
+ isCompound: true,
+ }))
+
+ return tt, nil
+ }
+
+ // implicit tag.
+ tag = *params.tag
+ }
+
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
+
+ return t, nil
+}
+
+// Marshal returns the ASN.1 encoding of val.
+//
+// In addition to the struct tags recognised by Unmarshal, the following can be
+// used:
+//
+// ia5: causes strings to be marshaled as ASN.1, IA5String values
+// omitempty: causes empty slices to be skipped
+// printable: causes strings to be marshaled as ASN.1, PrintableString values
+// utf8: causes strings to be marshaled as ASN.1, UTF8String values
+// utc: causes time.Time to be marshaled as ASN.1, UTCTime values
+// generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values
+func Marshal(val any) ([]byte, error) {
+ return MarshalWithParams(val, "")
+}
+
+// MarshalWithParams allows field parameters to be specified for the
+// top-level element. The form of the params is the same as the field tags.
+func MarshalWithParams(val any, params string) ([]byte, error) {
+ e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
+ if err != nil {
+ return nil, err
+ }
+ b := make([]byte, e.Len())
+ e.Encode(b)
+ return b, nil
+}