#!/usr/bin/env python3

# test if hashing round configuration works
# LP: #2125685

import binascii
import re
import subprocess
import shlex
import unittest
from typing import Optional, Tuple

from passlib.hash import ldap_pbkdf2_sha512


def get_slapd_hash(password: str, rounds: Optional[int]) -> Tuple[str, int, bytes]:
    iterations = f" {rounds}" if rounds else ""
    hash_raw = subprocess.check_output(
        shlex.split(
            f'slappasswd -o module-load="pw-pbkdf2.so{iterations}" -h {{PBKDF2-SHA512}} -s {password}'
        )).decode().strip()

    mat = re.match(r"{PBKDF2-SHA512}(?P<rounds>\d+)\$(?P<salt_b64>[^\$]+)\$(?P<hash>.+)", hash_raw)
    if not mat:
        raise Exception(f"output hash has unknown format: {hash_raw!r}")

    # it's not padded, and + are . for reasons.
    salt_b64 = mat.group("salt_b64").replace(".", "+") + "=="
    salt_raw = binascii.a2b_base64(salt_b64)

    return hash_raw, int(mat.group("rounds")), salt_raw


def get_py_hash(password: str, rounds: int, salt: bytes):
    return ldap_pbkdf2_sha512.using(rounds=rounds, salt=salt) \
                             .hash(password) \
                             .replace("+", ".")  # for compat with slappasswd hash

class TestStringMethods(unittest.TestCase):

    def test_defaults(self):
        password = "Hunter2"

        ldap_hash, rounds, salt = get_slapd_hash(password, rounds=None)
        self.assertEqual(rounds, 10000)

        test_hash = get_py_hash(password, rounds, salt)
        self.assertEqual(ldap_hash, test_hash)


    def test_custom_rounds(self):
        password = "Mb2.r5oHf-0t"
        rounds_req = 1337

        ldap_hash, rounds, salt = get_slapd_hash(password, rounds=rounds_req)
        self.assertEqual(rounds, rounds_req)

        test_hash = get_py_hash(password, rounds, salt)
        self.assertEqual(ldap_hash, test_hash)


    def test_mod_args_wrong_type(self):
        # text instead of number
        proc = subprocess.run(
            shlex.split(
                'slappasswd -o module-load="pw-pbkdf2.so sometext" -h {{PBKDF2-SHA512}} -s bestpassword'
            ),
            capture_output=True,
            check=False,
        )
        self.assertNotEqual(proc.returncode, 0, "iteration argument must be int")


    def test_mod_args_extra(self):
        proc = subprocess.run(
            shlex.split(
                'slappasswd -o module-load="pw-pbkdf2.so 1337 rofl" -h {{PBKDF2-SHA512}} -s awesomepassword'
            ),
            capture_output=True,
            check=False,
        )
        self.assertNotEqual(proc.returncode, 0, "extra args invalid")


if __name__ == "__main__":
    unittest.main()
