diff options
Diffstat (limited to 'src/image/jpeg/huffman.go')
-rw-r--r-- | src/image/jpeg/huffman.go | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/src/image/jpeg/huffman.go b/src/image/jpeg/huffman.go new file mode 100644 index 0000000..7244436 --- /dev/null +++ b/src/image/jpeg/huffman.go @@ -0,0 +1,247 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "io" +) + +// maxCodeLength is the maximum (inclusive) number of bits in a Huffman code. +const maxCodeLength = 16 + +// maxNCodes is the maximum (inclusive) number of codes in a Huffman tree. +const maxNCodes = 256 + +// lutSize is the log-2 size of the Huffman decoder's look-up table. +const lutSize = 8 + +// huffman is a Huffman decoder, specified in section C. +type huffman struct { + // length is the number of codes in the tree. + nCodes int32 + // lut is the look-up table for the next lutSize bits in the bit-stream. + // The high 8 bits of the uint16 are the encoded value. The low 8 bits + // are 1 plus the code length, or 0 if the value is too large to fit in + // lutSize bits. + lut [1 << lutSize]uint16 + // vals are the decoded values, sorted by their encoding. + vals [maxNCodes]uint8 + // minCodes[i] is the minimum code of length i, or -1 if there are no + // codes of that length. + minCodes [maxCodeLength]int32 + // maxCodes[i] is the maximum code of length i, or -1 if there are no + // codes of that length. + maxCodes [maxCodeLength]int32 + // valsIndices[i] is the index into vals of minCodes[i]. + valsIndices [maxCodeLength]int32 +} + +// errShortHuffmanData means that an unexpected EOF occurred while decoding +// Huffman data. +var errShortHuffmanData = FormatError("short Huffman data") + +// ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at +// least n. For best performance (avoiding function calls inside hot loops), +// the caller is the one responsible for first checking that d.bits.n < n. +func (d *decoder) ensureNBits(n int32) error { + for { + c, err := d.readByteStuffedByte() + if err != nil { + if err == io.ErrUnexpectedEOF { + return errShortHuffmanData + } + return err + } + d.bits.a = d.bits.a<<8 | uint32(c) + d.bits.n += 8 + if d.bits.m == 0 { + d.bits.m = 1 << 7 + } else { + d.bits.m <<= 8 + } + if d.bits.n >= n { + break + } + } + return nil +} + +// receiveExtend is the composition of RECEIVE and EXTEND, specified in section +// F.2.2.1. +func (d *decoder) receiveExtend(t uint8) (int32, error) { + if d.bits.n < int32(t) { + if err := d.ensureNBits(int32(t)); err != nil { + return 0, err + } + } + d.bits.n -= int32(t) + d.bits.m >>= t + s := int32(1) << t + x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1) + if x < s>>1 { + x += ((-1) << t) + 1 + } + return x, nil +} + +// processDHT processes a Define Huffman Table marker, and initializes a huffman +// struct from its contents. Specified in section B.2.4.2. +func (d *decoder) processDHT(n int) error { + for n > 0 { + if n < 17 { + return FormatError("DHT has wrong length") + } + if err := d.readFull(d.tmp[:17]); err != nil { + return err + } + tc := d.tmp[0] >> 4 + if tc > maxTc { + return FormatError("bad Tc value") + } + th := d.tmp[0] & 0x0f + // The baseline th <= 1 restriction is specified in table B.5. + if th > maxTh || (d.baseline && th > 1) { + return FormatError("bad Th value") + } + h := &d.huff[tc][th] + + // Read nCodes and h.vals (and derive h.nCodes). + // nCodes[i] is the number of codes with code length i. + // h.nCodes is the total number of codes. + h.nCodes = 0 + var nCodes [maxCodeLength]int32 + for i := range nCodes { + nCodes[i] = int32(d.tmp[i+1]) + h.nCodes += nCodes[i] + } + if h.nCodes == 0 { + return FormatError("Huffman table has zero length") + } + if h.nCodes > maxNCodes { + return FormatError("Huffman table has excessive length") + } + n -= int(h.nCodes) + 17 + if n < 0 { + return FormatError("DHT has wrong length") + } + if err := d.readFull(h.vals[:h.nCodes]); err != nil { + return err + } + + // Derive the look-up table. + for i := range h.lut { + h.lut[i] = 0 + } + var x, code uint32 + for i := uint32(0); i < lutSize; i++ { + code <<= 1 + for j := int32(0); j < nCodes[i]; j++ { + // The codeLength is 1+i, so shift code by 8-(1+i) to + // calculate the high bits for every 8-bit sequence + // whose codeLength's high bits matches code. + // The high 8 bits of lutValue are the encoded value. + // The low 8 bits are 1 plus the codeLength. + base := uint8(code << (7 - i)) + lutValue := uint16(h.vals[x])<<8 | uint16(2+i) + for k := uint8(0); k < 1<<(7-i); k++ { + h.lut[base|k] = lutValue + } + code++ + x++ + } + } + + // Derive minCodes, maxCodes, and valsIndices. + var c, index int32 + for i, n := range nCodes { + if n == 0 { + h.minCodes[i] = -1 + h.maxCodes[i] = -1 + h.valsIndices[i] = -1 + } else { + h.minCodes[i] = c + h.maxCodes[i] = c + n - 1 + h.valsIndices[i] = index + c += n + index += n + } + c <<= 1 + } + } + return nil +} + +// decodeHuffman returns the next Huffman-coded value from the bit-stream, +// decoded according to h. +func (d *decoder) decodeHuffman(h *huffman) (uint8, error) { + if h.nCodes == 0 { + return 0, FormatError("uninitialized Huffman table") + } + + if d.bits.n < 8 { + if err := d.ensureNBits(8); err != nil { + if err != errMissingFF00 && err != errShortHuffmanData { + return 0, err + } + // There are no more bytes of data in this segment, but we may still + // be able to read the next symbol out of the previously read bits. + // First, undo the readByte that the ensureNBits call made. + if d.bytes.nUnreadable != 0 { + d.unreadByteStuffedByte() + } + goto slowPath + } + } + if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 { + n := (v & 0xff) - 1 + d.bits.n -= int32(n) + d.bits.m >>= n + return uint8(v >> 8), nil + } + +slowPath: + for i, code := 0, int32(0); i < maxCodeLength; i++ { + if d.bits.n == 0 { + if err := d.ensureNBits(1); err != nil { + return 0, err + } + } + if d.bits.a&d.bits.m != 0 { + code |= 1 + } + d.bits.n-- + d.bits.m >>= 1 + if code <= h.maxCodes[i] { + return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil + } + code <<= 1 + } + return 0, FormatError("bad Huffman code") +} + +func (d *decoder) decodeBit() (bool, error) { + if d.bits.n == 0 { + if err := d.ensureNBits(1); err != nil { + return false, err + } + } + ret := d.bits.a&d.bits.m != 0 + d.bits.n-- + d.bits.m >>= 1 + return ret, nil +} + +func (d *decoder) decodeBits(n int32) (uint32, error) { + if d.bits.n < n { + if err := d.ensureNBits(n); err != nil { + return 0, err + } + } + ret := d.bits.a >> uint32(d.bits.n-n) + ret &= (1 << uint32(n)) - 1 + d.bits.n -= n + d.bits.m >>= uint32(n) + return ret, nil +} |