# Copyright (c) 2021, 2022  Peter Pentchev <roam@ringlet.net>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.
"""Run the remrun tests against a local SSH server instance."""

from __future__ import annotations

import argparse
import contextlib
import dataclasses
import errno
import os
import pathlib
import pwd
import shlex
import shutil
import socket
import subprocess
import sys
import tempfile
import time

from typing import Iterator

import cfg_diag
import utf8_locale


VERSION = "0.1.0"

PATH_PRIVSEP = pathlib.Path("/run/sshd")


@dataclasses.dataclass(frozen=True)
class Config(cfg_diag.ConfigDiag):
    """Runtime configuration for the remrun test tool."""

    prog: pathlib.Path
    test_prog: pathlib.Path | None
    unpriv_account: str | None
    utf8_env: dict[str, str]


@dataclasses.dataclass(frozen=True)
class SSHConfig:
    """Information about the generated SSH configuration."""

    addr: str
    port: int
    username: str
    home: pathlib.Path
    client_config: pathlib.Path
    server_config: pathlib.Path


def parse_args() -> Config:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(prog="run_sshd_test")
    parser.add_argument(
        "-t",
        "--test-prog",
        type=pathlib.Path,
        help="the path to run-test.sh if it is to be run",
    )
    parser.add_argument(
        "-u",
        "--unprivileged",
        type=str,
        help="the username of the unprivileged account to switch to",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="verbose operation; display diagnostic output",
    )
    parser.add_argument(
        "remrun", type=pathlib.Path, help="the path to the remrun program to test"
    )

    args = parser.parse_args()

    prog = args.remrun.absolute()
    if not prog.is_file() or not os.access(prog, os.R_OK | os.X_OK):
        sys.exit(f"Not an executable regular file: {prog}")

    return Config(
        prog=prog,
        test_prog=args.test_prog.absolute() if args.test_prog is not None else None,
        unpriv_account=args.unprivileged if args.unprivileged else None,
        utf8_env=utf8_locale.UTF8Detect().detect().env,
        verbose=args.verbose,
    )


def find_listening_port(cfg: Config) -> tuple[str, int]:
    """Find a port to listen on at a local address."""
    for (addr, family) in (("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)):
        cfg.diag(f"Looking for a port to listen on at {addr}")
        for port in range(8086, 8200):
            lsock = socket.socket(
                family=family, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
            )
            lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                lsock.bind((addr, port))
                lsock.close()
                cfg.diag(f"- got {addr}:{port}")
                return addr, port
            except socket.error as err:
                cfg.diag(f"- could not bind to {addr}:{port}: {err}")

        cfg.diag(f"- could not bind to any of the desired ports at {addr}")

    sys.exit("Could not find a local address/port to listen on")


def create_ssh_config(
    cfg: Config, addr: str, port: int, tempd: pathlib.Path
) -> SSHConfig:
    """Set up the SSH server's config directories."""
    username = pwd.getpwuid(os.getuid()).pw_name

    home_dir = tempd / "home"
    cli_dir = home_dir / ".ssh"
    srv_dir = tempd / "server"

    cli_cfg = cli_dir / "config"
    cli_key = cli_dir / "id"
    cli_known = cli_dir / "known_hosts"

    srv_authkeys = srv_dir / "authorized_keys"
    srv_cfg = srv_dir / "sshd_config"
    srv_key = srv_dir / "ssh_host_key"
    srv_pid = srv_dir / "sshd.pid"

    home_dir.mkdir(mode=0o700)
    cli_dir.mkdir(mode=0o700)
    srv_dir.mkdir(mode=0o700)

    cfg.diag(f"Generating the SSH host key at {srv_key}")
    subprocess.check_call(
        ["ssh-keygen", "-f", srv_key, "-t", "ed25519", "-N", ""], env=cfg.utf8_env
    )

    cfg.diag(f"Generating the SSH client key at {cli_key}")
    subprocess.check_call(
        ["ssh-keygen", "-f", cli_key, "-t", "ed25519", "-N", ""], env=cfg.utf8_env
    )

    cfg.diag("Copying the client public key to the authorized keys file")
    srv_authkeys.write_text(
        cli_key.with_suffix(".pub").read_text(encoding="UTF-8"), encoding="UTF-8"
    )
    srv_authkeys.chmod(0o600)

    cfg.diag("Generating the client known hosts file")
    cli_known.write_text(
        addr + " " + srv_key.with_suffix(".pub").read_text(encoding="UTF-8"),
        encoding="UTF-8",
    )

    cfg.diag("Generating the SSH client config file")
    cli_cfg.write_text(
        f"""
Host *
ForwardAgent no
ForwardX11 no
GlobalKnownHostsFile /dev/null
GSSAPIAuthentication no
HostbasedAuthentication no
IdentitiesOnly yes
IdentityFile {cli_key}
KbdInteractiveAuthentication no
PasswordAuthentication no
Port {port}
PubkeyAuthentication yes
RequestTTY no
StrictHostKeyChecking yes
Tunnel no
UpdateHostKeys no
User {username}
UserKnownHostsFile {cli_known}
VerifyHostKeyDNS no
""",
        encoding="UTF-8",
    )

    cfg.diag("Generating the SSH server config file")
    srv_cfg.write_text(
        f"""
AllowUsers {username}
AuthorizedKeysFile {srv_authkeys}
DisableForwarding yes
GSSAPIAuthentication no
HostKey {srv_key}
IgnoreRhosts yes
KbdInteractiveAuthentication no
ListenAddress {addr}
PasswordAuthentication no
PermitRootLogin {'yes' if username == 'root' else 'no'}
PermitTTY no
PidFile {srv_pid}
Port {port}
PubkeyAuthentication yes
StrictModes no
UseDNS no
""",
        encoding="UTF-8",
    )

    subprocess.check_call(["grep", "-Ere", "^", "."], cwd=tempd, env=cfg.utf8_env)

    if username == "root":
        cfg.diag("Unlocking the root account, just in case")
        subprocess.check_call(["usermod", "-U", "root"], env=cfg.utf8_env)

    return SSHConfig(
        addr=addr,
        port=port,
        username=username,
        home=home_dir,
        client_config=cli_cfg,
        server_config=srv_cfg,
    )


def create_ssh_wrapper(cfg: Config, ssh_cfg: SSHConfig) -> Config:
    """Create the SSH wrapper that uses the generated config and keys."""
    home_bin = ssh_cfg.home / "bin"
    home_bin.mkdir(mode=0o700)

    cfg.diag("Determining the full path of the real SSH executable")
    lines = subprocess.check_output(
        ["sh", "-c", "command -v ssh"], encoding="UTF-8", env=cfg.utf8_env
    ).splitlines()
    if len(lines) != 1:
        sys.exit(f"Expected `command -v ssh` to output exactly one line, got {lines!r}")
    ssh_prog = pathlib.Path(lines[0])
    if not ssh_prog.is_file() or not os.access(ssh_prog, os.R_OK | os.X_OK):
        sys.exit(
            f"Expected `command -v ssh` to point to an executable file, got {ssh_prog!r}"
        )

    cfg.diag("Generating the SSH wrapper")
    ssh_wrapper = home_bin / "ssh"
    ssh_wrapper.write_text(
        f"""#!/bin/sh

exec {shlex.quote(str(ssh_prog))} -F {shlex.quote(str(ssh_cfg.client_config))} "$@"
""",
        encoding="UTF-8",
    )
    ssh_wrapper.chmod(0o700)

    utf8_env = dict(cfg.utf8_env)
    opath = utf8_env.get("PATH")
    if opath is None:
        utf8_env["PATH"] = str(home_bin)
    else:
        utf8_env["PATH"] = f"{home_bin}:{opath}"

    cfg.diag("Checking that the SSH-specific environment is sane")
    lines = subprocess.check_output(
        ["sh", "-c", "command -v ssh"], encoding="UTF-8", env=utf8_env
    ).splitlines()
    if len(lines) != 1 or lines[0] != str(ssh_wrapper):
        sys.exit(f"Expected `command -v ssh` to output {ssh_wrapper}, got {lines!r}")

    return dataclasses.replace(cfg, utf8_env=utf8_env)


def check_ssh_connection(cfg: Config, ssh_cfg: SSHConfig) -> None:
    """Once the SSH server has been started, check that the client can connect to it."""
    cfg.diag("Checking that our SSH client and server both work")
    lines = subprocess.check_output(
        ["sh", "-c", f"ssh -- {ssh_cfg.addr} printenv SSH_CONNECTION"],
        encoding="UTF-8",
        env=cfg.utf8_env,
    ).splitlines()
    if len(lines) != 1 or lines[0].split()[-2:] != [ssh_cfg.addr, str(ssh_cfg.port)]:
        sys.exit(
            f"Expected `printenv SSH_CONNECTION` to end with '{ssh_cfg.addr} {ssh_cfg.port}', "
            f"got {lines!r}"
        )


@contextlib.contextmanager
def start_sshd(cfg: Config, ssh_cfg: SSHConfig) -> Iterator[subprocess.Popen[str]]:
    """Start an SSH server listening at the specified address and port."""
    proc = None
    try:
        cfg.diag("Starting an SSH server")
        proc = subprocess.Popen(
            ["/usr/sbin/sshd", "-D", "-e", "-f", ssh_cfg.server_config],
            encoding="UTF-8",
            env=cfg.utf8_env,
        )
        cfg.diag(f"- got SSH server process {proc.pid}")
        yield proc
    finally:
        if proc is not None:
            if not proc.poll():
                cfg.diag("Killing the SSH server")
                proc.kill()
            cfg.diag(f"The SSH server is done, code {proc.wait()}")


def create_test_script(cfg: Config, tempd: pathlib.Path) -> pathlib.Path:
    """Create the test script that runs printenv with some arguments."""
    cfg.diag("Creating the test printenv script to run on the other side")
    test_printenv = tempd / "test_printenv"
    test_printenv.write_text(
        """#!/bin/sh

printenv USER SSH_CONNECTION
""",
        encoding="UTF-8",
    )
    test_printenv.chmod(0o700)
    return test_printenv


def test_remrun(cfg: Config, ssh_cfg: SSHConfig, test_printenv: pathlib.Path) -> None:
    """Run remrun a couple of times, examine its output."""
    assert cfg.test_prog is None
    cfg.diag("Now running remrun with our client against our server")
    lines = subprocess.check_output(
        [cfg.prog, "--", ssh_cfg.addr, test_printenv],
        encoding="UTF-8",
        env=cfg.utf8_env,
    ).splitlines()
    if (
        len(lines) != 2
        or lines[0] != ssh_cfg.username
        or lines[1].split()[-2:] != [ssh_cfg.addr, str(ssh_cfg.port)]
    ):
        sys.exit(
            f"Expected `remrun test_printenv` to output {ssh_cfg.username!r} and "
            f"something ending in {ssh_cfg.addr!r} {ssh_cfg.port!r}, got {lines!r}"
        )

    res = subprocess.run(
        [str(cfg.prog), "--", ssh_cfg.addr, "-"],
        capture_output=True,
        check=True,
        encoding="UTF-8",
        env=cfg.utf8_env,
        input=test_printenv.read_text(encoding="UTF-8"),
    )
    if not res.stdout:
        sys.exit("`remrun -` did not output anything")
    else:
        lines = res.stdout.splitlines()
        if (
            len(lines) != 2
            or lines[0] != ssh_cfg.username
            or lines[1].split()[-2:] != [ssh_cfg.addr, str(ssh_cfg.port)]
        ):
            sys.exit(
                f"Expected `remrun -` to output {ssh_cfg.username!r} and "
                f"something ending in {ssh_cfg.addr!r} {ssh_cfg.port!r}, got {lines!r}"
            )


def test_prog(cfg: Config, ssh_cfg: SSHConfig) -> None:
    """Run the run-test.sh test suite within our environment."""
    assert cfg.test_prog is not None
    cfg.diag(f"Running the {cfg.test_prog} testsuite")
    run_env = dict(cfg.utf8_env)
    run_env["REMRUN_TEST_HOSTSPEC"] = ssh_cfg.addr
    subprocess.check_call(["sh", "--", cfg.test_prog, cfg.prog], env=run_env)


def ensure_privsep_path(cfg: Config, tempd: pathlib.Path) -> bool:
    """Make sure the SSH server will be able to chroot into /run/sshd."""
    if not PATH_PRIVSEP.is_dir():
        cfg.diag(f"Creating the {PATH_PRIVSEP} directory")
        os.makedirs(PATH_PRIVSEP, mode=0o755)

    if cfg.unpriv_account is not None:
        cfg.diag(f"Getting information about the {cfg.unpriv_account!r} account")
        try:
            acc_ent = pwd.getpwnam(cfg.unpriv_account)
        except KeyError:
            sys.exit(f"Unknown user account {cfg.unpriv_account!r}")

        cfg.diag(f"Changing the ownership of {tempd} to {cfg.unpriv_account!r}")
        try:
            os.chown(tempd, acc_ent.pw_uid, acc_ent.pw_gid)
        except OSError as err:
            sys.exit(
                f"Could not chown() {tempd} to {acc_ent.pw_uid}:{acc_ent.pw_gid} for "
                f"{cfg.unpriv_account!r}: {err}"
            )

        child_pid = os.fork()
        if child_pid != 0:
            cfg.diag(f"Skipping the tests, waiting for process {child_pid} to end")
            res = os.waitpid(child_pid, 0)
            cfg.diag(f"Process {child_pid} ended, code {res}")
            return False

        cfg.diag(f"Trying to setuid() to {cfg.unpriv_account!r}")
        try:
            os.setgid(acc_ent.pw_gid)
            os.setuid(acc_ent.pw_uid)
        except OSError as err:
            sys.exit(
                f"Could not setuid()/setgid() to {acc_ent.pw_uid}:{acc_ent.pw_gid} for "
                f"{cfg.unpriv_account!r}: {err}"
            )

    return True


@contextlib.contextmanager
def create_temp_dir(cfg: Config) -> Iterator[pathlib.Path]:
    """Create a temporary directory, remove it at the end."""
    tempd_obj = None
    initial_uid = os.getuid()
    try:
        tempd_obj = tempfile.mkdtemp(prefix="run_sshd_test.", dir=".")
        tempd = pathlib.Path(tempd_obj).absolute()
        cfg.diag(f"Using {tempd} as a temporary directory, initial uid {initial_uid}")
        yield tempd
    finally:
        if tempd_obj is not None:
            current_pid = os.getpid()
            current_uid = os.getuid()
            if initial_uid != current_uid:
                cfg.diag(
                    f"Not removing {tempd} in process {current_pid}: "
                    f"uid {current_uid} != {initial_uid}"
                )
            else:
                cfg.diag(f"Removing {tempd} in process {current_pid}")
                shutil.rmtree(tempd)


def wait_for_sshd_banner(cfg: Config, ssh_cfg: SSHConfig) -> None:
    """Try to connect to the SSH server's port, expect a banner."""
    ainfo = socket.getaddrinfo(
        ssh_cfg.addr,
        ssh_cfg.port,
        type=socket.SOCK_STREAM,
        proto=socket.IPPROTO_TCP,
        flags=socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
    )[0]
    if len(ainfo) < 5 or not isinstance(ainfo[4], tuple) or len(ainfo[4]) < 2:
        sys.exit(f"getaddrinfo({ssh_cfg.addr!r}, {ssh_cfg.port!r}) returned {ainfo!r}")
    s_family, s_type, s_proto = ainfo[:3]
    s_address, s_port = ainfo[4][:2]
    cfg.diag(
        f"Waiting for the SSH server at {s_address}:{s_port} "
        f"(address family {s_family.name}) to start accepting connections"
    )

    for _ in range(20):
        time.sleep(0.5)
        cfg.diag(f"Trying to connect to {s_address} port {s_port}...")
        with socket.socket(s_family, s_type, s_proto) as sock:
            try:
                sock.connect((s_address, s_port))
            except socket.error as err:
                if err.errno != errno.ECONNREFUSED:
                    raise
                print("Connection refused, will retry")
                continue

            cfg.diag("Connected!")
            data = sock.recv(4096).decode("ISO-8859-15")
            cfg.diag(f"Got banner {data!r}")
            if not data.startswith("SSH-") or "\n" not in data:
                sys.exit(
                    f"Expected an SSH banner from the server started at {s_address}:{s_port}, "
                    f"got {data!r}"
                )
            return

    sys.exit(f"Could not connect to {s_address}:{s_port} for 10 seconds")


def main() -> None:
    """Main program: parse command-line options, start the SSH server, run tests."""
    cfg = parse_args()

    with create_temp_dir(cfg) as tempd:
        try:
            os.chdir(tempd)
            run_test = ensure_privsep_path(cfg, tempd)
            if run_test:
                addr, port = find_listening_port(cfg)

                ssh_cfg = create_ssh_config(cfg, addr, port, tempd)
                cfg = create_ssh_wrapper(cfg, ssh_cfg)
                with start_sshd(cfg, ssh_cfg):
                    wait_for_sshd_banner(cfg, ssh_cfg)
                    check_ssh_connection(cfg, ssh_cfg)

                    if cfg.test_prog is None:
                        test_printenv = create_test_script(cfg, tempd)
                        test_remrun(cfg, ssh_cfg, test_printenv)
                    else:
                        test_prog(cfg, ssh_cfg)

                print("The remrun tool seems to be operational")
        finally:
            os.chdir("/")


if __name__ == "__main__":
    main()
