summaryrefslogtreecommitdiffstats
path: root/test/unit/nts_ke_session.c
diff options
context:
space:
mode:
Diffstat (limited to 'test/unit/nts_ke_session.c')
-rw-r--r--test/unit/nts_ke_session.c218
1 files changed, 218 insertions, 0 deletions
diff --git a/test/unit/nts_ke_session.c b/test/unit/nts_ke_session.c
new file mode 100644
index 0000000..adcade6
--- /dev/null
+++ b/test/unit/nts_ke_session.c
@@ -0,0 +1,218 @@
+/*
+ **********************************************************************
+ * Copyright (C) Miroslav Lichvar 2020
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program; if not, write to the Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+ *
+ **********************************************************************
+ */
+
+#include <config.h>
+#include "test.h"
+
+#ifdef FEAT_NTS
+
+#include <nts_ke_session.c>
+
+#include <local.h>
+#include <socket.h>
+#include <sched.h>
+
+static NKSN_Instance client, server;
+static unsigned char record[NKE_MAX_MESSAGE_LENGTH];
+static int record_length, critical, type_start, records;
+static int request_received;
+static int response_received;
+
+static void
+send_message(NKSN_Instance inst)
+{
+ int i;
+
+ record_length = random() % (NKE_MAX_MESSAGE_LENGTH - 4 + 1);
+ for (i = 0; i < record_length; i++)
+ record[i] = random() % 256;
+ critical = random() % 2;
+ type_start = random() % 30000 + 1;
+ assert(sizeof (struct RecordHeader) == 4);
+ records = random() % ((NKE_MAX_MESSAGE_LENGTH - 4) / (4 + record_length) + 1);
+
+ DEBUG_LOG("critical=%d type_start=%d records=%d*%d",
+ critical, type_start, records, record_length);
+
+ NKSN_BeginMessage(inst);
+
+ TEST_CHECK(check_message_format(&inst->message, 0));
+ TEST_CHECK(!check_message_format(&inst->message, 1));
+
+ TEST_CHECK(!NKSN_AddRecord(inst, 0, 1, record, NKE_MAX_MESSAGE_LENGTH - 4 + 1));
+
+ TEST_CHECK(check_message_format(&inst->message, 0));
+ TEST_CHECK(!check_message_format(&inst->message, 1));
+
+ for (i = 0; i < records; i++) {
+ TEST_CHECK(NKSN_AddRecord(inst, critical, type_start + i, record, record_length));
+ TEST_CHECK(!NKSN_AddRecord(inst, 0, 1, &record,
+ NKE_MAX_MESSAGE_LENGTH - inst->message.length - 4 + 1));
+
+ TEST_CHECK(check_message_format(&inst->message, 0));
+ TEST_CHECK(!check_message_format(&inst->message, 1));
+ }
+
+ TEST_CHECK(NKSN_EndMessage(inst));
+
+ TEST_CHECK(check_message_format(&inst->message, 0));
+ TEST_CHECK(check_message_format(&inst->message, 1));
+}
+
+static void
+verify_message(NKSN_Instance inst)
+{
+ unsigned char buffer[NKE_MAX_MESSAGE_LENGTH];
+ int i, c, t, length, buffer_length, msg_length, prev_parsed;
+ NKE_Key c2s, s2c;
+
+ for (i = 0; i < records; i++) {
+ memset(buffer, 0, sizeof (buffer));
+ buffer_length = random() % (record_length + 1);
+ assert(buffer_length <= sizeof (buffer));
+
+ prev_parsed = inst->message.parsed;
+ msg_length = inst->message.length;
+
+ TEST_CHECK(NKSN_GetRecord(inst, &c, &t, &length, buffer, buffer_length));
+ TEST_CHECK(c == critical);
+ TEST_CHECK(t == type_start + i);
+ TEST_CHECK(length == record_length);
+ TEST_CHECK(memcmp(record, buffer, buffer_length) == 0);
+ if (buffer_length < record_length)
+ TEST_CHECK(buffer[buffer_length] == 0);
+
+ inst->message.length = inst->message.parsed - 1;
+ inst->message.parsed = prev_parsed;
+ TEST_CHECK(!get_record(&inst->message, NULL, NULL, NULL, buffer, buffer_length));
+ TEST_CHECK(inst->message.parsed == prev_parsed);
+ inst->message.length = msg_length;
+ if (msg_length < 0x8000) {
+ inst->message.data[prev_parsed + 2] ^= 0x80;
+ TEST_CHECK(!get_record(&inst->message, NULL, NULL, NULL, buffer, buffer_length));
+ TEST_CHECK(inst->message.parsed == prev_parsed);
+ inst->message.data[prev_parsed + 2] ^= 0x80;
+ }
+ TEST_CHECK(get_record(&inst->message, NULL, NULL, NULL, buffer, buffer_length));
+ TEST_CHECK(inst->message.parsed > prev_parsed);
+ }
+
+ TEST_CHECK(!NKSN_GetRecord(inst, &critical, &t, &length, buffer, sizeof (buffer)));
+
+ TEST_CHECK(NKSN_GetKeys(inst, AEAD_AES_SIV_CMAC_256, &c2s, &s2c));
+ TEST_CHECK(c2s.length == SIV_GetKeyLength(AEAD_AES_SIV_CMAC_256));
+ TEST_CHECK(s2c.length == SIV_GetKeyLength(AEAD_AES_SIV_CMAC_256));
+}
+
+static int
+handle_request(void *arg)
+{
+ NKSN_Instance server = arg;
+
+ verify_message(server);
+
+ request_received = 1;
+
+ send_message(server);
+
+ return 1;
+}
+
+static int
+handle_response(void *arg)
+{
+ NKSN_Instance client = arg;
+
+ response_received = 1;
+
+ verify_message(client);
+
+ return 1;
+}
+
+static void
+check_finished(void *arg)
+{
+ DEBUG_LOG("checking for stopped sessions");
+ if (!NKSN_IsStopped(server) || !NKSN_IsStopped(client)) {
+ SCH_AddTimeoutByDelay(0.001, check_finished, NULL);
+ return;
+ }
+
+ SCH_QuitProgram();
+}
+
+void
+test_unit(void)
+{
+ void *client_cred, *server_cred;
+ int sock_fds[2], i;
+
+ LCL_Initialise();
+ TST_RegisterDummyDrivers();
+
+ for (i = 0; i < 50; i++) {
+ SCH_Initialise();
+
+ server = NKSN_CreateInstance(1, NULL, handle_request, NULL);
+ client = NKSN_CreateInstance(0, "test", handle_response, NULL);
+
+ server_cred = NKSN_CreateCertCredentials("nts_ke.crt", "nts_ke.key", NULL);
+ client_cred = NKSN_CreateCertCredentials(NULL, NULL, "nts_ke.crt");
+
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds) == 0);
+ TEST_CHECK(fcntl(sock_fds[0], F_SETFL, O_NONBLOCK) == 0);
+ TEST_CHECK(fcntl(sock_fds[1], F_SETFL, O_NONBLOCK) == 0);
+
+ TEST_CHECK(NKSN_StartSession(server, sock_fds[0], "client", server_cred, 4.0));
+ TEST_CHECK(NKSN_StartSession(client, sock_fds[1], "server", client_cred, 4.0));
+
+ send_message(client);
+
+ request_received = response_received = 0;
+
+ check_finished(NULL);
+
+ SCH_MainLoop();
+
+ TEST_CHECK(NKSN_IsStopped(server));
+ TEST_CHECK(NKSN_IsStopped(client));
+
+ TEST_CHECK(request_received);
+ TEST_CHECK(response_received);
+
+ NKSN_DestroyInstance(server);
+ NKSN_DestroyInstance(client);
+
+ NKSN_DestroyCertCredentials(server_cred);
+ NKSN_DestroyCertCredentials(client_cred);
+
+ SCH_Finalise();
+ }
+
+ LCL_Finalise();
+}
+#else
+void
+test_unit(void)
+{
+ TEST_REQUIRE(0);
+}
+#endif