blob: eedc85238019155cd8fea1f69447a49625a0b18b [file] [log] [blame]
#!/usr/bin/python
#
# Copyright (C) 2006, 2007, 2008 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for unittesting the ssh module"""
import os
import tempfile
import unittest
import shutil
import testutils
import mocks
from ganeti import constants
from ganeti import utils
from ganeti import ssh
from ganeti import errors
class TestKnownHosts(testutils.GanetiTestCase):
"""Test case for function writing the known_hosts file"""
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.tmpfile = self._CreateTempFile()
def test(self):
cfg = mocks.FakeConfig()
ssh.WriteKnownHostsFile(cfg, self.tmpfile)
self.assertFileContent(self.tmpfile,
"%s ssh-rsa %s\n%s ssh-dss %s\n" %
(cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY,
cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY))
class TestGetUserFiles(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
@staticmethod
def _GetNoHomedir(_):
return None
def _GetTempHomedir(self, _):
return self.tmpdir
def testNonExistantUser(self):
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
kind=kind, _homedir_fn=self._GetNoHomedir)
def testUnknownKind(self):
kind = "something-else"
assert kind not in constants.SSHK_ALL
self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testNoSshDirectory(self):
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testSshIsFile(self):
utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
def testMakeSshDirectory(self):
sshdir = os.path.join(self.tmpdir, ".ssh")
self.assertEqual(os.listdir(self.tmpdir), [])
for kind in constants.SSHK_ALL:
ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
def testFilenames(self):
sshdir = os.path.join(self.tmpdir, ".ssh")
os.mkdir(sshdir)
for kind in constants.SSHK_ALL:
result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(result, [
os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
])
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
self.assertEqual(os.listdir(sshdir), [])
def testNoDirCheck(self):
self.assertEqual(os.listdir(self.tmpdir), [])
for kind in constants.SSHK_ALL:
ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testGetAllUserFiles(self):
result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(result,
(os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
constants.SSHK_RSA:
(os.path.join(self.tmpdir, ".ssh", "id_rsa"),
os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
constants.SSHK_DSA:
(os.path.join(self.tmpdir, ".ssh", "id_dsa"),
os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
}))
self.assertEqual(os.listdir(self.tmpdir), [])
def testGetAllUserFilesNoDirectoryNoMkdir(self):
self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
"example17270", mkdir=False, dircheck=True,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
if __name__ == "__main__":
testutils.GanetiTestProgram()