summaryrefslogtreecommitdiffstats
path: root/src/seastar/apps/memcached/tests
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
commit483eb2f56657e8e7f419ab1a4fab8dce9ade8609 (patch)
treee5d88d25d870d5dedacb6bbdbe2a966086a0a5cf /src/seastar/apps/memcached/tests
parentInitial commit. (diff)
downloadceph-upstream.tar.xz
ceph-upstream.zip
Adding upstream version 14.2.21.upstream/14.2.21upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/seastar/apps/memcached/tests/CMakeLists.txt75
-rwxr-xr-xsrc/seastar/apps/memcached/tests/test.py49
-rw-r--r--src/seastar/apps/memcached/tests/test_ascii_parser.cc335
-rwxr-xr-xsrc/seastar/apps/memcached/tests/test_memcached.py600
4 files changed, 1059 insertions, 0 deletions
diff --git a/src/seastar/apps/memcached/tests/CMakeLists.txt b/src/seastar/apps/memcached/tests/CMakeLists.txt
new file mode 100644
index 00000000..9301cea7
--- /dev/null
+++ b/src/seastar/apps/memcached/tests/CMakeLists.txt
@@ -0,0 +1,75 @@
+#
+# This file is open source software, licensed to you under the terms
+# of the Apache License, Version 2.0 (the "License"). See the NOTICE file
+# distributed with this work for additional information regarding copyright
+# ownership. You may not use this file except in compliance with the License.
+#
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+#
+# Copyright (C) 2018 Scylladb, Ltd.
+#
+
+if (Seastar_EXECUTE_ONLY_FAST_TESTS)
+ set (memcached_test_args --fast)
+else ()
+ set (memcached_test_args "")
+endif ()
+
+add_custom_target (app_memcached_test_memcached_run
+ DEPENDS
+ ${memcached_app}
+ ${CMAKE_CURRENT_SOURCE_DIR}/test.py
+ ${CMAKE_CURRENT_SOURCE_DIR}/test_memcached.py
+ COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test.py --memcached $<TARGET_FILE:app_memcached> ${memcached_test_args}
+ USES_TERMINAL)
+
+add_test (
+ NAME Seastar.app.memcached.memcached
+ COMMAND ${CMAKE_COMMAND} --build ${Seastar_BINARY_DIR} --target app_memcached_test_memcached_run)
+
+set_tests_properties (Seastar.app.memcached.memcached
+ PROPERTIES
+ TIMEOUT ${Seastar_TEST_TIMEOUT})
+
+add_executable (app_memcached_test_ascii
+ test_ascii_parser.cc)
+
+add_dependencies (app_memcached_test_ascii app_memcached)
+
+target_include_directories (app_memcached_test_ascii
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${Seastar_APP_MEMCACHED_BINARY_DIR}
+ ${Seastar_APP_MEMCACHED_SOURCE_DIR})
+
+target_compile_definitions (app_memcached_test_ascii
+ PRIVATE SEASTAR_TESTING_MAIN)
+
+target_link_libraries (app_memcached_test_ascii
+ PRIVATE
+ seastar_with_flags
+ seastar_testing)
+
+add_custom_target (app_memcached_test_ascii_run
+ DEPENDS app_memcached_test_ascii
+ COMMAND app_memcached_test_ascii -- -c 2
+ USES_TERMINAL)
+
+add_test (
+ NAME Seastar.app.memcached.ascii
+ COMMAND ${CMAKE_COMMAND} --build ${Seastar_BINARY_DIR} --target app_memcached_test_ascii_run)
+
+set_tests_properties (Seastar.app.memcached.ascii
+ PROPERTIES
+ TIMEOUT ${Seastar_TEST_TIMEOUT})
diff --git a/src/seastar/apps/memcached/tests/test.py b/src/seastar/apps/memcached/tests/test.py
new file mode 100755
index 00000000..c2f2b80c
--- /dev/null
+++ b/src/seastar/apps/memcached/tests/test.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python3
+#
+# This file is open source software, licensed to you under the terms
+# of the Apache License, Version 2.0 (the "License"). See the NOTICE file
+# distributed with this work for additional information regarding copyright
+# ownership. You may not use this file except in compliance with the License.
+#
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import time
+import sys
+import os
+import argparse
+import subprocess
+
+DIR_PATH = os.path.dirname(os.path.realpath(__file__))
+
+def run(args, cmd):
+ mc = subprocess.Popen([args.memcached, '--smp=2'])
+ print('Memcached started.')
+ try:
+ cmdline = [DIR_PATH + '/test_memcached.py'] + cmd
+ if args.fast:
+ cmdline.append('--fast')
+ print('Running: ' + ' '.join(cmdline))
+ subprocess.check_call(cmdline)
+ finally:
+ print('Killing memcached...')
+ mc.terminate();
+ mc.wait()
+ print('Memcached killed.')
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Seastar test runner")
+ parser.add_argument('--fast', action="store_true", help="Run only fast tests")
+ parser.add_argument('--memcached', required=True, help='Path of the memcached executable')
+ args = parser.parse_args()
+
+ run(args, [])
+ run(args, ['-U'])
diff --git a/src/seastar/apps/memcached/tests/test_ascii_parser.cc b/src/seastar/apps/memcached/tests/test_ascii_parser.cc
new file mode 100644
index 00000000..596d193e
--- /dev/null
+++ b/src/seastar/apps/memcached/tests/test_ascii_parser.cc
@@ -0,0 +1,335 @@
+/*
+ * This file is open source software, licensed to you under the terms
+ * of the Apache License, Version 2.0 (the "License"). See the NOTICE file
+ * distributed with this work for additional information regarding copyright
+ * ownership. You may not use this file except in compliance with the License.
+ *
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Copyright (C) 2014 Cloudius Systems, Ltd.
+ */
+
+#include <iostream>
+#include <limits>
+#include <seastar/testing/test_case.hh>
+#include <seastar/core/shared_ptr.hh>
+#include <seastar/net/packet-data-source.hh>
+#include "ascii.hh"
+#include <seastar/core/future-util.hh>
+
+using namespace seastar;
+using namespace net;
+using namespace memcache;
+
+using parser_type = memcache_ascii_parser;
+
+static packet make_packet(std::vector<std::string> chunks, size_t buffer_size) {
+ packet p;
+ for (auto&& chunk : chunks) {
+ size_t size = chunk.size();
+ for (size_t pos = 0; pos < size; pos += buffer_size) {
+ auto now = std::min(pos + buffer_size, chunk.size()) - pos;
+ p.append(packet(chunk.data() + pos, now));
+ }
+ }
+ return p;
+}
+
+static auto make_input_stream(packet&& p) {
+ return input_stream<char>(data_source(
+ std::make_unique<packet_data_source>(std::move(p))));
+}
+
+static auto parse(packet&& p) {
+ auto is = make_lw_shared<input_stream<char>>(make_input_stream(std::move(p)));
+ auto parser = make_lw_shared<parser_type>();
+ parser->init();
+ return is->consume(*parser).then([is, parser] {
+ return make_ready_future<lw_shared_ptr<parser_type>>(parser);
+ });
+}
+
+auto for_each_fragment_size = [] (auto&& func) {
+ auto buffer_sizes = { 100000, 1000, 100, 10, 5, 2, 1 };
+ return do_for_each(buffer_sizes.begin(), buffer_sizes.end(), [func] (size_t buffer_size) {
+ return func([buffer_size] (std::vector<std::string> chunks) {
+ return make_packet(chunks, buffer_size);
+ });
+ });
+};
+
+SEASTAR_TEST_CASE(test_set_command_is_parsed) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return parse(make_packet({"set key 1 2 3\r\nabc\r\n"})).then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_flags_str == "1");
+ BOOST_REQUIRE(p->_expiration == 2);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_blob == "abc");
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_empty_data_is_parsed) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return parse(make_packet({"set key 1 2 0\r\n\r\n"})).then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_flags_str == "1");
+ BOOST_REQUIRE(p->_expiration == 2);
+ BOOST_REQUIRE(p->_size == 0);
+ BOOST_REQUIRE(p->_size_str == "0");
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_blob == "");
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_superflous_data_is_an_error) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return parse(make_packet({"set key 0 0 0\r\nasd\r\n"})).then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_not_enough_data_is_an_error) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return parse(make_packet({"set key 0 0 3\r\n"})).then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_u32_parsing) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return make_ready_future<>().then([make_packet] {
+ return parse(make_packet({"set key 0 0 0\r\n\r\n"})).then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_flags_str == "0");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 12345 0 0\r\n\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_flags_str == "12345");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key -1 0 0\r\n\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 1-1 0 0\r\n\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key " + std::to_string(std::numeric_limits<uint32_t>::max()) + " 0 0\r\n\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_flags_str == to_sstring(std::numeric_limits<uint32_t>::max()));
+ });
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_parsing_of_split_data) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return make_ready_future<>()
+ .then([make_packet] {
+ return parse(make_packet({"set key 11", "1 222 3\r\nasd\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 11", "1 22", "2 3", "\r\nasd\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set k", "ey 11", "1 2", "2", "2 3", "\r\nasd\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 111 222 3\r\n", "asd\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 111 222 3\r\na", "sd\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 111 222 3\r\nasd", "\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"set key 111 222 3\r\nasd\r", "\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key");
+ BOOST_REQUIRE(p->_flags_str == "111");
+ BOOST_REQUIRE(p->_expiration == 222);
+ BOOST_REQUIRE(p->_size == 3);
+ BOOST_REQUIRE(p->_size_str == "3");
+ BOOST_REQUIRE(p->_blob == "asd");
+ });
+ });
+ });
+}
+
+static std::vector<sstring> as_strings(std::vector<item_key>& keys) {
+ std::vector<sstring> v;
+ for (auto&& key : keys) {
+ v.push_back(key.key());
+ }
+ return v;
+}
+
+SEASTAR_TEST_CASE(test_get_parsing) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return make_ready_future<>()
+ .then([make_packet] {
+ return parse(make_packet({"get key1\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_get);
+ BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1"}));
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"get key1 key2\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_get);
+ BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1", "key2"}));
+ });
+ }).then([make_packet] {
+ return parse(make_packet({"get key1 key2 key3\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_get);
+ BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1", "key2", "key3"}));
+ });
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_catches_errors_in_get) {
+ return for_each_fragment_size([] (auto make_packet) {
+ return make_ready_future<>()
+ .then([make_packet] {
+ return parse(make_packet({"get\r\n"}))
+ .then([] (auto p) {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ });
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_parser_returns_eof_state_when_no_command_follows) {
+ return for_each_fragment_size([] (auto make_packet) {
+ auto p = make_shared<parser_type>();
+ auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"get key\r\n"})));
+ p->init();
+ return is->consume(*p).then([p] {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_get);
+ }).then([is, p] {
+ p->init();
+ return is->consume(*p).then([p, is] {
+ BOOST_REQUIRE(p->_state == parser_type::state::eof);
+ });
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_incomplete_command_is_an_error) {
+ return for_each_fragment_size([] (auto make_packet) {
+ auto p = make_shared<parser_type>();
+ auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"get"})));
+ p->init();
+ return is->consume(*p).then([p] {
+ BOOST_REQUIRE(p->_state == parser_type::state::error);
+ }).then([is, p] {
+ p->init();
+ return is->consume(*p).then([p, is] {
+ BOOST_REQUIRE(p->_state == parser_type::state::eof);
+ });
+ });
+ });
+}
+
+SEASTAR_TEST_CASE(test_multiple_requests_in_one_stream) {
+ return for_each_fragment_size([] (auto make_packet) {
+ auto p = make_shared<parser_type>();
+ auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"set key1 1 1 5\r\ndata1\r\nset key2 2 2 6\r\ndata2+\r\n"})));
+ p->init();
+ return is->consume(*p).then([p] {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key1");
+ BOOST_REQUIRE(p->_flags_str == "1");
+ BOOST_REQUIRE(p->_expiration == 1);
+ BOOST_REQUIRE(p->_size == 5);
+ BOOST_REQUIRE(p->_size_str == "5");
+ BOOST_REQUIRE(p->_blob == "data1");
+ }).then([is, p] {
+ p->init();
+ return is->consume(*p).then([p, is] {
+ BOOST_REQUIRE(p->_state == parser_type::state::cmd_set);
+ BOOST_REQUIRE(p->_key.key() == "key2");
+ BOOST_REQUIRE(p->_flags_str == "2");
+ BOOST_REQUIRE(p->_expiration == 2);
+ BOOST_REQUIRE(p->_size == 6);
+ BOOST_REQUIRE(p->_size_str == "6");
+ BOOST_REQUIRE(p->_blob == "data2+");
+ });
+ });
+ });
+}
diff --git a/src/seastar/apps/memcached/tests/test_memcached.py b/src/seastar/apps/memcached/tests/test_memcached.py
new file mode 100755
index 00000000..4aca858e
--- /dev/null
+++ b/src/seastar/apps/memcached/tests/test_memcached.py
@@ -0,0 +1,600 @@
+#!/usr/bin/env python3
+#
+# This file is open source software, licensed to you under the terms
+# of the Apache License, Version 2.0 (the "License"). See the NOTICE file
+# distributed with this work for additional information regarding copyright
+# ownership. You may not use this file except in compliance with the License.
+#
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from contextlib import contextmanager
+import socket
+import struct
+import sys
+import random
+import argparse
+import time
+import re
+import unittest
+
+server_addr = None
+call = None
+args = None
+
+class TimeoutError(Exception):
+ pass
+
+@contextmanager
+def tcp_connection(timeout=1):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.settimeout(timeout)
+ s.connect(server_addr)
+ def call(msg):
+ s.send(msg.encode())
+ return s.recv(16*1024)
+ yield call
+ s.close()
+
+def slow(f):
+ def wrapper(self):
+ if args.fast:
+ raise unittest.SkipTest('Slow')
+ return f(self)
+ return wrapper
+
+def recv_all(s):
+ m = b''
+ while True:
+ data = s.recv(1024)
+ if not data:
+ break
+ m += data
+ return m
+
+def tcp_call(msg, timeout=1):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.settimeout(timeout)
+ s.connect(server_addr)
+ s.send(msg.encode())
+ s.shutdown(socket.SHUT_WR)
+ data = recv_all(s)
+ s.close()
+ return data
+
+def udp_call_for_fragments(msg, timeout=1):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.settimeout(timeout)
+ this_req_id = random.randint(-32768, 32767)
+
+ datagram = struct.pack(">hhhh", this_req_id, 0, 1, 0) + msg.encode()
+ sock.sendto(datagram, server_addr)
+
+ messages = {}
+ n_determined = None
+ while True:
+ data, addr = sock.recvfrom(1500)
+ req_id, seq, n, res = struct.unpack_from(">hhhh", data)
+ content = data[8:]
+
+ if n_determined and n_determined != n:
+ raise Exception('Inconsitent number of total messages, %d and %d' % (n_determined, n))
+ n_determined = n
+
+ if req_id != this_req_id:
+ raise Exception('Invalid request id: ' + req_id + ', expected ' + this_req_id)
+
+ if seq in messages:
+ raise Exception('Duplicate message for seq=' + seq)
+
+ messages[seq] = content
+ if len(messages) == n:
+ break
+
+ for k, v in sorted(messages.items(), key=lambda e: e[0]):
+ yield v
+
+ sock.close()
+
+def udp_call(msg, **kwargs):
+ return b''.join(udp_call_for_fragments(msg, **kwargs))
+
+class MemcacheTest(unittest.TestCase):
+ def set(self, key, value, flags=0, expiry=0):
+ self.assertEqual(call('set %s %d %d %d\r\n%s\r\n' % (key, flags, expiry, len(value), value)), b'STORED\r\n')
+
+ def delete(self, key):
+ self.assertEqual(call('delete %s\r\n' % key), b'DELETED\r\n')
+
+ def assertHasKey(self, key):
+ resp = call('get %s\r\n' % key)
+ if not resp.startswith(('VALUE %s' % key).encode()):
+ self.fail('Key \'%s\' should be present, but got: %s' % (key, resp.decode()))
+
+ def assertNoKey(self, key):
+ resp = call('get %s\r\n' % key)
+ if resp != b'END\r\n':
+ self.fail('Key \'%s\' should not be present, but got: %s' % (key, resp.decode()))
+
+ def setKey(self, key):
+ self.set(key, 'some value')
+
+ def getItemVersion(self, key):
+ m = re.match(r'VALUE %s \d+ \d+ (?P<version>\d+)' % key, call('gets %s\r\n' % key).decode())
+ return int(m.group('version'))
+
+ def getStat(self, name, call_fn=None):
+ if not call_fn: call_fn = call
+ resp = call_fn('stats\r\n').decode()
+ m = re.search(r'STAT %s (?P<value>.+)' % re.escape(name), resp, re.MULTILINE)
+ return m.group('value')
+
+ def flush(self):
+ self.assertEqual(call('flush_all\r\n'), b'OK\r\n')
+
+ def tearDown(self):
+ self.flush()
+
+class TcpSpecificTests(MemcacheTest):
+ def test_recovers_from_errors_in_the_stream(self):
+ with tcp_connection() as conn:
+ self.assertEqual(conn('get\r\n'), b'ERROR\r\n')
+ self.assertEqual(conn('get key\r\n'), b'END\r\n')
+
+ def test_incomplete_command_results_in_error(self):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.connect(server_addr)
+ s.send(b'get')
+ s.shutdown(socket.SHUT_WR)
+ self.assertEqual(recv_all(s), b'ERROR\r\n')
+ s.close()
+
+ def test_stream_closed_results_in_error(self):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.connect(server_addr)
+ s.shutdown(socket.SHUT_WR)
+ self.assertEqual(recv_all(s), b'')
+ s.close()
+
+ def test_unsuccesful_parsing_does_not_leave_data_behind(self):
+ with tcp_connection() as conn:
+ self.assertEqual(conn('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n')
+ self.assertRegex(conn('delete a b c\r\n'), b'^(CLIENT_)?ERROR.*\r\n$')
+ self.assertEqual(conn('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+ self.assertEqual(conn('delete key\r\n'), b'DELETED\r\n')
+
+ def test_flush_all_no_reply(self):
+ self.assertEqual(call('flush_all noreply\r\n'), b'')
+
+ def test_set_no_reply(self):
+ self.assertEqual(call('set key 0 0 5 noreply\r\nhello\r\nget key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+ self.delete('key')
+
+ def test_delete_no_reply(self):
+ self.setKey('key')
+ self.assertEqual(call('delete key noreply\r\nget key\r\n'), b'END\r\n')
+
+ def test_add_no_reply(self):
+ self.assertEqual(call('add key 0 0 1 noreply\r\na\r\nget key\r\n'), b'VALUE key 0 1\r\na\r\nEND\r\n')
+ self.delete('key')
+
+ def test_replace_no_reply(self):
+ self.assertEqual(call('set key 0 0 1\r\na\r\n'), b'STORED\r\n')
+ self.assertEqual(call('replace key 0 0 1 noreply\r\nb\r\nget key\r\n'), b'VALUE key 0 1\r\nb\r\nEND\r\n')
+ self.delete('key')
+
+ def test_cas_noreply(self):
+ self.assertNoKey('key')
+ self.assertEqual(call('cas key 0 0 1 1 noreply\r\na\r\n'), b'')
+ self.assertNoKey('key')
+
+ self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n')
+ version = self.getItemVersion('key')
+
+ self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version + 1)), b'')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+
+ self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version)), b'')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 1 5\r\naloha\r\nEND\r\n')
+
+ self.delete('key')
+
+ @slow
+ def test_connection_statistics(self):
+ with tcp_connection() as conn:
+ curr_connections = int(self.getStat('curr_connections', call_fn=conn))
+ total_connections = int(self.getStat('total_connections', call_fn=conn))
+ with tcp_connection() as conn2:
+ self.assertEqual(curr_connections + 1, int(self.getStat('curr_connections', call_fn=conn)))
+ self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn)))
+ self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn)))
+ time.sleep(0.1)
+ self.assertEqual(curr_connections, int(self.getStat('curr_connections', call_fn=conn)))
+
+class UdpSpecificTests(MemcacheTest):
+ def test_large_response_is_split_into_mtu_chunks(self):
+ max_datagram_size = 1400
+ data = '1' * (max_datagram_size*3)
+ self.set('key', data)
+
+ chunks = list(udp_call_for_fragments('get key\r\n'))
+
+ for chunk in chunks:
+ self.assertLessEqual(len(chunk), max_datagram_size)
+
+ self.assertEqual(b''.join(chunks).decode(),
+ 'VALUE key 0 %d\r\n%s\r\n' \
+ 'END\r\n' % (len(data), data))
+
+ self.delete('key')
+
+class TestCommands(MemcacheTest):
+ def test_basic_commands(self):
+ self.assertEqual(call('get key\r\n'), b'END\r\n')
+ self.assertEqual(call('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+ self.assertEqual(call('delete key\r\n'), b'DELETED\r\n')
+ self.assertEqual(call('delete key\r\n'), b'NOT_FOUND\r\n')
+ self.assertEqual(call('get key\r\n'), b'END\r\n')
+
+ def test_error_handling(self):
+ self.assertEqual(call('get\r\n'), b'ERROR\r\n')
+
+ @slow
+ def test_expiry(self):
+ self.assertEqual(call('set key 0 1 5\r\nhello\r\n'), b'STORED\r\n')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+ time.sleep(2)
+ self.assertEqual(call('get key\r\n'), b'END\r\n')
+
+ @slow
+ def test_expiry_at_epoch_time(self):
+ expiry = int(time.time()) + 1
+ self.assertEqual(call('set key 0 %d 5\r\nhello\r\n' % expiry), b'STORED\r\n')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n')
+ time.sleep(2)
+ self.assertEqual(call('get key\r\n'), b'END\r\n')
+
+ def test_multiple_keys_in_get(self):
+ self.assertEqual(call('set key1 0 0 2\r\nv1\r\n'), b'STORED\r\n')
+ self.assertEqual(call('set key 0 0 2\r\nv2\r\n'), b'STORED\r\n')
+ resp = call('get key1 key\r\n')
+ self.assertRegex(resp, b'^(VALUE key1 0 2\r\nv1\r\nVALUE key 0 2\r\nv2\r\nEND\r\n)|(VALUE key 0 2\r\nv2\r\nVALUE key1 0 2\r\nv1\r\nEND\r\n)$')
+ self.delete("key")
+ self.delete("key1")
+
+ def test_flush_all(self):
+ self.set('key', 'value')
+ self.assertEqual(call('flush_all\r\n'), b'OK\r\n')
+ self.assertNoKey('key')
+
+ def test_keys_set_after_flush_remain(self):
+ self.assertEqual(call('flush_all\r\n'), b'OK\r\n')
+ self.setKey('key')
+ self.assertHasKey('key')
+ self.delete('key')
+
+ @slow
+ def test_flush_all_with_timeout_flushes_all_keys_even_those_set_after_flush(self):
+ self.setKey('key')
+ self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n')
+ self.assertHasKey('key')
+ self.setKey('key2')
+ time.sleep(3)
+ self.assertNoKey('key')
+ self.assertNoKey('key2')
+
+ @slow
+ def test_subsequent_flush_is_merged(self):
+ self.setKey('key')
+ self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') # Can flush in anything between 1-2
+ self.assertEqual(call('flush_all 4\r\n'), b'OK\r\n') # Can flush in anything between 3-4
+ time.sleep(3)
+ self.assertHasKey('key')
+ self.setKey('key2')
+ time.sleep(4)
+ self.assertNoKey('key')
+ self.assertNoKey('key2')
+
+ @slow
+ def test_immediate_flush_cancels_delayed_flush(self):
+ self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n')
+ self.assertEqual(call('flush_all\r\n'), b'OK\r\n')
+ self.setKey('key')
+ time.sleep(1)
+ self.assertHasKey('key')
+ self.delete('key')
+
+ @slow
+ def test_flushing_in_the_past(self):
+ self.setKey('key1')
+ time.sleep(1)
+ self.setKey('key2')
+ key2_time = int(time.time())
+ self.assertEqual(call('flush_all %d\r\n' % (key2_time - 1)), b'OK\r\n')
+ time.sleep(1)
+ self.assertNoKey("key1")
+ self.assertNoKey("key2")
+
+ @slow
+ def test_memcache_does_not_crash_when_flushing_with_already_expred_items(self):
+ self.assertEqual(call('set key1 0 2 5\r\nhello\r\n'), b'STORED\r\n')
+ time.sleep(1)
+ self.assertEqual(call('flush_all\r\n'), b'OK\r\n')
+
+ def test_response_spanning_many_datagrams(self):
+ key1_data = '1' * 1000
+ key2_data = '2' * 1000
+ key3_data = '3' * 1000
+ self.set('key1', key1_data)
+ self.set('key2', key2_data)
+ self.set('key3', key3_data)
+
+ resp = call('get key1 key2 key3\r\n').decode()
+
+ pattern = '^VALUE (?P<v1>.*?\r\n.*?)\r\nVALUE (?P<v2>.*?\r\n.*?)\r\nVALUE (?P<v3>.*?\r\n.*?)\r\nEND\r\n$'
+ self.assertRegex(resp, pattern)
+
+ m = re.match(pattern, resp)
+ self.assertEqual(set([m.group('v1'), m.group('v2'), m.group('v3')]),
+ set(['key1 0 %d\r\n%s' % (len(key1_data), key1_data),
+ 'key2 0 %d\r\n%s' % (len(key2_data), key2_data),
+ 'key3 0 %d\r\n%s' % (len(key3_data), key3_data)]))
+
+ self.delete('key1')
+ self.delete('key2')
+ self.delete('key3')
+
+ def test_version(self):
+ self.assertRegex(call('version\r\n'), b'^VERSION .*\r\n$')
+
+ def test_add(self):
+ self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n')
+ self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n')
+ self.delete('key')
+
+ def test_replace(self):
+ self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n')
+ self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'STORED\r\n')
+ self.delete('key')
+ self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n')
+
+ def test_cas_and_gets(self):
+ self.assertEqual(call('cas key 0 0 1 1\r\na\r\n'), b'NOT_FOUND\r\n')
+ self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n')
+ version = self.getItemVersion('key')
+
+ self.assertEqual(call('set key 1 0 5\r\nhello\r\n'), b'STORED\r\n')
+ self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 1 5 %d\r\nhello\r\nEND\r\n' % (version + 1))
+
+ self.assertEqual(call('cas key 0 0 5 %d\r\nhello\r\n' % (version)), b'EXISTS\r\n')
+ self.assertEqual(call('cas key 0 0 5 %d\r\naloha\r\n' % (version + 1)), b'STORED\r\n')
+ self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 0 5 %d\r\naloha\r\nEND\r\n' % (version + 2))
+
+ self.delete('key')
+
+ def test_curr_items_stat(self):
+ self.assertEqual(0, int(self.getStat('curr_items')))
+ self.setKey('key')
+ self.assertEqual(1, int(self.getStat('curr_items')))
+ self.delete('key')
+ self.assertEqual(0, int(self.getStat('curr_items')))
+
+ def test_how_stats_change_with_different_commands(self):
+ get_count = int(self.getStat('cmd_get'))
+ set_count = int(self.getStat('cmd_set'))
+ flush_count = int(self.getStat('cmd_flush'))
+ total_items = int(self.getStat('total_items'))
+ get_misses = int(self.getStat('get_misses'))
+ get_hits = int(self.getStat('get_hits'))
+ cas_hits = int(self.getStat('cas_hits'))
+ cas_badval = int(self.getStat('cas_badval'))
+ cas_misses = int(self.getStat('cas_misses'))
+ delete_misses = int(self.getStat('delete_misses'))
+ delete_hits = int(self.getStat('delete_hits'))
+ curr_connections = int(self.getStat('curr_connections'))
+ incr_hits = int(self.getStat('incr_hits'))
+ incr_misses = int(self.getStat('incr_misses'))
+ decr_hits = int(self.getStat('decr_hits'))
+ decr_misses = int(self.getStat('decr_misses'))
+
+ call('get key\r\n')
+ get_count += 1
+ get_misses += 1
+
+ call('gets key\r\n')
+ get_count += 1
+ get_misses += 1
+
+ call('set key1 0 0 1\r\na\r\n')
+ set_count += 1
+ total_items += 1
+
+ call('get key1\r\n')
+ get_count += 1
+ get_hits += 1
+
+ call('add key1 0 0 1\r\na\r\n')
+ set_count += 1
+
+ call('add key2 0 0 1\r\na\r\n')
+ set_count += 1
+ total_items += 1
+
+ call('replace key1 0 0 1\r\na\r\n')
+ set_count += 1
+ total_items += 1
+
+ call('replace key3 0 0 1\r\na\r\n')
+ set_count += 1
+
+ call('cas key4 0 0 1 1\r\na\r\n')
+ set_count += 1
+ cas_misses += 1
+
+ call('cas key1 0 0 1 %d\r\na\r\n' % self.getItemVersion('key1'))
+ set_count += 1
+ get_count += 1
+ get_hits += 1
+ cas_hits += 1
+ total_items += 1
+
+ call('cas key1 0 0 1 %d\r\na\r\n' % (self.getItemVersion('key1') + 1))
+ set_count += 1
+ get_count += 1
+ get_hits += 1
+ cas_badval += 1
+
+ call('delete key1\r\n')
+ delete_hits += 1
+
+ call('delete key1\r\n')
+ delete_misses += 1
+
+ call('incr num 1\r\n')
+ incr_misses += 1
+ call('decr num 1\r\n')
+ decr_misses += 1
+
+ call('set num 0 0 1\r\n0\r\n')
+ set_count += 1
+ total_items += 1
+
+ call('incr num 1\r\n')
+ incr_hits += 1
+ call('decr num 1\r\n')
+ decr_hits += 1
+
+ self.flush()
+ flush_count += 1
+
+ self.assertEqual(get_count, int(self.getStat('cmd_get')))
+ self.assertEqual(set_count, int(self.getStat('cmd_set')))
+ self.assertEqual(flush_count, int(self.getStat('cmd_flush')))
+ self.assertEqual(total_items, int(self.getStat('total_items')))
+ self.assertEqual(get_hits, int(self.getStat('get_hits')))
+ self.assertEqual(get_misses, int(self.getStat('get_misses')))
+ self.assertEqual(cas_misses, int(self.getStat('cas_misses')))
+ self.assertEqual(cas_hits, int(self.getStat('cas_hits')))
+ self.assertEqual(cas_badval, int(self.getStat('cas_badval')))
+ self.assertEqual(delete_misses, int(self.getStat('delete_misses')))
+ self.assertEqual(delete_hits, int(self.getStat('delete_hits')))
+ self.assertEqual(0, int(self.getStat('curr_items')))
+ self.assertEqual(curr_connections, int(self.getStat('curr_connections')))
+ self.assertEqual(incr_misses, int(self.getStat('incr_misses')))
+ self.assertEqual(incr_hits, int(self.getStat('incr_hits')))
+ self.assertEqual(decr_misses, int(self.getStat('decr_misses')))
+ self.assertEqual(decr_hits, int(self.getStat('decr_hits')))
+
+ def test_incr(self):
+ self.assertEqual(call('incr key 0\r\n'), b'NOT_FOUND\r\n')
+
+ self.assertEqual(call('set key 0 0 1\r\n0\r\n'), b'STORED\r\n')
+ self.assertEqual(call('incr key 0\r\n'), b'0\r\n')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n0\r\nEND\r\n')
+
+ self.assertEqual(call('incr key 1\r\n'), b'1\r\n')
+ self.assertEqual(call('incr key 2\r\n'), b'3\r\n')
+ self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 1)), b'2\r\n')
+ self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 3)), b'18446744073709551615\r\n')
+ self.assertRegex(call('incr key 1\r\n').decode(), r'0(\w+)?\r\n')
+
+ self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n')
+ self.assertEqual(call('incr key 1\r\n'), b'2\r\n')
+
+ self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n')
+ self.assertEqual(call('incr key 1\r\n'), b'10\r\n')
+
+ def test_decr(self):
+ self.assertEqual(call('decr key 0\r\n'), b'NOT_FOUND\r\n')
+
+ self.assertEqual(call('set key 0 0 1\r\n7\r\n'), b'STORED\r\n')
+ self.assertEqual(call('decr key 1\r\n'), b'6\r\n')
+ self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n6\r\nEND\r\n')
+
+ self.assertEqual(call('decr key 6\r\n'), b'0\r\n')
+ self.assertEqual(call('decr key 2\r\n'), b'0\r\n')
+
+ self.assertEqual(call('set key 0 0 2\r\n20\r\n'), b'STORED\r\n')
+ self.assertRegex(call('decr key 11\r\n').decode(), r'^9( )?\r\n$')
+
+ self.assertEqual(call('set key 0 0 3\r\n100\r\n'), b'STORED\r\n')
+ self.assertRegex(call('decr key 91\r\n').decode(), r'^9( )?\r\n$')
+
+ self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n')
+ self.assertEqual(call('decr key 1\r\n'), b'0\r\n')
+
+ self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n')
+ self.assertEqual(call('decr key 1\r\n'), b'8\r\n')
+
+ def test_incr_and_decr_on_invalid_input(self):
+ error_msg = b'CLIENT_ERROR cannot increment or decrement non-numeric value\r\n'
+ for cmd in ['incr', 'decr']:
+ for value in ['', '-1', 'a', '0x1', '18446744073709551616']:
+ self.assertEqual(call('set key 0 0 %d\r\n%s\r\n' % (len(value), value)), b'STORED\r\n')
+ prev = call('get key\r\n')
+ self.assertEqual(call(cmd + ' key 1\r\n'), error_msg, "cmd=%s, value=%s" % (cmd, value))
+ self.assertEqual(call('get key\r\n'), prev)
+ self.delete('key')
+
+def wait_for_memcache_tcp(timeout=4):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ timeout_at = time.time() + timeout
+ while True:
+ if time.time() >= timeout_at:
+ raise TimeoutError()
+ try:
+ s.connect(server_addr)
+ s.close()
+ break
+ except ConnectionRefusedError:
+ time.sleep(0.1)
+
+
+def wait_for_memcache_udp(timeout=4):
+ timeout_at = time.time() + timeout
+ while True:
+ if time.time() >= timeout_at:
+ raise TimeoutError()
+ try:
+ udp_call('version\r\n', timeout=0.2)
+ break
+ except socket.timeout:
+ pass
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="memcache protocol tests")
+ parser.add_argument('--server', '-s', action="store", help="server adddress in <host>:<port> format", default="localhost:11211")
+ parser.add_argument('--udp', '-U', action="store_true", help="Use UDP protocol")
+ parser.add_argument('--fast', action="store_true", help="Run only fast tests")
+ args = parser.parse_args()
+
+ host, port = args.server.split(':')
+ server_addr = (host, int(port))
+
+ if args.udp:
+ call = udp_call
+ wait_for_memcache_udp()
+ else:
+ call = tcp_call
+ wait_for_memcache_tcp()
+
+ runner = unittest.TextTestRunner()
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ suite.addTest(loader.loadTestsFromTestCase(TestCommands))
+ if args.udp:
+ suite.addTest(loader.loadTestsFromTestCase(UdpSpecificTests))
+ else:
+ suite.addTest(loader.loadTestsFromTestCase(TcpSpecificTests))
+ result = runner.run(suite)
+ if not result.wasSuccessful():
+ sys.exit(1)