diff options
Diffstat (limited to 'src/modules/extra')
-rw-r--r-- | src/modules/extra/m_geoip.cpp | 34 | ||||
-rw-r--r-- | src/modules/extra/m_ldap.cpp | 669 | ||||
-rw-r--r-- | src/modules/extra/m_ldapauth.cpp | 436 | ||||
-rw-r--r-- | src/modules/extra/m_ldapoper.cpp | 255 | ||||
-rw-r--r-- | src/modules/extra/m_mssql.cpp | 75 | ||||
-rw-r--r-- | src/modules/extra/m_mysql.cpp | 44 | ||||
-rw-r--r-- | src/modules/extra/m_pgsql.cpp | 89 | ||||
-rw-r--r-- | src/modules/extra/m_regex_pcre.cpp | 41 | ||||
-rw-r--r-- | src/modules/extra/m_regex_posix.cpp | 46 | ||||
-rw-r--r-- | src/modules/extra/m_regex_re2.cpp | 81 | ||||
-rw-r--r-- | src/modules/extra/m_regex_stdlib.cpp | 45 | ||||
-rw-r--r-- | src/modules/extra/m_regex_tre.cpp | 43 | ||||
-rw-r--r-- | src/modules/extra/m_sqlite3.cpp | 57 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 1685 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_mbedtls.cpp | 932 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 1241 |
16 files changed, 3552 insertions, 2221 deletions
diff --git a/src/modules/extra/m_geoip.cpp b/src/modules/extra/m_geoip.cpp index a36c39bc8..967c6a761 100644 --- a/src/modules/extra/m_geoip.cpp +++ b/src/modules/extra/m_geoip.cpp @@ -27,7 +27,6 @@ # pragma comment(lib, "GeoIP.lib") #endif -/* $ModDesc: Provides a way to restrict users by country using GeoIP lookup */ /* $LinkerFlags: -lGeoIP */ class ModuleGeoIP : public Module @@ -37,7 +36,7 @@ class ModuleGeoIP : public Module std::string* SetExt(LocalUser* user) { - const char* c = GeoIP_country_code_by_addr(gi, user->GetIPString()); + const char* c = GeoIP_country_code_by_addr(gi, user->GetIPString().c_str()); if (!c) c = "UNK"; @@ -47,21 +46,20 @@ class ModuleGeoIP : public Module } public: - ModuleGeoIP() : ext("geoip_cc", this), gi(NULL) + ModuleGeoIP() + : ext("geoip_cc", ExtensionItem::EXT_USER, this) + , gi(NULL) { } - void init() + void init() CXX11_OVERRIDE { gi = GeoIP_new(GEOIP_STANDARD); if (gi == NULL) throw ModuleException("Unable to initialize geoip, are you missing GeoIP.dat?"); - ServerInstance->Modules->AddService(ext); - Implementation eventlist[] = { I_OnSetConnectClass, I_OnStats }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - - for (LocalUserList::const_iterator i = ServerInstance->Users->local_users.begin(); i != ServerInstance->Users->local_users.end(); ++i) + const UserManager::LocalList& list = ServerInstance->Users.GetLocalUsers(); + for (UserManager::LocalList::const_iterator i = list.begin(); i != list.end(); ++i) { LocalUser* user = *i; if ((user->registered == REG_ALL) && (!ext.get(user))) @@ -77,12 +75,12 @@ class ModuleGeoIP : public Module GeoIP_delete(gi); } - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("Provides a way to assign users to connect classes by country using GeoIP lookup", VF_VENDOR); } - ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass) + ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass) CXX11_OVERRIDE { std::string* cc = ext.get(user); if (!cc) @@ -99,14 +97,16 @@ class ModuleGeoIP : public Module return MOD_RES_DENY; } - ModResult OnStats(char symbol, User* user, string_list &out) + ModResult OnStats(Stats::Context& stats) CXX11_OVERRIDE { - if (symbol != 'G') + if (stats.GetSymbol() != 'G') return MOD_RES_PASSTHRU; unsigned int unknown = 0; std::map<std::string, unsigned int> results; - for (LocalUserList::const_iterator i = ServerInstance->Users->local_users.begin(); i != ServerInstance->Users->local_users.end(); ++i) + + const UserManager::LocalList& list = ServerInstance->Users.GetLocalUsers(); + for (UserManager::LocalList::const_iterator i = list.begin(); i != list.end(); ++i) { std::string* cc = ext.get(*i); if (cc) @@ -115,18 +115,16 @@ class ModuleGeoIP : public Module unknown++; } - std::string p = ServerInstance->Config->ServerName + " 801 " + user->nick + " :GeoIPSTATS "; for (std::map<std::string, unsigned int>::const_iterator i = results.begin(); i != results.end(); ++i) { - out.push_back(p + i->first + " " + ConvToStr(i->second)); + stats.AddRow(801, "GeoIPSTATS " + i->first + " " + ConvToStr(i->second)); } if (unknown) - out.push_back(p + "Unknown " + ConvToStr(unknown)); + stats.AddRow(801, "GeoIPSTATS Unknown " + ConvToStr(unknown)); return MOD_RES_DENY; } }; MODULE_INIT(ModuleGeoIP) - diff --git a/src/modules/extra/m_ldap.cpp b/src/modules/extra/m_ldap.cpp new file mode 100644 index 000000000..c11025836 --- /dev/null +++ b/src/modules/extra/m_ldap.cpp @@ -0,0 +1,669 @@ +/* + * InspIRCd -- Internet Relay Chat Daemon + * + * Copyright (C) 2013-2015 Adam <Adam@anope.org> + * Copyright (C) 2003-2015 Anope Team <team@anope.org> + * + * This file is part of InspIRCd. InspIRCd is free software: you can + * redistribute it and/or modify it under the terms of the GNU General Public + * License as published by the Free Software Foundation, version 2. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "inspircd.h" +#include "modules/ldap.h" + +#include <ldap.h> + +#ifdef _WIN32 +# pragma comment(lib, "libldap_r.lib") +# pragma comment(lib, "liblber.lib") +#endif + +/* $LinkerFlags: -lldap_r */ + +class LDAPService; + +class LDAPRequest +{ + public: + LDAPService* service; + LDAPInterface* inter; + LDAPMessage* message; /* message returned by ldap_ */ + LDAPResult* result; /* final result */ + struct timeval tv; + QueryType type; + + LDAPRequest(LDAPService* s, LDAPInterface* i) + : service(s) + , inter(i) + , message(NULL) + , result(NULL) + { + type = QUERY_UNKNOWN; + tv.tv_sec = 0; + tv.tv_usec = 100000; + } + + virtual ~LDAPRequest() + { + delete result; + if (message != NULL) + ldap_msgfree(message); + } + + virtual int run() = 0; +}; + +class LDAPBind : public LDAPRequest +{ + std::string who, pass; + + public: + LDAPBind(LDAPService* s, LDAPInterface* i, const std::string& w, const std::string& p) + : LDAPRequest(s, i) + , who(w) + , pass(p) + { + type = QUERY_BIND; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPSearch : public LDAPRequest +{ + std::string base; + int searchscope; + std::string filter; + + public: + LDAPSearch(LDAPService* s, LDAPInterface* i, const std::string& b, int se, const std::string& f) + : LDAPRequest(s, i) + , base(b) + , searchscope(se) + , filter(f) + { + type = QUERY_SEARCH; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPAdd : public LDAPRequest +{ + std::string dn; + LDAPMods attributes; + + public: + LDAPAdd(LDAPService* s, LDAPInterface* i, const std::string& d, const LDAPMods& attr) + : LDAPRequest(s, i) + , dn(d) + , attributes(attr) + { + type = QUERY_ADD; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPDel : public LDAPRequest +{ + std::string dn; + + public: + LDAPDel(LDAPService* s, LDAPInterface* i, const std::string& d) + : LDAPRequest(s, i) + , dn(d) + { + type = QUERY_DELETE; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPModify : public LDAPRequest +{ + std::string base; + LDAPMods attributes; + + public: + LDAPModify(LDAPService* s, LDAPInterface* i, const std::string& b, const LDAPMods& attr) + : LDAPRequest(s, i) + , base(b) + , attributes(attr) + { + type = QUERY_MODIFY; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPCompare : public LDAPRequest +{ + std::string dn, attr, val; + + public: + LDAPCompare(LDAPService* s, LDAPInterface* i, const std::string& d, const std::string& a, const std::string& v) + : LDAPRequest(s, i) + , dn(d) + , attr(a) + , val(v) + { + type = QUERY_COMPARE; + } + + int run() CXX11_OVERRIDE; +}; + +class LDAPService : public LDAPProvider, public SocketThread +{ + LDAP* con; + reference<ConfigTag> config; + time_t last_connect; + int searchscope; + time_t timeout; + time_t last_timeout_check; + + public: + static LDAPMod** BuildMods(const LDAPMods& attributes) + { + LDAPMod** mods = new LDAPMod*[attributes.size() + 1]; + memset(mods, 0, sizeof(LDAPMod*) * (attributes.size() + 1)); + for (unsigned int x = 0; x < attributes.size(); ++x) + { + const LDAPModification& l = attributes[x]; + LDAPMod* mod = new LDAPMod; + mods[x] = mod; + + if (l.op == LDAPModification::LDAP_ADD) + mod->mod_op = LDAP_MOD_ADD; + else if (l.op == LDAPModification::LDAP_DEL) + mod->mod_op = LDAP_MOD_DELETE; + else if (l.op == LDAPModification::LDAP_REPLACE) + mod->mod_op = LDAP_MOD_REPLACE; + else if (l.op != 0) + { + FreeMods(mods); + throw LDAPException("Unknown LDAP operation"); + } + mod->mod_type = strdup(l.name.c_str()); + mod->mod_values = new char*[l.values.size() + 1]; + memset(mod->mod_values, 0, sizeof(char*) * (l.values.size() + 1)); + for (unsigned int j = 0, c = 0; j < l.values.size(); ++j) + if (!l.values[j].empty()) + mod->mod_values[c++] = strdup(l.values[j].c_str()); + } + return mods; + } + + static void FreeMods(LDAPMod** mods) + { + for (unsigned int i = 0; mods[i] != NULL; ++i) + { + LDAPMod* mod = mods[i]; + if (mod->mod_type != NULL) + free(mod->mod_type); + if (mod->mod_values != NULL) + { + for (unsigned int j = 0; mod->mod_values[j] != NULL; ++j) + free(mod->mod_values[j]); + delete[] mod->mod_values; + } + } + delete[] mods; + } + + private: + void Reconnect() + { + // Only try one connect a minute. It is an expensive blocking operation + if (last_connect > ServerInstance->Time() - 60) + throw LDAPException("Unable to connect to LDAP service " + this->name + ": reconnecting too fast"); + last_connect = ServerInstance->Time(); + + ldap_unbind_ext(this->con, NULL, NULL); + Connect(); + } + + void QueueRequest(LDAPRequest* r) + { + this->LockQueue(); + this->queries.push_back(r); + this->UnlockQueueWakeup(); + } + + public: + typedef std::vector<LDAPRequest*> query_queue; + query_queue queries, results; + Mutex process_mutex; /* held when processing requests not in either queue */ + + LDAPService(Module* c, ConfigTag* tag) + : LDAPProvider(c, "LDAP/" + tag->getString("id")) + , con(NULL), config(tag), last_connect(0), last_timeout_check(0) + { + std::string scope = config->getString("searchscope"); + if (scope == "base") + searchscope = LDAP_SCOPE_BASE; + else if (scope == "onelevel") + searchscope = LDAP_SCOPE_ONELEVEL; + else + searchscope = LDAP_SCOPE_SUBTREE; + timeout = config->getInt("timeout", 5); + + Connect(); + } + + ~LDAPService() + { + this->LockQueue(); + + for (unsigned int i = 0; i < this->queries.size(); ++i) + { + LDAPRequest* req = this->queries[i]; + + /* queries have no results yet */ + req->result = new LDAPResult(); + req->result->type = req->type; + req->result->error = "LDAP Interface is going away"; + req->inter->OnError(*req->result); + + delete req; + } + this->queries.clear(); + + for (unsigned int i = 0; i < this->results.size(); ++i) + { + LDAPRequest* req = this->results[i]; + + /* even though this may have already finished successfully we return that it didn't */ + req->result->error = "LDAP Interface is going away"; + req->inter->OnError(*req->result); + + delete req; + } + this->results.clear(); + + this->UnlockQueue(); + + ldap_unbind_ext(this->con, NULL, NULL); + } + + void Connect() + { + std::string server = config->getString("server"); + int i = ldap_initialize(&this->con, server.c_str()); + if (i != LDAP_SUCCESS) + throw LDAPException("Unable to connect to LDAP service " + this->name + ": " + ldap_err2string(i)); + + const int version = LDAP_VERSION3; + i = ldap_set_option(this->con, LDAP_OPT_PROTOCOL_VERSION, &version); + if (i != LDAP_OPT_SUCCESS) + { + ldap_unbind_ext(this->con, NULL, NULL); + this->con = NULL; + throw LDAPException("Unable to set protocol version for " + this->name + ": " + ldap_err2string(i)); + } + + const struct timeval tv = { 0, 0 }; + i = ldap_set_option(this->con, LDAP_OPT_NETWORK_TIMEOUT, &tv); + if (i != LDAP_OPT_SUCCESS) + { + ldap_unbind_ext(this->con, NULL, NULL); + this->con = NULL; + throw LDAPException("Unable to set timeout for " + this->name + ": " + ldap_err2string(i)); + } + } + + void BindAsManager(LDAPInterface* i) CXX11_OVERRIDE + { + std::string binddn = config->getString("binddn"); + std::string bindauth = config->getString("bindauth"); + this->Bind(i, binddn, bindauth); + } + + void Bind(LDAPInterface* i, const std::string& who, const std::string& pass) CXX11_OVERRIDE + { + LDAPBind* b = new LDAPBind(this, i, who, pass); + QueueRequest(b); + } + + void Search(LDAPInterface* i, const std::string& base, const std::string& filter) CXX11_OVERRIDE + { + if (i == NULL) + throw LDAPException("No interface"); + + LDAPSearch* s = new LDAPSearch(this, i, base, searchscope, filter); + QueueRequest(s); + } + + void Add(LDAPInterface* i, const std::string& dn, LDAPMods& attributes) CXX11_OVERRIDE + { + LDAPAdd* add = new LDAPAdd(this, i, dn, attributes); + QueueRequest(add); + } + + void Del(LDAPInterface* i, const std::string& dn) CXX11_OVERRIDE + { + LDAPDel* del = new LDAPDel(this, i, dn); + QueueRequest(del); + } + + void Modify(LDAPInterface* i, const std::string& base, LDAPMods& attributes) CXX11_OVERRIDE + { + LDAPModify* mod = new LDAPModify(this, i, base, attributes); + QueueRequest(mod); + } + + void Compare(LDAPInterface* i, const std::string& dn, const std::string& attr, const std::string& val) CXX11_OVERRIDE + { + LDAPCompare* comp = new LDAPCompare(this, i, dn, attr, val); + QueueRequest(comp); + } + + private: + void BuildReply(int res, LDAPRequest* req) + { + LDAPResult* ldap_result = req->result = new LDAPResult(); + req->result->type = req->type; + + if (res != LDAP_SUCCESS) + { + ldap_result->error = ldap_err2string(res); + return; + } + + if (req->message == NULL) + { + return; + } + + /* a search result */ + + for (LDAPMessage* cur = ldap_first_message(this->con, req->message); cur; cur = ldap_next_message(this->con, cur)) + { + LDAPAttributes attributes; + + char* dn = ldap_get_dn(this->con, cur); + if (dn != NULL) + { + attributes["dn"].push_back(dn); + ldap_memfree(dn); + dn = NULL; + } + + BerElement* ber = NULL; + + for (char* attr = ldap_first_attribute(this->con, cur, &ber); attr; attr = ldap_next_attribute(this->con, cur, ber)) + { + berval** vals = ldap_get_values_len(this->con, cur, attr); + int count = ldap_count_values_len(vals); + + std::vector<std::string> attrs; + for (int j = 0; j < count; ++j) + attrs.push_back(vals[j]->bv_val); + attributes[attr] = attrs; + + ldap_value_free_len(vals); + ldap_memfree(attr); + } + if (ber != NULL) + ber_free(ber, 0); + + ldap_result->messages.push_back(attributes); + } + } + + void SendRequests() + { + process_mutex.Lock(); + + query_queue q; + this->LockQueue(); + queries.swap(q); + this->UnlockQueue(); + + if (q.empty()) + { + process_mutex.Unlock(); + return; + } + + for (unsigned int i = 0; i < q.size(); ++i) + { + LDAPRequest* req = q[i]; + int ret = req->run(); + + if (ret == LDAP_SERVER_DOWN || ret == LDAP_TIMEOUT) + { + /* try again */ + try + { + Reconnect(); + } + catch (const LDAPException &) + { + } + + ret = req->run(); + } + + BuildReply(ret, req); + + this->LockQueue(); + this->results.push_back(req); + this->UnlockQueue(); + } + + this->NotifyParent(); + + process_mutex.Unlock(); + } + + public: + void Run() CXX11_OVERRIDE + { + while (!this->GetExitFlag()) + { + this->LockQueue(); + if (this->queries.empty()) + this->WaitForQueue(); + this->UnlockQueue(); + + SendRequests(); + } + } + + void OnNotify() CXX11_OVERRIDE + { + query_queue r; + + this->LockQueue(); + this->results.swap(r); + this->UnlockQueue(); + + for (unsigned int i = 0; i < r.size(); ++i) + { + LDAPRequest* req = r[i]; + LDAPInterface* li = req->inter; + LDAPResult* res = req->result; + + if (!res->error.empty()) + li->OnError(*res); + else + li->OnResult(*res); + + delete req; + } + } + + LDAP* GetConnection() + { + return con; + } +}; + +class ModuleLDAP : public Module +{ + typedef insp::flat_map<std::string, LDAPService*> ServiceMap; + ServiceMap LDAPServices; + + public: + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE + { + ServiceMap conns; + + ConfigTagList tags = ServerInstance->Config->ConfTags("database"); + for (ConfigIter i = tags.first; i != tags.second; i++) + { + const reference<ConfigTag>& tag = i->second; + + if (tag->getString("module") != "ldap") + continue; + + std::string id = tag->getString("id"); + + ServiceMap::iterator curr = LDAPServices.find(id); + if (curr == LDAPServices.end()) + { + LDAPService* conn = new LDAPService(this, tag); + conns[id] = conn; + + ServerInstance->Modules->AddService(*conn); + ServerInstance->Threads.Start(conn); + } + else + { + conns.insert(*curr); + LDAPServices.erase(curr); + } + } + + for (ServiceMap::iterator i = LDAPServices.begin(); i != LDAPServices.end(); ++i) + { + LDAPService* conn = i->second; + ServerInstance->Modules->DelService(*conn); + conn->join(); + conn->OnNotify(); + delete conn; + } + + LDAPServices.swap(conns); + } + + void OnUnloadModule(Module* m) CXX11_OVERRIDE + { + for (ServiceMap::iterator it = this->LDAPServices.begin(); it != this->LDAPServices.end(); ++it) + { + LDAPService* s = it->second; + + s->process_mutex.Lock(); + s->LockQueue(); + + for (unsigned int i = s->queries.size(); i > 0; --i) + { + LDAPRequest* req = s->queries[i - 1]; + LDAPInterface* li = req->inter; + + if (li->creator == m) + { + s->queries.erase(s->queries.begin() + i - 1); + delete req; + } + } + + for (unsigned int i = s->results.size(); i > 0; --i) + { + LDAPRequest* req = s->results[i - 1]; + LDAPInterface* li = req->inter; + + if (li->creator == m) + { + s->results.erase(s->results.begin() + i - 1); + delete req; + } + } + + s->UnlockQueue(); + s->process_mutex.Unlock(); + } + } + + ~ModuleLDAP() + { + for (ServiceMap::iterator i = LDAPServices.begin(); i != LDAPServices.end(); ++i) + { + LDAPService* conn = i->second; + conn->join(); + conn->OnNotify(); + delete conn; + } + } + + Version GetVersion() CXX11_OVERRIDE + { + return Version("LDAP support", VF_VENDOR); + } +}; + +int LDAPBind::run() +{ + berval cred; + cred.bv_val = strdup(pass.c_str()); + cred.bv_len = pass.length(); + + int i = ldap_sasl_bind_s(service->GetConnection(), who.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL); + + free(cred.bv_val); + + return i; +} + +int LDAPSearch::run() +{ + return ldap_search_ext_s(service->GetConnection(), base.c_str(), searchscope, filter.c_str(), NULL, 0, NULL, NULL, &tv, 0, &message); +} + +int LDAPAdd::run() +{ + LDAPMod** mods = LDAPService::BuildMods(attributes); + int i = ldap_add_ext_s(service->GetConnection(), dn.c_str(), mods, NULL, NULL); + LDAPService::FreeMods(mods); + return i; +} + +int LDAPDel::run() +{ + return ldap_delete_ext_s(service->GetConnection(), dn.c_str(), NULL, NULL); +} + +int LDAPModify::run() +{ + LDAPMod** mods = LDAPService::BuildMods(attributes); + int i = ldap_modify_ext_s(service->GetConnection(), base.c_str(), mods, NULL, NULL); + LDAPService::FreeMods(mods); + return i; +} + +int LDAPCompare::run() +{ + berval cred; + cred.bv_val = strdup(val.c_str()); + cred.bv_len = val.length(); + + int ret = ldap_compare_ext_s(service->GetConnection(), dn.c_str(), attr.c_str(), &cred, NULL, NULL); + + free(cred.bv_val); + + return ret; + +} + +MODULE_INIT(ModuleLDAP) diff --git a/src/modules/extra/m_ldapauth.cpp b/src/modules/extra/m_ldapauth.cpp deleted file mode 100644 index 6c765fb2e..000000000 --- a/src/modules/extra/m_ldapauth.cpp +++ /dev/null @@ -1,436 +0,0 @@ -/* - * InspIRCd -- Internet Relay Chat Daemon - * - * Copyright (C) 2011 Pierre Carrier <pierre@spotify.com> - * Copyright (C) 2009-2010 Robin Burchell <robin+git@viroteck.net> - * Copyright (C) 2009 Daniel De Graaf <danieldg@inspircd.org> - * Copyright (C) 2008 Pippijn van Steenhoven <pip88nl@gmail.com> - * Copyright (C) 2008 Craig Edwards <craigedwards@brainbox.cc> - * Copyright (C) 2008 Dennis Friis <peavey@inspircd.org> - * Copyright (C) 2007 Carsten Valdemar Munk <carsten.munk+inspircd@gmail.com> - * - * This file is part of InspIRCd. InspIRCd is free software: you can - * redistribute it and/or modify it under the terms of the GNU General Public - * License as published by the Free Software Foundation, version 2. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS - * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more - * details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see <http://www.gnu.org/licenses/>. - */ - - -#include "inspircd.h" -#include "users.h" -#include "channels.h" -#include "modules.h" - -#include <ldap.h> - -#ifdef _WIN32 -# pragma comment(lib, "libldap.lib") -# pragma comment(lib, "liblber.lib") -#endif - -/* $ModDesc: Allow/Deny connections based upon answer from LDAP server */ -/* $LinkerFlags: -lldap */ - -struct RAIILDAPString -{ - char *str; - - RAIILDAPString(char *Str) - : str(Str) - { - } - - ~RAIILDAPString() - { - ldap_memfree(str); - } - - operator char*() - { - return str; - } - - operator std::string() - { - return str; - } -}; - -struct RAIILDAPMessage -{ - RAIILDAPMessage() - { - } - - ~RAIILDAPMessage() - { - dealloc(); - } - - void dealloc() - { - ldap_msgfree(msg); - } - - operator LDAPMessage*() - { - return msg; - } - - LDAPMessage **operator &() - { - return &msg; - } - - LDAPMessage *msg; -}; - -class ModuleLDAPAuth : public Module -{ - LocalIntExt ldapAuthed; - LocalStringExt ldapVhost; - std::string base; - std::string attribute; - std::string ldapserver; - std::string allowpattern; - std::string killreason; - std::string username; - std::string password; - std::string vhost; - std::vector<std::string> whitelistedcidrs; - std::vector<std::pair<std::string, std::string> > requiredattributes; - int searchscope; - bool verbose; - bool useusername; - LDAP *conn; - -public: - ModuleLDAPAuth() - : ldapAuthed("ldapauth", this) - , ldapVhost("ldapauth_vhost", this) - { - conn = NULL; - } - - void init() - { - ServerInstance->Modules->AddService(ldapAuthed); - ServerInstance->Modules->AddService(ldapVhost); - Implementation eventlist[] = { I_OnCheckReady, I_OnRehash,I_OnUserRegister, I_OnUserConnect }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - OnRehash(NULL); - } - - ~ModuleLDAPAuth() - { - if (conn) - ldap_unbind_ext(conn, NULL, NULL); - } - - void OnRehash(User* user) - { - ConfigTag* tag = ServerInstance->Config->ConfValue("ldapauth"); - whitelistedcidrs.clear(); - requiredattributes.clear(); - - base = tag->getString("baserdn"); - attribute = tag->getString("attribute"); - ldapserver = tag->getString("server"); - allowpattern = tag->getString("allowpattern"); - killreason = tag->getString("killreason"); - std::string scope = tag->getString("searchscope"); - username = tag->getString("binddn"); - password = tag->getString("bindauth"); - vhost = tag->getString("host"); - verbose = tag->getBool("verbose"); /* Set to true if failed connects should be reported to operators */ - useusername = tag->getBool("userfield"); - - ConfigTagList whitelisttags = ServerInstance->Config->ConfTags("ldapwhitelist"); - - for (ConfigIter i = whitelisttags.first; i != whitelisttags.second; ++i) - { - std::string cidr = i->second->getString("cidr"); - if (!cidr.empty()) { - whitelistedcidrs.push_back(cidr); - } - } - - ConfigTagList attributetags = ServerInstance->Config->ConfTags("ldaprequire"); - - for (ConfigIter i = attributetags.first; i != attributetags.second; ++i) - { - const std::string attr = i->second->getString("attribute"); - const std::string val = i->second->getString("value"); - - if (!attr.empty() && !val.empty()) - requiredattributes.push_back(make_pair(attr, val)); - } - - if (scope == "base") - searchscope = LDAP_SCOPE_BASE; - else if (scope == "onelevel") - searchscope = LDAP_SCOPE_ONELEVEL; - else searchscope = LDAP_SCOPE_SUBTREE; - - Connect(); - } - - bool Connect() - { - if (conn != NULL) - ldap_unbind_ext(conn, NULL, NULL); - int res, v = LDAP_VERSION3; - res = ldap_initialize(&conn, ldapserver.c_str()); - if (res != LDAP_SUCCESS) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "LDAP connection failed: %s", ldap_err2string(res)); - conn = NULL; - return false; - } - - res = ldap_set_option(conn, LDAP_OPT_PROTOCOL_VERSION, (void *)&v); - if (res != LDAP_SUCCESS) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "LDAP set protocol to v3 failed: %s", ldap_err2string(res)); - ldap_unbind_ext(conn, NULL, NULL); - conn = NULL; - return false; - } - return true; - } - - std::string SafeReplace(const std::string &text, std::map<std::string, - std::string> &replacements) - { - std::string result; - result.reserve(MAXBUF); - - for (unsigned int i = 0; i < text.length(); ++i) { - char c = text[i]; - if (c == '$') { - // find the first nonalpha - i++; - unsigned int start = i; - - while (i < text.length() - 1 && isalpha(text[i + 1])) - ++i; - - std::string key = text.substr(start, (i - start) + 1); - result.append(replacements[key]); - } else { - result.push_back(c); - } - } - - return result; - } - - virtual void OnUserConnect(LocalUser *user) - { - std::string* cc = ldapVhost.get(user); - if (cc) - { - user->ChangeDisplayedHost(cc->c_str()); - ldapVhost.unset(user); - } - } - - ModResult OnUserRegister(LocalUser* user) - { - if ((!allowpattern.empty()) && (InspIRCd::Match(user->nick,allowpattern))) - { - ldapAuthed.set(user,1); - return MOD_RES_PASSTHRU; - } - - for (std::vector<std::string>::iterator i = whitelistedcidrs.begin(); i != whitelistedcidrs.end(); i++) - { - if (InspIRCd::MatchCIDR(user->GetIPString(), *i, ascii_case_insensitive_map)) - { - ldapAuthed.set(user,1); - return MOD_RES_PASSTHRU; - } - } - - if (!CheckCredentials(user)) - { - ServerInstance->Users->QuitUser(user, killreason); - return MOD_RES_DENY; - } - return MOD_RES_PASSTHRU; - } - - bool CheckCredentials(LocalUser* user) - { - if (conn == NULL) - if (!Connect()) - return false; - - if (user->password.empty()) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (No password provided)", user->GetFullRealHost().c_str()); - return false; - } - - int res; - // bind anonymously if no bind DN and authentication are given in the config - struct berval cred; - cred.bv_val = const_cast<char*>(password.c_str()); - cred.bv_len = password.length(); - - if ((res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS) - { - if (res == LDAP_SERVER_DOWN) - { - // Attempt to reconnect if the connection dropped - if (verbose) - ServerInstance->SNO->WriteToSnoMask('a', "LDAP server has gone away - reconnecting..."); - Connect(); - res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL); - } - - if (res != LDAP_SUCCESS) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP bind failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res)); - ldap_unbind_ext(conn, NULL, NULL); - conn = NULL; - return false; - } - } - - RAIILDAPMessage msg; - std::string what = (attribute + "=" + (useusername ? user->ident : user->nick)); - if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS) - { - // Do a second search, based on password, if it contains a : - // That is, PASS <user>:<password> will work. - size_t pos = user->password.find(":"); - if (pos != std::string::npos) - { - // manpage says we must deallocate regardless of success or failure - // since we're about to do another query (and reset msg), first - // free the old one. - msg.dealloc(); - - std::string cutpassword = user->password.substr(0, pos); - res = ldap_search_ext_s(conn, base.c_str(), searchscope, cutpassword.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg); - - if (res == LDAP_SUCCESS) - { - // Trim the user: prefix, leaving just 'pass' for later password check - user->password = user->password.substr(pos + 1); - } - } - - // It may have found based on user:pass check above. - if (res != LDAP_SUCCESS) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res)); - return false; - } - } - if (ldap_count_entries(conn, msg) > 1) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned more than one result: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res)); - return false; - } - - LDAPMessage *entry; - if ((entry = ldap_first_entry(conn, msg)) == NULL) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned no results: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res)); - return false; - } - cred.bv_val = (char*)user->password.data(); - cred.bv_len = user->password.length(); - RAIILDAPString DN(ldap_get_dn(conn, entry)); - if ((res = ldap_sasl_bind_s(conn, DN, LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (%s)", user->GetFullRealHost().c_str(), ldap_err2string(res)); - return false; - } - - if (!requiredattributes.empty()) - { - bool authed = false; - - for (std::vector<std::pair<std::string, std::string> >::const_iterator it = requiredattributes.begin(); it != requiredattributes.end(); ++it) - { - const std::string &attr = it->first; - const std::string &val = it->second; - - struct berval attr_value; - attr_value.bv_val = const_cast<char*>(val.c_str()); - attr_value.bv_len = val.length(); - - ServerInstance->Logs->Log("m_ldapauth", DEBUG, "LDAP compare: %s=%s", attr.c_str(), val.c_str()); - - authed = (ldap_compare_ext_s(conn, DN, attr.c_str(), &attr_value, NULL, NULL) == LDAP_COMPARE_TRUE); - - if (authed) - break; - } - - if (!authed) - { - if (verbose) - ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (Lacks required LDAP attributes)", user->GetFullRealHost().c_str()); - return false; - } - } - - if (!vhost.empty()) - { - irc::commasepstream stream(DN); - - // mashed map of key:value parts of the DN - std::map<std::string, std::string> dnParts; - - std::string dnPart; - while (stream.GetToken(dnPart)) - { - std::string::size_type pos = dnPart.find('='); - if (pos == std::string::npos) // malformed - continue; - - std::string key = dnPart.substr(0, pos); - std::string value = dnPart.substr(pos + 1, dnPart.length() - pos + 1); // +1s to skip the = itself - dnParts[key] = value; - } - - // change host according to config key - ldapVhost.set(user, SafeReplace(vhost, dnParts)); - } - - ldapAuthed.set(user,1); - return true; - } - - ModResult OnCheckReady(LocalUser* user) - { - return ldapAuthed.get(user) ? MOD_RES_PASSTHRU : MOD_RES_DENY; - } - - Version GetVersion() - { - return Version("Allow/Deny connections based upon answer from LDAP server", VF_VENDOR); - } - -}; - -MODULE_INIT(ModuleLDAPAuth) diff --git a/src/modules/extra/m_ldapoper.cpp b/src/modules/extra/m_ldapoper.cpp deleted file mode 100644 index 1f46361d4..000000000 --- a/src/modules/extra/m_ldapoper.cpp +++ /dev/null @@ -1,255 +0,0 @@ -/* - * InspIRCd -- Internet Relay Chat Daemon - * - * Copyright (C) 2009 Robin Burchell <robin+git@viroteck.net> - * Copyright (C) 2008 Pippijn van Steenhoven <pip88nl@gmail.com> - * Copyright (C) 2008 Craig Edwards <craigedwards@brainbox.cc> - * Copyright (C) 2007 Carsten Valdemar Munk <carsten.munk+inspircd@gmail.com> - * - * This file is part of InspIRCd. InspIRCd is free software: you can - * redistribute it and/or modify it under the terms of the GNU General Public - * License as published by the Free Software Foundation, version 2. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS - * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more - * details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see <http://www.gnu.org/licenses/>. - */ - - -#include "inspircd.h" -#include "users.h" -#include "channels.h" -#include "modules.h" - -#include <ldap.h> - -#ifdef _WIN32 -# pragma comment(lib, "libldap.lib") -# pragma comment(lib, "liblber.lib") -#endif - -/* $ModDesc: Adds the ability to authenticate opers via LDAP */ -/* $LinkerFlags: -lldap */ - -// Duplicated code, also found in cmd_oper and m_sqloper -static bool OneOfMatches(const char* host, const char* ip, const std::string& hostlist) -{ - std::stringstream hl(hostlist); - std::string xhost; - while (hl >> xhost) - { - if (InspIRCd::Match(host, xhost, ascii_case_insensitive_map) || InspIRCd::MatchCIDR(ip, xhost, ascii_case_insensitive_map)) - { - return true; - } - } - return false; -} - -struct RAIILDAPString -{ - char *str; - - RAIILDAPString(char *Str) - : str(Str) - { - } - - ~RAIILDAPString() - { - ldap_memfree(str); - } - - operator char*() - { - return str; - } - - operator std::string() - { - return str; - } -}; - -class ModuleLDAPAuth : public Module -{ - std::string base; - std::string ldapserver; - std::string username; - std::string password; - std::string attribute; - int searchscope; - LDAP *conn; - - bool HandleOper(LocalUser* user, const std::string& opername, const std::string& inputpass) - { - OperIndex::iterator it = ServerInstance->Config->oper_blocks.find(opername); - if (it == ServerInstance->Config->oper_blocks.end()) - return false; - - ConfigTag* tag = it->second->oper_block; - if (!tag) - return false; - - std::string acceptedhosts = tag->getString("host"); - std::string hostname = user->ident + "@" + user->host; - if (!OneOfMatches(hostname.c_str(), user->GetIPString(), acceptedhosts)) - return false; - - if (!LookupOper(opername, inputpass)) - return false; - - user->Oper(it->second); - return true; - } - -public: - ModuleLDAPAuth() - : conn(NULL) - { - } - - void init() - { - Implementation eventlist[] = { I_OnRehash, I_OnPreCommand }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - OnRehash(NULL); - } - - virtual ~ModuleLDAPAuth() - { - if (conn) - ldap_unbind_ext(conn, NULL, NULL); - } - - virtual void OnRehash(User* user) - { - ConfigTag* tag = ServerInstance->Config->ConfValue("ldapoper"); - - base = tag->getString("baserdn"); - ldapserver = tag->getString("server"); - std::string scope = tag->getString("searchscope"); - username = tag->getString("binddn"); - password = tag->getString("bindauth"); - attribute = tag->getString("attribute"); - - if (scope == "base") - searchscope = LDAP_SCOPE_BASE; - else if (scope == "onelevel") - searchscope = LDAP_SCOPE_ONELEVEL; - else searchscope = LDAP_SCOPE_SUBTREE; - - Connect(); - } - - bool Connect() - { - if (conn != NULL) - ldap_unbind_ext(conn, NULL, NULL); - int res, v = LDAP_VERSION3; - res = ldap_initialize(&conn, ldapserver.c_str()); - if (res != LDAP_SUCCESS) - { - conn = NULL; - return false; - } - - res = ldap_set_option(conn, LDAP_OPT_PROTOCOL_VERSION, (void *)&v); - if (res != LDAP_SUCCESS) - { - ldap_unbind_ext(conn, NULL, NULL); - conn = NULL; - return false; - } - return true; - } - - ModResult OnPreCommand(std::string& command, std::vector<std::string>& parameters, LocalUser* user, bool validated, const std::string& original_line) - { - if (validated && command == "OPER" && parameters.size() >= 2) - { - if (HandleOper(user, parameters[0], parameters[1])) - return MOD_RES_DENY; - } - return MOD_RES_PASSTHRU; - } - - bool LookupOper(const std::string& opername, const std::string& opassword) - { - if (conn == NULL) - if (!Connect()) - return false; - - int res; - char* authpass = strdup(password.c_str()); - // bind anonymously if no bind DN and authentication are given in the config - struct berval cred; - cred.bv_val = authpass; - cred.bv_len = password.length(); - - if ((res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS) - { - if (res == LDAP_SERVER_DOWN) - { - // Attempt to reconnect if the connection dropped - ServerInstance->SNO->WriteToSnoMask('a', "LDAP server has gone away - reconnecting..."); - Connect(); - res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL); - } - - if (res != LDAP_SUCCESS) - { - free(authpass); - ldap_unbind_ext(conn, NULL, NULL); - conn = NULL; - return false; - } - } - free(authpass); - - LDAPMessage *msg, *entry; - std::string what = attribute + "=" + opername; - if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS) - { - return false; - } - if (ldap_count_entries(conn, msg) > 1) - { - ldap_msgfree(msg); - return false; - } - if ((entry = ldap_first_entry(conn, msg)) == NULL) - { - ldap_msgfree(msg); - return false; - } - authpass = strdup(opassword.c_str()); - cred.bv_val = authpass; - cred.bv_len = opassword.length(); - RAIILDAPString DN(ldap_get_dn(conn, entry)); - if ((res = ldap_sasl_bind_s(conn, DN, LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) == LDAP_SUCCESS) - { - free(authpass); - ldap_msgfree(msg); - return true; - } - else - { - free(authpass); - ldap_msgfree(msg); - return false; - } - } - - virtual Version GetVersion() - { - return Version("Adds the ability to authenticate opers via LDAP", VF_VENDOR); - } - -}; - -MODULE_INIT(ModuleLDAPAuth) diff --git a/src/modules/extra/m_mssql.cpp b/src/modules/extra/m_mssql.cpp index 598f9aac9..0e8c8cf55 100644 --- a/src/modules/extra/m_mssql.cpp +++ b/src/modules/extra/m_mssql.cpp @@ -24,22 +24,17 @@ #include "inspircd.h" #include <tds.h> #include <tdsconvert.h> -#include "users.h" -#include "channels.h" -#include "modules.h" #include "m_sqlv2.h" -/* $ModDesc: MsSQL provider */ /* $CompileFlags: exec("grep VERSION_NO /usr/include/tdsver.h 2>/dev/null | perl -e 'print "-D_TDSVER=".((<> =~ /freetds v(\d+\.\d+)/i) ? $1*100 : 0);'") */ /* $LinkerFlags: -ltds */ -/* $ModDep: m_sqlv2.h */ class SQLConn; class MsSQLResult; class ModuleMsSQL; -typedef std::map<std::string, SQLConn*> ConnMap; +typedef insp::flat_map<std::string, SQLConn*> ConnMap; typedef std::deque<MsSQLResult*> ResultQueue; unsigned long count(const char * const str, char a) @@ -64,8 +59,8 @@ class QueryThread : public SocketThread public: QueryThread(ModuleMsSQL* mod) : Parent(mod) { } ~QueryThread() { } - virtual void Run(); - virtual void OnNotify(); + void Run(); + void OnNotify(); }; class MsSQLResult : public SQLresult @@ -88,10 +83,6 @@ class MsSQLResult : public SQLresult { } - ~MsSQLResult() - { - } - void AddRow(int colsnum, char **dat, char **colname) { colnames.clear(); @@ -111,17 +102,17 @@ class MsSQLResult : public SQLresult rows++; } - virtual int Rows() + int Rows() { return rows; } - virtual int Cols() + int Cols() { return cols; } - virtual std::string ColName(int column) + std::string ColName(int column) { if (column < (int)colnames.size()) { @@ -134,7 +125,7 @@ class MsSQLResult : public SQLresult return ""; } - virtual int ColNum(const std::string &column) + int ColNum(const std::string &column) { for (unsigned int i = 0; i < colnames.size(); i++) { @@ -145,7 +136,7 @@ class MsSQLResult : public SQLresult return 0; } - virtual SQLfield GetValue(int row, int column) + SQLfield GetValue(int row, int column) { if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols())) { @@ -158,7 +149,7 @@ class MsSQLResult : public SQLresult return SQLfield("",true); } - virtual SQLfieldList& GetRow() + SQLfieldList& GetRow() { if (currentrow < rows) return fieldlists[currentrow]; @@ -166,7 +157,7 @@ class MsSQLResult : public SQLresult return emptyfieldlist; } - virtual SQLfieldMap& GetRowMap() + SQLfieldMap& GetRowMap() { /* In an effort to reduce overhead we don't actually allocate the map * until the first time it's needed...so... @@ -192,7 +183,7 @@ class MsSQLResult : public SQLresult return *fieldmap; } - virtual SQLfieldList* GetRowPtr() + SQLfieldList* GetRowPtr() { fieldlist = new SQLfieldList(); @@ -207,7 +198,7 @@ class MsSQLResult : public SQLresult return fieldlist; } - virtual SQLfieldMap* GetRowMapPtr() + SQLfieldMap* GetRowMapPtr() { fieldmap = new SQLfieldMap(); @@ -223,12 +214,12 @@ class MsSQLResult : public SQLresult return fieldmap; } - virtual void Free(SQLfieldMap* fm) + void Free(SQLfieldMap* fm) { delete fm; } - virtual void Free(SQLfieldList* fl) + void Free(SQLfieldList* fl) { delete fl; } @@ -258,7 +249,7 @@ class SQLConn : public classbase if (tds_process_simple_query(sock) != TDS_SUCCEED) { LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id); LoggingMutex->Unlock(); CloseDB(); } @@ -266,7 +257,7 @@ class SQLConn : public classbase else { LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id); LoggingMutex->Unlock(); CloseDB(); } @@ -274,7 +265,7 @@ class SQLConn : public classbase else { LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not connect to DB with id: " + host.id); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not connect to DB with id: " + host.id); LoggingMutex->Unlock(); CloseDB(); } @@ -433,7 +424,7 @@ class SQLConn : public classbase char* msquery = strdup(req->query.q.data()); LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql",DEBUG,"doing Query: %s",msquery); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "doing Query: %s",msquery); LoggingMutex->Unlock(); if (tds_submit_query(sock, msquery) != TDS_SUCCEED) { @@ -449,8 +440,8 @@ class SQLConn : public classbase int tds_res; while (tds_process_tokens(sock, &tds_res, NULL, TDS_TOKEN_RESULTS) == TDS_SUCCEED) { - //ServerInstance->Logs->Log("m_mssql",DEBUG,"<******> result type: %d", tds_res); - //ServerInstance->Logs->Log("m_mssql",DEBUG,"AFFECTED ROWS: %d", sock->rows_affected); + //ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "<******> result type: %d", tds_res); + //ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "AFFECTED ROWS: %d", sock->rows_affected); switch (tds_res) { case TDS_ROWFMT_RESULT: @@ -476,8 +467,8 @@ class SQLConn : public classbase if (sock->res_info->row_count > 0) { int cols = sock->res_info->num_cols; - char** name = new char*[MAXBUF]; - char** data = new char*[MAXBUF]; + char** name = new char*[512]; + char** data = new char*[512]; for (int j=0; j<cols; j++) { TDSCOLUMN* col = sock->current_results->columns[j]; @@ -516,7 +507,7 @@ class SQLConn : public classbase { SQLConn* sc = (SQLConn*)pContext->parent; LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql", DEBUG, "Message for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Message for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message); LoggingMutex->Unlock(); return 0; } @@ -525,7 +516,7 @@ class SQLConn : public classbase { SQLConn* sc = (SQLConn*)pContext->parent; LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql", DEFAULT, "Error for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Error for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message); LoggingMutex->Unlock(); return 0; } @@ -657,18 +648,14 @@ class ModuleMsSQL : public Module queryDispatcher = new QueryThread(this); } - void init() + void init() CXX11_OVERRIDE { ReadConf(); - ServerInstance->Threads->Start(queryDispatcher); - - Implementation eventlist[] = { I_OnRehash }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - ServerInstance->Modules->AddService(sqlserv); + ServerInstance->Threads.Start(queryDispatcher); } - virtual ~ModuleMsSQL() + ~ModuleMsSQL() { queryDispatcher->join(); delete queryDispatcher; @@ -753,7 +740,7 @@ class ModuleMsSQL : public Module if (HasHost(hi)) { LoggingMutex->Lock(); - ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: A MsSQL connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str()); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: A MsSQL connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str()); LoggingMutex->Unlock(); return; } @@ -787,14 +774,14 @@ class ModuleMsSQL : public Module connections.clear(); } - virtual void OnRehash(User* user) + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { queryDispatcher->LockQueue(); ReadConf(); queryDispatcher->UnlockQueueWakeup(); } - void OnRequest(Request& request) + void OnRequest(Request& request) CXX11_OVERRIDE { if(strcmp(SQLREQID, request.id) == 0) { @@ -825,7 +812,7 @@ class ModuleMsSQL : public Module return ++currid; } - virtual Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("MsSQL provider", VF_VENDOR); } diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp index 01b1553b0..d8dda27a4 100644 --- a/src/modules/extra/m_mysql.cpp +++ b/src/modules/extra/m_mysql.cpp @@ -25,7 +25,7 @@ #include "inspircd.h" #include <mysql.h> -#include "sql.h" +#include "modules/sql.h" #ifdef _WIN32 # pragma comment(lib, "libmysql.lib") @@ -33,7 +33,6 @@ /* VERSION 3 API: With nonblocking (threaded) requests */ -/* $ModDesc: SQL Service Provider module for all other m_sql* modules */ /* $CompileFlags: exec("mysql_config --include") */ /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */ @@ -90,7 +89,7 @@ struct RQueueItem RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {} }; -typedef std::map<std::string, SQLConnection*> ConnMap; +typedef insp::flat_map<std::string, SQLConnection*> ConnMap; typedef std::deque<QQueueItem> QueryQueue; typedef std::deque<RQueueItem> ResultQueue; @@ -105,11 +104,11 @@ class ModuleSQL : public Module ConnMap connections; // main thread only ModuleSQL(); - void init(); + void init() CXX11_OVERRIDE; ~ModuleSQL(); - void OnRehash(User* user); - void OnUnloadModule(Module* mod); - Version GetVersion(); + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE; + void OnUnloadModule(Module* mod) CXX11_OVERRIDE; + Version GetVersion() CXX11_OVERRIDE; }; class DispatcherThread : public SocketThread @@ -119,8 +118,8 @@ class DispatcherThread : public SocketThread public: DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { } ~DispatcherThread() { } - virtual void Run(); - virtual void OnNotify(); + void Run(); + void OnNotify(); }; #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224 @@ -186,21 +185,17 @@ class MySQLresult : public SQLResult } - ~MySQLresult() - { - } - - virtual int Rows() + int Rows() { return rows; } - virtual void GetCols(std::vector<std::string>& result) + void GetCols(std::vector<std::string>& result) { result.assign(colnames.begin(), colnames.end()); } - virtual SQLEntry GetValue(int row, int column) + SQLEntry GetValue(int row, int column) { if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size())) { @@ -209,7 +204,7 @@ class MySQLresult : public SQLResult return SQLEntry(); } - virtual bool GetRow(SQLEntries& result) + bool GetRow(SQLEntries& result) { if (currentrow < rows) { @@ -260,6 +255,12 @@ class SQLConnection : public SQLProvider bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0); if (!rv) return rv; + + // Enable character set settings + std::string charset = config->getString("charset"); + if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str()))) + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str()); + std::string initquery; if (config->readString("initialquery", initquery)) { @@ -383,12 +384,7 @@ ModuleSQL::ModuleSQL() void ModuleSQL::init() { Dispatcher = new DispatcherThread(this); - ServerInstance->Threads->Start(Dispatcher); - - Implementation eventlist[] = { I_OnRehash, I_OnUnloadModule }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - - OnRehash(NULL); + ServerInstance->Threads.Start(Dispatcher); } ModuleSQL::~ModuleSQL() @@ -405,7 +401,7 @@ ModuleSQL::~ModuleSQL() } } -void ModuleSQL::OnRehash(User* user) +void ModuleSQL::ReadConfig(ConfigStatus& status) { ConnMap conns; ConfigTagList tags = ServerInstance->Config->ConfTags("database"); diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp index ac247548a..ff8c1174c 100644 --- a/src/modules/extra/m_pgsql.cpp +++ b/src/modules/extra/m_pgsql.cpp @@ -24,11 +24,9 @@ #include "inspircd.h" #include <cstdlib> -#include <sstream> #include <libpq-fe.h> -#include "sql.h" +#include "modules/sql.h" -/* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */ /* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */ /* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */ @@ -43,7 +41,7 @@ class SQLConn; class ModulePgSQL; -typedef std::map<std::string, SQLConn*> ConnMap; +typedef insp::flat_map<std::string, SQLConn*> ConnMap; /* CREAD, Connecting and wants read event * CWRITE, Connecting and wants write event @@ -59,10 +57,10 @@ class ReconnectTimer : public Timer private: ModulePgSQL* mod; public: - ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m) + ReconnectTimer(ModulePgSQL* m) : Timer(5, false), mod(m) { } - virtual void Tick(time_t TIME); + bool Tick(time_t TIME); }; struct QueueItem @@ -97,12 +95,12 @@ class PgSQLresult : public SQLResult PQclear(res); } - virtual int Rows() + int Rows() { return rows; } - virtual void GetCols(std::vector<std::string>& result) + void GetCols(std::vector<std::string>& result) { result.resize(PQnfields(res)); for(unsigned int i=0; i < result.size(); i++) @@ -111,7 +109,7 @@ class PgSQLresult : public SQLResult } } - virtual SQLEntry GetValue(int row, int column) + SQLEntry GetValue(int row, int column) { char* v = PQgetvalue(res, row, column); if (!v || PQgetisnull(res, row, column)) @@ -120,7 +118,7 @@ class PgSQLresult : public SQLResult return SQLEntry(std::string(v, PQgetlength(res, row, column))); } - virtual bool GetRow(SQLEntries& result) + bool GetRow(SQLEntries& result) { if (currentrow >= PQntuples(res)) return false; @@ -152,7 +150,7 @@ class SQLConn : public SQLProvider, public EventHandler { if (!DoConnect()) { - ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database " + tag->getString("id")); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not connect to database " + tag->getString("id")); DelayReconnect(); } } @@ -180,18 +178,19 @@ class SQLConn : public SQLProvider, public EventHandler } } - virtual void HandleEvent(EventType et, int errornum) + void OnEventHandlerRead() CXX11_OVERRIDE { - switch (et) - { - case EVENT_READ: - case EVENT_WRITE: - DoEvent(); - break; + DoEvent(); + } - case EVENT_ERROR: - DelayReconnect(); - } + void OnEventHandlerWrite() CXX11_OVERRIDE + { + DoEvent(); + } + + void OnEventHandlerError(int errornum) CXX11_OVERRIDE + { + DelayReconnect(); } std::string GetDSN() @@ -242,9 +241,9 @@ class SQLConn : public SQLProvider, public EventHandler if(this->fd <= -1) return false; - if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ)) + if (!SocketEngine::AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ)) { - ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine"); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Couldn't add pgsql socket to socket engine"); return false; } @@ -257,17 +256,17 @@ class SQLConn : public SQLProvider, public EventHandler switch(PQconnectPoll(sql)) { case PGRES_POLLING_WRITING: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ); status = CWRITE; return true; case PGRES_POLLING_READING: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); status = CREAD; return true; case PGRES_POLLING_FAILED: return false; case PGRES_POLLING_OK: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); status = WWRITE; DoConnectedPoll(); default: @@ -350,17 +349,17 @@ restart: switch(PQresetPoll(sql)) { case PGRES_POLLING_WRITING: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ); status = CWRITE; return DoPoll(); case PGRES_POLLING_READING: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); status = CREAD; return true; case PGRES_POLLING_FAILED: return false; case PGRES_POLLING_OK: - ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); status = WWRITE; DoConnectedPoll(); default: @@ -417,7 +416,7 @@ restart: int error; size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error); if (error) - ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed"); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed"); #else size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length()); #endif @@ -452,7 +451,7 @@ restart: int error; size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error); if (error) - ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed"); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed"); #else size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length()); #endif @@ -488,7 +487,7 @@ restart: void Close() { - ServerInstance->SE->DelFd(this); + SocketEngine::DelFd(this); if(sql) { @@ -505,25 +504,17 @@ class ModulePgSQL : public Module ReconnectTimer* retimer; ModulePgSQL() + : retimer(NULL) { } - void init() - { - ReadConf(); - - Implementation eventlist[] = { I_OnUnloadModule, I_OnRehash }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - } - - virtual ~ModulePgSQL() + ~ModulePgSQL() { - if (retimer) - ServerInstance->Timers->DelTimer(retimer); + delete retimer; ClearAllConnections(); } - virtual void OnRehash(User* user) + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { ReadConf(); } @@ -564,7 +555,7 @@ class ModulePgSQL : public Module connections.clear(); } - void OnUnloadModule(Module* mod) + void OnUnloadModule(Module* mod) CXX11_OVERRIDE { SQLerror err(SQL_BAD_DBID); for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++) @@ -592,16 +583,18 @@ class ModulePgSQL : public Module } } - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR); } }; -void ReconnectTimer::Tick(time_t time) +bool ReconnectTimer::Tick(time_t time) { mod->retimer = NULL; mod->ReadConf(); + delete this; + return false; } void SQLConn::DelayReconnect() @@ -615,7 +608,7 @@ void SQLConn::DelayReconnect() if (!mod->retimer) { mod->retimer = new ReconnectTimer(mod); - ServerInstance->Timers->AddTimer(mod->retimer); + ServerInstance->Timers.AddTimer(mod->retimer); } } } diff --git a/src/modules/extra/m_regex_pcre.cpp b/src/modules/extra/m_regex_pcre.cpp index cba234c8c..9ae6719ba 100644 --- a/src/modules/extra/m_regex_pcre.cpp +++ b/src/modules/extra/m_regex_pcre.cpp @@ -20,10 +20,8 @@ #include "inspircd.h" #include <pcre.h> -#include "m_regex.h" +#include "modules/regex.h" -/* $ModDesc: Regex Provider Module for PCRE */ -/* $ModDep: m_regex.h */ /* $CompileFlags: exec("pcre-config --cflags") */ /* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */ @@ -31,21 +29,11 @@ # pragma comment(lib, "libpcre.lib") #endif -class PCREException : public ModuleException -{ -public: - PCREException(const std::string& rx, const std::string& error, int erroffset) - : ModuleException("Error in regex " + rx + " at offset " + ConvToStr(erroffset) + ": " + error) - { - } -}; - class PCRERegex : public Regex { -private: pcre* regex; -public: + public: PCRERegex(const std::string& rx) : Regex(rx) { const char* error; @@ -53,24 +41,19 @@ public: regex = pcre_compile(rx.c_str(), 0, &error, &erroffset, NULL); if (!regex) { - ServerInstance->Logs->Log("REGEX", DEBUG, "pcre_compile failed: /%s/ [%d] %s", rx.c_str(), erroffset, error); - throw PCREException(rx, error, erroffset); + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "pcre_compile failed: /%s/ [%d] %s", rx.c_str(), erroffset, error); + throw RegexException(rx, error, erroffset); } } - virtual ~PCRERegex() + ~PCRERegex() { pcre_free(regex); } - virtual bool Matches(const std::string& text) + bool Matches(const std::string& text) CXX11_OVERRIDE { - if (pcre_exec(regex, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1) - { - // Bang. :D - return true; - } - return false; + return (pcre_exec(regex, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) >= 0); } }; @@ -78,7 +61,7 @@ class PCREFactory : public RegexFactory { public: PCREFactory(Module* m) : RegexFactory(m, "regex/pcre") {} - Regex* Create(const std::string& expr) + Regex* Create(const std::string& expr) CXX11_OVERRIDE { return new PCRERegex(expr); } @@ -86,13 +69,13 @@ class PCREFactory : public RegexFactory class ModuleRegexPCRE : public Module { -public: + public: PCREFactory ref; - ModuleRegexPCRE() : ref(this) { - ServerInstance->Modules->AddService(ref); + ModuleRegexPCRE() : ref(this) + { } - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("Regex Provider Module for PCRE", VF_VENDOR); } diff --git a/src/modules/extra/m_regex_posix.cpp b/src/modules/extra/m_regex_posix.cpp index b3afd60c8..b5fddfab8 100644 --- a/src/modules/extra/m_regex_posix.cpp +++ b/src/modules/extra/m_regex_posix.cpp @@ -19,28 +19,15 @@ #include "inspircd.h" -#include "m_regex.h" +#include "modules/regex.h" #include <sys/types.h> #include <regex.h> -/* $ModDesc: Regex Provider Module for POSIX Regular Expressions */ -/* $ModDep: m_regex.h */ - -class POSIXRegexException : public ModuleException -{ -public: - POSIXRegexException(const std::string& rx, const std::string& error) - : ModuleException("Error in regex " + rx + ": " + error) - { - } -}; - class POSIXRegex : public Regex { -private: regex_t regbuf; -public: + public: POSIXRegex(const std::string& rx, bool extended) : Regex(rx) { int flags = (extended ? REG_EXTENDED : 0) | REG_NOSUB; @@ -58,23 +45,18 @@ public: error = errbuf; delete[] errbuf; regfree(®buf); - throw POSIXRegexException(rx, error); + throw RegexException(rx, error); } } - virtual ~POSIXRegex() + ~POSIXRegex() { regfree(®buf); } - virtual bool Matches(const std::string& text) + bool Matches(const std::string& text) CXX11_OVERRIDE { - if (regexec(®buf, text.c_str(), 0, NULL, 0) == 0) - { - // Bang. :D - return true; - } - return false; + return (regexec(®buf, text.c_str(), 0, NULL, 0) == 0); } }; @@ -83,7 +65,7 @@ class PosixFactory : public RegexFactory public: bool extended; PosixFactory(Module* m) : RegexFactory(m, "regex/posix") {} - Regex* Create(const std::string& expr) + Regex* Create(const std::string& expr) CXX11_OVERRIDE { return new POSIXRegex(expr, extended); } @@ -92,20 +74,18 @@ class PosixFactory : public RegexFactory class ModuleRegexPOSIX : public Module { PosixFactory ref; -public: - ModuleRegexPOSIX() : ref(this) { - ServerInstance->Modules->AddService(ref); - Implementation eventlist[] = { I_OnRehash }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - OnRehash(NULL); + + public: + ModuleRegexPOSIX() : ref(this) + { } - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("Regex Provider Module for POSIX Regular Expressions", VF_VENDOR); } - void OnRehash(User* u) + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { ref.extended = ServerInstance->Config->ConfValue("posix")->getBool("extended"); } diff --git a/src/modules/extra/m_regex_re2.cpp b/src/modules/extra/m_regex_re2.cpp new file mode 100644 index 000000000..c4657bf8b --- /dev/null +++ b/src/modules/extra/m_regex_re2.cpp @@ -0,0 +1,81 @@ +/* + * InspIRCd -- Internet Relay Chat Daemon + * + * Copyright (C) 2013 Peter Powell <petpow@saberuk.com> + * Copyright (C) 2012 ChrisTX <chris@rev-crew.info> + * + * This file is part of InspIRCd. InspIRCd is free software: you can + * redistribute it and/or modify it under the terms of the GNU General Public + * License as published by the Free Software Foundation, version 2. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + + +#include "inspircd.h" +#include "modules/regex.h" + +// Fix warnings about the use of `long long` on C++03 and +// shadowing on GCC. +#if defined __clang__ +# pragma clang diagnostic ignored "-Wc++11-long-long" +#elif defined __GNUC__ +# pragma GCC diagnostic ignored "-Wlong-long" +# pragma GCC diagnostic ignored "-Wshadow" +#endif + +#include <re2/re2.h> + +/* $LinkerFlags: -lre2 */ + +class RE2Regex : public Regex +{ + RE2 regexcl; + + public: + RE2Regex(const std::string& rx) : Regex(rx), regexcl(rx, RE2::Quiet) + { + if (!regexcl.ok()) + { + throw RegexException(rx, regexcl.error()); + } + } + + bool Matches(const std::string& text) CXX11_OVERRIDE + { + return RE2::FullMatch(text, regexcl); + } +}; + +class RE2Factory : public RegexFactory +{ + public: + RE2Factory(Module* m) : RegexFactory(m, "regex/re2") { } + Regex* Create(const std::string& expr) CXX11_OVERRIDE + { + return new RE2Regex(expr); + } +}; + +class ModuleRegexRE2 : public Module +{ + RE2Factory ref; + + public: + ModuleRegexRE2() : ref(this) + { + } + + Version GetVersion() CXX11_OVERRIDE + { + return Version("Regex Provider Module for RE2", VF_VENDOR); + } +}; + +MODULE_INIT(ModuleRegexRE2) diff --git a/src/modules/extra/m_regex_stdlib.cpp b/src/modules/extra/m_regex_stdlib.cpp index 204728b65..8e7bd0da2 100644 --- a/src/modules/extra/m_regex_stdlib.cpp +++ b/src/modules/extra/m_regex_stdlib.cpp @@ -15,32 +15,18 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ - + #include "inspircd.h" -#include "m_regex.h" +#include "modules/regex.h" #include <regex> -/* $ModDesc: Regex Provider Module for std::regex Regular Expressions */ -/* $ModConfig: <stdregex type="ecmascript"> - * Specify the Regular Expression engine to use here. Valid settings are - * bre, ere, awk, grep, egrep, ecmascript (default if not specified)*/ /* $CompileFlags: -std=c++11 */ -/* $ModDep: m_regex.h */ - -class StdRegexException : public ModuleException -{ -public: - StdRegexException(const std::string& rx, const std::string& error) - : ModuleException(std::string("Error in regex ") + rx + ": " + error) - { - } -}; class StdRegex : public Regex { -private: std::regex regexcl; -public: + + public: StdRegex(const std::string& rx, std::regex::flag_type fltype) : Regex(rx) { try{ @@ -48,11 +34,11 @@ public: } catch(std::regex_error rxerr) { - throw StdRegexException(rx, rxerr.what()); + throw RegexException(rx, rxerr.what()); } } - - virtual bool Matches(const std::string& text) + + bool Matches(const std::string& text) CXX11_OVERRIDE { return std::regex_search(text, regexcl); } @@ -63,7 +49,7 @@ class StdRegexFactory : public RegexFactory public: std::regex::flag_type regextype; StdRegexFactory(Module* m) : RegexFactory(m, "regex/stdregex") {} - Regex* Create(const std::string& expr) + Regex* Create(const std::string& expr) CXX11_OVERRIDE { return new StdRegex(expr, regextype); } @@ -73,23 +59,20 @@ class ModuleRegexStd : public Module { public: StdRegexFactory ref; - ModuleRegexStd() : ref(this) { - ServerInstance->Modules->AddService(ref); - Implementation eventlist[] = { I_OnRehash }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - OnRehash(NULL); + ModuleRegexStd() : ref(this) + { } - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("Regex Provider Module for std::regex", VF_VENDOR); } - - void OnRehash(User* u) + + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { ConfigTag* Conf = ServerInstance->Config->ConfValue("stdregex"); std::string regextype = Conf->getString("type", "ecmascript"); - + if(regextype == "bre") ref.regextype = std::regex::basic; else if(regextype == "ere") diff --git a/src/modules/extra/m_regex_tre.cpp b/src/modules/extra/m_regex_tre.cpp index 4b9eab472..8a1d54248 100644 --- a/src/modules/extra/m_regex_tre.cpp +++ b/src/modules/extra/m_regex_tre.cpp @@ -19,27 +19,15 @@ #include "inspircd.h" -#include "m_regex.h" +#include "modules/regex.h" #include <sys/types.h> #include <tre/regex.h> -/* $ModDesc: Regex Provider Module for TRE Regular Expressions */ /* $CompileFlags: pkgconfincludes("tre","tre/regex.h","") */ /* $LinkerFlags: pkgconflibs("tre","/libtre.so","-ltre") rpath("pkg-config --libs tre") */ -/* $ModDep: m_regex.h */ - -class TRERegexException : public ModuleException -{ -public: - TRERegexException(const std::string& rx, const std::string& error) - : ModuleException("Error in regex " + rx + ": " + error) - { - } -}; class TRERegex : public Regex { -private: regex_t regbuf; public: @@ -60,30 +48,26 @@ public: error = errbuf; delete[] errbuf; regfree(®buf); - throw TRERegexException(rx, error); + throw RegexException(rx, error); } } - virtual ~TRERegex() + ~TRERegex() { regfree(®buf); } - virtual bool Matches(const std::string& text) + bool Matches(const std::string& text) CXX11_OVERRIDE { - if (regexec(®buf, text.c_str(), 0, NULL, 0) == 0) - { - // Bang. :D - return true; - } - return false; + return (regexec(®buf, text.c_str(), 0, NULL, 0) == 0); } }; -class TREFactory : public RegexFactory { +class TREFactory : public RegexFactory +{ public: TREFactory(Module* m) : RegexFactory(m, "regex/tre") {} - Regex* Create(const std::string& expr) + Regex* Create(const std::string& expr) CXX11_OVERRIDE { return new TRERegex(expr); } @@ -92,18 +76,15 @@ class TREFactory : public RegexFactory { class ModuleRegexTRE : public Module { TREFactory trf; -public: - ModuleRegexTRE() : trf(this) { - ServerInstance->Modules->AddService(trf); - } - Version GetVersion() + public: + ModuleRegexTRE() : trf(this) { - return Version("Regex Provider Module for TRE Regular Expressions", VF_VENDOR); } - ~ModuleRegexTRE() + Version GetVersion() CXX11_OVERRIDE { + return Version("Regex Provider Module for TRE Regular Expressions", VF_VENDOR); } }; diff --git a/src/modules/extra/m_sqlite3.cpp b/src/modules/extra/m_sqlite3.cpp index 1e3a65a18..05203da39 100644 --- a/src/modules/extra/m_sqlite3.cpp +++ b/src/modules/extra/m_sqlite3.cpp @@ -21,20 +21,26 @@ #include "inspircd.h" +#include "modules/sql.h" + +// Fix warnings about the use of `long long` on C++03. +#if defined __clang__ +# pragma clang diagnostic ignored "-Wc++11-long-long" +#elif defined __GNUC__ +# pragma GCC diagnostic ignored "-Wlong-long" +#endif + #include <sqlite3.h> -#include "sql.h" #ifdef _WIN32 # pragma comment(lib, "sqlite3.lib") #endif -/* $ModDesc: sqlite3 provider */ /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */ /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */ -/* $NoPedantic */ class SQLConn; -typedef std::map<std::string, SQLConn*> ConnMap; +typedef insp::flat_map<std::string, SQLConn*> ConnMap; class SQLite3Result : public SQLResult { @@ -48,16 +54,12 @@ class SQLite3Result : public SQLResult { } - ~SQLite3Result() - { - } - - virtual int Rows() + int Rows() { return rows; } - virtual bool GetRow(SQLEntries& result) + bool GetRow(SQLEntries& result) { if (currentrow < rows) { @@ -72,7 +74,7 @@ class SQLite3Result : public SQLResult } } - virtual void GetCols(std::vector<std::string>& result) + void GetCols(std::vector<std::string>& result) { result.assign(columns.begin(), columns.end()); } @@ -80,7 +82,6 @@ class SQLite3Result : public SQLResult class SQLConn : public SQLProvider { - private: sqlite3* conn; reference<ConfigTag> config; @@ -90,7 +91,7 @@ class SQLConn : public SQLProvider std::string host = tag->getString("hostname"); if (sqlite3_open_v2(host.c_str(), &conn, SQLITE_OPEN_READWRITE, 0) != SQLITE_OK) { - ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id")); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id")); conn = NULL; } } @@ -152,13 +153,13 @@ class SQLConn : public SQLProvider sqlite3_finalize(stmt); } - virtual void submit(SQLQuery* query, const std::string& q) + void submit(SQLQuery* query, const std::string& q) { Query(query, q); delete query; } - virtual void submit(SQLQuery* query, const std::string& q, const ParamL& p) + void submit(SQLQuery* query, const std::string& q, const ParamL& p) { std::string res; unsigned int param = 0; @@ -179,7 +180,7 @@ class SQLConn : public SQLProvider submit(query, res); } - virtual void submit(SQLQuery* query, const std::string& q, const ParamM& p) + void submit(SQLQuery* query, const std::string& q, const ParamM& p) { std::string res; for(std::string::size_type i = 0; i < q.length(); i++) @@ -209,23 +210,10 @@ class SQLConn : public SQLProvider class ModuleSQLite3 : public Module { - private: ConnMap conns; public: - ModuleSQLite3() - { - } - - void init() - { - ReadConf(); - - Implementation eventlist[] = { I_OnRehash }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - } - - virtual ~ModuleSQLite3() + ~ModuleSQLite3() { ClearConns(); } @@ -241,7 +229,7 @@ class ModuleSQLite3 : public Module conns.clear(); } - void ReadConf() + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { ClearConns(); ConfigTagList tags = ServerInstance->Config->ConfTags("database"); @@ -255,12 +243,7 @@ class ModuleSQLite3 : public Module } } - void OnRehash(User* user) - { - ReadConf(); - } - - Version GetVersion() + Version GetVersion() CXX11_OVERRIDE { return Version("sqlite3 provider", VF_VENDOR); } diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 2f4acf3f0..e5cb8ee90 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -22,117 +22,90 @@ #include "inspircd.h" +#include "modules/ssl.h" +#include <memory> + +// Fix warnings about the use of commas at end of enumerator lists on C++03. +#if defined __clang__ +# pragma clang diagnostic ignored "-Wc++11-extensions" +#elif defined __GNUC__ +# if __GNUC__ < 6 +# pragma GCC diagnostic ignored "-pedantic" +# else +# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +# endif +#endif + #include <gnutls/gnutls.h> #include <gnutls/x509.h> -#include "ssl.h" -#include "m_cap.h" -#ifdef _WIN32 -# pragma comment(lib, "libgnutls-30.lib") +#ifndef GNUTLS_VERSION_NUMBER +#define GNUTLS_VERSION_NUMBER LIBGNUTLS_VERSION_NUMBER +#define GNUTLS_VERSION LIBGNUTLS_VERSION #endif -/* $ModDesc: Provides SSL support for clients */ -/* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") iflt("pkg-config --modversion gnutls","2.12") exec("libgcrypt-config --cflags") */ -/* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") iflt("pkg-config --modversion gnutls","2.12") exec("libgcrypt-config --libs") */ -/* $NoPedantic */ +// Check if the GnuTLS library is at least version major.minor.patch +#define INSPIRCD_GNUTLS_HAS_VERSION(major, minor, patch) (GNUTLS_VERSION_NUMBER >= ((major << 16) | (minor << 8) | patch)) -#ifndef GNUTLS_VERSION_MAJOR -#define GNUTLS_VERSION_MAJOR LIBGNUTLS_VERSION_MAJOR -#define GNUTLS_VERSION_MINOR LIBGNUTLS_VERSION_MINOR -#define GNUTLS_VERSION_PATCH LIBGNUTLS_VERSION_PATCH +#if INSPIRCD_GNUTLS_HAS_VERSION(2, 9, 8) +#define GNUTLS_HAS_MAC_GET_ID +#include <gnutls/crypto.h> #endif -// These don't exist in older GnuTLS versions -#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 1) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 1 && GNUTLS_VERSION_PATCH >= 7)) -#define GNUTLS_NEW_PRIO_API -#endif - -#if(GNUTLS_VERSION_MAJOR < 2) -typedef gnutls_certificate_credentials_t gnutls_certificate_credentials; -typedef gnutls_dh_params_t gnutls_dh_params; -#endif - -#if (GNUTLS_VERSION_MAJOR > 2 || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR >= 12)) +#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0) # define GNUTLS_HAS_RND -# include <gnutls/crypto.h> #else # include <gcrypt.h> #endif -enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED }; - -struct SSLConfig : public refcountbase -{ - gnutls_certificate_credentials_t x509_cred; - std::vector<gnutls_x509_crt_t> x509_certs; - gnutls_x509_privkey_t x509_key; - gnutls_dh_params_t dh_params; -#ifdef GNUTLS_NEW_PRIO_API - gnutls_priority_t priority; +#ifdef _WIN32 +# pragma comment(lib, "libgnutls-30.lib") #endif - SSLConfig() - : x509_cred(NULL) - , x509_key(NULL) - , dh_params(NULL) -#ifdef GNUTLS_NEW_PRIO_API - , priority(NULL) -#endif - { - } +/* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") eval("print `libgcrypt-config --cflags | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */ +/* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") eval("print `libgcrypt-config --libs | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */ - ~SSLConfig() - { - ServerInstance->Logs->Log("m_ssl_gnutls", DEBUG, "Destroying SSLConfig %p", (void*)this); - - if (x509_cred) - gnutls_certificate_free_credentials(x509_cred); - - for (unsigned int i = 0; i < x509_certs.size(); i++) - gnutls_x509_crt_deinit(x509_certs[i]); - - if (x509_key) - gnutls_x509_privkey_deinit(x509_key); - - if (dh_params) - gnutls_dh_params_deinit(dh_params); +// These don't exist in older GnuTLS versions +#if INSPIRCD_GNUTLS_HAS_VERSION(2, 1, 7) +#define GNUTLS_NEW_PRIO_API +#endif -#ifdef GNUTLS_NEW_PRIO_API - if (priority) - gnutls_priority_deinit(priority); +#if (!INSPIRCD_GNUTLS_HAS_VERSION(2, 0, 0)) +typedef gnutls_certificate_credentials_t gnutls_certificate_credentials; +typedef gnutls_dh_params_t gnutls_dh_params; #endif - } -}; -static reference<SSLConfig> currconf; +enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_HANDSHAKEN }; -static SSLConfig* GetSessionConfig(gnutls_session_t session); +#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0) +#define INSPIRCD_GNUTLS_HAS_VECTOR_PUSH +#define GNUTLS_NEW_CERT_CALLBACK_API +typedef gnutls_retr2_st cert_cb_last_param_type; +#else +typedef gnutls_retr_st cert_cb_last_param_type; +#endif -#if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) ) -static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs, - const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr_st * st) { +#if INSPIRCD_GNUTLS_HAS_VERSION(3, 3, 5) +#define INSPIRCD_GNUTLS_HAS_RECV_PACKET +#endif - st->type = GNUTLS_CRT_X509; +#if INSPIRCD_GNUTLS_HAS_VERSION(2, 99, 0) +// The second parameter of gnutls_init() has changed in 2.99.0 from gnutls_connection_end_t to unsigned int +// (it became a general flags parameter) and the enum has been deprecated and generates a warning on use. +typedef unsigned int inspircd_gnutls_session_init_flags_t; #else -static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs, - const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr2_st * st) { - st->cert_type = GNUTLS_CRT_X509; - st->key_type = GNUTLS_PRIVKEY_X509; +typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t; #endif - SSLConfig* conf = GetSessionConfig(session); - std::vector<gnutls_x509_crt_t>& x509_certs = conf->x509_certs; - st->ncerts = x509_certs.size(); - st->cert.x509 = &x509_certs[0]; - st->key.x509 = conf->x509_key; - st->deinit_all = 0; - return 0; -} +#if INSPIRCD_GNUTLS_HAS_VERSION(3, 1, 9) +#define INSPIRCD_GNUTLS_HAS_CORK +#endif + +static Module* thismod; class RandGen : public HandlerBase2<void, char*, size_t> { public: - RandGen() {} void Call(char* buffer, size_t len) { #ifdef GNUTLS_HAS_RND @@ -143,749 +116,670 @@ class RandGen : public HandlerBase2<void, char*, size_t> } }; -/** Represents an SSL user's extra data - */ -class issl_session -{ -public: - StreamSocket* socket; - gnutls_session_t sess; - issl_status status; - reference<ssl_cert> cert; - reference<SSLConfig> config; - - issl_session() : socket(NULL), sess(NULL), status(ISSL_NONE) {} -}; - -static SSLConfig* GetSessionConfig(gnutls_session_t sess) -{ - issl_session* session = reinterpret_cast<issl_session*>(gnutls_transport_get_ptr(sess)); - return session->config; -} - -class CommandStartTLS : public SplitCommand +namespace GnuTLS { - public: - bool enabled; - CommandStartTLS (Module* mod) : SplitCommand(mod, "STARTTLS") + class Init { - enabled = true; - works_before_reg = true; - } + public: + Init() { gnutls_global_init(); } + ~Init() { gnutls_global_deinit(); } + }; - CmdResult HandleLocal(const std::vector<std::string> ¶meters, LocalUser *user) + class Exception : public ModuleException { - if (!enabled) - { - user->WriteNumeric(691, "%s :STARTTLS is not enabled", user->nick.c_str()); - return CMD_FAILURE; - } + public: + Exception(const std::string& reason) + : ModuleException(reason) { } + }; - if (user->registered == REG_ALL) - { - user->WriteNumeric(691, "%s :STARTTLS is not permitted after client registration is complete", user->nick.c_str()); - } - else + void ThrowOnError(int errcode, const char* msg) + { + if (errcode < 0) { - if (!user->eh.GetIOHook()) - { - user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str()); - /* We need to flush the write buffer prior to adding the IOHook, - * otherwise we'll be sending this line inside the SSL session - which - * won't start its handshake until the client gets this line. Currently, - * we assume the write will not block here; this is usually safe, as - * STARTTLS is sent very early on in the registration phase, where the - * user hasn't built up much sendq. Handling a blocked write here would - * be very annoying. - */ - user->eh.DoWrite(); - user->eh.AddIOHook(creator); - creator->OnStreamSocketAccept(&user->eh, NULL, NULL); - } - else - user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str()); + std::string reason = msg; + reason.append(" :").append(gnutls_strerror(errcode)); + throw Exception(reason); } - - return CMD_FAILURE; } -}; - -class ModuleSSLGnuTLS : public Module -{ - issl_session* sessions; - - gnutls_digest_algorithm_t hash; - std::string sslports; - int dh_bits; + /** Used to create a gnutls_datum_t* from a std::string + */ + class Datum + { + gnutls_datum_t datum; - RandGen randhandler; - CommandStartTLS starttls; + public: + Datum(const std::string& dat) + { + datum.data = (unsigned char*)dat.data(); + datum.size = static_cast<unsigned int>(dat.length()); + } - GenericCap capHandler; - ServiceProvider iohook; + const gnutls_datum_t* get() const { return &datum; } + }; - inline static const char* UnknownIfNULL(const char* str) + class Hash { - return str ? str : "UNKNOWN"; - } + gnutls_digest_algorithm_t hash; - static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) - { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK) + public: + // Nothing to deallocate, constructor may throw freely + Hash(const std::string& hashname) { -#ifdef _WIN32 - gnutls_transport_set_errno(session->sess, EAGAIN); + // As older versions of gnutls can't do this, let's disable it where needed. +#ifdef GNUTLS_HAS_MAC_GET_ID + // As gnutls_digest_algorithm_t and gnutls_mac_algorithm_t are mapped 1:1, we can do this + // There is no gnutls_dig_get_id() at the moment, but it may come later + hash = (gnutls_digest_algorithm_t)gnutls_mac_get_id(hashname.c_str()); + if (hash == GNUTLS_DIG_UNKNOWN) + throw Exception("Unknown hash type " + hashname); + + // Check if the user is giving us something that is a valid MAC but not digest + gnutls_hash_hd_t is_digest; + if (gnutls_hash_init(&is_digest, hash) < 0) + throw Exception("Unknown hash type " + hashname); + gnutls_hash_deinit(is_digest, NULL); #else - errno = EAGAIN; + if (hashname == "md5") + hash = GNUTLS_DIG_MD5; + else if (hashname == "sha1") + hash = GNUTLS_DIG_SHA1; +#ifdef INSPIRCD_GNUTLS_ENABLE_SHA256_FINGERPRINT + else if (hashname == "sha256") + hash = GNUTLS_DIG_SHA256; +#endif + else + throw Exception("Unknown hash type " + hashname); #endif - return -1; } - int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0); + gnutls_digest_algorithm_t get() const { return hash; } + }; -#ifdef _WIN32 - if (rv < 0) + class DHParams + { + gnutls_dh_params_t dh_params; + + DHParams() { - /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() - * and then set errno appropriately. - * The gnutls library may also have a different errno variable than us, see - * gnutls_transport_set_errno(3). - */ - gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + ThrowOnError(gnutls_dh_params_init(&dh_params), "gnutls_dh_params_init() failed"); } -#endif - - if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK); - return rv; - } - static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) - { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK) + public: + /** Import */ + static std::auto_ptr<DHParams> Import(const std::string& dhstr) { -#ifdef _WIN32 - gnutls_transport_set_errno(session->sess, EAGAIN); -#else - errno = EAGAIN; -#endif - return -1; + std::auto_ptr<DHParams> dh(new DHParams); + int ret = gnutls_dh_params_import_pkcs3(dh->dh_params, Datum(dhstr).get(), GNUTLS_X509_FMT_PEM); + ThrowOnError(ret, "Unable to import DH params"); + return dh; } - int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0); - -#ifdef _WIN32 - if (rv < 0) + ~DHParams() { - /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() - * and then set errno appropriately. - * The gnutls library may also have a different errno variable than us, see - * gnutls_transport_set_errno(3). - */ - gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + gnutls_dh_params_deinit(dh_params); } -#endif - - if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK); - return rv; - } - public: + const gnutls_dh_params_t& get() const { return dh_params; } + }; - ModuleSSLGnuTLS() - : starttls(this), capHandler(this, "tls"), iohook(this, "ssl/gnutls", SERVICE_IOHOOK) + class X509Key { -#ifndef GNUTLS_HAS_RND - gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); -#endif - - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - - gnutls_global_init(); // This must be called once in the program - } + /** Ensure that the key is deinited in case the constructor of X509Key throws + */ + class RAIIKey + { + public: + gnutls_x509_privkey_t key; - void init() - { - currconf = new SSLConfig; - InitSSLConfig(currconf); + RAIIKey() + { + ThrowOnError(gnutls_x509_privkey_init(&key), "gnutls_x509_privkey_init() failed"); + } - ServerInstance->GenRandom = &randhandler; + ~RAIIKey() + { + gnutls_x509_privkey_deinit(key); + } + } key; - Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnUserConnect, - I_OnEvent, I_OnHookIO, I_OnCheckReady }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); + public: + /** Import */ + X509Key(const std::string& keystr) + { + int ret = gnutls_x509_privkey_import(key.key, Datum(keystr).get(), GNUTLS_X509_FMT_PEM); + ThrowOnError(ret, "Unable to import private key"); + } - ServerInstance->Modules->AddService(iohook); - ServerInstance->Modules->AddService(starttls); - } + gnutls_x509_privkey_t& get() { return key.key; } + }; - void OnRehash(User* user) + class X509CertList { - sslports.clear(); + std::vector<gnutls_x509_crt_t> certs; - ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls"); - starttls.enabled = Conf->getBool("starttls", true); - - if (Conf->getBool("showports", true)) + public: + /** Import */ + X509CertList(const std::string& certstr) { - sslports = Conf->getString("advertisedports"); - if (!sslports.empty()) - return; + unsigned int certcount = 3; + certs.resize(certcount); + Datum datum(certstr); - for (size_t i = 0; i < ServerInstance->ports.size(); i++) + int ret = gnutls_x509_crt_list_import(raw(), &certcount, datum.get(), GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED); + if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER) { - ListenSocket* port = ServerInstance->ports[i]; - if (port->bind_tag->getString("ssl") != "gnutls") - continue; - - const std::string& portid = port->bind_desc; - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %s", portid.c_str()); - - if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1") - { - /* - * Found an SSL port for clients that is not bound to 127.0.0.1 and handled by us, display - * the IP:port in ISUPPORT. - * - * We used to advertise all ports seperated by a ';' char that matched the above criteria, - * but this resulted in too long ISUPPORT lines if there were lots of ports to be displayed. - * To solve this by default we now only display the first IP:port found and let the user - * configure the exact value for the 005 token, if necessary. - */ - sslports = portid; - break; - } + // the buffer wasn't big enough to hold all certs but gnutls changed certcount to the number of available certs, + // try again with a bigger buffer + certs.resize(certcount); + ret = gnutls_x509_crt_list_import(raw(), &certcount, datum.get(), GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED); } - } - } - void OnModuleRehash(User* user, const std::string ¶m) - { - if(param != "ssl") - return; + ThrowOnError(ret, "Unable to load certificates"); - reference<SSLConfig> newconf = new SSLConfig; - try - { - InitSSLConfig(newconf); + // Resize the vector to the actual number of certs because we rely on its size being correct + // when deallocating the certs + certs.resize(certcount); } - catch (ModuleException& ex) + + ~X509CertList() { - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls: Not applying new config. %s", ex.GetReason()); - return; + for (std::vector<gnutls_x509_crt_t>::iterator i = certs.begin(); i != certs.end(); ++i) + gnutls_x509_crt_deinit(*i); } - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls: Applying new config, old config is in use by %d connection(s)", currconf->GetReferenceCount()-1); - currconf = newconf; - } + gnutls_x509_crt_t* raw() { return &certs[0]; } + unsigned int size() const { return certs.size(); } + }; - void InitSSLConfig(SSLConfig* config) + class X509CRL : public refcountbase { - ServerInstance->Logs->Log("m_ssl_gnutls", DEBUG, "Initializing new SSLConfig %p", (void*)config); - - std::string keyfile; - std::string certfile; - std::string cafile; - std::string crlfile; - OnRehash(NULL); - - ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls"); - - cafile = Conf->getString("cafile", CONFIG_PATH "/ca.pem"); - crlfile = Conf->getString("crlfile", CONFIG_PATH "/crl.pem"); - certfile = Conf->getString("certfile", CONFIG_PATH "/cert.pem"); - keyfile = Conf->getString("keyfile", CONFIG_PATH "/key.pem"); - dh_bits = Conf->getInt("dhbits"); - std::string hashname = Conf->getString("hash", "md5"); - - // The GnuTLS manual states that the gnutls_set_default_priority() - // call we used previously when initializing the session is the same - // as setting the "NORMAL" priority string. - // Thus if the setting below is not in the config we will behave exactly - // the same as before, when the priority setting wasn't available. - std::string priorities = Conf->getString("priority", "NORMAL"); - - if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096)) - dh_bits = 1024; - - if (hashname == "md5") - hash = GNUTLS_DIG_MD5; - else if (hashname == "sha1") - hash = GNUTLS_DIG_SHA1; -#ifdef INSPIRCD_GNUTLS_ENABLE_SHA256_FINGERPRINT - else if (hashname == "sha256") - hash = GNUTLS_DIG_SHA256; -#endif - else - throw ModuleException("Unknown hash type " + hashname); - + class RAIICRL + { + public: + gnutls_x509_crl_t crl; - int ret; + RAIICRL() + { + ThrowOnError(gnutls_x509_crl_init(&crl), "gnutls_x509_crl_init() failed"); + } - gnutls_certificate_credentials_t& x509_cred = config->x509_cred; + ~RAIICRL() + { + gnutls_x509_crl_deinit(crl); + } + } crl; - ret = gnutls_certificate_allocate_credentials(&x509_cred); - if (ret < 0) + public: + /** Import */ + X509CRL(const std::string& crlstr) { - // Set to NULL because we can't be sure what value is in it and we must not try to - // deallocate it in case of an error - x509_cred = NULL; - throw ModuleException("Failed to allocate certificate credentials: " + std::string(gnutls_strerror(ret))); + int ret = gnutls_x509_crl_import(get(), Datum(crlstr).get(), GNUTLS_X509_FMT_PEM); + ThrowOnError(ret, "Unable to load certificate revocation list"); } - if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) - ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret)); - - if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) - ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret)); - - FileReader reader; + gnutls_x509_crl_t& get() { return crl.crl; } + }; - reader.LoadFile(certfile); - std::string cert_string = reader.Contents(); - gnutls_datum_t cert_datum = { (unsigned char*)cert_string.data(), static_cast<unsigned int>(cert_string.length()) }; +#ifdef GNUTLS_NEW_PRIO_API + class Priority + { + gnutls_priority_t priority; - reader.LoadFile(keyfile); - std::string key_string = reader.Contents(); - gnutls_datum_t key_datum = { (unsigned char*)key_string.data(), static_cast<unsigned int>(key_string.length()) }; + public: + Priority(const std::string& priorities) + { + // Try to set the priorities for ciphers, kex methods etc. to the user supplied string + // If the user did not supply anything then the string is already set to "NORMAL" + const char* priocstr = priorities.c_str(); + const char* prioerror; - std::vector<gnutls_x509_crt_t>& x509_certs = config->x509_certs; + int ret = gnutls_priority_init(&priority, priocstr, &prioerror); + if (ret < 0) + { + // gnutls did not understand the user supplied string + throw Exception("Unable to initialize priorities to \"" + priorities + "\": " + gnutls_strerror(ret) + " Syntax error at position " + ConvToStr((unsigned int) (prioerror - priocstr))); + } + } - // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException - unsigned int certcount = 3; - x509_certs.resize(certcount); - ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED); - if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER) + ~Priority() { - // the buffer wasn't big enough to hold all certs but gnutls updated certcount to the number of available certs, try again with a bigger buffer - x509_certs.resize(certcount); - ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED); + gnutls_priority_deinit(priority); } - if (ret <= 0) + void SetupSession(gnutls_session_t sess) { - // clear the vector so we won't call gnutls_x509_crt_deinit() on the (uninited) certs later - x509_certs.clear(); - throw ModuleException("Unable to load GnuTLS server certificate (" + certfile + "): " + ((ret < 0) ? (std::string(gnutls_strerror(ret))) : "No certs could be read")); + gnutls_priority_set(sess, priority); } - x509_certs.resize(ret); - gnutls_x509_privkey_t& x509_key = config->x509_key; - if (gnutls_x509_privkey_init(&x509_key) < 0) + static const char* GetDefault() { - // Make sure the destructor does not try to deallocate this, see above - x509_key = NULL; - throw ModuleException("Unable to initialize private key"); + return "NORMAL:%SERVER_PRECEDENCE:-VERS-SSL3.0"; } - if((ret = gnutls_x509_privkey_import(x509_key, &key_datum, GNUTLS_X509_FMT_PEM)) < 0) - throw ModuleException("Unable to load GnuTLS server private key (" + keyfile + "): " + std::string(gnutls_strerror(ret))); + static std::string RemoveUnknownTokens(const std::string& prio) + { + std::string ret; + irc::sepstream ss(prio, ':'); + for (std::string token; ss.GetToken(token); ) + { + // Save current position so we can revert later if needed + const std::string::size_type prevpos = ret.length(); + // Append next token + if (!ret.empty()) + ret.push_back(':'); + ret.append(token); + + gnutls_priority_t test; + if (gnutls_priority_init(&test, ret.c_str(), NULL) < 0) + { + // The new token broke the priority string, revert to the previously working one + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Priority string token not recognized: \"%s\"", token.c_str()); + ret.erase(prevpos); + } + else + { + // Worked + gnutls_priority_deinit(test); + } + } + return ret; + } + }; +#else + /** Dummy class, used when gnutls_priority_set() is not available + */ + class Priority + { + public: + Priority(const std::string& priorities) + { + if (priorities != GetDefault()) + throw Exception("You've set a non-default priority string, but GnuTLS lacks support for it"); + } - if((ret = gnutls_certificate_set_x509_key(x509_cred, &x509_certs[0], certcount, x509_key)) < 0) - throw ModuleException("Unable to set GnuTLS cert/key pair: " + std::string(gnutls_strerror(ret))); + static void SetupSession(gnutls_session_t sess) + { + // Always set the default priorities + gnutls_set_default_priority(sess); + } - #ifdef GNUTLS_NEW_PRIO_API - // Try to set the priorities for ciphers, kex methods etc. to the user supplied string - // If the user did not supply anything then the string is already set to "NORMAL" - const char* priocstr = priorities.c_str(); - const char* prioerror; + static const char* GetDefault() + { + return "NORMAL"; + } - gnutls_priority_t& priority = config->priority; - if ((ret = gnutls_priority_init(&priority, priocstr, &prioerror)) < 0) + static std::string RemoveUnknownTokens(const std::string& prio) { - // gnutls did not understand the user supplied string, log and fall back to the default priorities - ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set priorities to \"%s\": %s Syntax error at position %u, falling back to default (NORMAL)", priorities.c_str(), gnutls_strerror(ret), (unsigned int) (prioerror - priocstr)); - gnutls_priority_init(&priority, "NORMAL", NULL); + // We don't do anything here because only NORMAL is accepted + return prio; } + }; +#endif - #else - if (priorities != "NORMAL") - ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: You've set <gnutls:priority> to a value other than the default, but this is only supported with GnuTLS v2.1.7 or newer. Your GnuTLS version is older than that so the option will have no effect."); - #endif + class CertCredentials + { + /** DH parameters associated with these credentials + */ + std::auto_ptr<DHParams> dh; - #if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) ) - gnutls_certificate_client_set_retrieve_function (x509_cred, cert_callback); - #else - gnutls_certificate_set_retrieve_function (x509_cred, cert_callback); - #endif + protected: + gnutls_certificate_credentials_t cred; - gnutls_dh_params_t& dh_params = config->dh_params; - ret = gnutls_dh_params_init(&dh_params); - if (ret < 0) + public: + CertCredentials() { - // Make sure the destructor does not try to deallocate this, see above - dh_params = NULL; - ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret)); - return; + ThrowOnError(gnutls_certificate_allocate_credentials(&cred), "Cannot allocate certificate credentials"); } - std::string dhfile = Conf->getString("dhfile"); - if (!dhfile.empty()) + ~CertCredentials() { - // Try to load DH params from file - reader.LoadFile(dhfile); - std::string dhstring = reader.Contents(); - gnutls_datum_t dh_datum = { (unsigned char*)dhstring.data(), static_cast<unsigned int>(dhstring.length()) }; - - if ((ret = gnutls_dh_params_import_pkcs3(dh_params, &dh_datum, GNUTLS_X509_FMT_PEM)) < 0) - { - // File unreadable or GnuTLS was unhappy with the contents, generate the DH primes now - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Generating DH parameters because I failed to load them from file '%s': %s", dhfile.c_str(), gnutls_strerror(ret)); - GenerateDHParams(dh_params); - } + gnutls_certificate_free_credentials(cred); } - else + + /** Associates these credentials with the session + */ + void SetupSession(gnutls_session_t sess) { - GenerateDHParams(dh_params); + gnutls_credentials_set(sess, GNUTLS_CRD_CERTIFICATE, cred); } - gnutls_certificate_set_dh_params(x509_cred, dh_params); - } + /** Set the given DH parameters to be used with these credentials + */ + void SetDH(std::auto_ptr<DHParams>& DH) + { + dh = DH; + gnutls_certificate_set_dh_params(cred, dh->get()); + } + }; - void GenerateDHParams(gnutls_dh_params_t dh_params) + class X509Credentials : public CertCredentials { - // Generate Diffie Hellman parameters - for use with DHE - // kx algorithms. These should be discarded and regenerated - // once a day, once a week or once a month. Depending on the - // security requirements. + /** Private key + */ + X509Key key; - int ret; + /** Certificate list, presented to the peer + */ + X509CertList certs; - if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0) - ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret)); - } + /** Trusted CA, may be NULL + */ + std::auto_ptr<X509CertList> trustedca; - ~ModuleSSLGnuTLS() - { - currconf = NULL; + /** Certificate revocation list, may be NULL + */ + std::auto_ptr<X509CRL> crl; - gnutls_global_deinit(); - delete[] sessions; - ServerInstance->GenRandom = &ServerInstance->HandleGenRandom; - } + static int cert_callback(gnutls_session_t session, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st); - void OnCleanup(int target_type, void* item) - { - if(target_type == TYPE_USER) + public: + X509Credentials(const std::string& certstr, const std::string& keystr) + : key(keystr) + , certs(certstr) { - LocalUser* user = IS_LOCAL(static_cast<User*>(item)); + // Throwing is ok here, the destructor of Credentials is called in that case + int ret = gnutls_certificate_set_x509_key(cred, certs.raw(), certs.size(), key.get()); + ThrowOnError(ret, "Unable to set cert/key pair"); - if (user && user->eh.GetIOHook() == this) - { - // User is using SSL, they're a local user, and they're using one of *our* SSL ports. - // Potentially there could be multiple SSL modules loaded at once on different ports. - ServerInstance->Users->QuitUser(user, "SSL module unloading"); - } +#ifdef GNUTLS_NEW_CERT_CALLBACK_API + gnutls_certificate_set_retrieve_function(cred, cert_callback); +#else + gnutls_certificate_client_set_retrieve_function(cred, cert_callback); +#endif } - } - Version GetVersion() - { - return Version("Provides SSL support for clients", VF_VENDOR); - } + /** Sets the trusted CA and the certificate revocation list + * to use when verifying certificates + */ + void SetCA(std::auto_ptr<X509CertList>& certlist, std::auto_ptr<X509CRL>& CRL) + { + // Do nothing if certlist is NULL + if (certlist.get()) + { + int ret = gnutls_certificate_set_x509_trust(cred, certlist->raw(), certlist->size()); + ThrowOnError(ret, "gnutls_certificate_set_x509_trust() failed"); + if (CRL.get()) + { + ret = gnutls_certificate_set_x509_crl(cred, &CRL->get(), 1); + ThrowOnError(ret, "gnutls_certificate_set_x509_crl() failed"); + } - void On005Numeric(std::string &output) - { - if (!sslports.empty()) - output.append(" SSL=" + sslports); - if (starttls.enabled) - output.append(" STARTTLS"); - } + trustedca = certlist; + crl = CRL; + } + } + }; - void OnHookIO(StreamSocket* user, ListenSocket* lsb) + class DataReader { - if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "gnutls") + int retval; +#ifdef INSPIRCD_GNUTLS_HAS_RECV_PACKET + gnutls_packet_t packet; + + public: + DataReader(gnutls_session_t sess) { - /* Hook the user with our module */ - user->AddIOHook(this); + // Using the packet API avoids the final copy of the data which GnuTLS does if we supply + // our own buffer. Instead, we get the buffer containing the data from GnuTLS and copy it + // to the recvq directly from there in appendto(). + retval = gnutls_record_recv_packet(sess, &packet); } - } - void OnRequest(Request& request) - { - if (strcmp("GET_SSL_CERT", request.id) == 0) + void appendto(std::string& recvq) { - SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request); - int fd = req.sock->GetFd(); - issl_session* session = &sessions[fd]; + // Copy data from GnuTLS buffers to recvq + gnutls_datum_t datum; + gnutls_packet_get(packet, &datum, NULL); + recvq.append(reinterpret_cast<const char*>(datum.data), datum.size); - req.cert = session->cert; + gnutls_packet_deinit(packet); } - else if (!strcmp("GET_RAW_SSL_SESSION", request.id)) +#else + char* const buffer; + + public: + DataReader(gnutls_session_t sess) + : buffer(ServerInstance->GetReadBuffer()) { - SSLRawSessionRequest& req = static_cast<SSLRawSessionRequest&>(request); - if ((req.fd >= 0) && (req.fd < ServerInstance->SE->GetMaxFds())) - req.data = reinterpret_cast<void*>(sessions[req.fd].sess); + // Read data from GnuTLS buffers into ReadBuffer + retval = gnutls_record_recv(sess, buffer, ServerInstance->Config->NetBufferSize); } - } - - void InitSession(StreamSocket* user, bool me_server) - { - issl_session* session = &sessions[user->GetFd()]; - gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->socket = user; - session->config = currconf; - - #ifdef GNUTLS_NEW_PRIO_API - gnutls_priority_set(session->sess, currconf->priority); - #else - gnutls_set_default_priority(session->sess); - #endif - gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, currconf->x509_cred); - gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session)); - gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); - gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); - - if (me_server) - gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. + void appendto(std::string& recvq) + { + // Copy data from ReadBuffer to recvq + recvq.append(buffer, retval); + } +#endif - Handshake(session, user); - } + int ret() const { return retval; } + }; - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) + class Profile : public refcountbase { - issl_session* session = &sessions[user->GetFd()]; + /** Name of this profile + */ + const std::string name; - /* For STARTTLS: Don't try and init a session on a socket that already has a session */ - if (session->sess) - return; + /** X509 certificate(s) and key + */ + X509Credentials x509cred; - InitSession(user, true); - } + /** The minimum length in bits for the DH prime to be accepted as a client + */ + unsigned int min_dh_bits; - void OnStreamSocketConnect(StreamSocket* user) - { - InitSession(user, false); - } + /** Hashing algorithm to use when generating certificate fingerprints + */ + Hash hash; - void OnStreamSocketClose(StreamSocket* user) - { - CloseSession(&sessions[user->GetFd()]); - } + /** Priorities for ciphers, compression methods, etc. + */ + Priority priority; - int OnStreamSocketRead(StreamSocket* user, std::string& recvq) - { - issl_session* session = &sessions[user->GetFd()]; + /** Rough max size of records to send + */ + const unsigned int outrecsize; - if (!session->sess) + /** True to request a client certificate as a server + */ + const bool requestclientcert; + + Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr, + std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr, + const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL, + unsigned int recsize, bool Requestclientcert) + : name(profilename) + , x509cred(certstr, keystr) + , min_dh_bits(mindh) + , hash(hashstr) + , priority(priostr) + , outrecsize(recsize) + , requestclientcert(Requestclientcert) { - CloseSession(session); - user->SetError("No SSL session"); - return -1; + x509cred.SetDH(DH); + x509cred.SetCA(CA, CRL); } - if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) + static std::string ReadFile(const std::string& filename) { - // The handshake isn't finished, try to finish it. + FileReader reader(filename); + std::string ret = reader.GetString(); + if (ret.empty()) + throw Exception("Cannot read file " + filename); + return ret; + } - if(!Handshake(session, user)) + static std::string GetPrioStr(const std::string& profilename, ConfigTag* tag) + { + // Use default priority string if this tag does not specify one + std::string priostr = GnuTLS::Priority::GetDefault(); + bool found = tag->readString("priority", priostr); + // If the prio string isn't set in the config don't be strict about the default one because it doesn't work on all versions of GnuTLS + if (!tag->getBool("strictpriority", found)) { - if (session->status != ISSL_CLOSING) - return 0; - return -1; + std::string stripped = GnuTLS::Priority::RemoveUnknownTokens(priostr); + if (stripped.empty()) + { + // Stripping failed, act as if a prio string wasn't set + stripped = GnuTLS::Priority::RemoveUnknownTokens(GnuTLS::Priority::GetDefault()); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Priority string for profile \"%s\" contains unknown tokens and stripping it didn't yield a working one either, falling back to \"%s\"", profilename.c_str(), stripped.c_str()); + } + else if ((found) && (stripped != priostr)) + { + // Prio string was set in the config and we ended up with something that works but different + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Priority string for profile \"%s\" contains unknown tokens, stripped to \"%s\"", profilename.c_str(), stripped.c_str()); + } + priostr.swap(stripped); } + return priostr; } - // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. - - if (session->status == ISSL_HANDSHAKEN) + public: + static reference<Profile> Create(const std::string& profilename, ConfigTag* tag) { - char* buffer = ServerInstance->GetReadBuffer(); - size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = gnutls_record_recv(session->sess, buffer, bufsiz); - if (ret > 0) - { - recvq.append(buffer, ret); - // Schedule a read if there is still data in the GnuTLS buffer - if (gnutls_record_check_pending(session->sess) > 0) - ServerInstance->SE->ChangeEventMask(user, FD_ADD_TRIAL_READ); - return 1; - } - else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) - { - return 0; - } - else if (ret == 0) - { - user->SetError("Connection closed"); - CloseSession(session); - return -1; - } - else + std::string certstr = ReadFile(tag->getString("certfile", "cert.pem")); + std::string keystr = ReadFile(tag->getString("keyfile", "key.pem")); + + std::auto_ptr<DHParams> dh = DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem"))); + + std::string priostr = GetPrioStr(profilename, tag); + unsigned int mindh = tag->getInt("mindhbits", 1024); + std::string hashstr = tag->getString("hash", "md5"); + + // Load trusted CA and revocation list, if set + std::auto_ptr<X509CertList> ca; + std::auto_ptr<X509CRL> crl; + std::string filename = tag->getString("cafile"); + if (!filename.empty()) { - user->SetError(gnutls_strerror(ret)); - CloseSession(session); - return -1; + ca.reset(new X509CertList(ReadFile(filename))); + + filename = tag->getString("crlfile"); + if (!filename.empty()) + crl.reset(new X509CRL(ReadFile(filename))); } - } - else if (session->status == ISSL_CLOSING) - return -1; - return 0; - } +#ifdef INSPIRCD_GNUTLS_HAS_CORK + // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked + unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512); +#else + unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384); +#endif - int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) - { - issl_session* session = &sessions[user->GetFd()]; + const bool requestclientcert = tag->getBool("requestclientcert", true); - if (!session->sess) - { - CloseSession(session); - user->SetError("No SSL session"); - return -1; + return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert); } - if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ) + /** Set up the given session with the settings in this profile + */ + void SetupSession(gnutls_session_t sess) { - // The handshake isn't finished, try to finish it. - Handshake(session, user); - if (session->status != ISSL_CLOSING) - return 0; - return -1; + priority.SetupSession(sess); + x509cred.SetupSession(sess); + gnutls_dh_set_prime_bits(sess, min_dh_bits); + + // Request client certificate if enabled and we are a server, no-op if we're a client + if (requestclientcert) + gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST); } - int ret = 0; + const std::string& GetName() const { return name; } + X509Credentials& GetX509Credentials() { return x509cred; } + gnutls_digest_algorithm_t GetHash() const { return hash.get(); } + unsigned int GetOutgoingRecordSize() const { return outrecsize; } + }; +} - if (session->status == ISSL_HANDSHAKEN) - { - ret = gnutls_record_send(session->sess, sendq.data(), sendq.length()); +class GnuTLSIOHook : public SSLIOHook +{ + private: + gnutls_session_t sess; + issl_status status; + reference<GnuTLS::Profile> profile; +#ifdef INSPIRCD_GNUTLS_HAS_CORK + size_t gbuffersize; +#endif - if (ret == (int)sendq.length()) - { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_WRITE); - return 1; - } - else if (ret > 0) - { - sendq = sendq.substr(ret); - ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); - return 0; - } - else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED || ret == 0) - { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); - return 0; - } - else // (ret < 0) - { - user->SetError(gnutls_strerror(ret)); - CloseSession(session); - return -1; - } + void CloseSession() + { + if (this->sess) + { + gnutls_bye(this->sess, GNUTLS_SHUT_WR); + gnutls_deinit(this->sess); } - - return 0; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; } - bool Handshake(issl_session* session, StreamSocket* user) + // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed + int Handshake(StreamSocket* user) { - int ret = gnutls_handshake(session->sess); + int ret = gnutls_handshake(this->sess); if (ret < 0) { if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { // Handshake needs resuming later, read() or write() would have blocked. + this->status = ISSL_HANDSHAKING; - if(gnutls_record_get_direction(session->sess) == 0) + if (gnutls_record_get_direction(this->sess) == 0) { // gnutls_handshake() wants to read() again. - session->status = ISSL_HANDSHAKING_READ; - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); } else { // gnutls_handshake() wants to write() again. - session->status = ISSL_HANDSHAKING_WRITE; - ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); + SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); } + + return 0; } else { user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret))); - CloseSession(session); - session->status = ISSL_CLOSING; + CloseSession(); + return -1; } - - return false; } else { // Change the seesion state - session->status = ISSL_HANDSHAKEN; + this->status = ISSL_HANDSHAKEN; - VerifyCertificate(session,user); + VerifyCertificate(); // Finish writing, if any left - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); - - return true; - } - } - - void OnUserConnect(LocalUser* user) - { - if (user->eh.GetIOHook() == this) - { - if (sessions[user->eh.GetFd()].sess) - { - const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess; - std::string cipher = UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess))); - cipher.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-"); - cipher.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))); - - ssl_cert* cert = sessions[user->eh.GetFd()].cert; - if (cert->fingerprint.empty()) - user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str()); - else - user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"" - " and your SSL fingerprint is %s", user->nick.c_str(), cipher.c_str(), cert->fingerprint.c_str()); - } - } - } + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); - void CloseSession(issl_session* session) - { - if (session->sess) - { - gnutls_bye(session->sess, GNUTLS_SHUT_WR); - gnutls_deinit(session->sess); + return 1; } - session->socket = NULL; - session->sess = NULL; - session->cert = NULL; - session->status = ISSL_NONE; - session->config = NULL; } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - - unsigned int status; + unsigned int certstatus; const gnutls_datum_t* cert_list; int ret; unsigned int cert_list_size; gnutls_x509_crt_t cert; - char name[MAXBUF]; - unsigned char digest[MAXBUF]; + char str[512]; + unsigned char digest[512]; size_t digest_size = sizeof(digest); - size_t name_size = sizeof(name); + size_t name_size = sizeof(str); ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. */ - ret = gnutls_certificate_verify_peers2(session->sess, &status); + ret = gnutls_certificate_verify_peers2(this->sess, &certstatus); if (ret < 0) { @@ -893,16 +787,16 @@ class ModuleSSLGnuTLS : public Module return; } - certinfo->invalid = (status & GNUTLS_CERT_INVALID); - certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND); - certinfo->revoked = (status & GNUTLS_CERT_REVOKED); - certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA); + certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID); + certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND); + certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED); + certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA); /* Up to here the process is the same for X.509 certificates and * OpenPGP keys. From now on X.509 certificates are assumed. This can * be easily extended to work with openpgp keys as well. */ - if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) + if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; return; @@ -916,7 +810,7 @@ class ModuleSSLGnuTLS : public Module } cert_list_size = 0; - cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); + cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size); if (cert_list == NULL) { certinfo->error = "No certificate was found"; @@ -934,31 +828,31 @@ class ModuleSSLGnuTLS : public Module goto info_done_dealloc; } - if (gnutls_x509_crt_get_dn(cert, name, &name_size) == 0) + if (gnutls_x509_crt_get_dn(cert, str, &name_size) == 0) { std::string& dn = certinfo->dn; - dn = name; + dn = str; // Make sure there are no chars in the string that we consider invalid if (dn.find_first_of("\r\n") != std::string::npos) dn.clear(); } - name_size = sizeof(name); - if (gnutls_x509_crt_get_issuer_dn(cert, name, &name_size) == 0) + name_size = sizeof(str); + if (gnutls_x509_crt_get_issuer_dn(cert, str, &name_size) == 0) { std::string& issuer = certinfo->issuer; - issuer = name; + issuer = str; if (issuer.find_first_of("\r\n") != std::string::npos) issuer.clear(); } - if ((ret = gnutls_x509_crt_get_fingerprint(cert, hash, digest, &digest_size)) < 0) + if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0) { certinfo->error = gnutls_strerror(ret); } else { - certinfo->fingerprint = irc::hex(digest, digest_size); + certinfo->fingerprint = BinToHex(digest, digest_size); } /* Beware here we do not check for errors. @@ -972,15 +866,490 @@ info_done_dealloc: gnutls_x509_crt_deinit(cert); } - void OnEvent(Event& ev) + // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error + int PrepareIO(StreamSocket* sock) + { + if (status == ISSL_HANDSHAKEN) + return 1; + else if (status == ISSL_HANDSHAKING) + { + // The handshake isn't finished, try to finish it + return Handshake(sock); + } + + CloseSession(); + sock->SetError("No SSL session"); + return -1; + } + +#ifdef INSPIRCD_GNUTLS_HAS_CORK + int FlushBuffer(StreamSocket* sock) + { + // If GnuTLS has some data buffered, write it + if (gbuffersize) + return HandleWriteRet(sock, gnutls_record_uncork(this->sess, 0)); + return 1; + } +#endif + + int HandleWriteRet(StreamSocket* sock, int ret) + { + if (ret > 0) + { +#ifdef INSPIRCD_GNUTLS_HAS_CORK + gbuffersize -= ret; + if (gbuffersize) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE); + return 0; + } +#endif + return ret; + } + else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED || ret == 0) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE); + return 0; + } + else // (ret < 0) + { + sock->SetError(gnutls_strerror(ret)); + CloseSession(); + return -1; + } + } + + static const char* UnknownIfNULL(const char* str) + { + return str ? str : "UNKNOWN"; + } + + static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) + { + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod)); +#endif + + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) + { +#ifdef _WIN32 + gnutls_transport_set_errno(session->sess, EAGAIN); +#else + errno = EAGAIN; +#endif + return -1; + } + + int rv = SocketEngine::Recv(sock, reinterpret_cast<char *>(buffer), size, 0); + +#ifdef _WIN32 + if (rv < 0) + { + /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() + * and then set errno appropriately. + * The gnutls library may also have a different errno variable than us, see + * gnutls_transport_set_errno(3). + */ + gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + } +#endif + + if (rv < (int)size) + SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK); + return rv; + } + +#ifdef INSPIRCD_GNUTLS_HAS_VECTOR_PUSH + static ssize_t VectorPush(gnutls_transport_ptr_t transportptr, const giovec_t* iov, int iovcnt) + { + StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod)); +#endif + + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) + { +#ifdef _WIN32 + gnutls_transport_set_errno(session->sess, EAGAIN); +#else + errno = EAGAIN; +#endif + return -1; + } + + // Cast the giovec_t to iovec not to IOVector so the correct function is called on Windows + int ret = SocketEngine::WriteV(sock, reinterpret_cast<const iovec*>(iov), iovcnt); +#ifdef _WIN32 + // See the function above for more info about the usage of gnutls_transport_set_errno() on Windows + if (ret < 0) + gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); +#endif + + int size = 0; + for (int i = 0; i < iovcnt; i++) + size += iov[i].iov_len; + + if (ret < size) + SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); + return ret; + } + +#else // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH + static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) + { + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod)); +#endif + + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) + { +#ifdef _WIN32 + gnutls_transport_set_errno(session->sess, EAGAIN); +#else + errno = EAGAIN; +#endif + return -1; + } + + int rv = SocketEngine::Send(sock, reinterpret_cast<const char *>(buffer), size, 0); + +#ifdef _WIN32 + if (rv < 0) + { + /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() + * and then set errno appropriately. + * The gnutls library may also have a different errno variable than us, see + * gnutls_transport_set_errno(3). + */ + gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + } +#endif + + if (rv < (int)size) + SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); + return rv; + } +#endif // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH + + public: + GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags, const reference<GnuTLS::Profile>& sslprofile) + : SSLIOHook(hookprov) + , sess(NULL) + , status(ISSL_NONE) + , profile(sslprofile) +#ifdef INSPIRCD_GNUTLS_HAS_CORK + , gbuffersize(0) +#endif + { + gnutls_init(&sess, flags); + gnutls_transport_set_ptr(sess, reinterpret_cast<gnutls_transport_ptr_t>(sock)); +#ifdef INSPIRCD_GNUTLS_HAS_VECTOR_PUSH + gnutls_transport_set_vec_push_function(sess, VectorPush); +#else + gnutls_transport_set_push_function(sess, gnutls_push_wrapper); +#endif + gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper); + profile->SetupSession(sess); + + sock->AddIOHook(this); + Handshake(sock); + } + + void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE + { + CloseSession(); + } + + int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE + { + // Finish handshake if needed + int prepret = PrepareIO(user); + if (prepret <= 0) + return prepret; + + // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN. + { + GnuTLS::DataReader reader(sess); + int ret = reader.ret(); + if (ret > 0) + { + reader.appendto(recvq); + // Schedule a read if there is still data in the GnuTLS buffer + if (gnutls_record_check_pending(sess) > 0) + SocketEngine::ChangeEventMask(user, FD_ADD_TRIAL_READ); + return 1; + } + else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) + { + return 0; + } + else if (ret == 0) + { + user->SetError("Connection closed"); + CloseSession(); + return -1; + } + else + { + user->SetError(gnutls_strerror(ret)); + CloseSession(); + return -1; + } + } + } + + int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE + { + // Finish handshake if needed + int prepret = PrepareIO(user); + if (prepret <= 0) + return prepret; + + // Session is ready for transferring application data + +#ifdef INSPIRCD_GNUTLS_HAS_CORK + while (true) + { + // If there is something in the GnuTLS buffer try to send() it + int ret = FlushBuffer(user); + if (ret <= 0) + return ret; // Couldn't flush entire buffer, retry later (or close on error) + + // GnuTLS buffer is empty, if the sendq is empty as well then break to set FD_WANT_NO_WRITE + if (sendq.empty()) + break; + + // GnuTLS buffer is empty but sendq is not, begin sending data from the sendq + gnutls_record_cork(this->sess); + while ((!sendq.empty()) && (gbuffersize < profile->GetOutgoingRecordSize())) + { + const StreamSocket::SendQueue::Element& elem = sendq.front(); + gbuffersize += elem.length(); + ret = gnutls_record_send(this->sess, elem.data(), elem.length()); + if (ret < 0) + { + CloseSession(); + return -1; + } + sendq.pop_front(); + } + } +#else + int ret = 0; + + while (!sendq.empty()) + { + FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + const StreamSocket::SendQueue::Element& buffer = sendq.front(); + ret = HandleWriteRet(user, gnutls_record_send(this->sess, buffer.data(), buffer.length())); + + if (ret <= 0) + return ret; + else if (ret < (int)buffer.length()) + { + sendq.erase_front(ret); + SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE); + return 0; + } + + // Wrote entire record, continue sending + sendq.pop_front(); + } +#endif + + SocketEngine::ChangeEventMask(user, FD_WANT_NO_WRITE); + return 1; + } + + void GetCiphersuite(std::string& out) const CXX11_OVERRIDE + { + if (!IsHandshakeDone()) + return; + out.append(UnknownIfNULL(gnutls_protocol_get_name(gnutls_protocol_get_version(sess)))).push_back('-'); + out.append(UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)))).push_back('-'); + out.append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).push_back('-'); + out.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))); + } + + GnuTLS::Profile* GetProfile() { return profile; } + bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); } +}; + +int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st) +{ +#ifndef GNUTLS_NEW_CERT_CALLBACK_API + st->type = GNUTLS_CRT_X509; +#else + st->cert_type = GNUTLS_CRT_X509; + st->key_type = GNUTLS_PRIVKEY_X509; +#endif + StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess)); + GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials(); + + st->ncerts = cred.certs.size(); + st->cert.x509 = cred.certs.raw(); + st->key.x509 = cred.key.get(); + st->deinit_all = 0; + + return 0; +} + +class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<GnuTLS::Profile> profile; + + public: + GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~GnuTLSIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, GNUTLS_SERVER, profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, GNUTLS_CLIENT, profile); + } +}; + +class ModuleSSLGnuTLS : public Module +{ + typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList; + + // First member of the class, gets constructed first and destructed last + GnuTLS::Init libinit; + RandGen randhandler; + ProfileList profiles; + + void ReadProfiles() + { + // First, store all profiles in a new, temporary container. If no problems occur, swap the two + // containers; this way if something goes wrong we can go back and continue using the current profiles, + // avoiding unpleasant situations where no new SSL connections are possible. + ProfileList newprofiles; + + ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile"); + if (tags.first == tags.second) + { + // No <sslprofile> tags found, create a profile named "gnutls" from settings in the <gnutls> block + const std::string defname = "gnutls"; + ConfigTag* tag = ServerInstance->Config->ConfValue(defname); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <gnutls> tag"); + + try + { + reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag)); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); + } + catch (CoreException& ex) + { + throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason()); + } + } + + for (ConfigIter i = tags.first; i != tags.second; ++i) + { + ConfigTag* tag = i->second; + if (tag->getString("provider") != "gnutls") + continue; + + std::string name = tag->getString("name"); + if (name.empty()) + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation()); + continue; + } + + reference<GnuTLS::Profile> profile; + try + { + profile = GnuTLS::Profile::Create(name, tag); + } + catch (CoreException& ex) + { + throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); + } + + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); + } + + // New profiles are ok, begin using them + // Old profiles are deleted when their refcount drops to zero + profiles.swap(newprofiles); + } + + public: + ModuleSSLGnuTLS() + { +#ifndef GNUTLS_HAS_RND + gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); +#endif + thismod = this; + } + + void init() CXX11_OVERRIDE + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "GnuTLS lib version %s module was compiled for " GNUTLS_VERSION, gnutls_check_version(NULL)); + ReadProfiles(); + ServerInstance->GenRandom = &randhandler; + } + + void OnModuleRehash(User* user, const std::string ¶m) CXX11_OVERRIDE + { + if(param != "ssl") + return; + + try + { + ReadProfiles(); + } + catch (ModuleException& ex) + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings."); + } + } + + ~ModuleSSLGnuTLS() + { + ServerInstance->GenRandom = &ServerInstance->HandleGenRandom; + } + + void OnCleanup(int target_type, void* item) CXX11_OVERRIDE + { + if(target_type == TYPE_USER) + { + LocalUser* user = IS_LOCAL(static_cast<User*>(item)); + + if ((user) && (user->eh.GetModHook(this))) + { + // User is using SSL, they're a local user, and they're using one of *our* SSL ports. + // Potentially there could be multiple SSL modules loaded at once on different ports. + ServerInstance->Users->QuitUser(user, "SSL module unloading"); + } + } + } + + Version GetVersion() CXX11_OVERRIDE { - if (starttls.enabled) - capHandler.HandleEvent(ev); + return Version("Provides SSL support for clients", VF_VENDOR); } - ModResult OnCheckReady(LocalUser* user) + ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE { - if ((user->eh.GetIOHook() == this) && (sessions[user->eh.GetFd()].status != ISSL_HANDSHAKEN)) + const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) return MOD_RES_DENY; return MOD_RES_PASSTHRU; } diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp new file mode 100644 index 000000000..ffe0a71b8 --- /dev/null +++ b/src/modules/extra/m_ssl_mbedtls.cpp @@ -0,0 +1,932 @@ +/* + * InspIRCd -- Internet Relay Chat Daemon + * + * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com> + * + * This file is part of InspIRCd. InspIRCd is free software: you can + * redistribute it and/or modify it under the terms of the GNU General Public + * License as published by the Free Software Foundation, version 2. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + + +/* $LinkerFlags: -lmbedtls */ + +#include "inspircd.h" +#include "modules/ssl.h" + +#include <mbedtls/ctr_drbg.h> +#include <mbedtls/dhm.h> +#include <mbedtls/ecp.h> +#include <mbedtls/entropy.h> +#include <mbedtls/error.h> +#include <mbedtls/md.h> +#include <mbedtls/pk.h> +#include <mbedtls/ssl.h> +#include <mbedtls/ssl_ciphersuites.h> +#include <mbedtls/version.h> +#include <mbedtls/x509.h> +#include <mbedtls/x509_crt.h> +#include <mbedtls/x509_crl.h> + +#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG +#include <mbedtls/debug.h> +#endif + +namespace mbedTLS +{ + class Exception : public ModuleException + { + public: + Exception(const std::string& reason) + : ModuleException(reason) { } + }; + + std::string ErrorToString(int errcode) + { + char buf[256]; + mbedtls_strerror(errcode, buf, sizeof(buf)); + return buf; + } + + void ThrowOnError(int errcode, const char* msg) + { + if (errcode != 0) + { + std::string reason = msg; + reason.append(" :").append(ErrorToString(errcode)); + throw Exception(reason); + } + } + + template <typename T, void (*init)(T*), void (*deinit)(T*)> + class RAIIObj + { + T obj; + + public: + RAIIObj() + { + init(&obj); + } + + ~RAIIObj() + { + deinit(&obj); + } + + T* get() { return &obj; } + const T* get() const { return &obj; } + }; + + typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy; + + class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free> + { + public: + bool Seed(Entropy& entropy) + { + return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0); + } + + void SetupConf(mbedtls_ssl_config* conf) + { + mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get()); + } + }; + + class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free> + { + public: + void set(const std::string& dhstr) + { + // Last parameter is buffer size, must include the terminating null + int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1); + ThrowOnError(ret, "Unable to import DH params"); + } + }; + + class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free> + { + public: + /** Import */ + X509Key(const std::string& keystr) + { + int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0); + ThrowOnError(ret, "Unable to import private key"); + } + }; + + class Ciphersuites + { + std::vector<int> list; + + public: + Ciphersuites(const std::string& str) + { + // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally. + // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts. + irc::sepstream ss(str, ':'); + for (std::string token; ss.GetToken(token); ) + { + // Prepend "TLS-" if not there + if (token.compare(0, 4, "TLS-", 4)) + token.insert(0, "TLS-"); + + const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str()); + if (!id) + throw Exception("Unknown ciphersuite " + token); + list.push_back(id); + } + list.push_back(0); + } + + const int* get() const { return &list.front(); } + bool empty() const { return (list.size() <= 1); } + }; + + class Curves + { + std::vector<mbedtls_ecp_group_id> list; + + public: + Curves(const std::string& str) + { + irc::sepstream ss(str, ':'); + for (std::string token; ss.GetToken(token); ) + { + const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str()); + if (!curve) + throw Exception("Unknown curve " + token); + list.push_back(curve->grp_id); + } + list.push_back(MBEDTLS_ECP_DP_NONE); + } + + const mbedtls_ecp_group_id* get() const { return &list.front(); } + bool empty() const { return (list.size() <= 1); } + }; + + class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free> + { + public: + /** Import or create empty */ + X509CertList(const std::string& certstr, bool allowempty = false) + { + if ((allowempty) && (certstr.empty())) + return; + int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1); + ThrowOnError(ret, "Unable to load certificates"); + } + + bool empty() const { return (get()->raw.p != NULL); } + }; + + class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free> + { + public: + X509CRL(const std::string& crlstr) + { + if (crlstr.empty()) + return; + int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1); + ThrowOnError(ret, "Unable to load CRL"); + } + }; + + class X509Credentials + { + /** Private key + */ + X509Key key; + + /** Certificate list, presented to the peer + */ + X509CertList certs; + + public: + X509Credentials(const std::string& certstr, const std::string& keystr) + : key(keystr) + , certs(certstr) + { + // Verify that one of the certs match the private key + bool found = false; + for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next) + { + if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0) + { + found = true; + break; + } + } + if (!found) + throw Exception("Public/private key pair does not match"); + } + + mbedtls_pk_context* getkey() { return key.get(); } + mbedtls_x509_crt* getcerts() { return certs.get(); } + }; + + class Context + { + mbedtls_ssl_config conf; + +#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG + static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg) + { + // Remove trailing \n + size_t len = strlen(msg); + if ((len > 0) && (msg[len-1] == '\n')) + len--; + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg); + } +#endif + + public: + Context(CTRDRBG& ctrdrbg, unsigned int endpoint) + { + mbedtls_ssl_config_init(&conf); +#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG + mbedtls_debug_set_threshold(INT_MAX); + mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL); +#endif + + // TODO: check ret of mbedtls_ssl_config_defaults + mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + ctrdrbg.SetupConf(&conf); + } + + ~Context() + { + mbedtls_ssl_config_free(&conf); + } + + void SetMinDHBits(unsigned int mindh) + { + mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh); + } + + void SetDHParams(DHParams& dh) + { + mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get()); + } + + void SetX509CertAndKey(X509Credentials& x509cred) + { + mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey()); + } + + void SetCiphersuites(const Ciphersuites& ciphersuites) + { + mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get()); + } + + void SetCurves(const Curves& curves) + { + mbedtls_ssl_conf_curves(&conf, curves.get()); + } + + void SetVersion(int minver, int maxver) + { + // SSL v3 support cannot be enabled + if (minver) + mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver); + if (maxver) + mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver); + } + + void SetCA(X509CertList& certs, X509CRL& crl) + { + mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get()); + } + + void SetOptionalVerifyCert() + { + mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + } + + const mbedtls_ssl_config* GetConf() const { return &conf; } + }; + + class Hash + { + const mbedtls_md_info_t* md; + + /** Buffer where cert hashes are written temporarily + */ + mutable std::vector<unsigned char> buf; + + public: + Hash(std::string hashstr) + { + std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper); + md = mbedtls_md_info_from_string(hashstr.c_str()); + if (!md) + throw Exception("Unknown hash: " + hashstr); + + buf.resize(mbedtls_md_get_size(md)); + } + + std::string hash(const unsigned char* input, size_t length) const + { + mbedtls_md(md, input, length, &buf.front()); + return BinToHex(&buf.front(), buf.size()); + } + }; + + class Profile : public refcountbase + { + /** Name of this profile + */ + const std::string name; + + X509Credentials x509cred; + + /** Ciphersuites to use + */ + Ciphersuites ciphersuites; + + /** Curves accepted for use in ECDHE and in the peer's end-entity certificate + */ + Curves curves; + + Context serverctx; + Context clientctx; + + DHParams dhparams; + + X509CertList cacerts; + + X509CRL crl; + + /** Hashing algorithm to use when generating certificate fingerprints + */ + Hash hash; + + /** Rough max size of records to send + */ + const unsigned int outrecsize; + + Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr, + const std::string& dhstr, unsigned int mindh, const std::string& hashstr, + const std::string& ciphersuitestr, const std::string& curvestr, + const std::string& castr, const std::string& crlstr, + unsigned int recsize, + CTRDRBG& ctrdrbg, + int minver, int maxver, + bool requestclientcert + ) + : name(profilename) + , x509cred(certstr, keystr) + , ciphersuites(ciphersuitestr) + , curves(curvestr) + , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER) + , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT) + , cacerts(castr, true) + , crl(crlstr) + , hash(hashstr) + , outrecsize(recsize) + { + serverctx.SetX509CertAndKey(x509cred); + clientctx.SetX509CertAndKey(x509cred); + clientctx.SetMinDHBits(mindh); + + if (!ciphersuites.empty()) + { + serverctx.SetCiphersuites(ciphersuites); + clientctx.SetCiphersuites(ciphersuites); + } + + if (!curves.empty()) + { + serverctx.SetCurves(curves); + clientctx.SetCurves(curves); + } + + serverctx.SetVersion(minver, maxver); + clientctx.SetVersion(minver, maxver); + + if (!dhstr.empty()) + { + dhparams.set(dhstr); + serverctx.SetDHParams(dhparams); + } + + clientctx.SetOptionalVerifyCert(); + clientctx.SetCA(cacerts, crl); + // The default for servers is to not request a client certificate from the peer + if (requestclientcert) + { + serverctx.SetOptionalVerifyCert(); + serverctx.SetCA(cacerts, crl); + } + } + + static std::string ReadFile(const std::string& filename) + { + FileReader reader(filename); + std::string ret = reader.GetString(); + if (ret.empty()) + throw Exception("Cannot read file " + filename); + return ret; + } + + public: + static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg) + { + const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem")); + const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem")); + const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem")); + + const std::string ciphersuitestr = tag->getString("ciphersuites"); + const std::string curvestr = tag->getString("curves"); + unsigned int mindh = tag->getInt("mindhbits", 2048); + std::string hashstr = tag->getString("hash", "sha256"); + + std::string crlstr; + std::string castr = tag->getString("cafile"); + if (!castr.empty()) + { + castr = ReadFile(castr); + crlstr = tag->getString("crlfile"); + if (!crlstr.empty()) + crlstr = ReadFile(crlstr); + } + + int minver = tag->getInt("minver"); + int maxver = tag->getInt("maxver"); + unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384); + const bool requestclientcert = tag->getBool("requestclientcert", true); + return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert); + } + + /** Set up the given session with the settings in this profile + */ + void SetupClientSession(mbedtls_ssl_context* sess) + { + mbedtls_ssl_setup(sess, clientctx.GetConf()); + } + + void SetupServerSession(mbedtls_ssl_context* sess) + { + mbedtls_ssl_setup(sess, serverctx.GetConf()); + } + + const std::string& GetName() const { return name; } + X509Credentials& GetX509Credentials() { return x509cred; } + unsigned int GetOutgoingRecordSize() const { return outrecsize; } + const Hash& GetHash() const { return hash; } + }; +} + +class mbedTLSIOHook : public SSLIOHook +{ + enum Status + { + ISSL_NONE, + ISSL_HANDSHAKING, + ISSL_HANDSHAKEN + }; + + mbedtls_ssl_context sess; + Status status; + reference<mbedTLS::Profile> profile; + + void CloseSession() + { + if (status == ISSL_NONE) + return; + + mbedtls_ssl_close_notify(&sess); + mbedtls_ssl_free(&sess); + certificate = NULL; + status = ISSL_NONE; + } + + // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed + int Handshake(StreamSocket* sock) + { + int ret = mbedtls_ssl_handshake(&sess); + if (ret == 0) + { + // Change the seesion state + this->status = ISSL_HANDSHAKEN; + + VerifyCertificate(); + + // Finish writing, if any left + SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); + + return 1; + } + + this->status = ISSL_HANDSHAKING; + if (ret == MBEDTLS_ERR_SSL_WANT_READ) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + return 0; + } + else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); + return 0; + } + + sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret)); + CloseSession(); + return -1; + } + + // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error + int PrepareIO(StreamSocket* sock) + { + if (status == ISSL_HANDSHAKEN) + return 1; + else if (status == ISSL_HANDSHAKING) + { + // The handshake isn't finished, try to finish it + return Handshake(sock); + } + + CloseSession(); + sock->SetError("No SSL session"); + return -1; + } + + void VerifyCertificate() + { + this->certificate = new ssl_cert; + const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess); + if (!cert) + { + certificate->error = "No client certificate sent"; + return; + } + + // If there is a certificate we can always generate a fingerprint + certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len); + + // At this point mbedTLS verified the cert already, we just need to check the results + const uint32_t flags = mbedtls_ssl_get_verify_result(&sess); + if (flags == 0xFFFFFFFF) + { + certificate->error = "Internal error during verification"; + return; + } + + if (flags == 0) + { + // Verification succeeded + certificate->trusted = true; + } + else + { + // Verification failed + certificate->trusted = false; + if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE)) + certificate->error = "Not activated, or expired certificate"; + } + + certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED); + certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED); + certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK)); + + GetDNString(&cert->subject, certificate->dn); + GetDNString(&cert->issuer, certificate->issuer); + } + + static void GetDNString(const mbedtls_x509_name* x509name, std::string& out) + { + char buf[512]; + const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name); + if (ret <= 0) + return; + + out.assign(buf, ret); + } + + static int Pull(void* userptr, unsigned char* buffer, size_t size) + { + StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr); + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) + return MBEDTLS_ERR_SSL_WANT_READ; + + const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0); + if (ret < (int)size) + { + SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK); + if ((ret == -1) && (SocketEngine::IgnoreError())) + return MBEDTLS_ERR_SSL_WANT_READ; + } + return ret; + } + + static int Push(void* userptr, const unsigned char* buffer, size_t size) + { + StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr); + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) + return MBEDTLS_ERR_SSL_WANT_WRITE; + + const int ret = SocketEngine::Send(sock, buffer, size, 0); + if (ret < (int)size) + { + SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); + if ((ret == -1) && (SocketEngine::IgnoreError())) + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + return ret; + } + + public: + mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile) + : SSLIOHook(hookprov) + , status(ISSL_NONE) + , profile(sslprofile) + { + mbedtls_ssl_init(&sess); + if (isserver) + profile->SetupServerSession(&sess); + else + profile->SetupClientSession(&sess); + + mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL); + + sock->AddIOHook(this); + Handshake(sock); + } + + void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE + { + CloseSession(); + } + + int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE + { + // Finish handshake if needed + int prepret = PrepareIO(sock); + if (prepret <= 0) + return prepret; + + // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN. + char* const readbuf = ServerInstance->GetReadBuffer(); + const size_t readbufsize = ServerInstance->Config->NetBufferSize; + int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize); + if (ret > 0) + { + recvq.append(readbuf, ret); + + // Schedule a read if there is still data in the mbedTLS buffer + if (mbedtls_ssl_get_bytes_avail(&sess) > 0) + SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ); + return 1; + } + else if (ret == MBEDTLS_ERR_SSL_WANT_READ) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ); + return 0; + } + else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); + return 0; + } + else if (ret == 0) + { + sock->SetError("Connection closed"); + CloseSession(); + return -1; + } + else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error + { + sock->SetError(mbedTLS::ErrorToString(ret)); + CloseSession(); + return -1; + } + } + + int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE + { + // Finish handshake if needed + int prepret = PrepareIO(sock); + if (prepret <= 0) + return prepret; + + // Session is ready for transferring application data + while (!sendq.empty()) + { + FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + const StreamSocket::SendQueue::Element& buffer = sendq.front(); + int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length()); + if (ret == (int)buffer.length()) + { + // Wrote entire record, continue sending + sendq.pop_front(); + } + else if (ret > 0) + { + sendq.erase_front(ret); + SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE); + return 0; + } + else if (ret == 0) + { + sock->SetError("Connection closed"); + CloseSession(); + return -1; + } + else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE); + return 0; + } + else if (ret == MBEDTLS_ERR_SSL_WANT_READ) + { + SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ); + return 0; + } + else + { + sock->SetError(mbedTLS::ErrorToString(ret)); + CloseSession(); + return -1; + } + } + + SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE); + return 1; + } + + void GetCiphersuite(std::string& out) const CXX11_OVERRIDE + { + if (!IsHandshakeDone()) + return; + out.append(mbedtls_ssl_get_version(&sess)).push_back('-'); + + // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes + const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess); + const char prefix[] = "TLS-"; + unsigned int skip = sizeof(prefix)-1; + if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1)) + skip = 0; + out.append(ciphersuitestr + skip); + } + + bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); } +}; + +class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<mbedTLS::Profile> profile; + + public: + mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~mbedTLSIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new mbedTLSIOHook(this, sock, true, profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new mbedTLSIOHook(this, sock, false, profile); + } +}; + +class ModuleSSLmbedTLS : public Module +{ + typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList; + + mbedTLS::Entropy entropy; + mbedTLS::CTRDRBG ctr_drbg; + ProfileList profiles; + + void ReadProfiles() + { + // First, store all profiles in a new, temporary container. If no problems occur, swap the two + // containers; this way if something goes wrong we can go back and continue using the current profiles, + // avoiding unpleasant situations where no new SSL connections are possible. + ProfileList newprofiles; + + ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile"); + if (tags.first == tags.second) + { + // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block + const std::string defname = "mbedtls"; + ConfigTag* tag = ServerInstance->Config->ConfValue(defname); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag"); + + try + { + reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg)); + newprofiles.push_back(new mbedTLSIOHookProvider(this, profile)); + } + catch (CoreException& ex) + { + throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason()); + } + } + + for (ConfigIter i = tags.first; i != tags.second; ++i) + { + ConfigTag* tag = i->second; + if (tag->getString("provider") != "mbedtls") + continue; + + std::string name = tag->getString("name"); + if (name.empty()) + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation()); + continue; + } + + reference<mbedTLS::Profile> profile; + try + { + profile = mbedTLS::Profile::Create(name, tag, ctr_drbg); + } + catch (CoreException& ex) + { + throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); + } + + newprofiles.push_back(new mbedTLSIOHookProvider(this, profile)); + } + + // New profiles are ok, begin using them + // Old profiles are deleted when their refcount drops to zero + profiles.swap(newprofiles); + } + + public: + void init() CXX11_OVERRIDE + { + char verbuf[16]; // Should be at least 9 bytes in size + mbedtls_version_get_string(verbuf); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf); + + if (!ctr_drbg.Seed(entropy)) + throw ModuleException("CTR DRBG seed failed"); + ReadProfiles(); + } + + void OnModuleRehash(User* user, const std::string ¶m) CXX11_OVERRIDE + { + if (param != "ssl") + return; + + try + { + ReadProfiles(); + } + catch (ModuleException& ex) + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings."); + } + } + + void OnCleanup(int target_type, void* item) CXX11_OVERRIDE + { + if (target_type != TYPE_USER) + return; + + LocalUser* user = IS_LOCAL(static_cast<User*>(item)); + if ((user) && (user->eh.GetModHook(this))) + { + // User is using SSL, they're a local user, and they're using our IOHook. + // Potentially there could be multiple SSL modules loaded at once on different ports. + ServerInstance->Users.QuitUser(user, "SSL module unloading"); + } + } + + ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE + { + const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) + return MOD_RES_DENY; + return MOD_RES_PASSTHRU; + } + + Version GetVersion() CXX11_OVERRIDE + { + return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR); + } +}; + +MODULE_INIT(ModuleSSLmbedTLS) diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index aee7a5e34..8467cc6d4 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -21,843 +21,930 @@ * along with this program. If not, see <http://www.gnu.org/licenses/>. */ - /* HACK: This prevents OpenSSL on OS X 10.7 and later from spewing deprecation - * warnings for every single function call. As far as I (SaberUK) know, Apple - * have no plans to remove OpenSSL so this warning just causes needless spam. - */ -#ifdef __APPLE__ -# define __AVAILABILITYMACROS__ -# define DEPRECATED_IN_MAC_OS_X_VERSION_10_7_AND_LATER -#endif - + #include "inspircd.h" +#include "iohook.h" +#include "modules/ssl.h" + +// Ignore OpenSSL deprecation warnings on OS X Lion and newer. +#if defined __APPLE__ +# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + +// Fix warnings about the use of `long long` on C++03. +#if defined __clang__ +# pragma clang diagnostic ignored "-Wc++11-long-long" +#elif defined __GNUC__ +# pragma GCC diagnostic ignored "-Wlong-long" +#endif + #include <openssl/ssl.h> #include <openssl/err.h> -#include "ssl.h" #ifdef _WIN32 # pragma comment(lib, "ssleay32.lib") # pragma comment(lib, "libeay32.lib") -# undef MAX_DESCRIPTORS -# define MAX_DESCRIPTORS 10000 #endif -/* $ModDesc: Provides SSL support for clients */ - -/* $LinkerFlags: if("USE_FREEBSD_BASE_SSL") -lssl -lcrypto */ -/* $CompileFlags: if(!"USE_FREEBSD_BASE_SSL") pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */ -/* $LinkerFlags: if(!"USE_FREEBSD_BASE_SSL") rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */ +/* $CompileFlags: pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */ +/* $LinkerFlags: rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto") */ -/* $NoPedantic */ - - -class ModuleSSLOpenSSL; +#if ((OPENSSL_VERSION_NUMBER >= 0x10000000L) && (!(defined(OPENSSL_NO_ECDH)))) +// OpenSSL 0.9.8 includes some ECC support, but it's unfinished. Enable only for 1.0.0 and later. +#define INSPIRCD_OPENSSL_ENABLE_ECDH +#endif enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN }; static bool SelfSigned = false; - -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION -static ModuleSSLOpenSSL* opensslmod = NULL; -#endif +static int exdataindex; char* get_error() { return ERR_error_string(ERR_get_error(), NULL); } -static int error_callback(const char *str, size_t len, void *u); +static int OnVerify(int preverify_ok, X509_STORE_CTX* ctx); +static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc); -/** Represents an SSL user's extra data - */ -class issl_session +namespace OpenSSL { -public: - SSL* sess; - issl_status status; - reference<ssl_cert> cert; - - bool outbound; - bool data_to_write; - - issl_session() - : sess(NULL) - , status(ISSL_NONE) + class Exception : public ModuleException { - outbound = false; - data_to_write = false; - } -}; + public: + Exception(const std::string& reason) + : ModuleException(reason) { } + }; -static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) -{ - /* XXX: This will allow self signed certificates. - * In the future if we want an option to not allow this, - * we can just return preverify_ok here, and openssl - * will boot off self-signed and invalid peer certs. - */ - int ve = X509_STORE_CTX_get_error(ctx); - - SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT); - - return 1; -} + class DHParams + { + DH* dh; -class ModuleSSLOpenSSL : public Module -{ - issl_session* sessions; + public: + DHParams(const std::string& filename) + { + BIO* dhpfile = BIO_new_file(filename.c_str(), "r"); + if (dhpfile == NULL) + throw Exception("Couldn't open DH file " + filename); - SSL_CTX* ctx; - SSL_CTX* clictx; + dh = PEM_read_bio_DHparams(dhpfile, NULL, NULL, NULL); + BIO_free(dhpfile); - long ctx_options; - long clictx_options; + if (!dh) + throw Exception("Couldn't read DH params from file " + filename); + } - std::string sslports; - bool use_sha; + ~DHParams() + { + DH_free(dh); + } - ServiceProvider iohook; + DH* get() + { + return dh; + } + }; - static void SetContextOptions(SSL_CTX* ctx, long defoptions, const std::string& ctxname, ConfigTag* tag) + class Context { - long setoptions = tag->getInt(ctxname + "setoptions"); - // User-friendly config options for setting context options -#ifdef SSL_OP_CIPHER_SERVER_PREFERENCE - if (tag->getBool("cipherserverpref")) - setoptions |= SSL_OP_CIPHER_SERVER_PREFERENCE; + SSL_CTX* const ctx; + long ctx_options; + + public: + Context(SSL_CTX* context) + : ctx(context) + { + // Sane default options for OpenSSL see https://www.openssl.org/docs/ssl/SSL_CTX_set_options.html + // and when choosing a cipher, use the server's preferences instead of the client preferences. + long opts = SSL_OP_NO_SSLv2 | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_SINGLE_DH_USE; + // Only turn options on if they exist +#ifdef SSL_OP_SINGLE_ECDH_USE + opts |= SSL_OP_SINGLE_ECDH_USE; #endif -#ifdef SSL_OP_NO_COMPRESSION - if (!tag->getBool("compression", true)) - setoptions |= SSL_OP_NO_COMPRESSION; +#ifdef SSL_OP_NO_TICKET + opts |= SSL_OP_NO_TICKET; #endif - if (!tag->getBool("sslv3", true)) - setoptions |= SSL_OP_NO_SSLv3; - if (!tag->getBool("tlsv1", true)) - setoptions |= SSL_OP_NO_TLSv1; - long clearoptions = tag->getInt(ctxname + "clearoptions"); - ServerInstance->Logs->Log("m_ssl_openssl", DEBUG, "Setting OpenSSL %s context options, default: %ld set: %ld clear: %ld", ctxname.c_str(), defoptions, setoptions, clearoptions); + ctx_options = SSL_CTX_set_options(ctx, opts); - // Clear everything - SSL_CTX_clear_options(ctx, SSL_CTX_get_options(ctx)); + long mode = SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER; +#ifdef SSL_MODE_RELEASE_BUFFERS + mode |= SSL_MODE_RELEASE_BUFFERS; +#endif + SSL_CTX_set_mode(ctx, mode); + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL); + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF); + SSL_CTX_set_info_callback(ctx, StaticSSLInfoCallback); + } - // Set the default options and what is in the conf - SSL_CTX_set_options(ctx, defoptions | setoptions); - long final = SSL_CTX_clear_options(ctx, clearoptions); - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "OpenSSL %s context options: %ld", ctxname.c_str(), final); - } + ~Context() + { + SSL_CTX_free(ctx); + } -#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH - void SetupECDH(ConfigTag* tag) - { - std::string curvename = tag->getString("ecdhcurve", "prime256v1"); - if (curvename.empty()) - return; + bool SetDH(DHParams& dh) + { + ERR_clear_error(); + return (SSL_CTX_set_tmp_dh(ctx, dh.get()) >= 0); + } - int nid = OBJ_sn2nid(curvename.c_str()); - if (nid == 0) +#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH + void SetECDH(const std::string& curvename) { - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Unknown curve: \"%s\"", curvename.c_str()); - return; + int nid = OBJ_sn2nid(curvename.c_str()); + if (nid == 0) + throw Exception("Unknown curve: " + curvename); + + EC_KEY* eckey = EC_KEY_new_by_curve_name(nid); + if (!eckey) + throw Exception("Unable to create EC key object"); + + ERR_clear_error(); + bool ret = (SSL_CTX_set_tmp_ecdh(ctx, eckey) >= 0); + EC_KEY_free(eckey); + if (!ret) + throw Exception("Couldn't set ECDH parameters"); } +#endif - EC_KEY* eckey = EC_KEY_new_by_curve_name(nid); - if (!eckey) + bool SetCiphers(const std::string& ciphers) { - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Unable to create EC key object"); - return; + ERR_clear_error(); + return SSL_CTX_set_cipher_list(ctx, ciphers.c_str()); } - ERR_clear_error(); - if (SSL_CTX_set_tmp_ecdh(ctx, eckey) < 0) + bool SetCerts(const std::string& filename) { - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set ECDH parameters"); - ERR_print_errors_cb(error_callback, this); + ERR_clear_error(); + return SSL_CTX_use_certificate_chain_file(ctx, filename.c_str()); } - EC_KEY_free(eckey); - } -#endif + bool SetPrivateKey(const std::string& filename) + { + ERR_clear_error(); + return SSL_CTX_use_PrivateKey_file(ctx, filename.c_str(), SSL_FILETYPE_PEM); + } -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION - static void SSLInfoCallback(const SSL* ssl, int where, int rc) - { - int fd = SSL_get_fd(const_cast<SSL*>(ssl)); - issl_session& session = opensslmod->sessions[fd]; + bool SetCA(const std::string& filename) + { + ERR_clear_error(); + return SSL_CTX_load_verify_locations(ctx, filename.c_str(), 0); + } - if ((where & SSL_CB_HANDSHAKE_START) && (session.status == ISSL_OPEN)) + long GetDefaultContextOptions() const { - // The other side is trying to renegotiate, kill the connection and change status - // to ISSL_NONE so CheckRenego() closes the session - session.status = ISSL_NONE; - ServerInstance->SE->Shutdown(fd, 2); + return ctx_options; } - } - bool CheckRenego(StreamSocket* sock, issl_session* session) - { - if (session->status != ISSL_NONE) - return true; + long SetRawContextOptions(long setoptions, long clearoptions) + { + // Clear everything + SSL_CTX_clear_options(ctx, SSL_CTX_get_options(ctx)); - ServerInstance->Logs->Log("m_ssl_openssl", DEBUG, "Session %p killed, attempted to renegotiate", (void*)session->sess); - CloseSession(session); - sock->SetError("Renegotiation is not allowed"); - return false; - } -#endif + // Set the default options and what is in the conf + SSL_CTX_set_options(ctx, ctx_options | setoptions); + return SSL_CTX_clear_options(ctx, clearoptions); + } - public: + void SetVerifyCert() + { + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); + } - ModuleSSLOpenSSL() : iohook(this, "ssl/openssl", SERVICE_IOHOOK) - { -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION - opensslmod = this; -#endif - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; + SSL* CreateServerSession() + { + SSL* sess = SSL_new(ctx); + SSL_set_accept_state(sess); // Act as server + return sess; + } - /* Global SSL library initialization*/ - SSL_library_init(); - SSL_load_error_strings(); + SSL* CreateClientSession() + { + SSL* sess = SSL_new(ctx); + SSL_set_connect_state(sess); // Act as client + return sess; + } + }; - /* Build our SSL contexts: - * NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK. + class Profile : public refcountbase + { + /** Name of this profile */ - ctx = SSL_CTX_new( SSLv23_server_method() ); - clictx = SSL_CTX_new( SSLv23_client_method() ); + const std::string name; - SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - SSL_CTX_set_mode(clictx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + /** DH parameters in use + */ + DHParams dh; - SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); - SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); + /** OpenSSL makes us have two contexts, one for servers and one for clients + */ + Context ctx; + Context clictx; - SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF); - SSL_CTX_set_session_cache_mode(clictx, SSL_SESS_CACHE_OFF); + /** Digest to use when generating fingerprints + */ + const EVP_MD* digest; - long opts = SSL_OP_NO_SSLv2 | SSL_OP_SINGLE_DH_USE; - // Only turn options on if they exist -#ifdef SSL_OP_SINGLE_ECDH_USE - opts |= SSL_OP_SINGLE_ECDH_USE; -#endif -#ifdef SSL_OP_NO_TICKET - opts |= SSL_OP_NO_TICKET; -#endif + /** Last error, set by error_callback() + */ + std::string lasterr; - ctx_options = SSL_CTX_set_options(ctx, opts); - clictx_options = SSL_CTX_set_options(clictx, opts); - } + /** True if renegotiations are allowed, false if not + */ + const bool allowrenego; - void init() - { - // Needs the flag as it ignores a plain /rehash - OnModuleRehash(NULL,"ssl"); - Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO, I_OnUserConnect }; - ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); - ServerInstance->Modules->AddService(iohook); - } + /** Rough max size of records to send + */ + const unsigned int outrecsize; - void OnHookIO(StreamSocket* user, ListenSocket* lsb) - { - if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "openssl") + static int error_callback(const char* str, size_t len, void* u) { - /* Hook the user with our module */ - user->AddIOHook(this); + Profile* profile = reinterpret_cast<Profile*>(u); + profile->lasterr = std::string(str, len - 1); + return 0; } - } - - void OnRehash(User* user) - { - sslports.clear(); - ConfigTag* Conf = ServerInstance->Config->ConfValue("openssl"); - -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION - // Set the callback if we are not allowing renegotiations, unset it if we do - if (Conf->getBool("renegotiation", true)) - { - SSL_CTX_set_info_callback(ctx, NULL); - SSL_CTX_set_info_callback(clictx, NULL); - } - else + /** Set raw OpenSSL context (SSL_CTX) options from a config tag + * @param ctxname Name of the context, client or server + * @param tag Config tag defining this profile + * @param context Context object to manipulate + */ + void SetContextOptions(const std::string& ctxname, ConfigTag* tag, Context& context) { - SSL_CTX_set_info_callback(ctx, SSLInfoCallback); - SSL_CTX_set_info_callback(clictx, SSLInfoCallback); - } + long setoptions = tag->getInt(ctxname + "setoptions"); + long clearoptions = tag->getInt(ctxname + "clearoptions"); +#ifdef SSL_OP_NO_COMPRESSION + if (!tag->getBool("compression", false)) // Disable compression by default + setoptions |= SSL_OP_NO_COMPRESSION; #endif - - if (Conf->getBool("showports", true)) - { - sslports = Conf->getString("advertisedports"); - if (!sslports.empty()) - return; - - for (size_t i = 0; i < ServerInstance->ports.size(); i++) + if (!tag->getBool("sslv3", false)) // Disable SSLv3 by default + setoptions |= SSL_OP_NO_SSLv3; + if (!tag->getBool("tlsv1", true)) + setoptions |= SSL_OP_NO_TLSv1; + + if (!setoptions && !clearoptions) + return; // Nothing to do + + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Setting %s %s context options, default: %ld set: %ld clear: %ld", name.c_str(), ctxname.c_str(), ctx.GetDefaultContextOptions(), setoptions, clearoptions); + long final = context.SetRawContextOptions(setoptions, clearoptions); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "%s %s context options: %ld", name.c_str(), ctxname.c_str(), final); + } + + public: + Profile(const std::string& profilename, ConfigTag* tag) + : name(profilename) + , dh(ServerInstance->Config->Paths.PrependConfig(tag->getString("dhfile", "dh.pem"))) + , ctx(SSL_CTX_new(SSLv23_server_method())) + , clictx(SSL_CTX_new(SSLv23_client_method())) + , allowrenego(tag->getBool("renegotiation")) // Disallow by default + , outrecsize(tag->getInt("outrecsize", 2048, 512, 16384)) + { + if ((!ctx.SetDH(dh)) || (!clictx.SetDH(dh))) + throw Exception("Couldn't set DH parameters"); + + std::string hash = tag->getString("hash", "md5"); + digest = EVP_get_digestbyname(hash.c_str()); + if (digest == NULL) + throw Exception("Unknown hash type " + hash); + + std::string ciphers = tag->getString("ciphers"); + if (!ciphers.empty()) { - ListenSocket* port = ServerInstance->ports[i]; - if (port->bind_tag->getString("ssl") != "openssl") - continue; - - const std::string& portid = port->bind_desc; - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %s", portid.c_str()); - - if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1") + if ((!ctx.SetCiphers(ciphers)) || (!clictx.SetCiphers(ciphers))) { - /* - * Found an SSL port for clients that is not bound to 127.0.0.1 and handled by us, display - * the IP:port in ISUPPORT. - * - * We used to advertise all ports seperated by a ';' char that matched the above criteria, - * but this resulted in too long ISUPPORT lines if there were lots of ports to be displayed. - * To solve this by default we now only display the first IP:port found and let the user - * configure the exact value for the 005 token, if necessary. - */ - sslports = portid; - break; + ERR_print_errors_cb(error_callback, this); + throw Exception("Can't set cipher list to \"" + ciphers + "\" " + lasterr); } } - } - } - - void OnModuleRehash(User* user, const std::string ¶m) - { - if (param != "ssl") - return; - std::string keyfile; - std::string certfile; - std::string cafile; - std::string dhfile; - OnRehash(user); - - ConfigTag* conf = ServerInstance->Config->ConfValue("openssl"); +#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH + std::string curvename = tag->getString("ecdhcurve", "prime256v1"); + if (!curvename.empty()) + ctx.SetECDH(curvename); +#endif - cafile = conf->getString("cafile", CONFIG_PATH "/ca.pem"); - certfile = conf->getString("certfile", CONFIG_PATH "/cert.pem"); - keyfile = conf->getString("keyfile", CONFIG_PATH "/key.pem"); - dhfile = conf->getString("dhfile", CONFIG_PATH "/dhparams.pem"); - std::string hash = conf->getString("hash", "md5"); - if (hash != "sha1" && hash != "md5") - throw ModuleException("Unknown hash type " + hash); - use_sha = (hash == "sha1"); + SetContextOptions("server", tag, ctx); + SetContextOptions("client", tag, clictx); - if (conf->getBool("customcontextoptions")) - { - SetContextOptions(ctx, ctx_options, "server", conf); - SetContextOptions(clictx, clictx_options, "client", conf); - } + /* Load our keys and certificates + * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck. + */ + std::string filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("certfile", "cert.pem")); + if ((!ctx.SetCerts(filename)) || (!clictx.SetCerts(filename))) + { + ERR_print_errors_cb(error_callback, this); + throw Exception("Can't read certificate file: " + lasterr); + } - std::string ciphers = conf->getString("ciphers", ""); + filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("keyfile", "key.pem")); + if ((!ctx.SetPrivateKey(filename)) || (!clictx.SetPrivateKey(filename))) + { + ERR_print_errors_cb(error_callback, this); + throw Exception("Can't read key file: " + lasterr); + } - if (!ciphers.empty()) - { - ERR_clear_error(); - if ((!SSL_CTX_set_cipher_list(ctx, ciphers.c_str())) || (!SSL_CTX_set_cipher_list(clictx, ciphers.c_str()))) + // Load the CAs we trust + filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("cafile", "ca.pem")); + if ((!ctx.SetCA(filename)) || (!clictx.SetCA(filename))) { - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't set cipher list to %s.", ciphers.c_str()); ERR_print_errors_cb(error_callback, this); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", filename.c_str(), lasterr.c_str()); } + + clictx.SetVerifyCert(); + if (tag->getBool("requestclientcert", true)) + ctx.SetVerifyCert(); } - /* Load our keys and certificates - * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck. - */ - ERR_clear_error(); - if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str()))) + const std::string& GetName() const { return name; } + SSL* CreateServerSession() { return ctx.CreateServerSession(); } + SSL* CreateClientSession() { return clictx.CreateClientSession(); } + const EVP_MD* GetDigest() { return digest; } + bool AllowRenegotiation() const { return allowrenego; } + unsigned int GetOutgoingRecordSize() const { return outrecsize; } + }; + + namespace BIOMethod + { + static int create(BIO* bio) { - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno)); - ERR_print_errors_cb(error_callback, this); + bio->init = 1; + return 1; } - ERR_clear_error(); - if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM))) + static int destroy(BIO* bio) { - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno)); - ERR_print_errors_cb(error_callback, this); + // XXX: Dummy function to avoid a memory leak in OpenSSL. + // The memory leak happens in BIO_free() (bio_lib.c) when the destroy func of the BIO is NULL. + // This is fixed in OpenSSL but some distros still ship the unpatched version hence we provide this workaround. + return 1; } - /* Load the CAs we trust*/ - ERR_clear_error(); - if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0))) + static long ctrl(BIO* bio, int cmd, long num, void* ptr) { - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", cafile.c_str(), strerror(errno)); - ERR_print_errors_cb(error_callback, this); + if (cmd == BIO_CTRL_FLUSH) + return 1; + return 0; } -#ifdef _WIN32 - BIO* dhpfile = BIO_new_file(dhfile.c_str(), "r"); -#else - FILE* dhpfile = fopen(dhfile.c_str(), "r"); -#endif - DH* ret; + static int read(BIO* bio, char* buf, int len); + static int write(BIO* bio, const char* buf, int len); + } +} - if (dhpfile == NULL) - { - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno)); - throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno)); - } - else +static BIO_METHOD biomethods = +{ + (100 | BIO_TYPE_SOURCE_SINK), + "inspircd", + OpenSSL::BIOMethod::write, + OpenSSL::BIOMethod::read, + NULL, // puts + NULL, // gets + OpenSSL::BIOMethod::ctrl, + OpenSSL::BIOMethod::create, + OpenSSL::BIOMethod::destroy, // destroy, does nothing, see function body for more info + NULL // callback_ctrl +}; + +static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) +{ + /* XXX: This will allow self signed certificates. + * In the future if we want an option to not allow this, + * we can just return preverify_ok here, and openssl + * will boot off self-signed and invalid peer certs. + */ + int ve = X509_STORE_CTX_get_error(ctx); + + SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT); + + return 1; +} + +class OpenSSLIOHook : public SSLIOHook +{ + private: + SSL* sess; + issl_status status; + bool data_to_write; + reference<OpenSSL::Profile> profile; + + // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed + int Handshake(StreamSocket* user) + { + ERR_clear_error(); + int ret = SSL_do_handshake(sess); + if (ret < 0) { -#ifdef _WIN32 - ret = PEM_read_bio_DHparams(dhpfile, NULL, NULL, NULL); - BIO_free(dhpfile); -#else - ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL); -#endif + int err = SSL_get_error(sess, ret); - ERR_clear_error(); - if (ret) + if (err == SSL_ERROR_WANT_READ) { - if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0)) - { - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str()); - ERR_print_errors_cb(error_callback, this); - } - DH_free(ret); + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + this->status = ISSL_HANDSHAKING; + return 0; + } + else if (err == SSL_ERROR_WANT_WRITE) + { + SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); + this->status = ISSL_HANDSHAKING; + return 0; } else { - ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s.", dhfile.c_str()); + CloseSession(); + return -1; } } + else if (ret > 0) + { + // Handshake complete. + VerifyCertificate(); -#ifndef _WIN32 - fclose(dhpfile); -#endif - -#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH - SetupECDH(conf); -#endif - } + status = ISSL_OPEN; - void On005Numeric(std::string &output) - { - if (!sslports.empty()) - output.append(" SSL=" + sslports); - } + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); - ~ModuleSSLOpenSSL() - { - SSL_CTX_free(ctx); - SSL_CTX_free(clictx); - delete[] sessions; - } - - void OnUserConnect(LocalUser* user) - { - if (user->eh.GetIOHook() == this) + return 1; + } + else if (ret == 0) { - if (sessions[user->eh.GetFd()].sess) - { - if (!sessions[user->eh.GetFd()].cert->fingerprint.empty()) - user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"" - " and your SSL fingerprint is %s", user->nick.c_str(), SSL_get_cipher(sessions[user->eh.GetFd()].sess), sessions[user->eh.GetFd()].cert->fingerprint.c_str()); - else - user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), SSL_get_cipher(sessions[user->eh.GetFd()].sess)); - } + CloseSession(); } + return -1; } - void OnCleanup(int target_type, void* item) + void CloseSession() { - if (target_type == TYPE_USER) + if (sess) { - LocalUser* user = IS_LOCAL((User*)item); - - if (user && user->eh.GetIOHook() == this) - { - // User is using SSL, they're a local user, and they're using one of *our* SSL ports. - // Potentially there could be multiple SSL modules loaded at once on different ports. - ServerInstance->Users->QuitUser(user, "SSL module unloading"); - } + SSL_shutdown(sess); + SSL_free(sess); } + sess = NULL; + certificate = NULL; + status = ISSL_NONE; } - Version GetVersion() + void VerifyCertificate() { - return Version("Provides SSL support for clients", VF_VENDOR); - } + X509* cert; + ssl_cert* certinfo = new ssl_cert; + this->certificate = certinfo; + unsigned int n; + unsigned char md[EVP_MAX_MD_SIZE]; - void OnRequest(Request& request) - { - if (strcmp("GET_SSL_CERT", request.id) == 0) + cert = SSL_get_peer_certificate(sess); + + if (!cert) { - SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request); - int fd = req.sock->GetFd(); - issl_session* session = &sessions[fd]; + certinfo->error = "Could not get peer certificate: "+std::string(get_error()); + return; + } + + certinfo->invalid = (SSL_get_verify_result(sess) != X509_V_OK); - req.cert = session->cert; + if (!SelfSigned) + { + certinfo->unknownsigner = false; + certinfo->trusted = true; } - else if (!strcmp("GET_RAW_SSL_SESSION", request.id)) + else { - SSLRawSessionRequest& req = static_cast<SSLRawSessionRequest&>(request); - if ((req.fd >= 0) && (req.fd < ServerInstance->SE->GetMaxFds())) - req.data = reinterpret_cast<void*>(sessions[req.fd].sess); + certinfo->unknownsigner = true; + certinfo->trusted = false; } - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) - { - int fd = user->GetFd(); - issl_session* session = &sessions[fd]; + char buf[512]; + X509_NAME_oneline(X509_get_subject_name(cert), buf, sizeof(buf)); + certinfo->dn = buf; + // Make sure there are no chars in the string that we consider invalid + if (certinfo->dn.find_first_of("\r\n") != std::string::npos) + certinfo->dn.clear(); - session->sess = SSL_new(ctx); - session->status = ISSL_NONE; - session->outbound = false; - session->data_to_write = false; + X509_NAME_oneline(X509_get_issuer_name(cert), buf, sizeof(buf)); + certinfo->issuer = buf; + if (certinfo->issuer.find_first_of("\r\n") != std::string::npos) + certinfo->issuer.clear(); - if (session->sess == NULL) - return; + if (!X509_digest(cert, profile->GetDigest(), md, &n)) + { + certinfo->error = "Out of memory generating fingerprint"; + } + else + { + certinfo->fingerprint = BinToHex(md, n); + } - if (SSL_set_fd(session->sess, fd) == 0) + if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), ServerInstance->Time()) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), ServerInstance->Time()) == 0)) { - ServerInstance->Logs->Log("m_ssl_openssl",DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); - return; + certinfo->error = "Not activated, or expired certificate"; } - Handshake(user, session); + X509_free(cert); } - void OnStreamSocketConnect(StreamSocket* user) + void SSLInfoCallback(int where, int rc) { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1)) - return; + if ((where & SSL_CB_HANDSHAKE_START) && (status == ISSL_OPEN)) + { + if (profile->AllowRenegotiation()) + return; - issl_session* session = &sessions[fd]; + // The other side is trying to renegotiate, kill the connection and change status + // to ISSL_NONE so CheckRenego() closes the session + status = ISSL_NONE; + BIO* bio = SSL_get_rbio(sess); + EventHandler* eh = static_cast<StreamSocket*>(bio->ptr); + SocketEngine::Shutdown(eh, 2); + } + } - session->sess = SSL_new(clictx); - session->status = ISSL_NONE; - session->outbound = true; - session->data_to_write = false; + bool CheckRenego(StreamSocket* sock) + { + if (status != ISSL_NONE) + return true; - if (session->sess == NULL) - return; + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Session %p killed, attempted to renegotiate", (void*)sess); + CloseSession(); + sock->SetError("Renegotiation is not allowed"); + return false; + } - if (SSL_set_fd(session->sess, fd) == 0) + // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error + int PrepareIO(StreamSocket* sock) + { + if (status == ISSL_OPEN) + return 1; + else if (status == ISSL_HANDSHAKING) { - ServerInstance->Logs->Log("m_ssl_openssl",DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); - return; + // The handshake isn't finished, try to finish it + return Handshake(sock); } - Handshake(user, session); + CloseSession(); + return -1; } - void OnStreamSocketClose(StreamSocket* user) - { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; + // Calls our private SSLInfoCallback() + friend void StaticSSLInfoCallback(const SSL* ssl, int where, int rc); - CloseSession(&sessions[fd]); + public: + OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session, const reference<OpenSSL::Profile>& sslprofile) + : SSLIOHook(hookprov) + , sess(session) + , status(ISSL_NONE) + , data_to_write(false) + , profile(sslprofile) + { + // Create BIO instance and store a pointer to the socket in it which will be used by the read and write functions + BIO* bio = BIO_new(&biomethods); + bio->ptr = sock; + SSL_set_bio(sess, bio, bio); + + SSL_set_ex_data(sess, exdataindex, this); + sock->AddIOHook(this); + Handshake(sock); } - int OnStreamSocketRead(StreamSocket* user, std::string& recvq) + void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return -1; - - issl_session* session = &sessions[fd]; - - if (!session->sess) - { - CloseSession(session); - return -1; - } - - if (session->status == ISSL_HANDSHAKING) - { - // The handshake isn't finished and it wants to read, try to finish it. - if (!Handshake(user, session)) - { - // Couldn't resume handshake. - if (session->status == ISSL_NONE) - return -1; - return 0; - } - } + CloseSession(); + } - // If we resumed the handshake then session->status will be ISSL_OPEN + int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE + { + // Finish handshake if needed + int prepret = PrepareIO(user); + if (prepret <= 0) + return prepret; - if (session->status == ISSL_OPEN) + // If we resumed the handshake then this->status will be ISSL_OPEN { ERR_clear_error(); char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = SSL_read(session->sess, buffer, bufsiz); + int ret = SSL_read(sess, buffer, bufsiz); -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION - if (!CheckRenego(user, session)) + if (!CheckRenego(user)) return -1; -#endif if (ret > 0) { recvq.append(buffer, ret); - int mask = 0; // Schedule a read if there is still data in the OpenSSL buffer - if (SSL_pending(session->sess) > 0) + if (SSL_pending(sess) > 0) mask |= FD_ADD_TRIAL_READ; - if (session->data_to_write) + if (data_to_write) mask |= FD_WANT_POLL_READ | FD_WANT_SINGLE_WRITE; if (mask != 0) - ServerInstance->SE->ChangeEventMask(user, mask); + SocketEngine::ChangeEventMask(user, mask); return 1; } else if (ret == 0) { // Client closed connection. - CloseSession(session); + CloseSession(); user->SetError("Connection closed"); return -1; } - else if (ret < 0) + else // if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_READ) { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ); + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ); return 0; } else if (err == SSL_ERROR_WANT_WRITE) { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); + SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); return 0; } else { - CloseSession(session); + CloseSession(); return -1; } } } - - return 0; } - int OnStreamSocketWrite(StreamSocket* user, std::string& buffer) + int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE { - int fd = user->GetFd(); - - issl_session* session = &sessions[fd]; - - if (!session->sess) - { - CloseSession(session); - return -1; - } - - session->data_to_write = true; + // Finish handshake if needed + int prepret = PrepareIO(user); + if (prepret <= 0) + return prepret; - if (session->status == ISSL_HANDSHAKING) - { - if (!Handshake(user, session)) - { - // Couldn't resume handshake. - if (session->status == ISSL_NONE) - return -1; - return 0; - } - } + data_to_write = true; - if (session->status == ISSL_OPEN) + // Session is ready for transferring application data + while (!sendq.empty()) { ERR_clear_error(); - int ret = SSL_write(session->sess, buffer.data(), buffer.size()); + FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + const StreamSocket::SendQueue::Element& buffer = sendq.front(); + int ret = SSL_write(sess, buffer.data(), buffer.size()); -#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION - if (!CheckRenego(user, session)) + if (!CheckRenego(user)) return -1; -#endif if (ret == (int)buffer.length()) { - session->data_to_write = false; - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); - return 1; + // Wrote entire record, continue sending + sendq.pop_front(); } else if (ret > 0) { - buffer = buffer.substr(ret); - ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); + sendq.erase_front(ret); + SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE); return 0; } else if (ret == 0) { - CloseSession(session); + CloseSession(); return -1; } - else if (ret < 0) + else // if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_WRITE) { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); + SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE); return 0; } else if (err == SSL_ERROR_WANT_READ) { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ); + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ); return 0; } else { - CloseSession(session); + CloseSession(); return -1; } } } - return 0; + + data_to_write = false; + SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + return 1; } - bool Handshake(StreamSocket* user, issl_session* session) + void GetCiphersuite(std::string& out) const CXX11_OVERRIDE { - int ret; + if (!IsHandshakeDone()) + return; + out.append(SSL_get_version(sess)).push_back('-'); + out.append(SSL_get_cipher(sess)); + } - ERR_clear_error(); - if (session->outbound) - ret = SSL_connect(session->sess); - else - ret = SSL_accept(session->sess); + bool IsHandshakeDone() const { return (status == ISSL_OPEN); } +}; - if (ret < 0) - { - int err = SSL_get_error(session->sess, ret); +static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc) +{ + OpenSSLIOHook* hook = static_cast<OpenSSLIOHook*>(SSL_get_ex_data(ssl, exdataindex)); + hook->SSLInfoCallback(where, rc); +} - if (err == SSL_ERROR_WANT_READ) - { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); - session->status = ISSL_HANDSHAKING; - return true; - } - else if (err == SSL_ERROR_WANT_WRITE) - { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); - session->status = ISSL_HANDSHAKING; - return true; - } - else - { - CloseSession(session); - } +static int OpenSSL::BIOMethod::write(BIO* bio, const char* buffer, int size) +{ + BIO_clear_retry_flags(bio); - return false; - } - else if (ret > 0) - { - // Handshake complete. - VerifyCertificate(session, user); + StreamSocket* sock = static_cast<StreamSocket*>(bio->ptr); + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) + { + // Writes blocked earlier, don't retry syscall + BIO_set_retry_write(bio); + return -1; + } - session->status = ISSL_OPEN; + int ret = SocketEngine::Send(sock, buffer, size, 0); + if ((ret < size) && ((ret > 0) || (SocketEngine::IgnoreError()))) + { + // Blocked, set retry flag for OpenSSL + SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); + BIO_set_retry_write(bio); + } - ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); + return ret; +} - return true; - } - else if (ret == 0) - { - CloseSession(session); - } - return false; +static int OpenSSL::BIOMethod::read(BIO* bio, char* buffer, int size) +{ + BIO_clear_retry_flags(bio); + + StreamSocket* sock = static_cast<StreamSocket*>(bio->ptr); + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) + { + // Reads blocked earlier, don't retry syscall + BIO_set_retry_read(bio); + return -1; } - void CloseSession(issl_session* session) + int ret = SocketEngine::Recv(sock, buffer, size, 0); + if ((ret < size) && ((ret > 0) || (SocketEngine::IgnoreError()))) { - if (session->sess) - { - SSL_shutdown(session->sess); - SSL_free(session->sess); - } + // Blocked, set retry flag for OpenSSL + SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK); + BIO_set_retry_read(bio); + } + + return ret; +} + +class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<OpenSSL::Profile> profile; - session->sess = NULL; - session->status = ISSL_NONE; - session->cert = NULL; + public: + OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); } - void VerifyCertificate(issl_session* session, StreamSocket* user) + ~OpenSSLIOHookProvider() { - if (!session->sess || !user) - return; + ServerInstance->Modules->DelService(*this); + } - X509* cert; - ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; - unsigned int n; - unsigned char md[EVP_MAX_MD_SIZE]; - const EVP_MD *digest = use_sha ? EVP_sha1() : EVP_md5(); + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, profile->CreateServerSession(), profile); + } - cert = SSL_get_peer_certificate((SSL*)session->sess); + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, profile->CreateClientSession(), profile); + } +}; - if (!cert) - { - certinfo->error = "Could not get peer certificate: "+std::string(get_error()); - return; - } +class ModuleSSLOpenSSL : public Module +{ + typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList; - certinfo->invalid = (SSL_get_verify_result(session->sess) != X509_V_OK); + ProfileList profiles; - if (!SelfSigned) + void ReadProfiles() + { + ProfileList newprofiles; + ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile"); + if (tags.first == tags.second) { - certinfo->unknownsigner = false; - certinfo->trusted = true; + // Create a default profile named "openssl" + const std::string defname = "openssl"; + ConfigTag* tag = ServerInstance->Config->ConfValue(defname); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found, using settings from the <openssl> tag"); + + try + { + reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag)); + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); + } + catch (OpenSSL::Exception& ex) + { + throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason()); + } } - else + + for (ConfigIter i = tags.first; i != tags.second; ++i) { - certinfo->unknownsigner = true; - certinfo->trusted = false; + ConfigTag* tag = i->second; + if (tag->getString("provider") != "openssl") + continue; + + std::string name = tag->getString("name"); + if (name.empty()) + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation()); + continue; + } + + reference<OpenSSL::Profile> profile; + try + { + profile = new OpenSSL::Profile(name, tag); + } + catch (CoreException& ex) + { + throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); + } + + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); } - char buf[512]; - X509_NAME_oneline(X509_get_subject_name(cert), buf, sizeof(buf)); - certinfo->dn = buf; - // Make sure there are no chars in the string that we consider invalid - if (certinfo->dn.find_first_of("\r\n") != std::string::npos) - certinfo->dn.clear(); + profiles.swap(newprofiles); + } - X509_NAME_oneline(X509_get_issuer_name(cert), buf, sizeof(buf)); - certinfo->issuer = buf; - if (certinfo->issuer.find_first_of("\r\n") != std::string::npos) - certinfo->issuer.clear(); + public: + ModuleSSLOpenSSL() + { + // Initialize OpenSSL + SSL_library_init(); + SSL_load_error_strings(); + } + + void init() CXX11_OVERRIDE + { + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "OpenSSL lib version \"%s\" module was compiled for \"" OPENSSL_VERSION_TEXT "\"", SSLeay_version(SSLEAY_VERSION)); + + // Register application specific data + char exdatastr[] = "inspircd"; + exdataindex = SSL_get_ex_new_index(0, exdatastr, NULL, NULL, NULL); + if (exdataindex < 0) + throw ModuleException("Failed to register application specific data"); + + ReadProfiles(); + } + + void OnModuleRehash(User* user, const std::string ¶m) CXX11_OVERRIDE + { + if (param != "ssl") + return; - if (!X509_digest(cert, digest, md, &n)) + try { - certinfo->error = "Out of memory generating fingerprint"; + ReadProfiles(); } - else + catch (ModuleException& ex) { - certinfo->fingerprint = irc::hex(md, n); + ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings."); } + } - if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), ServerInstance->Time()) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), ServerInstance->Time()) == 0)) + void OnCleanup(int target_type, void* item) CXX11_OVERRIDE + { + if (target_type == TYPE_USER) { - certinfo->error = "Not activated, or expired certificate"; + LocalUser* user = IS_LOCAL((User*)item); + + if ((user) && (user->eh.GetModHook(this))) + { + // User is using SSL, they're a local user, and they're using one of *our* SSL ports. + // Potentially there could be multiple SSL modules loaded at once on different ports. + ServerInstance->Users->QuitUser(user, "SSL module unloading"); + } } + } - X509_free(cert); + ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE + { + const OpenSSLIOHook* const iohook = static_cast<OpenSSLIOHook*>(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) + return MOD_RES_DENY; + return MOD_RES_PASSTHRU; } -}; -static int error_callback(const char *str, size_t len, void *u) -{ - ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "SSL error: " + std::string(str, len - 1)); - - // - // XXX: Remove this line, it causes valgrind warnings... - // - // MD_update(&m, buf, j); - // - // - // ... ONLY JOKING! :-) - // - - return 0; -} + Version GetVersion() CXX11_OVERRIDE + { + return Version("Provides SSL support for clients", VF_VENDOR); + } +}; MODULE_INIT(ModuleSSLOpenSSL) |