diff options
Diffstat (limited to 'src/dns.cpp')
-rw-r--r-- | src/dns.cpp | 102 |
1 files changed, 87 insertions, 15 deletions
diff --git a/src/dns.cpp b/src/dns.cpp index d2f08a6e0..385f8639e 100644 --- a/src/dns.cpp +++ b/src/dns.cpp @@ -95,8 +95,9 @@ class DNSRequest insp_inaddr myserver; /* DNS server address*/ DNS* dnsobj; /* DNS caller (where we get our FD from) */ unsigned long ttl; /* Time to live */ + std::string orig; /* Original requested name/ip */ - DNSRequest(InspIRCd* Instance, DNS* dns, insp_inaddr server, int id); + DNSRequest(InspIRCd* Instance, DNS* dns, insp_inaddr server, int id, const std::string &original); ~DNSRequest(); DNSInfo ResultIsReady(DNSHeader &h, int length); int SendRequests(const DNSHeader *header, const int length, QueryType qt); @@ -134,11 +135,12 @@ class RequestTimeout : public InspTimer }; /* Allocate the processing buffer */ -DNSRequest::DNSRequest(InspIRCd* Instance, DNS* dns, insp_inaddr server, int id) : dnsobj(dns) +DNSRequest::DNSRequest(InspIRCd* Instance, DNS* dns, insp_inaddr server, int id, const std::string &original) : dnsobj(dns) { res = new unsigned char[512]; *res = 0; memcpy(&myserver, &server, sizeof(insp_inaddr)); + orig = original; RequestTimeout* RT = new RequestTimeout(Instance->Config->dns_timeout ? Instance->Config->dns_timeout : 5, Instance, this, id); Instance->Timers->AddTimer(RT); /* The timer manager frees this */ } @@ -218,7 +220,7 @@ int DNSRequest::SendRequests(const DNSHeader *header, const int length, QueryTyp } /** Add a query with a predefined header, and allocate an ID for it. */ -DNSRequest* DNS::AddQuery(DNSHeader *header, int &id) +DNSRequest* DNS::AddQuery(DNSHeader *header, int &id, const char* original) { /* Is the DNS connection down? */ if (this->GetFd() == -1) @@ -231,7 +233,7 @@ DNSRequest* DNS::AddQuery(DNSHeader *header, int &id) while (requests[id]) id = this->PRNG() & DNS::MAX_REQUEST_ID; - DNSRequest* req = new DNSRequest(ServerInstance, this, this->myserver, id); + DNSRequest* req = new DNSRequest(ServerInstance, this, this->myserver, id, original); header->id[0] = req->id[0] = id >> 8; header->id[1] = req->id[1] = id & 0xFF; @@ -263,6 +265,19 @@ void DNS::Rehash() shutdown(this->GetFd(), 2); close(this->GetFd()); this->SetFd(-1); + + /* Rehash the cache */ + dnscache* newcache = new dnscache(); + for (dnscache::iterator i = this->cache->begin(); i != this->cache->end(); i++) + newcache->insert(*i); + + delete this->cache; + this->cache = newcache; + } + else + { + /* Create initial dns cache */ + this->cache = new dnscache(); } if (insp_aton(ServerInstance->Config->DNSServer,&addr) > 0) @@ -358,9 +373,18 @@ DNS::DNS(InspIRCd* Instance) : ServerInstance(Instance) /* Set the id of the next request to 0 */ currid = 0; + + /* DNS::Rehash() sets this to a valid ptr + */ + this->cache = NULL; + /* Again, DNS::Rehash() sets this to a + * valid value + */ this->SetFd(-1); + /* Actually read the settings + */ this->Rehash(); } @@ -411,7 +435,7 @@ int DNS::GetIP(const char *name) if ((length = this->MakePayload(name, DNS_QUERY_A, 1, (unsigned char*)&h.payload)) == -1) return -1; - DNSRequest* req = this->AddQuery(&h, id); + DNSRequest* req = this->AddQuery(&h, id, name); if ((!req) || (req->SendRequests(&h, length, DNS_QUERY_A) == -1)) return -1; @@ -429,7 +453,7 @@ int DNS::GetIP6(const char *name) if ((length = this->MakePayload(name, DNS_QUERY_AAAA, 1, (unsigned char*)&h.payload)) == -1) return -1; - DNSRequest* req = this->AddQuery(&h, id); + DNSRequest* req = this->AddQuery(&h, id, name); if ((!req) || (req->SendRequests(&h, length, DNS_QUERY_AAAA) == -1)) return -1; @@ -447,7 +471,7 @@ int DNS::GetCName(const char *alias) if ((length = this->MakePayload(alias, DNS_QUERY_CNAME, 1, (unsigned char*)&h.payload)) == -1) return -1; - DNSRequest* req = this->AddQuery(&h, id); + DNSRequest* req = this->AddQuery(&h, id, alias); if ((!req) || (req->SendRequests(&h, length, DNS_QUERY_CNAME) == -1)) return -1; @@ -479,7 +503,7 @@ int DNS::GetName(const insp_inaddr *ip) if ((length = this->MakePayload(query, DNS_QUERY_PTR, 1, (unsigned char*)&h.payload)) == -1) return -1; - DNSRequest* req = this->AddQuery(&h, id); + DNSRequest* req = this->AddQuery(&h, id, insp_ntoa(*ip)); if ((!req) || (req->SendRequests(&h, length, DNS_QUERY_PTR) == -1)) return -1; @@ -525,7 +549,7 @@ int DNS::GetNameForce(const char *ip, ForceProtocol fp) if ((length = this->MakePayload(query, DNS_QUERY_PTR, 1, (unsigned char*)&h.payload)) == -1) return -1; - DNSRequest* req = this->AddQuery(&h, id); + DNSRequest* req = this->AddQuery(&h, id, ip); if ((!req) || (req->SendRequests(&h, length, DNS_QUERY_PTR) == -1)) return -1; @@ -577,7 +601,7 @@ DNSResult DNS::GetResult() { /* Nope - something screwed up. */ ServerInstance->Log(DEBUG,"Whole header not read!"); - return DNSResult(-1,"",0); + return DNSResult(-1,"",0,""); } /* Check wether the reply came from a different DNS @@ -605,7 +629,7 @@ DNSResult DNS::GetResult() if ((port_from != DNS::QUERY_PORT) || (strcasecmp(ipaddr_from, ServerInstance->Config->DNSServer))) { ServerInstance->Log(DEBUG,"port %d is not 53, or %s is not %s",port_from, ipaddr_from, ServerInstance->Config->DNSServer); - return DNSResult(-1,"",0); + return DNSResult(-1,"",0,""); } } @@ -623,7 +647,7 @@ DNSResult DNS::GetResult() { /* Somehow we got a DNS response for a request we never made... */ ServerInstance->Log(DEBUG,"DNS: got a response for a query we didnt send with fd=%d queryid=%d",this->GetFd(),this_id); - return DNSResult(-1,"",0); + return DNSResult(-1,"",0,""); } else { @@ -648,8 +672,9 @@ DNSResult DNS::GetResult() * an error response and needs to be treated uniquely. * Put the error message in the second field. */ + std::string ro = req->orig; delete req; - return DNSResult(this_id | ERROR_MASK, data.second, 0); + return DNSResult(this_id | ERROR_MASK, data.second, 0, ro); } else { @@ -714,8 +739,9 @@ DNSResult DNS::GetResult() } /* Build the reply with the id and hostname/ip in it */ + std::string ro = req->orig; delete req; - return DNSResult(this_id,resultstr,ttl); + return DNSResult(this_id,resultstr,ttl,ro); } } @@ -871,11 +897,50 @@ DNS::~DNS() close(this->GetFd()); } +CachedQuery* DNS::GetCache(const std::string &source) +{ + dnscache::iterator x = cache->find(source.c_str()); + if (x != cache->end()) + return &(x->second); + else + return NULL; +} + +void DNS::DelCache(const std::string &source) +{ + cache->erase(source.c_str()); +} + +void Resolver::OnLookupComplete(const std::string &result, unsigned int ttl) +{ + throw CoreException("Someone didnt define an OnLookupComplete method for their Resolver class!"); +} + /** High level abstraction of dns used by application at large */ -Resolver::Resolver(InspIRCd* Instance, const std::string &source, QueryType qt, Module* creator) : ServerInstance(Instance), Creator(creator), input(source), querytype(qt) +Resolver::Resolver(InspIRCd* Instance, const std::string &source, QueryType qt, bool &cached, Module* creator) : ServerInstance(Instance), Creator(creator), input(source), querytype(qt) { ServerInstance->Log(DEBUG,"Instance: %08x %08x",Instance, ServerInstance); + cached = false; + + CachedQuery* CQ = ServerInstance->Res->GetCache(source); + if (CQ) + { + int time_left = CQ->CalcTTLRemaining(); + if (!time_left) + { + ServerInstance->Log(DEBUG,"Cached but EXPIRED result: %s", CQ->data.c_str()); + ServerInstance->Res->DelCache(source); + } + else + { + cached = true; + ServerInstance->Log(DEBUG,"Cached result: %s", CQ->data.c_str()); + OnLookupComplete(CQ->data, time_left); + return; + } + } + insp_inaddr binip; switch (querytype) @@ -988,6 +1053,13 @@ void DNS::HandleEvent(EventType et, int errornum) { if (ServerInstance && ServerInstance->stats) ServerInstance->stats->statsDnsGood++; + + if (!this->GetCache(res.original.c_str())) + { + ServerInstance->Log(DEBUG,"Caching result: %s->%s for %lu secs", res.original.c_str(), res.result.c_str(), res.ttl); + this->cache->insert(std::make_pair(res.original.c_str(), CachedQuery(res.result, res.ttl))); + } + Classes[res.id]->OnLookupComplete(res.result, res.ttl); delete Classes[res.id]; Classes[res.id] = NULL; |