/* $Id$ */

/*
 *
 * Copyright (C) 2004 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "asmtpd.h"

domain_map dmap;

struct alias_map : public map_base {
  str path () { return opt->alias_file; }
  str lookup (str name);
};

rxx &
map_base::linerx ()
{
  static rxx aliasrx ("^\\s*([^\\x00-\\x20\\x7f:]+)\\s*:"
		      "\\s*([^\\x00-\\x20\\x7f]*)\\s*$");
  return aliasrx;
}

bool
map_base::load ()
{
  for (;;) {
    struct stat sb;
    if (stat (path (), &sb) < 0) {
      latest = 0;
      loadno = 0;
      if (smtpd::tmperr (errno)) {
	warn << path () << ": " << strerror (errno) << "\n";
	return false;
      }
      table.deleteall ();
      return true;
    }
    if (opt->configno == loadno && latest == max (sb.st_mtime, sb.st_ctime))
      return true;

    warn << "loading " << path () << "\n";

    loadno = 0;
    table.deleteall ();
    int fd = open (path (), O_RDONLY);
    if (fd < 0)
      return !smtpd::tmperr (errno);
    suio buf;
    u_int lineno = 0;
    while (buf.input (fd) > 0)
      while (str line = suio_getline (&buf)) {
	lineno++;
	rxx &aliasrx = linerx ();
	static rxx blankrx ("^\\s*$");
	if (aliasrx.match (line)) {
	  str k = mytolower (aliasrx[1]);
	  if (table[k])
	    warn << path () << ":" << lineno
		 << ": overriding previous entry for " << k << "\n";
	  //warn << k << " -> " << aliasrx[2] << "\n";
	  table.insert (k, aliasrx[2]);
	}
	else if (!blankrx.match (line))
	  warn << path () << ":" << lineno << ": syntax error\n";
      }
    if (buf.resid ())
      warn << path () << ": ignoring incomplete last line\n";

    struct stat sb2;
    errno = 0;
    fstat (fd, &sb2);
    close (fd);

    if (errno) {
      warn << path () << ": " << strerror (errno) << "\n";
      continue;
    }
    if (sb.st_dev != sb2.st_dev || sb.st_ino != sb2.st_ino
	|| sb.st_mtime != sb2.st_mtime || sb.st_ctime != sb2.st_ctime) {
      warn << path () << ": file changed while reading\n";
      continue;
    }

    latest = max (sb.st_ctime, sb.st_mtime);
    loadno = opt->configno;
    return true;
  }
}

str
alias_map::lookup (str orig)
{
  if (!load ())
    return NULL;

  str *to = table[mytolower (orig)];
  if (to && to->len ())
    return *to;
  if (!opt->separator)
    return orig;

  str name (orig), base (name), ext;
  for (int i = 0; i < 20;) {
    const char *p = strrchr (base, opt->separator);
    if (!p || p == base.cstr ())
      return name;
    if (ext)
      ext = strbuf ("%s%c", p + 1, opt->separator) << ext;
    else
      ext = p + 1;
    base = substr (base, 0, p - base.cstr ());
    if ((to = table[mytolower (base)])) {
      i++;
      base = name = strbuf ("%s%c%s", to->cstr (),
			    opt->separator, ext.cstr ());
      ext = NULL;
    }
  }

  maybe_warn (strbuf () << "possible loop during alias expansion for "
	      << orig << "\n");
  return orig;
}

bool
domain_map::lookup (str *avuser, str recip)
{
  if (!load ())
    return false;

  str local = extract_local (recip);
  str domain = mytolower (extract_domain (recip));

  str *to = table[domain];
  if (to && to->len ()) {
    if ((*to)[to->len () - 1] == opt->separator)
      *avuser = strbuf () << *to << local;
    else
      *avuser = *to;
  }
  else if (to)
    *avuser = local;
  else
    *avuser = NULL;

  return true;
}

static alias_map amap;

localcheck::localcheck (smtpd *s, str r, char m, cbs c)
  : smtp (s), recip (r), cb (c), depth (0), indefault (false)
{
  switch (m) {
  case 'r':
    mode = 'r';
    if (opt->user_rcpt) {
      try_user = true;
      unknown_user = "unknown";
      fallback_user = "default";
    }
    else {
      try_user = false;
      fallback_user = "default";
    }
    break;
  case 'R':
    mode = 'r';
    try_user = false;
    fallback_user = "secondary";
    break;
  case 'm':
    mode = 'm';
    try_user = opt->user_mail;
    fallback_user = "relay";
    break;
  case 'M':
    mode = 'm';
    try_user = false;
    fallback_user = "relay";
    break;
  default:
    panic ("localcheck: bad mode %c\n", mode);
  }
}

void
localcheck::reply (str res, str bodycmd)
{
#if 0
  if (mode != 'r' && mode != 'R')
    (*cb) (res);
  else
#endif
  if (bodycmd)
    (*cb) (smtp->bodycheck (execuser, bodycmd, smtpd::okstr));
  else if (!res || res[0] == '2')
    (*cb) (smtp->bodycheck (NULL, NULL, res));
  else
    (*cb) (res);
  delete this;
}

void
localcheck::init ()
{
  str addr;
  switch (mode) {
  case 'r':
    addr = recip;
    break;
  case 'm':
    addr = smtp->get_from ();
    break;
  }

  if (try_user && !dmap.lookup (&avuser, addr)) {
    reply ("451 temporary error processing domain file\r\n");
    return;
  }
  if (try_user) {
    if (!avuser)
      avuser = extract_local (recip);
    if (avuser) {
      avenge ();
      return;
    }
  }
  dodefault ();
}

void
localcheck::avenge ()
{
  if (!avuser || !(avuser = amap.lookup (avuser))) {
    reply ("451 temporary error processing alias file\r\n");
    return;
  }

  if (!loop.insert (mytolower (avuser)) || ++depth >= 20) {
    warn << "loop while checking rcpt " << recip << "\n";
    reply (NULL);
    return;
  }

  avenge_1 ();
}

void
localcheck::avenge_1 ()
{
  str user = avuser;
  str ext;
  if (opt->separator)
    if (const char *p = strchr (user, opt->separator)) {
      ext = mytolower (p + 1);
      user = substr (user, 0, p - user.cstr ());
    }

  struct passwd *pw = validuser (user, user != opt->av_user->pw_name);
  if (!pw) {
    execuser = NULL;
    if (unknown_user)
      avif::alloc (opt->av_user, smtp, recip, 's', NULL, unknown_user, avuser,
		   wrap (this, &localcheck::avenge_2),
		   modeenv ());
    else
      dodefault ();
  }
  else {
    execuser = pw->pw_name;
    avcount *avc = avcount::get (pw->pw_uid);
    if (avc->acquire ())
      avif::alloc (pw, smtp, recip, mode, avc, ext, avuser,
		   wrap (this, &localcheck::avenge_2),
		   modeenv ());
    else
      avc->waiters.push_back (wrap (this, &localcheck::avenge_1));
  }
}

void
localcheck::dodefault ()
{
  if (indefault || !fallback_user) {
    if (mode == 'r' || mode == 'R')
      (*cb) (smtp->bodycheck (NULL, NULL, NULL));
    else
      (*cb) (NULL);
    delete this;
  }
  else {
    indefault = true;
    execuser = NULL;
    avif::alloc (opt->av_user, smtp, recip, 's', NULL, fallback_user, avuser,
		 wrap (this, &localcheck::avenge_2),
		 modeenv ());
  }
}

void
localcheck::avenge_2 (avif::disp_t disp, str r)
{
  switch (disp) {
  case avif::NEXT:
    dodefault ();
    break;;
  case avif::DONE:
    reply (r);
    break;
  case avif::REDIR:
    {
      if (!validate_local (r)) {
	warn << "bad redirect from " << avuser << " -> " << r << "\n";
	dodefault ();
	return;
      }
      str newuser = r;
      if (opt->separator)
	if (const char *p = strchr (newuser, opt->separator))
	  newuser = substr (newuser, 0, p - newuser.cstr ());
      if (execuser
	  && strcasecmp (execuser, newuser)
	  && strcmp (execuser, opt->av_user->pw_name)
	  && strcmp (newuser, opt->av_user->pw_name)) {
	warn << "bad redirect from " << avuser << " -> " << r << "\n";
	dodefault ();
      }
      else {
	avuser = r;
	avenge ();
      }
    }
    break;
  case avif::BODY:
    reply (NULL, r);
    break;
  }
}

str
localcheck::modeenv ()
{
  static str rcpt ("AVENGER_MODE=rcpt");
  static str mail ("AVENGER_MODE=mail");
  switch (mode) {
  case 'r':
    return rcpt;
  case 'm':
    return mail;
  default:
    panic ("localcheck::modeenv: bad mode %c\n", mode);
  }
}

void
rcptcheck (smtpd *s, str recip, char mode, cbs cb)
{
  localcheck *lc = New localcheck (s, recip, mode, cb);
  lc->init ();
}

