#define _GNU_SOURCE
#include <sys/mount.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>

#include <errno.h>
#include <fcntl.h>
#include <sched.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

static bool write_proc(int proc_pid_fd, const char *fname, const char *buf,
                       size_t buflen, bool ignore_errors)
{
    int fd;

    if ((fd = openat(proc_pid_fd, fname, O_WRONLY)) == -1) {
        fprintf(stderr, "open %s: %s\n", fname, strerror(errno));
        return false;
    }

    if (write(fd, buf, buflen) == -1) {
        if (!ignore_errors)
            fprintf(stderr, "write %s: %s\n", fname, strerror(errno));
        close(fd);
        return ignore_errors;
    }

    close(fd);
    return true;
}


#define WRITE_IDMAP(file, value) \
    buflen = snprintf(buf, 100, "%1$lu %1$lu 1", (unsigned long)value); \
    if (buflen >= 100) { \
        fputs("Unable to write buffer for " file ".\n", stderr); \
        close(proc_pid_fd); \
        return false; \
    } else if (buflen < 0) { \
        perror("snprintf " file " buffer"); \
        close(proc_pid_fd); \
        return false; \
    } \
    if (!write_proc(proc_pid_fd, file, buf, buflen, false)) { \
        close(proc_pid_fd); \
        return false; \
    }

static bool write_maps(pid_t parent_pid)
{
    int proc_pid_fd;
    size_t buflen;
    char buf[100];

    buflen = snprintf(buf, 100, "/proc/%lu", (unsigned long)parent_pid);
    if (buflen >= 100) {
        fputs("Unable to write buffer for child pid proc path.\n", stderr);
        return false;
    } else if (buflen < 0) {
        perror("snprintf child pid proc path");
        return false;
    }

    if ((proc_pid_fd = open(buf, O_RDONLY | O_DIRECTORY)) == -1) {
        fprintf(stderr, "open %s: %s\n", buf, strerror(errno));
        return false;
    }

    WRITE_IDMAP("uid_map", geteuid());

    // Kernels prior to Linux 3.19 which do not impose setgroups()
    // restrictions won't have this file, so ignore failure.
    write_proc(proc_pid_fd, "setgroups", "deny", 4, true);

    WRITE_IDMAP("gid_map", getegid());

    return true;
}

#define MKDIR_MOUNT(path) \
    if (mkdir(FS_ROOT_DIR path, 0755) == -1) { \
        perror("mkdir"); \
        return EXIT_FAILURE; \
    } \
    if (mount("/nix", FS_ROOT_DIR path, "", MS_BIND | MS_REC, NULL)) { \
        perror("mount " path); \
        return EXIT_FAILURE; \
    }

int main(int argc, char **argv)
{
    int sync_pipe[2], child_status;
    char sync_status = '.';
    pid_t pid, parent_pid;

    if (pipe(sync_pipe) == -1) {
        perror("pipe");
        return EXIT_FAILURE;
    }

    parent_pid = getpid();

    switch (pid = fork()) {
        case -1:
            perror("fork");
            return EXIT_FAILURE;
        case 0:
            close(sync_pipe[1]);
            if (read(sync_pipe[0], &sync_status, 1) == -1) {
                perror("read pipe from parent");
                _exit(1);
            } else if (sync_status == 'X')
                _exit(1);
            close(sync_pipe[0]);
            _exit(write_maps(parent_pid) ? 0 : 1);
        default:
            if (unshare(CLONE_NEWNS | CLONE_NEWUSER) == -1) {
                perror("unshare");
                if (write(sync_pipe[1], "X", 1) == -1)
                    perror("signal child exit");
                waitpid(pid, NULL, 0);
                return EXIT_FAILURE;
            }

            close(sync_pipe[1]);
            waitpid(pid, &child_status, 0);
            if (WIFEXITED(child_status) && WEXITSTATUS(child_status) == 0)
                break;
            return EXIT_FAILURE;
    }

    if (mount("none", FS_ROOT_DIR, "tmpfs",
              MS_NOEXEC | MS_NOSUID | MS_NODEV | MS_NOATIME, NULL) == -1) {
        perror("mount rootfs");
        return EXIT_FAILURE;
    }

    MKDIR_MOUNT("/proc");
    MKDIR_MOUNT("/sys");
    MKDIR_MOUNT("/dev");
    MKDIR_MOUNT("/nix");

    if (chroot(FS_ROOT_DIR) == -1) {
        perror("chroot");
        return EXIT_FAILURE;
    }

    if (chdir("/") == -1) {
        perror("chdir");
        return EXIT_FAILURE;
    }

    argv++;
    if (execv(argv[0], argv) == -1) {
        perror("execv");
        return EXIT_FAILURE;
    }

    /* Never reached. */
    return EXIT_SUCCESS;
}