/* 

                          Firewall Builder

                 Copyright (C) 2001 Vadim Zaliva, Vadim Kurland

  Author:  Vadim Zaliva lord@crocodile.org

  $Id: dns.cc,v 1.3 2001/12/19 12:47:20 lord Exp $


  This program is free software which we release under the GNU General Public
  License. You may redistribute and/or modify this program under the terms
  of that license as published by the Free Software Foundation; either
  version 2 of the License, or (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
 
  To get a copy of the GNU General Public License, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

*/

#include <fwbuilder/dns.hh>

#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>
#include <resolv.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <unistd.h>

#include <memory>

#include <fwbuilder/ThreadTools.hh>

using namespace std;
using namespace libfwbuilder;

const size_t DNS::RSP_BUF_SIZE=64*1024;

DNS::DNS()
{
}

string DNS::getErrorMessage(int rcode)
{
    map<int, string> error_messages;
    
    // Define some common error codes
    error_messages[ns_r_formerr]  = "Format error"        ;
    error_messages[ns_r_servfail] = "Server failed"       ;
    error_messages[ns_r_nxdomain] = "No such domain name" ;
    error_messages[ns_r_notimpl]  = "Not implemented"     ;
    error_messages[ns_r_refused]  = "Refused"             ;
    error_messages[ns_r_yxdomain] = "Domain name exists"  ;
    error_messages[ns_r_yxrrset]  = "Rrset exists"        ;
    error_messages[ns_r_nxrrset]  = "Rrset doesn't exist" ;
    error_messages[ns_r_notauth]  = "Not authoritative"   ;
    error_messages[ns_r_notzone]  = "Not in zone"         ;
    error_messages[ns_r_badsig]   = "Bad signature"       ;
    error_messages[ns_r_badkey]   = "Bad key"             ;
    error_messages[ns_r_badtime]  = "Bad time"            ;

    if(error_messages.count(rcode))
        return error_messages[rcode];
    else
    {
        char buf[80];
        sprintf(buf,"DNS Error '%d'", rcode);
        return buf;
    }
}

HostEnt DNS::getHostByAddr(const IPAddress &addr, int retries_, int timeout_) throw(FWException)
{
    struct __res_state res;

    if(res_ninit(&res)==-1)
        throw FWException("Error initializing resolver library");

    res.retrans = timeout_;
    res.retry   = retries_;

    char host[NS_MAXDNAME];
    sprintf(host, "%u.%u.%u.%u.IN-ADDR.ARPA.", addr[3], addr[2], addr[1], addr[0]);

    u_char buf[PACKETSZ];
    int n = res_nmkquery(&res, ns_o_query, host, ns_c_in, ns_t_ptr, NULL, 0, NULL,
                         buf, sizeof buf);
    if (n < 0) 
        throw FWException(string("Resoving ")+host+" failed in res_nmkquery");
    
    auto_ptr<unsigned char> answer(new unsigned char[RSP_BUF_SIZE]);
    int len = res_nsend(&res, buf, n, answer.get(), RSP_BUF_SIZE);
    if (len < 0) 
        throw FWException(string("Resoving ")+host+" failed in res_nsend");
    
    // Let's decode answer

    if(len<0)
        throw FWException("Error returned while quering domain NS records");

    // Rsp. buffer
    HostEnt v;

    ns_msg handle;
    
    if(ns_initparse(answer.get(), len, &handle) < 0) 
        throw FWException("Zone parse error in initparse");
    
    if(ns_msg_getflag(handle, ns_f_rcode) != ns_r_noerror)
        throw FWException(getErrorMessage(ns_msg_getflag(handle, ns_f_rcode)));
    
    if(ns_msg_count(handle, ns_s_an) == 0)
        throw FWException("Answer contains to records");
    
    while(true) 
    {        
        ns_rr rr;
        if(ns_parserr(&handle, ns_s_an, -1, &rr))
        {
            if(errno != ENODEV) 
                throw FWException("NS query response parse error in parserr");
            else
                break;
        }
        
        if(ns_rr_type(rr)==ns_t_ptr && ns_rr_class(rr)==ns_c_in)
        {
            char dn[NS_MAXDNAME];
            if(dn_expand(answer.get(), answer.get() + len, ns_rr_rdata(rr), dn, sizeof(dn))<0)
                throw FWException("A record parse error in parserr");
            if(v.name.empty())
                v.name=dn;
            else
                v.aliases.insert(dn);
        } 
    }
    return v;
}


HostEnt DNS::getHostByAddr(const IPAddress &addr) throw(FWException)
{
    struct hostent hostbuf;
    struct hostent *hp;
    int herr;
    
    size_t hstbuflen = 1024; 
    char *tmphstbuf = (char *)malloc(hstbuflen);

    struct in_addr naddr;
    naddr.s_addr=addr;

    int res;
#ifdef DEBIAN
    while((res = gethostbyaddr_r((const char *)&naddr, sizeof(naddr),
#else
     while((res = gethostbyaddr_r(&naddr, sizeof(naddr),
#endif
                                 AF_INET,
                                 &hostbuf,
                                 tmphstbuf, hstbuflen,
                                 &hp,
                                 &herr)
          ) && (herr == ERANGE))
    {
        hstbuflen *= 2;
        tmphstbuf = (char *)realloc(tmphstbuf, hstbuflen);
    }
    if(res || herr)
    {
        free(tmphstbuf);
        throw FWException(string("Hostname of address: '")+IPAddress(&naddr).toString()+"' not found");
    } 
    
    HostEnt v;
    v.name=hp->h_name;
    if(hp->h_aliases)
        for(char **p = hp->h_aliases; *p; p++) 
            v.aliases.insert(string(*p));

    free(tmphstbuf);
    return v;
}

vector<IPAddress> DNS::getHostByName(const string &name) throw(FWException)
{
    struct hostent hostbuf;
    struct hostent *hp;
    int herr;
    
    size_t hstbuflen = 1024; 
    char *tmphstbuf = (char *)malloc(hstbuflen);

#ifdef HAVE_FUNC_GETHOSTBYNAME_R_6
    int res;
    while((res = gethostbyname_r(name.c_str(), &hostbuf,tmphstbuf,hstbuflen,&hp,&herr))
          && (herr == ERANGE))
    {
        hstbuflen *= 2;
        tmphstbuf = (char *)realloc(tmphstbuf, hstbuflen);
    }
    if(res)
    {
        free(tmphstbuf);
        throw FWException("Host or network '"+name+"' not found");
    } 
# else
# ifdef HAVE_FUNC_GETHOSTBYNAME_R_5
    while(!(hp = gethostbyname_r(name.c_str(), &hostbuf, tmphstbuf, hstbuflen, &herr))
          && (herr == ERANGE))
    {
        hstbuflen *= 2;
        tmphstbuf = (char *)realloc(tmphstbuf, hstbuflen);
    }
    if(!hp)
    {
        free(tmphstbuf);
        throw FWException("Host or network '"+name+"' not found");
    } 
# else
#   error "unsupported gethostbyname_r()"
# endif
#endif
    
    vector<IPAddress> v;
    try
    {
        for(char **p = hp->h_addr_list; *p != 0; p++) 
            v.push_back(IPAddress((struct in_addr *)(*p)));
    } catch(const FWException &e)
    {
        free(tmphstbuf);
        throw;
    }
    free(tmphstbuf);
    return v;
}


multimap<string, IPAddress> DNS::getNS(const string &domain, Logger *logger, int retries_, int timeout_) throw(FWException)
{
    struct __res_state res;
    
    if(res_ninit(&res)==-1)
        throw FWException("Error initializing resolver library");

    res.retrans = timeout_;
    res.retry   = retries_;

    check_stop();

    auto_ptr<unsigned char> answer(new unsigned char[RSP_BUF_SIZE]);
    
    *logger << "Requesting list of name servers for domain '" << domain << "'"  << '\n';
    int  len = res_nquery(&res, 
                          domain.c_str(), 
                          ns_c_in,
                          ns_t_ns, 
                          answer.get(), 
                          RSP_BUF_SIZE);

    check_stop();
    
    if(len<0)
        throw FWException("Error returned while quering domain NS records");

    // Rsp. buffer
    multimap<string, IPAddress> v;

    ns_msg handle;
    
    if(ns_initparse(answer.get(), len, &handle) < 0) 
        throw FWException("Zone parse error in initparse");
    
    check_stop();

    if(ns_msg_getflag(handle, ns_f_rcode) != ns_r_noerror)
        throw FWException(getErrorMessage(ns_msg_getflag(handle, ns_f_rcode)));
    
    check_stop();

    if(ns_msg_count(handle, ns_s_an) == 0)
        throw FWException("Answer contains to records");
    
    while(true) 
    {        
        check_stop();
        ns_rr rr;
        if(ns_parserr(&handle, ns_s_an, -1, &rr))
        {
            if(errno != ENODEV) 
                throw FWException("NS query response parse error in parserr");
            else
                break;
        }
        
        check_stop();
        if(ns_rr_type(rr)==ns_t_ns && ns_rr_class(rr)==ns_c_in)
        {
            char dn[NS_MAXDNAME];
            if(dn_expand(answer.get(), answer.get() + len, ns_rr_rdata(rr), dn, sizeof(dn))<0)
                throw FWException("A record parse error in parserr");
            check_stop();
            vector<IPAddress> a=DNS::getHostByName(dn);
            check_stop();
            for(vector<IPAddress>::iterator i=a.begin();i!=a.end();++i)
                v.insert(pair<string, IPAddress>(string(dn), (*i)));
        } 
    }
    
    *logger << "Succesfuly found " << (int)v.size() << " name servers."  << '\n';
    return v;
}

/**
 * 'Retries' applicable only to UPD part of the query (if any).
 * TCP connection to transfer zone established and attempted 
 * only once.
 */
map<string, set<IPAddress> > DNS::findA(const string &domain, Logger *logger, int retries_, int timeout_) throw(FWException)
{
    TimeoutCounter timeout(timeout_, "Getting A records");

    *logger << "Looking for authoritative servers" << '\n';
    
    multimap<string, IPAddress> ns=DNS::getNS(domain, logger);
    if(!ns.size())
        throw FWException("No NS records found");

    check_stop();
    timeout.check();

    FWException *last_err;
    for(multimap<string, IPAddress>::iterator nsi=ns.begin(); nsi!=ns.end(); ++nsi)
    {
        try
        {
            return findA(domain, (*nsi).second, logger, retries_, timeout.timeLeft());
            timeout.check();
        } catch(FWException &ex)
        {
            *logger << "Quering NS " << (*nsi).second.toString() 
                    << " ( " << (*nsi).first << " ) " 
                    << " failed. with error: '" << ex.toString() << "'" << '\n' 
                    << " Cycling to next one" << '\n';
            last_err = new FWException(ex) ;
        }
    }
    
    *logger << "No more servers to ask. Query failed." << '\n';
    
    // We get here if we cycled through all name servers
    // with no luck.
    // Throw last exception.
    throw *last_err;
}

/**
 * 'Retries' applicable only to UDP part of the query (if any).
 * TCP connection to transfer zone established and attempted 
 * only once.
 */
map<string, set<IPAddress> > DNS::findA(const string &domain, const IPAddress &ns, Logger *logger, int retries_, int timeout_) throw(FWException)
{
    TimeoutCounter timeout(timeout_, "Getting A records");

    *logger << "Querying server: " << ns.toString() << '\n';
    
    struct __res_state res;
    
    if(res_ninit(&res)==-1)
        throw FWException("Error initializing resolver library");

    res.retrans = timeout_;
    res.retry   = retries_;
    
    struct in_addr nsaddress;
    inet_aton(ns.toString().c_str(), &nsaddress);
    res.nsaddr.sin_addr = nsaddress;
    
    res.nscount = 1;
    res.nsaddr_list[0].sin_family = AF_INET;
    res.nsaddr_list[0].sin_port   = htons(NAMESERVER_PORT);
    res.nsaddr_list[0].sin_addr   = nsaddress;

    check_stop();
    timeout.check();

    // Create a query packet for the requested zone name.

    u_char buf[PACKETSZ];
    int msglen = res_nmkquery(&res, ns_o_query, domain.c_str(),
                              ns_c_in, ns_t_axfr, NULL,
                              0, 0, buf, sizeof buf);
    
    check_stop();
    timeout.check();

    if(msglen<0)
        throw FWException("Error returned while creating DNS query");
    
    const struct sockaddr_in *sin=&res.nsaddr_list[0];
    int sockFD;
    
    if((sockFD = socket(sin->sin_family, SOCK_STREAM, 0)) < 0) 
        throw FWException("Error creating DNS socket");

    if (connect(sockFD, (struct sockaddr *)sin, sizeof *sin) < 0) 
        throw FWException("Error connecting to DNS server");
    
    check_stop();
    timeout.check();

    // Send length & message for zone transfer
    
    u_char tmp[NS_INT16SZ];
    ns_put16(msglen, tmp);
    
    map<string, set<IPAddress> > v;
            
    try
    {
        check_stop();
	timeout.check();

        if(write(sockFD, tmp, NS_INT16SZ) != NS_INT16SZ)
            throw FWException("Error 01 sending AXFR query to DNS server");
        check_stop();
	timeout.check();

        *logger << "Sending Query" << '\n';
        if(write(sockFD, buf, msglen) != msglen) 
            throw FWException("Error 02 sending AXFR query to DNS server");
    
        bool firstsoa = true  ;
        bool done     = false ;
        string soa;

        *logger << "Reading response." << '\n';
        while(!done)
        {
            check_stop();
	    timeout.check();

            // Read the length of the response.
            u_char *cp = tmp;
            size_t amtToRead = INT16SZ;
            while(amtToRead > 0)
            {
                check_stop();
		timeout.check();

		int numRead=timeout.read(sockFD, cp, amtToRead);

                if(numRead <= 0) 
                    throw FWException("Error reading AXFR response length");
                cp        += numRead;
                amtToRead -= numRead;
            }
        
            int len = amtToRead = ns_get16(tmp);
            if(len == 0)
                break;
            else if(len<0)
                throw FWException("Ivalid packet size returned");

            check_stop();
	    timeout.check();

            try
            {
                cp=new unsigned char[len];
            } catch(bad_alloc)
            {
                throw FWException("Response packet too big");
            }
            auto_ptr<unsigned char> answer(cp);
    
            // Read the response.
            cp = answer.get();
            while(amtToRead > 0)
            {
                check_stop();
		timeout.check();

		int numRead=timeout.read(sockFD, cp, amtToRead);

                if(numRead <= 0)
                    throw FWException("Error reading AXFR response");
                cp        += numRead;
                amtToRead -= numRead;
            }

            // Read next answer.
            // Let's parse it

            ns_msg handle;
                    
            check_stop();
	    timeout.check();

            if(ns_initparse(answer.get(), len, &handle) < 0) 
                throw FWException("Zone parse error in initparse");
                    
            if(ns_msg_getflag(handle, ns_f_rcode) != ns_r_noerror)
                throw FWException(getErrorMessage(ns_msg_getflag(handle, ns_f_rcode)));
                    
            if(ns_msg_count(handle, ns_s_an) == 0)
                throw FWException("Answer contains to records");
                    
            while(true) 
            {
                check_stop();
		timeout.check();

                ns_rr rr;
                if(ns_parserr(&handle, ns_s_an, -1, &rr))
                {
                    if(errno != ENODEV) 
                        throw FWException("Zone parse error in parserr");
                    else
                        break;
                }
                        
                check_stop();
		timeout.check();

                if(ns_rr_type(rr)==ns_t_a && ns_rr_class(rr)==ns_c_in)
                {
                    if(ns_rr_rdlen(rr) != NS_INADDRSZ)
                        throw FWException("Invalid address length in A record");
                    
                    if(v.find(ns_rr_name(rr))==v.end())
                        v[ns_rr_name(rr)]=set<IPAddress>();
                    
                    v[ns_rr_name(rr)].insert(IPAddress((const struct in_addr *)ns_rr_rdata(rr)));
                    
                } else if(ns_rr_type(rr)==ns_t_soa)
                {
                    if(firstsoa)
                    {
                        firstsoa = false;
                        soa=ns_rr_name(rr);
                    } else
                    {
                        if(ns_samename(soa.c_str(), ns_rr_name(rr)) == 1)
                        {
                            done=true;
                            break;
                        }
                    }
                }
            }
        }
        (void) close(sockFD);
        
    } catch(FWException &ex)
    {
        (void) close(sockFD);
        throw;
    }
    *logger << "Succesfuly found " << (int)v.size() << " hosts."  << '\n';
    return v;
}

DNS_getNS_query::DNS_getNS_query(const string &domain_, int retries_, int timeout_)
{
    domain  = domain_  ;
    retries = retries_ ;
    timeout = timeout_ ;
}

void DNS_getNS_query::run_impl(Logger *logger) throw(FWException)
{
    result = DNS::getNS(domain, logger, retries, timeout);
}

DNS_findA_query::DNS_findA_query()
{
}

DNS_findA_query::DNS_findA_query(const string &domain_, const IPAddress &ns_, int retries_, int timeout_)
{
        domain  = domain_  ;
        retries = retries_ ;
        timeout = timeout_ ;
        ns      = ns_      ;
}

void DNS_findA_query::init(const string &domain_, const IPAddress &ns_, int retries_, int timeout_)
{
        domain  = domain_  ;
        retries = retries_ ;
        timeout = timeout_ ;
        ns      = ns_      ;
}

void DNS_findA_query::run_impl(Logger *logger) throw(FWException)
{
    result = findA(domain, ns, logger, retries, timeout);
}

DNS_bulkBackResolve_query::DNS_bulkBackResolve_query(set<IPAddress> _ips, 
                                                     unsigned int _nthreads,
						     int _retries,
                                                     int _timeout)
{
    for(set<IPAddress>::iterator j = _ips.begin(); j!=_ips.end(); ++j)
        ips.push(*j);

    retries  = _retries  ;
    timeout  = _timeout  ;
    nthreads = _nthreads ;

}

DNS_bulkBackResolve_query::~DNS_bulkBackResolve_query()
{
}

void* libfwbuilder::DNS_bulkBackResolve_Thread(void *args)
{
    void **void_pair=(void**)args;
    DNS_bulkBackResolve_query *p      = static_cast<DNS_bulkBackResolve_query*>(void_pair[0]);
    Logger                    *logger = static_cast<Logger *>(void_pair[1]);

    while(!p->get_stop_program_flag())
    {
        p->queue_mutex.lock();
        if(p->ips.empty())
        {
            p->queue_mutex.unlock();
            break;
        }
        IPAddress j=p->ips.front(); p->ips.pop();
        p->queue_mutex.unlock();
        
        try
        {
            if(p->get_stop_program_flag())
                break;
            HostEnt he=DNS::getHostByAddr(j, p->retries, p->timeout);
	    *logger << start << "Resolved  " << j.toString() << ": " << he.name << "\n" << end ;
            p->result_mutex.lock();
            p->result[j]=he;
            p->result_mutex.unlock();
        } catch(FWException &ex)
        {
            *logger << start << "Could not resolve address " << j.toString() << " to a host name, will use generic name\n" << end ;
            p->failed_mutex.lock();
            p->failed.insert(j);
            p->failed_mutex.unlock();
        }
    }
    
    p->running_mutex.lock   ();
    p->running_cond.signal  ();
    p->running_mutex.unlock ();
    
    return NULL;
}

void DNS_bulkBackResolve_query::run_impl(Logger *logger) throw(FWException)
{
    queue_mutex.lock();
    for(unsigned int i=0;i<nthreads;i++)
    {
        void **void_pair = new void*[2];
        void_pair[0]     = this;
        void_pair[1]     = logger;

        pthread_t tid;
        pthread_create(&tid, NULL, DNS_bulkBackResolve_Thread, void_pair);
        pthread_detach(tid);
    }
    runnning_count = nthreads;

    running_mutex.lock();

    queue_mutex.unlock();

    while(runnning_count)
    {
        running_cond.wait(running_mutex);
        runnning_count--;
    }
    running_mutex.unlock();

    check_stop();
    if(!failed.empty())
        throw FWException("Some of IPs did not resolve");
}
