import os
import unittest
import paramiko

from shutil import copyfile
from paramiko.client import RejectPolicy, WarningPolicy
from tests.utils import make_tests_data_path
from webssh.policy import (
    AutoAddPolicy, get_policy_dictionary, load_host_keys,
    get_policy_class, check_policy_setting
)


class TestPolicy(unittest.TestCase):

    def test_get_policy_dictionary(self):
        classes = [AutoAddPolicy, RejectPolicy, WarningPolicy]
        dic = get_policy_dictionary()
        for cls in classes:
            val = dic[cls.__name__.lower()]
            self.assertIs(cls, val)

    def test_load_host_keys(self):
        path = '/path-not-exists'
        host_keys = load_host_keys(path)
        self.assertFalse(host_keys)

        path = '/tmp'
        host_keys = load_host_keys(path)
        self.assertFalse(host_keys)

        path = make_tests_data_path('known_hosts_example')
        host_keys = load_host_keys(path)
        self.assertEqual(host_keys, paramiko.hostkeys.HostKeys(path))

    def test_get_policy_class(self):
        keys = ['autoadd', 'reject', 'warning']
        vals = [AutoAddPolicy, RejectPolicy, WarningPolicy]
        for key, val in zip(keys, vals):
            cls = get_policy_class(key)
            self.assertIs(cls, val)

        key = 'non-exists'
        with self.assertRaises(ValueError):
            get_policy_class(key)

    def test_check_policy_setting(self):
        host_keys_filename = make_tests_data_path('host_keys_test.db')
        host_keys_settings = dict(
            host_keys=paramiko.hostkeys.HostKeys(),
            system_host_keys=paramiko.hostkeys.HostKeys(),
            host_keys_filename=host_keys_filename
        )

        with self.assertRaises(ValueError):
            check_policy_setting(RejectPolicy, host_keys_settings)

        try:
            os.unlink(host_keys_filename)
        except OSError:
            pass
        check_policy_setting(AutoAddPolicy, host_keys_settings)
        self.assertEqual(os.path.exists(host_keys_filename), True)

    def test_is_missing_host_key(self):
        client = paramiko.SSHClient()
        file1 = make_tests_data_path('known_hosts_example')
        file2 = make_tests_data_path('known_hosts_example2')
        client.load_host_keys(file1)
        client.load_system_host_keys(file2)

        autoadd = AutoAddPolicy()
        for f in [file1, file2]:
            entry = paramiko.hostkeys.HostKeys(f)._entries[0]
            hostname = entry.hostnames[0]
            key = entry.key
            self.assertIsNone(
                autoadd.is_missing_host_key(client, hostname, key)
            )

        for f in [file1, file2]:
            entry = paramiko.hostkeys.HostKeys(f)._entries[0]
            hostname = entry.hostnames[0]
            key = entry.key
            key.get_name = lambda: 'unknown'
            self.assertTrue(
                autoadd.is_missing_host_key(client, hostname, key)
            )
        del key.get_name

        for f in [file1, file2]:
            entry = paramiko.hostkeys.HostKeys(f)._entries[0]
            hostname = entry.hostnames[0][1:]
            key = entry.key
            self.assertTrue(
                autoadd.is_missing_host_key(client, hostname, key)
            )

        file3 = make_tests_data_path('known_hosts_example3')
        entry = paramiko.hostkeys.HostKeys(file3)._entries[0]
        hostname = entry.hostnames[0]
        key = entry.key
        with self.assertRaises(paramiko.BadHostKeyException):
            autoadd.is_missing_host_key(client, hostname, key)

    def test_missing_host_key(self):
        client = paramiko.SSHClient()
        file1 = make_tests_data_path('known_hosts_example')
        file2 = make_tests_data_path('known_hosts_example2')
        filename = make_tests_data_path('known_hosts')
        copyfile(file1, filename)
        client.load_host_keys(filename)
        n1 = len(client._host_keys)

        autoadd = AutoAddPolicy()
        entry = paramiko.hostkeys.HostKeys(file2)._entries[0]
        hostname = entry.hostnames[0]
        key = entry.key
        autoadd.missing_host_key(client, hostname, key)
        self.assertEqual(len(client._host_keys),  n1 + 1)
        self.assertEqual(paramiko.hostkeys.HostKeys(filename),
                         client._host_keys)
        os.unlink(filename)
