diff options
Diffstat (limited to '')
-rw-r--r-- | src/seastar/apps/memcached/tests/CMakeLists.txt | 75 | ||||
-rwxr-xr-x | src/seastar/apps/memcached/tests/test.py | 49 | ||||
-rw-r--r-- | src/seastar/apps/memcached/tests/test_ascii_parser.cc | 335 | ||||
-rwxr-xr-x | src/seastar/apps/memcached/tests/test_memcached.py | 600 |
4 files changed, 1059 insertions, 0 deletions
diff --git a/src/seastar/apps/memcached/tests/CMakeLists.txt b/src/seastar/apps/memcached/tests/CMakeLists.txt new file mode 100644 index 00000000..9301cea7 --- /dev/null +++ b/src/seastar/apps/memcached/tests/CMakeLists.txt @@ -0,0 +1,75 @@ +# +# This file is open source software, licensed to you under the terms +# of the Apache License, Version 2.0 (the "License"). See the NOTICE file +# distributed with this work for additional information regarding copyright +# ownership. You may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# +# Copyright (C) 2018 Scylladb, Ltd. +# + +if (Seastar_EXECUTE_ONLY_FAST_TESTS) + set (memcached_test_args --fast) +else () + set (memcached_test_args "") +endif () + +add_custom_target (app_memcached_test_memcached_run + DEPENDS + ${memcached_app} + ${CMAKE_CURRENT_SOURCE_DIR}/test.py + ${CMAKE_CURRENT_SOURCE_DIR}/test_memcached.py + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test.py --memcached $<TARGET_FILE:app_memcached> ${memcached_test_args} + USES_TERMINAL) + +add_test ( + NAME Seastar.app.memcached.memcached + COMMAND ${CMAKE_COMMAND} --build ${Seastar_BINARY_DIR} --target app_memcached_test_memcached_run) + +set_tests_properties (Seastar.app.memcached.memcached + PROPERTIES + TIMEOUT ${Seastar_TEST_TIMEOUT}) + +add_executable (app_memcached_test_ascii + test_ascii_parser.cc) + +add_dependencies (app_memcached_test_ascii app_memcached) + +target_include_directories (app_memcached_test_ascii + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${Seastar_APP_MEMCACHED_BINARY_DIR} + ${Seastar_APP_MEMCACHED_SOURCE_DIR}) + +target_compile_definitions (app_memcached_test_ascii + PRIVATE SEASTAR_TESTING_MAIN) + +target_link_libraries (app_memcached_test_ascii + PRIVATE + seastar_with_flags + seastar_testing) + +add_custom_target (app_memcached_test_ascii_run + DEPENDS app_memcached_test_ascii + COMMAND app_memcached_test_ascii -- -c 2 + USES_TERMINAL) + +add_test ( + NAME Seastar.app.memcached.ascii + COMMAND ${CMAKE_COMMAND} --build ${Seastar_BINARY_DIR} --target app_memcached_test_ascii_run) + +set_tests_properties (Seastar.app.memcached.ascii + PROPERTIES + TIMEOUT ${Seastar_TEST_TIMEOUT}) diff --git a/src/seastar/apps/memcached/tests/test.py b/src/seastar/apps/memcached/tests/test.py new file mode 100755 index 00000000..c2f2b80c --- /dev/null +++ b/src/seastar/apps/memcached/tests/test.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# +# This file is open source software, licensed to you under the terms +# of the Apache License, Version 2.0 (the "License"). See the NOTICE file +# distributed with this work for additional information regarding copyright +# ownership. You may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import time +import sys +import os +import argparse +import subprocess + +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) + +def run(args, cmd): + mc = subprocess.Popen([args.memcached, '--smp=2']) + print('Memcached started.') + try: + cmdline = [DIR_PATH + '/test_memcached.py'] + cmd + if args.fast: + cmdline.append('--fast') + print('Running: ' + ' '.join(cmdline)) + subprocess.check_call(cmdline) + finally: + print('Killing memcached...') + mc.terminate(); + mc.wait() + print('Memcached killed.') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Seastar test runner") + parser.add_argument('--fast', action="store_true", help="Run only fast tests") + parser.add_argument('--memcached', required=True, help='Path of the memcached executable') + args = parser.parse_args() + + run(args, []) + run(args, ['-U']) diff --git a/src/seastar/apps/memcached/tests/test_ascii_parser.cc b/src/seastar/apps/memcached/tests/test_ascii_parser.cc new file mode 100644 index 00000000..596d193e --- /dev/null +++ b/src/seastar/apps/memcached/tests/test_ascii_parser.cc @@ -0,0 +1,335 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright (C) 2014 Cloudius Systems, Ltd. + */ + +#include <iostream> +#include <limits> +#include <seastar/testing/test_case.hh> +#include <seastar/core/shared_ptr.hh> +#include <seastar/net/packet-data-source.hh> +#include "ascii.hh" +#include <seastar/core/future-util.hh> + +using namespace seastar; +using namespace net; +using namespace memcache; + +using parser_type = memcache_ascii_parser; + +static packet make_packet(std::vector<std::string> chunks, size_t buffer_size) { + packet p; + for (auto&& chunk : chunks) { + size_t size = chunk.size(); + for (size_t pos = 0; pos < size; pos += buffer_size) { + auto now = std::min(pos + buffer_size, chunk.size()) - pos; + p.append(packet(chunk.data() + pos, now)); + } + } + return p; +} + +static auto make_input_stream(packet&& p) { + return input_stream<char>(data_source( + std::make_unique<packet_data_source>(std::move(p)))); +} + +static auto parse(packet&& p) { + auto is = make_lw_shared<input_stream<char>>(make_input_stream(std::move(p))); + auto parser = make_lw_shared<parser_type>(); + parser->init(); + return is->consume(*parser).then([is, parser] { + return make_ready_future<lw_shared_ptr<parser_type>>(parser); + }); +} + +auto for_each_fragment_size = [] (auto&& func) { + auto buffer_sizes = { 100000, 1000, 100, 10, 5, 2, 1 }; + return do_for_each(buffer_sizes.begin(), buffer_sizes.end(), [func] (size_t buffer_size) { + return func([buffer_size] (std::vector<std::string> chunks) { + return make_packet(chunks, buffer_size); + }); + }); +}; + +SEASTAR_TEST_CASE(test_set_command_is_parsed) { + return for_each_fragment_size([] (auto make_packet) { + return parse(make_packet({"set key 1 2 3\r\nabc\r\n"})).then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_flags_str == "1"); + BOOST_REQUIRE(p->_expiration == 2); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_blob == "abc"); + }); + }); +} + +SEASTAR_TEST_CASE(test_empty_data_is_parsed) { + return for_each_fragment_size([] (auto make_packet) { + return parse(make_packet({"set key 1 2 0\r\n\r\n"})).then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_flags_str == "1"); + BOOST_REQUIRE(p->_expiration == 2); + BOOST_REQUIRE(p->_size == 0); + BOOST_REQUIRE(p->_size_str == "0"); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_blob == ""); + }); + }); +} + +SEASTAR_TEST_CASE(test_superflous_data_is_an_error) { + return for_each_fragment_size([] (auto make_packet) { + return parse(make_packet({"set key 0 0 0\r\nasd\r\n"})).then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }); + }); +} + +SEASTAR_TEST_CASE(test_not_enough_data_is_an_error) { + return for_each_fragment_size([] (auto make_packet) { + return parse(make_packet({"set key 0 0 3\r\n"})).then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }); + }); +} + +SEASTAR_TEST_CASE(test_u32_parsing) { + return for_each_fragment_size([] (auto make_packet) { + return make_ready_future<>().then([make_packet] { + return parse(make_packet({"set key 0 0 0\r\n\r\n"})).then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_flags_str == "0"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 12345 0 0\r\n\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_flags_str == "12345"); + }); + }).then([make_packet] { + return parse(make_packet({"set key -1 0 0\r\n\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }); + }).then([make_packet] { + return parse(make_packet({"set key 1-1 0 0\r\n\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }); + }).then([make_packet] { + return parse(make_packet({"set key " + std::to_string(std::numeric_limits<uint32_t>::max()) + " 0 0\r\n\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_flags_str == to_sstring(std::numeric_limits<uint32_t>::max())); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_parsing_of_split_data) { + return for_each_fragment_size([] (auto make_packet) { + return make_ready_future<>() + .then([make_packet] { + return parse(make_packet({"set key 11", "1 222 3\r\nasd\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 11", "1 22", "2 3", "\r\nasd\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set k", "ey 11", "1 2", "2", "2 3", "\r\nasd\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 111 222 3\r\n", "asd\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 111 222 3\r\na", "sd\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 111 222 3\r\nasd", "\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }).then([make_packet] { + return parse(make_packet({"set key 111 222 3\r\nasd\r", "\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key"); + BOOST_REQUIRE(p->_flags_str == "111"); + BOOST_REQUIRE(p->_expiration == 222); + BOOST_REQUIRE(p->_size == 3); + BOOST_REQUIRE(p->_size_str == "3"); + BOOST_REQUIRE(p->_blob == "asd"); + }); + }); + }); +} + +static std::vector<sstring> as_strings(std::vector<item_key>& keys) { + std::vector<sstring> v; + for (auto&& key : keys) { + v.push_back(key.key()); + } + return v; +} + +SEASTAR_TEST_CASE(test_get_parsing) { + return for_each_fragment_size([] (auto make_packet) { + return make_ready_future<>() + .then([make_packet] { + return parse(make_packet({"get key1\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_get); + BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1"})); + }); + }).then([make_packet] { + return parse(make_packet({"get key1 key2\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_get); + BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1", "key2"})); + }); + }).then([make_packet] { + return parse(make_packet({"get key1 key2 key3\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_get); + BOOST_REQUIRE_EQUAL(as_strings(p->_keys), std::vector<sstring>({"key1", "key2", "key3"})); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_catches_errors_in_get) { + return for_each_fragment_size([] (auto make_packet) { + return make_ready_future<>() + .then([make_packet] { + return parse(make_packet({"get\r\n"})) + .then([] (auto p) { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_parser_returns_eof_state_when_no_command_follows) { + return for_each_fragment_size([] (auto make_packet) { + auto p = make_shared<parser_type>(); + auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"get key\r\n"}))); + p->init(); + return is->consume(*p).then([p] { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_get); + }).then([is, p] { + p->init(); + return is->consume(*p).then([p, is] { + BOOST_REQUIRE(p->_state == parser_type::state::eof); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_incomplete_command_is_an_error) { + return for_each_fragment_size([] (auto make_packet) { + auto p = make_shared<parser_type>(); + auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"get"}))); + p->init(); + return is->consume(*p).then([p] { + BOOST_REQUIRE(p->_state == parser_type::state::error); + }).then([is, p] { + p->init(); + return is->consume(*p).then([p, is] { + BOOST_REQUIRE(p->_state == parser_type::state::eof); + }); + }); + }); +} + +SEASTAR_TEST_CASE(test_multiple_requests_in_one_stream) { + return for_each_fragment_size([] (auto make_packet) { + auto p = make_shared<parser_type>(); + auto is = make_shared<input_stream<char>>(make_input_stream(make_packet({"set key1 1 1 5\r\ndata1\r\nset key2 2 2 6\r\ndata2+\r\n"}))); + p->init(); + return is->consume(*p).then([p] { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key1"); + BOOST_REQUIRE(p->_flags_str == "1"); + BOOST_REQUIRE(p->_expiration == 1); + BOOST_REQUIRE(p->_size == 5); + BOOST_REQUIRE(p->_size_str == "5"); + BOOST_REQUIRE(p->_blob == "data1"); + }).then([is, p] { + p->init(); + return is->consume(*p).then([p, is] { + BOOST_REQUIRE(p->_state == parser_type::state::cmd_set); + BOOST_REQUIRE(p->_key.key() == "key2"); + BOOST_REQUIRE(p->_flags_str == "2"); + BOOST_REQUIRE(p->_expiration == 2); + BOOST_REQUIRE(p->_size == 6); + BOOST_REQUIRE(p->_size_str == "6"); + BOOST_REQUIRE(p->_blob == "data2+"); + }); + }); + }); +} diff --git a/src/seastar/apps/memcached/tests/test_memcached.py b/src/seastar/apps/memcached/tests/test_memcached.py new file mode 100755 index 00000000..4aca858e --- /dev/null +++ b/src/seastar/apps/memcached/tests/test_memcached.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# +# This file is open source software, licensed to you under the terms +# of the Apache License, Version 2.0 (the "License"). See the NOTICE file +# distributed with this work for additional information regarding copyright +# ownership. You may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from contextlib import contextmanager +import socket +import struct +import sys +import random +import argparse +import time +import re +import unittest + +server_addr = None +call = None +args = None + +class TimeoutError(Exception): + pass + +@contextmanager +def tcp_connection(timeout=1): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) + s.connect(server_addr) + def call(msg): + s.send(msg.encode()) + return s.recv(16*1024) + yield call + s.close() + +def slow(f): + def wrapper(self): + if args.fast: + raise unittest.SkipTest('Slow') + return f(self) + return wrapper + +def recv_all(s): + m = b'' + while True: + data = s.recv(1024) + if not data: + break + m += data + return m + +def tcp_call(msg, timeout=1): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) + s.connect(server_addr) + s.send(msg.encode()) + s.shutdown(socket.SHUT_WR) + data = recv_all(s) + s.close() + return data + +def udp_call_for_fragments(msg, timeout=1): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(timeout) + this_req_id = random.randint(-32768, 32767) + + datagram = struct.pack(">hhhh", this_req_id, 0, 1, 0) + msg.encode() + sock.sendto(datagram, server_addr) + + messages = {} + n_determined = None + while True: + data, addr = sock.recvfrom(1500) + req_id, seq, n, res = struct.unpack_from(">hhhh", data) + content = data[8:] + + if n_determined and n_determined != n: + raise Exception('Inconsitent number of total messages, %d and %d' % (n_determined, n)) + n_determined = n + + if req_id != this_req_id: + raise Exception('Invalid request id: ' + req_id + ', expected ' + this_req_id) + + if seq in messages: + raise Exception('Duplicate message for seq=' + seq) + + messages[seq] = content + if len(messages) == n: + break + + for k, v in sorted(messages.items(), key=lambda e: e[0]): + yield v + + sock.close() + +def udp_call(msg, **kwargs): + return b''.join(udp_call_for_fragments(msg, **kwargs)) + +class MemcacheTest(unittest.TestCase): + def set(self, key, value, flags=0, expiry=0): + self.assertEqual(call('set %s %d %d %d\r\n%s\r\n' % (key, flags, expiry, len(value), value)), b'STORED\r\n') + + def delete(self, key): + self.assertEqual(call('delete %s\r\n' % key), b'DELETED\r\n') + + def assertHasKey(self, key): + resp = call('get %s\r\n' % key) + if not resp.startswith(('VALUE %s' % key).encode()): + self.fail('Key \'%s\' should be present, but got: %s' % (key, resp.decode())) + + def assertNoKey(self, key): + resp = call('get %s\r\n' % key) + if resp != b'END\r\n': + self.fail('Key \'%s\' should not be present, but got: %s' % (key, resp.decode())) + + def setKey(self, key): + self.set(key, 'some value') + + def getItemVersion(self, key): + m = re.match(r'VALUE %s \d+ \d+ (?P<version>\d+)' % key, call('gets %s\r\n' % key).decode()) + return int(m.group('version')) + + def getStat(self, name, call_fn=None): + if not call_fn: call_fn = call + resp = call_fn('stats\r\n').decode() + m = re.search(r'STAT %s (?P<value>.+)' % re.escape(name), resp, re.MULTILINE) + return m.group('value') + + def flush(self): + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + + def tearDown(self): + self.flush() + +class TcpSpecificTests(MemcacheTest): + def test_recovers_from_errors_in_the_stream(self): + with tcp_connection() as conn: + self.assertEqual(conn('get\r\n'), b'ERROR\r\n') + self.assertEqual(conn('get key\r\n'), b'END\r\n') + + def test_incomplete_command_results_in_error(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(server_addr) + s.send(b'get') + s.shutdown(socket.SHUT_WR) + self.assertEqual(recv_all(s), b'ERROR\r\n') + s.close() + + def test_stream_closed_results_in_error(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(server_addr) + s.shutdown(socket.SHUT_WR) + self.assertEqual(recv_all(s), b'') + s.close() + + def test_unsuccesful_parsing_does_not_leave_data_behind(self): + with tcp_connection() as conn: + self.assertEqual(conn('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + self.assertRegex(conn('delete a b c\r\n'), b'^(CLIENT_)?ERROR.*\r\n$') + self.assertEqual(conn('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + self.assertEqual(conn('delete key\r\n'), b'DELETED\r\n') + + def test_flush_all_no_reply(self): + self.assertEqual(call('flush_all noreply\r\n'), b'') + + def test_set_no_reply(self): + self.assertEqual(call('set key 0 0 5 noreply\r\nhello\r\nget key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + self.delete('key') + + def test_delete_no_reply(self): + self.setKey('key') + self.assertEqual(call('delete key noreply\r\nget key\r\n'), b'END\r\n') + + def test_add_no_reply(self): + self.assertEqual(call('add key 0 0 1 noreply\r\na\r\nget key\r\n'), b'VALUE key 0 1\r\na\r\nEND\r\n') + self.delete('key') + + def test_replace_no_reply(self): + self.assertEqual(call('set key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('replace key 0 0 1 noreply\r\nb\r\nget key\r\n'), b'VALUE key 0 1\r\nb\r\nEND\r\n') + self.delete('key') + + def test_cas_noreply(self): + self.assertNoKey('key') + self.assertEqual(call('cas key 0 0 1 1 noreply\r\na\r\n'), b'') + self.assertNoKey('key') + + self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + version = self.getItemVersion('key') + + self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version + 1)), b'') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + + self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version)), b'') + self.assertEqual(call('get key\r\n'), b'VALUE key 1 5\r\naloha\r\nEND\r\n') + + self.delete('key') + + @slow + def test_connection_statistics(self): + with tcp_connection() as conn: + curr_connections = int(self.getStat('curr_connections', call_fn=conn)) + total_connections = int(self.getStat('total_connections', call_fn=conn)) + with tcp_connection() as conn2: + self.assertEqual(curr_connections + 1, int(self.getStat('curr_connections', call_fn=conn))) + self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) + self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) + time.sleep(0.1) + self.assertEqual(curr_connections, int(self.getStat('curr_connections', call_fn=conn))) + +class UdpSpecificTests(MemcacheTest): + def test_large_response_is_split_into_mtu_chunks(self): + max_datagram_size = 1400 + data = '1' * (max_datagram_size*3) + self.set('key', data) + + chunks = list(udp_call_for_fragments('get key\r\n')) + + for chunk in chunks: + self.assertLessEqual(len(chunk), max_datagram_size) + + self.assertEqual(b''.join(chunks).decode(), + 'VALUE key 0 %d\r\n%s\r\n' \ + 'END\r\n' % (len(data), data)) + + self.delete('key') + +class TestCommands(MemcacheTest): + def test_basic_commands(self): + self.assertEqual(call('get key\r\n'), b'END\r\n') + self.assertEqual(call('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + self.assertEqual(call('delete key\r\n'), b'DELETED\r\n') + self.assertEqual(call('delete key\r\n'), b'NOT_FOUND\r\n') + self.assertEqual(call('get key\r\n'), b'END\r\n') + + def test_error_handling(self): + self.assertEqual(call('get\r\n'), b'ERROR\r\n') + + @slow + def test_expiry(self): + self.assertEqual(call('set key 0 1 5\r\nhello\r\n'), b'STORED\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + time.sleep(2) + self.assertEqual(call('get key\r\n'), b'END\r\n') + + @slow + def test_expiry_at_epoch_time(self): + expiry = int(time.time()) + 1 + self.assertEqual(call('set key 0 %d 5\r\nhello\r\n' % expiry), b'STORED\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + time.sleep(2) + self.assertEqual(call('get key\r\n'), b'END\r\n') + + def test_multiple_keys_in_get(self): + self.assertEqual(call('set key1 0 0 2\r\nv1\r\n'), b'STORED\r\n') + self.assertEqual(call('set key 0 0 2\r\nv2\r\n'), b'STORED\r\n') + resp = call('get key1 key\r\n') + self.assertRegex(resp, b'^(VALUE key1 0 2\r\nv1\r\nVALUE key 0 2\r\nv2\r\nEND\r\n)|(VALUE key 0 2\r\nv2\r\nVALUE key1 0 2\r\nv1\r\nEND\r\n)$') + self.delete("key") + self.delete("key1") + + def test_flush_all(self): + self.set('key', 'value') + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.assertNoKey('key') + + def test_keys_set_after_flush_remain(self): + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.setKey('key') + self.assertHasKey('key') + self.delete('key') + + @slow + def test_flush_all_with_timeout_flushes_all_keys_even_those_set_after_flush(self): + self.setKey('key') + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') + self.assertHasKey('key') + self.setKey('key2') + time.sleep(3) + self.assertNoKey('key') + self.assertNoKey('key2') + + @slow + def test_subsequent_flush_is_merged(self): + self.setKey('key') + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') # Can flush in anything between 1-2 + self.assertEqual(call('flush_all 4\r\n'), b'OK\r\n') # Can flush in anything between 3-4 + time.sleep(3) + self.assertHasKey('key') + self.setKey('key2') + time.sleep(4) + self.assertNoKey('key') + self.assertNoKey('key2') + + @slow + def test_immediate_flush_cancels_delayed_flush(self): + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.setKey('key') + time.sleep(1) + self.assertHasKey('key') + self.delete('key') + + @slow + def test_flushing_in_the_past(self): + self.setKey('key1') + time.sleep(1) + self.setKey('key2') + key2_time = int(time.time()) + self.assertEqual(call('flush_all %d\r\n' % (key2_time - 1)), b'OK\r\n') + time.sleep(1) + self.assertNoKey("key1") + self.assertNoKey("key2") + + @slow + def test_memcache_does_not_crash_when_flushing_with_already_expred_items(self): + self.assertEqual(call('set key1 0 2 5\r\nhello\r\n'), b'STORED\r\n') + time.sleep(1) + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + + def test_response_spanning_many_datagrams(self): + key1_data = '1' * 1000 + key2_data = '2' * 1000 + key3_data = '3' * 1000 + self.set('key1', key1_data) + self.set('key2', key2_data) + self.set('key3', key3_data) + + resp = call('get key1 key2 key3\r\n').decode() + + pattern = '^VALUE (?P<v1>.*?\r\n.*?)\r\nVALUE (?P<v2>.*?\r\n.*?)\r\nVALUE (?P<v3>.*?\r\n.*?)\r\nEND\r\n$' + self.assertRegex(resp, pattern) + + m = re.match(pattern, resp) + self.assertEqual(set([m.group('v1'), m.group('v2'), m.group('v3')]), + set(['key1 0 %d\r\n%s' % (len(key1_data), key1_data), + 'key2 0 %d\r\n%s' % (len(key2_data), key2_data), + 'key3 0 %d\r\n%s' % (len(key3_data), key3_data)])) + + self.delete('key1') + self.delete('key2') + self.delete('key3') + + def test_version(self): + self.assertRegex(call('version\r\n'), b'^VERSION .*\r\n$') + + def test_add(self): + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') + self.delete('key') + + def test_replace(self): + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.delete('key') + self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') + + def test_cas_and_gets(self): + self.assertEqual(call('cas key 0 0 1 1\r\na\r\n'), b'NOT_FOUND\r\n') + self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + version = self.getItemVersion('key') + + self.assertEqual(call('set key 1 0 5\r\nhello\r\n'), b'STORED\r\n') + self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 1 5 %d\r\nhello\r\nEND\r\n' % (version + 1)) + + self.assertEqual(call('cas key 0 0 5 %d\r\nhello\r\n' % (version)), b'EXISTS\r\n') + self.assertEqual(call('cas key 0 0 5 %d\r\naloha\r\n' % (version + 1)), b'STORED\r\n') + self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 0 5 %d\r\naloha\r\nEND\r\n' % (version + 2)) + + self.delete('key') + + def test_curr_items_stat(self): + self.assertEqual(0, int(self.getStat('curr_items'))) + self.setKey('key') + self.assertEqual(1, int(self.getStat('curr_items'))) + self.delete('key') + self.assertEqual(0, int(self.getStat('curr_items'))) + + def test_how_stats_change_with_different_commands(self): + get_count = int(self.getStat('cmd_get')) + set_count = int(self.getStat('cmd_set')) + flush_count = int(self.getStat('cmd_flush')) + total_items = int(self.getStat('total_items')) + get_misses = int(self.getStat('get_misses')) + get_hits = int(self.getStat('get_hits')) + cas_hits = int(self.getStat('cas_hits')) + cas_badval = int(self.getStat('cas_badval')) + cas_misses = int(self.getStat('cas_misses')) + delete_misses = int(self.getStat('delete_misses')) + delete_hits = int(self.getStat('delete_hits')) + curr_connections = int(self.getStat('curr_connections')) + incr_hits = int(self.getStat('incr_hits')) + incr_misses = int(self.getStat('incr_misses')) + decr_hits = int(self.getStat('decr_hits')) + decr_misses = int(self.getStat('decr_misses')) + + call('get key\r\n') + get_count += 1 + get_misses += 1 + + call('gets key\r\n') + get_count += 1 + get_misses += 1 + + call('set key1 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('get key1\r\n') + get_count += 1 + get_hits += 1 + + call('add key1 0 0 1\r\na\r\n') + set_count += 1 + + call('add key2 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('replace key1 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('replace key3 0 0 1\r\na\r\n') + set_count += 1 + + call('cas key4 0 0 1 1\r\na\r\n') + set_count += 1 + cas_misses += 1 + + call('cas key1 0 0 1 %d\r\na\r\n' % self.getItemVersion('key1')) + set_count += 1 + get_count += 1 + get_hits += 1 + cas_hits += 1 + total_items += 1 + + call('cas key1 0 0 1 %d\r\na\r\n' % (self.getItemVersion('key1') + 1)) + set_count += 1 + get_count += 1 + get_hits += 1 + cas_badval += 1 + + call('delete key1\r\n') + delete_hits += 1 + + call('delete key1\r\n') + delete_misses += 1 + + call('incr num 1\r\n') + incr_misses += 1 + call('decr num 1\r\n') + decr_misses += 1 + + call('set num 0 0 1\r\n0\r\n') + set_count += 1 + total_items += 1 + + call('incr num 1\r\n') + incr_hits += 1 + call('decr num 1\r\n') + decr_hits += 1 + + self.flush() + flush_count += 1 + + self.assertEqual(get_count, int(self.getStat('cmd_get'))) + self.assertEqual(set_count, int(self.getStat('cmd_set'))) + self.assertEqual(flush_count, int(self.getStat('cmd_flush'))) + self.assertEqual(total_items, int(self.getStat('total_items'))) + self.assertEqual(get_hits, int(self.getStat('get_hits'))) + self.assertEqual(get_misses, int(self.getStat('get_misses'))) + self.assertEqual(cas_misses, int(self.getStat('cas_misses'))) + self.assertEqual(cas_hits, int(self.getStat('cas_hits'))) + self.assertEqual(cas_badval, int(self.getStat('cas_badval'))) + self.assertEqual(delete_misses, int(self.getStat('delete_misses'))) + self.assertEqual(delete_hits, int(self.getStat('delete_hits'))) + self.assertEqual(0, int(self.getStat('curr_items'))) + self.assertEqual(curr_connections, int(self.getStat('curr_connections'))) + self.assertEqual(incr_misses, int(self.getStat('incr_misses'))) + self.assertEqual(incr_hits, int(self.getStat('incr_hits'))) + self.assertEqual(decr_misses, int(self.getStat('decr_misses'))) + self.assertEqual(decr_hits, int(self.getStat('decr_hits'))) + + def test_incr(self): + self.assertEqual(call('incr key 0\r\n'), b'NOT_FOUND\r\n') + + self.assertEqual(call('set key 0 0 1\r\n0\r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 0\r\n'), b'0\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n0\r\nEND\r\n') + + self.assertEqual(call('incr key 1\r\n'), b'1\r\n') + self.assertEqual(call('incr key 2\r\n'), b'3\r\n') + self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 1)), b'2\r\n') + self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 3)), b'18446744073709551615\r\n') + self.assertRegex(call('incr key 1\r\n').decode(), r'0(\w+)?\r\n') + + self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 1\r\n'), b'2\r\n') + + self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 1\r\n'), b'10\r\n') + + def test_decr(self): + self.assertEqual(call('decr key 0\r\n'), b'NOT_FOUND\r\n') + + self.assertEqual(call('set key 0 0 1\r\n7\r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'6\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n6\r\nEND\r\n') + + self.assertEqual(call('decr key 6\r\n'), b'0\r\n') + self.assertEqual(call('decr key 2\r\n'), b'0\r\n') + + self.assertEqual(call('set key 0 0 2\r\n20\r\n'), b'STORED\r\n') + self.assertRegex(call('decr key 11\r\n').decode(), r'^9( )?\r\n$') + + self.assertEqual(call('set key 0 0 3\r\n100\r\n'), b'STORED\r\n') + self.assertRegex(call('decr key 91\r\n').decode(), r'^9( )?\r\n$') + + self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'0\r\n') + + self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'8\r\n') + + def test_incr_and_decr_on_invalid_input(self): + error_msg = b'CLIENT_ERROR cannot increment or decrement non-numeric value\r\n' + for cmd in ['incr', 'decr']: + for value in ['', '-1', 'a', '0x1', '18446744073709551616']: + self.assertEqual(call('set key 0 0 %d\r\n%s\r\n' % (len(value), value)), b'STORED\r\n') + prev = call('get key\r\n') + self.assertEqual(call(cmd + ' key 1\r\n'), error_msg, "cmd=%s, value=%s" % (cmd, value)) + self.assertEqual(call('get key\r\n'), prev) + self.delete('key') + +def wait_for_memcache_tcp(timeout=4): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + timeout_at = time.time() + timeout + while True: + if time.time() >= timeout_at: + raise TimeoutError() + try: + s.connect(server_addr) + s.close() + break + except ConnectionRefusedError: + time.sleep(0.1) + + +def wait_for_memcache_udp(timeout=4): + timeout_at = time.time() + timeout + while True: + if time.time() >= timeout_at: + raise TimeoutError() + try: + udp_call('version\r\n', timeout=0.2) + break + except socket.timeout: + pass + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="memcache protocol tests") + parser.add_argument('--server', '-s', action="store", help="server adddress in <host>:<port> format", default="localhost:11211") + parser.add_argument('--udp', '-U', action="store_true", help="Use UDP protocol") + parser.add_argument('--fast', action="store_true", help="Run only fast tests") + args = parser.parse_args() + + host, port = args.server.split(':') + server_addr = (host, int(port)) + + if args.udp: + call = udp_call + wait_for_memcache_udp() + else: + call = tcp_call + wait_for_memcache_tcp() + + runner = unittest.TextTestRunner() + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTest(loader.loadTestsFromTestCase(TestCommands)) + if args.udp: + suite.addTest(loader.loadTestsFromTestCase(UdpSpecificTests)) + else: + suite.addTest(loader.loadTestsFromTestCase(TcpSpecificTests)) + result = runner.run(suite) + if not result.wasSuccessful(): + sys.exit(1) |