From be956cd27353a4bb585b1a648e8469cf7adb5edf Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 14 Jan 2022 16:03:48 +0100 Subject: Adding upstream version 0.2.0. Signed-off-by: Daniel Baumann --- src/writer.c | 710 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 710 insertions(+) create mode 100644 src/writer.c (limited to 'src/writer.c') diff --git a/src/writer.c b/src/writer.c new file mode 100644 index 0000000..885ff91 --- /dev/null +++ b/src/writer.c @@ -0,0 +1,710 @@ +/* + * Author Jerry Lundström + * Copyright (c) 2019, OARC, Inc. + * All rights reserved. + * + * This file is part of the dnswire library. + * + * dnswire library is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * dnswire library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with dnswire library. If not, see . + */ + +#include "config.h" + +#include "dnswire/writer.h" +#include "dnswire/trace.h" +#include "dnswire/dnswire.h" + +#include +#include + +const char* const dnswire_writer_state_string[] = { + "encoding_ready", + "writing_ready", + "reading_accept", + "decoding_accept", + "encoding", + "writing", + "stopping", + "encoding_stop", + "writing_stop", + "reading_finish", + "decoding_finish", + "done", +}; + +#define __state(h, s) \ + __trace("state %s => %s", dnswire_writer_state_string[(h)->state], dnswire_writer_state_string[s]); \ + (h)->state = s; + +static struct dnswire_writer _defaults = { + .state = dnswire_writer_encoding, + + .encoder = DNSWIRE_ENCODER_INITIALIZER, + .buf = 0, + .size = DNSWIRE_DEFAULT_BUF_SIZE, + .inc = DNSWIRE_DEFAULT_BUF_SIZE, + .max = DNSWIRE_MAXIMUM_BUF_SIZE, + .at = 0, + .left = 0, + .popped = 0, + + .decoder = DNSWIRE_DECODER_INITIALIZER, + .read_buf = 0, + .read_size = DNSWIRE_DEFAULT_BUF_SIZE, + .read_inc = DNSWIRE_DEFAULT_BUF_SIZE, + .read_max = DNSWIRE_MAXIMUM_BUF_SIZE, + .read_at = 0, + .read_left = 0, + .read_pushed = 0, + + .bidirectional = false, +}; + +enum dnswire_result dnswire_writer_init(struct dnswire_writer* handle) +{ + assert(handle); + + *handle = _defaults; + + if (!(handle->buf = malloc(handle->size))) { + return dnswire_error; + } + + return dnswire_ok; +} + +enum dnswire_result dnswire_writer_set_bidirectional(struct dnswire_writer* handle, bool bidirectional) +{ + assert(handle); + + if (bidirectional) { + if (!handle->read_buf) { + if (!(handle->read_buf = malloc(handle->read_size))) { + return dnswire_error; + } + } + + handle->encoder.state = dnswire_encoder_control_ready; + __state(handle, dnswire_writer_encoding_ready); + } else { + handle->encoder.state = dnswire_encoder_control_start; + __state(handle, dnswire_writer_encoding); + } + + handle->bidirectional = bidirectional; + + return dnswire_ok; +} + +enum dnswire_result dnswire_writer_set_bufsize(struct dnswire_writer* handle, size_t size) +{ + assert(handle); + assert(size); + assert(handle->buf); + + if (handle->left > size) { + // we got data and it doesn't fit in the new size + return dnswire_error; + } + if (size > handle->max) { + // don't expand over max + return dnswire_error; + } + + if (handle->at + handle->left > size) { + // move what's left to the start + if (handle->left) { + memmove(handle->buf, &handle->buf[handle->at], handle->left); + } + handle->at = 0; + } + + uint8_t* buf = realloc(handle->buf, size); + if (!buf) { + return dnswire_error; + } + + handle->buf = buf; + handle->size = size; + + return dnswire_ok; +} + +enum dnswire_result dnswire_writer_set_bufinc(struct dnswire_writer* handle, size_t inc) +{ + assert(handle); + assert(inc); + + handle->inc = inc; + + return dnswire_ok; +} + +enum dnswire_result dnswire_writer_set_bufmax(struct dnswire_writer* handle, size_t max) +{ + assert(handle); + assert(max); + + if (max < handle->size) { + return dnswire_error; + } + + handle->max = max; + + return dnswire_ok; +} + +static enum dnswire_result _encoding(struct dnswire_writer* handle) +{ + enum dnswire_result res; + + while (1) { + res = dnswire_encoder_encode(&handle->encoder, &handle->buf[handle->at], handle->size - handle->at); + __trace("encode %s", dnswire_result_string[res]); + + switch (res) { + case dnswire_ok: + case dnswire_again: + case dnswire_endofdata: + handle->at += dnswire_encoder_encoded(handle->encoder); + handle->left += dnswire_encoder_encoded(handle->encoder); + break; + + case dnswire_need_more: { + if (handle->size >= handle->max) { + // already at max size and it's not enough + return dnswire_error; + } + + // no space left, expand + size_t size = handle->size + handle->inc > handle->max ? handle->max : handle->size + handle->inc; + uint8_t* buf = realloc(handle->buf, size); + if (!buf) { + return dnswire_error; + } + handle->buf = buf; + handle->size = size; + continue; + } + default: + break; + } + break; + } + return res; +} + +enum dnswire_result dnswire_writer_pop(struct dnswire_writer* handle, uint8_t* data, size_t len, uint8_t* in_data, size_t* in_len) +{ + assert(handle); + assert(data); + assert(len); + assert(handle->buf); + assert(!handle->bidirectional || in_data); + + handle->popped = 0; + size_t in_len_orig = 0; + if (in_len) { + in_len_orig = *in_len; + *in_len = 0; + } + + enum dnswire_result res = dnswire_again; + + __trace("state %s", dnswire_writer_state_string[handle->state]); + + switch (handle->state) { + case dnswire_writer_encoding_ready: + res = _encoding(handle); + __trace("left %zu", handle->left); + if (res != dnswire_error && handle->left) { + __state(handle, dnswire_writer_writing); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing_ready: + handle->popped = len < handle->left ? len : handle->left; + memcpy(data, &handle->buf[handle->at - handle->left], handle->popped); + __trace("wrote %zd", handle->popped); + handle->left -= handle->popped; + __trace("left %zu", handle->left); + if (handle->left) { + break; + } + handle->at = 0; + __state(handle, dnswire_writer_reading_accept); + + case dnswire_writer_reading_accept: + if (!in_len_orig) { + return dnswire_need_more; + } + *in_len = handle->read_size - handle->read_at - handle->read_left < in_len_orig ? handle->read_size - handle->read_at - handle->read_left : in_len_orig; + if (*in_len) { + memcpy(&handle->read_buf[handle->read_at + handle->read_left], in_data, *in_len); + __trace("%s", __printable_string(&handle->read_buf[handle->read_at + handle->read_left], *in_len)); + handle->left += *in_len; + } + __state(handle, dnswire_writer_decoding_accept); + // fallthrough + + case dnswire_writer_decoding_accept: + switch (dnswire_decoder_decode(&handle->decoder, &handle->read_buf[handle->read_at], handle->read_left)) { + case dnswire_bidirectional: + handle->read_at += dnswire_decoder_decoded(handle->decoder); + handle->read_left -= dnswire_decoder_decoded(handle->decoder); + if (!handle->read_left) { + handle->read_at = 0; + } + + if (!handle->decoder.accept_support_dnstap_protobuf) { + return dnswire_error; + } + + __state(handle, dnswire_writer_encoding); + return dnswire_again; + + case dnswire_again: + handle->read_at += dnswire_decoder_decoded(handle->decoder); + handle->read_left -= dnswire_decoder_decoded(handle->decoder); + if (!handle->read_left) { + handle->read_at = 0; + __state(handle, dnswire_writer_reading_accept); + } + return dnswire_again; + + case dnswire_need_more: + if (handle->read_left < handle->read_size) { + // still space left to fill + if (handle->read_at) { + // move what's left to the start + if (handle->read_left) { + memmove(handle->read_buf, &handle->read_buf[handle->read_at], handle->read_left); + } + handle->read_at = 0; + } + } else if (handle->read_size < handle->read_max) { + // no space left, expand + size_t size = handle->read_size + handle->read_inc > handle->read_max ? handle->read_max : handle->read_size + handle->read_inc; + uint8_t* buf = realloc(handle->read_buf, size); + if (!buf) { + return dnswire_error; + } + handle->read_buf = buf; + handle->read_size = size; + } else { + // already at max size, and full + return dnswire_error; + } + __state(handle, dnswire_writer_reading_accept); + return dnswire_need_more; + + default: + break; + } + return dnswire_error; + + case dnswire_writer_encoding: + res = _encoding(handle); + __trace("left %zu", handle->left); + if (res != dnswire_error && handle->left) { + __state(handle, dnswire_writer_writing); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing: + handle->popped = len < handle->left ? len : handle->left; + memcpy(data, &handle->buf[handle->at - handle->left], handle->popped); + __trace("wrote %zd", handle->popped); + handle->left -= handle->popped; + __trace("left %zu", handle->left); + if (!handle->left) { + handle->at = 0; + __state(handle, dnswire_writer_encoding); + } + break; + + case dnswire_writer_stopping: + if (handle->left) { + handle->popped = len < handle->left ? len : handle->left; + memcpy(data, &handle->buf[handle->at - handle->left], handle->popped); + __trace("wrote %zd", handle->popped); + handle->left -= handle->popped; + if (handle->left) { + __trace("left %zu", handle->left); + return dnswire_again; + } + handle->at = 0; + } + __state(handle, dnswire_writer_encoding_stop); + // fallthrough + + case dnswire_writer_encoding_stop: + res = _encoding(handle); + if (res == dnswire_endofdata) { + __state(handle, dnswire_writer_writing_stop); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing_stop: + if (handle->left) { + handle->popped = len < handle->left ? len : handle->left; + memcpy(data, &handle->buf[handle->at - handle->left], handle->popped); + __trace("wrote %zd", handle->popped); + handle->left -= handle->popped; + if (handle->left) { + __trace("left %zu", handle->left); + return dnswire_again; + } + handle->at = 0; + } + if (handle->bidirectional) { + __state(handle, dnswire_writer_reading_finish); + return dnswire_again; + } + __state(handle, dnswire_writer_done); + return dnswire_endofdata; + + case dnswire_writer_reading_finish: + if (!in_len_orig) { + return dnswire_need_more; + } + *in_len = handle->read_size - handle->read_at - handle->read_left < in_len_orig ? handle->read_size - handle->read_at - handle->read_left : in_len_orig; + if (*in_len) { + memcpy(&handle->read_buf[handle->read_at + handle->read_left], in_data, *in_len); + __trace("%s", __printable_string(&handle->read_buf[handle->read_at + handle->read_left], *in_len)); + handle->left += *in_len; + } + __state(handle, dnswire_writer_decoding_finish); + // fallthrough + + case dnswire_writer_decoding_finish: + switch (dnswire_decoder_decode(&handle->decoder, &handle->read_buf[handle->read_at], handle->read_left)) { + case dnswire_endofdata: + __state(handle, dnswire_writer_done); + return dnswire_endofdata; + + case dnswire_need_more: + if (handle->read_left < handle->read_size) { + // still space left to fill + if (handle->read_at) { + // move what's left to the start + if (handle->read_left) { + memmove(handle->read_buf, &handle->read_buf[handle->read_at], handle->read_left); + } + handle->read_at = 0; + } + } else if (handle->read_size < handle->read_max) { + // no space left, expand + size_t size = handle->read_size + handle->read_inc > handle->read_max ? handle->read_max : handle->read_size + handle->read_inc; + uint8_t* buf = realloc(handle->read_buf, size); + if (!buf) { + return dnswire_error; + } + handle->read_buf = buf; + handle->read_size = size; + } else { + // already at max size, and full + return dnswire_error; + } + __state(handle, dnswire_writer_reading_accept); + return dnswire_need_more; + + default: + break; + } + return dnswire_error; + + case dnswire_writer_done: + return dnswire_error; + } + + return res; +} + +enum dnswire_result dnswire_writer_write(struct dnswire_writer* handle, int fd) +{ + assert(handle); + assert(handle->buf); + + enum dnswire_result res = dnswire_again; + + __trace("state %s", dnswire_writer_state_string[handle->state]); + + switch (handle->state) { + case dnswire_writer_encoding_ready: + res = _encoding(handle); + __trace("left %zu", handle->left); + if (res != dnswire_error && handle->left) { + __state(handle, dnswire_writer_writing); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing_ready: { + ssize_t nwrote = write(fd, &handle->buf[handle->at - handle->left], handle->left); + __trace("wrote %zd", nwrote); + if (nwrote < 0) { + // TODO + return dnswire_error; + } else if (!nwrote) { + // TODO + return dnswire_error; + } + + handle->left -= nwrote; + __trace("left %zu", handle->left); + if (handle->left) { + break; + } + handle->at = 0; + __state(handle, dnswire_writer_reading_accept); + // fallthrough + } + + case dnswire_writer_reading_accept: { + ssize_t nread = read(fd, &handle->read_buf[handle->read_at + handle->read_left], handle->read_size - handle->read_at - handle->read_left); + if (nread < 0) { + // TODO + return dnswire_error; + } else if (!nread) { + // TODO + return dnswire_error; + } + __trace("%s", __printable_string(&handle->read_buf[handle->read_at + handle->read_left], nread)); + handle->read_left += nread; + __state(handle, dnswire_writer_decoding_accept); + // fallthrough + } + + case dnswire_writer_decoding_accept: + switch (dnswire_decoder_decode(&handle->decoder, &handle->read_buf[handle->read_at], handle->read_left)) { + case dnswire_bidirectional: + handle->read_at += dnswire_decoder_decoded(handle->decoder); + handle->read_left -= dnswire_decoder_decoded(handle->decoder); + if (!handle->read_left) { + handle->read_at = 0; + } + + if (!handle->decoder.accept_support_dnstap_protobuf) { + return dnswire_error; + } + + __state(handle, dnswire_writer_encoding); + return dnswire_again; + + case dnswire_again: + handle->read_at += dnswire_decoder_decoded(handle->decoder); + handle->read_left -= dnswire_decoder_decoded(handle->decoder); + if (!handle->read_left) { + handle->read_at = 0; + __state(handle, dnswire_writer_reading_accept); + } + return dnswire_again; + + case dnswire_need_more: + if (handle->read_left < handle->read_size) { + // still space left to fill + if (handle->read_at) { + // move what's left to the start + if (handle->read_left) { + memmove(handle->read_buf, &handle->read_buf[handle->read_at], handle->read_left); + } + handle->read_at = 0; + } + } else if (handle->read_size < handle->read_max) { + // no space left, expand + size_t size = handle->read_size + handle->read_inc > handle->read_max ? handle->read_max : handle->read_size + handle->read_inc; + uint8_t* buf = realloc(handle->read_buf, size); + if (!buf) { + return dnswire_error; + } + handle->read_buf = buf; + handle->read_size = size; + } else { + // already at max size, and full + return dnswire_error; + } + __state(handle, dnswire_writer_reading_accept); + return dnswire_need_more; + + default: + break; + } + return dnswire_error; + + case dnswire_writer_encoding: + res = _encoding(handle); + __trace("left %zu", handle->left); + if (res != dnswire_error && handle->left) { + __state(handle, dnswire_writer_writing); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing: { + ssize_t nwrote = write(fd, &handle->buf[handle->at - handle->left], handle->left); + __trace("wrote %zd", nwrote); + if (nwrote < 0) { + // TODO + return dnswire_error; + } else if (!nwrote) { + // TODO + return dnswire_error; + } + + handle->left -= nwrote; + __trace("left %zu", handle->left); + if (!handle->left) { + handle->at = 0; + __state(handle, dnswire_writer_encoding); + } + break; + } + + case dnswire_writer_stopping: + if (handle->left) { + ssize_t nwrote = write(fd, &handle->buf[handle->at - handle->left], handle->left); + __trace("wrote %zd", nwrote); + if (nwrote < 0) { + // TODO + return dnswire_error; + } else if (!nwrote) { + // TODO + return dnswire_error; + } + + handle->left -= nwrote; + if (handle->left) { + __trace("left %zu", handle->left); + return dnswire_again; + } + handle->at = 0; + } + __state(handle, dnswire_writer_encoding_stop); + // fallthrough + + case dnswire_writer_encoding_stop: + res = _encoding(handle); + if (res == dnswire_endofdata) { + __state(handle, dnswire_writer_writing_stop); + // fallthrough + } else { + break; + } + + case dnswire_writer_writing_stop: + if (handle->left) { + ssize_t nwrote = write(fd, &handle->buf[handle->at - handle->left], handle->left); + __trace("wrote %zd", nwrote); + if (nwrote < 0) { + // TODO + return dnswire_error; + } else if (!nwrote) { + // TODO + return dnswire_error; + } + + handle->left -= nwrote; + if (handle->left) { + __trace("left %zu", handle->left); + return dnswire_again; + } + handle->at = 0; + } + if (handle->bidirectional) { + __state(handle, dnswire_writer_reading_finish); + return dnswire_again; + } + __state(handle, dnswire_writer_done); + return dnswire_endofdata; + + case dnswire_writer_reading_finish: { + ssize_t nread = read(fd, &handle->read_buf[handle->read_at + handle->read_left], handle->read_size - handle->read_at - handle->read_left); + if (nread < 0) { + // TODO + return dnswire_error; + } else if (!nread) { + // TODO + return dnswire_error; + } + __trace("%s", __printable_string(&handle->read_buf[handle->read_at + handle->read_left], nread)); + handle->read_left += nread; + __state(handle, dnswire_writer_decoding_finish); + // fallthrough + } + + case dnswire_writer_decoding_finish: + switch (dnswire_decoder_decode(&handle->decoder, &handle->read_buf[handle->read_at], handle->read_left)) { + case dnswire_endofdata: + __state(handle, dnswire_writer_done); + return dnswire_endofdata; + + case dnswire_need_more: + if (handle->read_left < handle->read_size) { + // still space left to fill + if (handle->read_at) { + // move what's left to the start + if (handle->read_left) { + memmove(handle->read_buf, &handle->read_buf[handle->read_at], handle->read_left); + } + handle->read_at = 0; + } + } else if (handle->read_size < handle->read_max) { + // no space left, expand + size_t size = handle->read_size + handle->read_inc > handle->read_max ? handle->read_max : handle->read_size + handle->read_inc; + uint8_t* buf = realloc(handle->read_buf, size); + if (!buf) { + return dnswire_error; + } + handle->read_buf = buf; + handle->read_size = size; + } else { + // already at max size, and full + return dnswire_error; + } + __state(handle, dnswire_writer_reading_accept); + return dnswire_need_more; + + default: + break; + } + return dnswire_error; + + case dnswire_writer_done: + return dnswire_error; + } + + return res; +} + +enum dnswire_result dnswire_writer_stop(struct dnswire_writer* handle) +{ + assert(handle); + + enum dnswire_result res = dnswire_encoder_stop(&handle->encoder); + + if (res == dnswire_ok) { + __state(handle, dnswire_writer_stopping); + } + + return res; +} -- cgit v1.2.3