#!/usr/bin/env python3 # # This file is open source software, licensed to you under the terms # of the Apache License, Version 2.0 (the "License"). See the NOTICE file # distributed with this work for additional information regarding copyright # ownership. You may not use this file except in compliance with the License. # # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # from contextlib import contextmanager import socket import struct import sys import random import argparse import time import re import unittest server_addr = None call = None args = None class TimeoutError(Exception): pass @contextmanager def tcp_connection(timeout=1): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(timeout) s.connect(server_addr) def call(msg): s.send(msg.encode()) return s.recv(16*1024) yield call s.close() def slow(f): def wrapper(self): if args.fast: raise unittest.SkipTest('Slow') return f(self) return wrapper def recv_all(s): m = b'' while True: data = s.recv(1024) if not data: break m += data return m def tcp_call(msg, timeout=1): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(timeout) s.connect(server_addr) s.send(msg.encode()) s.shutdown(socket.SHUT_WR) data = recv_all(s) s.close() return data def udp_call_for_fragments(msg, timeout=1): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) this_req_id = random.randint(-32768, 32767) datagram = struct.pack(">hhhh", this_req_id, 0, 1, 0) + msg.encode() sock.sendto(datagram, server_addr) messages = {} n_determined = None while True: data, addr = sock.recvfrom(1500) req_id, seq, n, res = struct.unpack_from(">hhhh", data) content = data[8:] if n_determined and n_determined != n: raise Exception('Inconsitent number of total messages, %d and %d' % (n_determined, n)) n_determined = n if req_id != this_req_id: raise Exception('Invalid request id: ' + req_id + ', expected ' + this_req_id) if seq in messages: raise Exception('Duplicate message for seq=' + seq) messages[seq] = content if len(messages) == n: break for k, v in sorted(messages.items(), key=lambda e: e[0]): yield v sock.close() def udp_call(msg, **kwargs): return b''.join(udp_call_for_fragments(msg, **kwargs)) class MemcacheTest(unittest.TestCase): def set(self, key, value, flags=0, expiry=0): self.assertEqual(call('set %s %d %d %d\r\n%s\r\n' % (key, flags, expiry, len(value), value)), b'STORED\r\n') def delete(self, key): self.assertEqual(call('delete %s\r\n' % key), b'DELETED\r\n') def assertHasKey(self, key): resp = call('get %s\r\n' % key) if not resp.startswith(('VALUE %s' % key).encode()): self.fail('Key \'%s\' should be present, but got: %s' % (key, resp.decode())) def assertNoKey(self, key): resp = call('get %s\r\n' % key) if resp != b'END\r\n': self.fail('Key \'%s\' should not be present, but got: %s' % (key, resp.decode())) def setKey(self, key): self.set(key, 'some value') def getItemVersion(self, key): m = re.match(r'VALUE %s \d+ \d+ (?P\d+)' % key, call('gets %s\r\n' % key).decode()) return int(m.group('version')) def getStat(self, name, call_fn=None): if not call_fn: call_fn = call resp = call_fn('stats\r\n').decode() m = re.search(r'STAT %s (?P.+)' % re.escape(name), resp, re.MULTILINE) return m.group('value') def flush(self): self.assertEqual(call('flush_all\r\n'), b'OK\r\n') def tearDown(self): self.flush() class TcpSpecificTests(MemcacheTest): def test_recovers_from_errors_in_the_stream(self): with tcp_connection() as conn: self.assertEqual(conn('get\r\n'), b'ERROR\r\n') self.assertEqual(conn('get key\r\n'), b'END\r\n') def test_incomplete_command_results_in_error(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect(server_addr) s.send(b'get') s.shutdown(socket.SHUT_WR) self.assertEqual(recv_all(s), b'ERROR\r\n') s.close() def test_stream_closed_results_in_error(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect(server_addr) s.shutdown(socket.SHUT_WR) self.assertEqual(recv_all(s), b'') s.close() def test_unsuccesful_parsing_does_not_leave_data_behind(self): with tcp_connection() as conn: self.assertEqual(conn('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') self.assertRegex(conn('delete a b c\r\n'), b'^(CLIENT_)?ERROR.*\r\n$') self.assertEqual(conn('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') self.assertEqual(conn('delete key\r\n'), b'DELETED\r\n') def test_flush_all_no_reply(self): self.assertEqual(call('flush_all noreply\r\n'), b'') def test_set_no_reply(self): self.assertEqual(call('set key 0 0 5 noreply\r\nhello\r\nget key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') self.delete('key') def test_delete_no_reply(self): self.setKey('key') self.assertEqual(call('delete key noreply\r\nget key\r\n'), b'END\r\n') def test_add_no_reply(self): self.assertEqual(call('add key 0 0 1 noreply\r\na\r\nget key\r\n'), b'VALUE key 0 1\r\na\r\nEND\r\n') self.delete('key') def test_replace_no_reply(self): self.assertEqual(call('set key 0 0 1\r\na\r\n'), b'STORED\r\n') self.assertEqual(call('replace key 0 0 1 noreply\r\nb\r\nget key\r\n'), b'VALUE key 0 1\r\nb\r\nEND\r\n') self.delete('key') def test_cas_noreply(self): self.assertNoKey('key') self.assertEqual(call('cas key 0 0 1 1 noreply\r\na\r\n'), b'') self.assertNoKey('key') self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') version = self.getItemVersion('key') self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version + 1)), b'') self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version)), b'') self.assertEqual(call('get key\r\n'), b'VALUE key 1 5\r\naloha\r\nEND\r\n') self.delete('key') @slow def test_connection_statistics(self): with tcp_connection() as conn: curr_connections = int(self.getStat('curr_connections', call_fn=conn)) total_connections = int(self.getStat('total_connections', call_fn=conn)) with tcp_connection() as conn2: self.assertEqual(curr_connections + 1, int(self.getStat('curr_connections', call_fn=conn))) self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) self.assertEqual(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) time.sleep(0.1) self.assertEqual(curr_connections, int(self.getStat('curr_connections', call_fn=conn))) class UdpSpecificTests(MemcacheTest): def test_large_response_is_split_into_mtu_chunks(self): max_datagram_size = 1400 data = '1' * (max_datagram_size*3) self.set('key', data) chunks = list(udp_call_for_fragments('get key\r\n')) for chunk in chunks: self.assertLessEqual(len(chunk), max_datagram_size) self.assertEqual(b''.join(chunks).decode(), 'VALUE key 0 %d\r\n%s\r\n' \ 'END\r\n' % (len(data), data)) self.delete('key') class TestCommands(MemcacheTest): def test_basic_commands(self): self.assertEqual(call('get key\r\n'), b'END\r\n') self.assertEqual(call('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') self.assertEqual(call('delete key\r\n'), b'DELETED\r\n') self.assertEqual(call('delete key\r\n'), b'NOT_FOUND\r\n') self.assertEqual(call('get key\r\n'), b'END\r\n') def test_error_handling(self): self.assertEqual(call('get\r\n'), b'ERROR\r\n') @slow def test_expiry(self): self.assertEqual(call('set key 0 1 5\r\nhello\r\n'), b'STORED\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') time.sleep(2) self.assertEqual(call('get key\r\n'), b'END\r\n') @slow def test_expiry_at_epoch_time(self): expiry = int(time.time()) + 1 self.assertEqual(call('set key 0 %d 5\r\nhello\r\n' % expiry), b'STORED\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') time.sleep(2) self.assertEqual(call('get key\r\n'), b'END\r\n') def test_multiple_keys_in_get(self): self.assertEqual(call('set key1 0 0 2\r\nv1\r\n'), b'STORED\r\n') self.assertEqual(call('set key 0 0 2\r\nv2\r\n'), b'STORED\r\n') resp = call('get key1 key\r\n') self.assertRegex(resp, b'^(VALUE key1 0 2\r\nv1\r\nVALUE key 0 2\r\nv2\r\nEND\r\n)|(VALUE key 0 2\r\nv2\r\nVALUE key1 0 2\r\nv1\r\nEND\r\n)$') self.delete("key") self.delete("key1") def test_flush_all(self): self.set('key', 'value') self.assertEqual(call('flush_all\r\n'), b'OK\r\n') self.assertNoKey('key') def test_keys_set_after_flush_remain(self): self.assertEqual(call('flush_all\r\n'), b'OK\r\n') self.setKey('key') self.assertHasKey('key') self.delete('key') @slow def test_flush_all_with_timeout_flushes_all_keys_even_those_set_after_flush(self): self.setKey('key') self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') self.assertHasKey('key') self.setKey('key2') time.sleep(3) self.assertNoKey('key') self.assertNoKey('key2') @slow def test_subsequent_flush_is_merged(self): self.setKey('key') self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') # Can flush in anything between 1-2 self.assertEqual(call('flush_all 4\r\n'), b'OK\r\n') # Can flush in anything between 3-4 time.sleep(3) self.assertHasKey('key') self.setKey('key2') time.sleep(4) self.assertNoKey('key') self.assertNoKey('key2') @slow def test_immediate_flush_cancels_delayed_flush(self): self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') self.assertEqual(call('flush_all\r\n'), b'OK\r\n') self.setKey('key') time.sleep(1) self.assertHasKey('key') self.delete('key') @slow def test_flushing_in_the_past(self): self.setKey('key1') time.sleep(1) self.setKey('key2') key2_time = int(time.time()) self.assertEqual(call('flush_all %d\r\n' % (key2_time - 1)), b'OK\r\n') time.sleep(1) self.assertNoKey("key1") self.assertNoKey("key2") @slow def test_memcache_does_not_crash_when_flushing_with_already_expred_items(self): self.assertEqual(call('set key1 0 2 5\r\nhello\r\n'), b'STORED\r\n') time.sleep(1) self.assertEqual(call('flush_all\r\n'), b'OK\r\n') def test_response_spanning_many_datagrams(self): key1_data = '1' * 1000 key2_data = '2' * 1000 key3_data = '3' * 1000 self.set('key1', key1_data) self.set('key2', key2_data) self.set('key3', key3_data) resp = call('get key1 key2 key3\r\n').decode() pattern = '^VALUE (?P.*?\r\n.*?)\r\nVALUE (?P.*?\r\n.*?)\r\nVALUE (?P.*?\r\n.*?)\r\nEND\r\n$' self.assertRegex(resp, pattern) m = re.match(pattern, resp) self.assertEqual(set([m.group('v1'), m.group('v2'), m.group('v3')]), set(['key1 0 %d\r\n%s' % (len(key1_data), key1_data), 'key2 0 %d\r\n%s' % (len(key2_data), key2_data), 'key3 0 %d\r\n%s' % (len(key3_data), key3_data)])) self.delete('key1') self.delete('key2') self.delete('key3') def test_version(self): self.assertRegex(call('version\r\n'), b'^VERSION .*\r\n$') def test_add(self): self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') self.delete('key') def test_replace(self): self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'STORED\r\n') self.delete('key') self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') def test_cas_and_gets(self): self.assertEqual(call('cas key 0 0 1 1\r\na\r\n'), b'NOT_FOUND\r\n') self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') version = self.getItemVersion('key') self.assertEqual(call('set key 1 0 5\r\nhello\r\n'), b'STORED\r\n') self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 1 5 %d\r\nhello\r\nEND\r\n' % (version + 1)) self.assertEqual(call('cas key 0 0 5 %d\r\nhello\r\n' % (version)), b'EXISTS\r\n') self.assertEqual(call('cas key 0 0 5 %d\r\naloha\r\n' % (version + 1)), b'STORED\r\n') self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 0 5 %d\r\naloha\r\nEND\r\n' % (version + 2)) self.delete('key') def test_curr_items_stat(self): self.assertEqual(0, int(self.getStat('curr_items'))) self.setKey('key') self.assertEqual(1, int(self.getStat('curr_items'))) self.delete('key') self.assertEqual(0, int(self.getStat('curr_items'))) def test_how_stats_change_with_different_commands(self): get_count = int(self.getStat('cmd_get')) set_count = int(self.getStat('cmd_set')) flush_count = int(self.getStat('cmd_flush')) total_items = int(self.getStat('total_items')) get_misses = int(self.getStat('get_misses')) get_hits = int(self.getStat('get_hits')) cas_hits = int(self.getStat('cas_hits')) cas_badval = int(self.getStat('cas_badval')) cas_misses = int(self.getStat('cas_misses')) delete_misses = int(self.getStat('delete_misses')) delete_hits = int(self.getStat('delete_hits')) curr_connections = int(self.getStat('curr_connections')) incr_hits = int(self.getStat('incr_hits')) incr_misses = int(self.getStat('incr_misses')) decr_hits = int(self.getStat('decr_hits')) decr_misses = int(self.getStat('decr_misses')) call('get key\r\n') get_count += 1 get_misses += 1 call('gets key\r\n') get_count += 1 get_misses += 1 call('set key1 0 0 1\r\na\r\n') set_count += 1 total_items += 1 call('get key1\r\n') get_count += 1 get_hits += 1 call('add key1 0 0 1\r\na\r\n') set_count += 1 call('add key2 0 0 1\r\na\r\n') set_count += 1 total_items += 1 call('replace key1 0 0 1\r\na\r\n') set_count += 1 total_items += 1 call('replace key3 0 0 1\r\na\r\n') set_count += 1 call('cas key4 0 0 1 1\r\na\r\n') set_count += 1 cas_misses += 1 call('cas key1 0 0 1 %d\r\na\r\n' % self.getItemVersion('key1')) set_count += 1 get_count += 1 get_hits += 1 cas_hits += 1 total_items += 1 call('cas key1 0 0 1 %d\r\na\r\n' % (self.getItemVersion('key1') + 1)) set_count += 1 get_count += 1 get_hits += 1 cas_badval += 1 call('delete key1\r\n') delete_hits += 1 call('delete key1\r\n') delete_misses += 1 call('incr num 1\r\n') incr_misses += 1 call('decr num 1\r\n') decr_misses += 1 call('set num 0 0 1\r\n0\r\n') set_count += 1 total_items += 1 call('incr num 1\r\n') incr_hits += 1 call('decr num 1\r\n') decr_hits += 1 self.flush() flush_count += 1 self.assertEqual(get_count, int(self.getStat('cmd_get'))) self.assertEqual(set_count, int(self.getStat('cmd_set'))) self.assertEqual(flush_count, int(self.getStat('cmd_flush'))) self.assertEqual(total_items, int(self.getStat('total_items'))) self.assertEqual(get_hits, int(self.getStat('get_hits'))) self.assertEqual(get_misses, int(self.getStat('get_misses'))) self.assertEqual(cas_misses, int(self.getStat('cas_misses'))) self.assertEqual(cas_hits, int(self.getStat('cas_hits'))) self.assertEqual(cas_badval, int(self.getStat('cas_badval'))) self.assertEqual(delete_misses, int(self.getStat('delete_misses'))) self.assertEqual(delete_hits, int(self.getStat('delete_hits'))) self.assertEqual(0, int(self.getStat('curr_items'))) self.assertEqual(curr_connections, int(self.getStat('curr_connections'))) self.assertEqual(incr_misses, int(self.getStat('incr_misses'))) self.assertEqual(incr_hits, int(self.getStat('incr_hits'))) self.assertEqual(decr_misses, int(self.getStat('decr_misses'))) self.assertEqual(decr_hits, int(self.getStat('decr_hits'))) def test_incr(self): self.assertEqual(call('incr key 0\r\n'), b'NOT_FOUND\r\n') self.assertEqual(call('set key 0 0 1\r\n0\r\n'), b'STORED\r\n') self.assertEqual(call('incr key 0\r\n'), b'0\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n0\r\nEND\r\n') self.assertEqual(call('incr key 1\r\n'), b'1\r\n') self.assertEqual(call('incr key 2\r\n'), b'3\r\n') self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 1)), b'2\r\n') self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 3)), b'18446744073709551615\r\n') self.assertRegex(call('incr key 1\r\n').decode(), r'0(\w+)?\r\n') self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') self.assertEqual(call('incr key 1\r\n'), b'2\r\n') self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') self.assertEqual(call('incr key 1\r\n'), b'10\r\n') def test_decr(self): self.assertEqual(call('decr key 0\r\n'), b'NOT_FOUND\r\n') self.assertEqual(call('set key 0 0 1\r\n7\r\n'), b'STORED\r\n') self.assertEqual(call('decr key 1\r\n'), b'6\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n6\r\nEND\r\n') self.assertEqual(call('decr key 6\r\n'), b'0\r\n') self.assertEqual(call('decr key 2\r\n'), b'0\r\n') self.assertEqual(call('set key 0 0 2\r\n20\r\n'), b'STORED\r\n') self.assertRegex(call('decr key 11\r\n').decode(), r'^9( )?\r\n$') self.assertEqual(call('set key 0 0 3\r\n100\r\n'), b'STORED\r\n') self.assertRegex(call('decr key 91\r\n').decode(), r'^9( )?\r\n$') self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') self.assertEqual(call('decr key 1\r\n'), b'0\r\n') self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') self.assertEqual(call('decr key 1\r\n'), b'8\r\n') def test_incr_and_decr_on_invalid_input(self): error_msg = b'CLIENT_ERROR cannot increment or decrement non-numeric value\r\n' for cmd in ['incr', 'decr']: for value in ['', '-1', 'a', '0x1', '18446744073709551616']: self.assertEqual(call('set key 0 0 %d\r\n%s\r\n' % (len(value), value)), b'STORED\r\n') prev = call('get key\r\n') self.assertEqual(call(cmd + ' key 1\r\n'), error_msg, "cmd=%s, value=%s" % (cmd, value)) self.assertEqual(call('get key\r\n'), prev) self.delete('key') def wait_for_memcache_tcp(timeout=4): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) timeout_at = time.time() + timeout while True: if time.time() >= timeout_at: raise TimeoutError() try: s.connect(server_addr) s.close() break except ConnectionRefusedError: time.sleep(0.1) def wait_for_memcache_udp(timeout=4): timeout_at = time.time() + timeout while True: if time.time() >= timeout_at: raise TimeoutError() try: udp_call('version\r\n', timeout=0.2) break except socket.timeout: pass if __name__ == '__main__': parser = argparse.ArgumentParser(description="memcache protocol tests") parser.add_argument('--server', '-s', action="store", help="server adddress in : format", default="localhost:11211") parser.add_argument('--udp', '-U', action="store_true", help="Use UDP protocol") parser.add_argument('--fast', action="store_true", help="Run only fast tests") args = parser.parse_args() host, port = args.server.split(':') server_addr = (host, int(port)) if args.udp: call = udp_call wait_for_memcache_udp() else: call = tcp_call wait_for_memcache_tcp() runner = unittest.TextTestRunner() loader = unittest.TestLoader() suite = unittest.TestSuite() suite.addTest(loader.loadTestsFromTestCase(TestCommands)) if args.udp: suite.addTest(loader.loadTestsFromTestCase(UdpSpecificTests)) else: suite.addTest(loader.loadTestsFromTestCase(TcpSpecificTests)) result = runner.run(suite) if not result.wasSuccessful(): sys.exit(1)