#!/usr/bin/python3
# SPDX-License-Identifier: GPL-2.0-or-later
########################################################################
# FILE:	greylistd.py
# PURPOSE:	Simple greylisting daemon.  See "greylistd(8)".
# For an introduction to greylisting, see:
# http://projects.puremagic.com/greylisting/
#
# This program listens for connections on a UNIX domain
# socket, presumably from an MTA such as Exim.  Nominally,
# it reads an identifier (referred to as a "triplet"),
# and returns a single word ("white" or "grey") depending
# on prior knowledge of said identifier.
#
# Copyright (C)
# 2004, Tor Slettnes <tor@slett.net>
# 2006, Sven Anderson <sven(at)anderson.de>
########################################################################

import time
import socket
from socket import AF_UNIX, SOCK_STREAM
import os
import pwd
import grp
import sys
import signal
from signal import SIGTERM, SIGHUP
import syslog
from syslog import LOG_NOTICE, LOG_ERR
import select
import inspect


# Ensure that we can run this program
if sys.version_info.major < 3 or sys.version_info.minor < 6:
    sys.stderr.write("This program requires Python 3.6 or newer\n")
    sys.exit(1)


# Configuration file sections, items
(DATA, STATEFILE, TRIPLETFILE, SAVETRIPLETS, UPDATEINTERVAL,
 SINGLECHECK, SINGLEUPDATE) = (
     "data", "statefile", "tripletfile", "savetriplets", "update",
     "singlecheck", "singleupdate")

(SOCKET, SOCKPATH, SOCKOWNER, SOCKMODE) = (
    "socket", "path", "owner", "mode")

(TIMEOUTS, RETRYMIN, RETRYMAX, EXPIRE) = (
    "timeouts", "retryMin", "retryMax", "expire")


# Defaults for various configuration items
conffile = "/etc/greylistd/config"

config = {DATA: {STATEFILE: "/var/lib/greylistd/states",
                 TRIPLETFILE: "/var/lib/greylistd/triplets",
                 SAVETRIPLETS: True,
                 SINGLECHECK: False,
                 SINGLEUPDATE: False,
                 UPDATEINTERVAL: 300},

          SOCKET: {SOCKPATH: "/var/run/greylistd/socket",
                   SOCKOWNER: "greylist:greylist",
                   SOCKMODE: "0660"},

          TIMEOUTS: {RETRYMIN: 60 * 60,
                     RETRYMAX: 60 * 60 * 8,
                     EXPIRE: 60 * 60 * 24 * 60}}


# Lists/states
(WHITE, GREY, BLACK, UNSEEN) = ("white", "grey", "black", "unseen")

# Additional data file sections/items
(STATS, START, LASTSAVE) = ("statistics", "start", "lastsave")

# Greylist data
data = {WHITE: {},
        GREY: {},
        BLACK: {},
        STATS: {WHITE: 0,
                GREY: 0,
                BLACK: 0,
                START: 0,
                LASTSAVE: 0}}

# Type conversions for items in data file
datatypes = {WHITE: (int, (int, int, int)),
             GREY: (int, (int, int, int)),
             BLACK: (int, (int, int, int))}

# Index of elements in data
(IDX_LAST, IDX_FIRST, IDX_COUNT) = range(3)

# Unhashed/original data not yet saved to disk
newTriplets = {}


class RunException(Exception):
    pass


class CommandError(RunException):
    pass


def expireKeys(now):
    for (listKey, timeoutKey) in ((GREY, RETRYMAX),
                                  (WHITE, EXPIRE),
                                  (BLACK, EXPIRE)):
        for (dataKey, dataValue) in data[listKey].copy().items():
            if dataValue[IDX_LAST] + config[TIMEOUTS][timeoutKey] < now:
                del data[listKey][dataKey]


def listStatus(searchkey):
    for listkey in datatypes:
        if searchkey in data[listkey]:
            return listkey

    return None


def log(message, priority=LOG_NOTICE):
    if os.isatty(sys.stderr.fileno()):
        sys.stderr.write("%s\n" % (message,))
    else:
        syslog.syslog(priority, message)


def duration(secs):
    plural = ("", "s")

    if secs < 60:
        return "%d second%s" % (secs, plural[secs > 1])

    elif secs < 60 * 60:
        (mins, secs) = (secs // 60, secs % 60)
        return "%s%s%s%s%s%s%s" % (mins, " minute", plural[mins != 1],
                                   secs and " and " or "",
                                   secs or "",
                                   secs and " second" or "",
                                   plural[secs > 1] or "")

    elif (secs + 30) < 60 * 60 * 24:
        (hrs, mins) = ((secs + 30) // 3600, ((secs + 30) // 60) % 60)
        return "%s%s%s%s%s%s%s" % (hrs, " hour", plural[hrs != 1],
                                   mins and " and " or "",
                                   mins or "",
                                   mins and " minute" or "",
                                   plural[mins > 1] or "")

    else:
        (days, hrs) = ((secs + 1800) // 86400, ((secs + 1800) // 3600) % 24)
        return "%s%s%s%s%s%s%s" % (days, " day", plural[days != 1],
                                   hrs and " and " or "",
                                   hrs or "",
                                   hrs and " hour" or "",
                                   plural[hrs > 1] or "")


def typeConvert(typeobject, string):
    if type(typeobject) in (tuple, list):
        typelist = typeobject
        stringlist = string.split(None, len(typeobject) - 1)
        valuelist = []

        for idx, typeobject in enumerate(typelist):
            word = stringlist[idx]
            valuelist.append(typeConvert(typeobject, word))

        return valuelist

    elif type(typeobject) is not type:
        return typeConvert(type(typeobject), string)

    elif typeobject is bool:
        try:
            if int(string):
                return True
            else:
                return False

        except ValueError:
            if string.lower() in ("yes", "true", "on"):
                return True
            elif string.lower() in ("no", "false", "off"):
                return False
            else:
                raise ValueError("Not a valid boolean: '%s'" % string)

    else:
        return(typeobject(string))


def loadFromFile(datafile, dictionary, typelist=None):
    try:
        fp = open(datafile)
        section = None

        logPfx = 'In %s:' % datafile

        for line in fp:
            line = line.strip()

            if line.startswith("#"):
                continue

            elif (line[0:1] == '[') and (']' in line):
                section = line[1:line.find(']')].strip().lower()

                if not section in dictionary:
                    log("%s Invalid or obsolete section: [%s]" %
                        (logPfx, section))
                    section = None

            elif section and ('=' in line):
                key, data = [s.strip() for s in line.split('=', 1)]

                if typelist and (section in typelist):
                    keytype, valuetype = typelist[section]

                elif key in dictionary[section]:
                    keytype = str
                    valuetype = dictionary[section][key]
                else:
                    log("%s Invalid or obsolete key: [%s] %s" %
                        (logPfx, section, key))
                    continue

                try:
                    key = typeConvert(keytype, key)
                    dictionary[section][key] = typeConvert(valuetype, data)

                except ValueError:
                    log("%s Invalid value for [%s] %s: '%s'" %
                        (logPfx, section, key, data))

                except IndexError:
                    log("%s Too few values for [%s] %s (%d, should be %d)" %
                        (logPfx, section, key, len(stringlist), len(typelist)))

            elif line:
                log("%s Invalid line: '%s'" % (logPfx, line))

        fp.close()

    except IOError as e:
        raise RunException("Cannot read from '%s': %s" % (datafile, str(e)))


def saveToFile(datafile, dictionary, perm=0o600):
    try:
        fp = open(datafile, 'w')

        os.chmod(datafile, perm)

        for (section, subdict) in dictionary.items():
            fp.write("[%s]\n" % section)

            for (key, value) in subdict.items():
                if type(value) in (list, tuple):
                    value = " ".join(map(str, value))
                fp.write("%s = %s\n" % (key, value))

            fp.write("\n")

        fp.close()

    except IOError as e:
        raise RunException("Cannot write to %s: %s" % (datafile, str(e)))

    except OSError as e:
        raise RunException("Cannot set mode 0%o on %s: %s" %
                           (perm, datafile, str(e)))


def loadConfigAndData():
    now = int(time.time())

    try:
        loadFromFile(conffile, config)
    except RunException as e:
        log(str(e))

    try:
        loadFromFile(config[DATA][STATEFILE], data, datatypes)
    except RunException:
        data[STATS][START] = now

    expireKeys(now)
    data[STATS][LASTSAVE] = now


def saveData(datafile):
    # Save data hashes and timestamps
    tempfile = "%s.%s" % (datafile, os.getpid())

    saveToFile(tempfile, data)

    try:
        os.rename(tempfile, datafile)
    except OSError as e:
        raise RunException("Cannot rename %s to %s: %s" % (tempfile,
                                                           datafile, str(e)))


def syncTriplets(datafile, perm=0o600):
    source = datafile
    target = "%s.%s" % (source, os.getpid())

    try:
        infile = open(source, "r")
    except IOError:
        infile = None

    try:
        outfile = open(target, "w")

        os.chmod(target, perm)

        if infile:
            for line in infile:
                try:
                    (key, value) = line.split(" = ", 1)
                    key = int(key)

                    if listStatus(key) and not key in newTriplets:
                        outfile.write(line)

                except ValueError:
                    continue

        for (key, data) in newTriplets.items():
            if listStatus(key):
                outfile.write("%d = %s\n" % (key, data))

        newTriplets.clear()
        outfile.close()

    except IOError as e:
        raise RunException("Could not write to %s: %s" % (target, str(e)))

    except OSError as e:
        raise RunException("Cannot set mode 0%o on %s: %s" %
                           (perm, target, str(e)))

    if infile:
        infile.close()

    try:
        os.rename(target, source)
    except OSError as e:
        raise RunException("Could not rename %s to %s: %s" %
                           (target, source, str(e)))


def listTriplets(fp, socket, options):
    parseErrors = False
    firstPass = True
    listformat = "  %-20s %5s  %s\n"

    for listkey in (options or datatypes):
        if not listkey in data:
            raise CommandError("Invalid list: %s" % listkey)

        elif not data[listkey]:
            continue

        listdata = data[listkey]
        line = "%slist data:" % listkey.capitalize()
        dash = "=" * len(line)
        out = "\n%s\n%s\n" % (line, dash)
        client.send(out.encode('ascii', 'replace'))
        out = listformat % ("Last Seen", "Count", "Data")
        client.send(out.encode('ascii', 'replace'))

        fp.seek(0)
        for line in fp:
            try:
                (key, value) = line.split(" = ", 1)
                key = int(key)

                if key in listdata:
                    last, first, num = listdata[key]
                    ldate = time.strftime("%Y-%m-%d %H:%M:%S",
                                          time.localtime(last))
                    out = listformat % (ldate, num, value.strip())
                    client.send(out.encode('ascii', 'replace'))

            except ValueError:
                if not parseErrors:
                    log("While reading triplets from %s:" %
                        (config[DATA][TRIPLETFILE],))
                    parseErrors = True

                if firstPass:
                    log("Invalid line: '%s'" % line)

        firstPass = False


def do_add(options, key):
    if not options:
        state = WHITE
    elif options[0] in datatypes:
        state = options[0]
    else:
        raise CommandError("No such list: '%s'" % options[0])

    now = int(time.time())
    oldstate = listStatus(key)

    if state == oldstate:
        (lastseen, firstseen, count) = data[state][key]

    else:
        (lastseen, firstseen, count) = (now, now, 0)
        data[STATS][state] += 1
        if oldstate:
            del data[oldstate][key]

    data[state][key] = (now, firstseen, count + 1)
    return "Added to %slist" % state


def do_delete(options, key):
    listkey = listStatus(key)
    if listkey:
        del data[listkey][key]
        return "Removed from %slist" % listkey
    else:
        raise CommandError("Not found")


def do_status(options, args):
    if config[DATA][SINGLECHECK]:
        key = hash(args[0])
    else:
        key = hash(" ".join(args))
    state = listStatus(key)
    if state is None:
        state = UNSEEN
    return state


def do_check(options, args, update=False):
    if options:
        truthtest = options[0]
        if not truthtest in datatypes:
            raise CommandError("'%s' is not a known state" % truthtest)
    else:
        truthtest = None

    now = int(time.time())
    # according to #375504 this call to expireKeys() can be
    # removed here, expiration takes place when periodically
    # saving to disk
    # 375504: expireKeys(now)
    state = None

    if config[DATA][SINGLECHECK]:
        key = hash(args[0])
        state = listStatus(key)

    if state is None:
        key = hash(" ".join(args))
        state = listStatus(key)

    if state is None:
        state = GREY
        if update and config[DATA][TRIPLETFILE] and config[DATA][SAVETRIPLETS]:
            newTriplets[key] = " ".join(args).lower()

    elif ((state == GREY) and
          (data[GREY][key][IDX_FIRST] + config[TIMEOUTS][RETRYMIN] < now)):
        state = WHITE
        if update and config[DATA][SINGLEUPDATE]:
            do_delete(None, key)
            key = hash(args[0])
            if config[DATA][TRIPLETFILE] and config[DATA][SAVETRIPLETS]:
                newTriplets[key] = args[0]

    if update:
        do_add([state], key)

    if truthtest:
        return "true" if state == truthtest else "false"
    else:
        return state


def do_update(options, args):
    return do_check(options, args, update=True)


def do_stats():
    text = []
    now = int(time.time())
    stats = data[STATS]
    starttime = stats.get(START, None)
    expireKeys(now)

    if starttime:
        title = "Statistics since %s (%s ago)" % (
            time.ctime(starttime), duration(now - starttime))
    else:
        title = "Statistics"

    text.append(title)
    text.append("-" * len(title))
    hits = {}
    items = {}

    for listkey in datatypes:
        items[listkey] = len(data[listkey])
        hits[listkey] = 0
        for (key, value) in data[listkey].items():
            (lastseen, firstseen, count) = value
            hits[listkey] += count

    for listkey in datatypes:
        hitdigits = len(str(max(hits.values())))
        itemdigits = len(str(max(items.values())))

        text.append("%s items, matching %s requests, are currently %slisted" %
                    (str(items[listkey]).rjust(itemdigits),
                     str(hits[listkey]).rjust(hitdigits),
                     listkey))

    previousGrey = stats[GREY] - len(data[GREY])
    expiredGrey = previousGrey - stats[WHITE]

    if previousGrey:
        digits = len(str(previousGrey))

        text.append("")
        text.append("Of %s items that were initially greylisted:" %
                    str(previousGrey).rjust(digits))

        text.append(" - %s (%5.1f%%) became whitelisted" %
                    (str(stats[WHITE]).rjust(digits),
                     100.0 * stats[WHITE] / previousGrey))

        text.append(" - %s (%5.1f%%) expired from the greylist" %
                    (str(expiredGrey).rjust(digits),
                     100.0 * expiredGrey / previousGrey))

    text.append('')
    return "\n".join(text)


def do_mrtg():
    text = []
    now = int(time.time())
    stats = data[STATS]
    starttime = stats.get(START, None)
    expireKeys(now)

    text.append(str(stats[GREY]))
    text.append(str(stats[WHITE]))

    text.append(duration(now - starttime))
    text.append(socket.gethostname())

    text.append('')
    return "\n".join(text)


def do_list(options, socket):
    if not config[DATA][TRIPLETFILE] or not config[DATA][SAVETRIPLETS]:
        raise CommandError("Original triplet data is not retained.")

    do_save()

    try:
        infile = open(config[DATA][TRIPLETFILE], "r")
        error = listTriplets(infile, socket, options)
        infile.close()

    except IOError as e:
        raise CommandError("Cannot read from '%s': %s\n" %
                           (str(config[DATA][TRIPLETFILE]), str(e)))


def do_clear(options):
    for listkey in (options or datatypes):
        if listkey in data:
            data[listkey].clear()
            data[STATS][listkey] = 0
        else:
            raise CommandError("Invalid list: '%s'" % listkey)

    if not options:
        data[STATS][START] = int(time.time())

    return "data and statistics cleared"


def do_reload():
    do_save()
    loadConfigAndData()
    return "configuration and data reloaded"


def do_save():
    now = int(time.time())
    expireKeys(now)

    # Save data hashes and timestamps
    saveData(config[DATA][STATEFILE])

    # Save unhashed triplets
    if newTriplets:
        syncTriplets(config[DATA][TRIPLETFILE])

    data[STATS][LASTSAVE] = now

    return "greylistd data has been saved"


def nodata():
    raise CommandError("No data received")


def runCommand(line, client):
    now = int(time.time())
    words = line.lower().split()
    options = []

    if not words:
        function = nodata
    elif "do_" + words[0] in globals():
        function = globals()["do_" + words.pop(0)]
    else:
        function = do_update

    (args, varargs, varkw, defaults,
     kwonlyargs, kwonlydefaults, annotations) = inspect.getfullargspec(function)

    while words and words[0].startswith("-"):
        options.append(words.pop(0).lstrip("-"))

    arglist = []
    key = None
    useargs = False

    for arg in args:
        if arg == "options":
            arglist.append(options)

        elif arg == "key":
            key = hash(" ".join(words))
            useargs = True
            arglist.append(key)

        elif arg == "args":
            useargs = True
            arglist.append(words)

        elif arg == "socket":
            arglist.append(client)

        else:
            break

    try:
        if useargs and not words:
            raise CommandError("Missing argument")

        if words and not useargs:
            raise CommandError("Too many arguments")

        if key and config[DATA][TRIPLETFILE] and config[DATA][SAVETRIPLETS]:
            newTriplets[key] = " ".join(words)

        return function(*tuple(arglist)) or ""

    except RunException as e:
        if not isinstance(e, CommandError):
            log(str(e))

        return "error: %s" % str(e).replace("\n", "\n       ")


def createSocket(path, owner, mode):
    # If systemd has created the socket for us, we pick it up from
    # there
    LISTEN_FDS = int(os.environ.get("LISTEN_FDS", 0))

    if LISTEN_FDS == 0:
        sock = socket.socket(AF_UNIX, SOCK_STREAM)
        sock.bind(path)
    else:
        # systemd sockets start at SD_LISTEN_FDS_START = 3
        sock = socket.fromfd(3, AF_UNIX, SOCK_STREAM)
   
    if os.getuid() == 0:
        try:
            if ":" in owner:
                user, group = owner.split(":", 1)
            else:
                user, group = owner, None

            (name, passwd, uid, gid, gecos, dir, shell) = pwd.getpwnam(user)

            if group:
                (name, passwd, gid, members) = grp.getgrnam(group)

        except KeyError:
            raise RunException(
                "Invalid owner specified in configuration file: %s" % owner)

        try:
            os.chown(path, uid, gid)
        except OSError as e:
            raise RunException(
                "Could not change ownership of socket %s: %s" % (path, str(e)))

    try:
        os.chmod(path, int(mode, 8))

    except ValueError:
        raise RunException(
            "Specified socket mode '%s' is not a valid octal number" % mode)
    except OSError as e:
        raise RunException(
            "Could not set mode 0%o on socket %s: %s" % (int(mode, 8), path, str(e)))

    sock.listen(5)
    return sock


def startup():
    global listener, sockets
    listener = None
    sockets = []

    signal.signal(SIGTERM, term)
    signal.signal(SIGHUP, hangup)
    syslog.openlog("greylistd")

    loadConfigAndData()

    try:
        listener = createSocket(**config[SOCKET])

    except socket.error as e:
        log("Could not bind/listen to socket %s: %s" %
            (config[SOCKET][SOCKPATH], str(e)))
        sys.exit(1)

    except RunException as e:
        log(str(e))
        cleanup(False)
        sys.exit(1)

    sockets = [listener]


def cleanup(save=True):
    if os.path.exists(config[SOCKET][SOCKPATH]):
        os.remove(config[SOCKET][SOCKPATH])

    if save:
        do_save()


def term(signum=None, frame=None):
    cleanup()
    sys.exit(0)


def hangup(signum=None, frame=None):
    do_reload()


if len(sys.argv) > 1:
    conffile = sys.argv[1]

startup()

try:
    while sockets:
        interval = config[DATA][UPDATEINTERVAL]
        lastsave = data[STATS][LASTSAVE]

        if interval and (lastsave + interval < time.time()):
            (inlist, outlist, errlist) = select.select(sockets, [], [], 0)
        else:
            (inlist, outlist, errlist) = select.select(sockets, [], [], interval)

        if not inlist:
            try:
                do_save()
            except RunException as e:
                log(str(e))

        elif inlist[0] is listener:
            (client, addr) = listener.accept()
            sockets.append(client)

        else:
            client = inlist[0]

            try:
                line = client.recv(16384).decode('ascii', 'replace')
                reply = runCommand(line, client)
                client.send(reply.encode('ascii', 'replace'))

            except socket.error as e:
                log("Socket error: %s" % str(e))

            client.close()
            sockets.remove(client)


except SystemExit:
    pass

except KeyboardInterrupt:
    cleanup()

except Exception as e:
    (type, value, tb) = sys.exc_info()
    while tb.tb_next:
        tb = tb.tb_next

    frame = tb.tb_frame
    code = frame.f_code
    line = tb.tb_lineno
    filename = code.co_filename

    log("### Fatal event in %s, line %d:" % (filename, line), LOG_ERR)
    log(">>> %s" % e, LOG_ERR)

    cleanup(save=False)
