summaryrefslogtreecommitdiff
path: root/src/modules/extra
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules/extra')
-rw-r--r--src/modules/extra/m_geoip.cpp114
-rw-r--r--src/modules/extra/m_ldap.cpp677
-rw-r--r--src/modules/extra/m_ldapauth.cpp425
-rw-r--r--src/modules/extra/m_ldapoper.cpp255
-rw-r--r--src/modules/extra/m_mssql.cpp870
-rw-r--r--src/modules/extra/m_mysql.cpp132
-rw-r--r--src/modules/extra/m_pgsql.cpp184
-rw-r--r--src/modules/extra/m_regex_pcre.cpp52
-rw-r--r--src/modules/extra/m_regex_posix.cpp46
-rw-r--r--src/modules/extra/m_regex_re2.cpp86
-rw-r--r--src/modules/extra/m_regex_stdlib.cpp62
-rw-r--r--src/modules/extra/m_regex_tre.cpp52
-rw-r--r--src/modules/extra/m_sqlite3.cpp104
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp1741
-rw-r--r--src/modules/extra/m_ssl_mbedtls.cpp969
-rw-r--r--src/modules/extra/m_ssl_openssl.cpp1328
-rw-r--r--src/modules/extra/m_sslrehashsignal.cpp64
17 files changed, 4023 insertions, 3138 deletions
diff --git a/src/modules/extra/m_geoip.cpp b/src/modules/extra/m_geoip.cpp
index 03b7a55f7..0d7c2eb70 100644
--- a/src/modules/extra/m_geoip.cpp
+++ b/src/modules/extra/m_geoip.cpp
@@ -17,9 +17,25 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: find_compiler_flags("geoip" "")
+/// $LinkerFlags: find_linker_flags("geoip" "-lGeoIP")
+
+/// $PackageInfo: require_system("centos" "7.0") GeoIP-devel pkgconfig
+/// $PackageInfo: require_system("darwin") geoip pkg-config
+/// $PackageInfo: require_system("debian") libgeoip-dev pkg-config
+/// $PackageInfo: require_system("ubuntu") libgeoip-dev pkg-config
#include "inspircd.h"
#include "xline.h"
+#include "modules/stats.h"
+#include "modules/whois.h"
+
+// Fix warnings about the use of commas at end of enumerator lists on C++03.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-extensions"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-pedantic"
+#endif
#include <GeoIP.h>
@@ -27,41 +43,46 @@
# pragma comment(lib, "GeoIP.lib")
#endif
-/* $ModDesc: Provides a way to restrict users by country using GeoIP lookup */
-/* $LinkerFlags: -lGeoIP */
+enum
+{
+ // InspIRCd-specific.
+ RPL_WHOISCOUNTRY = 344
+};
-class ModuleGeoIP : public Module
+class ModuleGeoIP : public Module, public Stats::EventListener, public Whois::EventListener
{
- LocalStringExt ext;
+ StringExtItem ext;
+ bool extban;
GeoIP* gi;
- std::string* SetExt(LocalUser* user)
+ std::string* SetExt(User* user)
{
- const char* c = GeoIP_country_code_by_addr(gi, user->GetIPString());
+ const char* c = GeoIP_country_code_by_addr(gi, user->GetIPString().c_str());
if (!c)
c = "UNK";
- std::string* cc = new std::string(c);
- ext.set(user, cc);
- return cc;
+ ext.set(user, c);
+ return ext.get(user);
}
public:
- ModuleGeoIP() : ext("geoip_cc", this), gi(NULL)
+ ModuleGeoIP()
+ : Stats::EventListener(this)
+ , Whois::EventListener(this)
+ , ext("geoip_cc", ExtensionItem::EXT_USER, this)
+ , extban(true)
+ , gi(NULL)
{
}
- void init()
+ void init() CXX11_OVERRIDE
{
gi = GeoIP_new(GEOIP_STANDARD);
if (gi == NULL)
throw ModuleException("Unable to initialize geoip, are you missing GeoIP.dat?");
- ServerInstance->Modules->AddService(ext);
- Implementation eventlist[] = { I_OnSetConnectClass, I_OnSetUserIP, I_OnStats };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
-
- for (LocalUserList::const_iterator i = ServerInstance->Users->local_users.begin(); i != ServerInstance->Users->local_users.end(); ++i)
+ const UserManager::LocalList& list = ServerInstance->Users.GetLocalUsers();
+ for (UserManager::LocalList::const_iterator i = list.begin(); i != list.end(); ++i)
{
LocalUser* user = *i;
if ((user->registered == REG_ALL) && (!ext.get(user)))
@@ -77,12 +98,38 @@ class ModuleGeoIP : public Module
GeoIP_delete(gi);
}
- Version GetVersion()
+ void ReadConfig(ConfigStatus&) CXX11_OVERRIDE
{
- return Version("Provides a way to assign users to connect classes by country using GeoIP lookup", VF_VENDOR);
+ ConfigTag* tag = ServerInstance->Config->ConfValue("geoip");
+ extban = tag->getBool("extban");
}
- ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass)
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Provides a way to assign users to connect classes by country using GeoIP lookup", VF_OPTCOMMON|VF_VENDOR);
+ }
+
+ void On005Numeric(std::map<std::string, std::string>& tokens) CXX11_OVERRIDE
+ {
+ if (extban)
+ tokens["EXTBAN"].push_back('G');
+ }
+
+ ModResult OnCheckBan(User* user, Channel*, const std::string& mask) CXX11_OVERRIDE
+ {
+ if (extban && (mask.length() > 2) && (mask[0] == 'G') && (mask[1] == ':'))
+ {
+ std::string* cc = ext.get(user);
+ if (!cc)
+ cc = SetExt(user);
+
+ if (InspIRCd::Match(*cc, mask.substr(2)))
+ return MOD_RES_DENY;
+ }
+ return MOD_RES_PASSTHRU;
+ }
+
+ ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass) CXX11_OVERRIDE
{
std::string* cc = ext.get(user);
if (!cc)
@@ -99,21 +146,36 @@ class ModuleGeoIP : public Module
return MOD_RES_DENY;
}
- void OnSetUserIP(LocalUser* user)
+ void OnSetUserIP(LocalUser* user) CXX11_OVERRIDE
{
// If user has sent NICK/USER, re-set the ExtItem as this is likely CGI:IRC changing the IP
if (user->registered == REG_NICKUSER)
SetExt(user);
}
- ModResult OnStats(char symbol, User* user, string_list &out)
+ void OnWhois(Whois::Context& whois) CXX11_OVERRIDE
+ {
+ // If the extban is disabled we don't expose users location.
+ if (!extban)
+ return;
+
+ std::string* cc = ext.get(whois.GetTarget());
+ if (!cc)
+ cc = SetExt(whois.GetTarget());
+
+ whois.SendLine(RPL_WHOISCOUNTRY, *cc, "is located in this country");
+ }
+
+ ModResult OnStats(Stats::Context& stats) CXX11_OVERRIDE
{
- if (symbol != 'G')
+ if (stats.GetSymbol() != 'G')
return MOD_RES_PASSTHRU;
unsigned int unknown = 0;
std::map<std::string, unsigned int> results;
- for (LocalUserList::const_iterator i = ServerInstance->Users->local_users.begin(); i != ServerInstance->Users->local_users.end(); ++i)
+
+ const UserManager::LocalList& list = ServerInstance->Users.GetLocalUsers();
+ for (UserManager::LocalList::const_iterator i = list.begin(); i != list.end(); ++i)
{
std::string* cc = ext.get(*i);
if (cc)
@@ -122,18 +184,16 @@ class ModuleGeoIP : public Module
unknown++;
}
- std::string p = ServerInstance->Config->ServerName + " 801 " + user->nick + " :GeoIPSTATS ";
for (std::map<std::string, unsigned int>::const_iterator i = results.begin(); i != results.end(); ++i)
{
- out.push_back(p + i->first + " " + ConvToStr(i->second));
+ stats.AddRow(801, "GeoIPSTATS " + i->first + " " + ConvToStr(i->second));
}
if (unknown)
- out.push_back(p + "Unknown " + ConvToStr(unknown));
+ stats.AddRow(801, "GeoIPSTATS Unknown " + ConvToStr(unknown));
return MOD_RES_DENY;
}
};
MODULE_INIT(ModuleGeoIP)
-
diff --git a/src/modules/extra/m_ldap.cpp b/src/modules/extra/m_ldap.cpp
new file mode 100644
index 000000000..8c2752dbf
--- /dev/null
+++ b/src/modules/extra/m_ldap.cpp
@@ -0,0 +1,677 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2013-2015 Adam <Adam@anope.org>
+ * Copyright (C) 2003-2015 Anope Team <team@anope.org>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/// $LinkerFlags: -llber -lldap_r
+
+/// $PackageInfo: require_system("centos") openldap-devel
+/// $PackageInfo: require_system("debian") libldap2-dev
+/// $PackageInfo: require_system("ubuntu") libldap2-dev
+
+#include "inspircd.h"
+#include "modules/ldap.h"
+
+// Ignore OpenLDAP deprecation warnings on OS X Yosemite and newer.
+#if defined __APPLE__
+# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+#include <ldap.h>
+
+#ifdef _WIN32
+# pragma comment(lib, "libldap_r.lib")
+# pragma comment(lib, "liblber.lib")
+#endif
+
+class LDAPService;
+
+class LDAPRequest
+{
+ public:
+ LDAPService* service;
+ LDAPInterface* inter;
+ LDAPMessage* message; /* message returned by ldap_ */
+ LDAPResult* result; /* final result */
+ struct timeval tv;
+ QueryType type;
+
+ LDAPRequest(LDAPService* s, LDAPInterface* i)
+ : service(s)
+ , inter(i)
+ , message(NULL)
+ , result(NULL)
+ {
+ type = QUERY_UNKNOWN;
+ tv.tv_sec = 0;
+ tv.tv_usec = 100000;
+ }
+
+ virtual ~LDAPRequest()
+ {
+ delete result;
+ if (message != NULL)
+ ldap_msgfree(message);
+ }
+
+ virtual int run() = 0;
+};
+
+class LDAPBind : public LDAPRequest
+{
+ std::string who, pass;
+
+ public:
+ LDAPBind(LDAPService* s, LDAPInterface* i, const std::string& w, const std::string& p)
+ : LDAPRequest(s, i)
+ , who(w)
+ , pass(p)
+ {
+ type = QUERY_BIND;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPSearch : public LDAPRequest
+{
+ std::string base;
+ int searchscope;
+ std::string filter;
+
+ public:
+ LDAPSearch(LDAPService* s, LDAPInterface* i, const std::string& b, int se, const std::string& f)
+ : LDAPRequest(s, i)
+ , base(b)
+ , searchscope(se)
+ , filter(f)
+ {
+ type = QUERY_SEARCH;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPAdd : public LDAPRequest
+{
+ std::string dn;
+ LDAPMods attributes;
+
+ public:
+ LDAPAdd(LDAPService* s, LDAPInterface* i, const std::string& d, const LDAPMods& attr)
+ : LDAPRequest(s, i)
+ , dn(d)
+ , attributes(attr)
+ {
+ type = QUERY_ADD;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPDel : public LDAPRequest
+{
+ std::string dn;
+
+ public:
+ LDAPDel(LDAPService* s, LDAPInterface* i, const std::string& d)
+ : LDAPRequest(s, i)
+ , dn(d)
+ {
+ type = QUERY_DELETE;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPModify : public LDAPRequest
+{
+ std::string base;
+ LDAPMods attributes;
+
+ public:
+ LDAPModify(LDAPService* s, LDAPInterface* i, const std::string& b, const LDAPMods& attr)
+ : LDAPRequest(s, i)
+ , base(b)
+ , attributes(attr)
+ {
+ type = QUERY_MODIFY;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPCompare : public LDAPRequest
+{
+ std::string dn, attr, val;
+
+ public:
+ LDAPCompare(LDAPService* s, LDAPInterface* i, const std::string& d, const std::string& a, const std::string& v)
+ : LDAPRequest(s, i)
+ , dn(d)
+ , attr(a)
+ , val(v)
+ {
+ type = QUERY_COMPARE;
+ }
+
+ int run() CXX11_OVERRIDE;
+};
+
+class LDAPService : public LDAPProvider, public SocketThread
+{
+ LDAP* con;
+ reference<ConfigTag> config;
+ time_t last_connect;
+ int searchscope;
+ time_t timeout;
+
+ public:
+ static LDAPMod** BuildMods(const LDAPMods& attributes)
+ {
+ LDAPMod** mods = new LDAPMod*[attributes.size() + 1];
+ memset(mods, 0, sizeof(LDAPMod*) * (attributes.size() + 1));
+ for (unsigned int x = 0; x < attributes.size(); ++x)
+ {
+ const LDAPModification& l = attributes[x];
+ LDAPMod* mod = new LDAPMod;
+ mods[x] = mod;
+
+ if (l.op == LDAPModification::LDAP_ADD)
+ mod->mod_op = LDAP_MOD_ADD;
+ else if (l.op == LDAPModification::LDAP_DEL)
+ mod->mod_op = LDAP_MOD_DELETE;
+ else if (l.op == LDAPModification::LDAP_REPLACE)
+ mod->mod_op = LDAP_MOD_REPLACE;
+ else if (l.op != 0)
+ {
+ FreeMods(mods);
+ throw LDAPException("Unknown LDAP operation");
+ }
+ mod->mod_type = strdup(l.name.c_str());
+ mod->mod_values = new char*[l.values.size() + 1];
+ memset(mod->mod_values, 0, sizeof(char*) * (l.values.size() + 1));
+ for (unsigned int j = 0, c = 0; j < l.values.size(); ++j)
+ if (!l.values[j].empty())
+ mod->mod_values[c++] = strdup(l.values[j].c_str());
+ }
+ return mods;
+ }
+
+ static void FreeMods(LDAPMod** mods)
+ {
+ for (unsigned int i = 0; mods[i] != NULL; ++i)
+ {
+ LDAPMod* mod = mods[i];
+ if (mod->mod_type != NULL)
+ free(mod->mod_type);
+ if (mod->mod_values != NULL)
+ {
+ for (unsigned int j = 0; mod->mod_values[j] != NULL; ++j)
+ free(mod->mod_values[j]);
+ delete[] mod->mod_values;
+ }
+ }
+ delete[] mods;
+ }
+
+ private:
+ void Reconnect()
+ {
+ // Only try one connect a minute. It is an expensive blocking operation
+ if (last_connect > ServerInstance->Time() - 60)
+ throw LDAPException("Unable to connect to LDAP service " + this->name + ": reconnecting too fast");
+ last_connect = ServerInstance->Time();
+
+ ldap_unbind_ext(this->con, NULL, NULL);
+ Connect();
+ }
+
+ void QueueRequest(LDAPRequest* r)
+ {
+ this->LockQueue();
+ this->queries.push_back(r);
+ this->UnlockQueueWakeup();
+ }
+
+ public:
+ typedef std::vector<LDAPRequest*> query_queue;
+ query_queue queries, results;
+ Mutex process_mutex; /* held when processing requests not in either queue */
+
+ LDAPService(Module* c, ConfigTag* tag)
+ : LDAPProvider(c, "LDAP/" + tag->getString("id"))
+ , con(NULL), config(tag), last_connect(0)
+ {
+ std::string scope = config->getString("searchscope");
+ if (stdalgo::string::equalsci(scope, "base"))
+ searchscope = LDAP_SCOPE_BASE;
+ else if (stdalgo::string::equalsci(scope, "onelevel"))
+ searchscope = LDAP_SCOPE_ONELEVEL;
+ else
+ searchscope = LDAP_SCOPE_SUBTREE;
+ timeout = config->getDuration("timeout", 5);
+
+ Connect();
+ }
+
+ ~LDAPService()
+ {
+ this->LockQueue();
+
+ for (unsigned int i = 0; i < this->queries.size(); ++i)
+ {
+ LDAPRequest* req = this->queries[i];
+
+ /* queries have no results yet */
+ req->result = new LDAPResult();
+ req->result->type = req->type;
+ req->result->error = "LDAP Interface is going away";
+ req->inter->OnError(*req->result);
+
+ delete req;
+ }
+ this->queries.clear();
+
+ for (unsigned int i = 0; i < this->results.size(); ++i)
+ {
+ LDAPRequest* req = this->results[i];
+
+ /* even though this may have already finished successfully we return that it didn't */
+ req->result->error = "LDAP Interface is going away";
+ req->inter->OnError(*req->result);
+
+ delete req;
+ }
+ this->results.clear();
+
+ this->UnlockQueue();
+
+ ldap_unbind_ext(this->con, NULL, NULL);
+ }
+
+ void Connect()
+ {
+ std::string server = config->getString("server");
+ int i = ldap_initialize(&this->con, server.c_str());
+ if (i != LDAP_SUCCESS)
+ throw LDAPException("Unable to connect to LDAP service " + this->name + ": " + ldap_err2string(i));
+
+ const int version = LDAP_VERSION3;
+ i = ldap_set_option(this->con, LDAP_OPT_PROTOCOL_VERSION, &version);
+ if (i != LDAP_OPT_SUCCESS)
+ {
+ ldap_unbind_ext(this->con, NULL, NULL);
+ this->con = NULL;
+ throw LDAPException("Unable to set protocol version for " + this->name + ": " + ldap_err2string(i));
+ }
+
+ const struct timeval tv = { 0, 0 };
+ i = ldap_set_option(this->con, LDAP_OPT_NETWORK_TIMEOUT, &tv);
+ if (i != LDAP_OPT_SUCCESS)
+ {
+ ldap_unbind_ext(this->con, NULL, NULL);
+ this->con = NULL;
+ throw LDAPException("Unable to set timeout for " + this->name + ": " + ldap_err2string(i));
+ }
+ }
+
+ void BindAsManager(LDAPInterface* i) CXX11_OVERRIDE
+ {
+ std::string binddn = config->getString("binddn");
+ std::string bindauth = config->getString("bindauth");
+ this->Bind(i, binddn, bindauth);
+ }
+
+ void Bind(LDAPInterface* i, const std::string& who, const std::string& pass) CXX11_OVERRIDE
+ {
+ LDAPBind* b = new LDAPBind(this, i, who, pass);
+ QueueRequest(b);
+ }
+
+ void Search(LDAPInterface* i, const std::string& base, const std::string& filter) CXX11_OVERRIDE
+ {
+ if (i == NULL)
+ throw LDAPException("No interface");
+
+ LDAPSearch* s = new LDAPSearch(this, i, base, searchscope, filter);
+ QueueRequest(s);
+ }
+
+ void Add(LDAPInterface* i, const std::string& dn, LDAPMods& attributes) CXX11_OVERRIDE
+ {
+ LDAPAdd* add = new LDAPAdd(this, i, dn, attributes);
+ QueueRequest(add);
+ }
+
+ void Del(LDAPInterface* i, const std::string& dn) CXX11_OVERRIDE
+ {
+ LDAPDel* del = new LDAPDel(this, i, dn);
+ QueueRequest(del);
+ }
+
+ void Modify(LDAPInterface* i, const std::string& base, LDAPMods& attributes) CXX11_OVERRIDE
+ {
+ LDAPModify* mod = new LDAPModify(this, i, base, attributes);
+ QueueRequest(mod);
+ }
+
+ void Compare(LDAPInterface* i, const std::string& dn, const std::string& attr, const std::string& val) CXX11_OVERRIDE
+ {
+ LDAPCompare* comp = new LDAPCompare(this, i, dn, attr, val);
+ QueueRequest(comp);
+ }
+
+ private:
+ void BuildReply(int res, LDAPRequest* req)
+ {
+ LDAPResult* ldap_result = req->result = new LDAPResult();
+ req->result->type = req->type;
+
+ if (res != LDAP_SUCCESS)
+ {
+ ldap_result->error = ldap_err2string(res);
+ return;
+ }
+
+ if (req->message == NULL)
+ {
+ return;
+ }
+
+ /* a search result */
+
+ for (LDAPMessage* cur = ldap_first_message(this->con, req->message); cur; cur = ldap_next_message(this->con, cur))
+ {
+ LDAPAttributes attributes;
+
+ char* dn = ldap_get_dn(this->con, cur);
+ if (dn != NULL)
+ {
+ attributes["dn"].push_back(dn);
+ ldap_memfree(dn);
+ dn = NULL;
+ }
+
+ BerElement* ber = NULL;
+
+ for (char* attr = ldap_first_attribute(this->con, cur, &ber); attr; attr = ldap_next_attribute(this->con, cur, ber))
+ {
+ berval** vals = ldap_get_values_len(this->con, cur, attr);
+ int count = ldap_count_values_len(vals);
+
+ std::vector<std::string> attrs;
+ for (int j = 0; j < count; ++j)
+ attrs.push_back(vals[j]->bv_val);
+ attributes[attr] = attrs;
+
+ ldap_value_free_len(vals);
+ ldap_memfree(attr);
+ }
+ if (ber != NULL)
+ ber_free(ber, 0);
+
+ ldap_result->messages.push_back(attributes);
+ }
+ }
+
+ void SendRequests()
+ {
+ process_mutex.Lock();
+
+ query_queue q;
+ this->LockQueue();
+ queries.swap(q);
+ this->UnlockQueue();
+
+ if (q.empty())
+ {
+ process_mutex.Unlock();
+ return;
+ }
+
+ for (unsigned int i = 0; i < q.size(); ++i)
+ {
+ LDAPRequest* req = q[i];
+ int ret = req->run();
+
+ if (ret == LDAP_SERVER_DOWN || ret == LDAP_TIMEOUT)
+ {
+ /* try again */
+ try
+ {
+ Reconnect();
+ }
+ catch (const LDAPException &)
+ {
+ }
+
+ ret = req->run();
+ }
+
+ BuildReply(ret, req);
+
+ this->LockQueue();
+ this->results.push_back(req);
+ this->UnlockQueue();
+ }
+
+ this->NotifyParent();
+
+ process_mutex.Unlock();
+ }
+
+ public:
+ void Run() CXX11_OVERRIDE
+ {
+ while (!this->GetExitFlag())
+ {
+ this->LockQueue();
+ if (this->queries.empty())
+ this->WaitForQueue();
+ this->UnlockQueue();
+
+ SendRequests();
+ }
+ }
+
+ void OnNotify() CXX11_OVERRIDE
+ {
+ query_queue r;
+
+ this->LockQueue();
+ this->results.swap(r);
+ this->UnlockQueue();
+
+ for (unsigned int i = 0; i < r.size(); ++i)
+ {
+ LDAPRequest* req = r[i];
+ LDAPInterface* li = req->inter;
+ LDAPResult* res = req->result;
+
+ if (!res->error.empty())
+ li->OnError(*res);
+ else
+ li->OnResult(*res);
+
+ delete req;
+ }
+ }
+
+ LDAP* GetConnection()
+ {
+ return con;
+ }
+};
+
+class ModuleLDAP : public Module
+{
+ typedef insp::flat_map<std::string, LDAPService*> ServiceMap;
+ ServiceMap LDAPServices;
+
+ public:
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
+ {
+ ServiceMap conns;
+
+ ConfigTagList tags = ServerInstance->Config->ConfTags("database");
+ for (ConfigIter i = tags.first; i != tags.second; i++)
+ {
+ const reference<ConfigTag>& tag = i->second;
+
+ if (!stdalgo::string::equalsci(tag->getString("module"), "ldap"))
+ continue;
+
+ std::string id = tag->getString("id");
+
+ ServiceMap::iterator curr = LDAPServices.find(id);
+ if (curr == LDAPServices.end())
+ {
+ LDAPService* conn = new LDAPService(this, tag);
+ conns[id] = conn;
+
+ ServerInstance->Modules->AddService(*conn);
+ ServerInstance->Threads.Start(conn);
+ }
+ else
+ {
+ conns.insert(*curr);
+ LDAPServices.erase(curr);
+ }
+ }
+
+ for (ServiceMap::iterator i = LDAPServices.begin(); i != LDAPServices.end(); ++i)
+ {
+ LDAPService* conn = i->second;
+ ServerInstance->Modules->DelService(*conn);
+ conn->join();
+ conn->OnNotify();
+ delete conn;
+ }
+
+ LDAPServices.swap(conns);
+ }
+
+ void OnUnloadModule(Module* m) CXX11_OVERRIDE
+ {
+ for (ServiceMap::iterator it = this->LDAPServices.begin(); it != this->LDAPServices.end(); ++it)
+ {
+ LDAPService* s = it->second;
+
+ s->process_mutex.Lock();
+ s->LockQueue();
+
+ for (unsigned int i = s->queries.size(); i > 0; --i)
+ {
+ LDAPRequest* req = s->queries[i - 1];
+ LDAPInterface* li = req->inter;
+
+ if (li->creator == m)
+ {
+ s->queries.erase(s->queries.begin() + i - 1);
+ delete req;
+ }
+ }
+
+ for (unsigned int i = s->results.size(); i > 0; --i)
+ {
+ LDAPRequest* req = s->results[i - 1];
+ LDAPInterface* li = req->inter;
+
+ if (li->creator == m)
+ {
+ s->results.erase(s->results.begin() + i - 1);
+ delete req;
+ }
+ }
+
+ s->UnlockQueue();
+ s->process_mutex.Unlock();
+ }
+ }
+
+ ~ModuleLDAP()
+ {
+ for (ServiceMap::iterator i = LDAPServices.begin(); i != LDAPServices.end(); ++i)
+ {
+ LDAPService* conn = i->second;
+ conn->join();
+ conn->OnNotify();
+ delete conn;
+ }
+ }
+
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("LDAP support", VF_VENDOR);
+ }
+};
+
+int LDAPBind::run()
+{
+ berval cred;
+ cred.bv_val = strdup(pass.c_str());
+ cred.bv_len = pass.length();
+
+ int i = ldap_sasl_bind_s(service->GetConnection(), who.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL);
+
+ free(cred.bv_val);
+
+ return i;
+}
+
+int LDAPSearch::run()
+{
+ return ldap_search_ext_s(service->GetConnection(), base.c_str(), searchscope, filter.c_str(), NULL, 0, NULL, NULL, &tv, 0, &message);
+}
+
+int LDAPAdd::run()
+{
+ LDAPMod** mods = LDAPService::BuildMods(attributes);
+ int i = ldap_add_ext_s(service->GetConnection(), dn.c_str(), mods, NULL, NULL);
+ LDAPService::FreeMods(mods);
+ return i;
+}
+
+int LDAPDel::run()
+{
+ return ldap_delete_ext_s(service->GetConnection(), dn.c_str(), NULL, NULL);
+}
+
+int LDAPModify::run()
+{
+ LDAPMod** mods = LDAPService::BuildMods(attributes);
+ int i = ldap_modify_ext_s(service->GetConnection(), base.c_str(), mods, NULL, NULL);
+ LDAPService::FreeMods(mods);
+ return i;
+}
+
+int LDAPCompare::run()
+{
+ berval cred;
+ cred.bv_val = strdup(val.c_str());
+ cred.bv_len = val.length();
+
+ int ret = ldap_compare_ext_s(service->GetConnection(), dn.c_str(), attr.c_str(), &cred, NULL, NULL);
+
+ free(cred.bv_val);
+
+ return ret;
+
+}
+
+MODULE_INIT(ModuleLDAP)
diff --git a/src/modules/extra/m_ldapauth.cpp b/src/modules/extra/m_ldapauth.cpp
deleted file mode 100644
index 405bab082..000000000
--- a/src/modules/extra/m_ldapauth.cpp
+++ /dev/null
@@ -1,425 +0,0 @@
-/*
- * InspIRCd -- Internet Relay Chat Daemon
- *
- * Copyright (C) 2011 Pierre Carrier <pierre@spotify.com>
- * Copyright (C) 2009-2010 Robin Burchell <robin+git@viroteck.net>
- * Copyright (C) 2009 Daniel De Graaf <danieldg@inspircd.org>
- * Copyright (C) 2008 Pippijn van Steenhoven <pip88nl@gmail.com>
- * Copyright (C) 2008 Craig Edwards <craigedwards@brainbox.cc>
- * Copyright (C) 2008 Dennis Friis <peavey@inspircd.org>
- * Copyright (C) 2007 Carsten Valdemar Munk <carsten.munk+inspircd@gmail.com>
- *
- * This file is part of InspIRCd. InspIRCd is free software: you can
- * redistribute it and/or modify it under the terms of the GNU General Public
- * License as published by the Free Software Foundation, version 2.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
- * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
- * details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-
-#include "inspircd.h"
-#include "users.h"
-#include "channels.h"
-#include "modules.h"
-
-#include <ldap.h>
-
-#ifdef _WIN32
-# pragma comment(lib, "libldap.lib")
-# pragma comment(lib, "liblber.lib")
-#endif
-
-/* $ModDesc: Allow/Deny connections based upon answer from LDAP server */
-/* $LinkerFlags: -lldap */
-
-struct RAIILDAPString
-{
- char *str;
-
- RAIILDAPString(char *Str)
- : str(Str)
- {
- }
-
- ~RAIILDAPString()
- {
- ldap_memfree(str);
- }
-
- operator char*()
- {
- return str;
- }
-
- operator std::string()
- {
- return str;
- }
-};
-
-struct RAIILDAPMessage
-{
- RAIILDAPMessage()
- {
- }
-
- ~RAIILDAPMessage()
- {
- dealloc();
- }
-
- void dealloc()
- {
- ldap_msgfree(msg);
- }
-
- operator LDAPMessage*()
- {
- return msg;
- }
-
- LDAPMessage **operator &()
- {
- return &msg;
- }
-
- LDAPMessage *msg;
-};
-
-class ModuleLDAPAuth : public Module
-{
- LocalIntExt ldapAuthed;
- LocalStringExt ldapVhost;
- std::string base;
- std::string attribute;
- std::string ldapserver;
- std::string allowpattern;
- std::string killreason;
- std::string username;
- std::string password;
- std::string vhost;
- std::vector<std::string> whitelistedcidrs;
- std::vector<std::pair<std::string, std::string> > requiredattributes;
- int searchscope;
- bool verbose;
- bool useusername;
- LDAP *conn;
-
-public:
- ModuleLDAPAuth()
- : ldapAuthed("ldapauth", this)
- , ldapVhost("ldapauth_vhost", this)
- {
- conn = NULL;
- }
-
- void init()
- {
- ServerInstance->Modules->AddService(ldapAuthed);
- ServerInstance->Modules->AddService(ldapVhost);
- Implementation eventlist[] = { I_OnCheckReady, I_OnRehash,I_OnUserRegister, I_OnUserConnect };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- OnRehash(NULL);
- }
-
- ~ModuleLDAPAuth()
- {
- if (conn)
- ldap_unbind_ext(conn, NULL, NULL);
- }
-
- void OnRehash(User* user)
- {
- ConfigTag* tag = ServerInstance->Config->ConfValue("ldapauth");
- whitelistedcidrs.clear();
- requiredattributes.clear();
-
- base = tag->getString("baserdn");
- attribute = tag->getString("attribute");
- ldapserver = tag->getString("server");
- allowpattern = tag->getString("allowpattern");
- killreason = tag->getString("killreason");
- std::string scope = tag->getString("searchscope");
- username = tag->getString("binddn");
- password = tag->getString("bindauth");
- vhost = tag->getString("host");
- verbose = tag->getBool("verbose"); /* Set to true if failed connects should be reported to operators */
- useusername = tag->getBool("userfield");
-
- ConfigTagList whitelisttags = ServerInstance->Config->ConfTags("ldapwhitelist");
-
- for (ConfigIter i = whitelisttags.first; i != whitelisttags.second; ++i)
- {
- std::string cidr = i->second->getString("cidr");
- if (!cidr.empty()) {
- whitelistedcidrs.push_back(cidr);
- }
- }
-
- ConfigTagList attributetags = ServerInstance->Config->ConfTags("ldaprequire");
-
- for (ConfigIter i = attributetags.first; i != attributetags.second; ++i)
- {
- const std::string attr = i->second->getString("attribute");
- const std::string val = i->second->getString("value");
-
- if (!attr.empty() && !val.empty())
- requiredattributes.push_back(make_pair(attr, val));
- }
-
- if (scope == "base")
- searchscope = LDAP_SCOPE_BASE;
- else if (scope == "onelevel")
- searchscope = LDAP_SCOPE_ONELEVEL;
- else searchscope = LDAP_SCOPE_SUBTREE;
-
- Connect();
- }
-
- bool Connect()
- {
- if (conn != NULL)
- ldap_unbind_ext(conn, NULL, NULL);
- int res, v = LDAP_VERSION3;
- res = ldap_initialize(&conn, ldapserver.c_str());
- if (res != LDAP_SUCCESS)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "LDAP connection failed: %s", ldap_err2string(res));
- conn = NULL;
- return false;
- }
-
- res = ldap_set_option(conn, LDAP_OPT_PROTOCOL_VERSION, (void *)&v);
- if (res != LDAP_SUCCESS)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "LDAP set protocol to v3 failed: %s", ldap_err2string(res));
- ldap_unbind_ext(conn, NULL, NULL);
- conn = NULL;
- return false;
- }
- return true;
- }
-
- std::string SafeReplace(const std::string &text, std::map<std::string,
- std::string> &replacements)
- {
- std::string result;
- result.reserve(MAXBUF);
-
- for (unsigned int i = 0; i < text.length(); ++i) {
- char c = text[i];
- if (c == '$') {
- // find the first nonalpha
- i++;
- unsigned int start = i;
-
- while (i < text.length() - 1 && isalpha(text[i + 1]))
- ++i;
-
- std::string key = text.substr(start, (i - start) + 1);
- result.append(replacements[key]);
- } else {
- result.push_back(c);
- }
- }
-
- return result;
- }
-
- virtual void OnUserConnect(LocalUser *user)
- {
- std::string* cc = ldapVhost.get(user);
- if (cc)
- {
- user->ChangeDisplayedHost(cc->c_str());
- ldapVhost.unset(user);
- }
- }
-
- ModResult OnUserRegister(LocalUser* user)
- {
- if ((!allowpattern.empty()) && (InspIRCd::Match(user->nick,allowpattern)))
- {
- ldapAuthed.set(user,1);
- return MOD_RES_PASSTHRU;
- }
-
- for (std::vector<std::string>::iterator i = whitelistedcidrs.begin(); i != whitelistedcidrs.end(); i++)
- {
- if (InspIRCd::MatchCIDR(user->GetIPString(), *i, ascii_case_insensitive_map))
- {
- ldapAuthed.set(user,1);
- return MOD_RES_PASSTHRU;
- }
- }
-
- if (!CheckCredentials(user))
- {
- ServerInstance->Users->QuitUser(user, killreason);
- return MOD_RES_DENY;
- }
- return MOD_RES_PASSTHRU;
- }
-
- bool CheckCredentials(LocalUser* user)
- {
- if (conn == NULL)
- if (!Connect())
- return false;
-
- if (user->password.empty())
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (No password provided)", user->GetFullRealHost().c_str());
- return false;
- }
-
- int res;
- // bind anonymously if no bind DN and authentication are given in the config
- struct berval cred;
- cred.bv_val = const_cast<char*>(password.c_str());
- cred.bv_len = password.length();
-
- if ((res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS)
- {
- if (res == LDAP_SERVER_DOWN)
- {
- // Attempt to reconnect if the connection dropped
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('a', "LDAP server has gone away - reconnecting...");
- Connect();
- res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL);
- }
-
- if (res != LDAP_SUCCESS)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP bind failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
- ldap_unbind_ext(conn, NULL, NULL);
- conn = NULL;
- return false;
- }
- }
-
- RAIILDAPMessage msg;
- std::string what;
- std::string::size_type pos = user->password.find(':');
- // If a username is provided in PASS, use it, othewrise user their nick or ident
- if (pos != std::string::npos)
- {
- what = (attribute + "=" + user->password.substr(0, pos));
-
- // Trim the user: prefix, leaving just 'pass' for later password check
- user->password = user->password.substr(pos + 1);
- }
- else
- {
- what = (attribute + "=" + (useusername ? user->ident : user->nick));
- }
- if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
- return false;
- }
- if (ldap_count_entries(conn, msg) > 1)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned more than one result: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
- return false;
- }
-
- LDAPMessage *entry;
- if ((entry = ldap_first_entry(conn, msg)) == NULL)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned no results: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
- return false;
- }
- cred.bv_val = (char*)user->password.data();
- cred.bv_len = user->password.length();
- RAIILDAPString DN(ldap_get_dn(conn, entry));
- if ((res = ldap_sasl_bind_s(conn, DN, LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (%s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
- return false;
- }
-
- if (!requiredattributes.empty())
- {
- bool authed = false;
-
- for (std::vector<std::pair<std::string, std::string> >::const_iterator it = requiredattributes.begin(); it != requiredattributes.end(); ++it)
- {
- const std::string &attr = it->first;
- const std::string &val = it->second;
-
- struct berval attr_value;
- attr_value.bv_val = const_cast<char*>(val.c_str());
- attr_value.bv_len = val.length();
-
- ServerInstance->Logs->Log("m_ldapauth", DEBUG, "LDAP compare: %s=%s", attr.c_str(), val.c_str());
-
- authed = (ldap_compare_ext_s(conn, DN, attr.c_str(), &attr_value, NULL, NULL) == LDAP_COMPARE_TRUE);
-
- if (authed)
- break;
- }
-
- if (!authed)
- {
- if (verbose)
- ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (Lacks required LDAP attributes)", user->GetFullRealHost().c_str());
- return false;
- }
- }
-
- if (!vhost.empty())
- {
- irc::commasepstream stream(DN);
-
- // mashed map of key:value parts of the DN
- std::map<std::string, std::string> dnParts;
-
- std::string dnPart;
- while (stream.GetToken(dnPart))
- {
- pos = dnPart.find('=');
- if (pos == std::string::npos) // malformed
- continue;
-
- std::string key = dnPart.substr(0, pos);
- std::string value = dnPart.substr(pos + 1, dnPart.length() - pos + 1); // +1s to skip the = itself
- dnParts[key] = value;
- }
-
- // change host according to config key
- ldapVhost.set(user, SafeReplace(vhost, dnParts));
- }
-
- ldapAuthed.set(user,1);
- return true;
- }
-
- ModResult OnCheckReady(LocalUser* user)
- {
- return ldapAuthed.get(user) ? MOD_RES_PASSTHRU : MOD_RES_DENY;
- }
-
- Version GetVersion()
- {
- return Version("Allow/Deny connections based upon answer from LDAP server", VF_VENDOR);
- }
-
-};
-
-MODULE_INIT(ModuleLDAPAuth)
diff --git a/src/modules/extra/m_ldapoper.cpp b/src/modules/extra/m_ldapoper.cpp
deleted file mode 100644
index 1f46361d4..000000000
--- a/src/modules/extra/m_ldapoper.cpp
+++ /dev/null
@@ -1,255 +0,0 @@
-/*
- * InspIRCd -- Internet Relay Chat Daemon
- *
- * Copyright (C) 2009 Robin Burchell <robin+git@viroteck.net>
- * Copyright (C) 2008 Pippijn van Steenhoven <pip88nl@gmail.com>
- * Copyright (C) 2008 Craig Edwards <craigedwards@brainbox.cc>
- * Copyright (C) 2007 Carsten Valdemar Munk <carsten.munk+inspircd@gmail.com>
- *
- * This file is part of InspIRCd. InspIRCd is free software: you can
- * redistribute it and/or modify it under the terms of the GNU General Public
- * License as published by the Free Software Foundation, version 2.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
- * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
- * details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-
-#include "inspircd.h"
-#include "users.h"
-#include "channels.h"
-#include "modules.h"
-
-#include <ldap.h>
-
-#ifdef _WIN32
-# pragma comment(lib, "libldap.lib")
-# pragma comment(lib, "liblber.lib")
-#endif
-
-/* $ModDesc: Adds the ability to authenticate opers via LDAP */
-/* $LinkerFlags: -lldap */
-
-// Duplicated code, also found in cmd_oper and m_sqloper
-static bool OneOfMatches(const char* host, const char* ip, const std::string& hostlist)
-{
- std::stringstream hl(hostlist);
- std::string xhost;
- while (hl >> xhost)
- {
- if (InspIRCd::Match(host, xhost, ascii_case_insensitive_map) || InspIRCd::MatchCIDR(ip, xhost, ascii_case_insensitive_map))
- {
- return true;
- }
- }
- return false;
-}
-
-struct RAIILDAPString
-{
- char *str;
-
- RAIILDAPString(char *Str)
- : str(Str)
- {
- }
-
- ~RAIILDAPString()
- {
- ldap_memfree(str);
- }
-
- operator char*()
- {
- return str;
- }
-
- operator std::string()
- {
- return str;
- }
-};
-
-class ModuleLDAPAuth : public Module
-{
- std::string base;
- std::string ldapserver;
- std::string username;
- std::string password;
- std::string attribute;
- int searchscope;
- LDAP *conn;
-
- bool HandleOper(LocalUser* user, const std::string& opername, const std::string& inputpass)
- {
- OperIndex::iterator it = ServerInstance->Config->oper_blocks.find(opername);
- if (it == ServerInstance->Config->oper_blocks.end())
- return false;
-
- ConfigTag* tag = it->second->oper_block;
- if (!tag)
- return false;
-
- std::string acceptedhosts = tag->getString("host");
- std::string hostname = user->ident + "@" + user->host;
- if (!OneOfMatches(hostname.c_str(), user->GetIPString(), acceptedhosts))
- return false;
-
- if (!LookupOper(opername, inputpass))
- return false;
-
- user->Oper(it->second);
- return true;
- }
-
-public:
- ModuleLDAPAuth()
- : conn(NULL)
- {
- }
-
- void init()
- {
- Implementation eventlist[] = { I_OnRehash, I_OnPreCommand };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- OnRehash(NULL);
- }
-
- virtual ~ModuleLDAPAuth()
- {
- if (conn)
- ldap_unbind_ext(conn, NULL, NULL);
- }
-
- virtual void OnRehash(User* user)
- {
- ConfigTag* tag = ServerInstance->Config->ConfValue("ldapoper");
-
- base = tag->getString("baserdn");
- ldapserver = tag->getString("server");
- std::string scope = tag->getString("searchscope");
- username = tag->getString("binddn");
- password = tag->getString("bindauth");
- attribute = tag->getString("attribute");
-
- if (scope == "base")
- searchscope = LDAP_SCOPE_BASE;
- else if (scope == "onelevel")
- searchscope = LDAP_SCOPE_ONELEVEL;
- else searchscope = LDAP_SCOPE_SUBTREE;
-
- Connect();
- }
-
- bool Connect()
- {
- if (conn != NULL)
- ldap_unbind_ext(conn, NULL, NULL);
- int res, v = LDAP_VERSION3;
- res = ldap_initialize(&conn, ldapserver.c_str());
- if (res != LDAP_SUCCESS)
- {
- conn = NULL;
- return false;
- }
-
- res = ldap_set_option(conn, LDAP_OPT_PROTOCOL_VERSION, (void *)&v);
- if (res != LDAP_SUCCESS)
- {
- ldap_unbind_ext(conn, NULL, NULL);
- conn = NULL;
- return false;
- }
- return true;
- }
-
- ModResult OnPreCommand(std::string& command, std::vector<std::string>& parameters, LocalUser* user, bool validated, const std::string& original_line)
- {
- if (validated && command == "OPER" && parameters.size() >= 2)
- {
- if (HandleOper(user, parameters[0], parameters[1]))
- return MOD_RES_DENY;
- }
- return MOD_RES_PASSTHRU;
- }
-
- bool LookupOper(const std::string& opername, const std::string& opassword)
- {
- if (conn == NULL)
- if (!Connect())
- return false;
-
- int res;
- char* authpass = strdup(password.c_str());
- // bind anonymously if no bind DN and authentication are given in the config
- struct berval cred;
- cred.bv_val = authpass;
- cred.bv_len = password.length();
-
- if ((res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS)
- {
- if (res == LDAP_SERVER_DOWN)
- {
- // Attempt to reconnect if the connection dropped
- ServerInstance->SNO->WriteToSnoMask('a', "LDAP server has gone away - reconnecting...");
- Connect();
- res = ldap_sasl_bind_s(conn, username.c_str(), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL);
- }
-
- if (res != LDAP_SUCCESS)
- {
- free(authpass);
- ldap_unbind_ext(conn, NULL, NULL);
- conn = NULL;
- return false;
- }
- }
- free(authpass);
-
- LDAPMessage *msg, *entry;
- std::string what = attribute + "=" + opername;
- if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS)
- {
- return false;
- }
- if (ldap_count_entries(conn, msg) > 1)
- {
- ldap_msgfree(msg);
- return false;
- }
- if ((entry = ldap_first_entry(conn, msg)) == NULL)
- {
- ldap_msgfree(msg);
- return false;
- }
- authpass = strdup(opassword.c_str());
- cred.bv_val = authpass;
- cred.bv_len = opassword.length();
- RAIILDAPString DN(ldap_get_dn(conn, entry));
- if ((res = ldap_sasl_bind_s(conn, DN, LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) == LDAP_SUCCESS)
- {
- free(authpass);
- ldap_msgfree(msg);
- return true;
- }
- else
- {
- free(authpass);
- ldap_msgfree(msg);
- return false;
- }
- }
-
- virtual Version GetVersion()
- {
- return Version("Adds the ability to authenticate opers via LDAP", VF_VENDOR);
- }
-
-};
-
-MODULE_INIT(ModuleLDAPAuth)
diff --git a/src/modules/extra/m_mssql.cpp b/src/modules/extra/m_mssql.cpp
deleted file mode 100644
index 598f9aac9..000000000
--- a/src/modules/extra/m_mssql.cpp
+++ /dev/null
@@ -1,870 +0,0 @@
-/*
- * InspIRCd -- Internet Relay Chat Daemon
- *
- * Copyright (C) 2008-2009 Dennis Friis <peavey@inspircd.org>
- * Copyright (C) 2009 Daniel De Graaf <danieldg@inspircd.org>
- * Copyright (C) 2008-2009 Craig Edwards <craigedwards@brainbox.cc>
- * Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
- * Copyright (C) 2008 Pippijn van Steenhoven <pip88nl@gmail.com>
- *
- * This file is part of InspIRCd. InspIRCd is free software: you can
- * redistribute it and/or modify it under the terms of the GNU General Public
- * License as published by the Free Software Foundation, version 2.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
- * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
- * details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-
-#include "inspircd.h"
-#include <tds.h>
-#include <tdsconvert.h>
-#include "users.h"
-#include "channels.h"
-#include "modules.h"
-
-#include "m_sqlv2.h"
-
-/* $ModDesc: MsSQL provider */
-/* $CompileFlags: exec("grep VERSION_NO /usr/include/tdsver.h 2>/dev/null | perl -e 'print "-D_TDSVER=".((<> =~ /freetds v(\d+\.\d+)/i) ? $1*100 : 0);'") */
-/* $LinkerFlags: -ltds */
-/* $ModDep: m_sqlv2.h */
-
-class SQLConn;
-class MsSQLResult;
-class ModuleMsSQL;
-
-typedef std::map<std::string, SQLConn*> ConnMap;
-typedef std::deque<MsSQLResult*> ResultQueue;
-
-unsigned long count(const char * const str, char a)
-{
- unsigned long n = 0;
- for (const char *p = str; *p; ++p)
- {
- if (*p == '?')
- ++n;
- }
- return n;
-}
-
-ConnMap connections;
-Mutex* ResultsMutex;
-Mutex* LoggingMutex;
-
-class QueryThread : public SocketThread
-{
- private:
- ModuleMsSQL* const Parent;
- public:
- QueryThread(ModuleMsSQL* mod) : Parent(mod) { }
- ~QueryThread() { }
- virtual void Run();
- virtual void OnNotify();
-};
-
-class MsSQLResult : public SQLresult
-{
- private:
- int currentrow;
- int rows;
- int cols;
-
- std::vector<std::string> colnames;
- std::vector<SQLfieldList> fieldlists;
- SQLfieldList emptyfieldlist;
-
- SQLfieldList* fieldlist;
- SQLfieldMap* fieldmap;
-
- public:
- MsSQLResult(Module* self, Module* to, unsigned int rid)
- : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
- {
- }
-
- ~MsSQLResult()
- {
- }
-
- void AddRow(int colsnum, char **dat, char **colname)
- {
- colnames.clear();
- cols = colsnum;
- for (int i = 0; i < colsnum; i++)
- {
- fieldlists.resize(fieldlists.size()+1);
- colnames.push_back(colname[i]);
- SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
- fieldlists[rows].push_back(sf);
- }
- rows++;
- }
-
- void UpdateAffectedCount()
- {
- rows++;
- }
-
- virtual int Rows()
- {
- return rows;
- }
-
- virtual int Cols()
- {
- return cols;
- }
-
- virtual std::string ColName(int column)
- {
- if (column < (int)colnames.size())
- {
- return colnames[column];
- }
- else
- {
- throw SQLbadColName();
- }
- return "";
- }
-
- virtual int ColNum(const std::string &column)
- {
- for (unsigned int i = 0; i < colnames.size(); i++)
- {
- if (column == colnames[i])
- return i;
- }
- throw SQLbadColName();
- return 0;
- }
-
- virtual SQLfield GetValue(int row, int column)
- {
- if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
- {
- return fieldlists[row][column];
- }
-
- throw SQLbadColName();
-
- /* XXX: We never actually get here because of the throw */
- return SQLfield("",true);
- }
-
- virtual SQLfieldList& GetRow()
- {
- if (currentrow < rows)
- return fieldlists[currentrow];
- else
- return emptyfieldlist;
- }
-
- virtual SQLfieldMap& GetRowMap()
- {
- /* In an effort to reduce overhead we don't actually allocate the map
- * until the first time it's needed...so...
- */
- if(fieldmap)
- {
- fieldmap->clear();
- }
- else
- {
- fieldmap = new SQLfieldMap;
- }
-
- if (currentrow < rows)
- {
- for (int i = 0; i < Cols(); i++)
- {
- fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
- }
- currentrow++;
- }
-
- return *fieldmap;
- }
-
- virtual SQLfieldList* GetRowPtr()
- {
- fieldlist = new SQLfieldList();
-
- if (currentrow < rows)
- {
- for (int i = 0; i < Rows(); i++)
- {
- fieldlist->push_back(fieldlists[currentrow][i]);
- }
- currentrow++;
- }
- return fieldlist;
- }
-
- virtual SQLfieldMap* GetRowMapPtr()
- {
- fieldmap = new SQLfieldMap();
-
- if (currentrow < rows)
- {
- for (int i = 0; i < Cols(); i++)
- {
- fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
- }
- currentrow++;
- }
-
- return fieldmap;
- }
-
- virtual void Free(SQLfieldMap* fm)
- {
- delete fm;
- }
-
- virtual void Free(SQLfieldList* fl)
- {
- delete fl;
- }
-};
-
-class SQLConn : public classbase
-{
- private:
- ResultQueue results;
- Module* mod;
- SQLhost host;
- TDSLOGIN* login;
- TDSSOCKET* sock;
- TDSCONTEXT* context;
-
- public:
- QueryQueue queue;
-
- SQLConn(Module* m, const SQLhost& hi)
- : mod(m), host(hi), login(NULL), sock(NULL), context(NULL)
- {
- if (OpenDB())
- {
- std::string query("USE " + host.name);
- if (tds_submit_query(sock, query.c_str()) == TDS_SUCCEED)
- {
- if (tds_process_simple_query(sock) != TDS_SUCCEED)
- {
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id);
- LoggingMutex->Unlock();
- CloseDB();
- }
- }
- else
- {
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not select database " + host.name + " for DB with id: " + host.id);
- LoggingMutex->Unlock();
- CloseDB();
- }
- }
- else
- {
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: Could not connect to DB with id: " + host.id);
- LoggingMutex->Unlock();
- CloseDB();
- }
- }
-
- ~SQLConn()
- {
- CloseDB();
- }
-
- SQLerror Query(SQLrequest* req)
- {
- if (!sock)
- return SQLerror(SQL_BAD_CONN, "Socket was NULL, check if SQL server is running.");
-
- /* Pointer to the buffer we screw around with substitution in */
- char* query;
-
- /* Pointer to the current end of query, where we append new stuff */
- char* queryend;
-
- /* Total length of the unescaped parameters */
- unsigned long maxparamlen, paramcount;
-
- /* The length of the longest parameter */
- maxparamlen = 0;
-
- for(ParamL::iterator i = req->query.p.begin(); i != req->query.p.end(); i++)
- {
- if (i->size() > maxparamlen)
- maxparamlen = i->size();
- }
-
- /* How many params are there in the query? */
- paramcount = count(req->query.q.c_str(), '?');
-
- /* This stores copy of params to be inserted with using numbered params 1;3B*/
- ParamL paramscopy(req->query.p);
-
- /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
- * sizeofquery + (maxtotalparamlength*2) + 1
- *
- * The +1 is for null-terminating the string
- */
-
- query = new char[req->query.q.length() + (maxparamlen*paramcount*2) + 1];
- queryend = query;
-
- for(unsigned long i = 0; i < req->query.q.length(); i++)
- {
- if(req->query.q[i] == '?')
- {
- /* We found a place to substitute..what fun.
- * use mssql calls to escape and write the
- * escaped string onto the end of our query buffer,
- * then we "just" need to make sure queryend is
- * pointing at the right place.
- */
-
- /* Is it numbered parameter?
- */
-
- bool numbered;
- numbered = false;
-
- /* Numbered parameter number :|
- */
- unsigned int paramnum;
- paramnum = 0;
-
- /* Let's check if it's a numbered param. And also calculate it's number.
- */
-
- while ((i < req->query.q.length() - 1) && (req->query.q[i+1] >= '0') && (req->query.q[i+1] <= '9'))
- {
- numbered = true;
- ++i;
- paramnum = paramnum * 10 + req->query.q[i] - '0';
- }
-
- if (paramnum > paramscopy.size() - 1)
- {
- /* index is out of range!
- */
- numbered = false;
- }
-
- if (numbered)
- {
- /* Custom escaping for this one. converting ' to '' should make SQL Server happy. Ugly but fast :]
- */
- char* escaped = new char[(paramscopy[paramnum].length() * 2) + 1];
- char* escend = escaped;
- for (std::string::iterator p = paramscopy[paramnum].begin(); p < paramscopy[paramnum].end(); p++)
- {
- if (*p == '\'')
- {
- *escend = *p;
- escend++;
- *escend = *p;
- }
- *escend = *p;
- escend++;
- }
- *escend = 0;
-
- for (char* n = escaped; *n; n++)
- {
- *queryend = *n;
- queryend++;
- }
- delete[] escaped;
- }
- else if (req->query.p.size())
- {
- /* Custom escaping for this one. converting ' to '' should make SQL Server happy. Ugly but fast :]
- */
- char* escaped = new char[(req->query.p.front().length() * 2) + 1];
- char* escend = escaped;
- for (std::string::iterator p = req->query.p.front().begin(); p < req->query.p.front().end(); p++)
- {
- if (*p == '\'')
- {
- *escend = *p;
- escend++;
- *escend = *p;
- }
- *escend = *p;
- escend++;
- }
- *escend = 0;
-
- for (char* n = escaped; *n; n++)
- {
- *queryend = *n;
- queryend++;
- }
- delete[] escaped;
- req->query.p.pop_front();
- }
- else
- break;
- }
- else
- {
- *queryend = req->query.q[i];
- queryend++;
- }
- }
- *queryend = 0;
- req->query.q = query;
-
- MsSQLResult* res = new MsSQLResult((Module*)mod, req->source, req->id);
- res->dbid = host.id;
- res->query = req->query.q;
-
- char* msquery = strdup(req->query.q.data());
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql",DEBUG,"doing Query: %s",msquery);
- LoggingMutex->Unlock();
- if (tds_submit_query(sock, msquery) != TDS_SUCCEED)
- {
- std::string error("failed to execute: "+std::string(req->query.q.data()));
- delete[] query;
- delete res;
- free(msquery);
- return SQLerror(SQL_QSEND_FAIL, error);
- }
- delete[] query;
- free(msquery);
-
- int tds_res;
- while (tds_process_tokens(sock, &tds_res, NULL, TDS_TOKEN_RESULTS) == TDS_SUCCEED)
- {
- //ServerInstance->Logs->Log("m_mssql",DEBUG,"<******> result type: %d", tds_res);
- //ServerInstance->Logs->Log("m_mssql",DEBUG,"AFFECTED ROWS: %d", sock->rows_affected);
- switch (tds_res)
- {
- case TDS_ROWFMT_RESULT:
- break;
-
- case TDS_DONE_RESULT:
- if (sock->rows_affected > -1)
- {
- for (int c = 0; c < sock->rows_affected; c++) res->UpdateAffectedCount();
- continue;
- }
- break;
-
- case TDS_ROW_RESULT:
- while (tds_process_tokens(sock, &tds_res, NULL, TDS_STOPAT_ROWFMT|TDS_RETURN_DONE|TDS_RETURN_ROW) == TDS_SUCCEED)
- {
- if (tds_res != TDS_ROW_RESULT)
- break;
-
- if (!sock->current_results)
- continue;
-
- if (sock->res_info->row_count > 0)
- {
- int cols = sock->res_info->num_cols;
- char** name = new char*[MAXBUF];
- char** data = new char*[MAXBUF];
- for (int j=0; j<cols; j++)
- {
- TDSCOLUMN* col = sock->current_results->columns[j];
- name[j] = col->column_name;
-
- int ctype;
- int srclen;
- unsigned char* src;
- CONV_RESULT dres;
- ctype = tds_get_conversion_type(col->column_type, col->column_size);
-#if _TDSVER >= 82
- src = col->column_data;
-#else
- src = &(sock->current_results->current_row[col->column_offset]);
-#endif
- srclen = col->column_cur_size;
- tds_convert(sock->tds_ctx, ctype, (TDS_CHAR *) src, srclen, SYBCHAR, &dres);
- data[j] = (char*)dres.ib;
- }
- ResultReady(res, cols, data, name);
- }
- }
- break;
-
- default:
- break;
- }
- }
- ResultsMutex->Lock();
- results.push_back(res);
- ResultsMutex->Unlock();
- return SQLerror();
- }
-
- static int HandleMessage(const TDSCONTEXT * pContext, TDSSOCKET * pTdsSocket, TDSMESSAGE * pMessage)
- {
- SQLConn* sc = (SQLConn*)pContext->parent;
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql", DEBUG, "Message for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message);
- LoggingMutex->Unlock();
- return 0;
- }
-
- static int HandleError(const TDSCONTEXT * pContext, TDSSOCKET * pTdsSocket, TDSMESSAGE * pMessage)
- {
- SQLConn* sc = (SQLConn*)pContext->parent;
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql", DEFAULT, "Error for DB with id: %s -> %s", sc->host.id.c_str(), pMessage->message);
- LoggingMutex->Unlock();
- return 0;
- }
-
- void ResultReady(MsSQLResult *res, int cols, char **data, char **colnames)
- {
- res->AddRow(cols, data, colnames);
- }
-
- void AffectedReady(MsSQLResult *res)
- {
- res->UpdateAffectedCount();
- }
-
- bool OpenDB()
- {
- CloseDB();
-
- TDSCONNECTION* conn = NULL;
-
- login = tds_alloc_login();
- tds_set_app(login, "TSQL");
- tds_set_library(login,"TDS-Library");
- tds_set_host(login, "");
- tds_set_server(login, host.host.c_str());
- tds_set_server_addr(login, host.host.c_str());
- tds_set_user(login, host.user.c_str());
- tds_set_passwd(login, host.pass.c_str());
- tds_set_port(login, host.port);
- tds_set_packet(login, 512);
-
- context = tds_alloc_context(this);
- context->msg_handler = HandleMessage;
- context->err_handler = HandleError;
-
- sock = tds_alloc_socket(context, 512);
- tds_set_parent(sock, NULL);
-
- conn = tds_read_config_info(NULL, login, context->locale);
-
- if (tds_connect(sock, conn) == TDS_SUCCEED)
- {
- tds_free_connection(conn);
- return 1;
- }
- tds_free_connection(conn);
- return 0;
- }
-
- void CloseDB()
- {
- if (sock)
- {
- tds_free_socket(sock);
- sock = NULL;
- }
- if (context)
- {
- tds_free_context(context);
- context = NULL;
- }
- if (login)
- {
- tds_free_login(login);
- login = NULL;
- }
- }
-
- SQLhost GetConfHost()
- {
- return host;
- }
-
- void SendResults()
- {
- while (results.size())
- {
- MsSQLResult* res = results[0];
- ResultsMutex->Lock();
- if (res->dest)
- {
- res->Send();
- }
- else
- {
- /* If the client module is unloaded partway through a query then the provider will set
- * the pointer to NULL. We cannot just cancel the query as the result will still come
- * through at some point...and it could get messy if we play with invalid pointers...
- */
- delete res;
- }
- results.pop_front();
- ResultsMutex->Unlock();
- }
- }
-
- void ClearResults()
- {
- while (results.size())
- {
- MsSQLResult* res = results[0];
- delete res;
- results.pop_front();
- }
- }
-
- void DoLeadingQuery()
- {
- SQLrequest* req = queue.front();
- req->error = Query(req);
- }
-
-};
-
-
-class ModuleMsSQL : public Module
-{
- private:
- unsigned long currid;
- QueryThread* queryDispatcher;
- ServiceProvider sqlserv;
-
- public:
- ModuleMsSQL()
- : currid(0), sqlserv(this, "SQL/mssql", SERVICE_DATA)
- {
- LoggingMutex = new Mutex();
- ResultsMutex = new Mutex();
- queryDispatcher = new QueryThread(this);
- }
-
- void init()
- {
- ReadConf();
-
- ServerInstance->Threads->Start(queryDispatcher);
-
- Implementation eventlist[] = { I_OnRehash };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- ServerInstance->Modules->AddService(sqlserv);
- }
-
- virtual ~ModuleMsSQL()
- {
- queryDispatcher->join();
- delete queryDispatcher;
- ClearQueue();
- ClearAllConnections();
-
- delete LoggingMutex;
- delete ResultsMutex;
- }
-
- void SendQueue()
- {
- for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
- {
- iter->second->SendResults();
- }
- }
-
- void ClearQueue()
- {
- for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
- {
- iter->second->ClearResults();
- }
- }
-
- bool HasHost(const SQLhost &host)
- {
- for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
- {
- if (host == iter->second->GetConfHost())
- return true;
- }
- return false;
- }
-
- bool HostInConf(const SQLhost &h)
- {
- ConfigTagList tags = ServerInstance->Config->ConfTags("database");
- for (ConfigIter i = tags.first; i != tags.second; ++i)
- {
- ConfigTag* tag = i->second;
- SQLhost host;
- host.id = tag->getString("id");
- host.host = tag->getString("hostname");
- host.port = tag->getInt("port", 1433);
- host.name = tag->getString("name");
- host.user = tag->getString("username");
- host.pass = tag->getString("password");
- if (h == host)
- return true;
- }
- return false;
- }
-
- void ReadConf()
- {
- ClearOldConnections();
-
- ConfigTagList tags = ServerInstance->Config->ConfTags("database");
- for (ConfigIter i = tags.first; i != tags.second; ++i)
- {
- ConfigTag* tag = i->second;
- SQLhost host;
-
- host.id = tag->getString("id");
- host.host = tag->getString("hostname");
- host.port = tag->getInt("port", 1433);
- host.name = tag->getString("name");
- host.user = tag->getString("username");
- host.pass = tag->getString("password");
-
- if (HasHost(host))
- continue;
-
- this->AddConn(host);
- }
- }
-
- void AddConn(const SQLhost& hi)
- {
- if (HasHost(hi))
- {
- LoggingMutex->Lock();
- ServerInstance->Logs->Log("m_mssql",DEFAULT, "WARNING: A MsSQL connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
- LoggingMutex->Unlock();
- return;
- }
-
- SQLConn* newconn;
-
- newconn = new SQLConn(this, hi);
-
- connections.insert(std::make_pair(hi.id, newconn));
- }
-
- void ClearOldConnections()
- {
- ConnMap::iterator iter,safei;
- for (iter = connections.begin(); iter != connections.end(); iter++)
- {
- if (!HostInConf(iter->second->GetConfHost()))
- {
- delete iter->second;
- safei = iter;
- --iter;
- connections.erase(safei);
- }
- }
- }
-
- void ClearAllConnections()
- {
- for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i)
- delete i->second;
- connections.clear();
- }
-
- virtual void OnRehash(User* user)
- {
- queryDispatcher->LockQueue();
- ReadConf();
- queryDispatcher->UnlockQueueWakeup();
- }
-
- void OnRequest(Request& request)
- {
- if(strcmp(SQLREQID, request.id) == 0)
- {
- SQLrequest* req = (SQLrequest*)&request;
-
- queryDispatcher->LockQueue();
-
- ConnMap::iterator iter;
-
- if((iter = connections.find(req->dbid)) != connections.end())
- {
- req->id = NewID();
- iter->second->queue.push(new SQLrequest(*req));
- }
- else
- {
- req->error.Id(SQL_BAD_DBID);
- }
- queryDispatcher->UnlockQueueWakeup();
- }
- }
-
- unsigned long NewID()
- {
- if (currid+1 == 0)
- currid++;
-
- return ++currid;
- }
-
- virtual Version GetVersion()
- {
- return Version("MsSQL provider", VF_VENDOR);
- }
-
-};
-
-void QueryThread::OnNotify()
-{
- Parent->SendQueue();
-}
-
-void QueryThread::Run()
-{
- this->LockQueue();
- while (this->GetExitFlag() == false)
- {
- SQLConn* conn = NULL;
- for (ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
- {
- if (i->second->queue.totalsize())
- {
- conn = i->second;
- break;
- }
- }
- if (conn)
- {
- this->UnlockQueue();
- conn->DoLeadingQuery();
- this->NotifyParent();
- this->LockQueue();
- conn->queue.pop();
- }
- else
- {
- this->WaitForQueue();
- }
- }
- this->UnlockQueue();
-}
-
-MODULE_INIT(ModuleMsSQL)
diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp
index 159a0b8b2..41c3a2a65 100644
--- a/src/modules/extra/m_mysql.cpp
+++ b/src/modules/extra/m_mysql.cpp
@@ -19,13 +19,26 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
+/// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
-/* Stop mysql wanting to use long long */
-#define NO_CLIENT_LONG_LONG
+/// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
+/// $PackageInfo: require_system("centos" "7.0") mariadb-devel
+/// $PackageInfo: require_system("darwin") mysql-connector-c
+/// $PackageInfo: require_system("debian") libmysqlclient-dev
+/// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
+
+
+// Fix warnings about the use of `long long` on C++03.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-long-long"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-Wlong-long"
+#endif
#include "inspircd.h"
#include <mysql.h>
-#include "sql.h"
+#include "modules/sql.h"
#ifdef _WIN32
# pragma comment(lib, "libmysql.lib")
@@ -33,10 +46,6 @@
/* VERSION 3 API: With nonblocking (threaded) requests */
-/* $ModDesc: SQL Service Provider module for all other m_sql* modules */
-/* $CompileFlags: exec("mysql_config --include") */
-/* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
-
/* THE NONBLOCKING MYSQL API!
*
* MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
@@ -75,20 +84,20 @@ class DispatcherThread;
struct QQueueItem
{
- SQLQuery* q;
+ SQL::Query* q;
std::string query;
SQLConnection* c;
- QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
+ QQueueItem(SQL::Query* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
};
struct RQueueItem
{
- SQLQuery* q;
+ SQL::Query* q;
MySQLresult* r;
- RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
+ RQueueItem(SQL::Query* Q, MySQLresult* R) : q(Q), r(R) {}
};
-typedef std::map<std::string, SQLConnection*> ConnMap;
+typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
typedef std::deque<QQueueItem> QueryQueue;
typedef std::deque<RQueueItem> ResultQueue;
@@ -103,11 +112,11 @@ class ModuleSQL : public Module
ConnMap connections; // main thread only
ModuleSQL();
- void init();
+ void init() CXX11_OVERRIDE;
~ModuleSQL();
- void OnRehash(User* user);
- void OnUnloadModule(Module* mod);
- Version GetVersion();
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
+ void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
+ Version GetVersion() CXX11_OVERRIDE;
};
class DispatcherThread : public SocketThread
@@ -117,8 +126,8 @@ class DispatcherThread : public SocketThread
public:
DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
~DispatcherThread() { }
- virtual void Run();
- virtual void OnNotify();
+ void Run() CXX11_OVERRIDE;
+ void OnNotify() CXX11_OVERRIDE;
};
#if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
@@ -127,16 +136,16 @@ class DispatcherThread : public SocketThread
/** Represents a mysql result set
*/
-class MySQLresult : public SQLResult
+class MySQLresult : public SQL::Result
{
public:
- SQLerror err;
+ SQL::Error err;
int currentrow;
int rows;
std::vector<std::string> colnames;
- std::vector<SQLEntries> fieldlists;
+ std::vector<SQL::Row> fieldlists;
- MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
+ MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL::SUCCESS), currentrow(0), rows(0)
{
if (affected_rows >= 1)
{
@@ -165,9 +174,9 @@ class MySQLresult : public SQLResult
{
std::string a = (fields[field_count].name ? fields[field_count].name : "");
if (row[field_count])
- fieldlists[n].push_back(SQLEntry(row[field_count]));
+ fieldlists[n].push_back(SQL::Field(row[field_count]));
else
- fieldlists[n].push_back(SQLEntry());
+ fieldlists[n].push_back(SQL::Field());
colnames.push_back(a);
field_count++;
}
@@ -179,35 +188,44 @@ class MySQLresult : public SQLResult
}
}
- MySQLresult(SQLerror& e) : err(e)
+ MySQLresult(SQL::Error& e) : err(e)
{
}
- ~MySQLresult()
+ int Rows() CXX11_OVERRIDE
{
+ return rows;
}
- virtual int Rows()
+ void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
{
- return rows;
+ result.assign(colnames.begin(), colnames.end());
}
- virtual void GetCols(std::vector<std::string>& result)
+ bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
{
- result.assign(colnames.begin(), colnames.end());
+ for (size_t i = 0; i < colnames.size(); ++i)
+ {
+ if (colnames[i] == column)
+ {
+ index = i;
+ return true;
+ }
+ }
+ return false;
}
- virtual SQLEntry GetValue(int row, int column)
+ SQL::Field GetValue(int row, int column)
{
if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
{
return fieldlists[row][column];
}
- return SQLEntry();
+ return SQL::Field();
}
- virtual bool GetRow(SQLEntries& result)
+ bool GetRow(SQL::Row& result) CXX11_OVERRIDE
{
if (currentrow < rows)
{
@@ -225,7 +243,7 @@ class MySQLresult : public SQLResult
/** Represents a connection to a mysql database
*/
-class SQLConnection : public SQLProvider
+class SQLConnection : public SQL::Provider
{
public:
reference<ConfigTag> config;
@@ -233,7 +251,7 @@ class SQLConnection : public SQLProvider
Mutex lock;
// This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
- SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
+ SQLConnection(Module* p, ConfigTag* tag) : SQL::Provider(p, "SQL/" + tag->getString("id")),
config(tag), connection(NULL)
{
}
@@ -254,10 +272,16 @@ class SQLConnection : public SQLProvider
std::string user = config->getString("user");
std::string pass = config->getString("pass");
std::string dbname = config->getString("name");
- int port = config->getInt("port");
+ unsigned int port = config->getUInt("port", 3306);
bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
if (!rv)
return rv;
+
+ // Enable character set settings
+ std::string charset = config->getString("charset");
+ if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
+
std::string initquery;
if (config->readString("initialquery", initquery))
{
@@ -286,7 +310,7 @@ class SQLConnection : public SQLProvider
{
/* XXX: See /usr/include/mysql/mysqld_error.h for a list of
* possible error numbers and error messages */
- SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
+ SQL::Error e(SQL::QREPLY_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
return new MySQLresult(e);
}
}
@@ -308,14 +332,14 @@ class SQLConnection : public SQLProvider
mysql_close(connection);
}
- void submit(SQLQuery* q, const std::string& qs)
+ void Submit(SQL::Query* q, const std::string& qs) CXX11_OVERRIDE
{
Parent()->Dispatcher->LockQueue();
Parent()->qq.push_back(QQueueItem(q, qs, this));
Parent()->Dispatcher->UnlockQueueWakeup();
}
- void submit(SQLQuery* call, const std::string& q, const ParamL& p)
+ void Submit(SQL::Query* call, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
{
std::string res;
unsigned int param = 0;
@@ -332,18 +356,17 @@ class SQLConnection : public SQLProvider
// and one byte is the terminating null
std::vector<char> buffer(parm.length() * 2 + 1);
- // The return value of mysql_escape_string() is the length of the encoded string,
+ // The return value of mysql_real_escape_string() is the length of the encoded string,
// not including the terminating null
- unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
-// mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
+ unsigned long escapedsize = mysql_real_escape_string(connection, &buffer[0], parm.c_str(), parm.length());
res.append(&buffer[0], escapedsize);
}
}
}
- submit(call, res);
+ Submit(call, res);
}
- void submit(SQLQuery* call, const std::string& q, const ParamM& p)
+ void Submit(SQL::Query* call, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
{
std::string res;
for(std::string::size_type i = 0; i < q.length(); i++)
@@ -358,7 +381,7 @@ class SQLConnection : public SQLProvider
field.push_back(q[i++]);
i--;
- ParamM::const_iterator it = p.find(field);
+ SQL::ParamMap::const_iterator it = p.find(field);
if (it != p.end())
{
std::string parm = it->second;
@@ -369,7 +392,7 @@ class SQLConnection : public SQLProvider
}
}
}
- submit(call, res);
+ Submit(call, res);
}
};
@@ -381,12 +404,7 @@ ModuleSQL::ModuleSQL()
void ModuleSQL::init()
{
Dispatcher = new DispatcherThread(this);
- ServerInstance->Threads->Start(Dispatcher);
-
- Implementation eventlist[] = { I_OnRehash, I_OnUnloadModule };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
-
- OnRehash(NULL);
+ ServerInstance->Threads.Start(Dispatcher);
}
ModuleSQL::~ModuleSQL()
@@ -403,13 +421,13 @@ ModuleSQL::~ModuleSQL()
}
}
-void ModuleSQL::OnRehash(User* user)
+void ModuleSQL::ReadConfig(ConfigStatus& status)
{
ConnMap conns;
ConfigTagList tags = ServerInstance->Config->ConfTags("database");
for(ConfigIter i = tags.first; i != tags.second; i++)
{
- if (i->second->getString("module", "mysql") != "mysql")
+ if (!stdalgo::string::equalsci(i->second->getString("provider"), "mysql"))
continue;
std::string id = i->second->getString("id");
ConnMap::iterator curr = connections.find(id);
@@ -428,7 +446,7 @@ void ModuleSQL::OnRehash(User* user)
// now clean up the deleted databases
Dispatcher->LockQueue();
- SQLerror err(SQL_BAD_DBID);
+ SQL::Error err(SQL::BAD_DBID);
for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
{
ServerInstance->Modules->DelService(*i->second);
@@ -455,7 +473,7 @@ void ModuleSQL::OnRehash(User* user)
void ModuleSQL::OnUnloadModule(Module* mod)
{
- SQLerror err(SQL_BAD_DBID);
+ SQL::Error err(SQL::BAD_DBID);
Dispatcher->LockQueue();
unsigned int i = qq.size();
while (i > 0)
@@ -535,7 +553,7 @@ void DispatcherThread::OnNotify()
for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
{
MySQLresult* res = i->r;
- if (res->err.id == SQL_NO_ERROR)
+ if (res->err.code == SQL::SUCCESS)
i->q->OnResult(*res);
else
i->q->OnError(res->err);
diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp
index ac247548a..ec89208dd 100644
--- a/src/modules/extra/m_pgsql.cpp
+++ b/src/modules/extra/m_pgsql.cpp
@@ -21,16 +21,19 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: -Iexecute("pg_config --includedir" "POSTGRESQL_INCLUDE_DIR")
+/// $LinkerFlags: -Lexecute("pg_config --libdir" "POSTGRESQL_LIBRARY_DIR") -lpq
+
+/// $PackageInfo: require_system("centos") postgresql-devel
+/// $PackageInfo: require_system("darwin") postgresql
+/// $PackageInfo: require_system("debian") libpq-dev
+/// $PackageInfo: require_system("ubuntu") libpq-dev
+
#include "inspircd.h"
#include <cstdlib>
-#include <sstream>
#include <libpq-fe.h>
-#include "sql.h"
-
-/* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
-/* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
-/* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
+#include "modules/sql.h"
/* SQLConn rewritten by peavey to
* use EventHandler instead of
@@ -43,7 +46,7 @@
class SQLConn;
class ModulePgSQL;
-typedef std::map<std::string, SQLConn*> ConnMap;
+typedef insp::flat_map<std::string, SQLConn*> ConnMap;
/* CREAD, Connecting and wants read event
* CWRITE, Connecting and wants write event
@@ -59,17 +62,17 @@ class ReconnectTimer : public Timer
private:
ModulePgSQL* mod;
public:
- ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m)
+ ReconnectTimer(ModulePgSQL* m) : Timer(5, false), mod(m)
{
}
- virtual void Tick(time_t TIME);
+ bool Tick(time_t TIME) CXX11_OVERRIDE;
};
struct QueueItem
{
- SQLQuery* c;
+ SQL::Query* c;
std::string q;
- QueueItem(SQLQuery* C, const std::string& Q) : c(C), q(Q) {}
+ QueueItem(SQL::Query* C, const std::string& Q) : c(C), q(Q) {}
};
/** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
@@ -79,11 +82,21 @@ struct QueueItem
* data is passes to the module nearly as directly as if it was using the API directly itself.
*/
-class PgSQLresult : public SQLResult
+class PgSQLresult : public SQL::Result
{
PGresult* res;
int currentrow;
int rows;
+ std::vector<std::string> colnames;
+
+ void getColNames()
+ {
+ colnames.resize(PQnfields(res));
+ for(unsigned int i=0; i < colnames.size(); i++)
+ {
+ colnames[i] = PQfname(res, i);
+ }
+ }
public:
PgSQLresult(PGresult* result) : res(result), currentrow(0)
{
@@ -97,30 +110,44 @@ class PgSQLresult : public SQLResult
PQclear(res);
}
- virtual int Rows()
+ int Rows() CXX11_OVERRIDE
{
return rows;
}
- virtual void GetCols(std::vector<std::string>& result)
+ void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
+ {
+ if (colnames.empty())
+ getColNames();
+ result = colnames;
+ }
+
+ bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
{
- result.resize(PQnfields(res));
- for(unsigned int i=0; i < result.size(); i++)
+ if (colnames.empty())
+ getColNames();
+
+ for (size_t i = 0; i < colnames.size(); ++i)
{
- result[i] = PQfname(res, i);
+ if (colnames[i] == column)
+ {
+ index = i;
+ return true;
+ }
}
+ return false;
}
- virtual SQLEntry GetValue(int row, int column)
+ SQL::Field GetValue(int row, int column)
{
char* v = PQgetvalue(res, row, column);
if (!v || PQgetisnull(res, row, column))
- return SQLEntry();
+ return SQL::Field();
- return SQLEntry(std::string(v, PQgetlength(res, row, column)));
+ return SQL::Field(std::string(v, PQgetlength(res, row, column)));
}
- virtual bool GetRow(SQLEntries& result)
+ bool GetRow(SQL::Row& result) CXX11_OVERRIDE
{
if (currentrow >= PQntuples(res))
return false;
@@ -138,7 +165,7 @@ class PgSQLresult : public SQLResult
/** SQLConn represents one SQL session.
*/
-class SQLConn : public SQLProvider, public EventHandler
+class SQLConn : public SQL::Provider, public EventHandler
{
public:
reference<ConfigTag> conf; /* The <database> entry */
@@ -148,25 +175,25 @@ class SQLConn : public SQLProvider, public EventHandler
QueueItem qinprog; /* If there is currently a query in progress */
SQLConn(Module* Creator, ConfigTag* tag)
- : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
+ : SQL::Provider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
{
if (!DoConnect())
{
- ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database " + tag->getString("id"));
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not connect to database " + tag->getString("id"));
DelayReconnect();
}
}
- CullResult cull()
+ CullResult cull() CXX11_OVERRIDE
{
- this->SQLProvider::cull();
+ this->SQL::Provider::cull();
ServerInstance->Modules->DelService(*this);
return this->EventHandler::cull();
}
~SQLConn()
{
- SQLerror err(SQL_BAD_DBID);
+ SQL::Error err(SQL::BAD_DBID);
if (qinprog.c)
{
qinprog.c->OnError(err);
@@ -174,24 +201,25 @@ class SQLConn : public SQLProvider, public EventHandler
}
for(std::deque<QueueItem>::iterator i = queue.begin(); i != queue.end(); i++)
{
- SQLQuery* q = i->c;
+ SQL::Query* q = i->c;
q->OnError(err);
delete q;
}
}
- virtual void HandleEvent(EventType et, int errornum)
+ void OnEventHandlerRead() CXX11_OVERRIDE
{
- switch (et)
- {
- case EVENT_READ:
- case EVENT_WRITE:
- DoEvent();
- break;
+ DoEvent();
+ }
- case EVENT_ERROR:
- DelayReconnect();
- }
+ void OnEventHandlerWrite() CXX11_OVERRIDE
+ {
+ DoEvent();
+ }
+
+ void OnEventHandlerError(int errornum) CXX11_OVERRIDE
+ {
+ DelayReconnect();
}
std::string GetDSN()
@@ -242,9 +270,9 @@ class SQLConn : public SQLProvider, public EventHandler
if(this->fd <= -1)
return false;
- if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
+ if (!SocketEngine::AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
{
- ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
return false;
}
@@ -257,17 +285,17 @@ class SQLConn : public SQLProvider, public EventHandler
switch(PQconnectPoll(sql))
{
case PGRES_POLLING_WRITING:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
status = CWRITE;
return true;
case PGRES_POLLING_READING:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
status = CREAD;
return true;
case PGRES_POLLING_FAILED:
return false;
case PGRES_POLLING_OK:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
status = WWRITE;
DoConnectedPoll();
default:
@@ -316,7 +344,7 @@ restart:
case PGRES_BAD_RESPONSE:
case PGRES_FATAL_ERROR:
{
- SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
+ SQL::Error err(SQL::QREPLY_FAIL, PQresultErrorMessage(result));
qinprog.c->OnError(err);
break;
}
@@ -350,17 +378,17 @@ restart:
switch(PQresetPoll(sql))
{
case PGRES_POLLING_WRITING:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
status = CWRITE;
return DoPoll();
case PGRES_POLLING_READING:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
status = CREAD;
return true;
case PGRES_POLLING_FAILED:
return false;
case PGRES_POLLING_OK:
- ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ SocketEngine::ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
status = WWRITE;
DoConnectedPoll();
default:
@@ -386,7 +414,7 @@ restart:
}
}
- void submit(SQLQuery *req, const std::string& q)
+ void Submit(SQL::Query *req, const std::string& q) CXX11_OVERRIDE
{
if (qinprog.q.empty())
{
@@ -399,7 +427,7 @@ restart:
}
}
- void submit(SQLQuery *req, const std::string& q, const ParamL& p)
+ void Submit(SQL::Query *req, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
{
std::string res;
unsigned int param = 0;
@@ -413,22 +441,18 @@ restart:
{
std::string parm = p[param++];
std::vector<char> buffer(parm.length() * 2 + 1);
-#ifdef PGSQL_HAS_ESCAPECONN
int error;
size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error);
if (error)
- ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
-#else
- size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length());
-#endif
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed");
res.append(&buffer[0], escapedsize);
}
}
}
- submit(req, res);
+ Submit(req, res);
}
- void submit(SQLQuery *req, const std::string& q, const ParamM& p)
+ void Submit(SQL::Query *req, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
{
std::string res;
for(std::string::size_type i = 0; i < q.length(); i++)
@@ -443,24 +467,20 @@ restart:
field.push_back(q[i++]);
i--;
- ParamM::const_iterator it = p.find(field);
+ SQL::ParamMap::const_iterator it = p.find(field);
if (it != p.end())
{
std::string parm = it->second;
std::vector<char> buffer(parm.length() * 2 + 1);
-#ifdef PGSQL_HAS_ESCAPECONN
int error;
size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error);
if (error)
- ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
-#else
- size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length());
-#endif
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed");
res.append(&buffer[0], escapedsize);
}
}
}
- submit(req, res);
+ Submit(req, res);
}
void DoQuery(const QueueItem& req)
@@ -468,7 +488,7 @@ restart:
if (status != WREAD && status != WWRITE)
{
// whoops, not connected...
- SQLerror err(SQL_BAD_CONN);
+ SQL::Error err(SQL::BAD_CONN);
req.c->OnError(err);
delete req.c;
return;
@@ -480,7 +500,7 @@ restart:
}
else
{
- SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
+ SQL::Error err(SQL::QSEND_FAIL, PQerrorMessage(sql));
req.c->OnError(err);
delete req.c;
}
@@ -488,7 +508,7 @@ restart:
void Close()
{
- ServerInstance->SE->DelFd(this);
+ SocketEngine::DelFd(this);
if(sql)
{
@@ -505,25 +525,17 @@ class ModulePgSQL : public Module
ReconnectTimer* retimer;
ModulePgSQL()
+ : retimer(NULL)
{
}
- void init()
- {
- ReadConf();
-
- Implementation eventlist[] = { I_OnUnloadModule, I_OnRehash };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- }
-
- virtual ~ModulePgSQL()
+ ~ModulePgSQL()
{
- if (retimer)
- ServerInstance->Timers->DelTimer(retimer);
+ delete retimer;
ClearAllConnections();
}
- virtual void OnRehash(User* user)
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
{
ReadConf();
}
@@ -534,7 +546,7 @@ class ModulePgSQL : public Module
ConfigTagList tags = ServerInstance->Config->ConfTags("database");
for(ConfigIter i = tags.first; i != tags.second; i++)
{
- if (i->second->getString("module", "pgsql") != "pgsql")
+ if (!stdalgo::string::equalsci(i->second->getString("provider"), "pgsql"))
continue;
std::string id = i->second->getString("id");
ConnMap::iterator curr = connections.find(id);
@@ -564,9 +576,9 @@ class ModulePgSQL : public Module
connections.clear();
}
- void OnUnloadModule(Module* mod)
+ void OnUnloadModule(Module* mod) CXX11_OVERRIDE
{
- SQLerror err(SQL_BAD_DBID);
+ SQL::Error err(SQL::BAD_DBID);
for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
{
SQLConn* conn = i->second;
@@ -579,7 +591,7 @@ class ModulePgSQL : public Module
std::deque<QueueItem>::iterator j = conn->queue.begin();
while (j != conn->queue.end())
{
- SQLQuery* q = j->c;
+ SQL::Query* q = j->c;
if (q->creator == mod)
{
q->OnError(err);
@@ -592,16 +604,18 @@ class ModulePgSQL : public Module
}
}
- Version GetVersion()
+ Version GetVersion() CXX11_OVERRIDE
{
return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR);
}
};
-void ReconnectTimer::Tick(time_t time)
+bool ReconnectTimer::Tick(time_t time)
{
mod->retimer = NULL;
mod->ReadConf();
+ delete this;
+ return false;
}
void SQLConn::DelayReconnect()
@@ -615,7 +629,7 @@ void SQLConn::DelayReconnect()
if (!mod->retimer)
{
mod->retimer = new ReconnectTimer(mod);
- ServerInstance->Timers->AddTimer(mod->retimer);
+ ServerInstance->Timers.AddTimer(mod->retimer);
}
}
}
diff --git a/src/modules/extra/m_regex_pcre.cpp b/src/modules/extra/m_regex_pcre.cpp
index cba234c8c..e8ef96c22 100644
--- a/src/modules/extra/m_regex_pcre.cpp
+++ b/src/modules/extra/m_regex_pcre.cpp
@@ -17,35 +17,28 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: execute("pcre-config --cflags" "PCRE_CXXFLAGS")
+/// $LinkerFlags: execute("pcre-config --libs" "PCRE_LDFLAGS" "-lpcre")
+
+/// $PackageInfo: require_system("centos") pcre-devel pkgconfig
+/// $PackageInfo: require_system("darwin") pcre pkg-config
+/// $PackageInfo: require_system("debian") libpcre3-dev pkg-config
+/// $PackageInfo: require_system("ubuntu") libpcre3-dev pkg-config
+
#include "inspircd.h"
#include <pcre.h>
-#include "m_regex.h"
-
-/* $ModDesc: Regex Provider Module for PCRE */
-/* $ModDep: m_regex.h */
-/* $CompileFlags: exec("pcre-config --cflags") */
-/* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */
+#include "modules/regex.h"
#ifdef _WIN32
# pragma comment(lib, "libpcre.lib")
#endif
-class PCREException : public ModuleException
-{
-public:
- PCREException(const std::string& rx, const std::string& error, int erroffset)
- : ModuleException("Error in regex " + rx + " at offset " + ConvToStr(erroffset) + ": " + error)
- {
- }
-};
-
class PCRERegex : public Regex
{
-private:
pcre* regex;
-public:
+ public:
PCRERegex(const std::string& rx) : Regex(rx)
{
const char* error;
@@ -53,24 +46,19 @@ public:
regex = pcre_compile(rx.c_str(), 0, &error, &erroffset, NULL);
if (!regex)
{
- ServerInstance->Logs->Log("REGEX", DEBUG, "pcre_compile failed: /%s/ [%d] %s", rx.c_str(), erroffset, error);
- throw PCREException(rx, error, erroffset);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "pcre_compile failed: /%s/ [%d] %s", rx.c_str(), erroffset, error);
+ throw RegexException(rx, error, erroffset);
}
}
- virtual ~PCRERegex()
+ ~PCRERegex()
{
pcre_free(regex);
}
- virtual bool Matches(const std::string& text)
+ bool Matches(const std::string& text) CXX11_OVERRIDE
{
- if (pcre_exec(regex, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1)
- {
- // Bang. :D
- return true;
- }
- return false;
+ return (pcre_exec(regex, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) >= 0);
}
};
@@ -78,7 +66,7 @@ class PCREFactory : public RegexFactory
{
public:
PCREFactory(Module* m) : RegexFactory(m, "regex/pcre") {}
- Regex* Create(const std::string& expr)
+ Regex* Create(const std::string& expr) CXX11_OVERRIDE
{
return new PCRERegex(expr);
}
@@ -86,13 +74,13 @@ class PCREFactory : public RegexFactory
class ModuleRegexPCRE : public Module
{
-public:
+ public:
PCREFactory ref;
- ModuleRegexPCRE() : ref(this) {
- ServerInstance->Modules->AddService(ref);
+ ModuleRegexPCRE() : ref(this)
+ {
}
- Version GetVersion()
+ Version GetVersion() CXX11_OVERRIDE
{
return Version("Regex Provider Module for PCRE", VF_VENDOR);
}
diff --git a/src/modules/extra/m_regex_posix.cpp b/src/modules/extra/m_regex_posix.cpp
index b3afd60c8..b5fddfab8 100644
--- a/src/modules/extra/m_regex_posix.cpp
+++ b/src/modules/extra/m_regex_posix.cpp
@@ -19,28 +19,15 @@
#include "inspircd.h"
-#include "m_regex.h"
+#include "modules/regex.h"
#include <sys/types.h>
#include <regex.h>
-/* $ModDesc: Regex Provider Module for POSIX Regular Expressions */
-/* $ModDep: m_regex.h */
-
-class POSIXRegexException : public ModuleException
-{
-public:
- POSIXRegexException(const std::string& rx, const std::string& error)
- : ModuleException("Error in regex " + rx + ": " + error)
- {
- }
-};
-
class POSIXRegex : public Regex
{
-private:
regex_t regbuf;
-public:
+ public:
POSIXRegex(const std::string& rx, bool extended) : Regex(rx)
{
int flags = (extended ? REG_EXTENDED : 0) | REG_NOSUB;
@@ -58,23 +45,18 @@ public:
error = errbuf;
delete[] errbuf;
regfree(&regbuf);
- throw POSIXRegexException(rx, error);
+ throw RegexException(rx, error);
}
}
- virtual ~POSIXRegex()
+ ~POSIXRegex()
{
regfree(&regbuf);
}
- virtual bool Matches(const std::string& text)
+ bool Matches(const std::string& text) CXX11_OVERRIDE
{
- if (regexec(&regbuf, text.c_str(), 0, NULL, 0) == 0)
- {
- // Bang. :D
- return true;
- }
- return false;
+ return (regexec(&regbuf, text.c_str(), 0, NULL, 0) == 0);
}
};
@@ -83,7 +65,7 @@ class PosixFactory : public RegexFactory
public:
bool extended;
PosixFactory(Module* m) : RegexFactory(m, "regex/posix") {}
- Regex* Create(const std::string& expr)
+ Regex* Create(const std::string& expr) CXX11_OVERRIDE
{
return new POSIXRegex(expr, extended);
}
@@ -92,20 +74,18 @@ class PosixFactory : public RegexFactory
class ModuleRegexPOSIX : public Module
{
PosixFactory ref;
-public:
- ModuleRegexPOSIX() : ref(this) {
- ServerInstance->Modules->AddService(ref);
- Implementation eventlist[] = { I_OnRehash };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- OnRehash(NULL);
+
+ public:
+ ModuleRegexPOSIX() : ref(this)
+ {
}
- Version GetVersion()
+ Version GetVersion() CXX11_OVERRIDE
{
return Version("Regex Provider Module for POSIX Regular Expressions", VF_VENDOR);
}
- void OnRehash(User* u)
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
{
ref.extended = ServerInstance->Config->ConfValue("posix")->getBool("extended");
}
diff --git a/src/modules/extra/m_regex_re2.cpp b/src/modules/extra/m_regex_re2.cpp
new file mode 100644
index 000000000..4bcf287ca
--- /dev/null
+++ b/src/modules/extra/m_regex_re2.cpp
@@ -0,0 +1,86 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2013 Peter Powell <petpow@saberuk.com>
+ * Copyright (C) 2012 ChrisTX <chris@rev-crew.info>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/// $CompilerFlags: find_compiler_flags("re2" "")
+/// $LinkerFlags: find_linker_flags("re2" "-lre2")
+
+/// $PackageInfo: require_system("darwin") pkg-config re2
+/// $PackageInfo: require_system("debian" "8.0") libre2-dev pkg-config
+/// $PackageInfo: require_system("ubuntu" "15.10") libre2-dev pkg-config
+
+
+#include "inspircd.h"
+#include "modules/regex.h"
+
+// Fix warnings about the use of `long long` on C++03 and
+// shadowing on GCC.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-long-long"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-Wlong-long"
+# pragma GCC diagnostic ignored "-Wshadow"
+#endif
+
+#include <re2/re2.h>
+
+class RE2Regex : public Regex
+{
+ RE2 regexcl;
+
+ public:
+ RE2Regex(const std::string& rx) : Regex(rx), regexcl(rx, RE2::Quiet)
+ {
+ if (!regexcl.ok())
+ {
+ throw RegexException(rx, regexcl.error());
+ }
+ }
+
+ bool Matches(const std::string& text) CXX11_OVERRIDE
+ {
+ return RE2::FullMatch(text, regexcl);
+ }
+};
+
+class RE2Factory : public RegexFactory
+{
+ public:
+ RE2Factory(Module* m) : RegexFactory(m, "regex/re2") { }
+ Regex* Create(const std::string& expr) CXX11_OVERRIDE
+ {
+ return new RE2Regex(expr);
+ }
+};
+
+class ModuleRegexRE2 : public Module
+{
+ RE2Factory ref;
+
+ public:
+ ModuleRegexRE2() : ref(this)
+ {
+ }
+
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Regex Provider Module for RE2", VF_VENDOR);
+ }
+};
+
+MODULE_INIT(ModuleRegexRE2)
diff --git a/src/modules/extra/m_regex_stdlib.cpp b/src/modules/extra/m_regex_stdlib.cpp
index 204728b65..42e5c8bf1 100644
--- a/src/modules/extra/m_regex_stdlib.cpp
+++ b/src/modules/extra/m_regex_stdlib.cpp
@@ -15,32 +15,19 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-
-#include "inspircd.h"
-#include "m_regex.h"
-#include <regex>
-/* $ModDesc: Regex Provider Module for std::regex Regular Expressions */
-/* $ModConfig: <stdregex type="ecmascript">
- * Specify the Regular Expression engine to use here. Valid settings are
- * bre, ere, awk, grep, egrep, ecmascript (default if not specified)*/
-/* $CompileFlags: -std=c++11 */
-/* $ModDep: m_regex.h */
+/// $CompilerFlags: -std=c++11
-class StdRegexException : public ModuleException
-{
-public:
- StdRegexException(const std::string& rx, const std::string& error)
- : ModuleException(std::string("Error in regex ") + rx + ": " + error)
- {
- }
-};
+
+#include "inspircd.h"
+#include "modules/regex.h"
+#include <regex>
class StdRegex : public Regex
{
-private:
std::regex regexcl;
-public:
+
+ public:
StdRegex(const std::string& rx, std::regex::flag_type fltype) : Regex(rx)
{
try{
@@ -48,11 +35,11 @@ public:
}
catch(std::regex_error rxerr)
{
- throw StdRegexException(rx, rxerr.what());
+ throw RegexException(rx, rxerr.what());
}
}
-
- virtual bool Matches(const std::string& text)
+
+ bool Matches(const std::string& text) CXX11_OVERRIDE
{
return std::regex_search(text, regexcl);
}
@@ -63,7 +50,7 @@ class StdRegexFactory : public RegexFactory
public:
std::regex::flag_type regextype;
StdRegexFactory(Module* m) : RegexFactory(m, "regex/stdregex") {}
- Regex* Create(const std::string& expr)
+ Regex* Create(const std::string& expr) CXX11_OVERRIDE
{
return new StdRegex(expr, regextype);
}
@@ -73,36 +60,33 @@ class ModuleRegexStd : public Module
{
public:
StdRegexFactory ref;
- ModuleRegexStd() : ref(this) {
- ServerInstance->Modules->AddService(ref);
- Implementation eventlist[] = { I_OnRehash };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- OnRehash(NULL);
+ ModuleRegexStd() : ref(this)
+ {
}
- Version GetVersion()
+ Version GetVersion() CXX11_OVERRIDE
{
return Version("Regex Provider Module for std::regex", VF_VENDOR);
}
-
- void OnRehash(User* u)
+
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
{
ConfigTag* Conf = ServerInstance->Config->ConfValue("stdregex");
std::string regextype = Conf->getString("type", "ecmascript");
-
- if(regextype == "bre")
+
+ if (stdalgo::string::equalsci(regextype, "bre"))
ref.regextype = std::regex::basic;
- else if(regextype == "ere")
+ else if (stdalgo::string::equalsci(regextype, "ere"))
ref.regextype = std::regex::extended;
- else if(regextype == "awk")
+ else if (stdalgo::string::equalsci(regextype, "awk"))
ref.regextype = std::regex::awk;
- else if(regextype == "grep")
+ else if (stdalgo::string::equalsci(regextype, "grep"))
ref.regextype = std::regex::grep;
- else if(regextype == "egrep")
+ else if (stdalgo::string::equalsci(regextype, "egrep"))
ref.regextype = std::regex::egrep;
else
{
- if(regextype != "ecmascript")
+ if (!stdalgo::string::equalsci(regextype, "ecmascript"))
ServerInstance->SNO->WriteToSnoMask('a', "WARNING: Non-existent regex engine '%s' specified. Falling back to ECMAScript.", regextype.c_str());
ref.regextype = std::regex::ECMAScript;
}
diff --git a/src/modules/extra/m_regex_tre.cpp b/src/modules/extra/m_regex_tre.cpp
index 4b9eab472..aa3f1d41e 100644
--- a/src/modules/extra/m_regex_tre.cpp
+++ b/src/modules/extra/m_regex_tre.cpp
@@ -17,29 +17,20 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: find_compiler_flags("tre")
+/// $LinkerFlags: find_linker_flags("tre" "-ltre")
+
+/// $PackageInfo: require_system("darwin") pkg-config tre
+/// $PackageInfo: require_system("debian") libtre-dev pkg-config
+/// $PackageInfo: require_system("ubuntu") libtre-dev pkg-config
#include "inspircd.h"
-#include "m_regex.h"
+#include "modules/regex.h"
#include <sys/types.h>
#include <tre/regex.h>
-/* $ModDesc: Regex Provider Module for TRE Regular Expressions */
-/* $CompileFlags: pkgconfincludes("tre","tre/regex.h","") */
-/* $LinkerFlags: pkgconflibs("tre","/libtre.so","-ltre") rpath("pkg-config --libs tre") */
-/* $ModDep: m_regex.h */
-
-class TRERegexException : public ModuleException
-{
-public:
- TRERegexException(const std::string& rx, const std::string& error)
- : ModuleException("Error in regex " + rx + ": " + error)
- {
- }
-};
-
class TRERegex : public Regex
{
-private:
regex_t regbuf;
public:
@@ -60,30 +51,26 @@ public:
error = errbuf;
delete[] errbuf;
regfree(&regbuf);
- throw TRERegexException(rx, error);
+ throw RegexException(rx, error);
}
}
- virtual ~TRERegex()
+ ~TRERegex()
{
regfree(&regbuf);
}
- virtual bool Matches(const std::string& text)
+ bool Matches(const std::string& text) CXX11_OVERRIDE
{
- if (regexec(&regbuf, text.c_str(), 0, NULL, 0) == 0)
- {
- // Bang. :D
- return true;
- }
- return false;
+ return (regexec(&regbuf, text.c_str(), 0, NULL, 0) == 0);
}
};
-class TREFactory : public RegexFactory {
+class TREFactory : public RegexFactory
+{
public:
TREFactory(Module* m) : RegexFactory(m, "regex/tre") {}
- Regex* Create(const std::string& expr)
+ Regex* Create(const std::string& expr) CXX11_OVERRIDE
{
return new TRERegex(expr);
}
@@ -92,18 +79,15 @@ class TREFactory : public RegexFactory {
class ModuleRegexTRE : public Module
{
TREFactory trf;
-public:
- ModuleRegexTRE() : trf(this) {
- ServerInstance->Modules->AddService(trf);
- }
- Version GetVersion()
+ public:
+ ModuleRegexTRE() : trf(this)
{
- return Version("Regex Provider Module for TRE Regular Expressions", VF_VENDOR);
}
- ~ModuleRegexTRE()
+ Version GetVersion() CXX11_OVERRIDE
{
+ return Version("Regex Provider Module for TRE Regular Expressions", VF_VENDOR);
}
};
diff --git a/src/modules/extra/m_sqlite3.cpp b/src/modules/extra/m_sqlite3.cpp
index 47880c02c..b46810062 100644
--- a/src/modules/extra/m_sqlite3.cpp
+++ b/src/modules/extra/m_sqlite3.cpp
@@ -19,45 +19,51 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: find_compiler_flags("sqlite3")
+/// $LinkerFlags: find_linker_flags("sqlite3" "-lsqlite3")
+
+/// $PackageInfo: require_system("centos") pkgconfig sqlite-devel
+/// $PackageInfo: require_system("darwin") pkg-config sqlite3
+/// $PackageInfo: require_system("debian") libsqlite3-dev pkg-config
+/// $PackageInfo: require_system("ubuntu") libsqlite3-dev pkg-config
#include "inspircd.h"
+#include "modules/sql.h"
+
+// Fix warnings about the use of `long long` on C++03.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-long-long"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-Wlong-long"
+#endif
+
#include <sqlite3.h>
-#include "sql.h"
#ifdef _WIN32
# pragma comment(lib, "sqlite3.lib")
#endif
-/* $ModDesc: sqlite3 provider */
-/* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
-/* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
-/* $NoPedantic */
-
class SQLConn;
-typedef std::map<std::string, SQLConn*> ConnMap;
+typedef insp::flat_map<std::string, SQLConn*> ConnMap;
-class SQLite3Result : public SQLResult
+class SQLite3Result : public SQL::Result
{
public:
int currentrow;
int rows;
std::vector<std::string> columns;
- std::vector<SQLEntries> fieldlists;
+ std::vector<SQL::Row> fieldlists;
SQLite3Result() : currentrow(0), rows(0)
{
}
- ~SQLite3Result()
- {
- }
-
- virtual int Rows()
+ int Rows() CXX11_OVERRIDE
{
return rows;
}
- virtual bool GetRow(SQLEntries& result)
+ bool GetRow(SQL::Row& result) CXX11_OVERRIDE
{
if (currentrow < rows)
{
@@ -72,20 +78,32 @@ class SQLite3Result : public SQLResult
}
}
- virtual void GetCols(std::vector<std::string>& result)
+ void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
{
result.assign(columns.begin(), columns.end());
}
+
+ bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
+ {
+ for (size_t i = 0; i < columns.size(); ++i)
+ {
+ if (columns[i] == column)
+ {
+ index = i;
+ return true;
+ }
+ }
+ return false;
+ }
};
-class SQLConn : public SQLProvider
+class SQLConn : public SQL::Provider
{
- private:
sqlite3* conn;
reference<ConfigTag> config;
public:
- SQLConn(Module* Parent, ConfigTag* tag) : SQLProvider(Parent, "SQL/" + tag->getString("id")), config(tag)
+ SQLConn(Module* Parent, ConfigTag* tag) : SQL::Provider(Parent, "SQL/" + tag->getString("id")), config(tag)
{
std::string host = tag->getString("hostname");
if (sqlite3_open_v2(host.c_str(), &conn, SQLITE_OPEN_READWRITE, 0) != SQLITE_OK)
@@ -93,7 +111,7 @@ class SQLConn : public SQLProvider
// Even in case of an error conn must be closed
sqlite3_close(conn);
conn = NULL;
- ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id"));
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id"));
}
}
@@ -106,14 +124,14 @@ class SQLConn : public SQLProvider
}
}
- void Query(SQLQuery* query, const std::string& q)
+ void Query(SQL::Query* query, const std::string& q)
{
SQLite3Result res;
sqlite3_stmt *stmt;
int err = sqlite3_prepare_v2(conn, q.c_str(), q.length(), &stmt, NULL);
if (err != SQLITE_OK)
{
- SQLerror error(SQL_QSEND_FAIL, sqlite3_errmsg(conn));
+ SQL::Error error(SQL::QSEND_FAIL, sqlite3_errmsg(conn));
query->OnError(error);
return;
}
@@ -135,7 +153,7 @@ class SQLConn : public SQLProvider
{
const char* txt = (const char*)sqlite3_column_text(stmt, i);
if (txt)
- res.fieldlists[res.rows][i] = SQLEntry(txt);
+ res.fieldlists[res.rows][i] = SQL::Field(txt);
}
res.rows++;
}
@@ -146,7 +164,7 @@ class SQLConn : public SQLProvider
}
else
{
- SQLerror error(SQL_QREPLY_FAIL, sqlite3_errmsg(conn));
+ SQL::Error error(SQL::QREPLY_FAIL, sqlite3_errmsg(conn));
query->OnError(error);
break;
}
@@ -154,13 +172,13 @@ class SQLConn : public SQLProvider
sqlite3_finalize(stmt);
}
- virtual void submit(SQLQuery* query, const std::string& q)
+ void Submit(SQL::Query* query, const std::string& q) CXX11_OVERRIDE
{
Query(query, q);
delete query;
}
- virtual void submit(SQLQuery* query, const std::string& q, const ParamL& p)
+ void Submit(SQL::Query* query, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
{
std::string res;
unsigned int param = 0;
@@ -178,10 +196,10 @@ class SQLConn : public SQLProvider
}
}
}
- submit(query, res);
+ Submit(query, res);
}
- virtual void submit(SQLQuery* query, const std::string& q, const ParamM& p)
+ void Submit(SQL::Query* query, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
{
std::string res;
for(std::string::size_type i = 0; i < q.length(); i++)
@@ -196,7 +214,7 @@ class SQLConn : public SQLProvider
field.push_back(q[i++]);
i--;
- ParamM::const_iterator it = p.find(field);
+ SQL::ParamMap::const_iterator it = p.find(field);
if (it != p.end())
{
char* escaped = sqlite3_mprintf("%q", it->second.c_str());
@@ -205,29 +223,16 @@ class SQLConn : public SQLProvider
}
}
}
- submit(query, res);
+ Submit(query, res);
}
};
class ModuleSQLite3 : public Module
{
- private:
ConnMap conns;
public:
- ModuleSQLite3()
- {
- }
-
- void init()
- {
- ReadConf();
-
- Implementation eventlist[] = { I_OnRehash };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- }
-
- virtual ~ModuleSQLite3()
+ ~ModuleSQLite3()
{
ClearConns();
}
@@ -243,13 +248,13 @@ class ModuleSQLite3 : public Module
conns.clear();
}
- void ReadConf()
+ void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
{
ClearConns();
ConfigTagList tags = ServerInstance->Config->ConfTags("database");
for(ConfigIter i = tags.first; i != tags.second; i++)
{
- if (i->second->getString("module", "sqlite") != "sqlite")
+ if (!stdalgo::string::equalsci(i->second->getString("provider"), "sqlite"))
continue;
SQLConn* conn = new SQLConn(this, i->second);
conns.insert(std::make_pair(i->second->getString("id"), conn));
@@ -257,12 +262,7 @@ class ModuleSQLite3 : public Module
}
}
- void OnRehash(User* user)
- {
- ReadConf();
- }
-
- Version GetVersion()
+ Version GetVersion() CXX11_OVERRIDE
{
return Version("sqlite3 provider", VF_VENDOR);
}
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index 2f4acf3f0..8bd73b2bb 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -20,120 +20,98 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+/// $CompilerFlags: find_compiler_flags("gnutls")
+/// $CompilerFlags: require_version("gnutls" "1.0" "2.12") execute("libgcrypt-config --cflags" "LIBGCRYPT_CXXFLAGS")
-#include "inspircd.h"
-#include <gnutls/gnutls.h>
-#include <gnutls/x509.h>
-#include "ssl.h"
-#include "m_cap.h"
+/// $LinkerFlags: find_linker_flags("gnutls" "-lgnutls")
+/// $LinkerFlags: require_version("gnutls" "1.0" "2.12") execute("libgcrypt-config --libs" "LIBGCRYPT_LDFLAGS")
-#ifdef _WIN32
-# pragma comment(lib, "libgnutls-30.lib")
+/// $PackageInfo: require_system("centos") gnutls-devel pkgconfig
+/// $PackageInfo: require_system("darwin") gnutls pkg-config
+/// $PackageInfo: require_system("debian" "1.0" "7.99") libgcrypt11-dev
+/// $PackageInfo: require_system("debian") gnutls-bin libgnutls28-dev pkg-config
+/// $PackageInfo: require_system("ubuntu" "1.0" "13.10") libgcrypt11-dev
+/// $PackageInfo: require_system("ubuntu") gnutls-bin libgnutls-dev pkg-config
+
+#include "inspircd.h"
+#include "modules/ssl.h"
+#include <memory>
+
+// Fix warnings about the use of commas at end of enumerator lists on C++03.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-extensions"
+#elif defined __GNUC__
+# if __GNUC__ < 6
+# pragma GCC diagnostic ignored "-pedantic"
+# endif
#endif
-/* $ModDesc: Provides SSL support for clients */
-/* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") iflt("pkg-config --modversion gnutls","2.12") exec("libgcrypt-config --cflags") */
-/* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") iflt("pkg-config --modversion gnutls","2.12") exec("libgcrypt-config --libs") */
-/* $NoPedantic */
+// Fix warnings about using std::auto_ptr on C++11 or newer.
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
-#ifndef GNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MAJOR LIBGNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MINOR LIBGNUTLS_VERSION_MINOR
-#define GNUTLS_VERSION_PATCH LIBGNUTLS_VERSION_PATCH
-#endif
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
-// These don't exist in older GnuTLS versions
-#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 1) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 1 && GNUTLS_VERSION_PATCH >= 7))
-#define GNUTLS_NEW_PRIO_API
+#ifndef GNUTLS_VERSION_NUMBER
+#define GNUTLS_VERSION_NUMBER LIBGNUTLS_VERSION_NUMBER
+#define GNUTLS_VERSION LIBGNUTLS_VERSION
#endif
-#if(GNUTLS_VERSION_MAJOR < 2)
-typedef gnutls_certificate_credentials_t gnutls_certificate_credentials;
-typedef gnutls_dh_params_t gnutls_dh_params;
+// Check if the GnuTLS library is at least version major.minor.patch
+#define INSPIRCD_GNUTLS_HAS_VERSION(major, minor, patch) (GNUTLS_VERSION_NUMBER >= ((major << 16) | (minor << 8) | patch))
+
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 9, 8)
+#define GNUTLS_HAS_MAC_GET_ID
+#include <gnutls/crypto.h>
#endif
-#if (GNUTLS_VERSION_MAJOR > 2 || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR >= 12))
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
# define GNUTLS_HAS_RND
-# include <gnutls/crypto.h>
#else
# include <gcrypt.h>
#endif
-enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
-
-struct SSLConfig : public refcountbase
-{
- gnutls_certificate_credentials_t x509_cred;
- std::vector<gnutls_x509_crt_t> x509_certs;
- gnutls_x509_privkey_t x509_key;
- gnutls_dh_params_t dh_params;
-#ifdef GNUTLS_NEW_PRIO_API
- gnutls_priority_t priority;
+#ifdef _WIN32
+# pragma comment(lib, "libgnutls-30.lib")
#endif
- SSLConfig()
- : x509_cred(NULL)
- , x509_key(NULL)
- , dh_params(NULL)
-#ifdef GNUTLS_NEW_PRIO_API
- , priority(NULL)
+// These don't exist in older GnuTLS versions
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 1, 7)
+#define GNUTLS_NEW_PRIO_API
#endif
- {
- }
-
- ~SSLConfig()
- {
- ServerInstance->Logs->Log("m_ssl_gnutls", DEBUG, "Destroying SSLConfig %p", (void*)this);
-
- if (x509_cred)
- gnutls_certificate_free_credentials(x509_cred);
- for (unsigned int i = 0; i < x509_certs.size(); i++)
- gnutls_x509_crt_deinit(x509_certs[i]);
+enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_HANDSHAKEN };
- if (x509_key)
- gnutls_x509_privkey_deinit(x509_key);
-
- if (dh_params)
- gnutls_dh_params_deinit(dh_params);
-
-#ifdef GNUTLS_NEW_PRIO_API
- if (priority)
- gnutls_priority_deinit(priority);
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
+#define INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
+#define GNUTLS_NEW_CERT_CALLBACK_API
+typedef gnutls_retr2_st cert_cb_last_param_type;
+#else
+typedef gnutls_retr_st cert_cb_last_param_type;
#endif
- }
-};
-
-static reference<SSLConfig> currconf;
-
-static SSLConfig* GetSessionConfig(gnutls_session_t session);
-#if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) )
-static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs,
- const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr_st * st) {
+#if INSPIRCD_GNUTLS_HAS_VERSION(3, 3, 5)
+#define INSPIRCD_GNUTLS_HAS_RECV_PACKET
+#endif
- st->type = GNUTLS_CRT_X509;
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 99, 0)
+// The second parameter of gnutls_init() has changed in 2.99.0 from gnutls_connection_end_t to unsigned int
+// (it became a general flags parameter) and the enum has been deprecated and generates a warning on use.
+typedef unsigned int inspircd_gnutls_session_init_flags_t;
#else
-static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs,
- const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr2_st * st) {
- st->cert_type = GNUTLS_CRT_X509;
- st->key_type = GNUTLS_PRIVKEY_X509;
+typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
#endif
- SSLConfig* conf = GetSessionConfig(session);
- std::vector<gnutls_x509_crt_t>& x509_certs = conf->x509_certs;
- st->ncerts = x509_certs.size();
- st->cert.x509 = &x509_certs[0];
- st->key.x509 = conf->x509_key;
- st->deinit_all = 0;
- return 0;
-}
+#if INSPIRCD_GNUTLS_HAS_VERSION(3, 1, 9)
+#define INSPIRCD_GNUTLS_HAS_CORK
+#endif
-class RandGen : public HandlerBase2<void, char*, size_t>
+static Module* thismod;
+
+class RandGen
{
public:
- RandGen() {}
- void Call(char* buffer, size_t len)
+ static void Call(char* buffer, size_t len)
{
#ifdef GNUTLS_HAS_RND
gnutls_rnd(GNUTLS_RND_RANDOM, buffer, len);
@@ -143,749 +121,677 @@ class RandGen : public HandlerBase2<void, char*, size_t>
}
};
-/** Represents an SSL user's extra data
- */
-class issl_session
-{
-public:
- StreamSocket* socket;
- gnutls_session_t sess;
- issl_status status;
- reference<ssl_cert> cert;
- reference<SSLConfig> config;
-
- issl_session() : socket(NULL), sess(NULL), status(ISSL_NONE) {}
-};
-
-static SSLConfig* GetSessionConfig(gnutls_session_t sess)
+namespace GnuTLS
{
- issl_session* session = reinterpret_cast<issl_session*>(gnutls_transport_get_ptr(sess));
- return session->config;
-}
-
-class CommandStartTLS : public SplitCommand
-{
- public:
- bool enabled;
- CommandStartTLS (Module* mod) : SplitCommand(mod, "STARTTLS")
+ class Init
{
- enabled = true;
- works_before_reg = true;
- }
+ public:
+ Init() { gnutls_global_init(); }
+ ~Init() { gnutls_global_deinit(); }
+ };
- CmdResult HandleLocal(const std::vector<std::string> &parameters, LocalUser *user)
+ class Exception : public ModuleException
{
- if (!enabled)
- {
- user->WriteNumeric(691, "%s :STARTTLS is not enabled", user->nick.c_str());
- return CMD_FAILURE;
- }
+ public:
+ Exception(const std::string& reason)
+ : ModuleException(reason) { }
+ };
- if (user->registered == REG_ALL)
- {
- user->WriteNumeric(691, "%s :STARTTLS is not permitted after client registration is complete", user->nick.c_str());
- }
- else
+ void ThrowOnError(int errcode, const char* msg)
+ {
+ if (errcode < 0)
{
- if (!user->eh.GetIOHook())
- {
- user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
- /* We need to flush the write buffer prior to adding the IOHook,
- * otherwise we'll be sending this line inside the SSL session - which
- * won't start its handshake until the client gets this line. Currently,
- * we assume the write will not block here; this is usually safe, as
- * STARTTLS is sent very early on in the registration phase, where the
- * user hasn't built up much sendq. Handling a blocked write here would
- * be very annoying.
- */
- user->eh.DoWrite();
- user->eh.AddIOHook(creator);
- creator->OnStreamSocketAccept(&user->eh, NULL, NULL);
- }
- else
- user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str());
+ std::string reason = msg;
+ reason.append(" :").append(gnutls_strerror(errcode));
+ throw Exception(reason);
}
-
- return CMD_FAILURE;
}
-};
-
-class ModuleSSLGnuTLS : public Module
-{
- issl_session* sessions;
- gnutls_digest_algorithm_t hash;
-
- std::string sslports;
- int dh_bits;
+ /** Used to create a gnutls_datum_t* from a std::string
+ */
+ class Datum
+ {
+ gnutls_datum_t datum;
- RandGen randhandler;
- CommandStartTLS starttls;
+ public:
+ Datum(const std::string& dat)
+ {
+ datum.data = (unsigned char*)dat.data();
+ datum.size = static_cast<unsigned int>(dat.length());
+ }
- GenericCap capHandler;
- ServiceProvider iohook;
+ const gnutls_datum_t* get() const { return &datum; }
+ };
- inline static const char* UnknownIfNULL(const char* str)
+ class Hash
{
- return str ? str : "UNKNOWN";
- }
+ gnutls_digest_algorithm_t hash;
- static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size)
- {
- issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
- if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK)
+ public:
+ // Nothing to deallocate, constructor may throw freely
+ Hash(const std::string& hashname)
{
-#ifdef _WIN32
- gnutls_transport_set_errno(session->sess, EAGAIN);
+ // As older versions of gnutls can't do this, let's disable it where needed.
+#ifdef GNUTLS_HAS_MAC_GET_ID
+ // As gnutls_digest_algorithm_t and gnutls_mac_algorithm_t are mapped 1:1, we can do this
+ // There is no gnutls_dig_get_id() at the moment, but it may come later
+ hash = (gnutls_digest_algorithm_t)gnutls_mac_get_id(hashname.c_str());
+ if (hash == GNUTLS_DIG_UNKNOWN)
+ throw Exception("Unknown hash type " + hashname);
+
+ // Check if the user is giving us something that is a valid MAC but not digest
+ gnutls_hash_hd_t is_digest;
+ if (gnutls_hash_init(&is_digest, hash) < 0)
+ throw Exception("Unknown hash type " + hashname);
+ gnutls_hash_deinit(is_digest, NULL);
#else
- errno = EAGAIN;
+ if (stdalgo::string::equalsci(hashname, "md5"))
+ hash = GNUTLS_DIG_MD5;
+ else if (stdalgo::string::equalsci(hashname, "sha1"))
+ hash = GNUTLS_DIG_SHA1;
+#ifdef INSPIRCD_GNUTLS_ENABLE_SHA256_FINGERPRINT
+ else if (stdalgo::string::equalsci(hashname, "sha256"))
+ hash = GNUTLS_DIG_SHA256;
+#endif
+ else
+ throw Exception("Unknown hash type " + hashname);
#endif
- return -1;
}
- int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0);
+ gnutls_digest_algorithm_t get() const { return hash; }
+ };
-#ifdef _WIN32
- if (rv < 0)
+ class DHParams
+ {
+ gnutls_dh_params_t dh_params;
+
+ DHParams()
{
- /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
- * and then set errno appropriately.
- * The gnutls library may also have a different errno variable than us, see
- * gnutls_transport_set_errno(3).
- */
- gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
+ ThrowOnError(gnutls_dh_params_init(&dh_params), "gnutls_dh_params_init() failed");
}
-#endif
- if (rv < (int)size)
- ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK);
- return rv;
- }
-
- static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size)
- {
- issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
- if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ public:
+ /** Import */
+ static std::auto_ptr<DHParams> Import(const std::string& dhstr)
{
-#ifdef _WIN32
- gnutls_transport_set_errno(session->sess, EAGAIN);
-#else
- errno = EAGAIN;
-#endif
- return -1;
+ std::auto_ptr<DHParams> dh(new DHParams);
+ int ret = gnutls_dh_params_import_pkcs3(dh->dh_params, Datum(dhstr).get(), GNUTLS_X509_FMT_PEM);
+ ThrowOnError(ret, "Unable to import DH params");
+ return dh;
}
- int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0);
-
-#ifdef _WIN32
- if (rv < 0)
+ ~DHParams()
{
- /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
- * and then set errno appropriately.
- * The gnutls library may also have a different errno variable than us, see
- * gnutls_transport_set_errno(3).
- */
- gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
+ gnutls_dh_params_deinit(dh_params);
}
-#endif
- if (rv < (int)size)
- ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK);
- return rv;
- }
-
- public:
+ const gnutls_dh_params_t& get() const { return dh_params; }
+ };
- ModuleSSLGnuTLS()
- : starttls(this), capHandler(this, "tls"), iohook(this, "ssl/gnutls", SERVICE_IOHOOK)
+ class X509Key
{
-#ifndef GNUTLS_HAS_RND
- gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
-#endif
-
- sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
-
- gnutls_global_init(); // This must be called once in the program
- }
+ /** Ensure that the key is deinited in case the constructor of X509Key throws
+ */
+ class RAIIKey
+ {
+ public:
+ gnutls_x509_privkey_t key;
- void init()
- {
- currconf = new SSLConfig;
- InitSSLConfig(currconf);
+ RAIIKey()
+ {
+ ThrowOnError(gnutls_x509_privkey_init(&key), "gnutls_x509_privkey_init() failed");
+ }
- ServerInstance->GenRandom = &randhandler;
+ ~RAIIKey()
+ {
+ gnutls_x509_privkey_deinit(key);
+ }
+ } key;
- Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnUserConnect,
- I_OnEvent, I_OnHookIO, I_OnCheckReady };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
+ public:
+ /** Import */
+ X509Key(const std::string& keystr)
+ {
+ int ret = gnutls_x509_privkey_import(key.key, Datum(keystr).get(), GNUTLS_X509_FMT_PEM);
+ ThrowOnError(ret, "Unable to import private key");
+ }
- ServerInstance->Modules->AddService(iohook);
- ServerInstance->Modules->AddService(starttls);
- }
+ gnutls_x509_privkey_t& get() { return key.key; }
+ };
- void OnRehash(User* user)
+ class X509CertList
{
- sslports.clear();
+ std::vector<gnutls_x509_crt_t> certs;
- ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls");
- starttls.enabled = Conf->getBool("starttls", true);
-
- if (Conf->getBool("showports", true))
+ public:
+ /** Import */
+ X509CertList(const std::string& certstr)
{
- sslports = Conf->getString("advertisedports");
- if (!sslports.empty())
- return;
+ unsigned int certcount = 3;
+ certs.resize(certcount);
+ Datum datum(certstr);
- for (size_t i = 0; i < ServerInstance->ports.size(); i++)
+ int ret = gnutls_x509_crt_list_import(raw(), &certcount, datum.get(), GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
+ if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER)
{
- ListenSocket* port = ServerInstance->ports[i];
- if (port->bind_tag->getString("ssl") != "gnutls")
- continue;
-
- const std::string& portid = port->bind_desc;
- ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %s", portid.c_str());
-
- if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1")
- {
- /*
- * Found an SSL port for clients that is not bound to 127.0.0.1 and handled by us, display
- * the IP:port in ISUPPORT.
- *
- * We used to advertise all ports seperated by a ';' char that matched the above criteria,
- * but this resulted in too long ISUPPORT lines if there were lots of ports to be displayed.
- * To solve this by default we now only display the first IP:port found and let the user
- * configure the exact value for the 005 token, if necessary.
- */
- sslports = portid;
- break;
- }
+ // the buffer wasn't big enough to hold all certs but gnutls changed certcount to the number of available certs,
+ // try again with a bigger buffer
+ certs.resize(certcount);
+ ret = gnutls_x509_crt_list_import(raw(), &certcount, datum.get(), GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
}
- }
- }
- void OnModuleRehash(User* user, const std::string &param)
- {
- if(param != "ssl")
- return;
+ ThrowOnError(ret, "Unable to load certificates");
- reference<SSLConfig> newconf = new SSLConfig;
- try
- {
- InitSSLConfig(newconf);
+ // Resize the vector to the actual number of certs because we rely on its size being correct
+ // when deallocating the certs
+ certs.resize(certcount);
}
- catch (ModuleException& ex)
+
+ ~X509CertList()
{
- ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls: Not applying new config. %s", ex.GetReason());
- return;
+ for (std::vector<gnutls_x509_crt_t>::iterator i = certs.begin(); i != certs.end(); ++i)
+ gnutls_x509_crt_deinit(*i);
}
- ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls: Applying new config, old config is in use by %d connection(s)", currconf->GetReferenceCount()-1);
- currconf = newconf;
- }
+ gnutls_x509_crt_t* raw() { return &certs[0]; }
+ unsigned int size() const { return certs.size(); }
+ };
- void InitSSLConfig(SSLConfig* config)
+ class X509CRL : public refcountbase
{
- ServerInstance->Logs->Log("m_ssl_gnutls", DEBUG, "Initializing new SSLConfig %p", (void*)config);
-
- std::string keyfile;
- std::string certfile;
- std::string cafile;
- std::string crlfile;
- OnRehash(NULL);
-
- ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls");
-
- cafile = Conf->getString("cafile", CONFIG_PATH "/ca.pem");
- crlfile = Conf->getString("crlfile", CONFIG_PATH "/crl.pem");
- certfile = Conf->getString("certfile", CONFIG_PATH "/cert.pem");
- keyfile = Conf->getString("keyfile", CONFIG_PATH "/key.pem");
- dh_bits = Conf->getInt("dhbits");
- std::string hashname = Conf->getString("hash", "md5");
-
- // The GnuTLS manual states that the gnutls_set_default_priority()
- // call we used previously when initializing the session is the same
- // as setting the "NORMAL" priority string.
- // Thus if the setting below is not in the config we will behave exactly
- // the same as before, when the priority setting wasn't available.
- std::string priorities = Conf->getString("priority", "NORMAL");
-
- if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
- dh_bits = 1024;
-
- if (hashname == "md5")
- hash = GNUTLS_DIG_MD5;
- else if (hashname == "sha1")
- hash = GNUTLS_DIG_SHA1;
-#ifdef INSPIRCD_GNUTLS_ENABLE_SHA256_FINGERPRINT
- else if (hashname == "sha256")
- hash = GNUTLS_DIG_SHA256;
-#endif
- else
- throw ModuleException("Unknown hash type " + hashname);
-
+ class RAIICRL
+ {
+ public:
+ gnutls_x509_crl_t crl;
- int ret;
+ RAIICRL()
+ {
+ ThrowOnError(gnutls_x509_crl_init(&crl), "gnutls_x509_crl_init() failed");
+ }
- gnutls_certificate_credentials_t& x509_cred = config->x509_cred;
+ ~RAIICRL()
+ {
+ gnutls_x509_crl_deinit(crl);
+ }
+ } crl;
- ret = gnutls_certificate_allocate_credentials(&x509_cred);
- if (ret < 0)
+ public:
+ /** Import */
+ X509CRL(const std::string& crlstr)
{
- // Set to NULL because we can't be sure what value is in it and we must not try to
- // deallocate it in case of an error
- x509_cred = NULL;
- throw ModuleException("Failed to allocate certificate credentials: " + std::string(gnutls_strerror(ret)));
+ int ret = gnutls_x509_crl_import(get(), Datum(crlstr).get(), GNUTLS_X509_FMT_PEM);
+ ThrowOnError(ret, "Unable to load certificate revocation list");
}
- if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
- ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
-
- if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
- ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
+ gnutls_x509_crl_t& get() { return crl.crl; }
+ };
- FileReader reader;
-
- reader.LoadFile(certfile);
- std::string cert_string = reader.Contents();
- gnutls_datum_t cert_datum = { (unsigned char*)cert_string.data(), static_cast<unsigned int>(cert_string.length()) };
+#ifdef GNUTLS_NEW_PRIO_API
+ class Priority
+ {
+ gnutls_priority_t priority;
- reader.LoadFile(keyfile);
- std::string key_string = reader.Contents();
- gnutls_datum_t key_datum = { (unsigned char*)key_string.data(), static_cast<unsigned int>(key_string.length()) };
+ public:
+ Priority(const std::string& priorities)
+ {
+ // Try to set the priorities for ciphers, kex methods etc. to the user supplied string
+ // If the user did not supply anything then the string is already set to "NORMAL"
+ const char* priocstr = priorities.c_str();
+ const char* prioerror;
- std::vector<gnutls_x509_crt_t>& x509_certs = config->x509_certs;
+ int ret = gnutls_priority_init(&priority, priocstr, &prioerror);
+ if (ret < 0)
+ {
+ // gnutls did not understand the user supplied string
+ throw Exception("Unable to initialize priorities to \"" + priorities + "\": " + gnutls_strerror(ret) + " Syntax error at position " + ConvToStr((unsigned int) (prioerror - priocstr)));
+ }
+ }
- // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
- unsigned int certcount = 3;
- x509_certs.resize(certcount);
- ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
- if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER)
+ ~Priority()
{
- // the buffer wasn't big enough to hold all certs but gnutls updated certcount to the number of available certs, try again with a bigger buffer
- x509_certs.resize(certcount);
- ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
+ gnutls_priority_deinit(priority);
}
- if (ret <= 0)
+ void SetupSession(gnutls_session_t sess)
{
- // clear the vector so we won't call gnutls_x509_crt_deinit() on the (uninited) certs later
- x509_certs.clear();
- throw ModuleException("Unable to load GnuTLS server certificate (" + certfile + "): " + ((ret < 0) ? (std::string(gnutls_strerror(ret))) : "No certs could be read"));
+ gnutls_priority_set(sess, priority);
}
- x509_certs.resize(ret);
- gnutls_x509_privkey_t& x509_key = config->x509_key;
- if (gnutls_x509_privkey_init(&x509_key) < 0)
+ static const char* GetDefault()
{
- // Make sure the destructor does not try to deallocate this, see above
- x509_key = NULL;
- throw ModuleException("Unable to initialize private key");
+ return "NORMAL:%SERVER_PRECEDENCE:-VERS-SSL3.0";
}
- if((ret = gnutls_x509_privkey_import(x509_key, &key_datum, GNUTLS_X509_FMT_PEM)) < 0)
- throw ModuleException("Unable to load GnuTLS server private key (" + keyfile + "): " + std::string(gnutls_strerror(ret)));
+ static std::string RemoveUnknownTokens(const std::string& prio)
+ {
+ std::string ret;
+ irc::sepstream ss(prio, ':');
+ for (std::string token; ss.GetToken(token); )
+ {
+ // Save current position so we can revert later if needed
+ const std::string::size_type prevpos = ret.length();
+ // Append next token
+ if (!ret.empty())
+ ret.push_back(':');
+ ret.append(token);
+
+ gnutls_priority_t test;
+ if (gnutls_priority_init(&test, ret.c_str(), NULL) < 0)
+ {
+ // The new token broke the priority string, revert to the previously working one
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Priority string token not recognized: \"%s\"", token.c_str());
+ ret.erase(prevpos);
+ }
+ else
+ {
+ // Worked
+ gnutls_priority_deinit(test);
+ }
+ }
+ return ret;
+ }
+ };
+#else
+ /** Dummy class, used when gnutls_priority_set() is not available
+ */
+ class Priority
+ {
+ public:
+ Priority(const std::string& priorities)
+ {
+ if (priorities != GetDefault())
+ throw Exception("You've set a non-default priority string, but GnuTLS lacks support for it");
+ }
- if((ret = gnutls_certificate_set_x509_key(x509_cred, &x509_certs[0], certcount, x509_key)) < 0)
- throw ModuleException("Unable to set GnuTLS cert/key pair: " + std::string(gnutls_strerror(ret)));
+ static void SetupSession(gnutls_session_t sess)
+ {
+ // Always set the default priorities
+ gnutls_set_default_priority(sess);
+ }
- #ifdef GNUTLS_NEW_PRIO_API
- // Try to set the priorities for ciphers, kex methods etc. to the user supplied string
- // If the user did not supply anything then the string is already set to "NORMAL"
- const char* priocstr = priorities.c_str();
- const char* prioerror;
+ static const char* GetDefault()
+ {
+ return "NORMAL";
+ }
- gnutls_priority_t& priority = config->priority;
- if ((ret = gnutls_priority_init(&priority, priocstr, &prioerror)) < 0)
+ static std::string RemoveUnknownTokens(const std::string& prio)
{
- // gnutls did not understand the user supplied string, log and fall back to the default priorities
- ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set priorities to \"%s\": %s Syntax error at position %u, falling back to default (NORMAL)", priorities.c_str(), gnutls_strerror(ret), (unsigned int) (prioerror - priocstr));
- gnutls_priority_init(&priority, "NORMAL", NULL);
+ // We don't do anything here because only NORMAL is accepted
+ return prio;
}
+ };
+#endif
- #else
- if (priorities != "NORMAL")
- ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: You've set <gnutls:priority> to a value other than the default, but this is only supported with GnuTLS v2.1.7 or newer. Your GnuTLS version is older than that so the option will have no effect.");
- #endif
+ class CertCredentials
+ {
+ /** DH parameters associated with these credentials
+ */
+ std::auto_ptr<DHParams> dh;
- #if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) )
- gnutls_certificate_client_set_retrieve_function (x509_cred, cert_callback);
- #else
- gnutls_certificate_set_retrieve_function (x509_cred, cert_callback);
- #endif
+ protected:
+ gnutls_certificate_credentials_t cred;
- gnutls_dh_params_t& dh_params = config->dh_params;
- ret = gnutls_dh_params_init(&dh_params);
- if (ret < 0)
+ public:
+ CertCredentials()
{
- // Make sure the destructor does not try to deallocate this, see above
- dh_params = NULL;
- ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
- return;
+ ThrowOnError(gnutls_certificate_allocate_credentials(&cred), "Cannot allocate certificate credentials");
}
- std::string dhfile = Conf->getString("dhfile");
- if (!dhfile.empty())
+ ~CertCredentials()
{
- // Try to load DH params from file
- reader.LoadFile(dhfile);
- std::string dhstring = reader.Contents();
- gnutls_datum_t dh_datum = { (unsigned char*)dhstring.data(), static_cast<unsigned int>(dhstring.length()) };
-
- if ((ret = gnutls_dh_params_import_pkcs3(dh_params, &dh_datum, GNUTLS_X509_FMT_PEM)) < 0)
- {
- // File unreadable or GnuTLS was unhappy with the contents, generate the DH primes now
- ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Generating DH parameters because I failed to load them from file '%s': %s", dhfile.c_str(), gnutls_strerror(ret));
- GenerateDHParams(dh_params);
- }
+ gnutls_certificate_free_credentials(cred);
}
- else
+
+ /** Associates these credentials with the session
+ */
+ void SetupSession(gnutls_session_t sess)
{
- GenerateDHParams(dh_params);
+ gnutls_credentials_set(sess, GNUTLS_CRD_CERTIFICATE, cred);
}
- gnutls_certificate_set_dh_params(x509_cred, dh_params);
- }
+ /** Set the given DH parameters to be used with these credentials
+ */
+ void SetDH(std::auto_ptr<DHParams>& DH)
+ {
+ dh = DH;
+ gnutls_certificate_set_dh_params(cred, dh->get());
+ }
+ };
- void GenerateDHParams(gnutls_dh_params_t dh_params)
+ class X509Credentials : public CertCredentials
{
- // Generate Diffie Hellman parameters - for use with DHE
- // kx algorithms. These should be discarded and regenerated
- // once a day, once a week or once a month. Depending on the
- // security requirements.
+ /** Private key
+ */
+ X509Key key;
- int ret;
+ /** Certificate list, presented to the peer
+ */
+ X509CertList certs;
- if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
- ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
- }
+ /** Trusted CA, may be NULL
+ */
+ std::auto_ptr<X509CertList> trustedca;
- ~ModuleSSLGnuTLS()
- {
- currconf = NULL;
+ /** Certificate revocation list, may be NULL
+ */
+ std::auto_ptr<X509CRL> crl;
- gnutls_global_deinit();
- delete[] sessions;
- ServerInstance->GenRandom = &ServerInstance->HandleGenRandom;
- }
+ static int cert_callback(gnutls_session_t session, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st);
- void OnCleanup(int target_type, void* item)
- {
- if(target_type == TYPE_USER)
+ public:
+ X509Credentials(const std::string& certstr, const std::string& keystr)
+ : key(keystr)
+ , certs(certstr)
{
- LocalUser* user = IS_LOCAL(static_cast<User*>(item));
+ // Throwing is ok here, the destructor of Credentials is called in that case
+ int ret = gnutls_certificate_set_x509_key(cred, certs.raw(), certs.size(), key.get());
+ ThrowOnError(ret, "Unable to set cert/key pair");
- if (user && user->eh.GetIOHook() == this)
- {
- // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
- // Potentially there could be multiple SSL modules loaded at once on different ports.
- ServerInstance->Users->QuitUser(user, "SSL module unloading");
- }
+#ifdef GNUTLS_NEW_CERT_CALLBACK_API
+ gnutls_certificate_set_retrieve_function(cred, cert_callback);
+#else
+ gnutls_certificate_client_set_retrieve_function(cred, cert_callback);
+#endif
}
- }
- Version GetVersion()
- {
- return Version("Provides SSL support for clients", VF_VENDOR);
- }
+ /** Sets the trusted CA and the certificate revocation list
+ * to use when verifying certificates
+ */
+ void SetCA(std::auto_ptr<X509CertList>& certlist, std::auto_ptr<X509CRL>& CRL)
+ {
+ // Do nothing if certlist is NULL
+ if (certlist.get())
+ {
+ int ret = gnutls_certificate_set_x509_trust(cred, certlist->raw(), certlist->size());
+ ThrowOnError(ret, "gnutls_certificate_set_x509_trust() failed");
+
+ if (CRL.get())
+ {
+ ret = gnutls_certificate_set_x509_crl(cred, &CRL->get(), 1);
+ ThrowOnError(ret, "gnutls_certificate_set_x509_crl() failed");
+ }
+ trustedca = certlist;
+ crl = CRL;
+ }
+ }
+ };
- void On005Numeric(std::string &output)
+ class DataReader
{
- if (!sslports.empty())
- output.append(" SSL=" + sslports);
- if (starttls.enabled)
- output.append(" STARTTLS");
- }
+ int retval;
+#ifdef INSPIRCD_GNUTLS_HAS_RECV_PACKET
+ gnutls_packet_t packet;
- void OnHookIO(StreamSocket* user, ListenSocket* lsb)
- {
- if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "gnutls")
+ public:
+ DataReader(gnutls_session_t sess)
{
- /* Hook the user with our module */
- user->AddIOHook(this);
+ // Using the packet API avoids the final copy of the data which GnuTLS does if we supply
+ // our own buffer. Instead, we get the buffer containing the data from GnuTLS and copy it
+ // to the recvq directly from there in appendto().
+ retval = gnutls_record_recv_packet(sess, &packet);
}
- }
- void OnRequest(Request& request)
- {
- if (strcmp("GET_SSL_CERT", request.id) == 0)
+ void appendto(std::string& recvq)
{
- SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
- int fd = req.sock->GetFd();
- issl_session* session = &sessions[fd];
+ // Copy data from GnuTLS buffers to recvq
+ gnutls_datum_t datum;
+ gnutls_packet_get(packet, &datum, NULL);
+ recvq.append(reinterpret_cast<const char*>(datum.data), datum.size);
- req.cert = session->cert;
+ gnutls_packet_deinit(packet);
}
- else if (!strcmp("GET_RAW_SSL_SESSION", request.id))
+#else
+ char* const buffer;
+
+ public:
+ DataReader(gnutls_session_t sess)
+ : buffer(ServerInstance->GetReadBuffer())
{
- SSLRawSessionRequest& req = static_cast<SSLRawSessionRequest&>(request);
- if ((req.fd >= 0) && (req.fd < ServerInstance->SE->GetMaxFds()))
- req.data = reinterpret_cast<void*>(sessions[req.fd].sess);
+ // Read data from GnuTLS buffers into ReadBuffer
+ retval = gnutls_record_recv(sess, buffer, ServerInstance->Config->NetBufferSize);
}
- }
- void InitSession(StreamSocket* user, bool me_server)
- {
- issl_session* session = &sessions[user->GetFd()];
-
- gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT);
- session->socket = user;
- session->config = currconf;
-
- #ifdef GNUTLS_NEW_PRIO_API
- gnutls_priority_set(session->sess, currconf->priority);
- #else
- gnutls_set_default_priority(session->sess);
- #endif
- gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, currconf->x509_cred);
- gnutls_dh_set_prime_bits(session->sess, dh_bits);
- gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session));
- gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper);
- gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper);
-
- if (me_server)
- gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
+ void appendto(std::string& recvq)
+ {
+ // Copy data from ReadBuffer to recvq
+ recvq.append(buffer, retval);
+ }
+#endif
- Handshake(session, user);
- }
+ int ret() const { return retval; }
+ };
- void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
+ class Profile
{
- issl_session* session = &sessions[user->GetFd()];
+ /** Name of this profile
+ */
+ const std::string name;
- /* For STARTTLS: Don't try and init a session on a socket that already has a session */
- if (session->sess)
- return;
+ /** X509 certificate(s) and key
+ */
+ X509Credentials x509cred;
- InitSession(user, true);
- }
+ /** The minimum length in bits for the DH prime to be accepted as a client
+ */
+ unsigned int min_dh_bits;
- void OnStreamSocketConnect(StreamSocket* user)
- {
- InitSession(user, false);
- }
+ /** Hashing algorithm to use when generating certificate fingerprints
+ */
+ Hash hash;
- void OnStreamSocketClose(StreamSocket* user)
- {
- CloseSession(&sessions[user->GetFd()]);
- }
+ /** Priorities for ciphers, compression methods, etc.
+ */
+ Priority priority;
- int OnStreamSocketRead(StreamSocket* user, std::string& recvq)
- {
- issl_session* session = &sessions[user->GetFd()];
+ /** Rough max size of records to send
+ */
+ const unsigned int outrecsize;
- if (!session->sess)
+ /** True to request a client certificate as a server
+ */
+ const bool requestclientcert;
+
+ static std::string ReadFile(const std::string& filename)
{
- CloseSession(session);
- user->SetError("No SSL session");
- return -1;
+ FileReader reader(filename);
+ std::string ret = reader.GetString();
+ if (ret.empty())
+ throw Exception("Cannot read file " + filename);
+ return ret;
}
- if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE)
+ static std::string GetPrioStr(const std::string& profilename, ConfigTag* tag)
{
- // The handshake isn't finished, try to finish it.
-
- if(!Handshake(session, user))
+ // Use default priority string if this tag does not specify one
+ std::string priostr = GnuTLS::Priority::GetDefault();
+ bool found = tag->readString("priority", priostr);
+ // If the prio string isn't set in the config don't be strict about the default one because it doesn't work on all versions of GnuTLS
+ if (!tag->getBool("strictpriority", found))
{
- if (session->status != ISSL_CLOSING)
- return 0;
- return -1;
+ std::string stripped = GnuTLS::Priority::RemoveUnknownTokens(priostr);
+ if (stripped.empty())
+ {
+ // Stripping failed, act as if a prio string wasn't set
+ stripped = GnuTLS::Priority::RemoveUnknownTokens(GnuTLS::Priority::GetDefault());
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Priority string for profile \"%s\" contains unknown tokens and stripping it didn't yield a working one either, falling back to \"%s\"", profilename.c_str(), stripped.c_str());
+ }
+ else if ((found) && (stripped != priostr))
+ {
+ // Prio string was set in the config and we ended up with something that works but different
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Priority string for profile \"%s\" contains unknown tokens, stripped to \"%s\"", profilename.c_str(), stripped.c_str());
+ }
+ priostr.swap(stripped);
}
+ return priostr;
}
- // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
-
- if (session->status == ISSL_HANDSHAKEN)
+ public:
+ struct Config
{
- char* buffer = ServerInstance->GetReadBuffer();
- size_t bufsiz = ServerInstance->Config->NetBufferSize;
- int ret = gnutls_record_recv(session->sess, buffer, bufsiz);
- if (ret > 0)
- {
- recvq.append(buffer, ret);
- // Schedule a read if there is still data in the GnuTLS buffer
- if (gnutls_record_check_pending(session->sess) > 0)
- ServerInstance->SE->ChangeEventMask(user, FD_ADD_TRIAL_READ);
- return 1;
- }
- else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
- {
- return 0;
- }
- else if (ret == 0)
+ std::string name;
+
+ std::auto_ptr<X509CertList> ca;
+ std::auto_ptr<X509CRL> crl;
+
+ std::string certstr;
+ std::string keystr;
+ std::auto_ptr<DHParams> dh;
+
+ std::string priostr;
+ unsigned int mindh;
+ std::string hashstr;
+
+ unsigned int outrecsize;
+ bool requestclientcert;
+
+ Config(const std::string& profilename, ConfigTag* tag)
+ : name(profilename)
+ , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+ , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+ , dh(DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem"))))
+ , priostr(GetPrioStr(profilename, tag))
+ , mindh(tag->getUInt("mindhbits", 1024))
+ , hashstr(tag->getString("hash", "md5"))
+ , requestclientcert(tag->getBool("requestclientcert", true))
{
- user->SetError("Connection closed");
- CloseSession(session);
- return -1;
- }
- else
- {
- user->SetError(gnutls_strerror(ret));
- CloseSession(session);
- return -1;
- }
- }
- else if (session->status == ISSL_CLOSING)
- return -1;
-
- return 0;
- }
+ // Load trusted CA and revocation list, if set
+ std::string filename = tag->getString("cafile");
+ if (!filename.empty())
+ {
+ ca.reset(new X509CertList(ReadFile(filename)));
- int OnStreamSocketWrite(StreamSocket* user, std::string& sendq)
- {
- issl_session* session = &sessions[user->GetFd()];
+ filename = tag->getString("crlfile");
+ if (!filename.empty())
+ crl.reset(new X509CRL(ReadFile(filename)));
+ }
- if (!session->sess)
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked
+ outrecsize = tag->getUInt("outrecsize", 2048, 512);
+#else
+ outrecsize = tag->getUInt("outrecsize", 2048, 512, 16384);
+#endif
+ }
+ };
+
+ Profile(Config& config)
+ : name(config.name)
+ , x509cred(config.certstr, config.keystr)
+ , min_dh_bits(config.mindh)
+ , hash(config.hashstr)
+ , priority(config.priostr)
+ , outrecsize(config.outrecsize)
+ , requestclientcert(config.requestclientcert)
{
- CloseSession(session);
- user->SetError("No SSL session");
- return -1;
+ x509cred.SetDH(config.dh);
+ x509cred.SetCA(config.ca, config.crl);
}
-
- if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ)
+ /** Set up the given session with the settings in this profile
+ */
+ void SetupSession(gnutls_session_t sess)
{
- // The handshake isn't finished, try to finish it.
- Handshake(session, user);
- if (session->status != ISSL_CLOSING)
- return 0;
- return -1;
+ priority.SetupSession(sess);
+ x509cred.SetupSession(sess);
+ gnutls_dh_set_prime_bits(sess, min_dh_bits);
+
+ // Request client certificate if enabled and we are a server, no-op if we're a client
+ if (requestclientcert)
+ gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
}
- int ret = 0;
+ const std::string& GetName() const { return name; }
+ X509Credentials& GetX509Credentials() { return x509cred; }
+ gnutls_digest_algorithm_t GetHash() const { return hash.get(); }
+ unsigned int GetOutgoingRecordSize() const { return outrecsize; }
+ };
+}
- if (session->status == ISSL_HANDSHAKEN)
- {
- ret = gnutls_record_send(session->sess, sendq.data(), sendq.length());
+class GnuTLSIOHook : public SSLIOHook
+{
+ private:
+ gnutls_session_t sess;
+ issl_status status;
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ size_t gbuffersize;
+#endif
- if (ret == (int)sendq.length())
- {
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_WRITE);
- return 1;
- }
- else if (ret > 0)
- {
- sendq = sendq.substr(ret);
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
- return 0;
- }
- else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED || ret == 0)
- {
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
- return 0;
- }
- else // (ret < 0)
- {
- user->SetError(gnutls_strerror(ret));
- CloseSession(session);
- return -1;
- }
+ void CloseSession()
+ {
+ if (this->sess)
+ {
+ gnutls_bye(this->sess, GNUTLS_SHUT_WR);
+ gnutls_deinit(this->sess);
}
-
- return 0;
+ sess = NULL;
+ certificate = NULL;
+ status = ISSL_NONE;
}
- bool Handshake(issl_session* session, StreamSocket* user)
+ // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
+ int Handshake(StreamSocket* user)
{
- int ret = gnutls_handshake(session->sess);
+ int ret = gnutls_handshake(this->sess);
if (ret < 0)
{
if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
{
// Handshake needs resuming later, read() or write() would have blocked.
+ this->status = ISSL_HANDSHAKING;
- if(gnutls_record_get_direction(session->sess) == 0)
+ if (gnutls_record_get_direction(this->sess) == 0)
{
// gnutls_handshake() wants to read() again.
- session->status = ISSL_HANDSHAKING_READ;
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
}
else
{
// gnutls_handshake() wants to write() again.
- session->status = ISSL_HANDSHAKING_WRITE;
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
+ SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
}
+
+ return 0;
}
else
{
user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret)));
- CloseSession(session);
- session->status = ISSL_CLOSING;
+ CloseSession();
+ return -1;
}
-
- return false;
}
else
{
// Change the seesion state
- session->status = ISSL_HANDSHAKEN;
+ this->status = ISSL_HANDSHAKEN;
- VerifyCertificate(session,user);
+ VerifyCertificate();
// Finish writing, if any left
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
-
- return true;
- }
- }
-
- void OnUserConnect(LocalUser* user)
- {
- if (user->eh.GetIOHook() == this)
- {
- if (sessions[user->eh.GetFd()].sess)
- {
- const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess;
- std::string cipher = UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)));
- cipher.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-");
- cipher.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess))));
-
- ssl_cert* cert = sessions[user->eh.GetFd()].cert;
- if (cert->fingerprint.empty())
- user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
- else
- user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\""
- " and your SSL fingerprint is %s", user->nick.c_str(), cipher.c_str(), cert->fingerprint.c_str());
- }
- }
- }
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
- void CloseSession(issl_session* session)
- {
- if (session->sess)
- {
- gnutls_bye(session->sess, GNUTLS_SHUT_WR);
- gnutls_deinit(session->sess);
+ return 1;
}
- session->socket = NULL;
- session->sess = NULL;
- session->cert = NULL;
- session->status = ISSL_NONE;
- session->config = NULL;
}
- void VerifyCertificate(issl_session* session, StreamSocket* user)
+ void VerifyCertificate()
{
- if (!session->sess || !user)
- return;
-
- unsigned int status;
+ unsigned int certstatus;
const gnutls_datum_t* cert_list;
int ret;
unsigned int cert_list_size;
gnutls_x509_crt_t cert;
- char name[MAXBUF];
- unsigned char digest[MAXBUF];
+ char str[512];
+ unsigned char digest[512];
size_t digest_size = sizeof(digest);
- size_t name_size = sizeof(name);
+ size_t name_size = sizeof(str);
ssl_cert* certinfo = new ssl_cert;
- session->cert = certinfo;
+ this->certificate = certinfo;
/* This verification function uses the trusted CAs in the credentials
* structure. So you must have installed one or more CA certificates.
*/
- ret = gnutls_certificate_verify_peers2(session->sess, &status);
+ ret = gnutls_certificate_verify_peers2(this->sess, &certstatus);
if (ret < 0)
{
@@ -893,16 +799,16 @@ class ModuleSSLGnuTLS : public Module
return;
}
- certinfo->invalid = (status & GNUTLS_CERT_INVALID);
- certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND);
- certinfo->revoked = (status & GNUTLS_CERT_REVOKED);
- certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA);
+ certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID);
+ certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND);
+ certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED);
+ certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA);
/* Up to here the process is the same for X.509 certificates and
* OpenPGP keys. From now on X.509 certificates are assumed. This can
* be easily extended to work with openpgp keys as well.
*/
- if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
+ if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509)
{
certinfo->error = "No X509 keys sent";
return;
@@ -916,7 +822,7 @@ class ModuleSSLGnuTLS : public Module
}
cert_list_size = 0;
- cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
+ cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size);
if (cert_list == NULL)
{
certinfo->error = "No certificate was found";
@@ -934,31 +840,31 @@ class ModuleSSLGnuTLS : public Module
goto info_done_dealloc;
}
- if (gnutls_x509_crt_get_dn(cert, name, &name_size) == 0)
+ if (gnutls_x509_crt_get_dn(cert, str, &name_size) == 0)
{
std::string& dn = certinfo->dn;
- dn = name;
+ dn = str;
// Make sure there are no chars in the string that we consider invalid
if (dn.find_first_of("\r\n") != std::string::npos)
dn.clear();
}
- name_size = sizeof(name);
- if (gnutls_x509_crt_get_issuer_dn(cert, name, &name_size) == 0)
+ name_size = sizeof(str);
+ if (gnutls_x509_crt_get_issuer_dn(cert, str, &name_size) == 0)
{
std::string& issuer = certinfo->issuer;
- issuer = name;
+ issuer = str;
if (issuer.find_first_of("\r\n") != std::string::npos)
issuer.clear();
}
- if ((ret = gnutls_x509_crt_get_fingerprint(cert, hash, digest, &digest_size)) < 0)
+ if ((ret = gnutls_x509_crt_get_fingerprint(cert, GetProfile().GetHash(), digest, &digest_size)) < 0)
{
certinfo->error = gnutls_strerror(ret);
}
else
{
- certinfo->fingerprint = irc::hex(digest, digest_size);
+ certinfo->fingerprint = BinToHex(digest, digest_size);
}
/* Beware here we do not check for errors.
@@ -972,15 +878,522 @@ info_done_dealloc:
gnutls_x509_crt_deinit(cert);
}
- void OnEvent(Event& ev)
+ // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
+ int PrepareIO(StreamSocket* sock)
+ {
+ if (status == ISSL_HANDSHAKEN)
+ return 1;
+ else if (status == ISSL_HANDSHAKING)
+ {
+ // The handshake isn't finished, try to finish it
+ return Handshake(sock);
+ }
+
+ CloseSession();
+ sock->SetError("No SSL session");
+ return -1;
+ }
+
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ int FlushBuffer(StreamSocket* sock)
+ {
+ // If GnuTLS has some data buffered, write it
+ if (gbuffersize)
+ return HandleWriteRet(sock, gnutls_record_uncork(this->sess, 0));
+ return 1;
+ }
+#endif
+
+ int HandleWriteRet(StreamSocket* sock, int ret)
+ {
+ if (ret > 0)
+ {
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ gbuffersize -= ret;
+ if (gbuffersize)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+#endif
+ return ret;
+ }
+ else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED || ret == 0)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+ else // (ret < 0)
+ {
+ sock->SetError(gnutls_strerror(ret));
+ CloseSession();
+ return -1;
+ }
+ }
+
+ static const char* UnknownIfNULL(const char* str)
+ {
+ return str ? str : "UNKNOWN";
+ }
+
+ static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size)
+ {
+ StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
+#ifdef _WIN32
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
+#endif
+
+ if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
+ {
+#ifdef _WIN32
+ gnutls_transport_set_errno(session->sess, EAGAIN);
+#else
+ errno = EAGAIN;
+#endif
+ return -1;
+ }
+
+ int rv = SocketEngine::Recv(sock, reinterpret_cast<char *>(buffer), size, 0);
+
+#ifdef _WIN32
+ if (rv < 0)
+ {
+ /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
+ * and then set errno appropriately.
+ * The gnutls library may also have a different errno variable than us, see
+ * gnutls_transport_set_errno(3).
+ */
+ gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
+ }
+#endif
+
+ if (rv < (int)size)
+ SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
+ return rv;
+ }
+
+#ifdef INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
+ static ssize_t VectorPush(gnutls_transport_ptr_t transportptr, const giovec_t* iov, int iovcnt)
+ {
+ StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
+#ifdef _WIN32
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
+#endif
+
+ if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ {
+#ifdef _WIN32
+ gnutls_transport_set_errno(session->sess, EAGAIN);
+#else
+ errno = EAGAIN;
+#endif
+ return -1;
+ }
+
+ // Cast the giovec_t to iovec not to IOVector so the correct function is called on Windows
+ int ret = SocketEngine::WriteV(sock, reinterpret_cast<const iovec*>(iov), iovcnt);
+#ifdef _WIN32
+ // See the function above for more info about the usage of gnutls_transport_set_errno() on Windows
+ if (ret < 0)
+ gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
+#endif
+
+ int size = 0;
+ for (int i = 0; i < iovcnt; i++)
+ size += iov[i].iov_len;
+
+ if (ret < size)
+ SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
+ return ret;
+ }
+
+#else // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
+ static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size)
+ {
+ StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
+#ifdef _WIN32
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
+#endif
+
+ if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ {
+#ifdef _WIN32
+ gnutls_transport_set_errno(session->sess, EAGAIN);
+#else
+ errno = EAGAIN;
+#endif
+ return -1;
+ }
+
+ int rv = SocketEngine::Send(sock, reinterpret_cast<const char *>(buffer), size, 0);
+
+#ifdef _WIN32
+ if (rv < 0)
+ {
+ /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
+ * and then set errno appropriately.
+ * The gnutls library may also have a different errno variable than us, see
+ * gnutls_transport_set_errno(3).
+ */
+ gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
+ }
+#endif
+
+ if (rv < (int)size)
+ SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
+ return rv;
+ }
+#endif // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
+
+ public:
+ GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags)
+ : SSLIOHook(hookprov)
+ , sess(NULL)
+ , status(ISSL_NONE)
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ , gbuffersize(0)
+#endif
+ {
+ gnutls_init(&sess, flags);
+ gnutls_transport_set_ptr(sess, reinterpret_cast<gnutls_transport_ptr_t>(sock));
+#ifdef INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
+ gnutls_transport_set_vec_push_function(sess, VectorPush);
+#else
+ gnutls_transport_set_push_function(sess, gnutls_push_wrapper);
+#endif
+ gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper);
+ GetProfile().SetupSession(sess);
+
+ sock->AddIOHook(this);
+ Handshake(sock);
+ }
+
+ void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE
+ {
+ CloseSession();
+ }
+
+ int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE
+ {
+ // Finish handshake if needed
+ int prepret = PrepareIO(user);
+ if (prepret <= 0)
+ return prepret;
+
+ // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
+ {
+ GnuTLS::DataReader reader(sess);
+ int ret = reader.ret();
+ if (ret > 0)
+ {
+ reader.appendto(recvq);
+ // Schedule a read if there is still data in the GnuTLS buffer
+ if (gnutls_record_check_pending(sess) > 0)
+ SocketEngine::ChangeEventMask(user, FD_ADD_TRIAL_READ);
+ return 1;
+ }
+ else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+ {
+ return 0;
+ }
+ else if (ret == 0)
+ {
+ user->SetError("Connection closed");
+ CloseSession();
+ return -1;
+ }
+ else
+ {
+ user->SetError(gnutls_strerror(ret));
+ CloseSession();
+ return -1;
+ }
+ }
+ }
+
+ int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
+ {
+ // Finish handshake if needed
+ int prepret = PrepareIO(user);
+ if (prepret <= 0)
+ return prepret;
+
+ // Session is ready for transferring application data
+
+#ifdef INSPIRCD_GNUTLS_HAS_CORK
+ while (true)
+ {
+ // If there is something in the GnuTLS buffer try to send() it
+ int ret = FlushBuffer(user);
+ if (ret <= 0)
+ return ret; // Couldn't flush entire buffer, retry later (or close on error)
+
+ // GnuTLS buffer is empty, if the sendq is empty as well then break to set FD_WANT_NO_WRITE
+ if (sendq.empty())
+ break;
+
+ // GnuTLS buffer is empty but sendq is not, begin sending data from the sendq
+ gnutls_record_cork(this->sess);
+ while ((!sendq.empty()) && (gbuffersize < GetProfile().GetOutgoingRecordSize()))
+ {
+ const StreamSocket::SendQueue::Element& elem = sendq.front();
+ gbuffersize += elem.length();
+ ret = gnutls_record_send(this->sess, elem.data(), elem.length());
+ if (ret < 0)
+ {
+ CloseSession();
+ return -1;
+ }
+ sendq.pop_front();
+ }
+ }
+#else
+ int ret = 0;
+
+ while (!sendq.empty())
+ {
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
+ const StreamSocket::SendQueue::Element& buffer = sendq.front();
+ ret = HandleWriteRet(user, gnutls_record_send(this->sess, buffer.data(), buffer.length()));
+
+ if (ret <= 0)
+ return ret;
+ else if (ret < (int)buffer.length())
+ {
+ sendq.erase_front(ret);
+ SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+
+ // Wrote entire record, continue sending
+ sendq.pop_front();
+ }
+#endif
+
+ SocketEngine::ChangeEventMask(user, FD_WANT_NO_WRITE);
+ return 1;
+ }
+
+ void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
+ {
+ if (!IsHandshakeDone())
+ return;
+ out.append(UnknownIfNULL(gnutls_protocol_get_name(gnutls_protocol_get_version(sess)))).push_back('-');
+ out.append(UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)))).push_back('-');
+ out.append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).push_back('-');
+ out.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess))));
+ }
+
+ bool GetServerName(std::string& out) const CXX11_OVERRIDE
+ {
+ std::vector<char> nameBuffer;
+ size_t nameLength = 0;
+ unsigned int nameType = GNUTLS_NAME_DNS;
+
+ // First, determine the size of the hostname.
+ if (gnutls_server_name_get(sess, &nameBuffer[0], &nameLength, &nameType, 0) != GNUTLS_E_SHORT_MEMORY_BUFFER)
+ return false;
+
+ // Then retrieve the hostname.
+ nameBuffer.resize(nameLength);
+ if (gnutls_server_name_get(sess, &nameBuffer[0], &nameLength, &nameType, 0) != GNUTLS_E_SUCCESS)
+ return false;
+
+ out.append(&nameBuffer[0]);
+ return true;
+ }
+
+ GnuTLS::Profile& GetProfile();
+ bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
+};
+
+int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st)
+{
+#ifndef GNUTLS_NEW_CERT_CALLBACK_API
+ st->type = GNUTLS_CRT_X509;
+#else
+ st->cert_type = GNUTLS_CRT_X509;
+ st->key_type = GNUTLS_PRIVKEY_X509;
+#endif
+ StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
+ GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile().GetX509Credentials();
+
+ st->ncerts = cred.certs.size();
+ st->cert.x509 = cred.certs.raw();
+ st->key.x509 = cred.key.get();
+ st->deinit_all = 0;
+
+ return 0;
+}
+
+class GnuTLSIOHookProvider : public IOHookProvider
+{
+ GnuTLS::Profile profile;
+
+ public:
+ GnuTLSIOHookProvider(Module* mod, GnuTLS::Profile::Config& config)
+ : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+ , profile(config)
+ {
+ ServerInstance->Modules->AddService(*this);
+ }
+
+ ~GnuTLSIOHookProvider()
+ {
+ ServerInstance->Modules->DelService(*this);
+ }
+
+ void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+ {
+ new GnuTLSIOHook(this, sock, GNUTLS_SERVER);
+ }
+
+ void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
{
- if (starttls.enabled)
- capHandler.HandleEvent(ev);
+ new GnuTLSIOHook(this, sock, GNUTLS_CLIENT);
+ }
+
+ GnuTLS::Profile& GetProfile() { return profile; }
+};
+
+GnuTLS::Profile& GnuTLSIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<GnuTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
+class ModuleSSLGnuTLS : public Module
+{
+ typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList;
+
+ // First member of the class, gets constructed first and destructed last
+ GnuTLS::Init libinit;
+ ProfileList profiles;
+
+ void ReadProfiles()
+ {
+ // First, store all profiles in a new, temporary container. If no problems occur, swap the two
+ // containers; this way if something goes wrong we can go back and continue using the current profiles,
+ // avoiding unpleasant situations where no new SSL connections are possible.
+ ProfileList newprofiles;
+
+ ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
+ if (tags.first == tags.second)
+ {
+ // No <sslprofile> tags found, create a profile named "gnutls" from settings in the <gnutls> block
+ const std::string defname = "gnutls";
+ ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <gnutls> tag");
+
+ try
+ {
+ GnuTLS::Profile::Config profileconfig(defname, tag);
+ newprofiles.push_back(new GnuTLSIOHookProvider(this, profileconfig));
+ }
+ catch (CoreException& ex)
+ {
+ throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
+ }
+ }
+
+ for (ConfigIter i = tags.first; i != tags.second; ++i)
+ {
+ ConfigTag* tag = i->second;
+ if (!stdalgo::string::equalsci(tag->getString("provider"), "gnutls"))
+ continue;
+
+ std::string name = tag->getString("name");
+ if (name.empty())
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
+ continue;
+ }
+
+ reference<GnuTLSIOHookProvider> prov;
+ try
+ {
+ GnuTLS::Profile::Config profileconfig(name, tag);
+ prov = new GnuTLSIOHookProvider(this, profileconfig);
+ }
+ catch (CoreException& ex)
+ {
+ throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
+ }
+
+ newprofiles.push_back(prov);
+ }
+
+ // New profiles are ok, begin using them
+ // Old profiles are deleted when their refcount drops to zero
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+ {
+ GnuTLSIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
+ }
+
+ profiles.swap(newprofiles);
+ }
+
+ public:
+ ModuleSSLGnuTLS()
+ {
+#ifndef GNUTLS_HAS_RND
+ gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
+#endif
+ thismod = this;
+ }
+
+ void init() CXX11_OVERRIDE
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "GnuTLS lib version %s module was compiled for " GNUTLS_VERSION, gnutls_check_version(NULL));
+ ReadProfiles();
+ ServerInstance->GenRandom = RandGen::Call;
+ }
+
+ void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
+ {
+ if(param != "ssl")
+ return;
+
+ try
+ {
+ ReadProfiles();
+ }
+ catch (ModuleException& ex)
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
+ }
+ }
+
+ ~ModuleSSLGnuTLS()
+ {
+ ServerInstance->GenRandom = &InspIRCd::DefaultGenRandom;
+ }
+
+ void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
+ {
+ if (type == ExtensionItem::EXT_USER)
+ {
+ LocalUser* user = IS_LOCAL(static_cast<User*>(item));
+
+ if ((user) && (user->eh.GetModHook(this)))
+ {
+ // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
+ // Potentially there could be multiple SSL modules loaded at once on different ports.
+ ServerInstance->Users->QuitUser(user, "SSL module unloading");
+ }
+ }
+ }
+
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Provides SSL support for clients", VF_VENDOR);
}
- ModResult OnCheckReady(LocalUser* user)
+ ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook() == this) && (sessions[user->eh.GetFd()].status != ISSL_HANDSHAKEN))
+ const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp
new file mode 100644
index 000000000..75b25fbc4
--- /dev/null
+++ b/src/modules/extra/m_ssl_mbedtls.cpp
@@ -0,0 +1,969 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/// $LinkerFlags: -lmbedtls
+
+/// $PackageInfo: require_system("darwin") mbedtls
+/// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
+/// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
+
+
+#include "inspircd.h"
+#include "modules/ssl.h"
+
+#include <mbedtls/ctr_drbg.h>
+#include <mbedtls/dhm.h>
+#include <mbedtls/ecp.h>
+#include <mbedtls/entropy.h>
+#include <mbedtls/error.h>
+#include <mbedtls/md.h>
+#include <mbedtls/pk.h>
+#include <mbedtls/ssl.h>
+#include <mbedtls/ssl_ciphersuites.h>
+#include <mbedtls/version.h>
+#include <mbedtls/x509.h>
+#include <mbedtls/x509_crt.h>
+#include <mbedtls/x509_crl.h>
+
+#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
+#include <mbedtls/debug.h>
+#endif
+
+namespace mbedTLS
+{
+ class Exception : public ModuleException
+ {
+ public:
+ Exception(const std::string& reason)
+ : ModuleException(reason) { }
+ };
+
+ std::string ErrorToString(int errcode)
+ {
+ char buf[256];
+ mbedtls_strerror(errcode, buf, sizeof(buf));
+ return buf;
+ }
+
+ void ThrowOnError(int errcode, const char* msg)
+ {
+ if (errcode != 0)
+ {
+ std::string reason = msg;
+ reason.append(" :").append(ErrorToString(errcode));
+ throw Exception(reason);
+ }
+ }
+
+ template <typename T, void (*init)(T*), void (*deinit)(T*)>
+ class RAIIObj
+ {
+ T obj;
+
+ public:
+ RAIIObj()
+ {
+ init(&obj);
+ }
+
+ ~RAIIObj()
+ {
+ deinit(&obj);
+ }
+
+ T* get() { return &obj; }
+ const T* get() const { return &obj; }
+ };
+
+ typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
+
+ class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
+ {
+ public:
+ bool Seed(Entropy& entropy)
+ {
+ return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
+ }
+
+ void SetupConf(mbedtls_ssl_config* conf)
+ {
+ mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
+ }
+ };
+
+ class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
+ {
+ public:
+ void set(const std::string& dhstr)
+ {
+ // Last parameter is buffer size, must include the terminating null
+ int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
+ ThrowOnError(ret, "Unable to import DH params");
+ }
+ };
+
+ class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
+ {
+ public:
+ /** Import */
+ X509Key(const std::string& keystr)
+ {
+ int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
+ ThrowOnError(ret, "Unable to import private key");
+ }
+ };
+
+ class Ciphersuites
+ {
+ std::vector<int> list;
+
+ public:
+ Ciphersuites(const std::string& str)
+ {
+ // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
+ // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
+ irc::sepstream ss(str, ':');
+ for (std::string token; ss.GetToken(token); )
+ {
+ // Prepend "TLS-" if not there
+ if (token.compare(0, 4, "TLS-", 4))
+ token.insert(0, "TLS-");
+
+ const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
+ if (!id)
+ throw Exception("Unknown ciphersuite " + token);
+ list.push_back(id);
+ }
+ list.push_back(0);
+ }
+
+ const int* get() const { return &list.front(); }
+ bool empty() const { return (list.size() <= 1); }
+ };
+
+ class Curves
+ {
+ std::vector<mbedtls_ecp_group_id> list;
+
+ public:
+ Curves(const std::string& str)
+ {
+ irc::sepstream ss(str, ':');
+ for (std::string token; ss.GetToken(token); )
+ {
+ const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
+ if (!curve)
+ throw Exception("Unknown curve " + token);
+ list.push_back(curve->grp_id);
+ }
+ list.push_back(MBEDTLS_ECP_DP_NONE);
+ }
+
+ const mbedtls_ecp_group_id* get() const { return &list.front(); }
+ bool empty() const { return (list.size() <= 1); }
+ };
+
+ class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
+ {
+ public:
+ /** Import or create empty */
+ X509CertList(const std::string& certstr, bool allowempty = false)
+ {
+ if ((allowempty) && (certstr.empty()))
+ return;
+ int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
+ ThrowOnError(ret, "Unable to load certificates");
+ }
+
+ bool empty() const { return (get()->raw.p != NULL); }
+ };
+
+ class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
+ {
+ public:
+ X509CRL(const std::string& crlstr)
+ {
+ if (crlstr.empty())
+ return;
+ int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
+ ThrowOnError(ret, "Unable to load CRL");
+ }
+ };
+
+ class X509Credentials
+ {
+ /** Private key
+ */
+ X509Key key;
+
+ /** Certificate list, presented to the peer
+ */
+ X509CertList certs;
+
+ public:
+ X509Credentials(const std::string& certstr, const std::string& keystr)
+ : key(keystr)
+ , certs(certstr)
+ {
+ // Verify that one of the certs match the private key
+ bool found = false;
+ for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
+ {
+ if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
+ {
+ found = true;
+ break;
+ }
+ }
+ if (!found)
+ throw Exception("Public/private key pair does not match");
+ }
+
+ mbedtls_pk_context* getkey() { return key.get(); }
+ mbedtls_x509_crt* getcerts() { return certs.get(); }
+ };
+
+ class Context
+ {
+ mbedtls_ssl_config conf;
+
+#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
+ static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
+ {
+ // Remove trailing \n
+ size_t len = strlen(msg);
+ if ((len > 0) && (msg[len-1] == '\n'))
+ len--;
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
+ }
+#endif
+
+ public:
+ Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
+ {
+ mbedtls_ssl_config_init(&conf);
+#ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
+ mbedtls_debug_set_threshold(INT_MAX);
+ mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
+#endif
+
+ // TODO: check ret of mbedtls_ssl_config_defaults
+ mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
+ ctrdrbg.SetupConf(&conf);
+ }
+
+ ~Context()
+ {
+ mbedtls_ssl_config_free(&conf);
+ }
+
+ void SetMinDHBits(unsigned int mindh)
+ {
+ mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
+ }
+
+ void SetDHParams(DHParams& dh)
+ {
+ mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
+ }
+
+ void SetX509CertAndKey(X509Credentials& x509cred)
+ {
+ mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
+ }
+
+ void SetCiphersuites(const Ciphersuites& ciphersuites)
+ {
+ mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
+ }
+
+ void SetCurves(const Curves& curves)
+ {
+ mbedtls_ssl_conf_curves(&conf, curves.get());
+ }
+
+ void SetVersion(int minver, int maxver)
+ {
+ // SSL v3 support cannot be enabled
+ if (minver)
+ mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
+ if (maxver)
+ mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
+ }
+
+ void SetCA(X509CertList& certs, X509CRL& crl)
+ {
+ mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
+ }
+
+ void SetOptionalVerifyCert()
+ {
+ mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+ }
+
+ const mbedtls_ssl_config* GetConf() const { return &conf; }
+ };
+
+ class Hash
+ {
+ const mbedtls_md_info_t* md;
+
+ /** Buffer where cert hashes are written temporarily
+ */
+ mutable std::vector<unsigned char> buf;
+
+ public:
+ Hash(std::string hashstr)
+ {
+ std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
+ md = mbedtls_md_info_from_string(hashstr.c_str());
+ if (!md)
+ throw Exception("Unknown hash: " + hashstr);
+
+ buf.resize(mbedtls_md_get_size(md));
+ }
+
+ std::string hash(const unsigned char* input, size_t length) const
+ {
+ mbedtls_md(md, input, length, &buf.front());
+ return BinToHex(&buf.front(), buf.size());
+ }
+ };
+
+ class Profile
+ {
+ /** Name of this profile
+ */
+ const std::string name;
+
+ X509Credentials x509cred;
+
+ /** Ciphersuites to use
+ */
+ Ciphersuites ciphersuites;
+
+ /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
+ */
+ Curves curves;
+
+ Context serverctx;
+ Context clientctx;
+
+ DHParams dhparams;
+
+ X509CertList cacerts;
+
+ X509CRL crl;
+
+ /** Hashing algorithm to use when generating certificate fingerprints
+ */
+ Hash hash;
+
+ /** Rough max size of records to send
+ */
+ const unsigned int outrecsize;
+
+ public:
+ struct Config
+ {
+ const std::string name;
+
+ CTRDRBG& ctrdrbg;
+
+ const std::string certstr;
+ const std::string keystr;
+ const std::string dhstr;
+
+ const std::string ciphersuitestr;
+ const std::string curvestr;
+ const unsigned int mindh;
+ const std::string hashstr;
+
+ std::string crlstr;
+ std::string castr;
+
+ const int minver;
+ const int maxver;
+ const unsigned int outrecsize;
+ const bool requestclientcert;
+
+ Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
+ : name(profilename)
+ , ctrdrbg(ctr_drbg)
+ , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+ , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+ , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
+ , ciphersuitestr(tag->getString("ciphersuites"))
+ , curvestr(tag->getString("curves"))
+ , mindh(tag->getUInt("mindhbits", 2048))
+ , hashstr(tag->getString("hash", "sha256"))
+ , castr(tag->getString("cafile"))
+ , minver(tag->getUInt("minver", 0))
+ , maxver(tag->getUInt("maxver", 0))
+ , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
+ , requestclientcert(tag->getBool("requestclientcert", true))
+ {
+ if (!castr.empty())
+ {
+ castr = ReadFile(castr);
+ crlstr = tag->getString("crlfile");
+ if (!crlstr.empty())
+ crlstr = ReadFile(crlstr);
+ }
+ }
+ };
+
+ Profile(Config& config)
+ : name(config.name)
+ , x509cred(config.certstr, config.keystr)
+ , ciphersuites(config.ciphersuitestr)
+ , curves(config.curvestr)
+ , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
+ , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
+ , cacerts(config.castr, true)
+ , crl(config.crlstr)
+ , hash(config.hashstr)
+ , outrecsize(config.outrecsize)
+ {
+ serverctx.SetX509CertAndKey(x509cred);
+ clientctx.SetX509CertAndKey(x509cred);
+ clientctx.SetMinDHBits(config.mindh);
+
+ if (!ciphersuites.empty())
+ {
+ serverctx.SetCiphersuites(ciphersuites);
+ clientctx.SetCiphersuites(ciphersuites);
+ }
+
+ if (!curves.empty())
+ {
+ serverctx.SetCurves(curves);
+ clientctx.SetCurves(curves);
+ }
+
+ serverctx.SetVersion(config.minver, config.maxver);
+ clientctx.SetVersion(config.minver, config.maxver);
+
+ if (!config.dhstr.empty())
+ {
+ dhparams.set(config.dhstr);
+ serverctx.SetDHParams(dhparams);
+ }
+
+ clientctx.SetOptionalVerifyCert();
+ clientctx.SetCA(cacerts, crl);
+ // The default for servers is to not request a client certificate from the peer
+ if (config.requestclientcert)
+ {
+ serverctx.SetOptionalVerifyCert();
+ serverctx.SetCA(cacerts, crl);
+ }
+ }
+
+ static std::string ReadFile(const std::string& filename)
+ {
+ FileReader reader(filename);
+ std::string ret = reader.GetString();
+ if (ret.empty())
+ throw Exception("Cannot read file " + filename);
+ return ret;
+ }
+
+ /** Set up the given session with the settings in this profile
+ */
+ void SetupClientSession(mbedtls_ssl_context* sess)
+ {
+ mbedtls_ssl_setup(sess, clientctx.GetConf());
+ }
+
+ void SetupServerSession(mbedtls_ssl_context* sess)
+ {
+ mbedtls_ssl_setup(sess, serverctx.GetConf());
+ }
+
+ const std::string& GetName() const { return name; }
+ X509Credentials& GetX509Credentials() { return x509cred; }
+ unsigned int GetOutgoingRecordSize() const { return outrecsize; }
+ const Hash& GetHash() const { return hash; }
+ };
+}
+
+class mbedTLSIOHook : public SSLIOHook
+{
+ enum Status
+ {
+ ISSL_NONE,
+ ISSL_HANDSHAKING,
+ ISSL_HANDSHAKEN
+ };
+
+ mbedtls_ssl_context sess;
+ Status status;
+
+ void CloseSession()
+ {
+ if (status == ISSL_NONE)
+ return;
+
+ mbedtls_ssl_close_notify(&sess);
+ mbedtls_ssl_free(&sess);
+ certificate = NULL;
+ status = ISSL_NONE;
+ }
+
+ // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
+ int Handshake(StreamSocket* sock)
+ {
+ int ret = mbedtls_ssl_handshake(&sess);
+ if (ret == 0)
+ {
+ // Change the seesion state
+ this->status = ISSL_HANDSHAKEN;
+
+ VerifyCertificate();
+
+ // Finish writing, if any left
+ SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
+
+ return 1;
+ }
+
+ this->status = ISSL_HANDSHAKING;
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ return 0;
+ }
+ else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+
+ sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
+ CloseSession();
+ return -1;
+ }
+
+ // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
+ int PrepareIO(StreamSocket* sock)
+ {
+ if (status == ISSL_HANDSHAKEN)
+ return 1;
+ else if (status == ISSL_HANDSHAKING)
+ {
+ // The handshake isn't finished, try to finish it
+ return Handshake(sock);
+ }
+
+ CloseSession();
+ sock->SetError("No SSL session");
+ return -1;
+ }
+
+ void VerifyCertificate()
+ {
+ this->certificate = new ssl_cert;
+ const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
+ if (!cert)
+ {
+ certificate->error = "No client certificate sent";
+ return;
+ }
+
+ // If there is a certificate we can always generate a fingerprint
+ certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
+
+ // At this point mbedTLS verified the cert already, we just need to check the results
+ const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
+ if (flags == 0xFFFFFFFF)
+ {
+ certificate->error = "Internal error during verification";
+ return;
+ }
+
+ if (flags == 0)
+ {
+ // Verification succeeded
+ certificate->trusted = true;
+ }
+ else
+ {
+ // Verification failed
+ certificate->trusted = false;
+ if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
+ certificate->error = "Not activated, or expired certificate";
+ }
+
+ certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
+ certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
+ certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
+
+ GetDNString(&cert->subject, certificate->dn);
+ GetDNString(&cert->issuer, certificate->issuer);
+ }
+
+ static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
+ {
+ char buf[512];
+ const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
+ if (ret <= 0)
+ return;
+
+ out.assign(buf, ret);
+ }
+
+ static int Pull(void* userptr, unsigned char* buffer, size_t size)
+ {
+ StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
+ if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
+ return MBEDTLS_ERR_SSL_WANT_READ;
+
+ const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
+ if (ret < (int)size)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
+ if ((ret == -1) && (SocketEngine::IgnoreError()))
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ return ret;
+ }
+
+ static int Push(void* userptr, const unsigned char* buffer, size_t size)
+ {
+ StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
+ if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+
+ const int ret = SocketEngine::Send(sock, buffer, size, 0);
+ if (ret < (int)size)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
+ if ((ret == -1) && (SocketEngine::IgnoreError()))
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ }
+ return ret;
+ }
+
+ public:
+ mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
+ : SSLIOHook(hookprov)
+ , status(ISSL_NONE)
+ {
+ mbedtls_ssl_init(&sess);
+ if (isserver)
+ GetProfile().SetupServerSession(&sess);
+ else
+ GetProfile().SetupClientSession(&sess);
+
+ mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
+
+ sock->AddIOHook(this);
+ Handshake(sock);
+ }
+
+ void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
+ {
+ CloseSession();
+ }
+
+ int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
+ {
+ // Finish handshake if needed
+ int prepret = PrepareIO(sock);
+ if (prepret <= 0)
+ return prepret;
+
+ // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
+ char* const readbuf = ServerInstance->GetReadBuffer();
+ const size_t readbufsize = ServerInstance->Config->NetBufferSize;
+ int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
+ if (ret > 0)
+ {
+ recvq.append(readbuf, ret);
+
+ // Schedule a read if there is still data in the mbedTLS buffer
+ if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
+ SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
+ return 1;
+ }
+ else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
+ return 0;
+ }
+ else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+ else if (ret == 0)
+ {
+ sock->SetError("Connection closed");
+ CloseSession();
+ return -1;
+ }
+ else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
+ {
+ sock->SetError(mbedTLS::ErrorToString(ret));
+ CloseSession();
+ return -1;
+ }
+ }
+
+ int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
+ {
+ // Finish handshake if needed
+ int prepret = PrepareIO(sock);
+ if (prepret <= 0)
+ return prepret;
+
+ // Session is ready for transferring application data
+ while (!sendq.empty())
+ {
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
+ const StreamSocket::SendQueue::Element& buffer = sendq.front();
+ int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
+ if (ret == (int)buffer.length())
+ {
+ // Wrote entire record, continue sending
+ sendq.pop_front();
+ }
+ else if (ret > 0)
+ {
+ sendq.erase_front(ret);
+ SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+ else if (ret == 0)
+ {
+ sock->SetError("Connection closed");
+ CloseSession();
+ return -1;
+ }
+ else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
+ return 0;
+ }
+ else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
+ {
+ SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
+ return 0;
+ }
+ else
+ {
+ sock->SetError(mbedTLS::ErrorToString(ret));
+ CloseSession();
+ return -1;
+ }
+ }
+
+ SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
+ return 1;
+ }
+
+ void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
+ {
+ if (!IsHandshakeDone())
+ return;
+ out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
+
+ // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
+ const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
+ const char prefix[] = "TLS-";
+ unsigned int skip = sizeof(prefix)-1;
+ if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
+ skip = 0;
+ out.append(ciphersuitestr + skip);
+ }
+
+ bool GetServerName(std::string& out) const CXX11_OVERRIDE
+ {
+ // TODO: Implement SNI support.
+ return false;
+ }
+
+ mbedTLS::Profile& GetProfile();
+ bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
+};
+
+class mbedTLSIOHookProvider : public IOHookProvider
+{
+ mbedTLS::Profile profile;
+
+ public:
+ mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
+ : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+ , profile(config)
+ {
+ ServerInstance->Modules->AddService(*this);
+ }
+
+ ~mbedTLSIOHookProvider()
+ {
+ ServerInstance->Modules->DelService(*this);
+ }
+
+ void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+ {
+ new mbedTLSIOHook(this, sock, true);
+ }
+
+ void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+ {
+ new mbedTLSIOHook(this, sock, false);
+ }
+
+ mbedTLS::Profile& GetProfile() { return profile; }
+};
+
+mbedTLS::Profile& mbedTLSIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
+class ModuleSSLmbedTLS : public Module
+{
+ typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
+
+ mbedTLS::Entropy entropy;
+ mbedTLS::CTRDRBG ctr_drbg;
+ ProfileList profiles;
+
+ void ReadProfiles()
+ {
+ // First, store all profiles in a new, temporary container. If no problems occur, swap the two
+ // containers; this way if something goes wrong we can go back and continue using the current profiles,
+ // avoiding unpleasant situations where no new SSL connections are possible.
+ ProfileList newprofiles;
+
+ ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
+ if (tags.first == tags.second)
+ {
+ // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
+ const std::string defname = "mbedtls";
+ ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
+
+ try
+ {
+ mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
+ newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
+ }
+ catch (CoreException& ex)
+ {
+ throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
+ }
+ }
+
+ for (ConfigIter i = tags.first; i != tags.second; ++i)
+ {
+ ConfigTag* tag = i->second;
+ if (!stdalgo::string::equalsci(tag->getString("provider"), "mbedtls"))
+ continue;
+
+ std::string name = tag->getString("name");
+ if (name.empty())
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
+ continue;
+ }
+
+ reference<mbedTLSIOHookProvider> prov;
+ try
+ {
+ mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
+ prov = new mbedTLSIOHookProvider(this, profileconfig);
+ }
+ catch (CoreException& ex)
+ {
+ throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
+ }
+
+ newprofiles.push_back(prov);
+ }
+
+ // New profiles are ok, begin using them
+ // Old profiles are deleted when their refcount drops to zero
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+ {
+ mbedTLSIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
+ }
+
+ profiles.swap(newprofiles);
+ }
+
+ public:
+ void init() CXX11_OVERRIDE
+ {
+ char verbuf[16]; // Should be at least 9 bytes in size
+ mbedtls_version_get_string(verbuf);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
+
+ if (!ctr_drbg.Seed(entropy))
+ throw ModuleException("CTR DRBG seed failed");
+ ReadProfiles();
+ }
+
+ void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
+ {
+ if (param != "ssl")
+ return;
+
+ try
+ {
+ ReadProfiles();
+ }
+ catch (ModuleException& ex)
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
+ }
+ }
+
+ void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
+ {
+ if (type != ExtensionItem::EXT_USER)
+ return;
+
+ LocalUser* user = IS_LOCAL(static_cast<User*>(item));
+ if ((user) && (user->eh.GetModHook(this)))
+ {
+ // User is using SSL, they're a local user, and they're using our IOHook.
+ // Potentially there could be multiple SSL modules loaded at once on different ports.
+ ServerInstance->Users.QuitUser(user, "SSL module unloading");
+ }
+ }
+
+ ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
+ {
+ const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
+ return MOD_RES_PASSTHRU;
+ }
+
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
+ }
+};
+
+MODULE_INIT(ModuleSSLmbedTLS)
diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp
index f2189f257..5f61c71a9 100644
--- a/src/modules/extra/m_ssl_openssl.cpp
+++ b/src/modules/extra/m_ssl_openssl.cpp
@@ -21,852 +21,1050 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
- /* HACK: This prevents OpenSSL on OS X 10.7 and later from spewing deprecation
- * warnings for every single function call. As far as I (SaberUK) know, Apple
- * have no plans to remove OpenSSL so this warning just causes needless spam.
- */
-#ifdef __APPLE__
-# define __AVAILABILITYMACROS__
-# define DEPRECATED_IN_MAC_OS_X_VERSION_10_7_AND_LATER
-#endif
-
+/// $CompilerFlags: find_compiler_flags("openssl")
+/// $LinkerFlags: find_linker_flags("openssl" "-lssl -lcrypto")
+
+/// $PackageInfo: require_system("centos") openssl-devel pkgconfig
+/// $PackageInfo: require_system("darwin") openssl pkg-config
+/// $PackageInfo: require_system("debian") libssl-dev openssl pkg-config
+/// $PackageInfo: require_system("ubuntu") libssl-dev openssl pkg-config
+
+
#include "inspircd.h"
+#include "iohook.h"
+#include "modules/ssl.h"
+
+// Ignore OpenSSL deprecation warnings on OS X Lion and newer.
+#if defined __APPLE__
+# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+// Fix warnings about the use of `long long` on C++03.
+#if defined __clang__
+# pragma clang diagnostic ignored "-Wc++11-long-long"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-Wlong-long"
+#endif
+
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/dh.h>
-#include "ssl.h"
#ifdef _WIN32
# pragma comment(lib, "ssleay32.lib")
# pragma comment(lib, "libeay32.lib")
-# undef MAX_DESCRIPTORS
-# define MAX_DESCRIPTORS 10000
#endif
// Compatibility layer to allow OpenSSL 1.0 to use the 1.1 API.
#if ((defined LIBRESSL_VERSION_NUMBER) || (OPENSSL_VERSION_NUMBER < 0x10100000L))
+
+// BIO is opaque in OpenSSL 1.1 but the access API does not exist in 1.0.
+# define BIO_get_data(BIO) BIO->ptr
+# define BIO_set_data(BIO, VALUE) BIO->ptr = VALUE;
+# define BIO_set_init(BIO, VALUE) BIO->init = VALUE;
+
+// These functions have been renamed in OpenSSL 1.1.
+# define OpenSSL_version SSLeay_version
# define X509_getm_notAfter X509_get_notAfter
# define X509_getm_notBefore X509_get_notBefore
# define OPENSSL_init_ssl(OPTIONS, SETTINGS) \
SSL_library_init(); \
SSL_load_error_strings();
-#endif
-/* $ModDesc: Provides SSL support for clients */
+// These macros have been renamed in OpenSSL 1.1.
+# define OPENSSL_VERSION SSLEAY_VERSION
-/* $LinkerFlags: if("USE_FREEBSD_BASE_SSL") -lssl -lcrypto */
-/* $CompileFlags: if(!"USE_FREEBSD_BASE_SSL") pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */
-/* $LinkerFlags: if(!"USE_FREEBSD_BASE_SSL") rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */
-
-/* $NoPedantic */
-
-
-class ModuleSSLOpenSSL;
+#else
+# define INSPIRCD_OPENSSL_OPAQUE_BIO
+#endif
enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN };
static bool SelfSigned = false;
-
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
-static ModuleSSLOpenSSL* opensslmod = NULL;
-#endif
+static int exdataindex;
char* get_error()
{
return ERR_error_string(ERR_get_error(), NULL);
}
-static int error_callback(const char *str, size_t len, void *u);
+static int OnVerify(int preverify_ok, X509_STORE_CTX* ctx);
+static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc);
-/** Represents an SSL user's extra data
- */
-class issl_session
+namespace OpenSSL
{
-public:
- SSL* sess;
- issl_status status;
- reference<ssl_cert> cert;
-
- bool outbound;
- bool data_to_write;
-
- issl_session()
- : sess(NULL)
- , status(ISSL_NONE)
+ class Exception : public ModuleException
{
- outbound = false;
- data_to_write = false;
- }
-};
-
-static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
-{
- /* XXX: This will allow self signed certificates.
- * In the future if we want an option to not allow this,
- * we can just return preverify_ok here, and openssl
- * will boot off self-signed and invalid peer certs.
- */
- int ve = X509_STORE_CTX_get_error(ctx);
-
- SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT);
+ public:
+ Exception(const std::string& reason)
+ : ModuleException(reason) { }
+ };
- return 1;
-}
+ class DHParams
+ {
+ DH* dh;
-class ModuleSSLOpenSSL : public Module
-{
- issl_session* sessions;
+ public:
+ DHParams(const std::string& filename)
+ {
+ BIO* dhpfile = BIO_new_file(filename.c_str(), "r");
+ if (dhpfile == NULL)
+ throw Exception("Couldn't open DH file " + filename);
- SSL_CTX* ctx;
- SSL_CTX* clictx;
+ dh = PEM_read_bio_DHparams(dhpfile, NULL, NULL, NULL);
+ BIO_free(dhpfile);
- long ctx_options;
- long clictx_options;
+ if (!dh)
+ throw Exception("Couldn't read DH params from file " + filename);
+ }
- std::string sslports;
- bool use_sha;
+ ~DHParams()
+ {
+ DH_free(dh);
+ }
- ServiceProvider iohook;
+ DH* get()
+ {
+ return dh;
+ }
+ };
- static void SetContextOptions(SSL_CTX* ctx, long defoptions, const std::string& ctxname, ConfigTag* tag)
+ class Context
{
- long setoptions = tag->getInt(ctxname + "setoptions");
- // User-friendly config options for setting context options
-#ifdef SSL_OP_CIPHER_SERVER_PREFERENCE
- if (tag->getBool("cipherserverpref"))
- setoptions |= SSL_OP_CIPHER_SERVER_PREFERENCE;
+ SSL_CTX* const ctx;
+ long ctx_options;
+
+ public:
+ Context(SSL_CTX* context)
+ : ctx(context)
+ {
+ // Sane default options for OpenSSL see https://www.openssl.org/docs/ssl/SSL_CTX_set_options.html
+ // and when choosing a cipher, use the server's preferences instead of the client preferences.
+ long opts = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_SINGLE_DH_USE;
+ // Only turn options on if they exist
+#ifdef SSL_OP_SINGLE_ECDH_USE
+ opts |= SSL_OP_SINGLE_ECDH_USE;
#endif
-#ifdef SSL_OP_NO_COMPRESSION
- if (!tag->getBool("compression", true))
- setoptions |= SSL_OP_NO_COMPRESSION;
+#ifdef SSL_OP_NO_TICKET
+ opts |= SSL_OP_NO_TICKET;
#endif
- if (!tag->getBool("sslv3", true))
- setoptions |= SSL_OP_NO_SSLv3;
- if (!tag->getBool("tlsv1", true))
- setoptions |= SSL_OP_NO_TLSv1;
- long clearoptions = tag->getInt(ctxname + "clearoptions");
- ServerInstance->Logs->Log("m_ssl_openssl", DEBUG, "Setting OpenSSL %s context options, default: %ld set: %ld clear: %ld", ctxname.c_str(), defoptions, setoptions, clearoptions);
+ ctx_options = SSL_CTX_set_options(ctx, opts);
- // Clear everything
- SSL_CTX_clear_options(ctx, SSL_CTX_get_options(ctx));
-
- // Set the default options and what is in the conf
- SSL_CTX_set_options(ctx, defoptions | setoptions);
- long final = SSL_CTX_clear_options(ctx, clearoptions);
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "OpenSSL %s context options: %ld", ctxname.c_str(), final);
- }
-
-#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH
- void SetupECDH(ConfigTag* tag)
- {
- std::string curvename = tag->getString("ecdhcurve", "prime256v1");
- if (curvename.empty())
- return;
+ long mode = SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER;
+#ifdef SSL_MODE_RELEASE_BUFFERS
+ mode |= SSL_MODE_RELEASE_BUFFERS;
+#endif
+ SSL_CTX_set_mode(ctx, mode);
+ SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
+ SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
+ SSL_CTX_set_info_callback(ctx, StaticSSLInfoCallback);
+ }
- int nid = OBJ_sn2nid(curvename.c_str());
- if (nid == 0)
+ ~Context()
{
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Unknown curve: \"%s\"", curvename.c_str());
- return;
+ SSL_CTX_free(ctx);
}
- EC_KEY* eckey = EC_KEY_new_by_curve_name(nid);
- if (!eckey)
+ bool SetDH(DHParams& dh)
{
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Unable to create EC key object");
- return;
+ ERR_clear_error();
+ return (SSL_CTX_set_tmp_dh(ctx, dh.get()) >= 0);
}
- ERR_clear_error();
- if (SSL_CTX_set_tmp_ecdh(ctx, eckey) < 0)
+#ifndef OPENSSL_NO_ECDH
+ void SetECDH(const std::string& curvename)
{
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set ECDH parameters");
- ERR_print_errors_cb(error_callback, this);
- }
+ int nid = OBJ_sn2nid(curvename.c_str());
+ if (nid == 0)
+ throw Exception("Unknown curve: " + curvename);
- EC_KEY_free(eckey);
- }
+ EC_KEY* eckey = EC_KEY_new_by_curve_name(nid);
+ if (!eckey)
+ throw Exception("Unable to create EC key object");
+
+ ERR_clear_error();
+ bool ret = (SSL_CTX_set_tmp_ecdh(ctx, eckey) >= 0);
+ EC_KEY_free(eckey);
+ if (!ret)
+ throw Exception("Couldn't set ECDH parameters");
+ }
#endif
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
- static void SSLInfoCallback(const SSL* ssl, int where, int rc)
- {
- int fd = SSL_get_fd(const_cast<SSL*>(ssl));
- issl_session& session = opensslmod->sessions[fd];
+ bool SetCiphers(const std::string& ciphers)
+ {
+ ERR_clear_error();
+ return SSL_CTX_set_cipher_list(ctx, ciphers.c_str());
+ }
- if ((where & SSL_CB_HANDSHAKE_START) && (session.status == ISSL_OPEN))
+ bool SetCerts(const std::string& filename)
{
- // The other side is trying to renegotiate, kill the connection and change status
- // to ISSL_NONE so CheckRenego() closes the session
- session.status = ISSL_NONE;
- ServerInstance->SE->Shutdown(fd, 2);
+ ERR_clear_error();
+ return SSL_CTX_use_certificate_chain_file(ctx, filename.c_str());
}
- }
- bool CheckRenego(StreamSocket* sock, issl_session* session)
- {
- if (session->status != ISSL_NONE)
- return true;
+ bool SetPrivateKey(const std::string& filename)
+ {
+ ERR_clear_error();
+ return SSL_CTX_use_PrivateKey_file(ctx, filename.c_str(), SSL_FILETYPE_PEM);
+ }
- ServerInstance->Logs->Log("m_ssl_openssl", DEBUG, "Session %p killed, attempted to renegotiate", (void*)session->sess);
- CloseSession(session);
- sock->SetError("Renegotiation is not allowed");
- return false;
- }
-#endif
+ bool SetCA(const std::string& filename)
+ {
+ ERR_clear_error();
+ return SSL_CTX_load_verify_locations(ctx, filename.c_str(), 0);
+ }
- public:
+ void SetCRL(const std::string& crlfile, const std::string& crlpath, const std::string& crlmode)
+ {
+ if (crlfile.empty() && crlpath.empty())
+ return;
- ModuleSSLOpenSSL() : iohook(this, "ssl/openssl", SERVICE_IOHOOK)
- {
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
- opensslmod = this;
-#endif
- sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
+ /* Set CRL mode */
+ unsigned long crlflags = X509_V_FLAG_CRL_CHECK;
+ if (stdalgo::string::equalsci(crlmode, "chain"))
+ {
+ crlflags |= X509_V_FLAG_CRL_CHECK_ALL;
+ }
+ else if (!stdalgo::string::equalsci(crlmode, "leaf"))
+ {
+ throw ModuleException("Unknown mode '" + crlmode + "'; expected either 'chain' (default) or 'leaf'");
+ }
- /* Global SSL library initialization*/
- OPENSSL_init_ssl(0, NULL);
+ /* Load CRL files */
+ X509_STORE* store = SSL_CTX_get_cert_store(ctx);
+ if (!store)
+ {
+ throw ModuleException("Unable to get X509_STORE from SSL context; this should never happen");
+ }
+ ERR_clear_error();
+ if (!X509_STORE_load_locations(store,
+ crlfile.empty() ? NULL : crlfile.c_str(),
+ crlpath.empty() ? NULL : crlpath.c_str()))
+ {
+ int err = ERR_get_error();
+ throw ModuleException("Unable to load CRL file '" + crlfile + "' or CRL path '" + crlpath + "': '" + (err ? ERR_error_string(err, NULL) : "unknown") + "'");
+ }
- /* Build our SSL contexts:
- * NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK.
- */
- ctx = SSL_CTX_new( SSLv23_server_method() );
- clictx = SSL_CTX_new( SSLv23_client_method() );
+ /* Set CRL mode */
+ if (X509_STORE_set_flags(store, crlflags) != 1)
+ {
+ throw ModuleException("Unable to set X509 CRL flags");
+ }
+ }
- SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
- SSL_CTX_set_mode(clictx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
- SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
- SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+ long GetDefaultContextOptions() const
+ {
+ return ctx_options;
+ }
- SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
- SSL_CTX_set_session_cache_mode(clictx, SSL_SESS_CACHE_OFF);
+ long SetRawContextOptions(long setoptions, long clearoptions)
+ {
+ // Clear everything
+ SSL_CTX_clear_options(ctx, SSL_CTX_get_options(ctx));
- long opts = SSL_OP_NO_SSLv2 | SSL_OP_SINGLE_DH_USE;
- // Only turn options on if they exist
-#ifdef SSL_OP_SINGLE_ECDH_USE
- opts |= SSL_OP_SINGLE_ECDH_USE;
-#endif
-#ifdef SSL_OP_NO_TICKET
- opts |= SSL_OP_NO_TICKET;
-#endif
+ // Set the default options and what is in the conf
+ SSL_CTX_set_options(ctx, ctx_options | setoptions);
+ return SSL_CTX_clear_options(ctx, clearoptions);
+ }
- ctx_options = SSL_CTX_set_options(ctx, opts);
- clictx_options = SSL_CTX_set_options(clictx, opts);
- }
+ void SetVerifyCert()
+ {
+ SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+ }
- void init()
- {
- // Needs the flag as it ignores a plain /rehash
- OnModuleRehash(NULL,"ssl");
- Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO, I_OnUserConnect };
- ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
- ServerInstance->Modules->AddService(iohook);
- }
+ SSL* CreateServerSession()
+ {
+ SSL* sess = SSL_new(ctx);
+ SSL_set_accept_state(sess); // Act as server
+ return sess;
+ }
- void OnHookIO(StreamSocket* user, ListenSocket* lsb)
- {
- if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "openssl")
+ SSL* CreateClientSession()
{
- /* Hook the user with our module */
- user->AddIOHook(this);
+ SSL* sess = SSL_new(ctx);
+ SSL_set_connect_state(sess); // Act as client
+ return sess;
}
- }
+ };
- void OnRehash(User* user)
+ class Profile
{
- sslports.clear();
+ /** Name of this profile
+ */
+ const std::string name;
+
+ /** DH parameters in use
+ */
+ DHParams dh;
- ConfigTag* Conf = ServerInstance->Config->ConfValue("openssl");
+ /** OpenSSL makes us have two contexts, one for servers and one for clients
+ */
+ Context ctx;
+ Context clictx;
+
+ /** Digest to use when generating fingerprints
+ */
+ const EVP_MD* digest;
+
+ /** Last error, set by error_callback()
+ */
+ std::string lasterr;
+
+ /** True if renegotiations are allowed, false if not
+ */
+ const bool allowrenego;
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
- // Set the callback if we are not allowing renegotiations, unset it if we do
- if (Conf->getBool("renegotiation", true))
+ /** Rough max size of records to send
+ */
+ const unsigned int outrecsize;
+
+ static int error_callback(const char* str, size_t len, void* u)
{
- SSL_CTX_set_info_callback(ctx, NULL);
- SSL_CTX_set_info_callback(clictx, NULL);
+ Profile* profile = reinterpret_cast<Profile*>(u);
+ profile->lasterr = std::string(str, len - 1);
+ return 0;
}
- else
+
+ /** Set raw OpenSSL context (SSL_CTX) options from a config tag
+ * @param ctxname Name of the context, client or server
+ * @param tag Config tag defining this profile
+ * @param context Context object to manipulate
+ */
+ void SetContextOptions(const std::string& ctxname, ConfigTag* tag, Context& context)
{
- SSL_CTX_set_info_callback(ctx, SSLInfoCallback);
- SSL_CTX_set_info_callback(clictx, SSLInfoCallback);
- }
+ long setoptions = tag->getInt(ctxname + "setoptions", 0);
+ long clearoptions = tag->getInt(ctxname + "clearoptions", 0);
+#ifdef SSL_OP_NO_COMPRESSION
+ if (!tag->getBool("compression", false)) // Disable compression by default
+ setoptions |= SSL_OP_NO_COMPRESSION;
#endif
+ // Disable TLSv1.0 by default.
+ if (!tag->getBool("tlsv1", false))
+ setoptions |= SSL_OP_NO_TLSv1;
- if (Conf->getBool("showports", true))
- {
- sslports = Conf->getString("advertisedports");
- if (!sslports.empty())
- return;
+ if (!setoptions && !clearoptions)
+ return; // Nothing to do
- for (size_t i = 0; i < ServerInstance->ports.size(); i++)
- {
- ListenSocket* port = ServerInstance->ports[i];
- if (port->bind_tag->getString("ssl") != "openssl")
- continue;
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Setting %s %s context options, default: %ld set: %ld clear: %ld", name.c_str(), ctxname.c_str(), ctx.GetDefaultContextOptions(), setoptions, clearoptions);
+ long final = context.SetRawContextOptions(setoptions, clearoptions);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "%s %s context options: %ld", name.c_str(), ctxname.c_str(), final);
+ }
+
+ public:
+ Profile(const std::string& profilename, ConfigTag* tag)
+ : name(profilename)
+ , dh(ServerInstance->Config->Paths.PrependConfig(tag->getString("dhfile", "dhparams.pem")))
+ , ctx(SSL_CTX_new(SSLv23_server_method()))
+ , clictx(SSL_CTX_new(SSLv23_client_method()))
+ , allowrenego(tag->getBool("renegotiation")) // Disallow by default
+ , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
+ {
+ if ((!ctx.SetDH(dh)) || (!clictx.SetDH(dh)))
+ throw Exception("Couldn't set DH parameters");
- const std::string& portid = port->bind_desc;
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %s", portid.c_str());
+ std::string hash = tag->getString("hash", "md5");
+ digest = EVP_get_digestbyname(hash.c_str());
+ if (digest == NULL)
+ throw Exception("Unknown hash type " + hash);
- if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1")
+ std::string ciphers = tag->getString("ciphers");
+ if (!ciphers.empty())
+ {
+ if ((!ctx.SetCiphers(ciphers)) || (!clictx.SetCiphers(ciphers)))
{
- /*
- * Found an SSL port for clients that is not bound to 127.0.0.1 and handled by us, display
- * the IP:port in ISUPPORT.
- *
- * We used to advertise all ports seperated by a ';' char that matched the above criteria,
- * but this resulted in too long ISUPPORT lines if there were lots of ports to be displayed.
- * To solve this by default we now only display the first IP:port found and let the user
- * configure the exact value for the 005 token, if necessary.
- */
- sslports = portid;
- break;
+ ERR_print_errors_cb(error_callback, this);
+ throw Exception("Can't set cipher list to \"" + ciphers + "\" " + lasterr);
}
}
- }
- }
-
- void OnModuleRehash(User* user, const std::string &param)
- {
- if (param != "ssl")
- return;
-
- std::string keyfile;
- std::string certfile;
- std::string cafile;
- std::string dhfile;
- OnRehash(user);
- ConfigTag* conf = ServerInstance->Config->ConfValue("openssl");
+#ifndef OPENSSL_NO_ECDH
+ std::string curvename = tag->getString("ecdhcurve", "prime256v1");
+ if (!curvename.empty())
+ ctx.SetECDH(curvename);
+#endif
- cafile = conf->getString("cafile", CONFIG_PATH "/ca.pem");
- certfile = conf->getString("certfile", CONFIG_PATH "/cert.pem");
- keyfile = conf->getString("keyfile", CONFIG_PATH "/key.pem");
- dhfile = conf->getString("dhfile", CONFIG_PATH "/dhparams.pem");
- std::string hash = conf->getString("hash", "md5");
- if (hash != "sha1" && hash != "md5")
- throw ModuleException("Unknown hash type " + hash);
- use_sha = (hash == "sha1");
+ SetContextOptions("server", tag, ctx);
+ SetContextOptions("client", tag, clictx);
- if (conf->getBool("customcontextoptions"))
- {
- SetContextOptions(ctx, ctx_options, "server", conf);
- SetContextOptions(clictx, clictx_options, "client", conf);
- }
+ /* Load our keys and certificates
+ * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck.
+ */
+ std::string filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("certfile", "cert.pem"));
+ if ((!ctx.SetCerts(filename)) || (!clictx.SetCerts(filename)))
+ {
+ ERR_print_errors_cb(error_callback, this);
+ throw Exception("Can't read certificate file: " + lasterr);
+ }
- std::string ciphers = conf->getString("ciphers", "");
+ filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("keyfile", "key.pem"));
+ if ((!ctx.SetPrivateKey(filename)) || (!clictx.SetPrivateKey(filename)))
+ {
+ ERR_print_errors_cb(error_callback, this);
+ throw Exception("Can't read key file: " + lasterr);
+ }
- if (!ciphers.empty())
- {
- ERR_clear_error();
- if ((!SSL_CTX_set_cipher_list(ctx, ciphers.c_str())) || (!SSL_CTX_set_cipher_list(clictx, ciphers.c_str())))
+ // Load the CAs we trust
+ filename = ServerInstance->Config->Paths.PrependConfig(tag->getString("cafile", "ca.pem"));
+ if ((!ctx.SetCA(filename)) || (!clictx.SetCA(filename)))
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't set cipher list to %s.", ciphers.c_str());
ERR_print_errors_cb(error_callback, this);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", filename.c_str(), lasterr.c_str());
}
+
+ // Load the CRLs.
+ std::string crlfile = tag->getString("crlfile");
+ std::string crlpath = tag->getString("crlpath");
+ std::string crlmode = tag->getString("crlmode", "chain");
+ ctx.SetCRL(crlfile, crlpath, crlmode);
+
+ clictx.SetVerifyCert();
+ if (tag->getBool("requestclientcert", true))
+ ctx.SetVerifyCert();
}
- /* Load our keys and certificates
- * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck.
- */
- ERR_clear_error();
- if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str())))
+ const std::string& GetName() const { return name; }
+ SSL* CreateServerSession() { return ctx.CreateServerSession(); }
+ SSL* CreateClientSession() { return clictx.CreateClientSession(); }
+ const EVP_MD* GetDigest() { return digest; }
+ bool AllowRenegotiation() const { return allowrenego; }
+ unsigned int GetOutgoingRecordSize() const { return outrecsize; }
+ };
+
+ namespace BIOMethod
+ {
+ static int create(BIO* bio)
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno));
- ERR_print_errors_cb(error_callback, this);
+ BIO_set_init(bio, 1);
+ return 1;
}
- ERR_clear_error();
- if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM)))
+ static int destroy(BIO* bio)
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno));
- ERR_print_errors_cb(error_callback, this);
+ // XXX: Dummy function to avoid a memory leak in OpenSSL.
+ // The memory leak happens in BIO_free() (bio_lib.c) when the destroy func of the BIO is NULL.
+ // This is fixed in OpenSSL but some distros still ship the unpatched version hence we provide this workaround.
+ return 1;
}
- /* Load the CAs we trust*/
- ERR_clear_error();
- if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0)))
+ static long ctrl(BIO* bio, int cmd, long num, void* ptr)
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", cafile.c_str(), strerror(errno));
- ERR_print_errors_cb(error_callback, this);
+ if (cmd == BIO_CTRL_FLUSH)
+ return 1;
+ return 0;
}
-#ifdef _WIN32
- BIO* dhpfile = BIO_new_file(dhfile.c_str(), "r");
-#else
- FILE* dhpfile = fopen(dhfile.c_str(), "r");
-#endif
- DH* ret;
+ static int read(BIO* bio, char* buf, int len);
+ static int write(BIO* bio, const char* buf, int len);
- if (dhpfile == NULL)
+#ifdef INSPIRCD_OPENSSL_OPAQUE_BIO
+ static BIO_METHOD* alloc()
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno));
- throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno));
+ BIO_METHOD* meth = BIO_meth_new(100 | BIO_TYPE_SOURCE_SINK, "inspircd");
+ BIO_meth_set_write(meth, OpenSSL::BIOMethod::write);
+ BIO_meth_set_read(meth, OpenSSL::BIOMethod::read);
+ BIO_meth_set_ctrl(meth, OpenSSL::BIOMethod::ctrl);
+ BIO_meth_set_create(meth, OpenSSL::BIOMethod::create);
+ BIO_meth_set_destroy(meth, OpenSSL::BIOMethod::destroy);
+ return meth;
}
- else
- {
-#ifdef _WIN32
- ret = PEM_read_bio_DHparams(dhpfile, NULL, NULL, NULL);
- BIO_free(dhpfile);
+#endif
+ }
+}
+
+// BIO_METHOD is opaque in OpenSSL 1.1 so we can't do this.
+// See OpenSSL::BIOMethod::alloc for the new method.
+#ifndef INSPIRCD_OPENSSL_OPAQUE_BIO
+static BIO_METHOD biomethods =
+{
+ (100 | BIO_TYPE_SOURCE_SINK),
+ "inspircd",
+ OpenSSL::BIOMethod::write,
+ OpenSSL::BIOMethod::read,
+ NULL, // puts
+ NULL, // gets
+ OpenSSL::BIOMethod::ctrl,
+ OpenSSL::BIOMethod::create,
+ OpenSSL::BIOMethod::destroy, // destroy, does nothing, see function body for more info
+ NULL // callback_ctrl
+};
#else
- ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL);
+static BIO_METHOD* biomethods;
#endif
- ERR_clear_error();
- if (ret)
+static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
+{
+ /* XXX: This will allow self signed certificates.
+ * In the future if we want an option to not allow this,
+ * we can just return preverify_ok here, and openssl
+ * will boot off self-signed and invalid peer certs.
+ */
+ int ve = X509_STORE_CTX_get_error(ctx);
+
+ SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT);
+
+ return 1;
+}
+
+class OpenSSLIOHook : public SSLIOHook
+{
+ private:
+ SSL* sess;
+ issl_status status;
+ bool data_to_write;
+
+ // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
+ int Handshake(StreamSocket* user)
+ {
+ ERR_clear_error();
+ int ret = SSL_do_handshake(sess);
+ if (ret < 0)
+ {
+ int err = SSL_get_error(sess, ret);
+
+ if (err == SSL_ERROR_WANT_READ)
{
- if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0))
- {
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str());
- ERR_print_errors_cb(error_callback, this);
- }
- DH_free(ret);
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ this->status = ISSL_HANDSHAKING;
+ return 0;
+ }
+ else if (err == SSL_ERROR_WANT_WRITE)
+ {
+ SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
+ this->status = ISSL_HANDSHAKING;
+ return 0;
}
else
{
- ServerInstance->Logs->Log("m_ssl_openssl", DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s.", dhfile.c_str());
+ CloseSession();
+ return -1;
}
}
+ else if (ret > 0)
+ {
+ // Handshake complete.
+ VerifyCertificate();
-#ifndef _WIN32
- fclose(dhpfile);
-#endif
-
-#ifdef INSPIRCD_OPENSSL_ENABLE_ECDH
- SetupECDH(conf);
-#endif
- }
-
- void On005Numeric(std::string &output)
- {
- if (!sslports.empty())
- output.append(" SSL=" + sslports);
- }
+ status = ISSL_OPEN;
- ~ModuleSSLOpenSSL()
- {
- SSL_CTX_free(ctx);
- SSL_CTX_free(clictx);
- delete[] sessions;
- }
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
- void OnUserConnect(LocalUser* user)
- {
- if (user->eh.GetIOHook() == this)
+ return 1;
+ }
+ else if (ret == 0)
{
- if (sessions[user->eh.GetFd()].sess)
- {
- if (!sessions[user->eh.GetFd()].cert->fingerprint.empty())
- user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\""
- " and your SSL fingerprint is %s", user->nick.c_str(), SSL_get_cipher(sessions[user->eh.GetFd()].sess), sessions[user->eh.GetFd()].cert->fingerprint.c_str());
- else
- user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), SSL_get_cipher(sessions[user->eh.GetFd()].sess));
- }
+ CloseSession();
}
+ return -1;
}
- void OnCleanup(int target_type, void* item)
+ void CloseSession()
{
- if (target_type == TYPE_USER)
+ if (sess)
{
- LocalUser* user = IS_LOCAL((User*)item);
-
- if (user && user->eh.GetIOHook() == this)
- {
- // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
- // Potentially there could be multiple SSL modules loaded at once on different ports.
- ServerInstance->Users->QuitUser(user, "SSL module unloading");
- }
+ SSL_shutdown(sess);
+ SSL_free(sess);
}
+ sess = NULL;
+ certificate = NULL;
+ status = ISSL_NONE;
}
- Version GetVersion()
+ void VerifyCertificate()
{
- return Version("Provides SSL support for clients", VF_VENDOR);
- }
+ X509* cert;
+ ssl_cert* certinfo = new ssl_cert;
+ this->certificate = certinfo;
+ unsigned int n;
+ unsigned char md[EVP_MAX_MD_SIZE];
- void OnRequest(Request& request)
- {
- if (strcmp("GET_SSL_CERT", request.id) == 0)
+ cert = SSL_get_peer_certificate(sess);
+
+ if (!cert)
{
- SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
- int fd = req.sock->GetFd();
- issl_session* session = &sessions[fd];
+ certinfo->error = "Could not get peer certificate: "+std::string(get_error());
+ return;
+ }
+
+ certinfo->invalid = (SSL_get_verify_result(sess) != X509_V_OK);
- req.cert = session->cert;
+ if (!SelfSigned)
+ {
+ certinfo->unknownsigner = false;
+ certinfo->trusted = true;
}
- else if (!strcmp("GET_RAW_SSL_SESSION", request.id))
+ else
{
- SSLRawSessionRequest& req = static_cast<SSLRawSessionRequest&>(request);
- if ((req.fd >= 0) && (req.fd < ServerInstance->SE->GetMaxFds()))
- req.data = reinterpret_cast<void*>(sessions[req.fd].sess);
+ certinfo->unknownsigner = true;
+ certinfo->trusted = false;
}
- }
- void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
- {
- int fd = user->GetFd();
-
- issl_session* session = &sessions[fd];
+ char buf[512];
+ X509_NAME_oneline(X509_get_subject_name(cert), buf, sizeof(buf));
+ certinfo->dn = buf;
+ // Make sure there are no chars in the string that we consider invalid
+ if (certinfo->dn.find_first_of("\r\n") != std::string::npos)
+ certinfo->dn.clear();
- session->sess = SSL_new(ctx);
- session->status = ISSL_NONE;
- session->outbound = false;
- session->data_to_write = false;
+ X509_NAME_oneline(X509_get_issuer_name(cert), buf, sizeof(buf));
+ certinfo->issuer = buf;
+ if (certinfo->issuer.find_first_of("\r\n") != std::string::npos)
+ certinfo->issuer.clear();
- if (session->sess == NULL)
- return;
+ if (!X509_digest(cert, GetProfile().GetDigest(), md, &n))
+ {
+ certinfo->error = "Out of memory generating fingerprint";
+ }
+ else
+ {
+ certinfo->fingerprint = BinToHex(md, n);
+ }
- if (SSL_set_fd(session->sess, fd) == 0)
+ if ((ASN1_UTCTIME_cmp_time_t(X509_getm_notAfter(cert), ServerInstance->Time()) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_getm_notBefore(cert), ServerInstance->Time()) == 0))
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
- return;
+ certinfo->error = "Not activated, or expired certificate";
}
- Handshake(user, session);
+ X509_free(cert);
}
- void OnStreamSocketConnect(StreamSocket* user)
+ void SSLInfoCallback(int where, int rc)
{
- int fd = user->GetFd();
- /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
- if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1))
- return;
+ if ((where & SSL_CB_HANDSHAKE_START) && (status == ISSL_OPEN))
+ {
+ if (GetProfile().AllowRenegotiation())
+ return;
- issl_session* session = &sessions[fd];
+ // The other side is trying to renegotiate, kill the connection and change status
+ // to ISSL_NONE so CheckRenego() closes the session
+ status = ISSL_NONE;
+ BIO* bio = SSL_get_rbio(sess);
+ EventHandler* eh = static_cast<StreamSocket*>(BIO_get_data(bio));
+ SocketEngine::Shutdown(eh, 2);
+ }
+ }
- session->sess = SSL_new(clictx);
- session->status = ISSL_NONE;
- session->outbound = true;
- session->data_to_write = false;
+ bool CheckRenego(StreamSocket* sock)
+ {
+ if (status != ISSL_NONE)
+ return true;
- if (session->sess == NULL)
- return;
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Session %p killed, attempted to renegotiate", (void*)sess);
+ CloseSession();
+ sock->SetError("Renegotiation is not allowed");
+ return false;
+ }
- if (SSL_set_fd(session->sess, fd) == 0)
+ // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
+ int PrepareIO(StreamSocket* sock)
+ {
+ if (status == ISSL_OPEN)
+ return 1;
+ else if (status == ISSL_HANDSHAKING)
{
- ServerInstance->Logs->Log("m_ssl_openssl",DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
- return;
+ // The handshake isn't finished, try to finish it
+ return Handshake(sock);
}
- Handshake(user, session);
+ CloseSession();
+ return -1;
}
- void OnStreamSocketClose(StreamSocket* user)
+ // Calls our private SSLInfoCallback()
+ friend void StaticSSLInfoCallback(const SSL* ssl, int where, int rc);
+
+ public:
+ OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session)
+ : SSLIOHook(hookprov)
+ , sess(session)
+ , status(ISSL_NONE)
+ , data_to_write(false)
{
- int fd = user->GetFd();
- /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
- if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
- return;
+ // Create BIO instance and store a pointer to the socket in it which will be used by the read and write functions
+#ifdef INSPIRCD_OPENSSL_OPAQUE_BIO
+ BIO* bio = BIO_new(biomethods);
+#else
+ BIO* bio = BIO_new(&biomethods);
+#endif
+ BIO_set_data(bio, sock);
+ SSL_set_bio(sess, bio, bio);
- CloseSession(&sessions[fd]);
+ SSL_set_ex_data(sess, exdataindex, this);
+ sock->AddIOHook(this);
+ Handshake(sock);
}
- int OnStreamSocketRead(StreamSocket* user, std::string& recvq)
+ void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE
{
- int fd = user->GetFd();
- /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
- if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
- return -1;
-
- issl_session* session = &sessions[fd];
-
- if (!session->sess)
- {
- CloseSession(session);
- return -1;
- }
-
- if (session->status == ISSL_HANDSHAKING)
- {
- // The handshake isn't finished and it wants to read, try to finish it.
- if (!Handshake(user, session))
- {
- // Couldn't resume handshake.
- if (session->status == ISSL_NONE)
- return -1;
- return 0;
- }
- }
+ CloseSession();
+ }
- // If we resumed the handshake then session->status will be ISSL_OPEN
+ int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE
+ {
+ // Finish handshake if needed
+ int prepret = PrepareIO(user);
+ if (prepret <= 0)
+ return prepret;
- if (session->status == ISSL_OPEN)
+ // If we resumed the handshake then this->status will be ISSL_OPEN
{
ERR_clear_error();
char* buffer = ServerInstance->GetReadBuffer();
size_t bufsiz = ServerInstance->Config->NetBufferSize;
- int ret = SSL_read(session->sess, buffer, bufsiz);
+ int ret = SSL_read(sess, buffer, bufsiz);
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
- if (!CheckRenego(user, session))
+ if (!CheckRenego(user))
return -1;
-#endif
if (ret > 0)
{
recvq.append(buffer, ret);
-
int mask = 0;
// Schedule a read if there is still data in the OpenSSL buffer
- if (SSL_pending(session->sess) > 0)
+ if (SSL_pending(sess) > 0)
mask |= FD_ADD_TRIAL_READ;
- if (session->data_to_write)
+ if (data_to_write)
mask |= FD_WANT_POLL_READ | FD_WANT_SINGLE_WRITE;
if (mask != 0)
- ServerInstance->SE->ChangeEventMask(user, mask);
+ SocketEngine::ChangeEventMask(user, mask);
return 1;
}
else if (ret == 0)
{
// Client closed connection.
- CloseSession(session);
+ CloseSession();
user->SetError("Connection closed");
return -1;
}
- else if (ret < 0)
+ else // if (ret < 0)
{
- int err = SSL_get_error(session->sess, ret);
+ int err = SSL_get_error(sess, ret);
if (err == SSL_ERROR_WANT_READ)
{
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ);
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ);
return 0;
}
else if (err == SSL_ERROR_WANT_WRITE)
{
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
+ SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
return 0;
}
else
{
- CloseSession(session);
+ CloseSession();
return -1;
}
}
}
-
- return 0;
}
- int OnStreamSocketWrite(StreamSocket* user, std::string& buffer)
+ int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
{
- int fd = user->GetFd();
-
- issl_session* session = &sessions[fd];
-
- if (!session->sess)
- {
- CloseSession(session);
- return -1;
- }
-
- session->data_to_write = true;
+ // Finish handshake if needed
+ int prepret = PrepareIO(user);
+ if (prepret <= 0)
+ return prepret;
- if (session->status == ISSL_HANDSHAKING)
- {
- if (!Handshake(user, session))
- {
- // Couldn't resume handshake.
- if (session->status == ISSL_NONE)
- return -1;
- return 0;
- }
- }
+ data_to_write = true;
- if (session->status == ISSL_OPEN)
+ // Session is ready for transferring application data
+ while (!sendq.empty())
{
ERR_clear_error();
- int ret = SSL_write(session->sess, buffer.data(), buffer.size());
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
+ const StreamSocket::SendQueue::Element& buffer = sendq.front();
+ int ret = SSL_write(sess, buffer.data(), buffer.size());
-#ifdef INSPIRCD_OPENSSL_ENABLE_RENEGO_DETECTION
- if (!CheckRenego(user, session))
+ if (!CheckRenego(user))
return -1;
-#endif
if (ret == (int)buffer.length())
{
- session->data_to_write = false;
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
- return 1;
+ // Wrote entire record, continue sending
+ sendq.pop_front();
}
else if (ret > 0)
{
- buffer = buffer.substr(ret);
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
+ sendq.erase_front(ret);
+ SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
return 0;
}
else if (ret == 0)
{
- CloseSession(session);
+ CloseSession();
return -1;
}
- else if (ret < 0)
+ else // if (ret < 0)
{
- int err = SSL_get_error(session->sess, ret);
+ int err = SSL_get_error(sess, ret);
if (err == SSL_ERROR_WANT_WRITE)
{
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
+ SocketEngine::ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
return 0;
}
else if (err == SSL_ERROR_WANT_READ)
{
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ);
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ);
return 0;
}
else
{
- CloseSession(session);
+ CloseSession();
return -1;
}
}
}
- return 0;
+
+ data_to_write = false;
+ SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+ return 1;
}
- bool Handshake(StreamSocket* user, issl_session* session)
+ void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
{
- int ret;
+ if (!IsHandshakeDone())
+ return;
+ out.append(SSL_get_version(sess)).push_back('-');
+ out.append(SSL_get_cipher(sess));
+ }
- ERR_clear_error();
- if (session->outbound)
- ret = SSL_connect(session->sess);
- else
- ret = SSL_accept(session->sess);
+ bool GetServerName(std::string& out) const CXX11_OVERRIDE
+ {
+ const char* name = SSL_get_servername(sess, TLSEXT_NAMETYPE_host_name);
+ if (!name)
+ return false;
- if (ret < 0)
+ out.append(name);
+ return true;
+ }
+
+ bool IsHandshakeDone() const { return (status == ISSL_OPEN); }
+ OpenSSL::Profile& GetProfile();
+};
+
+static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc)
+{
+ OpenSSLIOHook* hook = static_cast<OpenSSLIOHook*>(SSL_get_ex_data(ssl, exdataindex));
+ hook->SSLInfoCallback(where, rc);
+}
+
+static int OpenSSL::BIOMethod::write(BIO* bio, const char* buffer, int size)
+{
+ BIO_clear_retry_flags(bio);
+
+ StreamSocket* sock = static_cast<StreamSocket*>(BIO_get_data(bio));
+ if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ {
+ // Writes blocked earlier, don't retry syscall
+ BIO_set_retry_write(bio);
+ return -1;
+ }
+
+ int ret = SocketEngine::Send(sock, buffer, size, 0);
+ if ((ret < size) && ((ret > 0) || (SocketEngine::IgnoreError())))
+ {
+ // Blocked, set retry flag for OpenSSL
+ SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
+ BIO_set_retry_write(bio);
+ }
+
+ return ret;
+}
+
+static int OpenSSL::BIOMethod::read(BIO* bio, char* buffer, int size)
+{
+ BIO_clear_retry_flags(bio);
+
+ StreamSocket* sock = static_cast<StreamSocket*>(BIO_get_data(bio));
+ if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
+ {
+ // Reads blocked earlier, don't retry syscall
+ BIO_set_retry_read(bio);
+ return -1;
+ }
+
+ int ret = SocketEngine::Recv(sock, buffer, size, 0);
+ if ((ret < size) && ((ret > 0) || (SocketEngine::IgnoreError())))
+ {
+ // Blocked, set retry flag for OpenSSL
+ SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
+ BIO_set_retry_read(bio);
+ }
+
+ return ret;
+}
+
+class OpenSSLIOHookProvider : public IOHookProvider
+{
+ OpenSSL::Profile profile;
+
+ public:
+ OpenSSLIOHookProvider(Module* mod, const std::string& profilename, ConfigTag* tag)
+ : IOHookProvider(mod, "ssl/" + profilename, IOHookProvider::IOH_SSL)
+ , profile(profilename, tag)
+ {
+ ServerInstance->Modules->AddService(*this);
+ }
+
+ ~OpenSSLIOHookProvider()
+ {
+ ServerInstance->Modules->DelService(*this);
+ }
+
+ void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+ {
+ new OpenSSLIOHook(this, sock, profile.CreateServerSession());
+ }
+
+ void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+ {
+ new OpenSSLIOHook(this, sock, profile.CreateClientSession());
+ }
+
+ OpenSSL::Profile& GetProfile() { return profile; }
+};
+
+OpenSSL::Profile& OpenSSLIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<OpenSSLIOHookProvider*>(hookprov)->GetProfile();
+}
+
+class ModuleSSLOpenSSL : public Module
+{
+ typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList;
+
+ ProfileList profiles;
+
+ void ReadProfiles()
+ {
+ ProfileList newprofiles;
+ ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
+ if (tags.first == tags.second)
{
- int err = SSL_get_error(session->sess, ret);
+ // Create a default profile named "openssl"
+ const std::string defname = "openssl";
+ ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found, using settings from the <openssl> tag");
- if (err == SSL_ERROR_WANT_READ)
+ try
{
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
- session->status = ISSL_HANDSHAKING;
- return true;
+ newprofiles.push_back(new OpenSSLIOHookProvider(this, defname, tag));
}
- else if (err == SSL_ERROR_WANT_WRITE)
- {
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
- session->status = ISSL_HANDSHAKING;
- return true;
- }
- else
+ catch (OpenSSL::Exception& ex)
{
- CloseSession(session);
+ throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
}
-
- return false;
}
- else if (ret > 0)
+
+ for (ConfigIter i = tags.first; i != tags.second; ++i)
{
- // Handshake complete.
- VerifyCertificate(session, user);
+ ConfigTag* tag = i->second;
+ if (!stdalgo::string::equalsci(tag->getString("provider"), "openssl"))
+ continue;
- session->status = ISSL_OPEN;
+ std::string name = tag->getString("name");
+ if (name.empty())
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
+ continue;
+ }
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
+ reference<OpenSSLIOHookProvider> prov;
+ try
+ {
+ prov = new OpenSSLIOHookProvider(this, name, tag);
+ }
+ catch (CoreException& ex)
+ {
+ throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
+ }
- return true;
+ newprofiles.push_back(prov);
}
- else if (ret == 0)
+
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
{
- CloseSession(session);
+ OpenSSLIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
}
- return false;
+
+ profiles.swap(newprofiles);
}
- void CloseSession(issl_session* session)
+ public:
+ ModuleSSLOpenSSL()
{
- if (session->sess)
- {
- SSL_shutdown(session->sess);
- SSL_free(session->sess);
- }
+ // Initialize OpenSSL
+ OPENSSL_init_ssl(0, NULL);
+#ifdef INSPIRCD_OPENSSL_OPAQUE_BIO
+ biomethods = OpenSSL::BIOMethod::alloc();
+ }
- session->sess = NULL;
- session->status = ISSL_NONE;
- session->cert = NULL;
+ ~ModuleSSLOpenSSL()
+ {
+ BIO_meth_free(biomethods);
+#endif
}
- void VerifyCertificate(issl_session* session, StreamSocket* user)
+ void init() CXX11_OVERRIDE
{
- if (!session->sess || !user)
- return;
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "OpenSSL lib version \"%s\" module was compiled for \"" OPENSSL_VERSION_TEXT "\"", OpenSSL_version(OPENSSL_VERSION));
- X509* cert;
- ssl_cert* certinfo = new ssl_cert;
- session->cert = certinfo;
- unsigned int n;
- unsigned char md[EVP_MAX_MD_SIZE];
- const EVP_MD *digest = use_sha ? EVP_sha1() : EVP_md5();
+ // Register application specific data
+ char exdatastr[] = "inspircd";
+ exdataindex = SSL_get_ex_new_index(0, exdatastr, NULL, NULL, NULL);
+ if (exdataindex < 0)
+ throw ModuleException("Failed to register application specific data");
- cert = SSL_get_peer_certificate((SSL*)session->sess);
+ ReadProfiles();
+ }
- if (!cert)
- {
- certinfo->error = "Could not get peer certificate: "+std::string(get_error());
+ void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
+ {
+ if (param != "ssl")
return;
- }
-
- certinfo->invalid = (SSL_get_verify_result(session->sess) != X509_V_OK);
- if (!SelfSigned)
+ try
{
- certinfo->unknownsigner = false;
- certinfo->trusted = true;
+ ReadProfiles();
}
- else
+ catch (ModuleException& ex)
{
- certinfo->unknownsigner = true;
- certinfo->trusted = false;
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
}
+ }
- char buf[512];
- X509_NAME_oneline(X509_get_subject_name(cert), buf, sizeof(buf));
- certinfo->dn = buf;
- // Make sure there are no chars in the string that we consider invalid
- if (certinfo->dn.find_first_of("\r\n") != std::string::npos)
- certinfo->dn.clear();
-
- X509_NAME_oneline(X509_get_issuer_name(cert), buf, sizeof(buf));
- certinfo->issuer = buf;
- if (certinfo->issuer.find_first_of("\r\n") != std::string::npos)
- certinfo->issuer.clear();
-
- if (!X509_digest(cert, digest, md, &n))
- {
- certinfo->error = "Out of memory generating fingerprint";
- }
- else
+ void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
+ {
+ if (type == ExtensionItem::EXT_USER)
{
- certinfo->fingerprint = irc::hex(md, n);
- }
+ LocalUser* user = IS_LOCAL((User*)item);
- if ((ASN1_UTCTIME_cmp_time_t(X509_getm_notAfter(cert), ServerInstance->Time()) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_getm_notBefore(cert), ServerInstance->Time()) == 0))
- {
- certinfo->error = "Not activated, or expired certificate";
+ if ((user) && (user->eh.GetModHook(this)))
+ {
+ // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
+ // Potentially there could be multiple SSL modules loaded at once on different ports.
+ ServerInstance->Users->QuitUser(user, "SSL module unloading");
+ }
}
+ }
- X509_free(cert);
+ ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
+ {
+ const OpenSSLIOHook* const iohook = static_cast<OpenSSLIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
+ return MOD_RES_PASSTHRU;
}
-};
-static int error_callback(const char *str, size_t len, void *u)
-{
- ServerInstance->Logs->Log("m_ssl_openssl",DEFAULT, "SSL error: " + std::string(str, len - 1));
-
- //
- // XXX: Remove this line, it causes valgrind warnings...
- //
- // MD_update(&m, buf, j);
- //
- //
- // ... ONLY JOKING! :-)
- //
-
- return 0;
-}
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Provides SSL support for clients", VF_VENDOR);
+ }
+};
MODULE_INIT(ModuleSSLOpenSSL)
diff --git a/src/modules/extra/m_sslrehashsignal.cpp b/src/modules/extra/m_sslrehashsignal.cpp
new file mode 100644
index 000000000..fea32326a
--- /dev/null
+++ b/src/modules/extra/m_sslrehashsignal.cpp
@@ -0,0 +1,64 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2018 Peter Powell <petpow@saberuk.com>
+ * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+#include "inspircd.h"
+
+static volatile sig_atomic_t signaled;
+
+class ModuleSSLRehashSignal : public Module
+{
+ private:
+ static void SignalHandler(int)
+ {
+ signaled = 1;
+ }
+
+ public:
+ ~ModuleSSLRehashSignal()
+ {
+ signal(SIGUSR1, SIG_DFL);
+ }
+
+ void init()
+ {
+ signal(SIGUSR1, SignalHandler);
+ }
+
+ void OnBackgroundTimer(time_t)
+ {
+ if (!signaled)
+ return;
+
+ const std::string feedbackmsg = "Got SIGUSR1, reloading SSL credentials";
+ ServerInstance->SNO->WriteGlobalSno('a', feedbackmsg);
+ ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, feedbackmsg);
+
+ const std::string str = "ssl";
+ FOREACH_MOD(OnModuleRehash, (NULL, str));
+ signaled = 0;
+ }
+
+ Version GetVersion()
+ {
+ return Version("Reloads SSL credentials on SIGUSR1", VF_VENDOR);
+ }
+};
+
+MODULE_INIT(ModuleSSLRehashSignal)