/**
 * Copyright (c) 2025, Fabian Groffen. All rights reserved.
 *
 * See LICENSE for the license.
 */

#include <stdio.h>
#include <unistd.h>
#include <strings.h>
#include <stdbool.h>
#include <stdlib.h>

#include <ldns/ldns.h>
#include <readline/readline.h>
#include <readline/history.h>

#include "util.h"

static void
nslookup_print_servers
(
    ldns_resolver *res,
    bool           firstonly
)
{
    ldns_rdf **list;
    size_t     listlen;
    size_t     i;

    /* print funky looking overview */
    list    = ldns_resolver_nameservers(res);
    listlen = ldns_resolver_nameserver_count(res);
    if (listlen > 0 && firstonly)
        listlen = 1;
    for (i = 0; i < listlen; i++)
    {
        char *addr = ldns_rdf2str(list[i]);
        fprintf(stdout,
                "Default server: %s\n"
                "Address: %s#%u\n",
                addr,
                addr, ldns_resolver_port(res));
    }
}

static int
nslookup_handle_set
(
    ldns_resolver *res,
    ldns_resolver *lres,
    uint8_t       *ndots,
    ldns_rr_type  *qtype,
    ldns_rr_class *qclass,
    const char    *command
)
{
    if (strcmp(command, "search") == 0 ||
        strcmp(command, "nosearch") == 0)
    {
        ldns_resolver_set_dnsrch(res, command[0] == 'n' ? false : true);
    }
    else if (strcmp(command, "rec") == 0 ||
             strcmp(command, "recurse") == 0 ||
             strcmp(command, "norec") == 0 ||
             strcmp(command, "norecurse") == 0)
    {
        ldns_resolver_set_recursive(res, command[0] == 'n' ? false : true);
    }
    else if (strcmp(command, "vc") == 0 ||
             strcmp(command, "novc") == 0)
    {
        ldns_resolver_set_usevc(res, command[0] == 'n' ? false : true);
    }
    else if (strcmp(command, "fail") == 0 ||
             strcmp(command, "nofail") == 0)
    {
        ldns_resolver_set_fail(res, command[0] == 'n' ? false : true);
    }
    else if (strncmp(command, "timeout=", sizeof("timeout=") - 1) == 0)
    {
        char *endp;
        long  timeout;

        command += sizeof("timeout=") - 1;
        timeout  = strtol(command, &endp, 10);

        if (endp == command ||
            timeout < 0)
        {
            fprintf(stderr, "invalid timeout '%s': not a valid number\n",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            struct timeval tv = {
                .tv_sec  = (time_t)timeout
            };

            ldns_resolver_set_timeout(res, tv);
        }
    }
    else if (strncmp(command, "retry=", sizeof("retry=") - 1) == 0)
    {
        char *endp;
        long  retries;

        command += sizeof("retry=") - 1;
        retries  = strtol(command, &endp, 10);

        if (endp == command ||
            retries < 0 ||
            retries > UINT8_MAX)
        {
            fprintf(stderr, "invalid retries '%s': %s\n",
                    retries > UINT8_MAX ? "out of range" :
                    "not a valid number",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            ldns_resolver_set_retry(res, (uint8_t)retries);
        }
    }
    else if (strncmp(command, "port=", sizeof("port=") - 1) == 0 ||
             strncmp(command, "po=", sizeof("po=") - 1) == 0)
    {
        char *endp;
        long  port;

        command = strchr(command, '=') + 1;  /* must exist */
        port    = strtol(command, &endp, 10);

        if (endp == command ||
            port < 0 ||
            port > UINT16_MAX)
        {
            fprintf(stderr, "invalid port '%s': %s\n",
                    port > UINT16_MAX ? "out of range" :
                    "not a valid number",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            ldns_resolver_set_port(res, (uint16_t)port);
        }
    }
    else if (strncmp(command, "ndots=", sizeof("ndots=") - 1) == 0)
    {
        char *endp;
        long  arg;

        command += sizeof("ndots=") - 1;
        arg      = strtol(command, &endp, 10);

        if (endp == command ||
            arg < 0 ||
            arg > UINT8_MAX)
        {
            fprintf(stderr, "invalid ndots '%s': %s\n",
                    arg > UINT8_MAX ? "out of range" :
                    "not a valid number",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            *ndots = (uint8_t)arg;
        }
    }
    else if (strncmp(command, "domain=", sizeof("domain=") - 1) == 0)
    {
        ldns_rdf *dom;

        /* srchlist in set all output */
        command += sizeof("timeout=") - 1;

        dom = ldns_dname_new_frm_str(command);
        ldns_resolver_set_domain(res, dom);
    }
    else if (strncmp(command,
                     "querytype=", sizeof("querytype=") - 1) == 0 ||
             strncmp(command, "type=", sizeof("type=") - 1) == 0 ||
             strncmp(command, "q=", sizeof("q=") - 1) == 0 ||
             strncmp(command, "ty=", sizeof("ty=") - 1) == 0)
    {
        ldns_rr_type tpe;

        command = strchr(command, '=') + 1;  /* must exist */
        tpe     = ldns_get_rr_type_by_name(command);

        if (tpe == (ldns_rr_type)0)
        {
            fprintf(stderr, "invalid querytype '%s': not a valid type\n",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            *qtype = tpe;
        }
    }
    else if (strncmp(command, "class=", sizeof("class=") - 1) == 0)
    {
        ldns_rr_class cls;

        command += sizeof("class=") - 1;
        cls      = ldns_get_rr_class_by_name(command);

        if (cls == (ldns_rr_class)0)
        {
            fprintf(stderr, "invalid class '%s': not a valid class\n",
                    command);
            return EXIT_FAILURE;
        }
        else
        {
            *qclass = cls;
        }
    }
    else if (strcmp(command, "all") == 0)
    {
        char          *querytype  = ldns_rr_type2str(*qtype);
        char          *queryclass = ldns_rr_class2str(*qclass);
        struct timeval timeout    = ldns_resolver_timeout(res);
        ldns_rdf     **list;
        size_t         listlen;
        size_t         i;

        if (lres != NULL)
            nslookup_print_servers(res, false);

        fprintf(stdout,
                "Set options:\n"
                "  %-20s  %-14s  %-14s\n"
                "  %-20s  %-14s\n"
                "  timeout = %-10u  retry = %-6u  port = %-7u  ndots = %-6u\n"
                "  querytype = %-8s  class = %-6s\n"
                "  srchlist = ",
                ldns_resolver_usevc(res) ? "vc" : "novc",
                ldns_resolver_debug(res) ? "debug" : "nodebug",
                "nod2",
                ldns_resolver_dnsrch(res) ? "search" : "nosearch",
                ldns_resolver_recursive(res) ? "recurse" : "norecurse",
                (uint32_t)timeout.tv_sec,
                ldns_resolver_retry(res),
                ldns_resolver_port(res),
                *ndots,
                querytype,
                queryclass);

        list    = ldns_resolver_searchlist(res);
        listlen = ldns_resolver_searchlist_count(res);
        for (i = 0; i < listlen; i++)
        {
            char *dom = ldns_rdf2str(list[i]);

            if (ldns_dname_str_absolute(dom))
                dom[strlen(dom) - 1] = '\0';
            fprintf(stdout, "%s%s", i > 0 ? "/" : "", dom);
            free(dom);
        }
        fprintf(stdout, "\n");

        free(querytype);
        free(queryclass);
    }
    else
    {
        /* not sure why, it only supports x=y syntax, but ok */
        char *sp = strchr(command, ' ');
        if (sp != NULL)
            *sp = '\0';
        fprintf(stderr, "*** Invalid option: %s\n", command);
        return EXIT_FAILURE;
    }

    return EXIT_SUCCESS;
}

static int
nslookup_handle_server
(
    ldns_resolver *res,
    ldns_resolver *useres,
    uint8_t        ndots,
    const char    *command
)
{
    ldns_rdf     **server;
    ldns_rdf      *swalk;
    size_t         i;

    server = util_addr_frm_str(useres, command, ndots);
    if (server == NULL)
    {
        fprintf(stderr,
                "nslookup: couldn't get address "
                "for '%s': not found\n", command);
        return EXIT_FAILURE;
    }

    /* this is a tad bit expensive, but I cannot find an API to
     * replace the list of nameservers, or clear the current list,
     * perhaps in the future there will be such functionality, for
     * now it isn't a big deal here */

    /* remove existing nameservers */
    while (ldns_resolver_pop_nameserver(res) != NULL)
        ;

    /* add new ones */
    for (i = 0; (swalk = server[i]) != NULL; i++)
    {
        ldns_resolver_push_nameserver(res, swalk);
        ldns_rdf_deep_free(swalk);
    }
    free(server);

    nslookup_print_servers(res, false);

    return EXIT_SUCCESS;
}

static int
nslookup_print_record
(
    ldns_rr_list *ans,
    bool          additional
)
{
    size_t i;
    size_t j;
    int    ret = EXIT_FAILURE;

    for (i = 0; i < ldns_rr_list_rr_count(ans); i++)
    {
        ldns_rr      *rr   = ldns_rr_list_rr(ans, i);
        ldns_rdf     *base = ldns_rr_owner(rr);
        ldns_rr_type  tpe  = ldns_rr_get_type(rr);
        char         *trg;
        char         *own;
        uint32_t      prio;

        if (ldns_rr_rd_count(rr) == 0)
            continue;

        /* drop trailing dot */
        ldns_rdf_set_size(base, ldns_rdf_size(base) - 1);
        own = ldns_rdf2str(base);

        if (tpe == LDNS_RR_TYPE_A)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            if (!additional)
            {
                fprintf(stdout,
                        "Name:\t%s\n"
                        "Address: %s\n",
                        own,
                        dest);
            }
            else
            {
                fprintf(stdout, "%s\tinternet address = %s\n", own, dest);
            }

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_AAAA)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s\thas AAAA address %s\n",
                    own, dest);

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_CNAME)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s\tcanonical name = %s\n", own, dest);

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_MX)
        {
            trg = NULL;
            for (j = 0; j < ldns_rr_rd_count(rr); j++)
            {
                ldns_rdf      *targ = ldns_rr_rdf(rr, j);
                ldns_rdf_type  t    = ldns_rdf_get_type(targ);

                switch (t)
                {
                    case LDNS_RDF_TYPE_DNAME:
                        trg = ldns_rdf2str(targ);
                        break;
                    case LDNS_RDF_TYPE_INT8:
                        prio = (uint32_t)ldns_rdf2native_int8(targ);
                        break;
                    case LDNS_RDF_TYPE_INT16:
                        prio = (uint32_t)ldns_rdf2native_int16(targ);
                        break;
                    case LDNS_RDF_TYPE_INT32:
                        prio = ldns_rdf2native_int32(targ);
                        break;
                    default:
                        /* ignore this RD */
                        break;
                }
            }
            if (trg != NULL)
            {
                fprintf(stdout, "%s\tmail exchanger = %u %s\n",
                        own, prio, trg);
                ret = EXIT_SUCCESS;
            }

            free(trg);
        }
        else if (tpe == LDNS_RR_TYPE_PTR)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s\tname = %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_TXT)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s\ttext = %s\n", own, dest);

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_NS)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s\tnameserver = %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_SOA)
        {
            ldns_rdf *targ;
            char     *dest;
            size_t    count   = ldns_rr_rd_count(rr);
            char     *names[] = {
                "origin",
                "mail addr",
                "serial",
                "refresh",
                "retry",
                "expire",
                "minimum"
            };

            if (count > sizeof(names) / sizeof(names[0]))
                count = sizeof(names) / sizeof(names[0]);

            fprintf(stdout, "%s\n", own);

            for (j = 0; j < count; j++)
            {
                targ = ldns_rr_rdf(rr, j);
                dest = ldns_rdf2str(targ);
                fprintf(stdout, "\t%s = %s\n", names[j], dest);
                free(dest);
            }

            ret = EXIT_SUCCESS;
        }

        free(own);
    }

    return ret;
}

static int
nslookup_resolve
(
    const char    *dname,
    ldns_resolver *res,
    ldns_resolver *lres,
    uint8_t        ndots,
    ldns_rr_type   qtype,
    ldns_rr_class  qclass
)
{
    ldns_rdf      *d;
    ldns_pkt      *pkt;
    ldns_rr_list  *ans;
    
    if (ldns_str2rdf_a(&d, dname) == LDNS_STATUS_OK ||
        ldns_str2rdf_aaaa(&d, dname) == LDNS_STATUS_OK)
    {
        /* force pointer lookup for addresses, regardless */
        qtype = LDNS_RR_TYPE_PTR;
        ldns_rdf_deep_free(d);
    }

    d   = util_dname_frm_str(dname, ndots, false);
    pkt = ldns_resolver_search(res, d, qtype, qclass, LDNS_RD);

    if (pkt == NULL ||
        (ldns_pkt_reply_type(pkt) != LDNS_PACKET_ANSWER &&
         ldns_pkt_reply_type(pkt) != LDNS_PACKET_NODATA))
    {
        char           *rcd;
        ldns_pkt_rcode  rc  = LDNS_PACKET_UNKNOWN;

        if (pkt != NULL)
            rc = ldns_pkt_get_rcode(pkt);
        rcd = ldns_pkt_rcode2str(rc);

        if (pkt != NULL)
            fprintf(stderr, "** host not found: %s\n", rcd);
        else
            fprintf(stderr,
                    ";; connection timed out; no servers could be reached\n");
        return EXIT_FAILURE;
    }

    if (ldns_resolver_nameserver_count(res) > 0)
    {
        nslookup_print_servers(res, true);
        fprintf(stdout, "\n");
    }

    ans = ldns_pkt_answer(pkt);
    if (qtype == LDNS_RR_TYPE_MX ||
        qtype == LDNS_RR_TYPE_ANY)
        ldns_rr_list_sort(ans);
    if (ldns_rr_list_rr_count(ans) > 0)
    {
        fprintf(stdout, "%s answer:\n",
                ldns_pkt_aa(pkt) ? "Authorative" : "Non-authorative");
        nslookup_print_record(ans, false);
    }
    else
    {
        char *dn = ldns_rdf2str(d);
        fprintf(stderr, "*** Can't find %s: No answer\n", dn);
        free(dn);
    }

    if (!ldns_pkt_aa(pkt))
    {
        ans = ldns_pkt_additional(pkt);
        if (ldns_rr_list_rr_count(ans) > 0)
        {
            fprintf(stdout, "\nAuthoritative answers can be found from:\n");
            nslookup_print_record(ans, true);
        }
    }

    return EXIT_SUCCESS;
}

int
main
(
    int    argc,
    char **argv
)
{
    ldns_resolver *res;
    ldns_resolver *lres;
    ldns_rr_type   qtype    = LDNS_RR_TYPE_A;
    ldns_rr_class  qclass   = LDNS_RR_CLASS_IN;
    ldns_status    s;
    uint8_t        ndots    = 1;
    char          *command;

    /* create resolver from /etc/resolv.conf */
    s = ldns_resolver_new_frm_file(&res, NULL);
    if (s != LDNS_STATUS_OK)
        res = ldns_resolver_new();
    lres = ldns_resolver_clone(res);  /* initial resolver */

    /* first eat all "options", they must be supplied as first arguments */
    while (argc > 1)
    {
        if (argv[1][0] != '-' ||
            argv[1][1] == '\0')
            break;  /* not an option */

        if (strcmp(argv[1], "-version") == 0)
        {
            fprintf(stdout, "nslookup %s from ldns-tools (using ldns %s)\n",
                    LDNS_TOOLS_VERSION, ldns_version());
            exit(EXIT_SUCCESS);
        }

        if (nslookup_handle_set(res, NULL, &ndots, &qtype, &qclass,
                                &argv[1][1]) == EXIT_FAILURE)
            exit(EXIT_FAILURE);

        /* shift */
        argv = &argv[1];
        argc--;
    }

    if (argc == 1 ||
        (argc == 3 &&
         argv[1][0] == '-' &&
         argv[1][1] == '\0'))
    {
        if (argc == 3)
        {
            char cmd[256];

            snprintf(cmd, sizeof(cmd), "server=%s", argv[2]);
            if (nslookup_handle_set(res, lres, &ndots, &qtype, &qclass,
                                    cmd) == EXIT_FAILURE)
                exit(EXIT_FAILURE);
        }

        while ((command = readline("> ")) != NULL)
        {
            add_history(command);

            if (strncmp(command, "set ", sizeof("set ") - 1) == 0)
                nslookup_handle_set(res, lres, &ndots, &qtype, &qclass,
                                    command + sizeof("set ") - 1);
            else if (strcmp(command, "exit") == 0)
                break;
            else if (strcmp(command, "root") == 0 ||
                     strcmp(command, "finger") == 0 ||
                     strcmp(command, "ls") == 0 ||
                     strcmp(command, "view") == 0 ||
                     strcmp(command, "help") == 0 ||
                     strcmp(command, "?") == 0)
                fprintf(stderr, "The '%s' command is not implemented.\n",
                        command);
            else if (strncmp(command, "server ", sizeof("server ") - 1) == 0 ||
                     strncmp(command, "lserver ", sizeof("lserver ") - 1) == 0)
                nslookup_handle_server(res,
                                       command[0] == 'l' ? lres : res,
                                       ndots,
                                       (command + sizeof("server ") - 1 +
                                        (command[0] == 'l' ? 1 : 0)));
            else
                nslookup_resolve(command, res, lres, ndots, qtype, qclass);
        }

        exit(EXIT_SUCCESS);
    }
    else
    {
        /* execute command, and exit */
        exit(nslookup_resolve(argv[1], res, lres, ndots, qtype, qclass));
    }
}

/* vim: set ts=4 sw=4 expandtab cinoptions=(0,u0,U1,W2s,l1: */
