blob: e337065422a7c58cc29cb65374461773ac2d9b82 [file] [log] [blame]
#!/usr/bin/python
#
# Copyright (C) 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for unittesting the daemon module"""
import unittest
import signal
import os
import socket
import time
import tempfile
import shutil
from ganeti import daemon
from ganeti import errors
from ganeti import constants
from ganeti import utils
import testutils
class TestMainloop(testutils.GanetiTestCase):
"""Test daemon.Mainloop"""
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.mainloop = daemon.Mainloop()
self.sendsig_events = []
self.onsignal_events = []
def _CancelEvent(self, handle):
self.mainloop.scheduler.cancel(handle)
def _SendSig(self, sig):
self.sendsig_events.append(sig)
os.kill(os.getpid(), sig)
def OnSignal(self, signum):
self.onsignal_events.append(signum)
def testRunAndTermBySched(self):
self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run() # terminates by _SendSig being scheduled
self.assertEquals(self.sendsig_events, [signal.SIGTERM])
def testTerminatingSignals(self):
self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
self.mainloop.Run()
self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
signal.SIGTERM])
def testSchedulerCancel(self):
handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
[signal.SIGTERM])
self.mainloop.scheduler.cancel(handle)
self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
def testRegisterSignal(self):
self.mainloop.RegisterSignal(self)
self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
[signal.SIGTERM])
self.mainloop.scheduler.cancel(handle)
self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
# ...not delievered because they are scheduled after TERM
self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.Run()
self.assertEquals(self.sendsig_events,
[signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
self.assertEquals(self.onsignal_events, self.sendsig_events)
def testDeferredCancel(self):
self.mainloop.RegisterSignal(self)
now = time.time()
self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
[signal.SIGCHLD])
handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
[signal.SIGCHLD])
handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
[signal.SIGCHLD])
self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
[handle1])
self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
[handle2])
self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
self.assertEquals(self.onsignal_events, self.sendsig_events)
def testReRun(self):
self.mainloop.RegisterSignal(self)
self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
self.mainloop.Run()
self.assertEquals(self.sendsig_events,
[signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
self.assertEquals(self.onsignal_events, self.sendsig_events)
self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events,
[signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
self.assertEquals(self.onsignal_events, self.sendsig_events)
def testPriority(self):
# for events at the same time, the highest priority one executes first
now = time.time()
self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
[signal.SIGCHLD])
self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
[signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events, [signal.SIGTERM])
self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
self.mainloop.Run()
self.assertEquals(self.sendsig_events,
[signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
def __init__(self, family):
daemon.AsyncUDPSocket.__init__(self, family)
self.received = []
self.error_count = 0
def handle_datagram(self, payload, ip, port):
self.received.append((payload))
if payload == "terminate":
os.kill(os.getpid(), signal.SIGTERM)
elif payload == "error":
raise errors.GenericError("error")
def handle_error(self):
self.error_count += 1
raise
class _BaseAsyncUDPSocketTest:
"""Base class for AsyncUDPSocket tests"""
family = None
address = None
def setUp(self):
self.mainloop = daemon.Mainloop()
self.server = _MyAsyncUDPSocket(self.family)
self.client = _MyAsyncUDPSocket(self.family)
self.server.bind((self.address, 0))
self.port = self.server.getsockname()[1]
# Save utils.IgnoreSignals so we can do evil things to it...
self.saved_utils_ignoresignals = utils.IgnoreSignals
def tearDown(self):
self.server.close()
self.client.close()
# ...and restore it as well
utils.IgnoreSignals = self.saved_utils_ignoresignals
testutils.GanetiTestCase.tearDown(self)
def testNoDoubleBind(self):
self.assertRaises(socket.error, self.client.bind, (self.address, self.port))
def testAsyncClientServer(self):
self.client.enqueue_send(self.address, self.port, "p1")
self.client.enqueue_send(self.address, self.port, "p2")
self.client.enqueue_send(self.address, self.port, "terminate")
self.mainloop.Run()
self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
def testSyncClientServer(self):
self.client.handle_write()
self.client.enqueue_send(self.address, self.port, "p1")
self.client.enqueue_send(self.address, self.port, "p2")
while self.client.writable():
self.client.handle_write()
self.server.process_next_packet()
self.assertEquals(self.server.received, ["p1"])
self.server.process_next_packet()
self.assertEquals(self.server.received, ["p1", "p2"])
self.client.enqueue_send(self.address, self.port, "p3")
while self.client.writable():
self.client.handle_write()
self.server.process_next_packet()
self.assertEquals(self.server.received, ["p1", "p2", "p3"])
def testErrorHandling(self):
self.client.enqueue_send(self.address, self.port, "p1")
self.client.enqueue_send(self.address, self.port, "p2")
self.client.enqueue_send(self.address, self.port, "error")
self.client.enqueue_send(self.address, self.port, "p3")
self.client.enqueue_send(self.address, self.port, "error")
self.client.enqueue_send(self.address, self.port, "terminate")
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.server.received,
["p1", "p2", "error"])
self.assertEquals(self.server.error_count, 1)
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.server.received,
["p1", "p2", "error", "p3", "error"])
self.assertEquals(self.server.error_count, 2)
self.mainloop.Run()
self.assertEquals(self.server.received,
["p1", "p2", "error", "p3", "error", "terminate"])
self.assertEquals(self.server.error_count, 2)
def testSignaledWhileReceiving(self):
utils.IgnoreSignals = lambda fn, *args, **kwargs: None
self.client.enqueue_send(self.address, self.port, "p1")
self.client.enqueue_send(self.address, self.port, "p2")
self.server.handle_read()
self.assertEquals(self.server.received, [])
self.client.enqueue_send(self.address, self.port, "terminate")
utils.IgnoreSignals = self.saved_utils_ignoresignals
self.mainloop.Run()
self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
def testOversizedDatagram(self):
oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
self.address, self.port, oversized_data)
class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
"""Test IP4 daemon.AsyncUDPSocket"""
family = socket.AF_INET
address = "127.0.0.1"
def setUp(self):
testutils.GanetiTestCase.setUp(self)
_BaseAsyncUDPSocketTest.setUp(self)
def tearDown(self):
testutils.GanetiTestCase.tearDown(self)
_BaseAsyncUDPSocketTest.tearDown(self)
class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
"""Test IP6 daemon.AsyncUDPSocket"""
family = socket.AF_INET6
address = "::1"
def setUp(self):
testutils.GanetiTestCase.setUp(self)
_BaseAsyncUDPSocketTest.setUp(self)
def tearDown(self):
testutils.GanetiTestCase.tearDown(self)
_BaseAsyncUDPSocketTest.tearDown(self)
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
def __init__(self, family, address, handle_connection_fn):
daemon.AsyncStreamServer.__init__(self, family, address)
self.handle_connection_fn = handle_connection_fn
self.error_count = 0
self.expt_count = 0
def handle_connection(self, connected_socket, client_address):
self.handle_connection_fn(connected_socket, client_address)
def handle_error(self):
self.error_count += 1
self.close()
raise
def handle_expt(self):
self.expt_count += 1
self.close()
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
def __init__(self, connected_socket, client_address, terminator, family,
message_fn, client_id, unhandled_limit):
daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
client_address,
terminator, family,
unhandled_limit)
self.message_fn = message_fn
self.client_id = client_id
self.error_count = 0
def handle_message(self, message, message_id):
self.message_fn(self, message, message_id)
def handle_error(self):
self.error_count += 1
raise
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
"""Test daemon.AsyncStreamServer with a TCP connection"""
family = socket.AF_INET
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.mainloop = daemon.Mainloop()
self.address = self.getAddress()
self.server = _MyAsyncStreamServer(self.family, self.address,
self.handle_connection)
self.client_handler = _MyMessageStreamHandler
self.unhandled_limit = None
self.terminator = "\3"
self.address = self.server.getsockname()
self.clients = []
self.connections = []
self.messages = {}
self.connect_terminate_count = 0
self.message_terminate_count = 0
self.next_client_id = 0
# Save utils.IgnoreSignals so we can do evil things to it...
self.saved_utils_ignoresignals = utils.IgnoreSignals
def tearDown(self):
for c in self.clients:
c.close()
for c in self.connections:
c.close()
self.server.close()
# ...and restore it as well
utils.IgnoreSignals = self.saved_utils_ignoresignals
testutils.GanetiTestCase.tearDown(self)
def getAddress(self):
return ("127.0.0.1", 0)
def countTerminate(self, name):
value = getattr(self, name)
if value is not None:
value -= 1
setattr(self, name, value)
if value <= 0:
os.kill(os.getpid(), signal.SIGTERM)
def handle_connection(self, connected_socket, client_address):
client_id = self.next_client_id
self.next_client_id += 1
client_handler = self.client_handler(connected_socket, client_address,
self.terminator, self.family,
self.handle_message,
client_id, self.unhandled_limit)
self.connections.append(client_handler)
self.countTerminate("connect_terminate_count")
def handle_message(self, handler, message, message_id):
self.messages.setdefault(handler.client_id, [])
# We should just check that the message_ids are monotonically increasing.
# If in the unit tests we never remove messages from the received queue,
# though, we can just require that the queue length is the same as the
# message id, before pushing the message to it. This forces a more
# restrictive check, but we can live with this for now.
self.assertEquals(len(self.messages[handler.client_id]), message_id)
self.messages[handler.client_id].append(message)
if message == "error":
raise errors.GenericError("error")
self.countTerminate("message_terminate_count")
def getClient(self):
client = socket.socket(self.family, socket.SOCK_STREAM)
client.connect(self.address)
self.clients.append(client)
return client
def tearDown(self):
testutils.GanetiTestCase.tearDown(self)
self.server.close()
def testConnect(self):
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 2)
self.connect_terminate_count = 4
self.getClient()
self.getClient()
self.getClient()
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 6)
def testBasicMessage(self):
self.connect_terminate_count = None
client = self.getClient()
client.send("ciao\3")
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.assertEquals(len(self.messages[0]), 1)
self.assertEquals(self.messages[0][0], "ciao")
def testDoubleMessage(self):
self.connect_terminate_count = None
client = self.getClient()
client.send("ciao\3")
self.mainloop.Run()
client.send("foobar\3")
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.assertEquals(len(self.messages[0]), 2)
self.assertEquals(self.messages[0][1], "foobar")
def testComposedMessage(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
client = self.getClient()
client.send("one\3composed\3message\3")
self.mainloop.Run()
self.assertEquals(len(self.messages[0]), 3)
self.assertEquals(self.messages[0], ["one", "composed", "message"])
def testLongTerminator(self):
self.terminator = "\0\1\2"
self.connect_terminate_count = None
self.message_terminate_count = 3
client = self.getClient()
client.send("one\0\1\2composed\0\1\2message\0\1\2")
self.mainloop.Run()
self.assertEquals(len(self.messages[0]), 3)
self.assertEquals(self.messages[0], ["one", "composed", "message"])
def testErrorHandling(self):
self.connect_terminate_count = None
self.message_terminate_count = None
client = self.getClient()
client.send("one\3two\3error\3three\3")
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.connections[0].error_count, 1)
self.assertEquals(self.messages[0], ["one", "two", "error"])
client.send("error\3")
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.connections[0].error_count, 2)
self.assertEquals(self.messages[0], ["one", "two", "error", "three",
"error"])
def testDoubleClient(self):
self.connect_terminate_count = None
self.message_terminate_count = 2
client1 = self.getClient()
client2 = self.getClient()
client1.send("c1m1\3")
client2.send("c2m1\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["c1m1"])
self.assertEquals(self.messages[1], ["c2m1"])
def testUnterminatedMessage(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
client1 = self.getClient()
client2 = self.getClient()
client1.send("message\3unterminated")
client2.send("c2m1\3c2m2\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["message"])
self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
client1.send("message\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
def testSignaledWhileAccepting(self):
utils.IgnoreSignals = lambda fn, *args, **kwargs: None
client1 = self.getClient()
self.server.handle_accept()
# When interrupted while accepting we don't have a connection, but we
# didn't crash either.
self.assertEquals(len(self.connections), 0)
utils.IgnoreSignals = self.saved_utils_ignoresignals
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
def testSendMessage(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
client1 = self.getClient()
client2 = self.getClient()
client1.send("one\3composed\3message\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["one", "composed", "message"])
self.assertFalse(self.connections[0].writable())
self.assertFalse(self.connections[1].writable())
self.connections[0].send_message("r0")
self.assert_(self.connections[0].writable())
self.assertFalse(self.connections[1].writable())
self.connections[0].send_message("r1")
self.connections[0].send_message("r2")
# We currently have no way to terminate the mainloop on write events, but
# let's assume handle_write will be called if writable() is True.
while self.connections[0].writable():
self.connections[0].handle_write()
client1.setblocking(0)
client2.setblocking(0)
self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
self.assertRaises(socket.error, client2.recv, 4096)
def testLimitedUnhandledMessages(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
self.unhandled_limit = 2
client1 = self.getClient()
client2 = self.getClient()
client1.send("one\3composed\3long\3message\3")
client2.send("c2one\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["one", "composed"])
self.assertEquals(self.messages[1], ["c2one"])
self.assertFalse(self.connections[0].readable())
self.assert_(self.connections[1].readable())
self.connections[0].send_message("r0")
self.message_terminate_count = None
client1.send("another\3")
# when we write replies messages queued also get handled, but not the ones
# in the socket.
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertFalse(self.connections[0].readable())
self.assertEquals(self.messages[0], ["one", "composed", "long"])
self.connections[0].send_message("r1")
self.connections[0].send_message("r2")
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
self.assert_(self.connections[0].readable())
def testLimitedUnhandledMessagesOne(self):
self.connect_terminate_count = None
self.message_terminate_count = 2
self.unhandled_limit = 1
client1 = self.getClient()
client2 = self.getClient()
client1.send("one\3composed\3message\3")
client2.send("c2one\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["one"])
self.assertEquals(self.messages[1], ["c2one"])
self.assertFalse(self.connections[0].readable())
self.assertFalse(self.connections[1].readable())
self.connections[0].send_message("r0")
self.message_terminate_count = None
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertFalse(self.connections[0].readable())
self.assertEquals(self.messages[0], ["one", "composed"])
self.connections[0].send_message("r2")
self.connections[0].send_message("r3")
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertEquals(self.messages[0], ["one", "composed", "message"])
self.assert_(self.connections[0].readable())
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
"""Test daemon.AsyncStreamServer with a Unix path connection"""
family = socket.AF_UNIX
def getAddress(self):
self.tmpdir = tempfile.mkdtemp()
return os.path.join(self.tmpdir, "server.sock")
def tearDown(self):
shutil.rmtree(self.tmpdir)
TestAsyncStreamServerTCP.tearDown(self)
class TestAsyncStreamServerUnixAbstract(TestAsyncStreamServerTCP):
"""Test daemon.AsyncStreamServer with a Unix abstract connection"""
family = socket.AF_UNIX
def getAddress(self):
return "\0myabstractsocketaddress"
class TestAsyncAwaker(testutils.GanetiTestCase):
"""Test daemon.AsyncAwaker"""
family = socket.AF_INET
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.mainloop = daemon.Mainloop()
self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
self.signal_count = 0
self.signal_terminate_count = 1
def tearDown(self):
self.awaker.close()
def handle_signal(self):
self.signal_count += 1
self.signal_terminate_count -= 1
if self.signal_terminate_count <= 0:
os.kill(os.getpid(), signal.SIGTERM)
def testBasicSignaling(self):
self.awaker.signal()
self.mainloop.Run()
self.assertEquals(self.signal_count, 1)
def testDoubleSignaling(self):
self.awaker.signal()
self.awaker.signal()
self.mainloop.Run()
# The second signal is never delivered
self.assertEquals(self.signal_count, 1)
def testReallyDoubleSignaling(self):
self.assert_(self.awaker.readable())
self.awaker.signal()
# Let's suppose two threads overlap, and both find need_signal True
self.awaker.need_signal = True
self.awaker.signal()
self.mainloop.Run()
# We still get only one signaling
self.assertEquals(self.signal_count, 1)
def testNoSignalFnArgument(self):
myawaker = daemon.AsyncAwaker()
self.assertRaises(socket.error, myawaker.handle_read)
myawaker.signal()
myawaker.handle_read()
self.assertRaises(socket.error, myawaker.handle_read)
myawaker.signal()
myawaker.signal()
myawaker.handle_read()
self.assertRaises(socket.error, myawaker.handle_read)
myawaker.close()
def testWrongSignalFnArgument(self):
self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
if __name__ == "__main__":
testutils.GanetiTestProgram()