summaryrefslogtreecommitdiff
path: root/src/modules/extra
diff options
context:
space:
mode:
authorpeavey <peavey@e03df62e-2008-0410-955e-edbf42e46eb7>2007-07-16 17:30:04 +0000
committerpeavey <peavey@e03df62e-2008-0410-955e-edbf42e46eb7>2007-07-16 17:30:04 +0000
commitf2acdbc3820f0f4f5ef76a0a64e73d2a320df91f (patch)
tree0602469ef10e4dab4b3975599eb4f919a501c1eb /src/modules/extra
parent387f54199e9f335c58af888bdad5ddc1f5cf9bec (diff)
OOPS! We try again, since I'm smoking craq. LF is 0x0a NOT CR.
git-svn-id: http://svn.inspircd.org/repository/trunk/inspircd@7456 e03df62e-2008-0410-955e-edbf42e46eb7
Diffstat (limited to 'src/modules/extra')
-rw-r--r--src/modules/extra/README8
-rw-r--r--src/modules/extra/m_filter_pcre.cpp183
-rw-r--r--src/modules/extra/m_httpclienttest.cpp82
-rw-r--r--src/modules/extra/m_mysql.cpp890
-rw-r--r--src/modules/extra/m_pgsql.cpp985
-rw-r--r--src/modules/extra/m_sqlauth.cpp195
-rw-r--r--src/modules/extra/m_sqlite3.cpp661
-rw-r--r--src/modules/extra/m_sqllog.cpp311
-rw-r--r--src/modules/extra/m_sqloper.cpp284
-rw-r--r--src/modules/extra/m_sqlutils.cpp239
-rw-r--r--src/modules/extra/m_sqlutils.h144
-rw-r--r--src/modules/extra/m_sqlv2.h606
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp844
-rw-r--r--src/modules/extra/m_ssl_openssl.cpp902
-rw-r--r--src/modules/extra/m_ssl_oper_cert.cpp181
-rw-r--r--src/modules/extra/m_sslinfo.cpp95
-rw-r--r--src/modules/extra/m_testclient.cpp111
-rw-r--r--src/modules/extra/m_ziplink.cpp453
18 files changed, 7156 insertions, 18 deletions
diff --git a/src/modules/extra/README b/src/modules/extra/README
index 4c4beef9d..7e3096b34 100644
--- a/src/modules/extra/README
+++ b/src/modules/extra/README
@@ -1 +1,7 @@
-This directory stores modules which require external libraries to compile. For example, m_filter_pcre requires the PCRE libraries. To compile any of these modules first ensure you have the required dependencies (read the online documentation at http://www.inspircd.org/wiki/) and then cp the .cpp file from this directory into the parent directory (src/modules/) and re-configure your inspircd with ./configure -update to detect the new module. \ No newline at end of file
+This directory stores modules which require external libraries to compile.
+For example, m_filter_pcre requires the PCRE libraries.
+
+To compile any of these modules first ensure you have the required dependencies
+(read the online documentation at http://www.inspircd.org/wiki/) and then cp
+the .cpp file from this directory into the parent directory (src/modules/) and
+re-configure your inspircd with ./configure -update to detect the new module.
diff --git a/src/modules/extra/m_filter_pcre.cpp b/src/modules/extra/m_filter_pcre.cpp
index 0c6c05c8c..6fe79a981 100644
--- a/src/modules/extra/m_filter_pcre.cpp
+++ b/src/modules/extra/m_filter_pcre.cpp
@@ -1 +1,182 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <pcre.h> #include "users.h" #include "channels.h" #include "modules.h" #include "m_filter.h" /* $ModDesc: m_filter with regexps */ /* $CompileFlags: exec("pcre-config --cflags") */ /* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */ /* $ModDep: m_filter.h */ #ifdef WINDOWS #pragma comment(lib, "pcre.lib") #endif class PCREFilter : public FilterResult { public: pcre* regexp; PCREFilter(pcre* r, const std::string &rea, const std::string &act, long gline_time, const std::string &pat, const std::string &flags) : FilterResult(pat, rea, act, gline_time, flags), regexp(r) { } PCREFilter() { } }; class ModuleFilterPCRE : public FilterBase { std::vector<PCREFilter> filters; pcre *re; const char *error; int erroffset; PCREFilter fr; public: ModuleFilterPCRE(InspIRCd* Me) : FilterBase(Me, "m_filter_pcre.so") { OnRehash(NULL,""); } virtual ~ModuleFilterPCRE() { } virtual FilterResult* FilterMatch(userrec* user, const std::string &text, int flags) { for (std::vector<PCREFilter>::iterator index = filters.begin(); index != filters.end(); index++) { /* Skip ones that dont apply to us */ if (!FilterBase::AppliesToMe(user, dynamic_cast<FilterResult*>(&(*index)), flags)) continue; if (pcre_exec(index->regexp, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1) { fr = *index; if (index != filters.begin()) { filters.erase(index); filters.insert(filters.begin(), fr); } return &fr; } } return NULL; } virtual bool DeleteFilter(const std::string &freeform) { for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) { if (i->freeform == freeform) { pcre_free((*i).regexp); filters.erase(i); return true; } } return false; } virtual void SyncFilters(Module* proto, void* opaque) { for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) { this->SendFilter(proto, opaque, &(*i)); } } virtual std::pair<bool, std::string> AddFilter(const std::string &freeform, const std::string &type, const std::string &reason, long duration, const std::string &flags) { for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) { if (i->freeform == freeform) { return std::make_pair(false, "Filter already exists"); } } re = pcre_compile(freeform.c_str(),0,&error,&erroffset,NULL); if (!re) { ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", freeform.c_str(), erroffset, error); ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", freeform.c_str()); return std::make_pair(false, "Error in regular expression at offset " + ConvToStr(erroffset) + ": "+error); } else { filters.push_back(PCREFilter(re, reason, type, duration, freeform, flags)); return std::make_pair(true, ""); } } virtual void OnRehash(userrec* user, const std::string &parameter) { ConfigReader MyConf(ServerInstance); for (int index = 0; index < MyConf.Enumerate("keyword"); index++) { this->DeleteFilter(MyConf.ReadValue("keyword", "pattern", index)); std::string pattern = MyConf.ReadValue("keyword", "pattern", index); std::string reason = MyConf.ReadValue("keyword", "reason", index); std::string action = MyConf.ReadValue("keyword", "action", index); std::string flags = MyConf.ReadValue("keyword", "flags", index); long gline_time = ServerInstance->Duration(MyConf.ReadValue("keyword", "duration", index)); if (action.empty()) action = "none"; if (flags.empty()) flags = "*"; re = pcre_compile(pattern.c_str(),0,&error,&erroffset,NULL); if (!re) { ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", pattern.c_str(), erroffset, error); ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", pattern.c_str()); } else { filters.push_back(PCREFilter(re, reason, action, gline_time, pattern, flags)); ServerInstance->Log(DEFAULT,"Regular expression %s loaded.", pattern.c_str()); } } } virtual int OnStats(char symbol, userrec* user, string_list &results) { if (symbol == 's') { std::string sn = ServerInstance->Config->ServerName; for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) { results.push_back(sn+" 223 "+user->nick+" :REGEXP:"+i->freeform+" "+i->flags+" "+i->action+" "+ConvToStr(i->gline_time)+" :"+i->reason); } } return 0; } }; MODULE_INIT(ModuleFilterPCRE); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <pcre.h>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "m_filter.h"
+
+/* $ModDesc: m_filter with regexps */
+/* $CompileFlags: exec("pcre-config --cflags") */
+/* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */
+/* $ModDep: m_filter.h */
+
+#ifdef WINDOWS
+#pragma comment(lib, "pcre.lib")
+#endif
+
+class PCREFilter : public FilterResult
+{
+ public:
+ pcre* regexp;
+
+ PCREFilter(pcre* r, const std::string &rea, const std::string &act, long gline_time, const std::string &pat, const std::string &flags)
+ : FilterResult(pat, rea, act, gline_time, flags), regexp(r)
+ {
+ }
+
+ PCREFilter()
+ {
+ }
+};
+
+class ModuleFilterPCRE : public FilterBase
+{
+ std::vector<PCREFilter> filters;
+ pcre *re;
+ const char *error;
+ int erroffset;
+ PCREFilter fr;
+
+ public:
+ ModuleFilterPCRE(InspIRCd* Me)
+ : FilterBase(Me, "m_filter_pcre.so")
+ {
+ OnRehash(NULL,"");
+ }
+
+ virtual ~ModuleFilterPCRE()
+ {
+ }
+
+ virtual FilterResult* FilterMatch(userrec* user, const std::string &text, int flags)
+ {
+ for (std::vector<PCREFilter>::iterator index = filters.begin(); index != filters.end(); index++)
+ {
+ /* Skip ones that dont apply to us */
+
+ if (!FilterBase::AppliesToMe(user, dynamic_cast<FilterResult*>(&(*index)), flags))
+ continue;
+
+ if (pcre_exec(index->regexp, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1)
+ {
+ fr = *index;
+ if (index != filters.begin())
+ {
+ filters.erase(index);
+ filters.insert(filters.begin(), fr);
+ }
+ return &fr;
+ }
+ }
+ return NULL;
+ }
+
+ virtual bool DeleteFilter(const std::string &freeform)
+ {
+ for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
+ {
+ if (i->freeform == freeform)
+ {
+ pcre_free((*i).regexp);
+ filters.erase(i);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ virtual void SyncFilters(Module* proto, void* opaque)
+ {
+ for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
+ {
+ this->SendFilter(proto, opaque, &(*i));
+ }
+ }
+
+ virtual std::pair<bool, std::string> AddFilter(const std::string &freeform, const std::string &type, const std::string &reason, long duration, const std::string &flags)
+ {
+ for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
+ {
+ if (i->freeform == freeform)
+ {
+ return std::make_pair(false, "Filter already exists");
+ }
+ }
+
+ re = pcre_compile(freeform.c_str(),0,&error,&erroffset,NULL);
+
+ if (!re)
+ {
+ ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", freeform.c_str(), erroffset, error);
+ ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", freeform.c_str());
+ return std::make_pair(false, "Error in regular expression at offset " + ConvToStr(erroffset) + ": "+error);
+ }
+ else
+ {
+ filters.push_back(PCREFilter(re, reason, type, duration, freeform, flags));
+ return std::make_pair(true, "");
+ }
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ConfigReader MyConf(ServerInstance);
+
+ for (int index = 0; index < MyConf.Enumerate("keyword"); index++)
+ {
+ this->DeleteFilter(MyConf.ReadValue("keyword", "pattern", index));
+
+ std::string pattern = MyConf.ReadValue("keyword", "pattern", index);
+ std::string reason = MyConf.ReadValue("keyword", "reason", index);
+ std::string action = MyConf.ReadValue("keyword", "action", index);
+ std::string flags = MyConf.ReadValue("keyword", "flags", index);
+ long gline_time = ServerInstance->Duration(MyConf.ReadValue("keyword", "duration", index));
+ if (action.empty())
+ action = "none";
+ if (flags.empty())
+ flags = "*";
+
+ re = pcre_compile(pattern.c_str(),0,&error,&erroffset,NULL);
+
+ if (!re)
+ {
+ ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", pattern.c_str(), erroffset, error);
+ ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", pattern.c_str());
+ }
+ else
+ {
+ filters.push_back(PCREFilter(re, reason, action, gline_time, pattern, flags));
+ ServerInstance->Log(DEFAULT,"Regular expression %s loaded.", pattern.c_str());
+ }
+ }
+ }
+
+ virtual int OnStats(char symbol, userrec* user, string_list &results)
+ {
+ if (symbol == 's')
+ {
+ std::string sn = ServerInstance->Config->ServerName;
+ for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
+ {
+ results.push_back(sn+" 223 "+user->nick+" :REGEXP:"+i->freeform+" "+i->flags+" "+i->action+" "+ConvToStr(i->gline_time)+" :"+i->reason);
+ }
+ }
+ return 0;
+ }
+};
+
+MODULE_INIT(ModuleFilterPCRE);
+
diff --git a/src/modules/extra/m_httpclienttest.cpp b/src/modules/extra/m_httpclienttest.cpp
index 3f74b549b..90e7a5159 100644
--- a/src/modules/extra/m_httpclienttest.cpp
+++ b/src/modules/extra/m_httpclienttest.cpp
@@ -1 +1,81 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "httpclient.h" /* $ModDep: httpclient.h */ class MyModule : public Module { public: MyModule(InspIRCd* Me) : Module::Module(Me) { } virtual ~MyModule() { } virtual void Implements(char* List) { List[I_OnRequest] = List[I_OnUserJoin] = List[I_OnUserPart] = 1; } virtual Version GetVersion() { return Version(1,0,0,1,VF_VENDOR,API_VERSION); } virtual void OnUserJoin(userrec* user, chanrec* channel, bool &silent) { // method called when a user joins a channel std::string chan = channel->name; std::string nick = user->nick; ServerInstance->Log(DEBUG,"User " + nick + " joined " + chan); Module* target = ServerInstance->FindModule("m_http_client.so"); if(target) { HTTPClientRequest req(ServerInstance, this, target, "http://znc.in/~psychon"); req.Send(); } else ServerInstance->Log(DEBUG,"module not found, load it!!"); } char* OnRequest(Request* req) { HTTPClientResponse* resp = (HTTPClientResponse*)req; if(!strcmp(resp->GetId(), HTTP_CLIENT_RESPONSE)) { ServerInstance->Log(DEBUG, resp->GetData()); } return NULL; } virtual void OnUserPart(userrec* user, chanrec* channel, const std::string &partmessage, bool &silent) { } }; MODULE_INIT(MyModule); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "httpclient.h"
+
+/* $ModDep: httpclient.h */
+
+class MyModule : public Module
+{
+
+public:
+
+ MyModule(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ }
+
+ virtual ~MyModule()
+ {
+ }
+
+ virtual void Implements(char* List)
+ {
+ List[I_OnRequest] = List[I_OnUserJoin] = List[I_OnUserPart] = 1;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,0,0,1,VF_VENDOR,API_VERSION);
+ }
+
+ virtual void OnUserJoin(userrec* user, chanrec* channel, bool &silent)
+ {
+ // method called when a user joins a channel
+
+ std::string chan = channel->name;
+ std::string nick = user->nick;
+ ServerInstance->Log(DEBUG,"User " + nick + " joined " + chan);
+
+ Module* target = ServerInstance->FindModule("m_http_client.so");
+ if(target)
+ {
+ HTTPClientRequest req(ServerInstance, this, target, "http://znc.in/~psychon");
+ req.Send();
+ }
+ else
+ ServerInstance->Log(DEBUG,"module not found, load it!!");
+ }
+
+ char* OnRequest(Request* req)
+ {
+ HTTPClientResponse* resp = (HTTPClientResponse*)req;
+ if(!strcmp(resp->GetId(), HTTP_CLIENT_RESPONSE))
+ {
+ ServerInstance->Log(DEBUG, resp->GetData());
+ }
+ return NULL;
+ }
+
+ virtual void OnUserPart(userrec* user, chanrec* channel, const std::string &partmessage, bool &silent)
+ {
+ }
+
+};
+
+MODULE_INIT(MyModule);
+
diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp
index eeabe5d48..6605bed3c 100644
--- a/src/modules/extra/m_mysql.cpp
+++ b/src/modules/extra/m_mysql.cpp
@@ -1 +1,889 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <mysql.h> #include <pthread.h> #include "users.h" #include "channels.h" #include "modules.h" #include "m_sqlv2.h" /* VERSION 2 API: With nonblocking (threaded) requests */ /* $ModDesc: SQL Service Provider module for all other m_sql* modules */ /* $CompileFlags: exec("mysql_config --include") */ /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */ /* $ModDep: m_sqlv2.h */ /* THE NONBLOCKING MYSQL API! * * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend * that instead, you should thread your program. This is what i've done here to allow for * asyncronous SQL requests via mysql. The way this works is as follows: * * The module spawns a thread via pthreads, and performs its mysql queries in this thread, * using a queue with priorities. There is a mutex on either end which prevents two threads * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the * worker thread wakes up, and checks if there is a request at the head of its queue. * If there is, it processes this request, blocking the worker thread but leaving the ircd * thread to go about its business as usual. During this period, the ircd thread is able * to insert futher pending requests into the queue. * * Once the processing of a request is complete, it is removed from the incoming queue to * an outgoing queue, and initialized as a 'response'. The worker thread then signals the * ircd thread (via a loopback socket) of the fact a result is available, by sending the * connection ID through the connection. * * The ircd thread then mutexes the queue once more, reads the outbound response off the head * of the queue, and sends it on its way to the original calling module. * * XXX: You might be asking "why doesnt he just send the response from within the worker thread?" * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not * threadsafe. This module is designed to be threadsafe and is careful with its use of threads, * however, if we were to call a module's OnRequest even from within a thread which was not the * one the module was originally instantiated upon, there is a chance of all hell breaking loose * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100% * gauranteed threadsafe!) * * For a diagram of this system please see http://www.inspircd.org/wiki/Mysql2 */ class SQLConnection; class Notifier; typedef std::map<std::string, SQLConnection*> ConnMap; bool giveup = false; static Module* SQLModule = NULL; static Notifier* MessagePipe = NULL; int QueueFD = -1; #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224 #define mysql_field_count mysql_num_fields #endif typedef std::deque<SQLresult*> ResultQueue; /* A mutex to wrap around queue accesses */ pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t results_mutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t logging_mutex = PTHREAD_MUTEX_INITIALIZER; /** Represents a mysql result set */ class MySQLresult : public SQLresult { int currentrow; std::vector<std::string> colnames; std::vector<SQLfieldList> fieldlists; SQLfieldMap* fieldmap; SQLfieldMap fieldmap2; SQLfieldList emptyfieldlist; int rows; public: MySQLresult(Module* self, Module* to, MYSQL_RES* res, int affected_rows, unsigned int id) : SQLresult(self, to, id), currentrow(0), fieldmap(NULL) { /* A number of affected rows from from mysql_affected_rows. */ fieldlists.clear(); rows = 0; if (affected_rows >= 1) { rows = affected_rows; fieldlists.resize(rows); } unsigned int field_count = 0; if (res) { MYSQL_ROW row; int n = 0; while ((row = mysql_fetch_row(res))) { if (fieldlists.size() < (unsigned int)rows+1) { fieldlists.resize(fieldlists.size()+1); } field_count = 0; MYSQL_FIELD *fields = mysql_fetch_fields(res); if(mysql_num_fields(res) == 0) break; if (fields && mysql_num_fields(res)) { colnames.clear(); while (field_count < mysql_num_fields(res)) { std::string a = (fields[field_count].name ? fields[field_count].name : ""); std::string b = (row[field_count] ? row[field_count] : ""); SQLfield sqlf(b, !row[field_count]); colnames.push_back(a); fieldlists[n].push_back(sqlf); field_count++; } n++; } rows++; } mysql_free_result(res); } } MySQLresult(Module* self, Module* to, SQLerror e, unsigned int id) : SQLresult(self, to, id), currentrow(0) { rows = 0; error = e; } ~MySQLresult() { } virtual int Rows() { return rows; } virtual int Cols() { return colnames.size(); } virtual std::string ColName(int column) { if (column < (int)colnames.size()) { return colnames[column]; } else { throw SQLbadColName(); } return ""; } virtual int ColNum(const std::string &column) { for (unsigned int i = 0; i < colnames.size(); i++) { if (column == colnames[i]) return i; } throw SQLbadColName(); return 0; } virtual SQLfield GetValue(int row, int column) { if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols())) { return fieldlists[row][column]; } throw SQLbadColName(); /* XXX: We never actually get here because of the throw */ return SQLfield("",true); } virtual SQLfieldList& GetRow() { if (currentrow < rows) return fieldlists[currentrow]; else return emptyfieldlist; } virtual SQLfieldMap& GetRowMap() { fieldmap2.clear(); if (currentrow < rows) { for (int i = 0; i < Cols(); i++) { fieldmap2.insert(std::make_pair(colnames[i],GetValue(currentrow, i))); } currentrow++; } return fieldmap2; } virtual SQLfieldList* GetRowPtr() { SQLfieldList* fieldlist = new SQLfieldList(); if (currentrow < rows) { for (int i = 0; i < Rows(); i++) { fieldlist->push_back(fieldlists[currentrow][i]); } currentrow++; } return fieldlist; } virtual SQLfieldMap* GetRowMapPtr() { fieldmap = new SQLfieldMap(); if (currentrow < rows) { for (int i = 0; i < Cols(); i++) { fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i))); } currentrow++; } return fieldmap; } virtual void Free(SQLfieldMap* fm) { delete fm; } virtual void Free(SQLfieldList* fl) { delete fl; } }; class SQLConnection; void NotifyMainThread(SQLConnection* connection_with_new_result); /** Represents a connection to a mysql database */ class SQLConnection : public classbase { protected: MYSQL connection; MYSQL_RES *res; MYSQL_ROW row; SQLhost host; std::map<std::string,std::string> thisrow; bool Enabled; public: QueryQueue queue; ResultQueue rq; // This constructor creates an SQLConnection object with the given credentials, but does not connect yet. SQLConnection(const SQLhost &hi) : host(hi), Enabled(false) { } ~SQLConnection() { Close(); } // This method connects to the database using the credentials supplied to the constructor, and returns // true upon success. bool Connect() { unsigned int timeout = 1; mysql_init(&connection); mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout); return mysql_real_connect(&connection, host.host.c_str(), host.user.c_str(), host.pass.c_str(), host.name.c_str(), host.port, NULL, 0); } void DoLeadingQuery() { if (!CheckConnection()) return; /* Parse the command string and dispatch it to mysql */ SQLrequest& req = queue.front(); /* Pointer to the buffer we screw around with substitution in */ char* query; /* Pointer to the current end of query, where we append new stuff */ char* queryend; /* Total length of the unescaped parameters */ unsigned long paramlen; /* Total length of query, used for binary-safety in mysql_real_query */ unsigned long querylength = 0; paramlen = 0; for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) { paramlen += i->size(); } /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. * sizeofquery + (totalparamlength*2) + 1 * * The +1 is for null-terminating the string for mysql_real_escape_string */ query = new char[req.query.q.length() + (paramlen*2) + 1]; queryend = query; /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting * the parameters into it... */ for(unsigned long i = 0; i < req.query.q.length(); i++) { if(req.query.q[i] == '?') { /* We found a place to substitute..what fun. * use mysql calls to escape and write the * escaped string onto the end of our query buffer, * then we "just" need to make sure queryend is * pointing at the right place. */ if(req.query.p.size()) { unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length()); queryend += len; req.query.p.pop_front(); } else break; } else { *queryend = req.query.q[i]; queryend++; } querylength++; } *queryend = 0; pthread_mutex_lock(&queue_mutex); req.query.q = query; pthread_mutex_unlock(&queue_mutex); if (!mysql_real_query(&connection, req.query.q.data(), req.query.q.length())) { /* Successfull query */ res = mysql_use_result(&connection); unsigned long rows = mysql_affected_rows(&connection); MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), res, rows, req.id); r->dbid = this->GetID(); r->query = req.query.q; /* Put this new result onto the results queue. * XXX: Remember to mutex the queue! */ pthread_mutex_lock(&results_mutex); rq.push_back(r); pthread_mutex_unlock(&results_mutex); } else { /* XXX: See /usr/include/mysql/mysqld_error.h for a list of * possible error numbers and error messages */ SQLerror e(QREPLY_FAIL, ConvToStr(mysql_errno(&connection)) + std::string(": ") + mysql_error(&connection)); MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), e, req.id); r->dbid = this->GetID(); r->query = req.query.q; pthread_mutex_lock(&results_mutex); rq.push_back(r); pthread_mutex_unlock(&results_mutex); } /* Now signal the main thread that we've got a result to process. * Pass them this connection id as what to examine */ delete[] query; NotifyMainThread(this); } bool ConnectionLost() { if (&connection) { return (mysql_ping(&connection) != 0); } else return false; } bool CheckConnection() { if (ConnectionLost()) { return Connect(); } else return true; } std::string GetError() { return mysql_error(&connection); } const std::string& GetID() { return host.id; } std::string GetHost() { return host.host; } void SetEnable(bool Enable) { Enabled = Enable; } bool IsEnabled() { return Enabled; } void Close() { mysql_close(&connection); } const SQLhost& GetConfHost() { return host; } }; ConnMap Connections; bool HasHost(const SQLhost &host) { for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++) { if (host == iter->second->GetConfHost()) return true; } return false; } bool HostInConf(ConfigReader* conf, const SQLhost &h) { for(int i = 0; i < conf->Enumerate("database"); i++) { SQLhost host; host.id = conf->ReadValue("database", "id", i); host.host = conf->ReadValue("database", "hostname", i); host.port = conf->ReadInteger("database", "port", i, true); host.name = conf->ReadValue("database", "name", i); host.user = conf->ReadValue("database", "username", i); host.pass = conf->ReadValue("database", "password", i); host.ssl = conf->ReadFlag("database", "ssl", i); if (h == host) return true; } return false; } void ClearOldConnections(ConfigReader* conf) { ConnMap::iterator i,safei; for (i = Connections.begin(); i != Connections.end(); i++) { if (!HostInConf(conf, i->second->GetConfHost())) { DELETE(i->second); safei = i; --i; Connections.erase(safei); } } } void ClearAllConnections() { ConnMap::iterator i; while ((i = Connections.begin()) != Connections.end()) { Connections.erase(i); DELETE(i->second); } } void ConnectDatabases(InspIRCd* ServerInstance) { for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) { if (i->second->IsEnabled()) continue; i->second->SetEnable(true); if (!i->second->Connect()) { /* XXX: MUTEX */ pthread_mutex_lock(&logging_mutex); ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError()); i->second->SetEnable(false); pthread_mutex_unlock(&logging_mutex); } } } void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance) { ClearOldConnections(conf); for (int j =0; j < conf->Enumerate("database"); j++) { SQLhost host; host.id = conf->ReadValue("database", "id", j); host.host = conf->ReadValue("database", "hostname", j); host.port = conf->ReadInteger("database", "port", j, true); host.name = conf->ReadValue("database", "name", j); host.user = conf->ReadValue("database", "username", j); host.pass = conf->ReadValue("database", "password", j); host.ssl = conf->ReadFlag("database", "ssl", j); if (HasHost(host)) continue; if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty()) { SQLConnection* ThisSQL = new SQLConnection(host); Connections[host.id] = ThisSQL; } } ConnectDatabases(ServerInstance); } char FindCharId(const std::string &id) { char i = 1; for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i) { if (iter->first == id) { return i; } } return 0; } ConnMap::iterator GetCharId(char id) { char i = 1; for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i) { if (i == id) return iter; } return Connections.end(); } void NotifyMainThread(SQLConnection* connection_with_new_result) { /* Here we write() to the socket the main thread has open * and we connect()ed back to before our thread became active. * The main thread is using a nonblocking socket tied into * the socket engine, so they wont block and they'll receive * nearly instant notification. Because we're in a seperate * thread, we can just use standard connect(), and we can * block if we like. We just send the connection id of the * connection back. * * NOTE: We only send a single char down the connection, this * way we know it wont get a partial read at the other end if * the system is especially congested (see bug #263). * The function FindCharId translates a connection name into a * one character id, and GetCharId translates a character id * back into an iterator. */ char id = FindCharId(connection_with_new_result->GetID()); send(QueueFD, &id, 1, 0); } void* DispatcherThread(void* arg); /** Used by m_mysql to notify one thread when the other has a result */ class Notifier : public InspSocket { insp_sockaddr sock_us; socklen_t uslen; public: /* Create a socket on a random port. Let the tcp stack allocate us an available port */ #ifdef IPV6 Notifier(InspIRCd* SI) : InspSocket(SI, "::1", 0, true, 3000) #else Notifier(InspIRCd* SI) : InspSocket(SI, "127.0.0.1", 0, true, 3000) #endif { uslen = sizeof(sock_us); if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen)) { throw ModuleException("Could not create random listening port on localhost"); } } Notifier(InspIRCd* SI, int newfd, char* ip) : InspSocket(SI, newfd, ip) { } /* Using getsockname and ntohs, we can determine which port number we were allocated */ int GetPort() { #ifdef IPV6 return ntohs(sock_us.sin6_port); #else return ntohs(sock_us.sin_port); #endif } virtual int OnIncomingConnection(int newsock, char* ip) { Notifier* n = new Notifier(this->Instance, newsock, ip); n = n; /* Stop bitching at me, GCC */ return true; } virtual bool OnDataReady() { char data = 0; /* NOTE: Only a single character is read so we know we * cant get a partial read. (We've been told that theres * data waiting, so we wont ever get EAGAIN) * The function GetCharId translates a single character * back into an iterator. */ if (read(this->GetFd(), &data, 1) > 0) { ConnMap::iterator iter = GetCharId(data); if (iter != Connections.end()) { /* Lock the mutex, send back the data */ pthread_mutex_lock(&results_mutex); ResultQueue::iterator n = iter->second->rq.begin(); (*n)->Send(); iter->second->rq.pop_front(); pthread_mutex_unlock(&results_mutex); return true; } /* No error, but unknown id */ return true; } /* Erk, error on descriptor! */ return false; } }; /** MySQL module */ class ModuleSQL : public Module { public: ConfigReader *Conf; InspIRCd* PublicServerInstance; pthread_t Dispatcher; int currid; bool rehashing; ModuleSQL(InspIRCd* Me) : Module::Module(Me), rehashing(false) { ServerInstance->UseInterface("SQLutils"); Conf = new ConfigReader(ServerInstance); PublicServerInstance = ServerInstance; currid = 0; SQLModule = this; MessagePipe = new Notifier(ServerInstance); pthread_attr_t attribs; pthread_attr_init(&attribs); pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED); if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0) { throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno))); } if (!ServerInstance->PublishFeature("SQL", this)) { /* Tell worker thread to exit NOW */ giveup = true; throw ModuleException("m_mysql: Unable to publish feature 'SQL'"); } ServerInstance->PublishInterface("SQL", this); } virtual ~ModuleSQL() { giveup = true; ClearAllConnections(); DELETE(Conf); ServerInstance->UnpublishInterface("SQL", this); ServerInstance->UnpublishFeature("SQL"); ServerInstance->DoneWithInterface("SQLutils"); } void Implements(char* List) { List[I_OnRehash] = List[I_OnRequest] = 1; } unsigned long NewID() { if (currid+1 == 0) currid++; return ++currid; } char* OnRequest(Request* request) { if(strcmp(SQLREQID, request->GetId()) == 0) { SQLrequest* req = (SQLrequest*)request; /* XXX: Lock */ pthread_mutex_lock(&queue_mutex); ConnMap::iterator iter; char* returnval = NULL; if((iter = Connections.find(req->dbid)) != Connections.end()) { req->id = NewID(); iter->second->queue.push(*req); returnval = SQLSUCCESS; } else { req->error.Id(BAD_DBID); } pthread_mutex_unlock(&queue_mutex); /* XXX: Unlock */ return returnval; } return NULL; } virtual void OnRehash(userrec* user, const std::string &parameter) { rehashing = true; } virtual Version GetVersion() { return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION); } }; void* DispatcherThread(void* arg) { ModuleSQL* thismodule = (ModuleSQL*)arg; LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance); /* Connect back to the Notifier */ if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1) { /* crap, we're out of sockets... */ return NULL; } insp_sockaddr addr; #ifdef IPV6 insp_aton("::1", &addr.sin6_addr); addr.sin6_family = AF_FAMILY; addr.sin6_port = htons(MessagePipe->GetPort()); #else insp_inaddr ia; insp_aton("127.0.0.1", &ia); addr.sin_family = AF_FAMILY; addr.sin_addr = ia; addr.sin_port = htons(MessagePipe->GetPort()); #endif if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1) { /* wtf, we cant connect to it, but we just created it! */ return NULL; } while (!giveup) { if (thismodule->rehashing) { /* XXX: Lock */ pthread_mutex_lock(&queue_mutex); thismodule->rehashing = false; LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance); pthread_mutex_unlock(&queue_mutex); /* XXX: Unlock */ } SQLConnection* conn = NULL; /* XXX: Lock here for safety */ pthread_mutex_lock(&queue_mutex); for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) { if (i->second->queue.totalsize()) { conn = i->second; break; } } pthread_mutex_unlock(&queue_mutex); /* XXX: Unlock */ /* Theres an item! */ if (conn) { conn->DoLeadingQuery(); /* XXX: Lock */ pthread_mutex_lock(&queue_mutex); conn->queue.pop(); pthread_mutex_unlock(&queue_mutex); /* XXX: Unlock */ } usleep(50); } return NULL; } MODULE_INIT(ModuleSQL); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <mysql.h>
+#include <pthread.h>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "m_sqlv2.h"
+
+/* VERSION 2 API: With nonblocking (threaded) requests */
+
+/* $ModDesc: SQL Service Provider module for all other m_sql* modules */
+/* $CompileFlags: exec("mysql_config --include") */
+/* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
+/* $ModDep: m_sqlv2.h */
+
+/* THE NONBLOCKING MYSQL API!
+ *
+ * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
+ * that instead, you should thread your program. This is what i've done here to allow for
+ * asyncronous SQL requests via mysql. The way this works is as follows:
+ *
+ * The module spawns a thread via pthreads, and performs its mysql queries in this thread,
+ * using a queue with priorities. There is a mutex on either end which prevents two threads
+ * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
+ * worker thread wakes up, and checks if there is a request at the head of its queue.
+ * If there is, it processes this request, blocking the worker thread but leaving the ircd
+ * thread to go about its business as usual. During this period, the ircd thread is able
+ * to insert futher pending requests into the queue.
+ *
+ * Once the processing of a request is complete, it is removed from the incoming queue to
+ * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
+ * ircd thread (via a loopback socket) of the fact a result is available, by sending the
+ * connection ID through the connection.
+ *
+ * The ircd thread then mutexes the queue once more, reads the outbound response off the head
+ * of the queue, and sends it on its way to the original calling module.
+ *
+ * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
+ * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
+ * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
+ * however, if we were to call a module's OnRequest even from within a thread which was not the
+ * one the module was originally instantiated upon, there is a chance of all hell breaking loose
+ * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
+ * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
+ * gauranteed threadsafe!)
+ *
+ * For a diagram of this system please see http://www.inspircd.org/wiki/Mysql2
+ */
+
+
+class SQLConnection;
+class Notifier;
+
+
+typedef std::map<std::string, SQLConnection*> ConnMap;
+bool giveup = false;
+static Module* SQLModule = NULL;
+static Notifier* MessagePipe = NULL;
+int QueueFD = -1;
+
+
+#if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
+#define mysql_field_count mysql_num_fields
+#endif
+
+typedef std::deque<SQLresult*> ResultQueue;
+
+/* A mutex to wrap around queue accesses */
+pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+pthread_mutex_t results_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+pthread_mutex_t logging_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+/** Represents a mysql result set
+ */
+class MySQLresult : public SQLresult
+{
+ int currentrow;
+ std::vector<std::string> colnames;
+ std::vector<SQLfieldList> fieldlists;
+ SQLfieldMap* fieldmap;
+ SQLfieldMap fieldmap2;
+ SQLfieldList emptyfieldlist;
+ int rows;
+ public:
+
+ MySQLresult(Module* self, Module* to, MYSQL_RES* res, int affected_rows, unsigned int id) : SQLresult(self, to, id), currentrow(0), fieldmap(NULL)
+ {
+ /* A number of affected rows from from mysql_affected_rows.
+ */
+ fieldlists.clear();
+ rows = 0;
+ if (affected_rows >= 1)
+ {
+ rows = affected_rows;
+ fieldlists.resize(rows);
+ }
+ unsigned int field_count = 0;
+ if (res)
+ {
+ MYSQL_ROW row;
+ int n = 0;
+ while ((row = mysql_fetch_row(res)))
+ {
+ if (fieldlists.size() < (unsigned int)rows+1)
+ {
+ fieldlists.resize(fieldlists.size()+1);
+ }
+ field_count = 0;
+ MYSQL_FIELD *fields = mysql_fetch_fields(res);
+ if(mysql_num_fields(res) == 0)
+ break;
+ if (fields && mysql_num_fields(res))
+ {
+ colnames.clear();
+ while (field_count < mysql_num_fields(res))
+ {
+ std::string a = (fields[field_count].name ? fields[field_count].name : "");
+ std::string b = (row[field_count] ? row[field_count] : "");
+ SQLfield sqlf(b, !row[field_count]);
+ colnames.push_back(a);
+ fieldlists[n].push_back(sqlf);
+ field_count++;
+ }
+ n++;
+ }
+ rows++;
+ }
+ mysql_free_result(res);
+ }
+ }
+
+ MySQLresult(Module* self, Module* to, SQLerror e, unsigned int id) : SQLresult(self, to, id), currentrow(0)
+ {
+ rows = 0;
+ error = e;
+ }
+
+ ~MySQLresult()
+ {
+ }
+
+ virtual int Rows()
+ {
+ return rows;
+ }
+
+ virtual int Cols()
+ {
+ return colnames.size();
+ }
+
+ virtual std::string ColName(int column)
+ {
+ if (column < (int)colnames.size())
+ {
+ return colnames[column];
+ }
+ else
+ {
+ throw SQLbadColName();
+ }
+ return "";
+ }
+
+ virtual int ColNum(const std::string &column)
+ {
+ for (unsigned int i = 0; i < colnames.size(); i++)
+ {
+ if (column == colnames[i])
+ return i;
+ }
+ throw SQLbadColName();
+ return 0;
+ }
+
+ virtual SQLfield GetValue(int row, int column)
+ {
+ if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
+ {
+ return fieldlists[row][column];
+ }
+
+ throw SQLbadColName();
+
+ /* XXX: We never actually get here because of the throw */
+ return SQLfield("",true);
+ }
+
+ virtual SQLfieldList& GetRow()
+ {
+ if (currentrow < rows)
+ return fieldlists[currentrow];
+ else
+ return emptyfieldlist;
+ }
+
+ virtual SQLfieldMap& GetRowMap()
+ {
+ fieldmap2.clear();
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Cols(); i++)
+ {
+ fieldmap2.insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
+ }
+ currentrow++;
+ }
+
+ return fieldmap2;
+ }
+
+ virtual SQLfieldList* GetRowPtr()
+ {
+ SQLfieldList* fieldlist = new SQLfieldList();
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Rows(); i++)
+ {
+ fieldlist->push_back(fieldlists[currentrow][i]);
+ }
+ currentrow++;
+ }
+ return fieldlist;
+ }
+
+ virtual SQLfieldMap* GetRowMapPtr()
+ {
+ fieldmap = new SQLfieldMap();
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Cols(); i++)
+ {
+ fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
+ }
+ currentrow++;
+ }
+
+ return fieldmap;
+ }
+
+ virtual void Free(SQLfieldMap* fm)
+ {
+ delete fm;
+ }
+
+ virtual void Free(SQLfieldList* fl)
+ {
+ delete fl;
+ }
+};
+
+class SQLConnection;
+
+void NotifyMainThread(SQLConnection* connection_with_new_result);
+
+/** Represents a connection to a mysql database
+ */
+class SQLConnection : public classbase
+{
+ protected:
+
+ MYSQL connection;
+ MYSQL_RES *res;
+ MYSQL_ROW row;
+ SQLhost host;
+ std::map<std::string,std::string> thisrow;
+ bool Enabled;
+
+ public:
+
+ QueryQueue queue;
+ ResultQueue rq;
+
+ // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
+ SQLConnection(const SQLhost &hi) : host(hi), Enabled(false)
+ {
+ }
+
+ ~SQLConnection()
+ {
+ Close();
+ }
+
+ // This method connects to the database using the credentials supplied to the constructor, and returns
+ // true upon success.
+ bool Connect()
+ {
+ unsigned int timeout = 1;
+ mysql_init(&connection);
+ mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
+ return mysql_real_connect(&connection, host.host.c_str(), host.user.c_str(), host.pass.c_str(), host.name.c_str(), host.port, NULL, 0);
+ }
+
+ void DoLeadingQuery()
+ {
+ if (!CheckConnection())
+ return;
+
+ /* Parse the command string and dispatch it to mysql */
+ SQLrequest& req = queue.front();
+
+ /* Pointer to the buffer we screw around with substitution in */
+ char* query;
+
+ /* Pointer to the current end of query, where we append new stuff */
+ char* queryend;
+
+ /* Total length of the unescaped parameters */
+ unsigned long paramlen;
+
+ /* Total length of query, used for binary-safety in mysql_real_query */
+ unsigned long querylength = 0;
+
+ paramlen = 0;
+
+ for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
+ {
+ paramlen += i->size();
+ }
+
+ /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
+ * sizeofquery + (totalparamlength*2) + 1
+ *
+ * The +1 is for null-terminating the string for mysql_real_escape_string
+ */
+
+ query = new char[req.query.q.length() + (paramlen*2) + 1];
+ queryend = query;
+
+ /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
+ * the parameters into it...
+ */
+
+ for(unsigned long i = 0; i < req.query.q.length(); i++)
+ {
+ if(req.query.q[i] == '?')
+ {
+ /* We found a place to substitute..what fun.
+ * use mysql calls to escape and write the
+ * escaped string onto the end of our query buffer,
+ * then we "just" need to make sure queryend is
+ * pointing at the right place.
+ */
+ if(req.query.p.size())
+ {
+ unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
+
+ queryend += len;
+ req.query.p.pop_front();
+ }
+ else
+ break;
+ }
+ else
+ {
+ *queryend = req.query.q[i];
+ queryend++;
+ }
+ querylength++;
+ }
+
+ *queryend = 0;
+
+ pthread_mutex_lock(&queue_mutex);
+ req.query.q = query;
+ pthread_mutex_unlock(&queue_mutex);
+
+ if (!mysql_real_query(&connection, req.query.q.data(), req.query.q.length()))
+ {
+ /* Successfull query */
+ res = mysql_use_result(&connection);
+ unsigned long rows = mysql_affected_rows(&connection);
+ MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), res, rows, req.id);
+ r->dbid = this->GetID();
+ r->query = req.query.q;
+ /* Put this new result onto the results queue.
+ * XXX: Remember to mutex the queue!
+ */
+ pthread_mutex_lock(&results_mutex);
+ rq.push_back(r);
+ pthread_mutex_unlock(&results_mutex);
+ }
+ else
+ {
+ /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
+ * possible error numbers and error messages */
+ SQLerror e(QREPLY_FAIL, ConvToStr(mysql_errno(&connection)) + std::string(": ") + mysql_error(&connection));
+ MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), e, req.id);
+ r->dbid = this->GetID();
+ r->query = req.query.q;
+
+ pthread_mutex_lock(&results_mutex);
+ rq.push_back(r);
+ pthread_mutex_unlock(&results_mutex);
+ }
+
+ /* Now signal the main thread that we've got a result to process.
+ * Pass them this connection id as what to examine
+ */
+
+ delete[] query;
+
+ NotifyMainThread(this);
+ }
+
+ bool ConnectionLost()
+ {
+ if (&connection) {
+ return (mysql_ping(&connection) != 0);
+ }
+ else return false;
+ }
+
+ bool CheckConnection()
+ {
+ if (ConnectionLost()) {
+ return Connect();
+ }
+ else return true;
+ }
+
+ std::string GetError()
+ {
+ return mysql_error(&connection);
+ }
+
+ const std::string& GetID()
+ {
+ return host.id;
+ }
+
+ std::string GetHost()
+ {
+ return host.host;
+ }
+
+ void SetEnable(bool Enable)
+ {
+ Enabled = Enable;
+ }
+
+ bool IsEnabled()
+ {
+ return Enabled;
+ }
+
+ void Close()
+ {
+ mysql_close(&connection);
+ }
+
+ const SQLhost& GetConfHost()
+ {
+ return host;
+ }
+
+};
+
+ConnMap Connections;
+
+bool HasHost(const SQLhost &host)
+{
+ for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++)
+ {
+ if (host == iter->second->GetConfHost())
+ return true;
+ }
+ return false;
+}
+
+bool HostInConf(ConfigReader* conf, const SQLhost &h)
+{
+ for(int i = 0; i < conf->Enumerate("database"); i++)
+ {
+ SQLhost host;
+ host.id = conf->ReadValue("database", "id", i);
+ host.host = conf->ReadValue("database", "hostname", i);
+ host.port = conf->ReadInteger("database", "port", i, true);
+ host.name = conf->ReadValue("database", "name", i);
+ host.user = conf->ReadValue("database", "username", i);
+ host.pass = conf->ReadValue("database", "password", i);
+ host.ssl = conf->ReadFlag("database", "ssl", i);
+ if (h == host)
+ return true;
+ }
+ return false;
+}
+
+void ClearOldConnections(ConfigReader* conf)
+{
+ ConnMap::iterator i,safei;
+ for (i = Connections.begin(); i != Connections.end(); i++)
+ {
+ if (!HostInConf(conf, i->second->GetConfHost()))
+ {
+ DELETE(i->second);
+ safei = i;
+ --i;
+ Connections.erase(safei);
+ }
+ }
+}
+
+void ClearAllConnections()
+{
+ ConnMap::iterator i;
+ while ((i = Connections.begin()) != Connections.end())
+ {
+ Connections.erase(i);
+ DELETE(i->second);
+ }
+}
+
+void ConnectDatabases(InspIRCd* ServerInstance)
+{
+ for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
+ {
+ if (i->second->IsEnabled())
+ continue;
+
+ i->second->SetEnable(true);
+ if (!i->second->Connect())
+ {
+ /* XXX: MUTEX */
+ pthread_mutex_lock(&logging_mutex);
+ ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
+ i->second->SetEnable(false);
+ pthread_mutex_unlock(&logging_mutex);
+ }
+ }
+}
+
+void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance)
+{
+ ClearOldConnections(conf);
+ for (int j =0; j < conf->Enumerate("database"); j++)
+ {
+ SQLhost host;
+ host.id = conf->ReadValue("database", "id", j);
+ host.host = conf->ReadValue("database", "hostname", j);
+ host.port = conf->ReadInteger("database", "port", j, true);
+ host.name = conf->ReadValue("database", "name", j);
+ host.user = conf->ReadValue("database", "username", j);
+ host.pass = conf->ReadValue("database", "password", j);
+ host.ssl = conf->ReadFlag("database", "ssl", j);
+
+ if (HasHost(host))
+ continue;
+
+ if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty())
+ {
+ SQLConnection* ThisSQL = new SQLConnection(host);
+ Connections[host.id] = ThisSQL;
+ }
+ }
+ ConnectDatabases(ServerInstance);
+}
+
+char FindCharId(const std::string &id)
+{
+ char i = 1;
+ for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
+ {
+ if (iter->first == id)
+ {
+ return i;
+ }
+ }
+ return 0;
+}
+
+ConnMap::iterator GetCharId(char id)
+{
+ char i = 1;
+ for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
+ {
+ if (i == id)
+ return iter;
+ }
+ return Connections.end();
+}
+
+void NotifyMainThread(SQLConnection* connection_with_new_result)
+{
+ /* Here we write() to the socket the main thread has open
+ * and we connect()ed back to before our thread became active.
+ * The main thread is using a nonblocking socket tied into
+ * the socket engine, so they wont block and they'll receive
+ * nearly instant notification. Because we're in a seperate
+ * thread, we can just use standard connect(), and we can
+ * block if we like. We just send the connection id of the
+ * connection back.
+ *
+ * NOTE: We only send a single char down the connection, this
+ * way we know it wont get a partial read at the other end if
+ * the system is especially congested (see bug #263).
+ * The function FindCharId translates a connection name into a
+ * one character id, and GetCharId translates a character id
+ * back into an iterator.
+ */
+ char id = FindCharId(connection_with_new_result->GetID());
+ send(QueueFD, &id, 1, 0);
+}
+
+void* DispatcherThread(void* arg);
+
+/** Used by m_mysql to notify one thread when the other has a result
+ */
+class Notifier : public InspSocket
+{
+ insp_sockaddr sock_us;
+ socklen_t uslen;
+
+
+ public:
+
+ /* Create a socket on a random port. Let the tcp stack allocate us an available port */
+#ifdef IPV6
+ Notifier(InspIRCd* SI) : InspSocket(SI, "::1", 0, true, 3000)
+#else
+ Notifier(InspIRCd* SI) : InspSocket(SI, "127.0.0.1", 0, true, 3000)
+#endif
+ {
+ uslen = sizeof(sock_us);
+ if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
+ {
+ throw ModuleException("Could not create random listening port on localhost");
+ }
+ }
+
+ Notifier(InspIRCd* SI, int newfd, char* ip) : InspSocket(SI, newfd, ip)
+ {
+ }
+
+ /* Using getsockname and ntohs, we can determine which port number we were allocated */
+ int GetPort()
+ {
+#ifdef IPV6
+ return ntohs(sock_us.sin6_port);
+#else
+ return ntohs(sock_us.sin_port);
+#endif
+ }
+
+ virtual int OnIncomingConnection(int newsock, char* ip)
+ {
+ Notifier* n = new Notifier(this->Instance, newsock, ip);
+ n = n; /* Stop bitching at me, GCC */
+ return true;
+ }
+
+ virtual bool OnDataReady()
+ {
+ char data = 0;
+ /* NOTE: Only a single character is read so we know we
+ * cant get a partial read. (We've been told that theres
+ * data waiting, so we wont ever get EAGAIN)
+ * The function GetCharId translates a single character
+ * back into an iterator.
+ */
+ if (read(this->GetFd(), &data, 1) > 0)
+ {
+ ConnMap::iterator iter = GetCharId(data);
+ if (iter != Connections.end())
+ {
+ /* Lock the mutex, send back the data */
+ pthread_mutex_lock(&results_mutex);
+ ResultQueue::iterator n = iter->second->rq.begin();
+ (*n)->Send();
+ iter->second->rq.pop_front();
+ pthread_mutex_unlock(&results_mutex);
+ return true;
+ }
+ /* No error, but unknown id */
+ return true;
+ }
+
+ /* Erk, error on descriptor! */
+ return false;
+ }
+};
+
+/** MySQL module
+ */
+class ModuleSQL : public Module
+{
+ public:
+
+ ConfigReader *Conf;
+ InspIRCd* PublicServerInstance;
+ pthread_t Dispatcher;
+ int currid;
+ bool rehashing;
+
+ ModuleSQL(InspIRCd* Me)
+ : Module::Module(Me), rehashing(false)
+ {
+ ServerInstance->UseInterface("SQLutils");
+
+ Conf = new ConfigReader(ServerInstance);
+ PublicServerInstance = ServerInstance;
+ currid = 0;
+ SQLModule = this;
+
+ MessagePipe = new Notifier(ServerInstance);
+
+ pthread_attr_t attribs;
+ pthread_attr_init(&attribs);
+ pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
+ if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
+ {
+ throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno)));
+ }
+
+ if (!ServerInstance->PublishFeature("SQL", this))
+ {
+ /* Tell worker thread to exit NOW */
+ giveup = true;
+ throw ModuleException("m_mysql: Unable to publish feature 'SQL'");
+ }
+
+ ServerInstance->PublishInterface("SQL", this);
+ }
+
+ virtual ~ModuleSQL()
+ {
+ giveup = true;
+ ClearAllConnections();
+ DELETE(Conf);
+ ServerInstance->UnpublishInterface("SQL", this);
+ ServerInstance->UnpublishFeature("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ }
+
+
+ void Implements(char* List)
+ {
+ List[I_OnRehash] = List[I_OnRequest] = 1;
+ }
+
+ unsigned long NewID()
+ {
+ if (currid+1 == 0)
+ currid++;
+ return ++currid;
+ }
+
+ char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLREQID, request->GetId()) == 0)
+ {
+ SQLrequest* req = (SQLrequest*)request;
+
+ /* XXX: Lock */
+ pthread_mutex_lock(&queue_mutex);
+
+ ConnMap::iterator iter;
+
+ char* returnval = NULL;
+
+ if((iter = Connections.find(req->dbid)) != Connections.end())
+ {
+ req->id = NewID();
+ iter->second->queue.push(*req);
+ returnval = SQLSUCCESS;
+ }
+ else
+ {
+ req->error.Id(BAD_DBID);
+ }
+
+ pthread_mutex_unlock(&queue_mutex);
+ /* XXX: Unlock */
+
+ return returnval;
+ }
+
+ return NULL;
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ rehashing = true;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
+ }
+
+};
+
+void* DispatcherThread(void* arg)
+{
+ ModuleSQL* thismodule = (ModuleSQL*)arg;
+ LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance);
+
+ /* Connect back to the Notifier */
+
+ if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
+ {
+ /* crap, we're out of sockets... */
+ return NULL;
+ }
+
+ insp_sockaddr addr;
+
+#ifdef IPV6
+ insp_aton("::1", &addr.sin6_addr);
+ addr.sin6_family = AF_FAMILY;
+ addr.sin6_port = htons(MessagePipe->GetPort());
+#else
+ insp_inaddr ia;
+ insp_aton("127.0.0.1", &ia);
+ addr.sin_family = AF_FAMILY;
+ addr.sin_addr = ia;
+ addr.sin_port = htons(MessagePipe->GetPort());
+#endif
+
+ if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
+ {
+ /* wtf, we cant connect to it, but we just created it! */
+ return NULL;
+ }
+
+ while (!giveup)
+ {
+ if (thismodule->rehashing)
+ {
+ /* XXX: Lock */
+ pthread_mutex_lock(&queue_mutex);
+ thismodule->rehashing = false;
+ LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance);
+ pthread_mutex_unlock(&queue_mutex);
+ /* XXX: Unlock */
+ }
+
+ SQLConnection* conn = NULL;
+ /* XXX: Lock here for safety */
+ pthread_mutex_lock(&queue_mutex);
+ for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
+ {
+ if (i->second->queue.totalsize())
+ {
+ conn = i->second;
+ break;
+ }
+ }
+ pthread_mutex_unlock(&queue_mutex);
+ /* XXX: Unlock */
+
+ /* Theres an item! */
+ if (conn)
+ {
+ conn->DoLeadingQuery();
+
+ /* XXX: Lock */
+ pthread_mutex_lock(&queue_mutex);
+ conn->queue.pop();
+ pthread_mutex_unlock(&queue_mutex);
+ /* XXX: Unlock */
+ }
+
+ usleep(50);
+ }
+
+ return NULL;
+}
+
+MODULE_INIT(ModuleSQL);
+
diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp
index 9e85a40de..5d267fc1a 100644
--- a/src/modules/extra/m_pgsql.cpp
+++ b/src/modules/extra/m_pgsql.cpp
@@ -1 +1,984 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <cstdlib> #include <sstream> #include <libpq-fe.h> #include "users.h" #include "channels.h" #include "modules.h" #include "configreader.h" #include "m_sqlv2.h" /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */ /* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */ /* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */ /* $ModDep: m_sqlv2.h */ /* SQLConn rewritten by peavey to * use EventHandler instead of * InspSocket. This is much neater * and gives total control of destroy * and delete of resources. */ /* Forward declare, so we can have the typedef neatly at the top */ class SQLConn; typedef std::map<std::string, SQLConn*> ConnMap; /* CREAD, Connecting and wants read event * CWRITE, Connecting and wants write event * WREAD, Connected/Working and wants read event * WWRITE, Connected/Working and wants write event * RREAD, Resetting and wants read event * RWRITE, Resetting and wants write event */ enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE }; /** SQLhost::GetDSN() - Overload to return correct DSN for PostgreSQL */ std::string SQLhost::GetDSN() { std::ostringstream conninfo("connect_timeout = '2'"); if (ip.length()) conninfo << " hostaddr = '" << ip << "'"; if (port) conninfo << " port = '" << port << "'"; if (name.length()) conninfo << " dbname = '" << name << "'"; if (user.length()) conninfo << " user = '" << user << "'"; if (pass.length()) conninfo << " password = '" << pass << "'"; if (ssl) { conninfo << " sslmode = 'require'"; } else { conninfo << " sslmode = 'disable'"; } return conninfo.str(); } class ReconnectTimer : public InspTimer { private: Module* mod; public: ReconnectTimer(InspIRCd* SI, Module* m) : InspTimer(5, SI->Time(), false), mod(m) { } virtual void Tick(time_t TIME); }; /** Used to resolve sql server hostnames */ class SQLresolver : public Resolver { private: SQLhost host; Module* mod; public: SQLresolver(Module* m, InspIRCd* Instance, const SQLhost& hi, bool &cached) : Resolver(Instance, hi.host, DNS_QUERY_FORWARD, cached, (Module*)m), host(hi), mod(m) { } virtual void OnLookupComplete(const std::string &result, unsigned int ttl, bool cached); virtual void OnError(ResolverError e, const std::string &errormessage) { ServerInstance->Log(DEBUG, "PgSQL: DNS lookup failed (%s), dying horribly", errormessage.c_str()); } }; /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult. * All SQL providers must create their own subclass and define it's methods using that * database library's data retriveal functions. The aim is to avoid a slow and inefficient process * of converting all data to a common format before it reaches the result structure. This way * data is passes to the module nearly as directly as if it was using the API directly itself. */ class PgSQLresult : public SQLresult { PGresult* res; int currentrow; int rows; int cols; SQLfieldList* fieldlist; SQLfieldMap* fieldmap; public: PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result) : SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL) { rows = PQntuples(res); cols = PQnfields(res); } ~PgSQLresult() { /* If we allocated these, free them... */ if(fieldlist) DELETE(fieldlist); if(fieldmap) DELETE(fieldmap); PQclear(res); } virtual int Rows() { if(!cols && !rows) { return atoi(PQcmdTuples(res)); } else { return rows; } } virtual int Cols() { return PQnfields(res); } virtual std::string ColName(int column) { char* name = PQfname(res, column); return (name) ? name : ""; } virtual int ColNum(const std::string &column) { int n = PQfnumber(res, column.c_str()); if(n == -1) { throw SQLbadColName(); } else { return n; } } virtual SQLfield GetValue(int row, int column) { char* v = PQgetvalue(res, row, column); if(v) { return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column)); } else { throw SQLbadColName(); } } virtual SQLfieldList& GetRow() { /* In an effort to reduce overhead we don't actually allocate the list * until the first time it's needed...so... */ if(fieldlist) { fieldlist->clear(); } else { fieldlist = new SQLfieldList; } if(currentrow < PQntuples(res)) { int cols = PQnfields(res); for(int i = 0; i < cols; i++) { fieldlist->push_back(GetValue(currentrow, i)); } currentrow++; } return *fieldlist; } virtual SQLfieldMap& GetRowMap() { /* In an effort to reduce overhead we don't actually allocate the map * until the first time it's needed...so... */ if(fieldmap) { fieldmap->clear(); } else { fieldmap = new SQLfieldMap; } if(currentrow < PQntuples(res)) { int cols = PQnfields(res); for(int i = 0; i < cols; i++) { fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); } currentrow++; } return *fieldmap; } virtual SQLfieldList* GetRowPtr() { SQLfieldList* fl = new SQLfieldList; if(currentrow < PQntuples(res)) { int cols = PQnfields(res); for(int i = 0; i < cols; i++) { fl->push_back(GetValue(currentrow, i)); } currentrow++; } return fl; } virtual SQLfieldMap* GetRowMapPtr() { SQLfieldMap* fm = new SQLfieldMap; if(currentrow < PQntuples(res)) { int cols = PQnfields(res); for(int i = 0; i < cols; i++) { fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); } currentrow++; } return fm; } virtual void Free(SQLfieldMap* fm) { DELETE(fm); } virtual void Free(SQLfieldList* fl) { DELETE(fl); } }; /** SQLConn represents one SQL session. */ class SQLConn : public EventHandler { private: InspIRCd* Instance; SQLhost confhost; /* The <database> entry */ Module* us; /* Pointer to the SQL provider itself */ PGconn* sql; /* PgSQL database connection handle */ SQLstatus status; /* PgSQL database connection status */ bool qinprog; /* If there is currently a query in progress */ QueryQueue queue; /* Queue of queries waiting to be executed on this connection */ time_t idle; /* Time we last heard from the database */ public: SQLConn(InspIRCd* SI, Module* self, const SQLhost& hi) : EventHandler(), Instance(SI), confhost(hi), us(self), sql(NULL), status(CWRITE), qinprog(false) { idle = this->Instance->Time(); if(!DoConnect()) { Instance->Log(DEFAULT, "WARNING: Could not connect to database with id: " + ConvToStr(hi.id)); DelayReconnect(); } } ~SQLConn() { Close(); } virtual void HandleEvent(EventType et, int errornum) { switch (et) { case EVENT_READ: OnDataReady(); break; case EVENT_WRITE: OnWriteReady(); break; case EVENT_ERROR: DelayReconnect(); break; default: break; } } bool DoConnect() { if(!(sql = PQconnectStart(confhost.GetDSN().c_str()))) return false; if(PQstatus(sql) == CONNECTION_BAD) return false; if(PQsetnonblocking(sql, 1) == -1) return false; /* OK, we've initalised the connection, now to get it hooked into the socket engine * and then start polling it. */ this->fd = PQsocket(sql); if(this->fd <= -1) return false; if (!this->Instance->SE->AddFd(this)) { Instance->Log(DEBUG, "BUG: Couldn't add pgsql socket to socket engine"); return false; } /* Socket all hooked into the engine, now to tell PgSQL to start connecting */ return DoPoll(); } bool DoPoll() { switch(PQconnectPoll(sql)) { case PGRES_POLLING_WRITING: Instance->SE->WantWrite(this); status = CWRITE; return true; case PGRES_POLLING_READING: status = CREAD; return true; case PGRES_POLLING_FAILED: return false; case PGRES_POLLING_OK: status = WWRITE; return DoConnectedPoll(); default: return true; } } bool DoConnectedPoll() { if(!qinprog && queue.totalsize()) { /* There's no query currently in progress, and there's queries in the queue. */ SQLrequest& query = queue.front(); DoQuery(query); } if(PQconsumeInput(sql)) { /* We just read stuff from the server, that counts as it being alive * so update the idle-since time :p */ idle = this->Instance->Time(); if (PQisBusy(sql)) { /* Nothing happens here */ } else if (qinprog) { /* Grab the request we're processing */ SQLrequest& query = queue.front(); /* Get a pointer to the module we're about to return the result to */ Module* to = query.GetSource(); /* Fetch the result.. */ PGresult* result = PQgetResult(sql); /* PgSQL would allow a query string to be sent which has multiple * queries in it, this isn't portable across database backends and * we don't want modules doing it. But just in case we make sure we * drain any results there are and just use the last one. * If the module devs are behaving there will only be one result. */ while (PGresult* temp = PQgetResult(sql)) { PQclear(result); result = temp; } if(to) { /* ..and the result */ PgSQLresult reply(us, to, query.id, result); /* Fix by brain, make sure the original query gets sent back in the reply */ reply.query = query.query.q; switch(PQresultStatus(result)) { case PGRES_EMPTY_QUERY: case PGRES_BAD_RESPONSE: case PGRES_FATAL_ERROR: reply.error.Id(QREPLY_FAIL); reply.error.Str(PQresultErrorMessage(result)); default:; /* No action, other values are not errors */ } reply.Send(); /* PgSQLresult's destructor will free the PGresult */ } else { /* If the client module is unloaded partway through a query then the provider will set * the pointer to NULL. We cannot just cancel the query as the result will still come * through at some point...and it could get messy if we play with invalid pointers... */ PQclear(result); } qinprog = false; queue.pop(); DoConnectedPoll(); } return true; } else { /* I think we'll assume this means the server died...it might not, * but I think that any error serious enough we actually get here * deserves to reconnect [/excuse] * Returning true so the core doesn't try and close the connection. */ DelayReconnect(); return true; } } bool DoResetPoll() { switch(PQresetPoll(sql)) { case PGRES_POLLING_WRITING: Instance->SE->WantWrite(this); status = CWRITE; return DoPoll(); case PGRES_POLLING_READING: status = CREAD; return true; case PGRES_POLLING_FAILED: return false; case PGRES_POLLING_OK: status = WWRITE; return DoConnectedPoll(); default: return true; } } bool OnDataReady() { /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */ return DoEvent(); } bool OnWriteReady() { /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */ return DoEvent(); } bool OnConnected() { return DoEvent(); } void DelayReconnect(); bool DoEvent() { bool ret; if((status == CREAD) || (status == CWRITE)) { ret = DoPoll(); } else if((status == RREAD) || (status == RWRITE)) { ret = DoResetPoll(); } else { ret = DoConnectedPoll(); } return ret; } SQLerror DoQuery(SQLrequest &req) { if((status == WREAD) || (status == WWRITE)) { if(!qinprog) { /* Parse the command string and dispatch it */ /* Pointer to the buffer we screw around with substitution in */ char* query; /* Pointer to the current end of query, where we append new stuff */ char* queryend; /* Total length of the unescaped parameters */ unsigned int paramlen; paramlen = 0; for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) { paramlen += i->size(); } /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. * sizeofquery + (totalparamlength*2) + 1 * * The +1 is for null-terminating the string for PQsendQuery() */ query = new char[req.query.q.length() + (paramlen*2) + 1]; queryend = query; /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting * the parameters into it... */ for(unsigned int i = 0; i < req.query.q.length(); i++) { if(req.query.q[i] == '?') { /* We found a place to substitute..what fun. * Use the PgSQL calls to escape and write the * escaped string onto the end of our query buffer, * then we "just" need to make sure queryend is * pointing at the right place. */ if(req.query.p.size()) { int error = 0; size_t len = 0; #ifdef PGSQL_HAS_ESCAPECONN len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error); #else len = PQescapeString (queryend, req.query.p.front().c_str(), req.query.p.front().length()); #endif if(error) { Instance->Log(DEBUG, "BUG: Apparently PQescapeStringConn() failed somehow...don't know how or what to do..."); } /* Incremenet queryend to the end of the newly escaped parameter */ queryend += len; /* Remove the parameter we just substituted in */ req.query.p.pop_front(); } else { Instance->Log(DEBUG, "BUG: Found a substitution location but no parameter to substitute :|"); break; } } else { *queryend = req.query.q[i]; queryend++; } } /* Null-terminate the query */ *queryend = 0; req.query.q = query; if(PQsendQuery(sql, query)) { qinprog = true; delete[] query; return SQLerror(); } else { delete[] query; return SQLerror(QSEND_FAIL, PQerrorMessage(sql)); } } } return SQLerror(BAD_CONN, "Can't query until connection is complete"); } SQLerror Query(const SQLrequest &req) { queue.push(req); if(!qinprog && queue.totalsize()) { /* There's no query currently in progress, and there's queries in the queue. */ SQLrequest& query = queue.front(); return DoQuery(query); } else { return SQLerror(); } } void OnUnloadModule(Module* mod) { queue.PurgeModule(mod); } const SQLhost GetConfHost() { return confhost; } void Close() { if (!this->Instance->SE->DelFd(this)) { if (sql && PQstatus(sql) == CONNECTION_BAD) { this->Instance->SE->DelFd(this, true); } else { Instance->Log(DEBUG, "BUG: PQsocket cant be removed from socket engine!"); } } if(sql) { PQfinish(sql); sql = NULL; } } }; class ModulePgSQL : public Module { private: ConnMap connections; unsigned long currid; char* sqlsuccess; ReconnectTimer* retimer; public: ModulePgSQL(InspIRCd* Me) : Module::Module(Me), currid(0) { ServerInstance->UseInterface("SQLutils"); sqlsuccess = new char[strlen(SQLSUCCESS)+1]; strlcpy(sqlsuccess, SQLSUCCESS, strlen(SQLSUCCESS)); if (!ServerInstance->PublishFeature("SQL", this)) { throw ModuleException("BUG: PgSQL Unable to publish feature 'SQL'"); } ReadConf(); ServerInstance->PublishInterface("SQL", this); } virtual ~ModulePgSQL() { if (retimer) ServerInstance->Timers->DelTimer(retimer); ClearAllConnections(); delete[] sqlsuccess; ServerInstance->UnpublishInterface("SQL", this); ServerInstance->UnpublishFeature("SQL"); ServerInstance->DoneWithInterface("SQLutils"); } void Implements(char* List) { List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1; } virtual void OnRehash(userrec* user, const std::string &parameter) { ReadConf(); } bool HasHost(const SQLhost &host) { for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { if (host == iter->second->GetConfHost()) return true; } return false; } bool HostInConf(const SQLhost &h) { ConfigReader conf(ServerInstance); for(int i = 0; i < conf.Enumerate("database"); i++) { SQLhost host; host.id = conf.ReadValue("database", "id", i); host.host = conf.ReadValue("database", "hostname", i); host.port = conf.ReadInteger("database", "port", i, true); host.name = conf.ReadValue("database", "name", i); host.user = conf.ReadValue("database", "username", i); host.pass = conf.ReadValue("database", "password", i); host.ssl = conf.ReadFlag("database", "ssl", "0", i); if (h == host) return true; } return false; } void ReadConf() { ClearOldConnections(); ConfigReader conf(ServerInstance); for(int i = 0; i < conf.Enumerate("database"); i++) { SQLhost host; int ipvalid; host.id = conf.ReadValue("database", "id", i); host.host = conf.ReadValue("database", "hostname", i); host.port = conf.ReadInteger("database", "port", i, true); host.name = conf.ReadValue("database", "name", i); host.user = conf.ReadValue("database", "username", i); host.pass = conf.ReadValue("database", "password", i); host.ssl = conf.ReadFlag("database", "ssl", "0", i); if (HasHost(host)) continue; #ifdef IPV6 if (strchr(host.host.c_str(),':')) { in6_addr blargle; ipvalid = inet_pton(AF_INET6, host.host.c_str(), &blargle); } else #endif { in_addr blargle; ipvalid = inet_aton(host.host.c_str(), &blargle); } if(ipvalid > 0) { /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ host.ip = host.host; this->AddConn(host); } else if(ipvalid == 0) { /* Conversion failed, assume it's a host */ SQLresolver* resolver; try { bool cached; resolver = new SQLresolver(this, ServerInstance, host, cached); ServerInstance->AddResolver(resolver, cached); } catch(...) { /* THE WORLD IS COMING TO AN END! */ } } else { /* Invalid address family, die horribly. */ ServerInstance->Log(DEBUG, "BUG: insp_aton failed returning -1, oh noes."); } } } void ClearOldConnections() { ConnMap::iterator iter,safei; for (iter = connections.begin(); iter != connections.end(); iter++) { if (!HostInConf(iter->second->GetConfHost())) { DELETE(iter->second); safei = iter; --iter; connections.erase(safei); } } } void ClearAllConnections() { ConnMap::iterator i; while ((i = connections.begin()) != connections.end()) { connections.erase(i); DELETE(i->second); } } void AddConn(const SQLhost& hi) { if (HasHost(hi)) { ServerInstance->Log(DEFAULT, "WARNING: A pgsql connection with id: %s already exists, possibly due to DNS delay. Aborting connection attempt.", hi.id.c_str()); return; } SQLConn* newconn; /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ newconn = new SQLConn(ServerInstance, this, hi); connections.insert(std::make_pair(hi.id, newconn)); } void ReconnectConn(SQLConn* conn) { for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { if (conn == iter->second) { DELETE(iter->second); connections.erase(iter); break; } } retimer = new ReconnectTimer(ServerInstance, this); ServerInstance->Timers->AddTimer(retimer); } virtual char* OnRequest(Request* request) { if(strcmp(SQLREQID, request->GetId()) == 0) { SQLrequest* req = (SQLrequest*)request; ConnMap::iterator iter; if((iter = connections.find(req->dbid)) != connections.end()) { /* Execute query */ req->id = NewID(); req->error = iter->second->Query(*req); return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL; } else { req->error.Id(BAD_DBID); return NULL; } } return NULL; } virtual void OnUnloadModule(Module* mod, const std::string& name) { /* When a module unloads we have to check all the pending queries for all our connections * and set the Module* specifying where the query came from to NULL. If the query has already * been dispatched then when it is processed it will be dropped if the pointer is NULL. * * If the queries we find are not already being executed then we can simply remove them immediately. */ for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { iter->second->OnUnloadModule(mod); } } unsigned long NewID() { if (currid+1 == 0) currid++; return ++currid; } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION); } }; /* move this here to use AddConn, rather that than having the whole * module above SQLConn, since this is buggin me right now :/ */ void SQLresolver::OnLookupComplete(const std::string &result, unsigned int ttl, bool cached) { host.ip = result; ((ModulePgSQL*)mod)->AddConn(host); ((ModulePgSQL*)mod)->ClearOldConnections(); } void ReconnectTimer::Tick(time_t time) { ((ModulePgSQL*)mod)->ReadConf(); } void SQLConn::DelayReconnect() { ((ModulePgSQL*)us)->ReconnectConn(this); } MODULE_INIT(ModulePgSQL); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <cstdlib>
+#include <sstream>
+#include <libpq-fe.h>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "configreader.h"
+#include "m_sqlv2.h"
+
+/* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
+/* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
+/* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
+/* $ModDep: m_sqlv2.h */
+
+
+/* SQLConn rewritten by peavey to
+ * use EventHandler instead of
+ * InspSocket. This is much neater
+ * and gives total control of destroy
+ * and delete of resources.
+ */
+
+/* Forward declare, so we can have the typedef neatly at the top */
+class SQLConn;
+
+typedef std::map<std::string, SQLConn*> ConnMap;
+
+/* CREAD, Connecting and wants read event
+ * CWRITE, Connecting and wants write event
+ * WREAD, Connected/Working and wants read event
+ * WWRITE, Connected/Working and wants write event
+ * RREAD, Resetting and wants read event
+ * RWRITE, Resetting and wants write event
+ */
+enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
+
+/** SQLhost::GetDSN() - Overload to return correct DSN for PostgreSQL
+ */
+std::string SQLhost::GetDSN()
+{
+ std::ostringstream conninfo("connect_timeout = '2'");
+
+ if (ip.length())
+ conninfo << " hostaddr = '" << ip << "'";
+
+ if (port)
+ conninfo << " port = '" << port << "'";
+
+ if (name.length())
+ conninfo << " dbname = '" << name << "'";
+
+ if (user.length())
+ conninfo << " user = '" << user << "'";
+
+ if (pass.length())
+ conninfo << " password = '" << pass << "'";
+
+ if (ssl)
+ {
+ conninfo << " sslmode = 'require'";
+ }
+ else
+ {
+ conninfo << " sslmode = 'disable'";
+ }
+
+ return conninfo.str();
+}
+
+class ReconnectTimer : public InspTimer
+{
+ private:
+ Module* mod;
+ public:
+ ReconnectTimer(InspIRCd* SI, Module* m)
+ : InspTimer(5, SI->Time(), false), mod(m)
+ {
+ }
+ virtual void Tick(time_t TIME);
+};
+
+
+/** Used to resolve sql server hostnames
+ */
+class SQLresolver : public Resolver
+{
+ private:
+ SQLhost host;
+ Module* mod;
+ public:
+ SQLresolver(Module* m, InspIRCd* Instance, const SQLhost& hi, bool &cached)
+ : Resolver(Instance, hi.host, DNS_QUERY_FORWARD, cached, (Module*)m), host(hi), mod(m)
+ {
+ }
+
+ virtual void OnLookupComplete(const std::string &result, unsigned int ttl, bool cached);
+
+ virtual void OnError(ResolverError e, const std::string &errormessage)
+ {
+ ServerInstance->Log(DEBUG, "PgSQL: DNS lookup failed (%s), dying horribly", errormessage.c_str());
+ }
+};
+
+/** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
+ * All SQL providers must create their own subclass and define it's methods using that
+ * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
+ * of converting all data to a common format before it reaches the result structure. This way
+ * data is passes to the module nearly as directly as if it was using the API directly itself.
+ */
+
+class PgSQLresult : public SQLresult
+{
+ PGresult* res;
+ int currentrow;
+ int rows;
+ int cols;
+
+ SQLfieldList* fieldlist;
+ SQLfieldMap* fieldmap;
+public:
+ PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result)
+ : SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL)
+ {
+ rows = PQntuples(res);
+ cols = PQnfields(res);
+ }
+
+ ~PgSQLresult()
+ {
+ /* If we allocated these, free them... */
+ if(fieldlist)
+ DELETE(fieldlist);
+
+ if(fieldmap)
+ DELETE(fieldmap);
+
+ PQclear(res);
+ }
+
+ virtual int Rows()
+ {
+ if(!cols && !rows)
+ {
+ return atoi(PQcmdTuples(res));
+ }
+ else
+ {
+ return rows;
+ }
+ }
+
+ virtual int Cols()
+ {
+ return PQnfields(res);
+ }
+
+ virtual std::string ColName(int column)
+ {
+ char* name = PQfname(res, column);
+
+ return (name) ? name : "";
+ }
+
+ virtual int ColNum(const std::string &column)
+ {
+ int n = PQfnumber(res, column.c_str());
+
+ if(n == -1)
+ {
+ throw SQLbadColName();
+ }
+ else
+ {
+ return n;
+ }
+ }
+
+ virtual SQLfield GetValue(int row, int column)
+ {
+ char* v = PQgetvalue(res, row, column);
+
+ if(v)
+ {
+ return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column));
+ }
+ else
+ {
+ throw SQLbadColName();
+ }
+ }
+
+ virtual SQLfieldList& GetRow()
+ {
+ /* In an effort to reduce overhead we don't actually allocate the list
+ * until the first time it's needed...so...
+ */
+ if(fieldlist)
+ {
+ fieldlist->clear();
+ }
+ else
+ {
+ fieldlist = new SQLfieldList;
+ }
+
+ if(currentrow < PQntuples(res))
+ {
+ int cols = PQnfields(res);
+
+ for(int i = 0; i < cols; i++)
+ {
+ fieldlist->push_back(GetValue(currentrow, i));
+ }
+
+ currentrow++;
+ }
+
+ return *fieldlist;
+ }
+
+ virtual SQLfieldMap& GetRowMap()
+ {
+ /* In an effort to reduce overhead we don't actually allocate the map
+ * until the first time it's needed...so...
+ */
+ if(fieldmap)
+ {
+ fieldmap->clear();
+ }
+ else
+ {
+ fieldmap = new SQLfieldMap;
+ }
+
+ if(currentrow < PQntuples(res))
+ {
+ int cols = PQnfields(res);
+
+ for(int i = 0; i < cols; i++)
+ {
+ fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
+ }
+
+ currentrow++;
+ }
+
+ return *fieldmap;
+ }
+
+ virtual SQLfieldList* GetRowPtr()
+ {
+ SQLfieldList* fl = new SQLfieldList;
+
+ if(currentrow < PQntuples(res))
+ {
+ int cols = PQnfields(res);
+
+ for(int i = 0; i < cols; i++)
+ {
+ fl->push_back(GetValue(currentrow, i));
+ }
+
+ currentrow++;
+ }
+
+ return fl;
+ }
+
+ virtual SQLfieldMap* GetRowMapPtr()
+ {
+ SQLfieldMap* fm = new SQLfieldMap;
+
+ if(currentrow < PQntuples(res))
+ {
+ int cols = PQnfields(res);
+
+ for(int i = 0; i < cols; i++)
+ {
+ fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
+ }
+
+ currentrow++;
+ }
+
+ return fm;
+ }
+
+ virtual void Free(SQLfieldMap* fm)
+ {
+ DELETE(fm);
+ }
+
+ virtual void Free(SQLfieldList* fl)
+ {
+ DELETE(fl);
+ }
+};
+
+/** SQLConn represents one SQL session.
+ */
+class SQLConn : public EventHandler
+{
+ private:
+ InspIRCd* Instance;
+ SQLhost confhost; /* The <database> entry */
+ Module* us; /* Pointer to the SQL provider itself */
+ PGconn* sql; /* PgSQL database connection handle */
+ SQLstatus status; /* PgSQL database connection status */
+ bool qinprog; /* If there is currently a query in progress */
+ QueryQueue queue; /* Queue of queries waiting to be executed on this connection */
+ time_t idle; /* Time we last heard from the database */
+
+ public:
+ SQLConn(InspIRCd* SI, Module* self, const SQLhost& hi)
+ : EventHandler(), Instance(SI), confhost(hi), us(self), sql(NULL), status(CWRITE), qinprog(false)
+ {
+ idle = this->Instance->Time();
+ if(!DoConnect())
+ {
+ Instance->Log(DEFAULT, "WARNING: Could not connect to database with id: " + ConvToStr(hi.id));
+ DelayReconnect();
+ }
+ }
+
+ ~SQLConn()
+ {
+ Close();
+ }
+
+ virtual void HandleEvent(EventType et, int errornum)
+ {
+ switch (et)
+ {
+ case EVENT_READ:
+ OnDataReady();
+ break;
+
+ case EVENT_WRITE:
+ OnWriteReady();
+ break;
+
+ case EVENT_ERROR:
+ DelayReconnect();
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ bool DoConnect()
+ {
+ if(!(sql = PQconnectStart(confhost.GetDSN().c_str())))
+ return false;
+
+ if(PQstatus(sql) == CONNECTION_BAD)
+ return false;
+
+ if(PQsetnonblocking(sql, 1) == -1)
+ return false;
+
+ /* OK, we've initalised the connection, now to get it hooked into the socket engine
+ * and then start polling it.
+ */
+ this->fd = PQsocket(sql);
+
+ if(this->fd <= -1)
+ return false;
+
+ if (!this->Instance->SE->AddFd(this))
+ {
+ Instance->Log(DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
+ return false;
+ }
+
+ /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
+ return DoPoll();
+ }
+
+ bool DoPoll()
+ {
+ switch(PQconnectPoll(sql))
+ {
+ case PGRES_POLLING_WRITING:
+ Instance->SE->WantWrite(this);
+ status = CWRITE;
+ return true;
+ case PGRES_POLLING_READING:
+ status = CREAD;
+ return true;
+ case PGRES_POLLING_FAILED:
+ return false;
+ case PGRES_POLLING_OK:
+ status = WWRITE;
+ return DoConnectedPoll();
+ default:
+ return true;
+ }
+ }
+
+ bool DoConnectedPoll()
+ {
+ if(!qinprog && queue.totalsize())
+ {
+ /* There's no query currently in progress, and there's queries in the queue. */
+ SQLrequest& query = queue.front();
+ DoQuery(query);
+ }
+
+ if(PQconsumeInput(sql))
+ {
+ /* We just read stuff from the server, that counts as it being alive
+ * so update the idle-since time :p
+ */
+ idle = this->Instance->Time();
+
+ if (PQisBusy(sql))
+ {
+ /* Nothing happens here */
+ }
+ else if (qinprog)
+ {
+ /* Grab the request we're processing */
+ SQLrequest& query = queue.front();
+
+ /* Get a pointer to the module we're about to return the result to */
+ Module* to = query.GetSource();
+
+ /* Fetch the result.. */
+ PGresult* result = PQgetResult(sql);
+
+ /* PgSQL would allow a query string to be sent which has multiple
+ * queries in it, this isn't portable across database backends and
+ * we don't want modules doing it. But just in case we make sure we
+ * drain any results there are and just use the last one.
+ * If the module devs are behaving there will only be one result.
+ */
+ while (PGresult* temp = PQgetResult(sql))
+ {
+ PQclear(result);
+ result = temp;
+ }
+
+ if(to)
+ {
+ /* ..and the result */
+ PgSQLresult reply(us, to, query.id, result);
+
+ /* Fix by brain, make sure the original query gets sent back in the reply */
+ reply.query = query.query.q;
+
+ switch(PQresultStatus(result))
+ {
+ case PGRES_EMPTY_QUERY:
+ case PGRES_BAD_RESPONSE:
+ case PGRES_FATAL_ERROR:
+ reply.error.Id(QREPLY_FAIL);
+ reply.error.Str(PQresultErrorMessage(result));
+ default:;
+ /* No action, other values are not errors */
+ }
+
+ reply.Send();
+
+ /* PgSQLresult's destructor will free the PGresult */
+ }
+ else
+ {
+ /* If the client module is unloaded partway through a query then the provider will set
+ * the pointer to NULL. We cannot just cancel the query as the result will still come
+ * through at some point...and it could get messy if we play with invalid pointers...
+ */
+ PQclear(result);
+ }
+ qinprog = false;
+ queue.pop();
+ DoConnectedPoll();
+ }
+ return true;
+ }
+ else
+ {
+ /* I think we'll assume this means the server died...it might not,
+ * but I think that any error serious enough we actually get here
+ * deserves to reconnect [/excuse]
+ * Returning true so the core doesn't try and close the connection.
+ */
+ DelayReconnect();
+ return true;
+ }
+ }
+
+ bool DoResetPoll()
+ {
+ switch(PQresetPoll(sql))
+ {
+ case PGRES_POLLING_WRITING:
+ Instance->SE->WantWrite(this);
+ status = CWRITE;
+ return DoPoll();
+ case PGRES_POLLING_READING:
+ status = CREAD;
+ return true;
+ case PGRES_POLLING_FAILED:
+ return false;
+ case PGRES_POLLING_OK:
+ status = WWRITE;
+ return DoConnectedPoll();
+ default:
+ return true;
+ }
+ }
+
+ bool OnDataReady()
+ {
+ /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
+ return DoEvent();
+ }
+
+ bool OnWriteReady()
+ {
+ /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
+ return DoEvent();
+ }
+
+ bool OnConnected()
+ {
+ return DoEvent();
+ }
+
+ void DelayReconnect();
+
+ bool DoEvent()
+ {
+ bool ret;
+
+ if((status == CREAD) || (status == CWRITE))
+ {
+ ret = DoPoll();
+ }
+ else if((status == RREAD) || (status == RWRITE))
+ {
+ ret = DoResetPoll();
+ }
+ else
+ {
+ ret = DoConnectedPoll();
+ }
+ return ret;
+ }
+
+ SQLerror DoQuery(SQLrequest &req)
+ {
+ if((status == WREAD) || (status == WWRITE))
+ {
+ if(!qinprog)
+ {
+ /* Parse the command string and dispatch it */
+
+ /* Pointer to the buffer we screw around with substitution in */
+ char* query;
+ /* Pointer to the current end of query, where we append new stuff */
+ char* queryend;
+ /* Total length of the unescaped parameters */
+ unsigned int paramlen;
+
+ paramlen = 0;
+
+ for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
+ {
+ paramlen += i->size();
+ }
+
+ /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
+ * sizeofquery + (totalparamlength*2) + 1
+ *
+ * The +1 is for null-terminating the string for PQsendQuery()
+ */
+
+ query = new char[req.query.q.length() + (paramlen*2) + 1];
+ queryend = query;
+
+ /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
+ * the parameters into it...
+ */
+
+ for(unsigned int i = 0; i < req.query.q.length(); i++)
+ {
+ if(req.query.q[i] == '?')
+ {
+ /* We found a place to substitute..what fun.
+ * Use the PgSQL calls to escape and write the
+ * escaped string onto the end of our query buffer,
+ * then we "just" need to make sure queryend is
+ * pointing at the right place.
+ */
+
+ if(req.query.p.size())
+ {
+ int error = 0;
+ size_t len = 0;
+
+#ifdef PGSQL_HAS_ESCAPECONN
+ len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error);
+#else
+ len = PQescapeString (queryend, req.query.p.front().c_str(), req.query.p.front().length());
+#endif
+ if(error)
+ {
+ Instance->Log(DEBUG, "BUG: Apparently PQescapeStringConn() failed somehow...don't know how or what to do...");
+ }
+
+ /* Incremenet queryend to the end of the newly escaped parameter */
+ queryend += len;
+
+ /* Remove the parameter we just substituted in */
+ req.query.p.pop_front();
+ }
+ else
+ {
+ Instance->Log(DEBUG, "BUG: Found a substitution location but no parameter to substitute :|");
+ break;
+ }
+ }
+ else
+ {
+ *queryend = req.query.q[i];
+ queryend++;
+ }
+ }
+
+ /* Null-terminate the query */
+ *queryend = 0;
+ req.query.q = query;
+
+ if(PQsendQuery(sql, query))
+ {
+ qinprog = true;
+ delete[] query;
+ return SQLerror();
+ }
+ else
+ {
+ delete[] query;
+ return SQLerror(QSEND_FAIL, PQerrorMessage(sql));
+ }
+ }
+ }
+ return SQLerror(BAD_CONN, "Can't query until connection is complete");
+ }
+
+ SQLerror Query(const SQLrequest &req)
+ {
+ queue.push(req);
+
+ if(!qinprog && queue.totalsize())
+ {
+ /* There's no query currently in progress, and there's queries in the queue. */
+ SQLrequest& query = queue.front();
+ return DoQuery(query);
+ }
+ else
+ {
+ return SQLerror();
+ }
+ }
+
+ void OnUnloadModule(Module* mod)
+ {
+ queue.PurgeModule(mod);
+ }
+
+ const SQLhost GetConfHost()
+ {
+ return confhost;
+ }
+
+ void Close() {
+ if (!this->Instance->SE->DelFd(this))
+ {
+ if (sql && PQstatus(sql) == CONNECTION_BAD)
+ {
+ this->Instance->SE->DelFd(this, true);
+ }
+ else
+ {
+ Instance->Log(DEBUG, "BUG: PQsocket cant be removed from socket engine!");
+ }
+ }
+
+ if(sql)
+ {
+ PQfinish(sql);
+ sql = NULL;
+ }
+ }
+
+};
+
+class ModulePgSQL : public Module
+{
+ private:
+ ConnMap connections;
+ unsigned long currid;
+ char* sqlsuccess;
+ ReconnectTimer* retimer;
+
+ public:
+ ModulePgSQL(InspIRCd* Me)
+ : Module::Module(Me), currid(0)
+ {
+ ServerInstance->UseInterface("SQLutils");
+
+ sqlsuccess = new char[strlen(SQLSUCCESS)+1];
+
+ strlcpy(sqlsuccess, SQLSUCCESS, strlen(SQLSUCCESS));
+
+ if (!ServerInstance->PublishFeature("SQL", this))
+ {
+ throw ModuleException("BUG: PgSQL Unable to publish feature 'SQL'");
+ }
+
+ ReadConf();
+
+ ServerInstance->PublishInterface("SQL", this);
+ }
+
+ virtual ~ModulePgSQL()
+ {
+ if (retimer)
+ ServerInstance->Timers->DelTimer(retimer);
+ ClearAllConnections();
+ delete[] sqlsuccess;
+ ServerInstance->UnpublishInterface("SQL", this);
+ ServerInstance->UnpublishFeature("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ReadConf();
+ }
+
+ bool HasHost(const SQLhost &host)
+ {
+ for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ if (host == iter->second->GetConfHost())
+ return true;
+ }
+ return false;
+ }
+
+ bool HostInConf(const SQLhost &h)
+ {
+ ConfigReader conf(ServerInstance);
+ for(int i = 0; i < conf.Enumerate("database"); i++)
+ {
+ SQLhost host;
+ host.id = conf.ReadValue("database", "id", i);
+ host.host = conf.ReadValue("database", "hostname", i);
+ host.port = conf.ReadInteger("database", "port", i, true);
+ host.name = conf.ReadValue("database", "name", i);
+ host.user = conf.ReadValue("database", "username", i);
+ host.pass = conf.ReadValue("database", "password", i);
+ host.ssl = conf.ReadFlag("database", "ssl", "0", i);
+ if (h == host)
+ return true;
+ }
+ return false;
+ }
+
+ void ReadConf()
+ {
+ ClearOldConnections();
+
+ ConfigReader conf(ServerInstance);
+ for(int i = 0; i < conf.Enumerate("database"); i++)
+ {
+ SQLhost host;
+ int ipvalid;
+
+ host.id = conf.ReadValue("database", "id", i);
+ host.host = conf.ReadValue("database", "hostname", i);
+ host.port = conf.ReadInteger("database", "port", i, true);
+ host.name = conf.ReadValue("database", "name", i);
+ host.user = conf.ReadValue("database", "username", i);
+ host.pass = conf.ReadValue("database", "password", i);
+ host.ssl = conf.ReadFlag("database", "ssl", "0", i);
+
+ if (HasHost(host))
+ continue;
+
+#ifdef IPV6
+ if (strchr(host.host.c_str(),':'))
+ {
+ in6_addr blargle;
+ ipvalid = inet_pton(AF_INET6, host.host.c_str(), &blargle);
+ }
+ else
+#endif
+ {
+ in_addr blargle;
+ ipvalid = inet_aton(host.host.c_str(), &blargle);
+ }
+
+ if(ipvalid > 0)
+ {
+ /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */
+ host.ip = host.host;
+ this->AddConn(host);
+ }
+ else if(ipvalid == 0)
+ {
+ /* Conversion failed, assume it's a host */
+ SQLresolver* resolver;
+
+ try
+ {
+ bool cached;
+ resolver = new SQLresolver(this, ServerInstance, host, cached);
+ ServerInstance->AddResolver(resolver, cached);
+ }
+ catch(...)
+ {
+ /* THE WORLD IS COMING TO AN END! */
+ }
+ }
+ else
+ {
+ /* Invalid address family, die horribly. */
+ ServerInstance->Log(DEBUG, "BUG: insp_aton failed returning -1, oh noes.");
+ }
+ }
+ }
+
+ void ClearOldConnections()
+ {
+ ConnMap::iterator iter,safei;
+ for (iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ if (!HostInConf(iter->second->GetConfHost()))
+ {
+ DELETE(iter->second);
+ safei = iter;
+ --iter;
+ connections.erase(safei);
+ }
+ }
+ }
+
+ void ClearAllConnections()
+ {
+ ConnMap::iterator i;
+ while ((i = connections.begin()) != connections.end())
+ {
+ connections.erase(i);
+ DELETE(i->second);
+ }
+ }
+
+ void AddConn(const SQLhost& hi)
+ {
+ if (HasHost(hi))
+ {
+ ServerInstance->Log(DEFAULT, "WARNING: A pgsql connection with id: %s already exists, possibly due to DNS delay. Aborting connection attempt.", hi.id.c_str());
+ return;
+ }
+
+ SQLConn* newconn;
+
+ /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */
+ newconn = new SQLConn(ServerInstance, this, hi);
+
+ connections.insert(std::make_pair(hi.id, newconn));
+ }
+
+ void ReconnectConn(SQLConn* conn)
+ {
+ for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ if (conn == iter->second)
+ {
+ DELETE(iter->second);
+ connections.erase(iter);
+ break;
+ }
+ }
+ retimer = new ReconnectTimer(ServerInstance, this);
+ ServerInstance->Timers->AddTimer(retimer);
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLREQID, request->GetId()) == 0)
+ {
+ SQLrequest* req = (SQLrequest*)request;
+ ConnMap::iterator iter;
+ if((iter = connections.find(req->dbid)) != connections.end())
+ {
+ /* Execute query */
+ req->id = NewID();
+ req->error = iter->second->Query(*req);
+
+ return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL;
+ }
+ else
+ {
+ req->error.Id(BAD_DBID);
+ return NULL;
+ }
+ }
+ return NULL;
+ }
+
+ virtual void OnUnloadModule(Module* mod, const std::string& name)
+ {
+ /* When a module unloads we have to check all the pending queries for all our connections
+ * and set the Module* specifying where the query came from to NULL. If the query has already
+ * been dispatched then when it is processed it will be dropped if the pointer is NULL.
+ *
+ * If the queries we find are not already being executed then we can simply remove them immediately.
+ */
+ for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ iter->second->OnUnloadModule(mod);
+ }
+ }
+
+ unsigned long NewID()
+ {
+ if (currid+1 == 0)
+ currid++;
+
+ return ++currid;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION);
+ }
+};
+
+/* move this here to use AddConn, rather that than having the whole
+ * module above SQLConn, since this is buggin me right now :/
+ */
+void SQLresolver::OnLookupComplete(const std::string &result, unsigned int ttl, bool cached)
+{
+ host.ip = result;
+ ((ModulePgSQL*)mod)->AddConn(host);
+ ((ModulePgSQL*)mod)->ClearOldConnections();
+}
+
+void ReconnectTimer::Tick(time_t time)
+{
+ ((ModulePgSQL*)mod)->ReadConf();
+}
+
+void SQLConn::DelayReconnect()
+{
+ ((ModulePgSQL*)us)->ReconnectConn(this);
+}
+
+MODULE_INIT(ModulePgSQL);
+
diff --git a/src/modules/extra/m_sqlauth.cpp b/src/modules/extra/m_sqlauth.cpp
index 862929919..6b05ee521 100644
--- a/src/modules/extra/m_sqlauth.cpp
+++ b/src/modules/extra/m_sqlauth.cpp
@@ -1 +1,194 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "m_sqlv2.h" #include "m_sqlutils.h" /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */ /* $ModDep: m_sqlv2.h m_sqlutils.h */ class ModuleSQLAuth : public Module { Module* SQLutils; Module* SQLprovider; std::string usertable; std::string userfield; std::string passfield; std::string encryption; std::string killreason; std::string allowpattern; std::string databaseid; bool verbose; public: ModuleSQLAuth(InspIRCd* Me) : Module::Module(Me) { ServerInstance->UseInterface("SQLutils"); ServerInstance->UseInterface("SQL"); SQLutils = ServerInstance->FindModule("m_sqlutils.so"); if (!SQLutils) throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so."); SQLprovider = ServerInstance->FindFeature("SQL"); if (!SQLprovider) throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth."); OnRehash(NULL,""); } virtual ~ModuleSQLAuth() { ServerInstance->DoneWithInterface("SQL"); ServerInstance->DoneWithInterface("SQLutils"); } void Implements(char* List) { List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1; } virtual void OnRehash(userrec* user, const std::string &parameter) { ConfigReader Conf(ServerInstance); usertable = Conf.ReadValue("sqlauth", "usertable", 0); /* User table name */ databaseid = Conf.ReadValue("sqlauth", "dbid", 0); /* Database ID, given to the SQL service provider */ userfield = Conf.ReadValue("sqlauth", "userfield", 0); /* Field name where username can be found */ passfield = Conf.ReadValue("sqlauth", "passfield", 0); /* Field name where password can be found */ killreason = Conf.ReadValue("sqlauth", "killreason", 0); /* Reason to give when access is denied to a user (put your reg details here) */ allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */ encryption = Conf.ReadValue("sqlauth", "encryption", 0); /* Name of sql function used to encrypt password, e.g. "md5" or "passwd". * define, but leave blank if no encryption is to be used. */ verbose = Conf.ReadFlag("sqlauth", "verbose", 0); /* Set to true if failed connects should be reported to operators */ if (encryption.find("(") == std::string::npos) { encryption.append("("); } } virtual int OnUserRegister(userrec* user) { if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern))) { user->Extend("sqlauthed"); return 0; } if (!CheckCredentials(user)) { userrec::QuitUser(ServerInstance,user,killreason); return 1; } return 0; } bool CheckCredentials(userrec* user) { SQLrequest req = SQLreq(this, SQLprovider, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password); if(req.Send()) { /* When we get the query response from the service provider we will be given an ID to play with, * just an ID number which is unique to this query. We need a way of associating that ID with a userrec * so we insert it into a map mapping the IDs to users. * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling * us to discard the query. */ AssociateUser(this, SQLutils, req.id, user).Send(); return true; } else { if (verbose) ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str()); return false; } } virtual char* OnRequest(Request* request) { if(strcmp(SQLRESID, request->GetId()) == 0) { SQLresult* res = static_cast<SQLresult*>(request); userrec* user = GetAssocUser(this, SQLutils, res->id).S().user; UnAssociate(this, SQLutils, res->id).S(); if(user) { if(res->error.Id() == NO_ERROR) { if(res->Rows()) { /* We got a row in the result, this is enough really */ user->Extend("sqlauthed"); } else if (verbose) { /* No rows in result, this means there was no record matching the user */ ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host); user->Extend("sqlauth_failed"); } } else if (verbose) { ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str()); user->Extend("sqlauth_failed"); } } else { return NULL; } if (!user->GetExt("sqlauthed")) { userrec::QuitUser(ServerInstance,user,killreason); } return SQLSUCCESS; } return NULL; } virtual void OnUserDisconnect(userrec* user) { user->Shrink("sqlauthed"); user->Shrink("sqlauth_failed"); } virtual bool OnCheckReady(userrec* user) { return user->GetExt("sqlauthed"); } virtual Version GetVersion() { return Version(1,1,1,0,VF_VENDOR,API_VERSION); } }; MODULE_INIT(ModuleSQLAuth); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "m_sqlv2.h"
+#include "m_sqlutils.h"
+
+/* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
+/* $ModDep: m_sqlv2.h m_sqlutils.h */
+
+class ModuleSQLAuth : public Module
+{
+ Module* SQLutils;
+ Module* SQLprovider;
+
+ std::string usertable;
+ std::string userfield;
+ std::string passfield;
+ std::string encryption;
+ std::string killreason;
+ std::string allowpattern;
+ std::string databaseid;
+
+ bool verbose;
+
+public:
+ ModuleSQLAuth(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ ServerInstance->UseInterface("SQLutils");
+ ServerInstance->UseInterface("SQL");
+
+ SQLutils = ServerInstance->FindModule("m_sqlutils.so");
+ if (!SQLutils)
+ throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
+
+ SQLprovider = ServerInstance->FindFeature("SQL");
+ if (!SQLprovider)
+ throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
+
+ OnRehash(NULL,"");
+ }
+
+ virtual ~ModuleSQLAuth()
+ {
+ ServerInstance->DoneWithInterface("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1;
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ConfigReader Conf(ServerInstance);
+
+ usertable = Conf.ReadValue("sqlauth", "usertable", 0); /* User table name */
+ databaseid = Conf.ReadValue("sqlauth", "dbid", 0); /* Database ID, given to the SQL service provider */
+ userfield = Conf.ReadValue("sqlauth", "userfield", 0); /* Field name where username can be found */
+ passfield = Conf.ReadValue("sqlauth", "passfield", 0); /* Field name where password can be found */
+ killreason = Conf.ReadValue("sqlauth", "killreason", 0); /* Reason to give when access is denied to a user (put your reg details here) */
+ allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
+ encryption = Conf.ReadValue("sqlauth", "encryption", 0); /* Name of sql function used to encrypt password, e.g. "md5" or "passwd".
+ * define, but leave blank if no encryption is to be used.
+ */
+ verbose = Conf.ReadFlag("sqlauth", "verbose", 0); /* Set to true if failed connects should be reported to operators */
+
+ if (encryption.find("(") == std::string::npos)
+ {
+ encryption.append("(");
+ }
+ }
+
+ virtual int OnUserRegister(userrec* user)
+ {
+ if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern)))
+ {
+ user->Extend("sqlauthed");
+ return 0;
+ }
+
+ if (!CheckCredentials(user))
+ {
+ userrec::QuitUser(ServerInstance,user,killreason);
+ return 1;
+ }
+ return 0;
+ }
+
+ bool CheckCredentials(userrec* user)
+ {
+ SQLrequest req = SQLreq(this, SQLprovider, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password);
+
+ if(req.Send())
+ {
+ /* When we get the query response from the service provider we will be given an ID to play with,
+ * just an ID number which is unique to this query. We need a way of associating that ID with a userrec
+ * so we insert it into a map mapping the IDs to users.
+ * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+ * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+ * us to discard the query.
+ */
+ AssociateUser(this, SQLutils, req.id, user).Send();
+
+ return true;
+ }
+ else
+ {
+ if (verbose)
+ ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
+ return false;
+ }
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLRESID, request->GetId()) == 0)
+ {
+ SQLresult* res = static_cast<SQLresult*>(request);
+
+ userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
+ UnAssociate(this, SQLutils, res->id).S();
+
+ if(user)
+ {
+ if(res->error.Id() == NO_ERROR)
+ {
+ if(res->Rows())
+ {
+ /* We got a row in the result, this is enough really */
+ user->Extend("sqlauthed");
+ }
+ else if (verbose)
+ {
+ /* No rows in result, this means there was no record matching the user */
+ ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
+ user->Extend("sqlauth_failed");
+ }
+ }
+ else if (verbose)
+ {
+ ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str());
+ user->Extend("sqlauth_failed");
+ }
+ }
+ else
+ {
+ return NULL;
+ }
+
+ if (!user->GetExt("sqlauthed"))
+ {
+ userrec::QuitUser(ServerInstance,user,killreason);
+ }
+ return SQLSUCCESS;
+ }
+ return NULL;
+ }
+
+ virtual void OnUserDisconnect(userrec* user)
+ {
+ user->Shrink("sqlauthed");
+ user->Shrink("sqlauth_failed");
+ }
+
+ virtual bool OnCheckReady(userrec* user)
+ {
+ return user->GetExt("sqlauthed");
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,1,0,VF_VENDOR,API_VERSION);
+ }
+
+};
+
+MODULE_INIT(ModuleSQLAuth);
+
diff --git a/src/modules/extra/m_sqlite3.cpp b/src/modules/extra/m_sqlite3.cpp
index 6741d7745..66955de07 100644
--- a/src/modules/extra/m_sqlite3.cpp
+++ b/src/modules/extra/m_sqlite3.cpp
@@ -1 +1,660 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <sqlite3.h> #include "users.h" #include "channels.h" #include "modules.h" #include "m_sqlv2.h" /* $ModDesc: sqlite3 provider */ /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */ /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */ /* $ModDep: m_sqlv2.h */ class SQLConn; class SQLite3Result; class ResultNotifier; typedef std::map<std::string, SQLConn*> ConnMap; typedef std::deque<classbase*> paramlist; typedef std::deque<SQLite3Result*> ResultQueue; ResultNotifier* resultnotify = NULL; class ResultNotifier : public InspSocket { Module* mod; insp_sockaddr sock_us; socklen_t uslen; public: /* Create a socket on a random port. Let the tcp stack allocate us an available port */ #ifdef IPV6 ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m) #else ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m) #endif { uslen = sizeof(sock_us); if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen)) { throw ModuleException("Could not create random listening port on localhost"); } } ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m) { } /* Using getsockname and ntohs, we can determine which port number we were allocated */ int GetPort() { #ifdef IPV6 return ntohs(sock_us.sin6_port); #else return ntohs(sock_us.sin_port); #endif } virtual int OnIncomingConnection(int newsock, char* ip) { Dispatch(); return false; } void Dispatch(); }; class SQLite3Result : public SQLresult { private: int currentrow; int rows; int cols; std::vector<std::string> colnames; std::vector<SQLfieldList> fieldlists; SQLfieldList emptyfieldlist; SQLfieldList* fieldlist; SQLfieldMap* fieldmap; public: SQLite3Result(Module* self, Module* to, unsigned int id) : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL) { } ~SQLite3Result() { } void AddRow(int colsnum, char **data, char **colname) { colnames.clear(); cols = colsnum; for (int i = 0; i < colsnum; i++) { fieldlists.resize(fieldlists.size()+1); colnames.push_back(colname[i]); SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true); fieldlists[rows].push_back(sf); } rows++; } void UpdateAffectedCount() { rows++; } virtual int Rows() { return rows; } virtual int Cols() { return cols; } virtual std::string ColName(int column) { if (column < (int)colnames.size()) { return colnames[column]; } else { throw SQLbadColName(); } return ""; } virtual int ColNum(const std::string &column) { for (unsigned int i = 0; i < colnames.size(); i++) { if (column == colnames[i]) return i; } throw SQLbadColName(); return 0; } virtual SQLfield GetValue(int row, int column) { if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols())) { return fieldlists[row][column]; } throw SQLbadColName(); /* XXX: We never actually get here because of the throw */ return SQLfield("",true); } virtual SQLfieldList& GetRow() { if (currentrow < rows) return fieldlists[currentrow]; else return emptyfieldlist; } virtual SQLfieldMap& GetRowMap() { /* In an effort to reduce overhead we don't actually allocate the map * until the first time it's needed...so... */ if(fieldmap) { fieldmap->clear(); } else { fieldmap = new SQLfieldMap; } if (currentrow < rows) { for (int i = 0; i < Cols(); i++) { fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); } currentrow++; } return *fieldmap; } virtual SQLfieldList* GetRowPtr() { fieldlist = new SQLfieldList(); if (currentrow < rows) { for (int i = 0; i < Rows(); i++) { fieldlist->push_back(fieldlists[currentrow][i]); } currentrow++; } return fieldlist; } virtual SQLfieldMap* GetRowMapPtr() { fieldmap = new SQLfieldMap(); if (currentrow < rows) { for (int i = 0; i < Cols(); i++) { fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i))); } currentrow++; } return fieldmap; } virtual void Free(SQLfieldMap* fm) { delete fm; } virtual void Free(SQLfieldList* fl) { delete fl; } }; class SQLConn : public classbase { private: ResultQueue results; InspIRCd* Instance; Module* mod; SQLhost host; sqlite3* conn; public: SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi) : Instance(SI), mod(m), host(hi) { if (OpenDB() != SQLITE_OK) { Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id); CloseDB(); } } ~SQLConn() { CloseDB(); } SQLerror Query(SQLrequest &req) { /* Pointer to the buffer we screw around with substitution in */ char* query; /* Pointer to the current end of query, where we append new stuff */ char* queryend; /* Total length of the unescaped parameters */ unsigned long paramlen; /* Total length of query, used for binary-safety in mysql_real_query */ unsigned long querylength = 0; paramlen = 0; for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) { paramlen += i->size(); } /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. * sizeofquery + (totalparamlength*2) + 1 * * The +1 is for null-terminating the string for mysql_real_escape_string */ query = new char[req.query.q.length() + (paramlen*2) + 1]; queryend = query; for(unsigned long i = 0; i < req.query.q.length(); i++) { if(req.query.q[i] == '?') { if(req.query.p.size()) { char* escaped; escaped = sqlite3_mprintf("%q", req.query.p.front().c_str()); for (char* n = escaped; *n; n++) { *queryend = *n; queryend++; } sqlite3_free(escaped); req.query.p.pop_front(); } else break; } else { *queryend = req.query.q[i]; queryend++; } querylength++; } *queryend = 0; req.query.q = query; SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id); res->dbid = host.id; res->query = req.query.q; paramlist params; params.push_back(this); params.push_back(res); char *errmsg = 0; sqlite3_update_hook(conn, QueryUpdateHook, &params); if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK) { std::string error(errmsg); sqlite3_free(errmsg); delete[] query; delete res; return SQLerror(QSEND_FAIL, error); } delete[] query; results.push_back(res); SendNotify(); return SQLerror(); } static int QueryResult(void *params, int argc, char **argv, char **azColName) { paramlist* p = (paramlist*)params; ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName); return 0; } static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid) { paramlist* p = (paramlist*)params; ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1])); } void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames) { res->AddRow(cols, data, colnames); } void AffectedReady(SQLite3Result *res) { res->UpdateAffectedCount(); } int OpenDB() { return sqlite3_open(host.host.c_str(), &conn); } void CloseDB() { sqlite3_interrupt(conn); sqlite3_close(conn); } SQLhost GetConfHost() { return host; } void SendResults() { while (results.size()) { SQLite3Result* res = results[0]; if (res->GetDest()) { res->Send(); } else { /* If the client module is unloaded partway through a query then the provider will set * the pointer to NULL. We cannot just cancel the query as the result will still come * through at some point...and it could get messy if we play with invalid pointers... */ delete res; } results.pop_front(); } } void ClearResults() { while (results.size()) { SQLite3Result* res = results[0]; delete res; results.pop_front(); } } void SendNotify() { int QueueFD; if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1) { /* crap, we're out of sockets... */ return; } insp_sockaddr addr; #ifdef IPV6 insp_aton("::1", &addr.sin6_addr); addr.sin6_family = AF_FAMILY; addr.sin6_port = htons(resultnotify->GetPort()); #else insp_inaddr ia; insp_aton("127.0.0.1", &ia); addr.sin_family = AF_FAMILY; addr.sin_addr = ia; addr.sin_port = htons(resultnotify->GetPort()); #endif if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1) { /* wtf, we cant connect to it, but we just created it! */ return; } } }; class ModuleSQLite3 : public Module { private: ConnMap connections; unsigned long currid; public: ModuleSQLite3(InspIRCd* Me) : Module::Module(Me), currid(0) { ServerInstance->UseInterface("SQLutils"); if (!ServerInstance->PublishFeature("SQL", this)) { throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'"); } resultnotify = new ResultNotifier(ServerInstance, this); ReadConf(); ServerInstance->PublishInterface("SQL", this); } virtual ~ModuleSQLite3() { ClearQueue(); ClearAllConnections(); resultnotify->SetFd(-1); resultnotify->state = I_ERROR; resultnotify->OnError(I_ERR_SOCKET); resultnotify->ClosePending = true; delete resultnotify; ServerInstance->UnpublishInterface("SQL", this); ServerInstance->UnpublishFeature("SQL"); ServerInstance->DoneWithInterface("SQLutils"); } void Implements(char* List) { List[I_OnRequest] = List[I_OnRehash] = 1; } void SendQueue() { for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { iter->second->SendResults(); } } void ClearQueue() { for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { iter->second->ClearResults(); } } bool HasHost(const SQLhost &host) { for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { if (host == iter->second->GetConfHost()) return true; } return false; } bool HostInConf(const SQLhost &h) { ConfigReader conf(ServerInstance); for(int i = 0; i < conf.Enumerate("database"); i++) { SQLhost host; host.id = conf.ReadValue("database", "id", i); host.host = conf.ReadValue("database", "hostname", i); host.port = conf.ReadInteger("database", "port", i, true); host.name = conf.ReadValue("database", "name", i); host.user = conf.ReadValue("database", "username", i); host.pass = conf.ReadValue("database", "password", i); host.ssl = conf.ReadFlag("database", "ssl", "0", i); if (h == host) return true; } return false; } void ReadConf() { ClearOldConnections(); ConfigReader conf(ServerInstance); for(int i = 0; i < conf.Enumerate("database"); i++) { SQLhost host; host.id = conf.ReadValue("database", "id", i); host.host = conf.ReadValue("database", "hostname", i); host.port = conf.ReadInteger("database", "port", i, true); host.name = conf.ReadValue("database", "name", i); host.user = conf.ReadValue("database", "username", i); host.pass = conf.ReadValue("database", "password", i); host.ssl = conf.ReadFlag("database", "ssl", "0", i); if (HasHost(host)) continue; this->AddConn(host); } } void AddConn(const SQLhost& hi) { if (HasHost(hi)) { ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str()); return; } SQLConn* newconn; newconn = new SQLConn(ServerInstance, this, hi); connections.insert(std::make_pair(hi.id, newconn)); } void ClearOldConnections() { ConnMap::iterator iter,safei; for (iter = connections.begin(); iter != connections.end(); iter++) { if (!HostInConf(iter->second->GetConfHost())) { DELETE(iter->second); safei = iter; --iter; connections.erase(safei); } } } void ClearAllConnections() { ConnMap::iterator i; while ((i = connections.begin()) != connections.end()) { connections.erase(i); DELETE(i->second); } } virtual void OnRehash(userrec* user, const std::string &parameter) { ReadConf(); } virtual char* OnRequest(Request* request) { if(strcmp(SQLREQID, request->GetId()) == 0) { SQLrequest* req = (SQLrequest*)request; ConnMap::iterator iter; if((iter = connections.find(req->dbid)) != connections.end()) { req->id = NewID(); req->error = iter->second->Query(*req); return SQLSUCCESS; } else { req->error.Id(BAD_DBID); return NULL; } } return NULL; } unsigned long NewID() { if (currid+1 == 0) currid++; return ++currid; } virtual Version GetVersion() { return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION); } }; void ResultNotifier::Dispatch() { ((ModuleSQLite3*)mod)->SendQueue(); } MODULE_INIT(ModuleSQLite3); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <sqlite3.h>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+
+#include "m_sqlv2.h"
+
+/* $ModDesc: sqlite3 provider */
+/* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
+/* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
+/* $ModDep: m_sqlv2.h */
+
+
+class SQLConn;
+class SQLite3Result;
+class ResultNotifier;
+
+typedef std::map<std::string, SQLConn*> ConnMap;
+typedef std::deque<classbase*> paramlist;
+typedef std::deque<SQLite3Result*> ResultQueue;
+
+ResultNotifier* resultnotify = NULL;
+
+
+class ResultNotifier : public InspSocket
+{
+ Module* mod;
+ insp_sockaddr sock_us;
+ socklen_t uslen;
+
+ public:
+ /* Create a socket on a random port. Let the tcp stack allocate us an available port */
+#ifdef IPV6
+ ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
+#else
+ ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
+#endif
+ {
+ uslen = sizeof(sock_us);
+ if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
+ {
+ throw ModuleException("Could not create random listening port on localhost");
+ }
+ }
+
+ ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
+ {
+ }
+
+ /* Using getsockname and ntohs, we can determine which port number we were allocated */
+ int GetPort()
+ {
+#ifdef IPV6
+ return ntohs(sock_us.sin6_port);
+#else
+ return ntohs(sock_us.sin_port);
+#endif
+ }
+
+ virtual int OnIncomingConnection(int newsock, char* ip)
+ {
+ Dispatch();
+ return false;
+ }
+
+ void Dispatch();
+};
+
+
+class SQLite3Result : public SQLresult
+{
+ private:
+ int currentrow;
+ int rows;
+ int cols;
+
+ std::vector<std::string> colnames;
+ std::vector<SQLfieldList> fieldlists;
+ SQLfieldList emptyfieldlist;
+
+ SQLfieldList* fieldlist;
+ SQLfieldMap* fieldmap;
+
+ public:
+ SQLite3Result(Module* self, Module* to, unsigned int id)
+ : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
+ {
+ }
+
+ ~SQLite3Result()
+ {
+ }
+
+ void AddRow(int colsnum, char **data, char **colname)
+ {
+ colnames.clear();
+ cols = colsnum;
+ for (int i = 0; i < colsnum; i++)
+ {
+ fieldlists.resize(fieldlists.size()+1);
+ colnames.push_back(colname[i]);
+ SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
+ fieldlists[rows].push_back(sf);
+ }
+ rows++;
+ }
+
+ void UpdateAffectedCount()
+ {
+ rows++;
+ }
+
+ virtual int Rows()
+ {
+ return rows;
+ }
+
+ virtual int Cols()
+ {
+ return cols;
+ }
+
+ virtual std::string ColName(int column)
+ {
+ if (column < (int)colnames.size())
+ {
+ return colnames[column];
+ }
+ else
+ {
+ throw SQLbadColName();
+ }
+ return "";
+ }
+
+ virtual int ColNum(const std::string &column)
+ {
+ for (unsigned int i = 0; i < colnames.size(); i++)
+ {
+ if (column == colnames[i])
+ return i;
+ }
+ throw SQLbadColName();
+ return 0;
+ }
+
+ virtual SQLfield GetValue(int row, int column)
+ {
+ if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
+ {
+ return fieldlists[row][column];
+ }
+
+ throw SQLbadColName();
+
+ /* XXX: We never actually get here because of the throw */
+ return SQLfield("",true);
+ }
+
+ virtual SQLfieldList& GetRow()
+ {
+ if (currentrow < rows)
+ return fieldlists[currentrow];
+ else
+ return emptyfieldlist;
+ }
+
+ virtual SQLfieldMap& GetRowMap()
+ {
+ /* In an effort to reduce overhead we don't actually allocate the map
+ * until the first time it's needed...so...
+ */
+ if(fieldmap)
+ {
+ fieldmap->clear();
+ }
+ else
+ {
+ fieldmap = new SQLfieldMap;
+ }
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Cols(); i++)
+ {
+ fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
+ }
+ currentrow++;
+ }
+
+ return *fieldmap;
+ }
+
+ virtual SQLfieldList* GetRowPtr()
+ {
+ fieldlist = new SQLfieldList();
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Rows(); i++)
+ {
+ fieldlist->push_back(fieldlists[currentrow][i]);
+ }
+ currentrow++;
+ }
+ return fieldlist;
+ }
+
+ virtual SQLfieldMap* GetRowMapPtr()
+ {
+ fieldmap = new SQLfieldMap();
+
+ if (currentrow < rows)
+ {
+ for (int i = 0; i < Cols(); i++)
+ {
+ fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
+ }
+ currentrow++;
+ }
+
+ return fieldmap;
+ }
+
+ virtual void Free(SQLfieldMap* fm)
+ {
+ delete fm;
+ }
+
+ virtual void Free(SQLfieldList* fl)
+ {
+ delete fl;
+ }
+
+
+};
+
+class SQLConn : public classbase
+{
+ private:
+ ResultQueue results;
+ InspIRCd* Instance;
+ Module* mod;
+ SQLhost host;
+ sqlite3* conn;
+
+ public:
+ SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
+ : Instance(SI), mod(m), host(hi)
+ {
+ if (OpenDB() != SQLITE_OK)
+ {
+ Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
+ CloseDB();
+ }
+ }
+
+ ~SQLConn()
+ {
+ CloseDB();
+ }
+
+ SQLerror Query(SQLrequest &req)
+ {
+ /* Pointer to the buffer we screw around with substitution in */
+ char* query;
+
+ /* Pointer to the current end of query, where we append new stuff */
+ char* queryend;
+
+ /* Total length of the unescaped parameters */
+ unsigned long paramlen;
+
+ /* Total length of query, used for binary-safety in mysql_real_query */
+ unsigned long querylength = 0;
+
+ paramlen = 0;
+ for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
+ {
+ paramlen += i->size();
+ }
+
+ /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
+ * sizeofquery + (totalparamlength*2) + 1
+ *
+ * The +1 is for null-terminating the string for mysql_real_escape_string
+ */
+ query = new char[req.query.q.length() + (paramlen*2) + 1];
+ queryend = query;
+
+ for(unsigned long i = 0; i < req.query.q.length(); i++)
+ {
+ if(req.query.q[i] == '?')
+ {
+ if(req.query.p.size())
+ {
+ char* escaped;
+ escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
+ for (char* n = escaped; *n; n++)
+ {
+ *queryend = *n;
+ queryend++;
+ }
+ sqlite3_free(escaped);
+ req.query.p.pop_front();
+ }
+ else
+ break;
+ }
+ else
+ {
+ *queryend = req.query.q[i];
+ queryend++;
+ }
+ querylength++;
+ }
+ *queryend = 0;
+ req.query.q = query;
+
+ SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
+ res->dbid = host.id;
+ res->query = req.query.q;
+ paramlist params;
+ params.push_back(this);
+ params.push_back(res);
+
+ char *errmsg = 0;
+ sqlite3_update_hook(conn, QueryUpdateHook, &params);
+ if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
+ {
+ std::string error(errmsg);
+ sqlite3_free(errmsg);
+ delete[] query;
+ delete res;
+ return SQLerror(QSEND_FAIL, error);
+ }
+ delete[] query;
+
+ results.push_back(res);
+ SendNotify();
+ return SQLerror();
+ }
+
+ static int QueryResult(void *params, int argc, char **argv, char **azColName)
+ {
+ paramlist* p = (paramlist*)params;
+ ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
+ return 0;
+ }
+
+ static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
+ {
+ paramlist* p = (paramlist*)params;
+ ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
+ }
+
+ void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
+ {
+ res->AddRow(cols, data, colnames);
+ }
+
+ void AffectedReady(SQLite3Result *res)
+ {
+ res->UpdateAffectedCount();
+ }
+
+ int OpenDB()
+ {
+ return sqlite3_open(host.host.c_str(), &conn);
+ }
+
+ void CloseDB()
+ {
+ sqlite3_interrupt(conn);
+ sqlite3_close(conn);
+ }
+
+ SQLhost GetConfHost()
+ {
+ return host;
+ }
+
+ void SendResults()
+ {
+ while (results.size())
+ {
+ SQLite3Result* res = results[0];
+ if (res->GetDest())
+ {
+ res->Send();
+ }
+ else
+ {
+ /* If the client module is unloaded partway through a query then the provider will set
+ * the pointer to NULL. We cannot just cancel the query as the result will still come
+ * through at some point...and it could get messy if we play with invalid pointers...
+ */
+ delete res;
+ }
+ results.pop_front();
+ }
+ }
+
+ void ClearResults()
+ {
+ while (results.size())
+ {
+ SQLite3Result* res = results[0];
+ delete res;
+ results.pop_front();
+ }
+ }
+
+ void SendNotify()
+ {
+ int QueueFD;
+ if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
+ {
+ /* crap, we're out of sockets... */
+ return;
+ }
+
+ insp_sockaddr addr;
+
+#ifdef IPV6
+ insp_aton("::1", &addr.sin6_addr);
+ addr.sin6_family = AF_FAMILY;
+ addr.sin6_port = htons(resultnotify->GetPort());
+#else
+ insp_inaddr ia;
+ insp_aton("127.0.0.1", &ia);
+ addr.sin_family = AF_FAMILY;
+ addr.sin_addr = ia;
+ addr.sin_port = htons(resultnotify->GetPort());
+#endif
+
+ if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
+ {
+ /* wtf, we cant connect to it, but we just created it! */
+ return;
+ }
+ }
+
+};
+
+
+class ModuleSQLite3 : public Module
+{
+ private:
+ ConnMap connections;
+ unsigned long currid;
+
+ public:
+ ModuleSQLite3(InspIRCd* Me)
+ : Module::Module(Me), currid(0)
+ {
+ ServerInstance->UseInterface("SQLutils");
+
+ if (!ServerInstance->PublishFeature("SQL", this))
+ {
+ throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
+ }
+
+ resultnotify = new ResultNotifier(ServerInstance, this);
+
+ ReadConf();
+
+ ServerInstance->PublishInterface("SQL", this);
+ }
+
+ virtual ~ModuleSQLite3()
+ {
+ ClearQueue();
+ ClearAllConnections();
+ resultnotify->SetFd(-1);
+ resultnotify->state = I_ERROR;
+ resultnotify->OnError(I_ERR_SOCKET);
+ resultnotify->ClosePending = true;
+ delete resultnotify;
+ ServerInstance->UnpublishInterface("SQL", this);
+ ServerInstance->UnpublishFeature("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRequest] = List[I_OnRehash] = 1;
+ }
+
+ void SendQueue()
+ {
+ for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ iter->second->SendResults();
+ }
+ }
+
+ void ClearQueue()
+ {
+ for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ iter->second->ClearResults();
+ }
+ }
+
+ bool HasHost(const SQLhost &host)
+ {
+ for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ if (host == iter->second->GetConfHost())
+ return true;
+ }
+ return false;
+ }
+
+ bool HostInConf(const SQLhost &h)
+ {
+ ConfigReader conf(ServerInstance);
+ for(int i = 0; i < conf.Enumerate("database"); i++)
+ {
+ SQLhost host;
+ host.id = conf.ReadValue("database", "id", i);
+ host.host = conf.ReadValue("database", "hostname", i);
+ host.port = conf.ReadInteger("database", "port", i, true);
+ host.name = conf.ReadValue("database", "name", i);
+ host.user = conf.ReadValue("database", "username", i);
+ host.pass = conf.ReadValue("database", "password", i);
+ host.ssl = conf.ReadFlag("database", "ssl", "0", i);
+ if (h == host)
+ return true;
+ }
+ return false;
+ }
+
+ void ReadConf()
+ {
+ ClearOldConnections();
+
+ ConfigReader conf(ServerInstance);
+ for(int i = 0; i < conf.Enumerate("database"); i++)
+ {
+ SQLhost host;
+
+ host.id = conf.ReadValue("database", "id", i);
+ host.host = conf.ReadValue("database", "hostname", i);
+ host.port = conf.ReadInteger("database", "port", i, true);
+ host.name = conf.ReadValue("database", "name", i);
+ host.user = conf.ReadValue("database", "username", i);
+ host.pass = conf.ReadValue("database", "password", i);
+ host.ssl = conf.ReadFlag("database", "ssl", "0", i);
+
+ if (HasHost(host))
+ continue;
+
+ this->AddConn(host);
+ }
+ }
+
+ void AddConn(const SQLhost& hi)
+ {
+ if (HasHost(hi))
+ {
+ ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
+ return;
+ }
+
+ SQLConn* newconn;
+
+ newconn = new SQLConn(ServerInstance, this, hi);
+
+ connections.insert(std::make_pair(hi.id, newconn));
+ }
+
+ void ClearOldConnections()
+ {
+ ConnMap::iterator iter,safei;
+ for (iter = connections.begin(); iter != connections.end(); iter++)
+ {
+ if (!HostInConf(iter->second->GetConfHost()))
+ {
+ DELETE(iter->second);
+ safei = iter;
+ --iter;
+ connections.erase(safei);
+ }
+ }
+ }
+
+ void ClearAllConnections()
+ {
+ ConnMap::iterator i;
+ while ((i = connections.begin()) != connections.end())
+ {
+ connections.erase(i);
+ DELETE(i->second);
+ }
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ReadConf();
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLREQID, request->GetId()) == 0)
+ {
+ SQLrequest* req = (SQLrequest*)request;
+ ConnMap::iterator iter;
+ if((iter = connections.find(req->dbid)) != connections.end())
+ {
+ req->id = NewID();
+ req->error = iter->second->Query(*req);
+ return SQLSUCCESS;
+ }
+ else
+ {
+ req->error.Id(BAD_DBID);
+ return NULL;
+ }
+ }
+ return NULL;
+ }
+
+ unsigned long NewID()
+ {
+ if (currid+1 == 0)
+ currid++;
+
+ return ++currid;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
+ }
+
+};
+
+void ResultNotifier::Dispatch()
+{
+ ((ModuleSQLite3*)mod)->SendQueue();
+}
+
+MODULE_INIT(ModuleSQLite3);
+
diff --git a/src/modules/extra/m_sqllog.cpp b/src/modules/extra/m_sqllog.cpp
index 04eb1fef1..391e4bbba 100644
--- a/src/modules/extra/m_sqllog.cpp
+++ b/src/modules/extra/m_sqllog.cpp
@@ -1 +1,310 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "configreader.h" #include "m_sqlv2.h" static Module* SQLModule; static Module* MyMod; static std::string dbid; enum LogTypes { LT_OPER = 1, LT_KILL, LT_SERVLINK, LT_XLINE, LT_CONNECT, LT_DISCONNECT, LT_FLOOD, LT_LOADMODULE }; enum QueryState { FIND_SOURCE, FIND_NICK, FIND_HOST, DONE}; class QueryInfo; std::map<unsigned long,QueryInfo*> active_queries; class QueryInfo { public: QueryState qs; unsigned long id; std::string nick; std::string source; std::string hostname; int sourceid; int nickid; int hostid; int category; time_t date; bool insert; QueryInfo(const std::string &n, const std::string &s, const std::string &h, unsigned long i, int cat) { qs = FIND_SOURCE; nick = n; source = s; hostname = h; id = i; category = cat; sourceid = nickid = hostid = -1; date = time(NULL); insert = false; } void Go(SQLresult* res) { SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "", ""); switch (qs) { case FIND_SOURCE: if (res->Rows() && sourceid == -1 && !insert) { sourceid = atoi(res->GetValue(0,0).d.c_str()); req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick); if(req.Send()) { insert = false; qs = FIND_NICK; active_queries[req.id] = this; } } else if (res->Rows() && sourceid == -1 && insert) { req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source); if(req.Send()) { insert = false; qs = FIND_SOURCE; active_queries[req.id] = this; } } else { req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')", source); if(req.Send()) { insert = true; qs = FIND_SOURCE; active_queries[req.id] = this; } } break; case FIND_NICK: if (res->Rows() && nickid == -1 && !insert) { nickid = atoi(res->GetValue(0,0).d.c_str()); req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname); if(req.Send()) { insert = false; qs = FIND_HOST; active_queries[req.id] = this; } } else if (res->Rows() && nickid == -1 && insert) { req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick); if(req.Send()) { insert = false; qs = FIND_NICK; active_queries[req.id] = this; } } else { req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')",nick); if(req.Send()) { insert = true; qs = FIND_NICK; active_queries[req.id] = this; } } break; case FIND_HOST: if (res->Rows() && hostid == -1 && !insert) { hostid = atoi(res->GetValue(0,0).d.c_str()); req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log (category_id,nick,host,source,dtime) VALUES("+ConvToStr(category)+","+ConvToStr(nickid)+","+ConvToStr(hostid)+","+ConvToStr(sourceid)+","+ConvToStr(date)+")"); if(req.Send()) { insert = true; qs = DONE; active_queries[req.id] = this; } } else if (res->Rows() && hostid == -1 && insert) { req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname); if(req.Send()) { insert = false; qs = FIND_HOST; active_queries[req.id] = this; } } else { req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_hosts (hostname) VALUES('?')", hostname); if(req.Send()) { insert = true; qs = FIND_HOST; active_queries[req.id] = this; } } break; case DONE: delete active_queries[req.id]; active_queries[req.id] = NULL; break; } } }; /* $ModDesc: Logs network-wide data to an SQL database */ class ModuleSQLLog : public Module { ConfigReader* Conf; public: ModuleSQLLog(InspIRCd* Me) : Module::Module(Me) { ServerInstance->UseInterface("SQLutils"); ServerInstance->UseInterface("SQL"); Module* SQLutils = ServerInstance->FindModule("m_sqlutils.so"); if (!SQLutils) throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so."); SQLModule = ServerInstance->FindFeature("SQL"); OnRehash(NULL,""); MyMod = this; active_queries.clear(); } virtual ~ModuleSQLLog() { ServerInstance->DoneWithInterface("SQL"); ServerInstance->DoneWithInterface("SQLutils"); } void Implements(char* List) { List[I_OnRehash] = List[I_OnOper] = List[I_OnGlobalOper] = List[I_OnKill] = 1; List[I_OnPreCommand] = List[I_OnUserConnect] = 1; List[I_OnUserQuit] = List[I_OnLoadModule] = List[I_OnRequest] = 1; } void ReadConfig() { ConfigReader Conf(ServerInstance); dbid = Conf.ReadValue("sqllog","dbid",0); // database id of a database configured in sql module } virtual void OnRehash(userrec* user, const std::string &parameter) { ReadConfig(); } virtual char* OnRequest(Request* request) { if(strcmp(SQLRESID, request->GetId()) == 0) { SQLresult* res; std::map<unsigned long, QueryInfo*>::iterator n; res = static_cast<SQLresult*>(request); n = active_queries.find(res->id); if (n != active_queries.end()) { n->second->Go(res); std::map<unsigned long, QueryInfo*>::iterator n = active_queries.find(res->id); active_queries.erase(n); } return SQLSUCCESS; } return NULL; } void AddLogEntry(int category, const std::string &nick, const std::string &host, const std::string &source) { // is the sql module loaded? If not, we don't attempt to do anything. if (!SQLModule) return; SQLrequest req = SQLreq(this, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source); if(req.Send()) { QueryInfo* i = new QueryInfo(nick, source, host, req.id, category); i->qs = FIND_SOURCE; active_queries[req.id] = i; } } virtual void OnOper(userrec* user, const std::string &opertype) { AddLogEntry(LT_OPER,user->nick,user->host,user->server); } virtual void OnGlobalOper(userrec* user) { AddLogEntry(LT_OPER,user->nick,user->host,user->server); } virtual int OnKill(userrec* source, userrec* dest, const std::string &reason) { AddLogEntry(LT_KILL,dest->nick,dest->host,source->nick); return 0; } virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) { if ((command == "GLINE" || command == "KLINE" || command == "ELINE" || command == "ZLINE") && validated) { AddLogEntry(LT_XLINE,user->nick,command[0]+std::string(":")+std::string(parameters[0]),user->server); } return 0; } virtual void OnUserConnect(userrec* user) { AddLogEntry(LT_CONNECT,user->nick,user->host,user->server); } virtual void OnUserQuit(userrec* user, const std::string &reason, const std::string &oper_message) { AddLogEntry(LT_DISCONNECT,user->nick,user->host,user->server); } virtual void OnLoadModule(Module* mod, const std::string &name) { AddLogEntry(LT_LOADMODULE,name,ServerInstance->Config->ServerName, ServerInstance->Config->ServerName); } virtual Version GetVersion() { return Version(1,1,0,1,VF_VENDOR,API_VERSION); } }; MODULE_INIT(ModuleSQLLog); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "configreader.h"
+#include "m_sqlv2.h"
+
+static Module* SQLModule;
+static Module* MyMod;
+static std::string dbid;
+
+enum LogTypes { LT_OPER = 1, LT_KILL, LT_SERVLINK, LT_XLINE, LT_CONNECT, LT_DISCONNECT, LT_FLOOD, LT_LOADMODULE };
+
+enum QueryState { FIND_SOURCE, FIND_NICK, FIND_HOST, DONE};
+
+class QueryInfo;
+
+std::map<unsigned long,QueryInfo*> active_queries;
+
+class QueryInfo
+{
+public:
+ QueryState qs;
+ unsigned long id;
+ std::string nick;
+ std::string source;
+ std::string hostname;
+ int sourceid;
+ int nickid;
+ int hostid;
+ int category;
+ time_t date;
+ bool insert;
+
+ QueryInfo(const std::string &n, const std::string &s, const std::string &h, unsigned long i, int cat)
+ {
+ qs = FIND_SOURCE;
+ nick = n;
+ source = s;
+ hostname = h;
+ id = i;
+ category = cat;
+ sourceid = nickid = hostid = -1;
+ date = time(NULL);
+ insert = false;
+ }
+
+ void Go(SQLresult* res)
+ {
+ SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "", "");
+ switch (qs)
+ {
+ case FIND_SOURCE:
+ if (res->Rows() && sourceid == -1 && !insert)
+ {
+ sourceid = atoi(res->GetValue(0,0).d.c_str());
+ req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick);
+ if(req.Send())
+ {
+ insert = false;
+ qs = FIND_NICK;
+ active_queries[req.id] = this;
+ }
+ }
+ else if (res->Rows() && sourceid == -1 && insert)
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source);
+ if(req.Send())
+ {
+ insert = false;
+ qs = FIND_SOURCE;
+ active_queries[req.id] = this;
+ }
+ }
+ else
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')", source);
+ if(req.Send())
+ {
+ insert = true;
+ qs = FIND_SOURCE;
+ active_queries[req.id] = this;
+ }
+ }
+ break;
+
+ case FIND_NICK:
+ if (res->Rows() && nickid == -1 && !insert)
+ {
+ nickid = atoi(res->GetValue(0,0).d.c_str());
+ req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname);
+ if(req.Send())
+ {
+ insert = false;
+ qs = FIND_HOST;
+ active_queries[req.id] = this;
+ }
+ }
+ else if (res->Rows() && nickid == -1 && insert)
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick);
+ if(req.Send())
+ {
+ insert = false;
+ qs = FIND_NICK;
+ active_queries[req.id] = this;
+ }
+ }
+ else
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')",nick);
+ if(req.Send())
+ {
+ insert = true;
+ qs = FIND_NICK;
+ active_queries[req.id] = this;
+ }
+ }
+ break;
+
+ case FIND_HOST:
+ if (res->Rows() && hostid == -1 && !insert)
+ {
+ hostid = atoi(res->GetValue(0,0).d.c_str());
+ req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log (category_id,nick,host,source,dtime) VALUES("+ConvToStr(category)+","+ConvToStr(nickid)+","+ConvToStr(hostid)+","+ConvToStr(sourceid)+","+ConvToStr(date)+")");
+ if(req.Send())
+ {
+ insert = true;
+ qs = DONE;
+ active_queries[req.id] = this;
+ }
+ }
+ else if (res->Rows() && hostid == -1 && insert)
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname);
+ if(req.Send())
+ {
+ insert = false;
+ qs = FIND_HOST;
+ active_queries[req.id] = this;
+ }
+ }
+ else
+ {
+ req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_hosts (hostname) VALUES('?')", hostname);
+ if(req.Send())
+ {
+ insert = true;
+ qs = FIND_HOST;
+ active_queries[req.id] = this;
+ }
+ }
+ break;
+
+ case DONE:
+ delete active_queries[req.id];
+ active_queries[req.id] = NULL;
+ break;
+ }
+ }
+};
+
+/* $ModDesc: Logs network-wide data to an SQL database */
+
+class ModuleSQLLog : public Module
+{
+ ConfigReader* Conf;
+
+ public:
+ ModuleSQLLog(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ ServerInstance->UseInterface("SQLutils");
+ ServerInstance->UseInterface("SQL");
+
+ Module* SQLutils = ServerInstance->FindModule("m_sqlutils.so");
+ if (!SQLutils)
+ throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
+
+ SQLModule = ServerInstance->FindFeature("SQL");
+
+ OnRehash(NULL,"");
+ MyMod = this;
+ active_queries.clear();
+ }
+
+ virtual ~ModuleSQLLog()
+ {
+ ServerInstance->DoneWithInterface("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRehash] = List[I_OnOper] = List[I_OnGlobalOper] = List[I_OnKill] = 1;
+ List[I_OnPreCommand] = List[I_OnUserConnect] = 1;
+ List[I_OnUserQuit] = List[I_OnLoadModule] = List[I_OnRequest] = 1;
+ }
+
+ void ReadConfig()
+ {
+ ConfigReader Conf(ServerInstance);
+ dbid = Conf.ReadValue("sqllog","dbid",0); // database id of a database configured in sql module
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ReadConfig();
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLRESID, request->GetId()) == 0)
+ {
+ SQLresult* res;
+ std::map<unsigned long, QueryInfo*>::iterator n;
+
+ res = static_cast<SQLresult*>(request);
+ n = active_queries.find(res->id);
+
+ if (n != active_queries.end())
+ {
+ n->second->Go(res);
+ std::map<unsigned long, QueryInfo*>::iterator n = active_queries.find(res->id);
+ active_queries.erase(n);
+ }
+
+ return SQLSUCCESS;
+ }
+
+ return NULL;
+ }
+
+ void AddLogEntry(int category, const std::string &nick, const std::string &host, const std::string &source)
+ {
+ // is the sql module loaded? If not, we don't attempt to do anything.
+ if (!SQLModule)
+ return;
+
+ SQLrequest req = SQLreq(this, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source);
+ if(req.Send())
+ {
+ QueryInfo* i = new QueryInfo(nick, source, host, req.id, category);
+ i->qs = FIND_SOURCE;
+ active_queries[req.id] = i;
+ }
+ }
+
+ virtual void OnOper(userrec* user, const std::string &opertype)
+ {
+ AddLogEntry(LT_OPER,user->nick,user->host,user->server);
+ }
+
+ virtual void OnGlobalOper(userrec* user)
+ {
+ AddLogEntry(LT_OPER,user->nick,user->host,user->server);
+ }
+
+ virtual int OnKill(userrec* source, userrec* dest, const std::string &reason)
+ {
+ AddLogEntry(LT_KILL,dest->nick,dest->host,source->nick);
+ return 0;
+ }
+
+ virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
+ {
+ if ((command == "GLINE" || command == "KLINE" || command == "ELINE" || command == "ZLINE") && validated)
+ {
+ AddLogEntry(LT_XLINE,user->nick,command[0]+std::string(":")+std::string(parameters[0]),user->server);
+ }
+ return 0;
+ }
+
+ virtual void OnUserConnect(userrec* user)
+ {
+ AddLogEntry(LT_CONNECT,user->nick,user->host,user->server);
+ }
+
+ virtual void OnUserQuit(userrec* user, const std::string &reason, const std::string &oper_message)
+ {
+ AddLogEntry(LT_DISCONNECT,user->nick,user->host,user->server);
+ }
+
+ virtual void OnLoadModule(Module* mod, const std::string &name)
+ {
+ AddLogEntry(LT_LOADMODULE,name,ServerInstance->Config->ServerName, ServerInstance->Config->ServerName);
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,0,1,VF_VENDOR,API_VERSION);
+ }
+
+};
+
+MODULE_INIT(ModuleSQLLog);
+
diff --git a/src/modules/extra/m_sqloper.cpp b/src/modules/extra/m_sqloper.cpp
index 4b09ac26e..520869e21 100644
--- a/src/modules/extra/m_sqloper.cpp
+++ b/src/modules/extra/m_sqloper.cpp
@@ -1 +1,283 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "configreader.h" #include "m_sqlv2.h" #include "m_sqlutils.h" #include "m_hash.h" #include "commands/cmd_oper.h" /* $ModDesc: Allows storage of oper credentials in an SQL table */ /* $ModDep: m_sqlv2.h m_sqlutils.h */ class ModuleSQLOper : public Module { Module* SQLutils; Module* HashModule; std::string databaseid; public: ModuleSQLOper(InspIRCd* Me) : Module::Module(Me) { ServerInstance->UseInterface("SQLutils"); ServerInstance->UseInterface("SQL"); ServerInstance->UseInterface("HashRequest"); /* Attempt to locate the md5 service provider, bail if we can't find it */ HashModule = ServerInstance->FindModule("m_md5.so"); if (!HashModule) throw ModuleException("Can't find m_md5.so. Please load m_md5.so before m_sqloper.so."); SQLutils = ServerInstance->FindModule("m_sqlutils.so"); if (!SQLutils) throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqloper.so."); OnRehash(NULL,""); } virtual ~ModuleSQLOper() { ServerInstance->DoneWithInterface("SQL"); ServerInstance->DoneWithInterface("SQLutils"); ServerInstance->DoneWithInterface("HashRequest"); } void Implements(char* List) { List[I_OnRequest] = List[I_OnRehash] = List[I_OnPreCommand] = 1; } virtual void OnRehash(userrec* user, const std::string &parameter) { ConfigReader Conf(ServerInstance); databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */ } virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) { if ((validated) && (command == "OPER")) { if (LookupOper(user, parameters[0], parameters[1])) { /* Returning true here just means the query is in progress, or on it's way to being * in progress. Nothing about the /oper actually being successful.. * If the oper lookup fails later, we pass the command to the original handler * for /oper by calling its Handle method directly. */ return 1; } } return 0; } bool LookupOper(userrec* user, const std::string &username, const std::string &password) { Module* target; target = ServerInstance->FindFeature("SQL"); if (target) { /* Reset hash module first back to MD5 standard state */ HashResetRequest(this, HashModule).Send(); /* Make an MD5 hash of the password for using in the query */ std::string md5_pass_hash = HashSumRequest(this, HashModule, password.c_str()).Send(); /* We generate our own MD5 sum here because some database providers (e.g. SQLite) dont have a builtin md5 function, * also hashing it in the module and only passing a remote query containing a hash is more secure. */ SQLrequest req = SQLreq(this, target, databaseid, "SELECT username, password, hostname, type FROM ircd_opers WHERE username = '?' AND password='?'", username, md5_pass_hash); if (req.Send()) { /* When we get the query response from the service provider we will be given an ID to play with, * just an ID number which is unique to this query. We need a way of associating that ID with a userrec * so we insert it into a map mapping the IDs to users. * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling * us to discard the query. */ AssociateUser(this, SQLutils, req.id, user).Send(); user->Extend("oper_user", strdup(username.c_str())); user->Extend("oper_pass", strdup(password.c_str())); return true; } else { return false; } } else { ServerInstance->Log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be able to oper up unless their o:line is statically configured"); return false; } } virtual char* OnRequest(Request* request) { if (strcmp(SQLRESID, request->GetId()) == 0) { SQLresult* res = static_cast<SQLresult*>(request); userrec* user = GetAssocUser(this, SQLutils, res->id).S().user; UnAssociate(this, SQLutils, res->id).S(); char* tried_user = NULL; char* tried_pass = NULL; user->GetExt("oper_user", tried_user); user->GetExt("oper_pass", tried_pass); if (user) { if (res->error.Id() == NO_ERROR) { if (res->Rows()) { /* We got a row in the result, this means there was a record for the oper.. * now we just need to check if their host matches, and if it does then * oper them up. * * We now (previous versions of the module didn't) support multiple SQL * rows per-oper in the same way the config file does, all rows will be tried * until one is found which matches. This is useful to define several different * hosts for a single oper. * * The for() loop works as SQLresult::GetRowMap() returns an empty map when there * are no more rows to return. */ for (SQLfieldMap& row = res->GetRowMap(); row.size(); row = res->GetRowMap()) { if (OperUser(user, row["username"].d, row["password"].d, row["hostname"].d, row["type"].d)) { /* If/when one of the rows matches, stop checking and return */ return SQLSUCCESS; } if (tried_user && tried_pass) { LoginFail(user, tried_user, tried_pass); free(tried_user); free(tried_pass); user->Shrink("oper_user"); user->Shrink("oper_pass"); } } } else { /* No rows in result, this means there was no oper line for the user, * we should have already checked the o:lines so now we need an * "insufficient awesomeness" (invalid credentials) error */ if (tried_user && tried_pass) { LoginFail(user, tried_user, tried_pass); free(tried_user); free(tried_pass); user->Shrink("oper_user"); user->Shrink("oper_pass"); } } } else { /* This one shouldn't happen, the query failed for some reason. * We have to fail the /oper request and give them the same error * as above. */ if (tried_user && tried_pass) { LoginFail(user, tried_user, tried_pass); free(tried_user); free(tried_pass); user->Shrink("oper_user"); user->Shrink("oper_pass"); } } } return SQLSUCCESS; } return NULL; } void LoginFail(userrec* user, const std::string &username, const std::string &pass) { command_t* oper_command = ServerInstance->Parser->GetHandler("OPER"); if (oper_command) { const char* params[] = { username.c_str(), pass.c_str() }; oper_command->Handle(params, 2, user); } else { ServerInstance->Log(DEBUG, "BUG: WHAT?! Why do we have no OPER command?!"); } } bool OperUser(userrec* user, const std::string &username, const std::string &password, const std::string &pattern, const std::string &type) { ConfigReader Conf(ServerInstance); for (int j = 0; j < Conf.Enumerate("type"); j++) { std::string tname = Conf.ReadValue("type","name",j); std::string hostname(user->ident); hostname.append("@").append(user->host); if ((tname == type) && OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str())) { /* Opertype and host match, looks like this is it. */ std::string operhost = Conf.ReadValue("type", "host", j); if (operhost.size()) user->ChangeDisplayedHost(operhost.c_str()); ServerInstance->SNO->WriteToSnoMask('o',"%s (%s@%s) is now an IRC operator of type %s", user->nick, user->ident, user->host, type.c_str()); user->WriteServ("381 %s :You are now an IRC operator of type %s", user->nick, type.c_str()); if (!user->modes[UM_OPERATOR]) user->Oper(type); return true; } } return false; } virtual Version GetVersion() { return Version(1,1,1,0,VF_VENDOR,API_VERSION); } }; MODULE_INIT(ModuleSQLOper); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "configreader.h"
+
+#include "m_sqlv2.h"
+#include "m_sqlutils.h"
+#include "m_hash.h"
+#include "commands/cmd_oper.h"
+
+/* $ModDesc: Allows storage of oper credentials in an SQL table */
+/* $ModDep: m_sqlv2.h m_sqlutils.h */
+
+class ModuleSQLOper : public Module
+{
+ Module* SQLutils;
+ Module* HashModule;
+ std::string databaseid;
+
+public:
+ ModuleSQLOper(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ ServerInstance->UseInterface("SQLutils");
+ ServerInstance->UseInterface("SQL");
+ ServerInstance->UseInterface("HashRequest");
+
+ /* Attempt to locate the md5 service provider, bail if we can't find it */
+ HashModule = ServerInstance->FindModule("m_md5.so");
+ if (!HashModule)
+ throw ModuleException("Can't find m_md5.so. Please load m_md5.so before m_sqloper.so.");
+
+ SQLutils = ServerInstance->FindModule("m_sqlutils.so");
+ if (!SQLutils)
+ throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqloper.so.");
+
+ OnRehash(NULL,"");
+ }
+
+ virtual ~ModuleSQLOper()
+ {
+ ServerInstance->DoneWithInterface("SQL");
+ ServerInstance->DoneWithInterface("SQLutils");
+ ServerInstance->DoneWithInterface("HashRequest");
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRequest] = List[I_OnRehash] = List[I_OnPreCommand] = 1;
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ ConfigReader Conf(ServerInstance);
+
+ databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */
+ }
+
+ virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
+ {
+ if ((validated) && (command == "OPER"))
+ {
+ if (LookupOper(user, parameters[0], parameters[1]))
+ {
+ /* Returning true here just means the query is in progress, or on it's way to being
+ * in progress. Nothing about the /oper actually being successful..
+ * If the oper lookup fails later, we pass the command to the original handler
+ * for /oper by calling its Handle method directly.
+ */
+ return 1;
+ }
+ }
+ return 0;
+ }
+
+ bool LookupOper(userrec* user, const std::string &username, const std::string &password)
+ {
+ Module* target;
+
+ target = ServerInstance->FindFeature("SQL");
+
+ if (target)
+ {
+ /* Reset hash module first back to MD5 standard state */
+ HashResetRequest(this, HashModule).Send();
+ /* Make an MD5 hash of the password for using in the query */
+ std::string md5_pass_hash = HashSumRequest(this, HashModule, password.c_str()).Send();
+
+ /* We generate our own MD5 sum here because some database providers (e.g. SQLite) dont have a builtin md5 function,
+ * also hashing it in the module and only passing a remote query containing a hash is more secure.
+ */
+
+ SQLrequest req = SQLreq(this, target, databaseid, "SELECT username, password, hostname, type FROM ircd_opers WHERE username = '?' AND password='?'", username, md5_pass_hash);
+
+ if (req.Send())
+ {
+ /* When we get the query response from the service provider we will be given an ID to play with,
+ * just an ID number which is unique to this query. We need a way of associating that ID with a userrec
+ * so we insert it into a map mapping the IDs to users.
+ * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+ * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+ * us to discard the query.
+ */
+ AssociateUser(this, SQLutils, req.id, user).Send();
+
+ user->Extend("oper_user", strdup(username.c_str()));
+ user->Extend("oper_pass", strdup(password.c_str()));
+
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ServerInstance->Log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be able to oper up unless their o:line is statically configured");
+ return false;
+ }
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if (strcmp(SQLRESID, request->GetId()) == 0)
+ {
+ SQLresult* res = static_cast<SQLresult*>(request);
+
+ userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
+ UnAssociate(this, SQLutils, res->id).S();
+
+ char* tried_user = NULL;
+ char* tried_pass = NULL;
+
+ user->GetExt("oper_user", tried_user);
+ user->GetExt("oper_pass", tried_pass);
+
+ if (user)
+ {
+ if (res->error.Id() == NO_ERROR)
+ {
+ if (res->Rows())
+ {
+ /* We got a row in the result, this means there was a record for the oper..
+ * now we just need to check if their host matches, and if it does then
+ * oper them up.
+ *
+ * We now (previous versions of the module didn't) support multiple SQL
+ * rows per-oper in the same way the config file does, all rows will be tried
+ * until one is found which matches. This is useful to define several different
+ * hosts for a single oper.
+ *
+ * The for() loop works as SQLresult::GetRowMap() returns an empty map when there
+ * are no more rows to return.
+ */
+
+ for (SQLfieldMap& row = res->GetRowMap(); row.size(); row = res->GetRowMap())
+ {
+ if (OperUser(user, row["username"].d, row["password"].d, row["hostname"].d, row["type"].d))
+ {
+ /* If/when one of the rows matches, stop checking and return */
+ return SQLSUCCESS;
+ }
+ if (tried_user && tried_pass)
+ {
+ LoginFail(user, tried_user, tried_pass);
+ free(tried_user);
+ free(tried_pass);
+ user->Shrink("oper_user");
+ user->Shrink("oper_pass");
+ }
+ }
+ }
+ else
+ {
+ /* No rows in result, this means there was no oper line for the user,
+ * we should have already checked the o:lines so now we need an
+ * "insufficient awesomeness" (invalid credentials) error
+ */
+ if (tried_user && tried_pass)
+ {
+ LoginFail(user, tried_user, tried_pass);
+ free(tried_user);
+ free(tried_pass);
+ user->Shrink("oper_user");
+ user->Shrink("oper_pass");
+ }
+ }
+ }
+ else
+ {
+ /* This one shouldn't happen, the query failed for some reason.
+ * We have to fail the /oper request and give them the same error
+ * as above.
+ */
+ if (tried_user && tried_pass)
+ {
+ LoginFail(user, tried_user, tried_pass);
+ free(tried_user);
+ free(tried_pass);
+ user->Shrink("oper_user");
+ user->Shrink("oper_pass");
+ }
+
+ }
+ }
+
+ return SQLSUCCESS;
+ }
+
+ return NULL;
+ }
+
+ void LoginFail(userrec* user, const std::string &username, const std::string &pass)
+ {
+ command_t* oper_command = ServerInstance->Parser->GetHandler("OPER");
+
+ if (oper_command)
+ {
+ const char* params[] = { username.c_str(), pass.c_str() };
+ oper_command->Handle(params, 2, user);
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "BUG: WHAT?! Why do we have no OPER command?!");
+ }
+ }
+
+ bool OperUser(userrec* user, const std::string &username, const std::string &password, const std::string &pattern, const std::string &type)
+ {
+ ConfigReader Conf(ServerInstance);
+
+ for (int j = 0; j < Conf.Enumerate("type"); j++)
+ {
+ std::string tname = Conf.ReadValue("type","name",j);
+ std::string hostname(user->ident);
+
+ hostname.append("@").append(user->host);
+
+ if ((tname == type) && OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str()))
+ {
+ /* Opertype and host match, looks like this is it. */
+ std::string operhost = Conf.ReadValue("type", "host", j);
+
+ if (operhost.size())
+ user->ChangeDisplayedHost(operhost.c_str());
+
+ ServerInstance->SNO->WriteToSnoMask('o',"%s (%s@%s) is now an IRC operator of type %s", user->nick, user->ident, user->host, type.c_str());
+ user->WriteServ("381 %s :You are now an IRC operator of type %s", user->nick, type.c_str());
+
+ if (!user->modes[UM_OPERATOR])
+ user->Oper(type);
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,1,0,VF_VENDOR,API_VERSION);
+ }
+
+};
+
+MODULE_INIT(ModuleSQLOper);
+
diff --git a/src/modules/extra/m_sqlutils.cpp b/src/modules/extra/m_sqlutils.cpp
index 6cd09252b..b470f99af 100644
--- a/src/modules/extra/m_sqlutils.cpp
+++ b/src/modules/extra/m_sqlutils.cpp
@@ -1 +1,238 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <sstream> #include <list> #include "users.h" #include "channels.h" #include "modules.h" #include "configreader.h" #include "m_sqlutils.h" /* $ModDesc: Provides some utilities to SQL client modules, such as mapping queries to users and channels */ /* $ModDep: m_sqlutils.h */ typedef std::map<unsigned long, userrec*> IdUserMap; typedef std::map<unsigned long, chanrec*> IdChanMap; typedef std::list<unsigned long> AssocIdList; class ModuleSQLutils : public Module { private: IdUserMap iduser; IdChanMap idchan; public: ModuleSQLutils(InspIRCd* Me) : Module::Module(Me) { ServerInstance->PublishInterface("SQLutils", this); } virtual ~ModuleSQLutils() { ServerInstance->UnpublishInterface("SQLutils", this); } void Implements(char* List) { List[I_OnChannelDelete] = List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnUserDisconnect] = 1; } virtual char* OnRequest(Request* request) { if(strcmp(SQLUTILAU, request->GetId()) == 0) { AssociateUser* req = (AssociateUser*)request; iduser.insert(std::make_pair(req->id, req->user)); AttachList(req->user, req->id); } else if(strcmp(SQLUTILAC, request->GetId()) == 0) { AssociateChan* req = (AssociateChan*)request; idchan.insert(std::make_pair(req->id, req->chan)); AttachList(req->chan, req->id); } else if(strcmp(SQLUTILUA, request->GetId()) == 0) { UnAssociate* req = (UnAssociate*)request; /* Unassociate a given query ID with all users and channels * it is associated with. */ DoUnAssociate(iduser, req->id); DoUnAssociate(idchan, req->id); } else if(strcmp(SQLUTILGU, request->GetId()) == 0) { GetAssocUser* req = (GetAssocUser*)request; IdUserMap::iterator iter = iduser.find(req->id); if(iter != iduser.end()) { req->user = iter->second; } } else if(strcmp(SQLUTILGC, request->GetId()) == 0) { GetAssocChan* req = (GetAssocChan*)request; IdChanMap::iterator iter = idchan.find(req->id); if(iter != idchan.end()) { req->chan = iter->second; } } return SQLUTILSUCCESS; } virtual void OnUserDisconnect(userrec* user) { /* A user is disconnecting, first we need to check if they have a list of queries associated with them. * Then, if they do, we need to erase each of them from our IdUserMap (iduser) so when the module that * associated them asks to look them up then it gets a NULL result and knows to discard the query. */ AssocIdList* il; if(user->GetExt("sqlutils_queryids", il)) { for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++) { IdUserMap::iterator iter; iter = iduser.find(*listiter); if(iter != iduser.end()) { if(iter->second != user) { ServerInstance->Log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map (erasing anyway)", user->nick); } iduser.erase(iter); } else { ServerInstance->Log(DEBUG, "BUG: user %s was extended with sqlutils_queryids but there was nothing matching in the map", user->nick); } } user->Shrink("sqlutils_queryids"); delete il; } } void AttachList(Extensible* obj, unsigned long id) { AssocIdList* il; if(!obj->GetExt("sqlutils_queryids", il)) { /* Doesn't already exist, create a new list and attach it. */ il = new AssocIdList; obj->Extend("sqlutils_queryids", il); } /* Now either way we have a valid list in il, attached. */ il->push_back(id); } void RemoveFromList(Extensible* obj, unsigned long id) { AssocIdList* il; if(obj->GetExt("sqlutils_queryids", il)) { /* Only do anything if the list exists... (which it ought to) */ il->remove(id); if(il->empty()) { /* If we just emptied it.. */ delete il; obj->Shrink("sqlutils_queryids"); } } } template <class T> void DoUnAssociate(T &map, unsigned long id) { /* For each occurence of 'id' (well, only one..it's not a multimap) in 'map' * remove it from the map, take an Extensible* value from the map and remove * 'id' from the list of query IDs attached to it. */ typename T::iterator iter = map.find(id); if(iter != map.end()) { /* Found a value indexed by 'id', call RemoveFromList() * on it with 'id' to remove 'id' from the list attached * to the value. */ RemoveFromList(iter->second, id); } } virtual void OnChannelDelete(chanrec* chan) { /* A channel is being destroyed, first we need to check if it has a list of queries associated with it. * Then, if it does, we need to erase each of them from our IdChanMap (idchan) so when the module that * associated them asks to look them up then it gets a NULL result and knows to discard the query. */ AssocIdList* il; if(chan->GetExt("sqlutils_queryids", il)) { for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++) { IdChanMap::iterator iter; iter = idchan.find(*listiter); if(iter != idchan.end()) { if(iter->second != chan) { ServerInstance->Log(DEBUG, "BUG: ID associated with channel %s doesn't have the same chanrec* associated with it in the map (erasing anyway)", chan->name); } idchan.erase(iter); } else { ServerInstance->Log(DEBUG, "BUG: channel %s was extended with sqlutils_queryids but there was nothing matching in the map", chan->name); } } chan->Shrink("sqlutils_queryids"); delete il; } } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION); } }; MODULE_INIT(ModuleSQLutils); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <sstream>
+#include <list>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "configreader.h"
+#include "m_sqlutils.h"
+
+/* $ModDesc: Provides some utilities to SQL client modules, such as mapping queries to users and channels */
+/* $ModDep: m_sqlutils.h */
+
+typedef std::map<unsigned long, userrec*> IdUserMap;
+typedef std::map<unsigned long, chanrec*> IdChanMap;
+typedef std::list<unsigned long> AssocIdList;
+
+class ModuleSQLutils : public Module
+{
+private:
+ IdUserMap iduser;
+ IdChanMap idchan;
+
+public:
+ ModuleSQLutils(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ ServerInstance->PublishInterface("SQLutils", this);
+ }
+
+ virtual ~ModuleSQLutils()
+ {
+ ServerInstance->UnpublishInterface("SQLutils", this);
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnChannelDelete] = List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnUserDisconnect] = 1;
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLUTILAU, request->GetId()) == 0)
+ {
+ AssociateUser* req = (AssociateUser*)request;
+
+ iduser.insert(std::make_pair(req->id, req->user));
+
+ AttachList(req->user, req->id);
+ }
+ else if(strcmp(SQLUTILAC, request->GetId()) == 0)
+ {
+ AssociateChan* req = (AssociateChan*)request;
+
+ idchan.insert(std::make_pair(req->id, req->chan));
+
+ AttachList(req->chan, req->id);
+ }
+ else if(strcmp(SQLUTILUA, request->GetId()) == 0)
+ {
+ UnAssociate* req = (UnAssociate*)request;
+
+ /* Unassociate a given query ID with all users and channels
+ * it is associated with.
+ */
+
+ DoUnAssociate(iduser, req->id);
+ DoUnAssociate(idchan, req->id);
+ }
+ else if(strcmp(SQLUTILGU, request->GetId()) == 0)
+ {
+ GetAssocUser* req = (GetAssocUser*)request;
+
+ IdUserMap::iterator iter = iduser.find(req->id);
+
+ if(iter != iduser.end())
+ {
+ req->user = iter->second;
+ }
+ }
+ else if(strcmp(SQLUTILGC, request->GetId()) == 0)
+ {
+ GetAssocChan* req = (GetAssocChan*)request;
+
+ IdChanMap::iterator iter = idchan.find(req->id);
+
+ if(iter != idchan.end())
+ {
+ req->chan = iter->second;
+ }
+ }
+
+ return SQLUTILSUCCESS;
+ }
+
+ virtual void OnUserDisconnect(userrec* user)
+ {
+ /* A user is disconnecting, first we need to check if they have a list of queries associated with them.
+ * Then, if they do, we need to erase each of them from our IdUserMap (iduser) so when the module that
+ * associated them asks to look them up then it gets a NULL result and knows to discard the query.
+ */
+ AssocIdList* il;
+
+ if(user->GetExt("sqlutils_queryids", il))
+ {
+ for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++)
+ {
+ IdUserMap::iterator iter;
+
+ iter = iduser.find(*listiter);
+
+ if(iter != iduser.end())
+ {
+ if(iter->second != user)
+ {
+ ServerInstance->Log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map (erasing anyway)", user->nick);
+ }
+
+ iduser.erase(iter);
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "BUG: user %s was extended with sqlutils_queryids but there was nothing matching in the map", user->nick);
+ }
+ }
+
+ user->Shrink("sqlutils_queryids");
+ delete il;
+ }
+ }
+
+ void AttachList(Extensible* obj, unsigned long id)
+ {
+ AssocIdList* il;
+
+ if(!obj->GetExt("sqlutils_queryids", il))
+ {
+ /* Doesn't already exist, create a new list and attach it. */
+ il = new AssocIdList;
+ obj->Extend("sqlutils_queryids", il);
+ }
+
+ /* Now either way we have a valid list in il, attached. */
+ il->push_back(id);
+ }
+
+ void RemoveFromList(Extensible* obj, unsigned long id)
+ {
+ AssocIdList* il;
+
+ if(obj->GetExt("sqlutils_queryids", il))
+ {
+ /* Only do anything if the list exists... (which it ought to) */
+ il->remove(id);
+
+ if(il->empty())
+ {
+ /* If we just emptied it.. */
+ delete il;
+ obj->Shrink("sqlutils_queryids");
+ }
+ }
+ }
+
+ template <class T> void DoUnAssociate(T &map, unsigned long id)
+ {
+ /* For each occurence of 'id' (well, only one..it's not a multimap) in 'map'
+ * remove it from the map, take an Extensible* value from the map and remove
+ * 'id' from the list of query IDs attached to it.
+ */
+ typename T::iterator iter = map.find(id);
+
+ if(iter != map.end())
+ {
+ /* Found a value indexed by 'id', call RemoveFromList()
+ * on it with 'id' to remove 'id' from the list attached
+ * to the value.
+ */
+ RemoveFromList(iter->second, id);
+ }
+ }
+
+ virtual void OnChannelDelete(chanrec* chan)
+ {
+ /* A channel is being destroyed, first we need to check if it has a list of queries associated with it.
+ * Then, if it does, we need to erase each of them from our IdChanMap (idchan) so when the module that
+ * associated them asks to look them up then it gets a NULL result and knows to discard the query.
+ */
+ AssocIdList* il;
+
+ if(chan->GetExt("sqlutils_queryids", il))
+ {
+ for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++)
+ {
+ IdChanMap::iterator iter;
+
+ iter = idchan.find(*listiter);
+
+ if(iter != idchan.end())
+ {
+ if(iter->second != chan)
+ {
+ ServerInstance->Log(DEBUG, "BUG: ID associated with channel %s doesn't have the same chanrec* associated with it in the map (erasing anyway)", chan->name);
+ }
+ idchan.erase(iter);
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "BUG: channel %s was extended with sqlutils_queryids but there was nothing matching in the map", chan->name);
+ }
+ }
+
+ chan->Shrink("sqlutils_queryids");
+ delete il;
+ }
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION);
+ }
+
+};
+
+MODULE_INIT(ModuleSQLutils);
+
diff --git a/src/modules/extra/m_sqlutils.h b/src/modules/extra/m_sqlutils.h
index cdde51f67..92fbdf5c7 100644
--- a/src/modules/extra/m_sqlutils.h
+++ b/src/modules/extra/m_sqlutils.h
@@ -1 +1,143 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #ifndef INSPIRCD_SQLUTILS #define INSPIRCD_SQLUTILS #include "modules.h" #define SQLUTILAU "SQLutil AssociateUser" #define SQLUTILAC "SQLutil AssociateChan" #define SQLUTILUA "SQLutil UnAssociate" #define SQLUTILGU "SQLutil GetAssocUser" #define SQLUTILGC "SQLutil GetAssocChan" #define SQLUTILSUCCESS "You shouldn't be reading this (success)" /** Used to associate an SQL query with a user */ class AssociateUser : public Request { public: /** Query ID */ unsigned long id; /** User */ userrec* user; AssociateUser(Module* s, Module* d, unsigned long i, userrec* u) : Request(s, d, SQLUTILAU), id(i), user(u) { } AssociateUser& S() { Send(); return *this; } }; /** Used to associate an SQL query with a channel */ class AssociateChan : public Request { public: /** Query ID */ unsigned long id; /** Channel */ chanrec* chan; AssociateChan(Module* s, Module* d, unsigned long i, chanrec* u) : Request(s, d, SQLUTILAC), id(i), chan(u) { } AssociateChan& S() { Send(); return *this; } }; /** Unassociate a user or class from an SQL query */ class UnAssociate : public Request { public: /** The query ID */ unsigned long id; UnAssociate(Module* s, Module* d, unsigned long i) : Request(s, d, SQLUTILUA), id(i) { } UnAssociate& S() { Send(); return *this; } }; /** Get the user associated with an SQL query ID */ class GetAssocUser : public Request { public: /** The query id */ unsigned long id; /** The user */ userrec* user; GetAssocUser(Module* s, Module* d, unsigned long i) : Request(s, d, SQLUTILGU), id(i), user(NULL) { } GetAssocUser& S() { Send(); return *this; } }; /** Get the channel associated with an SQL query ID */ class GetAssocChan : public Request { public: /** The query id */ unsigned long id; /** The channel */ chanrec* chan; GetAssocChan(Module* s, Module* d, unsigned long i) : Request(s, d, SQLUTILGC), id(i), chan(NULL) { } GetAssocChan& S() { Send(); return *this; } }; #endif \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#ifndef INSPIRCD_SQLUTILS
+#define INSPIRCD_SQLUTILS
+
+#include "modules.h"
+
+#define SQLUTILAU "SQLutil AssociateUser"
+#define SQLUTILAC "SQLutil AssociateChan"
+#define SQLUTILUA "SQLutil UnAssociate"
+#define SQLUTILGU "SQLutil GetAssocUser"
+#define SQLUTILGC "SQLutil GetAssocChan"
+#define SQLUTILSUCCESS "You shouldn't be reading this (success)"
+
+/** Used to associate an SQL query with a user
+ */
+class AssociateUser : public Request
+{
+public:
+ /** Query ID
+ */
+ unsigned long id;
+ /** User
+ */
+ userrec* user;
+
+ AssociateUser(Module* s, Module* d, unsigned long i, userrec* u)
+ : Request(s, d, SQLUTILAU), id(i), user(u)
+ {
+ }
+
+ AssociateUser& S()
+ {
+ Send();
+ return *this;
+ }
+};
+
+/** Used to associate an SQL query with a channel
+ */
+class AssociateChan : public Request
+{
+public:
+ /** Query ID
+ */
+ unsigned long id;
+ /** Channel
+ */
+ chanrec* chan;
+
+ AssociateChan(Module* s, Module* d, unsigned long i, chanrec* u)
+ : Request(s, d, SQLUTILAC), id(i), chan(u)
+ {
+ }
+
+ AssociateChan& S()
+ {
+ Send();
+ return *this;
+ }
+};
+
+/** Unassociate a user or class from an SQL query
+ */
+class UnAssociate : public Request
+{
+public:
+ /** The query ID
+ */
+ unsigned long id;
+
+ UnAssociate(Module* s, Module* d, unsigned long i)
+ : Request(s, d, SQLUTILUA), id(i)
+ {
+ }
+
+ UnAssociate& S()
+ {
+ Send();
+ return *this;
+ }
+};
+
+/** Get the user associated with an SQL query ID
+ */
+class GetAssocUser : public Request
+{
+public:
+ /** The query id
+ */
+ unsigned long id;
+ /** The user
+ */
+ userrec* user;
+
+ GetAssocUser(Module* s, Module* d, unsigned long i)
+ : Request(s, d, SQLUTILGU), id(i), user(NULL)
+ {
+ }
+
+ GetAssocUser& S()
+ {
+ Send();
+ return *this;
+ }
+};
+
+/** Get the channel associated with an SQL query ID
+ */
+class GetAssocChan : public Request
+{
+public:
+ /** The query id
+ */
+ unsigned long id;
+ /** The channel
+ */
+ chanrec* chan;
+
+ GetAssocChan(Module* s, Module* d, unsigned long i)
+ : Request(s, d, SQLUTILGC), id(i), chan(NULL)
+ {
+ }
+
+ GetAssocChan& S()
+ {
+ Send();
+ return *this;
+ }
+};
+
+#endif
diff --git a/src/modules/extra/m_sqlv2.h b/src/modules/extra/m_sqlv2.h
index decac4b57..c7f6edbb9 100644
--- a/src/modules/extra/m_sqlv2.h
+++ b/src/modules/extra/m_sqlv2.h
@@ -1 +1,605 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #ifndef INSPIRCD_SQLAPI_2 #define INSPIRCD_SQLAPI_2 #include <string> #include <deque> #include <map> #include "modules.h" /** SQLreq define. * This is the voodoo magic which lets us pass multiple * parameters to the SQLrequest constructor... voodoo... */ #define SQLreq(a, b, c, d, e...) SQLrequest(a, b, c, (SQLquery(d), ##e)) /** Identifiers used to identify Request types */ #define SQLREQID "SQLv2 Request" #define SQLRESID "SQLv2 Result" #define SQLSUCCESS "You shouldn't be reading this (success)" /** Defines the error types which SQLerror may be set to */ enum SQLerrorNum { NO_ERROR, BAD_DBID, BAD_CONN, QSEND_FAIL, QREPLY_FAIL }; /** A list of format parameters for an SQLquery object. */ typedef std::deque<std::string> ParamL; /** The base class of SQL exceptions */ class SQLexception : public ModuleException { public: SQLexception(const std::string &reason) : ModuleException(reason) { } SQLexception() : ModuleException("SQLv2: Undefined exception") { } }; /** An exception thrown when a bad column or row name or id is requested */ class SQLbadColName : public SQLexception { public: SQLbadColName() : SQLexception("SQLv2: Bad column name") { } }; /** SQLerror holds the error state of any SQLrequest or SQLresult. * The error string varies from database software to database software * and should be used to display informational error messages to users. */ class SQLerror : public classbase { /** The error id */ SQLerrorNum id; /** The error string */ std::string str; public: /** Initialize an SQLerror * @param i The error ID to set * @param s The (optional) error string to set */ SQLerror(SQLerrorNum i = NO_ERROR, const std::string &s = "") : id(i), str(s) { } /** Return the ID of the error */ SQLerrorNum Id() { return id; } /** Set the ID of an error * @param i The new error ID to set * @return the ID which was set */ SQLerrorNum Id(SQLerrorNum i) { id = i; return id; } /** Set the error string for an error * @param s The new error string to set */ void Str(const std::string &s) { str = s; } /** Return the error string for an error */ const char* Str() { if(str.length()) return str.c_str(); switch(id) { case NO_ERROR: return "No error"; case BAD_DBID: return "Invalid database ID"; case BAD_CONN: return "Invalid connection"; case QSEND_FAIL: return "Sending query failed"; case QREPLY_FAIL: return "Getting query result failed"; default: return "Unknown error"; } } }; /** SQLquery provides a way to represent a query string, and its parameters in a type-safe way. * C++ has no native type-safe way of having a variable number of arguments to a function, * the workaround for this isn't easy to describe simply, but in a nutshell what's really * happening when - from the above example - you do this: * * SQLrequest foo = SQLreq(this, target, "databaseid", "SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?", "Hello", "42"); * * what's actually happening is functionally this: * * SQLrequest foo = SQLreq(this, target, "databaseid", query("SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?").addparam("Hello").addparam("42")); * * with 'query()' returning a reference to an object with a 'addparam()' member function which * in turn returns a reference to that object. There are actually four ways you can create a * SQLrequest..all have their disadvantages and advantages. In the real implementations the * 'query()' function is replaced by the constructor of another class 'SQLquery' which holds * the query string and a ParamL (std::deque<std::string>) of query parameters. * This is essentially the same as the above example except 'addparam()' is replaced by operator,(). The full syntax for this method is: * * SQLrequest foo = SQLrequest(this, target, "databaseid", (SQLquery("SELECT.. ?"), parameter, parameter)); */ class SQLquery : public classbase { public: /** The query 'format string' */ std::string q; /** The query parameter list * There should be one parameter for every ? character * within the format string shown above. */ ParamL p; /** Initialize an SQLquery with a given format string only */ SQLquery(const std::string &query) : q(query) { } /** Initialize an SQLquery with a format string and parameters. * If you provide parameters, you must initialize the list yourself * if you choose to do it via this method, using std::deque::push_back(). */ SQLquery(const std::string &query, const ParamL &params) : q(query), p(params) { } /** An overloaded operator for pushing parameters onto the parameter list */ template<typename T> SQLquery& operator,(const T &foo) { p.push_back(ConvToStr(foo)); return *this; } /** An overloaded operator for pushing parameters onto the parameter list. * This has higher precedence than 'operator,' and can save on parenthesis. */ template<typename T> SQLquery& operator%(const T &foo) { p.push_back(ConvToStr(foo)); return *this; } }; /** SQLrequest is sent to the SQL API to command it to run a query and return the result. * You must instantiate this object with a valid SQLquery object and its parameters, then * send it using its Send() method to the module providing the 'SQL' feature. To find this * module, use Server::FindFeature(). */ class SQLrequest : public Request { public: /** The fully parsed and expanded query string * This is initialized from the SQLquery parameter given in the constructor. */ SQLquery query; /** The database ID to apply the request to */ std::string dbid; /** True if this is a priority query. * Priority queries may 'queue jump' in the request queue. */ bool pri; /** The query ID, assigned by the SQL api. * After your request is processed, this will * be initialized for you by the API to a valid request ID, * except in the case of an error. */ unsigned long id; /** If an error occured, error.id will be any other value than NO_ERROR. */ SQLerror error; /** Initialize an SQLrequest. * For example: * * SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors VALUES('','?')", nick); * * @param s A pointer to the sending module, where the result should be routed * @param d A pointer to the receiving module, identified as implementing the 'SQL' feature * @param databaseid The database ID to perform the query on. This must match a valid * database ID from the configuration of the SQL module. * @param q A properly initialized SQLquery object. */ SQLrequest(Module* s, Module* d, const std::string &databaseid, const SQLquery &q) : Request(s, d, SQLREQID), query(q), dbid(databaseid), pri(false), id(0) { } /** Set the priority of a request. */ void Priority(bool p = true) { pri = p; } /** Set the source of a request. You should not need to use this method. */ void SetSource(Module* mod) { source = mod; } }; /** * This class contains a field's data plus a way to determine if the field * is NULL or not without having to mess around with NULL pointers. */ class SQLfield { public: /** * The data itself */ std::string d; /** * If the field was null */ bool null; /** Initialize an SQLfield */ SQLfield(const std::string &data = "", bool n = false) : d(data), null(n) { } }; /** A list of items which make up a row of a result or table (tuple) * This does not include field names. */ typedef std::vector<SQLfield> SQLfieldList; /** A list of items which make up a row of a result or table (tuple) * This also includes the field names. */ typedef std::map<std::string, SQLfield> SQLfieldMap; /** SQLresult is a reply to a previous query. * If you send a query to the SQL api, the response will arrive at your * OnRequest method of your module at some later time, depending on the * congestion of the SQL server and complexity of the query. The ID of * this result will match the ID assigned to your original request. * SQLresult contains its own internal cursor (row counter) which is * incremented with each method call which retrieves a single row. */ class SQLresult : public Request { public: /** The original query string passed initially to the SQL API */ std::string query; /** The database ID the query was executed on */ std::string dbid; /** * The error (if any) which occured. * If an error occured the value of error.id will be any * other value than NO_ERROR. */ SQLerror error; /** * This will match query ID you were given when sending * the request at an earlier time. */ unsigned long id; /** Used by the SQL API to instantiate an SQLrequest */ SQLresult(Module* s, Module* d, unsigned long i) : Request(s, d, SQLRESID), id(i) { } /** * Return the number of rows in the result * Note that if you have perfomed an INSERT * or UPDATE query or other query which will * not return rows, this will return the * number of affected rows, and SQLresult::Cols() * will contain 0. In this case you SHOULD NEVER * access any of the result set rows, as there arent any! * @returns Number of rows in the result set. */ virtual int Rows() = 0; /** * Return the number of columns in the result. * If you performed an UPDATE or INSERT which * does not return a dataset, this value will * be 0. * @returns Number of columns in the result set. */ virtual int Cols() = 0; /** * Get a string name of the column by an index number * @param column The id number of a column * @returns The column name associated with the given ID */ virtual std::string ColName(int column) = 0; /** * Get an index number for a column from a string name. * An exception of type SQLbadColName will be thrown if * the name given is invalid. * @param column The column name to get the ID of * @returns The ID number of the column provided */ virtual int ColNum(const std::string &column) = 0; /** * Get a string value in a given row and column * This does not effect the internal cursor. * @returns The value stored at [row,column] in the table */ virtual SQLfield GetValue(int row, int column) = 0; /** * Return a list of values in a row, this should * increment an internal counter so you can repeatedly * call it until it returns an empty vector. * This returns a reference to an internal object, * the same object is used for all calls to this function * and therefore the return value is only valid until * you call this function again. It is also invalid if * the SQLresult object is destroyed. * The internal cursor (row counter) is incremented by one. * @returns A reference to the current row's SQLfieldList */ virtual SQLfieldList& GetRow() = 0; /** * As above, but return a map indexed by key name. * The internal cursor (row counter) is incremented by one. * @returns A reference to the current row's SQLfieldMap */ virtual SQLfieldMap& GetRowMap() = 0; /** * Like GetRow(), but returns a pointer to a dynamically * allocated object which must be explicitly freed. For * portability reasons this must be freed with SQLresult::Free() * The internal cursor (row counter) is incremented by one. * @returns A newly-allocated SQLfieldList */ virtual SQLfieldList* GetRowPtr() = 0; /** * As above, but return a map indexed by key name * The internal cursor (row counter) is incremented by one. * @returns A newly-allocated SQLfieldMap */ virtual SQLfieldMap* GetRowMapPtr() = 0; /** * Overloaded function for freeing the lists and maps * returned by GetRowPtr or GetRowMapPtr. * @param fm The SQLfieldMap to free */ virtual void Free(SQLfieldMap* fm) = 0; /** * Overloaded function for freeing the lists and maps * returned by GetRowPtr or GetRowMapPtr. * @param fl The SQLfieldList to free */ virtual void Free(SQLfieldList* fl) = 0; }; /** SQLHost represents a <database> config line and is useful * for storing in a map and iterating on rehash to see which * <database> tags was added/removed/unchanged. */ class SQLhost { public: std::string id; /* Database handle id */ std::string host; /* Database server hostname */ std::string ip; /* resolved IP, needed for at least pgsql.so */ unsigned int port; /* Database server port */ std::string name; /* Database name */ std::string user; /* Database username */ std::string pass; /* Database password */ bool ssl; /* If we should require SSL */ SQLhost() { } SQLhost(const std::string& i, const std::string& h, unsigned int p, const std::string& n, const std::string& u, const std::string& pa, bool s) : id(i), host(h), port(p), name(n), user(u), pass(pa), ssl(s) { } /** Overload this to return a correct Data source Name (DSN) for * the current SQL module. */ std::string GetDSN(); }; /** Overload operator== for two SQLhost objects for easy comparison. */ bool operator== (const SQLhost& l, const SQLhost& r) { return (l.id == r.id && l.host == r.host && l.port == r.port && l.name == r.name && l.user == l.user && l.pass == r.pass && l.ssl == r.ssl); } /** QueryQueue, a queue of queries waiting to be executed. * This maintains two queues internally, one for 'priority' * queries and one for less important ones. Each queue has * new queries appended to it and ones to execute are popped * off the front. This keeps them flowing round nicely and no * query should ever get 'stuck' for too long. If there are * queries in the priority queue they will be executed first, * 'unimportant' queries will only be executed when the * priority queue is empty. * * We store lists of SQLrequest's here, by value as we want to avoid storing * any data allocated inside the client module (in case that module is unloaded * while the query is in progress). * * Because we want to work on the current SQLrequest in-situ, we need a way * of accessing the request we are currently processing, QueryQueue::front(), * but that call needs to always return the same request until that request * is removed from the queue, this is what the 'which' variable is. New queries are * always added to the back of one of the two queues, but if when front() * is first called then the priority queue is empty then front() will return * a query from the normal queue, but if a query is then added to the priority * queue then front() must continue to return the front of the *normal* queue * until pop() is called. */ class QueryQueue : public classbase { private: typedef std::deque<SQLrequest> ReqDeque; ReqDeque priority; /* The priority queue */ ReqDeque normal; /* The 'normal' queue */ enum { PRI, NOR, NON } which; /* Which queue the currently active element is at the front of */ public: QueryQueue() : which(NON) { } void push(const SQLrequest &q) { if(q.pri) priority.push_back(q); else normal.push_back(q); } void pop() { if((which == PRI) && priority.size()) { priority.pop_front(); } else if((which == NOR) && normal.size()) { normal.pop_front(); } /* Reset this */ which = NON; /* Silently do nothing if there was no element to pop() */ } SQLrequest& front() { switch(which) { case PRI: return priority.front(); case NOR: return normal.front(); default: if(priority.size()) { which = PRI; return priority.front(); } if(normal.size()) { which = NOR; return normal.front(); } /* This will probably result in a segfault, * but the caller should have checked totalsize() * first so..meh - moron :p */ return priority.front(); } } std::pair<int, int> size() { return std::make_pair(priority.size(), normal.size()); } int totalsize() { return priority.size() + normal.size(); } void PurgeModule(Module* mod) { DoPurgeModule(mod, priority); DoPurgeModule(mod, normal); } private: void DoPurgeModule(Module* mod, ReqDeque& q) { for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++) { if(iter->GetSource() == mod) { if(iter->id == front().id) { /* It's the currently active query.. :x */ iter->SetSource(NULL); } else { /* It hasn't been executed yet..just remove it */ iter = q.erase(iter); } } } } }; #endif \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#ifndef INSPIRCD_SQLAPI_2
+#define INSPIRCD_SQLAPI_2
+
+#include <string>
+#include <deque>
+#include <map>
+#include "modules.h"
+
+/** SQLreq define.
+ * This is the voodoo magic which lets us pass multiple
+ * parameters to the SQLrequest constructor... voodoo...
+ */
+#define SQLreq(a, b, c, d, e...) SQLrequest(a, b, c, (SQLquery(d), ##e))
+
+/** Identifiers used to identify Request types
+ */
+#define SQLREQID "SQLv2 Request"
+#define SQLRESID "SQLv2 Result"
+#define SQLSUCCESS "You shouldn't be reading this (success)"
+
+/** Defines the error types which SQLerror may be set to
+ */
+enum SQLerrorNum { NO_ERROR, BAD_DBID, BAD_CONN, QSEND_FAIL, QREPLY_FAIL };
+
+/** A list of format parameters for an SQLquery object.
+ */
+typedef std::deque<std::string> ParamL;
+
+/** The base class of SQL exceptions
+ */
+class SQLexception : public ModuleException
+{
+ public:
+ SQLexception(const std::string &reason) : ModuleException(reason)
+ {
+ }
+
+ SQLexception() : ModuleException("SQLv2: Undefined exception")
+ {
+ }
+};
+
+/** An exception thrown when a bad column or row name or id is requested
+ */
+class SQLbadColName : public SQLexception
+{
+public:
+ SQLbadColName() : SQLexception("SQLv2: Bad column name")
+ {
+ }
+};
+
+/** SQLerror holds the error state of any SQLrequest or SQLresult.
+ * The error string varies from database software to database software
+ * and should be used to display informational error messages to users.
+ */
+class SQLerror : public classbase
+{
+ /** The error id
+ */
+ SQLerrorNum id;
+ /** The error string
+ */
+ std::string str;
+public:
+ /** Initialize an SQLerror
+ * @param i The error ID to set
+ * @param s The (optional) error string to set
+ */
+ SQLerror(SQLerrorNum i = NO_ERROR, const std::string &s = "")
+ : id(i), str(s)
+ {
+ }
+
+ /** Return the ID of the error
+ */
+ SQLerrorNum Id()
+ {
+ return id;
+ }
+
+ /** Set the ID of an error
+ * @param i The new error ID to set
+ * @return the ID which was set
+ */
+ SQLerrorNum Id(SQLerrorNum i)
+ {
+ id = i;
+ return id;
+ }
+
+ /** Set the error string for an error
+ * @param s The new error string to set
+ */
+ void Str(const std::string &s)
+ {
+ str = s;
+ }
+
+ /** Return the error string for an error
+ */
+ const char* Str()
+ {
+ if(str.length())
+ return str.c_str();
+
+ switch(id)
+ {
+ case NO_ERROR:
+ return "No error";
+ case BAD_DBID:
+ return "Invalid database ID";
+ case BAD_CONN:
+ return "Invalid connection";
+ case QSEND_FAIL:
+ return "Sending query failed";
+ case QREPLY_FAIL:
+ return "Getting query result failed";
+ default:
+ return "Unknown error";
+ }
+ }
+};
+
+/** SQLquery provides a way to represent a query string, and its parameters in a type-safe way.
+ * C++ has no native type-safe way of having a variable number of arguments to a function,
+ * the workaround for this isn't easy to describe simply, but in a nutshell what's really
+ * happening when - from the above example - you do this:
+ *
+ * SQLrequest foo = SQLreq(this, target, "databaseid", "SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?", "Hello", "42");
+ *
+ * what's actually happening is functionally this:
+ *
+ * SQLrequest foo = SQLreq(this, target, "databaseid", query("SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?").addparam("Hello").addparam("42"));
+ *
+ * with 'query()' returning a reference to an object with a 'addparam()' member function which
+ * in turn returns a reference to that object. There are actually four ways you can create a
+ * SQLrequest..all have their disadvantages and advantages. In the real implementations the
+ * 'query()' function is replaced by the constructor of another class 'SQLquery' which holds
+ * the query string and a ParamL (std::deque<std::string>) of query parameters.
+ * This is essentially the same as the above example except 'addparam()' is replaced by operator,(). The full syntax for this method is:
+ *
+ * SQLrequest foo = SQLrequest(this, target, "databaseid", (SQLquery("SELECT.. ?"), parameter, parameter));
+ */
+class SQLquery : public classbase
+{
+public:
+ /** The query 'format string'
+ */
+ std::string q;
+ /** The query parameter list
+ * There should be one parameter for every ? character
+ * within the format string shown above.
+ */
+ ParamL p;
+
+ /** Initialize an SQLquery with a given format string only
+ */
+ SQLquery(const std::string &query)
+ : q(query)
+ {
+ }
+
+ /** Initialize an SQLquery with a format string and parameters.
+ * If you provide parameters, you must initialize the list yourself
+ * if you choose to do it via this method, using std::deque::push_back().
+ */
+ SQLquery(const std::string &query, const ParamL &params)
+ : q(query), p(params)
+ {
+ }
+
+ /** An overloaded operator for pushing parameters onto the parameter list
+ */
+ template<typename T> SQLquery& operator,(const T &foo)
+ {
+ p.push_back(ConvToStr(foo));
+ return *this;
+ }
+
+ /** An overloaded operator for pushing parameters onto the parameter list.
+ * This has higher precedence than 'operator,' and can save on parenthesis.
+ */
+ template<typename T> SQLquery& operator%(const T &foo)
+ {
+ p.push_back(ConvToStr(foo));
+ return *this;
+ }
+};
+
+/** SQLrequest is sent to the SQL API to command it to run a query and return the result.
+ * You must instantiate this object with a valid SQLquery object and its parameters, then
+ * send it using its Send() method to the module providing the 'SQL' feature. To find this
+ * module, use Server::FindFeature().
+ */
+class SQLrequest : public Request
+{
+public:
+ /** The fully parsed and expanded query string
+ * This is initialized from the SQLquery parameter given in the constructor.
+ */
+ SQLquery query;
+ /** The database ID to apply the request to
+ */
+ std::string dbid;
+ /** True if this is a priority query.
+ * Priority queries may 'queue jump' in the request queue.
+ */
+ bool pri;
+ /** The query ID, assigned by the SQL api.
+ * After your request is processed, this will
+ * be initialized for you by the API to a valid request ID,
+ * except in the case of an error.
+ */
+ unsigned long id;
+ /** If an error occured, error.id will be any other value than NO_ERROR.
+ */
+ SQLerror error;
+
+ /** Initialize an SQLrequest.
+ * For example:
+ *
+ * SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors VALUES('','?')", nick);
+ *
+ * @param s A pointer to the sending module, where the result should be routed
+ * @param d A pointer to the receiving module, identified as implementing the 'SQL' feature
+ * @param databaseid The database ID to perform the query on. This must match a valid
+ * database ID from the configuration of the SQL module.
+ * @param q A properly initialized SQLquery object.
+ */
+ SQLrequest(Module* s, Module* d, const std::string &databaseid, const SQLquery &q)
+ : Request(s, d, SQLREQID), query(q), dbid(databaseid), pri(false), id(0)
+ {
+ }
+
+ /** Set the priority of a request.
+ */
+ void Priority(bool p = true)
+ {
+ pri = p;
+ }
+
+ /** Set the source of a request. You should not need to use this method.
+ */
+ void SetSource(Module* mod)
+ {
+ source = mod;
+ }
+};
+
+/**
+ * This class contains a field's data plus a way to determine if the field
+ * is NULL or not without having to mess around with NULL pointers.
+ */
+class SQLfield
+{
+public:
+ /**
+ * The data itself
+ */
+ std::string d;
+
+ /**
+ * If the field was null
+ */
+ bool null;
+
+ /** Initialize an SQLfield
+ */
+ SQLfield(const std::string &data = "", bool n = false)
+ : d(data), null(n)
+ {
+
+ }
+};
+
+/** A list of items which make up a row of a result or table (tuple)
+ * This does not include field names.
+ */
+typedef std::vector<SQLfield> SQLfieldList;
+/** A list of items which make up a row of a result or table (tuple)
+ * This also includes the field names.
+ */
+typedef std::map<std::string, SQLfield> SQLfieldMap;
+
+/** SQLresult is a reply to a previous query.
+ * If you send a query to the SQL api, the response will arrive at your
+ * OnRequest method of your module at some later time, depending on the
+ * congestion of the SQL server and complexity of the query. The ID of
+ * this result will match the ID assigned to your original request.
+ * SQLresult contains its own internal cursor (row counter) which is
+ * incremented with each method call which retrieves a single row.
+ */
+class SQLresult : public Request
+{
+public:
+ /** The original query string passed initially to the SQL API
+ */
+ std::string query;
+ /** The database ID the query was executed on
+ */
+ std::string dbid;
+ /**
+ * The error (if any) which occured.
+ * If an error occured the value of error.id will be any
+ * other value than NO_ERROR.
+ */
+ SQLerror error;
+ /**
+ * This will match query ID you were given when sending
+ * the request at an earlier time.
+ */
+ unsigned long id;
+
+ /** Used by the SQL API to instantiate an SQLrequest
+ */
+ SQLresult(Module* s, Module* d, unsigned long i)
+ : Request(s, d, SQLRESID), id(i)
+ {
+ }
+
+ /**
+ * Return the number of rows in the result
+ * Note that if you have perfomed an INSERT
+ * or UPDATE query or other query which will
+ * not return rows, this will return the
+ * number of affected rows, and SQLresult::Cols()
+ * will contain 0. In this case you SHOULD NEVER
+ * access any of the result set rows, as there arent any!
+ * @returns Number of rows in the result set.
+ */
+ virtual int Rows() = 0;
+
+ /**
+ * Return the number of columns in the result.
+ * If you performed an UPDATE or INSERT which
+ * does not return a dataset, this value will
+ * be 0.
+ * @returns Number of columns in the result set.
+ */
+ virtual int Cols() = 0;
+
+ /**
+ * Get a string name of the column by an index number
+ * @param column The id number of a column
+ * @returns The column name associated with the given ID
+ */
+ virtual std::string ColName(int column) = 0;
+
+ /**
+ * Get an index number for a column from a string name.
+ * An exception of type SQLbadColName will be thrown if
+ * the name given is invalid.
+ * @param column The column name to get the ID of
+ * @returns The ID number of the column provided
+ */
+ virtual int ColNum(const std::string &column) = 0;
+
+ /**
+ * Get a string value in a given row and column
+ * This does not effect the internal cursor.
+ * @returns The value stored at [row,column] in the table
+ */
+ virtual SQLfield GetValue(int row, int column) = 0;
+
+ /**
+ * Return a list of values in a row, this should
+ * increment an internal counter so you can repeatedly
+ * call it until it returns an empty vector.
+ * This returns a reference to an internal object,
+ * the same object is used for all calls to this function
+ * and therefore the return value is only valid until
+ * you call this function again. It is also invalid if
+ * the SQLresult object is destroyed.
+ * The internal cursor (row counter) is incremented by one.
+ * @returns A reference to the current row's SQLfieldList
+ */
+ virtual SQLfieldList& GetRow() = 0;
+
+ /**
+ * As above, but return a map indexed by key name.
+ * The internal cursor (row counter) is incremented by one.
+ * @returns A reference to the current row's SQLfieldMap
+ */
+ virtual SQLfieldMap& GetRowMap() = 0;
+
+ /**
+ * Like GetRow(), but returns a pointer to a dynamically
+ * allocated object which must be explicitly freed. For
+ * portability reasons this must be freed with SQLresult::Free()
+ * The internal cursor (row counter) is incremented by one.
+ * @returns A newly-allocated SQLfieldList
+ */
+ virtual SQLfieldList* GetRowPtr() = 0;
+
+ /**
+ * As above, but return a map indexed by key name
+ * The internal cursor (row counter) is incremented by one.
+ * @returns A newly-allocated SQLfieldMap
+ */
+ virtual SQLfieldMap* GetRowMapPtr() = 0;
+
+ /**
+ * Overloaded function for freeing the lists and maps
+ * returned by GetRowPtr or GetRowMapPtr.
+ * @param fm The SQLfieldMap to free
+ */
+ virtual void Free(SQLfieldMap* fm) = 0;
+
+ /**
+ * Overloaded function for freeing the lists and maps
+ * returned by GetRowPtr or GetRowMapPtr.
+ * @param fl The SQLfieldList to free
+ */
+ virtual void Free(SQLfieldList* fl) = 0;
+};
+
+
+/** SQLHost represents a <database> config line and is useful
+ * for storing in a map and iterating on rehash to see which
+ * <database> tags was added/removed/unchanged.
+ */
+class SQLhost
+{
+ public:
+ std::string id; /* Database handle id */
+ std::string host; /* Database server hostname */
+ std::string ip; /* resolved IP, needed for at least pgsql.so */
+ unsigned int port; /* Database server port */
+ std::string name; /* Database name */
+ std::string user; /* Database username */
+ std::string pass; /* Database password */
+ bool ssl; /* If we should require SSL */
+
+ SQLhost()
+ {
+ }
+
+ SQLhost(const std::string& i, const std::string& h, unsigned int p, const std::string& n, const std::string& u, const std::string& pa, bool s)
+ : id(i), host(h), port(p), name(n), user(u), pass(pa), ssl(s)
+ {
+ }
+
+ /** Overload this to return a correct Data source Name (DSN) for
+ * the current SQL module.
+ */
+ std::string GetDSN();
+};
+
+/** Overload operator== for two SQLhost objects for easy comparison.
+ */
+bool operator== (const SQLhost& l, const SQLhost& r)
+{
+ return (l.id == r.id && l.host == r.host && l.port == r.port && l.name == r.name && l.user == l.user && l.pass == r.pass && l.ssl == r.ssl);
+}
+
+
+/** QueryQueue, a queue of queries waiting to be executed.
+ * This maintains two queues internally, one for 'priority'
+ * queries and one for less important ones. Each queue has
+ * new queries appended to it and ones to execute are popped
+ * off the front. This keeps them flowing round nicely and no
+ * query should ever get 'stuck' for too long. If there are
+ * queries in the priority queue they will be executed first,
+ * 'unimportant' queries will only be executed when the
+ * priority queue is empty.
+ *
+ * We store lists of SQLrequest's here, by value as we want to avoid storing
+ * any data allocated inside the client module (in case that module is unloaded
+ * while the query is in progress).
+ *
+ * Because we want to work on the current SQLrequest in-situ, we need a way
+ * of accessing the request we are currently processing, QueryQueue::front(),
+ * but that call needs to always return the same request until that request
+ * is removed from the queue, this is what the 'which' variable is. New queries are
+ * always added to the back of one of the two queues, but if when front()
+ * is first called then the priority queue is empty then front() will return
+ * a query from the normal queue, but if a query is then added to the priority
+ * queue then front() must continue to return the front of the *normal* queue
+ * until pop() is called.
+ */
+
+class QueryQueue : public classbase
+{
+private:
+ typedef std::deque<SQLrequest> ReqDeque;
+
+ ReqDeque priority; /* The priority queue */
+ ReqDeque normal; /* The 'normal' queue */
+ enum { PRI, NOR, NON } which; /* Which queue the currently active element is at the front of */
+
+public:
+ QueryQueue()
+ : which(NON)
+ {
+ }
+
+ void push(const SQLrequest &q)
+ {
+ if(q.pri)
+ priority.push_back(q);
+ else
+ normal.push_back(q);
+ }
+
+ void pop()
+ {
+ if((which == PRI) && priority.size())
+ {
+ priority.pop_front();
+ }
+ else if((which == NOR) && normal.size())
+ {
+ normal.pop_front();
+ }
+
+ /* Reset this */
+ which = NON;
+
+ /* Silently do nothing if there was no element to pop() */
+ }
+
+ SQLrequest& front()
+ {
+ switch(which)
+ {
+ case PRI:
+ return priority.front();
+ case NOR:
+ return normal.front();
+ default:
+ if(priority.size())
+ {
+ which = PRI;
+ return priority.front();
+ }
+
+ if(normal.size())
+ {
+ which = NOR;
+ return normal.front();
+ }
+
+ /* This will probably result in a segfault,
+ * but the caller should have checked totalsize()
+ * first so..meh - moron :p
+ */
+
+ return priority.front();
+ }
+ }
+
+ std::pair<int, int> size()
+ {
+ return std::make_pair(priority.size(), normal.size());
+ }
+
+ int totalsize()
+ {
+ return priority.size() + normal.size();
+ }
+
+ void PurgeModule(Module* mod)
+ {
+ DoPurgeModule(mod, priority);
+ DoPurgeModule(mod, normal);
+ }
+
+private:
+ void DoPurgeModule(Module* mod, ReqDeque& q)
+ {
+ for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
+ {
+ if(iter->GetSource() == mod)
+ {
+ if(iter->id == front().id)
+ {
+ /* It's the currently active query.. :x */
+ iter->SetSource(NULL);
+ }
+ else
+ {
+ /* It hasn't been executed yet..just remove it */
+ iter = q.erase(iter);
+ }
+ }
+ }
+ }
+};
+
+
+#endif
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index 037d2cf72..fd8b12d32 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -1 +1,843 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <gnutls/gnutls.h> #include <gnutls/x509.h> #include "inspircd_config.h" #include "configreader.h" #include "users.h" #include "channels.h" #include "modules.h" #include "socket.h" #include "hashcomp.h" #include "transport.h" #ifdef WINDOWS #pragma comment(lib, "libgnutls-13.lib") #undef MAX_DESCRIPTORS #define MAX_DESCRIPTORS 10000 #endif /* $ModDesc: Provides SSL support for clients */ /* $CompileFlags: exec("libgnutls-config --cflags") */ /* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */ /* $ModDep: transport.h */ enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED }; bool isin(int port, const std::vector<int> &portlist) { for(unsigned int i = 0; i < portlist.size(); i++) if(portlist[i] == port) return true; return false; } /** Represents an SSL user's extra data */ class issl_session : public classbase { public: gnutls_session_t sess; issl_status status; std::string outbuf; int inbufoffset; char* inbuf; int fd; }; class ModuleSSLGnuTLS : public Module { ConfigReader* Conf; char* dummy; std::vector<int> listenports; int inbufsize; issl_session sessions[MAX_DESCRIPTORS]; gnutls_certificate_credentials x509_cred; gnutls_dh_params dh_params; std::string keyfile; std::string certfile; std::string cafile; std::string crlfile; std::string sslports; int dh_bits; int clientactive; public: ModuleSSLGnuTLS(InspIRCd* Me) : Module(Me) { ServerInstance->PublishInterface("InspSocketHook", this); // Not rehashable...because I cba to reduce all the sizes of existing buffers. inbufsize = ServerInstance->Config->NetBufferSize; gnutls_global_init(); // This must be called once in the program if(gnutls_certificate_allocate_credentials(&x509_cred) != 0) ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials"); // Guessing return meaning if(gnutls_dh_params_init(&dh_params) < 0) ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters"); // Needs the flag as it ignores a plain /rehash OnRehash(NULL,"ssl"); // Void return, guess we assume success gnutls_certificate_set_dh_params(x509_cred, dh_params); } virtual void OnRehash(userrec* user, const std::string &param) { if(param != "ssl") return; Conf = new ConfigReader(ServerInstance); for(unsigned int i = 0; i < listenports.size(); i++) { ServerInstance->Config->DelIOHook(listenports[i]); } listenports.clear(); clientactive = 0; sslports.clear(); for(int i = 0; i < Conf->Enumerate("bind"); i++) { // For each <bind> tag std::string x = Conf->ReadValue("bind", "type", i); if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "gnutls")) { // Get the port we're meant to be listening on with SSL std::string port = Conf->ReadValue("bind", "port", i); irc::portparser portrange(port, false); long portno = -1; while ((portno = portrange.GetToken())) { clientactive++; try { if (ServerInstance->Config->AddIOHook(portno, this)) { listenports.push_back(portno); for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++) if (ServerInstance->Config->ports[i]->GetPort() == portno) ServerInstance->Config->ports[i]->SetDescription("ssl"); ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %d", portno); sslports.append("*:").append(ConvToStr(portno)).append(";"); } else { ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno); } } catch (ModuleException &e) { ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason()); } } } } std::string confdir(ServerInstance->ConfigFileName); // +1 so we the path ends with a / confdir = confdir.substr(0, confdir.find_last_of('/') + 1); cafile = Conf->ReadValue("gnutls", "cafile", 0); crlfile = Conf->ReadValue("gnutls", "crlfile", 0); certfile = Conf->ReadValue("gnutls", "certfile", 0); keyfile = Conf->ReadValue("gnutls", "keyfile", 0); dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false); // Set all the default values needed. if (cafile.empty()) cafile = "ca.pem"; if (crlfile.empty()) crlfile = "crl.pem"; if (certfile.empty()) certfile = "cert.pem"; if (keyfile.empty()) keyfile = "key.pem"; if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096)) dh_bits = 1024; // Prepend relative paths with the path to the config directory. if(cafile[0] != '/') cafile = confdir + cafile; if(crlfile[0] != '/') crlfile = confdir + crlfile; if(certfile[0] != '/') certfile = confdir + certfile; if(keyfile[0] != '/') keyfile = confdir + keyfile; int ret; if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret)); if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret)); if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) { // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret))); } // This may be on a large (once a day or week) timer eventually. GenerateDHParams(); DELETE(Conf); } void GenerateDHParams() { // Generate Diffie Hellman parameters - for use with DHE // kx algorithms. These should be discarded and regenerated // once a day, once a week or once a month. Depending on the // security requirements. int ret; if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0) ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret)); } virtual ~ModuleSSLGnuTLS() { gnutls_dh_params_deinit(dh_params); gnutls_certificate_free_credentials(x509_cred); gnutls_global_deinit(); } virtual void OnCleanup(int target_type, void* item) { if(target_type == TYPE_USER) { userrec* user = (userrec*)item; if(user->GetExt("ssl", dummy) && isin(user->GetPort(), listenports)) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading"); } if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports)) { ssl_cert* tofree; user->GetExt("ssl_cert", tofree); delete tofree; user->Shrink("ssl_cert"); } } } virtual void OnUnloadModule(Module* mod, const std::string &name) { if(mod == this) { for(unsigned int i = 0; i < listenports.size(); i++) { ServerInstance->Config->DelIOHook(listenports[i]); for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++) if (ServerInstance->Config->ports[j]->GetPort() == listenports[i]) ServerInstance->Config->ports[j]->SetDescription("plaintext"); } } } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); } void Implements(char* List) { List[I_On005Numeric] = List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = 1; List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1; } virtual void On005Numeric(std::string &output) { output.append(" SSL=" + sslports); } virtual char* OnRequest(Request* request) { ISHRequest* ISR = (ISHRequest*)request; if (strcmp("IS_NAME", request->GetId()) == 0) { return "gnutls"; } else if (strcmp("IS_HOOK", request->GetId()) == 0) { char* ret = "OK"; try { ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } catch (ModuleException &e) { return NULL; } return ret; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { if (ISR->Sock->GetFd() < 0) return (char*)"OK"; issl_session* session = &sessions[ISR->Sock->GetFd()]; return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : (char*)"OK"; } else if (strcmp("IS_ATTACH", request->GetId()) == 0) { if (ISR->Sock->GetFd() > -1) { issl_session* session = &sessions[ISR->Sock->GetFd()]; if (session->sess) { if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock)) { VerifyCertificate(session, (InspSocket*)ISR->Sock); return "OK"; } } } } return NULL; } virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) { issl_session* session = &sessions[fd]; session->fd = fd; session->inbuf = new char[inbufsize]; session->inbufoffset = 0; gnutls_init(&session->sess, GNUTLS_SERVER); gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes * This needs testing, but it's easy enough to rollback if need be * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket. * * With testing this seems to...not work :/ */ gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. Handshake(session); } virtual void OnRawSocketConnect(int fd) { issl_session* session = &sessions[fd]; session->fd = fd; session->inbuf = new char[inbufsize]; session->inbufoffset = 0; gnutls_init(&session->sess, GNUTLS_CLIENT); gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. Handshake(session); } virtual void OnRawSocketClose(int fd) { CloseSession(&sessions[fd]); EventHandler* user = ServerInstance->SE->GetRef(fd); if ((user) && (user->GetExt("ssl_cert", dummy))) { ssl_cert* tofree; user->GetExt("ssl_cert", tofree); delete tofree; user->Shrink("ssl_cert"); } } virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) { issl_session* session = &sessions[fd]; if (!session->sess) { readresult = 0; CloseSession(session); return 1; } if (session->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. if(!Handshake(session)) { // Couldn't resume handshake. return -1; } } else if (session->status == ISSL_HANDSHAKING_WRITE) { errno = EAGAIN; return -1; } // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. if (session->status == ISSL_HANDSHAKEN) { // Is this right? Not sure if the unencrypted data is garaunteed to be the same length. // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet. int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset); if (ret == 0) { // Client closed connection. readresult = 0; CloseSession(session); return 1; } else if (ret < 0) { if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { errno = EAGAIN; return -1; } else { readresult = 0; CloseSession(session); } } else { // Read successfully 'ret' bytes into inbuf + inbufoffset // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf' // 'buffer' is 'count' long unsigned int length = ret + session->inbufoffset; if(count <= length) { memcpy(buffer, session->inbuf, count); // Move the stuff left in inbuf to the beginning of it memcpy(session->inbuf, session->inbuf + count, (length - count)); // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp. session->inbufoffset = length - count; // Insp uses readresult as the count of how much data there is in buffer, so: readresult = count; } else { // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing. memcpy(buffer, session->inbuf, length); // Zero the offset, as there's nothing there.. session->inbufoffset = 0; // As above readresult = length; } } } else if(session->status == ISSL_CLOSING) readresult = 0; return 1; } virtual int OnRawSocketWrite(int fd, const char* buffer, int count) { if (!count) return 0; issl_session* session = &sessions[fd]; const char* sendbuffer = buffer; if (!session->sess) { ServerInstance->Log(DEBUG,"No session"); CloseSession(session); return 1; } session->outbuf.append(sendbuffer, count); sendbuffer = session->outbuf.c_str(); count = session->outbuf.size(); if (session->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. ServerInstance->Log(DEBUG,"Finishing handshake"); Handshake(session); errno = EAGAIN; return -1; } int ret = 0; if (session->status == ISSL_HANDSHAKEN) { ServerInstance->Log(DEBUG,"Send record"); ret = gnutls_record_send(session->sess, sendbuffer, count); ServerInstance->Log(DEBUG,"Return: %d", ret); if (ret == 0) { CloseSession(session); } else if (ret < 0) { if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) { ServerInstance->Log(DEBUG,"Not egain or interrupt, close session"); CloseSession(session); } else { ServerInstance->Log(DEBUG,"Again please"); errno = EAGAIN; return -1; } } else { ServerInstance->Log(DEBUG,"Trim buffer"); session->outbuf = session->outbuf.substr(ret); } } /* Who's smart idea was it to return 1 when we havent written anything? * This fucks the buffer up in InspSocket :p */ return ret < 1 ? 0 : ret; } // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection virtual void OnWhois(userrec* source, userrec* dest) { if (!clientactive) return; // Bugfix, only send this numeric for *our* SSL users if(dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports))) { ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick); } } virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable) { // check if the linking module wants to know about OUR metadata if(extname == "ssl") { // check if this user has an swhois field to send if(user->GetExt(extname, dummy)) { // call this function in the linking module, let it format the data how it // sees fit, and send it on its way. We dont need or want to know how. proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON"); } } } virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata) { // check if its our metadata key, and its associated with a user if ((target_type == TYPE_USER) && (extname == "ssl")) { userrec* dest = (userrec*)target; // if they dont already have an ssl flag, accept the remote server's if (!dest->GetExt(extname, dummy)) { dest->Extend(extname, "ON"); } } } bool Handshake(issl_session* session) { int ret = gnutls_handshake(session->sess); if (ret < 0) { if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { // Handshake needs resuming later, read() or write() would have blocked. if(gnutls_record_get_direction(session->sess) == 0) { // gnutls_handshake() wants to read() again. session->status = ISSL_HANDSHAKING_READ; } else { // gnutls_handshake() wants to write() again. session->status = ISSL_HANDSHAKING_WRITE; MakePollWrite(session); } } else { // Handshake failed. CloseSession(session); session->status = ISSL_CLOSING; } return false; } else { // Handshake complete. // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater. userrec* extendme = ServerInstance->FindDescriptor(session->fd); if (extendme) { if (!extendme->GetExt("ssl", dummy)) extendme->Extend("ssl", "ON"); } // Change the seesion state session->status = ISSL_HANDSHAKEN; // Finish writing, if any left MakePollWrite(session); return true; } } virtual void OnPostConnect(userrec* user) { // This occurs AFTER OnUserConnect so we can be sure the // protocol module has propogated the NICK message. if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user))) { // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW. std::deque<std::string>* metadata = new std::deque<std::string>; metadata->push_back(user->nick); metadata->push_back("ssl"); // The metadata id metadata->push_back("ON"); // The value to send Event* event = new Event((char*)metadata,(Module*)this,"send_metadata"); event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up. DELETE(event); DELETE(metadata); VerifyCertificate(&sessions[user->GetFd()],user); if (sessions[user->GetFd()].sess) { std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess)); cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-"); cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess))); user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, cipher.c_str()); } } } void MakePollWrite(issl_session* session) { OnRawSocketWrite(session->fd, NULL, 0); } void CloseSession(issl_session* session) { if(session->sess) { gnutls_bye(session->sess, GNUTLS_SHUT_WR); gnutls_deinit(session->sess); } if(session->inbuf) { delete[] session->inbuf; } session->outbuf.clear(); session->inbuf = NULL; session->sess = NULL; session->status = ISSL_NONE; } void VerifyCertificate(issl_session* session, Extensible* user) { if (!session->sess || !user) return; unsigned int status; const gnutls_datum_t* cert_list; int ret; unsigned int cert_list_size; gnutls_x509_crt_t cert; char name[MAXBUF]; unsigned char digest[MAXBUF]; size_t digest_size = sizeof(digest); size_t name_size = sizeof(name); ssl_cert* certinfo = new ssl_cert; user->Extend("ssl_cert",certinfo); /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. */ ret = gnutls_certificate_verify_peers2(session->sess, &status); if (ret < 0) { certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret)))); return; } if (status & GNUTLS_CERT_INVALID) { certinfo->data.insert(std::make_pair("invalid",ConvToStr(1))); } else { certinfo->data.insert(std::make_pair("invalid",ConvToStr(0))); } if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) { certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1))); } else { certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0))); } if (status & GNUTLS_CERT_REVOKED) { certinfo->data.insert(std::make_pair("revoked",ConvToStr(1))); } else { certinfo->data.insert(std::make_pair("revoked",ConvToStr(0))); } if (status & GNUTLS_CERT_SIGNER_NOT_CA) { certinfo->data.insert(std::make_pair("trusted",ConvToStr(0))); } else { certinfo->data.insert(std::make_pair("trusted",ConvToStr(1))); } /* Up to here the process is the same for X.509 certificates and * OpenPGP keys. From now on X.509 certificates are assumed. This can * be easily extended to work with openpgp keys as well. */ if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) { certinfo->data.insert(std::make_pair("error","No X509 keys sent")); return; } ret = gnutls_x509_crt_init(&cert); if (ret < 0) { certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); return; } cert_list_size = 0; cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); if (cert_list == NULL) { certinfo->data.insert(std::make_pair("error","No certificate was found")); return; } /* This is not a real world example, since we only check the first * certificate in the given chain. */ ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER); if (ret < 0) { certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); return; } gnutls_x509_crt_get_dn(cert, name, &name_size); certinfo->data.insert(std::make_pair("dn",name)); gnutls_x509_crt_get_issuer_dn(cert, name, &name_size); certinfo->data.insert(std::make_pair("issuer",name)); if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0) { certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); } else { certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size))); } /* Beware here we do not check for errors. */ if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0))) { certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate")); } gnutls_x509_crt_deinit(cert); return; } }; MODULE_INIT(ModuleSSLGnuTLS); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+
+#include "inspircd_config.h"
+#include "configreader.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "socket.h"
+#include "hashcomp.h"
+#include "transport.h"
+
+#ifdef WINDOWS
+#pragma comment(lib, "libgnutls-13.lib")
+#undef MAX_DESCRIPTORS
+#define MAX_DESCRIPTORS 10000
+#endif
+
+/* $ModDesc: Provides SSL support for clients */
+/* $CompileFlags: exec("libgnutls-config --cflags") */
+/* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */
+/* $ModDep: transport.h */
+
+
+enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
+
+bool isin(int port, const std::vector<int> &portlist)
+{
+ for(unsigned int i = 0; i < portlist.size(); i++)
+ if(portlist[i] == port)
+ return true;
+
+ return false;
+}
+
+/** Represents an SSL user's extra data
+ */
+class issl_session : public classbase
+{
+public:
+ gnutls_session_t sess;
+ issl_status status;
+ std::string outbuf;
+ int inbufoffset;
+ char* inbuf;
+ int fd;
+};
+
+class ModuleSSLGnuTLS : public Module
+{
+
+ ConfigReader* Conf;
+
+ char* dummy;
+
+ std::vector<int> listenports;
+
+ int inbufsize;
+ issl_session sessions[MAX_DESCRIPTORS];
+
+ gnutls_certificate_credentials x509_cred;
+ gnutls_dh_params dh_params;
+
+ std::string keyfile;
+ std::string certfile;
+ std::string cafile;
+ std::string crlfile;
+ std::string sslports;
+ int dh_bits;
+
+ int clientactive;
+
+ public:
+
+ ModuleSSLGnuTLS(InspIRCd* Me)
+ : Module(Me)
+ {
+ ServerInstance->PublishInterface("InspSocketHook", this);
+
+ // Not rehashable...because I cba to reduce all the sizes of existing buffers.
+ inbufsize = ServerInstance->Config->NetBufferSize;
+
+ gnutls_global_init(); // This must be called once in the program
+
+ if(gnutls_certificate_allocate_credentials(&x509_cred) != 0)
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials");
+
+ // Guessing return meaning
+ if(gnutls_dh_params_init(&dh_params) < 0)
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters");
+
+ // Needs the flag as it ignores a plain /rehash
+ OnRehash(NULL,"ssl");
+
+ // Void return, guess we assume success
+ gnutls_certificate_set_dh_params(x509_cred, dh_params);
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &param)
+ {
+ if(param != "ssl")
+ return;
+
+ Conf = new ConfigReader(ServerInstance);
+
+ for(unsigned int i = 0; i < listenports.size(); i++)
+ {
+ ServerInstance->Config->DelIOHook(listenports[i]);
+ }
+
+ listenports.clear();
+ clientactive = 0;
+ sslports.clear();
+
+ for(int i = 0; i < Conf->Enumerate("bind"); i++)
+ {
+ // For each <bind> tag
+ std::string x = Conf->ReadValue("bind", "type", i);
+ if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "gnutls"))
+ {
+ // Get the port we're meant to be listening on with SSL
+ std::string port = Conf->ReadValue("bind", "port", i);
+ irc::portparser portrange(port, false);
+ long portno = -1;
+ while ((portno = portrange.GetToken()))
+ {
+ clientactive++;
+ try
+ {
+ if (ServerInstance->Config->AddIOHook(portno, this))
+ {
+ listenports.push_back(portno);
+ for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
+ if (ServerInstance->Config->ports[i]->GetPort() == portno)
+ ServerInstance->Config->ports[i]->SetDescription("ssl");
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %d", portno);
+ sslports.append("*:").append(ConvToStr(portno)).append(";");
+ }
+ else
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno);
+ }
+ }
+ catch (ModuleException &e)
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
+ }
+ }
+ }
+ }
+
+ std::string confdir(ServerInstance->ConfigFileName);
+ // +1 so we the path ends with a /
+ confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
+
+ cafile = Conf->ReadValue("gnutls", "cafile", 0);
+ crlfile = Conf->ReadValue("gnutls", "crlfile", 0);
+ certfile = Conf->ReadValue("gnutls", "certfile", 0);
+ keyfile = Conf->ReadValue("gnutls", "keyfile", 0);
+ dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false);
+
+ // Set all the default values needed.
+ if (cafile.empty())
+ cafile = "ca.pem";
+
+ if (crlfile.empty())
+ crlfile = "crl.pem";
+
+ if (certfile.empty())
+ certfile = "cert.pem";
+
+ if (keyfile.empty())
+ keyfile = "key.pem";
+
+ if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
+ dh_bits = 1024;
+
+ // Prepend relative paths with the path to the config directory.
+ if(cafile[0] != '/')
+ cafile = confdir + cafile;
+
+ if(crlfile[0] != '/')
+ crlfile = confdir + crlfile;
+
+ if(certfile[0] != '/')
+ certfile = confdir + certfile;
+
+ if(keyfile[0] != '/')
+ keyfile = confdir + keyfile;
+
+ int ret;
+
+ if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
+
+ if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
+
+ if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
+ {
+ // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
+ throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret)));
+ }
+
+ // This may be on a large (once a day or week) timer eventually.
+ GenerateDHParams();
+
+ DELETE(Conf);
+ }
+
+ void GenerateDHParams()
+ {
+ // Generate Diffie Hellman parameters - for use with DHE
+ // kx algorithms. These should be discarded and regenerated
+ // once a day, once a week or once a month. Depending on the
+ // security requirements.
+
+ int ret;
+
+ if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
+ ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
+ }
+
+ virtual ~ModuleSSLGnuTLS()
+ {
+ gnutls_dh_params_deinit(dh_params);
+ gnutls_certificate_free_credentials(x509_cred);
+ gnutls_global_deinit();
+ }
+
+ virtual void OnCleanup(int target_type, void* item)
+ {
+ if(target_type == TYPE_USER)
+ {
+ userrec* user = (userrec*)item;
+
+ if(user->GetExt("ssl", dummy) && isin(user->GetPort(), listenports))
+ {
+ // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
+ // Potentially there could be multiple SSL modules loaded at once on different ports.
+ ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading");
+ }
+ if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports))
+ {
+ ssl_cert* tofree;
+ user->GetExt("ssl_cert", tofree);
+ delete tofree;
+ user->Shrink("ssl_cert");
+ }
+ }
+ }
+
+ virtual void OnUnloadModule(Module* mod, const std::string &name)
+ {
+ if(mod == this)
+ {
+ for(unsigned int i = 0; i < listenports.size(); i++)
+ {
+ ServerInstance->Config->DelIOHook(listenports[i]);
+ for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
+ if (ServerInstance->Config->ports[j]->GetPort() == listenports[i])
+ ServerInstance->Config->ports[j]->SetDescription("plaintext");
+ }
+ }
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
+ }
+
+ void Implements(char* List)
+ {
+ List[I_On005Numeric] = List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = 1;
+ List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1;
+ }
+
+ virtual void On005Numeric(std::string &output)
+ {
+ output.append(" SSL=" + sslports);
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ ISHRequest* ISR = (ISHRequest*)request;
+ if (strcmp("IS_NAME", request->GetId()) == 0)
+ {
+ return "gnutls";
+ }
+ else if (strcmp("IS_HOOK", request->GetId()) == 0)
+ {
+ char* ret = "OK";
+ try
+ {
+ ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ catch (ModuleException &e)
+ {
+ return NULL;
+ }
+ return ret;
+ }
+ else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
+ {
+ return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ else if (strcmp("IS_HSDONE", request->GetId()) == 0)
+ {
+ if (ISR->Sock->GetFd() < 0)
+ return (char*)"OK";
+
+ issl_session* session = &sessions[ISR->Sock->GetFd()];
+ return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : (char*)"OK";
+ }
+ else if (strcmp("IS_ATTACH", request->GetId()) == 0)
+ {
+ if (ISR->Sock->GetFd() > -1)
+ {
+ issl_session* session = &sessions[ISR->Sock->GetFd()];
+ if (session->sess)
+ {
+ if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
+ {
+ VerifyCertificate(session, (InspSocket*)ISR->Sock);
+ return "OK";
+ }
+ }
+ }
+ }
+ return NULL;
+ }
+
+
+ virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
+ {
+ issl_session* session = &sessions[fd];
+
+ session->fd = fd;
+ session->inbuf = new char[inbufsize];
+ session->inbufoffset = 0;
+
+ gnutls_init(&session->sess, GNUTLS_SERVER);
+
+ gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
+ gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
+ gnutls_dh_set_prime_bits(session->sess, dh_bits);
+
+ /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes
+ * This needs testing, but it's easy enough to rollback if need be
+ * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
+ * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket.
+ *
+ * With testing this seems to...not work :/
+ */
+
+ gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
+
+ gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
+
+ Handshake(session);
+ }
+
+ virtual void OnRawSocketConnect(int fd)
+ {
+ issl_session* session = &sessions[fd];
+
+ session->fd = fd;
+ session->inbuf = new char[inbufsize];
+ session->inbufoffset = 0;
+
+ gnutls_init(&session->sess, GNUTLS_CLIENT);
+
+ gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
+ gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
+ gnutls_dh_set_prime_bits(session->sess, dh_bits);
+ gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
+
+ Handshake(session);
+ }
+
+ virtual void OnRawSocketClose(int fd)
+ {
+ CloseSession(&sessions[fd]);
+
+ EventHandler* user = ServerInstance->SE->GetRef(fd);
+
+ if ((user) && (user->GetExt("ssl_cert", dummy)))
+ {
+ ssl_cert* tofree;
+ user->GetExt("ssl_cert", tofree);
+ delete tofree;
+ user->Shrink("ssl_cert");
+ }
+ }
+
+ virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
+ {
+ issl_session* session = &sessions[fd];
+
+ if (!session->sess)
+ {
+ readresult = 0;
+ CloseSession(session);
+ return 1;
+ }
+
+ if (session->status == ISSL_HANDSHAKING_READ)
+ {
+ // The handshake isn't finished, try to finish it.
+
+ if(!Handshake(session))
+ {
+ // Couldn't resume handshake.
+ return -1;
+ }
+ }
+ else if (session->status == ISSL_HANDSHAKING_WRITE)
+ {
+ errno = EAGAIN;
+ return -1;
+ }
+
+ // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
+
+ if (session->status == ISSL_HANDSHAKEN)
+ {
+ // Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
+ // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
+ int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
+
+ if (ret == 0)
+ {
+ // Client closed connection.
+ readresult = 0;
+ CloseSession(session);
+ return 1;
+ }
+ else if (ret < 0)
+ {
+ if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+ {
+ errno = EAGAIN;
+ return -1;
+ }
+ else
+ {
+ readresult = 0;
+ CloseSession(session);
+ }
+ }
+ else
+ {
+ // Read successfully 'ret' bytes into inbuf + inbufoffset
+ // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
+ // 'buffer' is 'count' long
+
+ unsigned int length = ret + session->inbufoffset;
+
+ if(count <= length)
+ {
+ memcpy(buffer, session->inbuf, count);
+ // Move the stuff left in inbuf to the beginning of it
+ memcpy(session->inbuf, session->inbuf + count, (length - count));
+ // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
+ session->inbufoffset = length - count;
+ // Insp uses readresult as the count of how much data there is in buffer, so:
+ readresult = count;
+ }
+ else
+ {
+ // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
+ memcpy(buffer, session->inbuf, length);
+ // Zero the offset, as there's nothing there..
+ session->inbufoffset = 0;
+ // As above
+ readresult = length;
+ }
+ }
+ }
+ else if(session->status == ISSL_CLOSING)
+ readresult = 0;
+
+ return 1;
+ }
+
+ virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
+ {
+ if (!count)
+ return 0;
+
+ issl_session* session = &sessions[fd];
+ const char* sendbuffer = buffer;
+
+ if (!session->sess)
+ {
+ ServerInstance->Log(DEBUG,"No session");
+ CloseSession(session);
+ return 1;
+ }
+
+ session->outbuf.append(sendbuffer, count);
+ sendbuffer = session->outbuf.c_str();
+ count = session->outbuf.size();
+
+ if (session->status == ISSL_HANDSHAKING_WRITE)
+ {
+ // The handshake isn't finished, try to finish it.
+ ServerInstance->Log(DEBUG,"Finishing handshake");
+ Handshake(session);
+ errno = EAGAIN;
+ return -1;
+ }
+
+ int ret = 0;
+
+ if (session->status == ISSL_HANDSHAKEN)
+ {
+ ServerInstance->Log(DEBUG,"Send record");
+ ret = gnutls_record_send(session->sess, sendbuffer, count);
+ ServerInstance->Log(DEBUG,"Return: %d", ret);
+
+ if (ret == 0)
+ {
+ CloseSession(session);
+ }
+ else if (ret < 0)
+ {
+ if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED)
+ {
+ ServerInstance->Log(DEBUG,"Not egain or interrupt, close session");
+ CloseSession(session);
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG,"Again please");
+ errno = EAGAIN;
+ return -1;
+ }
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG,"Trim buffer");
+ session->outbuf = session->outbuf.substr(ret);
+ }
+ }
+
+ /* Who's smart idea was it to return 1 when we havent written anything?
+ * This fucks the buffer up in InspSocket :p
+ */
+ return ret < 1 ? 0 : ret;
+ }
+
+ // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
+ virtual void OnWhois(userrec* source, userrec* dest)
+ {
+ if (!clientactive)
+ return;
+
+ // Bugfix, only send this numeric for *our* SSL users
+ if(dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports)))
+ {
+ ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick);
+ }
+ }
+
+ virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
+ {
+ // check if the linking module wants to know about OUR metadata
+ if(extname == "ssl")
+ {
+ // check if this user has an swhois field to send
+ if(user->GetExt(extname, dummy))
+ {
+ // call this function in the linking module, let it format the data how it
+ // sees fit, and send it on its way. We dont need or want to know how.
+ proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
+ }
+ }
+ }
+
+ virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
+ {
+ // check if its our metadata key, and its associated with a user
+ if ((target_type == TYPE_USER) && (extname == "ssl"))
+ {
+ userrec* dest = (userrec*)target;
+ // if they dont already have an ssl flag, accept the remote server's
+ if (!dest->GetExt(extname, dummy))
+ {
+ dest->Extend(extname, "ON");
+ }
+ }
+ }
+
+ bool Handshake(issl_session* session)
+ {
+ int ret = gnutls_handshake(session->sess);
+
+ if (ret < 0)
+ {
+ if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+ {
+ // Handshake needs resuming later, read() or write() would have blocked.
+
+ if(gnutls_record_get_direction(session->sess) == 0)
+ {
+ // gnutls_handshake() wants to read() again.
+ session->status = ISSL_HANDSHAKING_READ;
+ }
+ else
+ {
+ // gnutls_handshake() wants to write() again.
+ session->status = ISSL_HANDSHAKING_WRITE;
+ MakePollWrite(session);
+ }
+ }
+ else
+ {
+ // Handshake failed.
+ CloseSession(session);
+ session->status = ISSL_CLOSING;
+ }
+
+ return false;
+ }
+ else
+ {
+ // Handshake complete.
+ // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
+ userrec* extendme = ServerInstance->FindDescriptor(session->fd);
+ if (extendme)
+ {
+ if (!extendme->GetExt("ssl", dummy))
+ extendme->Extend("ssl", "ON");
+ }
+
+ // Change the seesion state
+ session->status = ISSL_HANDSHAKEN;
+
+ // Finish writing, if any left
+ MakePollWrite(session);
+
+ return true;
+ }
+ }
+
+ virtual void OnPostConnect(userrec* user)
+ {
+ // This occurs AFTER OnUserConnect so we can be sure the
+ // protocol module has propogated the NICK message.
+ if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
+ {
+ // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
+ std::deque<std::string>* metadata = new std::deque<std::string>;
+ metadata->push_back(user->nick);
+ metadata->push_back("ssl"); // The metadata id
+ metadata->push_back("ON"); // The value to send
+ Event* event = new Event((char*)metadata,(Module*)this,"send_metadata");
+ event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up.
+ DELETE(event);
+ DELETE(metadata);
+
+ VerifyCertificate(&sessions[user->GetFd()],user);
+ if (sessions[user->GetFd()].sess)
+ {
+ std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
+ cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
+ cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
+ user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, cipher.c_str());
+ }
+ }
+ }
+
+ void MakePollWrite(issl_session* session)
+ {
+ OnRawSocketWrite(session->fd, NULL, 0);
+ }
+
+ void CloseSession(issl_session* session)
+ {
+ if(session->sess)
+ {
+ gnutls_bye(session->sess, GNUTLS_SHUT_WR);
+ gnutls_deinit(session->sess);
+ }
+
+ if(session->inbuf)
+ {
+ delete[] session->inbuf;
+ }
+
+ session->outbuf.clear();
+ session->inbuf = NULL;
+ session->sess = NULL;
+ session->status = ISSL_NONE;
+ }
+
+ void VerifyCertificate(issl_session* session, Extensible* user)
+ {
+ if (!session->sess || !user)
+ return;
+
+ unsigned int status;
+ const gnutls_datum_t* cert_list;
+ int ret;
+ unsigned int cert_list_size;
+ gnutls_x509_crt_t cert;
+ char name[MAXBUF];
+ unsigned char digest[MAXBUF];
+ size_t digest_size = sizeof(digest);
+ size_t name_size = sizeof(name);
+ ssl_cert* certinfo = new ssl_cert;
+
+ user->Extend("ssl_cert",certinfo);
+
+ /* This verification function uses the trusted CAs in the credentials
+ * structure. So you must have installed one or more CA certificates.
+ */
+ ret = gnutls_certificate_verify_peers2(session->sess, &status);
+
+ if (ret < 0)
+ {
+ certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret))));
+ return;
+ }
+
+ if (status & GNUTLS_CERT_INVALID)
+ {
+ certinfo->data.insert(std::make_pair("invalid",ConvToStr(1)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("invalid",ConvToStr(0)));
+ }
+ if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
+ {
+ certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
+ }
+ if (status & GNUTLS_CERT_REVOKED)
+ {
+ certinfo->data.insert(std::make_pair("revoked",ConvToStr(1)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("revoked",ConvToStr(0)));
+ }
+ if (status & GNUTLS_CERT_SIGNER_NOT_CA)
+ {
+ certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
+ }
+
+ /* Up to here the process is the same for X.509 certificates and
+ * OpenPGP keys. From now on X.509 certificates are assumed. This can
+ * be easily extended to work with openpgp keys as well.
+ */
+ if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
+ {
+ certinfo->data.insert(std::make_pair("error","No X509 keys sent"));
+ return;
+ }
+
+ ret = gnutls_x509_crt_init(&cert);
+ if (ret < 0)
+ {
+ certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
+ return;
+ }
+
+ cert_list_size = 0;
+ cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
+ if (cert_list == NULL)
+ {
+ certinfo->data.insert(std::make_pair("error","No certificate was found"));
+ return;
+ }
+
+ /* This is not a real world example, since we only check the first
+ * certificate in the given chain.
+ */
+
+ ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
+ if (ret < 0)
+ {
+ certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
+ return;
+ }
+
+ gnutls_x509_crt_get_dn(cert, name, &name_size);
+
+ certinfo->data.insert(std::make_pair("dn",name));
+
+ gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
+
+ certinfo->data.insert(std::make_pair("issuer",name));
+
+ if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0)
+ {
+ certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size)));
+ }
+
+ /* Beware here we do not check for errors.
+ */
+ if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0)))
+ {
+ certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
+ }
+
+ gnutls_x509_crt_deinit(cert);
+
+ return;
+ }
+
+};
+
+MODULE_INIT(ModuleSSLGnuTLS);
+
diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp
index 43dc43aea..ffd9d4032 100644
--- a/src/modules/extra/m_ssl_openssl.cpp
+++ b/src/modules/extra/m_ssl_openssl.cpp
@@ -1 +1,901 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <openssl/ssl.h> #include <openssl/err.h> #ifdef WINDOWS #include <openssl/applink.c> #endif #include "configreader.h" #include "users.h" #include "channels.h" #include "modules.h" #include "socket.h" #include "hashcomp.h" #include "transport.h" #ifdef WINDOWS #pragma comment(lib, "libeay32MTd") #pragma comment(lib, "ssleay32MTd") #undef MAX_DESCRIPTORS #define MAX_DESCRIPTORS 10000 #endif /* $ModDesc: Provides SSL support for clients */ /* $CompileFlags: pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */ /* $LinkerFlags: rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */ /* $ModDep: transport.h */ enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN }; enum issl_io_status { ISSL_WRITE, ISSL_READ }; static bool SelfSigned = false; bool isin(int port, const std::vector<int> &portlist) { for(unsigned int i = 0; i < portlist.size(); i++) if(portlist[i] == port) return true; return false; } char* get_error() { return ERR_error_string(ERR_get_error(), NULL); } static int error_callback(const char *str, size_t len, void *u); /** Represents an SSL user's extra data */ class issl_session : public classbase { public: SSL* sess; issl_status status; issl_io_status rstat; issl_io_status wstat; unsigned int inbufoffset; char* inbuf; // Buffer OpenSSL reads into. std::string outbuf; // Buffer for outgoing data that OpenSSL will not take. int fd; bool outbound; issl_session() { outbound = false; rstat = ISSL_READ; wstat = ISSL_WRITE; } }; static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) { /* XXX: This will allow self signed certificates. * In the future if we want an option to not allow this, * we can just return preverify_ok here, and openssl * will boot off self-signed and invalid peer certs. */ int ve = X509_STORE_CTX_get_error(ctx); SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT); return 1; } class ModuleSSLOpenSSL : public Module { ConfigReader* Conf; std::vector<int> listenports; int inbufsize; issl_session sessions[MAX_DESCRIPTORS]; SSL_CTX* ctx; SSL_CTX* clictx; char* dummy; char cipher[MAXBUF]; std::string keyfile; std::string certfile; std::string cafile; // std::string crlfile; std::string dhfile; std::string sslports; int clientactive; public: InspIRCd* PublicInstance; ModuleSSLOpenSSL(InspIRCd* Me) : Module(Me), PublicInstance(Me) { ServerInstance->PublishInterface("InspSocketHook", this); // Not rehashable...because I cba to reduce all the sizes of existing buffers. inbufsize = ServerInstance->Config->NetBufferSize; /* Global SSL library initialization*/ SSL_library_init(); SSL_load_error_strings(); /* Build our SSL contexts: * NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK. */ ctx = SSL_CTX_new( SSLv23_server_method() ); clictx = SSL_CTX_new( SSLv23_client_method() ); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); // Needs the flag as it ignores a plain /rehash OnRehash(NULL,"ssl"); } virtual void OnRehash(userrec* user, const std::string &param) { if (param != "ssl") return; Conf = new ConfigReader(ServerInstance); for (unsigned int i = 0; i < listenports.size(); i++) { ServerInstance->Config->DelIOHook(listenports[i]); } listenports.clear(); clientactive = 0; sslports.clear(); for (int i = 0; i < Conf->Enumerate("bind"); i++) { // For each <bind> tag std::string x = Conf->ReadValue("bind", "type", i); if (((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "openssl")) { // Get the port we're meant to be listening on with SSL std::string port = Conf->ReadValue("bind", "port", i); irc::portparser portrange(port, false); long portno = -1; while ((portno = portrange.GetToken())) { clientactive++; try { if (ServerInstance->Config->AddIOHook(portno, this)) { listenports.push_back(portno); for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++) if (ServerInstance->Config->ports[i]->GetPort() == portno) ServerInstance->Config->ports[i]->SetDescription("ssl"); ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %d", portno); sslports.append("*:").append(ConvToStr(portno)).append(";"); } else { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno); } } catch (ModuleException &e) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have another SSL or similar module loaded?", portno, e.GetReason()); } } } } if (!sslports.empty()) sslports.erase(sslports.end() - 1); std::string confdir(ServerInstance->ConfigFileName); // +1 so we the path ends with a / confdir = confdir.substr(0, confdir.find_last_of('/') + 1); cafile = Conf->ReadValue("openssl", "cafile", 0); certfile = Conf->ReadValue("openssl", "certfile", 0); keyfile = Conf->ReadValue("openssl", "keyfile", 0); dhfile = Conf->ReadValue("openssl", "dhfile", 0); // Set all the default values needed. if (cafile.empty()) cafile = "ca.pem"; if (certfile.empty()) certfile = "cert.pem"; if (keyfile.empty()) keyfile = "key.pem"; if (dhfile.empty()) dhfile = "dhparams.pem"; // Prepend relative paths with the path to the config directory. if (cafile[0] != '/') cafile = confdir + cafile; if (certfile[0] != '/') certfile = confdir + certfile; if (keyfile[0] != '/') keyfile = confdir + keyfile; if (dhfile[0] != '/') dhfile = confdir + dhfile; /* Load our keys and certificates * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck. */ if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str()))) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno)); ERR_print_errors_cb(error_callback, this); } if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM))) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno)); ERR_print_errors_cb(error_callback, this); } /* Load the CAs we trust*/ if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0))) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. %s", cafile.c_str(), strerror(errno)); ERR_print_errors_cb(error_callback, this); } FILE* dhpfile = fopen(dhfile.c_str(), "r"); DH* ret; if (dhpfile == NULL) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno)); throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno)); } else { ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL); if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0)) { ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str()); ERR_print_errors_cb(error_callback, this); } } fclose(dhpfile); DELETE(Conf); } virtual void On005Numeric(std::string &output) { output.append(" SSL=" + sslports); } virtual ~ModuleSSLOpenSSL() { SSL_CTX_free(ctx); SSL_CTX_free(clictx); } virtual void OnCleanup(int target_type, void* item) { if (target_type == TYPE_USER) { userrec* user = (userrec*)item; if (user->GetExt("ssl", dummy) && IS_LOCAL(user) && isin(user->GetPort(), listenports)) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading"); } if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports)) { ssl_cert* tofree; user->GetExt("ssl_cert", tofree); delete tofree; user->Shrink("ssl_cert"); } } } virtual void OnUnloadModule(Module* mod, const std::string &name) { if (mod == this) { for(unsigned int i = 0; i < listenports.size(); i++) { ServerInstance->Config->DelIOHook(listenports[i]); for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++) if (ServerInstance->Config->ports[j]->GetPort() == listenports[i]) ServerInstance->Config->ports[j]->SetDescription("plaintext"); } } } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); } void Implements(char* List) { List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = List[I_On005Numeric] = 1; List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1; } virtual char* OnRequest(Request* request) { ISHRequest* ISR = (ISHRequest*)request; if (strcmp("IS_NAME", request->GetId()) == 0) { return "openssl"; } else if (strcmp("IS_HOOK", request->GetId()) == 0) { char* ret = "OK"; try { ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } catch (ModuleException &e) { return NULL; } return ret; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { ServerInstance->Log(DEBUG,"Module checking if handshake is done"); if (ISR->Sock->GetFd() < 0) return (char*)"OK"; issl_session* session = &sessions[ISR->Sock->GetFd()]; return (session->status == ISSL_HANDSHAKING) ? NULL : (char*)"OK"; } else if (strcmp("IS_ATTACH", request->GetId()) == 0) { issl_session* session = &sessions[ISR->Sock->GetFd()]; if (session->sess) { VerifyCertificate(session, (InspSocket*)ISR->Sock); return "OK"; } } return NULL; } virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) { issl_session* session = &sessions[fd]; session->fd = fd; session->inbuf = new char[inbufsize]; session->inbufoffset = 0; session->sess = SSL_new(ctx); session->status = ISSL_NONE; session->outbound = false; if (session->sess == NULL) return; if (SSL_set_fd(session->sess, fd) == 0) { ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); return; } Handshake(session); } virtual void OnRawSocketConnect(int fd) { ServerInstance->Log(DEBUG,"OnRawSocketConnect connecting"); issl_session* session = &sessions[fd]; session->fd = fd; session->inbuf = new char[inbufsize]; session->inbufoffset = 0; session->sess = SSL_new(clictx); session->status = ISSL_NONE; session->outbound = true; if (session->sess == NULL) return; if (SSL_set_fd(session->sess, fd) == 0) { ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); return; } Handshake(session); ServerInstance->Log(DEBUG,"Exiting OnRawSocketConnect"); } virtual void OnRawSocketClose(int fd) { CloseSession(&sessions[fd]); EventHandler* user = ServerInstance->SE->GetRef(fd); if ((user) && (user->GetExt("ssl_cert", dummy))) { ssl_cert* tofree; user->GetExt("ssl_cert", tofree); delete tofree; user->Shrink("ssl_cert"); } } virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) { issl_session* session = &sessions[fd]; ServerInstance->Log(DEBUG,"OnRawSocketRead"); if (!session->sess) { ServerInstance->Log(DEBUG,"OnRawSocketRead has no session"); readresult = 0; CloseSession(session); return 1; } if (session->status == ISSL_HANDSHAKING) { if (session->rstat == ISSL_READ || session->wstat == ISSL_READ) { ServerInstance->Log(DEBUG,"Resume handshake in read"); // The handshake isn't finished and it wants to read, try to finish it. if (!Handshake(session)) { ServerInstance->Log(DEBUG,"Cant resume handshake in read"); // Couldn't resume handshake. return -1; } } else { errno = EAGAIN; return -1; } } // If we resumed the handshake then session->status will be ISSL_OPEN if (session->status == ISSL_OPEN) { if (session->wstat == ISSL_READ) { if(DoWrite(session) == 0) return 0; } if (session->rstat == ISSL_READ) { int ret = DoRead(session); if (ret > 0) { if (count <= session->inbufoffset) { memcpy(buffer, session->inbuf, count); // Move the stuff left in inbuf to the beginning of it memcpy(session->inbuf, session->inbuf + count, (session->inbufoffset - count)); // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp. session->inbufoffset -= count; // Insp uses readresult as the count of how much data there is in buffer, so: readresult = count; } else { // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing. memcpy(buffer, session->inbuf, session->inbufoffset); readresult = session->inbufoffset; // Zero the offset, as there's nothing there.. session->inbufoffset = 0; } return 1; } else { return ret; } } } return -1; } virtual int OnRawSocketWrite(int fd, const char* buffer, int count) { issl_session* session = &sessions[fd]; if (!session->sess) { ServerInstance->Log(DEBUG,"Close session missing sess"); CloseSession(session); return -1; } session->outbuf.append(buffer, count); if (session->status == ISSL_HANDSHAKING) { // The handshake isn't finished, try to finish it. if (session->rstat == ISSL_WRITE || session->wstat == ISSL_WRITE) { ServerInstance->Log(DEBUG,"Handshake resume"); Handshake(session); } } if (session->status == ISSL_OPEN) { if (session->rstat == ISSL_WRITE) { ServerInstance->Log(DEBUG,"DoRead"); DoRead(session); } if (session->wstat == ISSL_WRITE) { ServerInstance->Log(DEBUG,"DoWrite"); return DoWrite(session); } } return 1; } int DoWrite(issl_session* session) { if (!session->outbuf.size()) return -1; int ret = SSL_write(session->sess, session->outbuf.data(), session->outbuf.size()); if (ret == 0) { ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_write"); CloseSession(session); return 0; } else if (ret < 0) { int err = SSL_get_error(session->sess, ret); if (err == SSL_ERROR_WANT_WRITE) { session->wstat = ISSL_WRITE; return -1; } else if (err == SSL_ERROR_WANT_READ) { session->wstat = ISSL_READ; return -1; } else { ServerInstance->Log(DEBUG,"Close due to returned -1 in SSL_Write"); CloseSession(session); return 0; } } else { session->outbuf = session->outbuf.substr(ret); return ret; } } int DoRead(issl_session* session) { // Is this right? Not sure if the unencrypted data is garaunteed to be the same length. // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet. ServerInstance->Log(DEBUG,"DoRead"); int ret = SSL_read(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset); if (ret == 0) { // Client closed connection. ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_read"); CloseSession(session); return 0; } else if (ret < 0) { int err = SSL_get_error(session->sess, ret); if (err == SSL_ERROR_WANT_READ) { session->rstat = ISSL_READ; ServerInstance->Log(DEBUG,"Setting want_read"); return -1; } else if (err == SSL_ERROR_WANT_WRITE) { session->rstat = ISSL_WRITE; ServerInstance->Log(DEBUG,"Setting want_write"); return -1; } else { ServerInstance->Log(DEBUG,"Closed due to returned -1 in SSL_Read"); CloseSession(session); return 0; } } else { // Read successfully 'ret' bytes into inbuf + inbufoffset // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf' // 'buffer' is 'count' long session->inbufoffset += ret; return ret; } } // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection virtual void OnWhois(userrec* source, userrec* dest) { if (!clientactive) return; // Bugfix, only send this numeric for *our* SSL users if (dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports))) { ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick); } } virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable) { // check if the linking module wants to know about OUR metadata if (extname == "ssl") { // check if this user has an swhois field to send if(user->GetExt(extname, dummy)) { // call this function in the linking module, let it format the data how it // sees fit, and send it on its way. We dont need or want to know how. proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON"); } } } virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata) { // check if its our metadata key, and its associated with a user if ((target_type == TYPE_USER) && (extname == "ssl")) { userrec* dest = (userrec*)target; // if they dont already have an ssl flag, accept the remote server's if (!dest->GetExt(extname, dummy)) { dest->Extend(extname, "ON"); } } } bool Handshake(issl_session* session) { ServerInstance->Log(DEBUG,"Handshake"); int ret; if (session->outbound) { ServerInstance->Log(DEBUG,"SSL_connect"); ret = SSL_connect(session->sess); } else ret = SSL_accept(session->sess); if (ret < 0) { int err = SSL_get_error(session->sess, ret); if (err == SSL_ERROR_WANT_READ) { ServerInstance->Log(DEBUG,"Want read, handshaking"); session->rstat = ISSL_READ; session->status = ISSL_HANDSHAKING; return true; } else if (err == SSL_ERROR_WANT_WRITE) { ServerInstance->Log(DEBUG,"Want write, handshaking"); session->wstat = ISSL_WRITE; session->status = ISSL_HANDSHAKING; MakePollWrite(session); return true; } else { ServerInstance->Log(DEBUG,"Handshake failed"); CloseSession(session); } return false; } else if (ret > 0) { // Handshake complete. // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater. userrec* u = ServerInstance->FindDescriptor(session->fd); if (u) { if (!u->GetExt("ssl", dummy)) u->Extend("ssl", "ON"); } session->status = ISSL_OPEN; MakePollWrite(session); return true; } else if (ret == 0) { int ssl_err = SSL_get_error(session->sess, ret); char buf[1024]; ERR_print_errors_fp(stderr); ServerInstance->Log(DEBUG,"Handshake fail 2: %d: %s", ssl_err, ERR_error_string(ssl_err,buf)); CloseSession(session); return true; } return true; } virtual void OnPostConnect(userrec* user) { // This occurs AFTER OnUserConnect so we can be sure the // protocol module has propogated the NICK message. if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user))) { // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW. std::deque<std::string>* metadata = new std::deque<std::string>; metadata->push_back(user->nick); metadata->push_back("ssl"); // The metadata id metadata->push_back("ON"); // The value to send Event* event = new Event((char*)metadata,(Module*)this,"send_metadata"); event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up. DELETE(event); DELETE(metadata); VerifyCertificate(&sessions[user->GetFd()], user); if (sessions[user->GetFd()].sess) user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, SSL_get_cipher(sessions[user->GetFd()].sess)); } } void MakePollWrite(issl_session* session) { OnRawSocketWrite(session->fd, NULL, 0); //EventHandler* eh = ServerInstance->FindDescriptor(session->fd); //if (eh) // ServerInstance->SE->WantWrite(eh); } void CloseSession(issl_session* session) { if (session->sess) { SSL_shutdown(session->sess); SSL_free(session->sess); } if (session->inbuf) { delete[] session->inbuf; } session->outbuf.clear(); session->inbuf = NULL; session->sess = NULL; session->status = ISSL_NONE; } void VerifyCertificate(issl_session* session, Extensible* user) { if (!session->sess || !user) return; X509* cert; ssl_cert* certinfo = new ssl_cert; unsigned int n; unsigned char md[EVP_MAX_MD_SIZE]; const EVP_MD *digest = EVP_md5(); user->Extend("ssl_cert",certinfo); cert = SSL_get_peer_certificate((SSL*)session->sess); if (!cert) { certinfo->data.insert(std::make_pair("error","Could not get peer certificate: "+std::string(get_error()))); return; } certinfo->data.insert(std::make_pair("invalid", SSL_get_verify_result(session->sess) != X509_V_OK ? ConvToStr(1) : ConvToStr(0))); if (SelfSigned) { certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0))); certinfo->data.insert(std::make_pair("trusted",ConvToStr(1))); } else { certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1))); certinfo->data.insert(std::make_pair("trusted",ConvToStr(0))); } certinfo->data.insert(std::make_pair("dn",std::string(X509_NAME_oneline(X509_get_subject_name(cert),0,0)))); certinfo->data.insert(std::make_pair("issuer",std::string(X509_NAME_oneline(X509_get_issuer_name(cert),0,0)))); if (!X509_digest(cert, digest, md, &n)) { certinfo->data.insert(std::make_pair("error","Out of memory generating fingerprint")); } else { certinfo->data.insert(std::make_pair("fingerprint",irc::hex(md, n))); } if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), time(NULL)) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), time(NULL)) == 0)) { certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate")); } X509_free(cert); } }; static int error_callback(const char *str, size_t len, void *u) { ModuleSSLOpenSSL* mssl = (ModuleSSLOpenSSL*)u; mssl->PublicInstance->Log(DEFAULT, "SSL error: " + std::string(str, len - 1)); return 0; } MODULE_INIT(ModuleSSLOpenSSL); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+
+#ifdef WINDOWS
+#include <openssl/applink.c>
+#endif
+
+#include "configreader.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+
+#include "socket.h"
+#include "hashcomp.h"
+
+#include "transport.h"
+
+#ifdef WINDOWS
+#pragma comment(lib, "libeay32MTd")
+#pragma comment(lib, "ssleay32MTd")
+#undef MAX_DESCRIPTORS
+#define MAX_DESCRIPTORS 10000
+#endif
+
+/* $ModDesc: Provides SSL support for clients */
+/* $CompileFlags: pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */
+/* $LinkerFlags: rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */
+/* $ModDep: transport.h */
+
+enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN };
+enum issl_io_status { ISSL_WRITE, ISSL_READ };
+
+static bool SelfSigned = false;
+
+bool isin(int port, const std::vector<int> &portlist)
+{
+ for(unsigned int i = 0; i < portlist.size(); i++)
+ if(portlist[i] == port)
+ return true;
+
+ return false;
+}
+
+char* get_error()
+{
+ return ERR_error_string(ERR_get_error(), NULL);
+}
+
+static int error_callback(const char *str, size_t len, void *u);
+
+/** Represents an SSL user's extra data
+ */
+class issl_session : public classbase
+{
+public:
+ SSL* sess;
+ issl_status status;
+ issl_io_status rstat;
+ issl_io_status wstat;
+
+ unsigned int inbufoffset;
+ char* inbuf; // Buffer OpenSSL reads into.
+ std::string outbuf; // Buffer for outgoing data that OpenSSL will not take.
+ int fd;
+ bool outbound;
+
+ issl_session()
+ {
+ outbound = false;
+ rstat = ISSL_READ;
+ wstat = ISSL_WRITE;
+ }
+};
+
+static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
+{
+ /* XXX: This will allow self signed certificates.
+ * In the future if we want an option to not allow this,
+ * we can just return preverify_ok here, and openssl
+ * will boot off self-signed and invalid peer certs.
+ */
+ int ve = X509_STORE_CTX_get_error(ctx);
+
+ SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT);
+
+ return 1;
+}
+
+class ModuleSSLOpenSSL : public Module
+{
+
+ ConfigReader* Conf;
+
+ std::vector<int> listenports;
+
+ int inbufsize;
+ issl_session sessions[MAX_DESCRIPTORS];
+
+ SSL_CTX* ctx;
+ SSL_CTX* clictx;
+
+ char* dummy;
+ char cipher[MAXBUF];
+
+ std::string keyfile;
+ std::string certfile;
+ std::string cafile;
+ // std::string crlfile;
+ std::string dhfile;
+ std::string sslports;
+
+ int clientactive;
+
+ public:
+
+ InspIRCd* PublicInstance;
+
+ ModuleSSLOpenSSL(InspIRCd* Me)
+ : Module(Me), PublicInstance(Me)
+ {
+ ServerInstance->PublishInterface("InspSocketHook", this);
+
+ // Not rehashable...because I cba to reduce all the sizes of existing buffers.
+ inbufsize = ServerInstance->Config->NetBufferSize;
+
+ /* Global SSL library initialization*/
+ SSL_library_init();
+ SSL_load_error_strings();
+
+ /* Build our SSL contexts:
+ * NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK.
+ */
+ ctx = SSL_CTX_new( SSLv23_server_method() );
+ clictx = SSL_CTX_new( SSLv23_client_method() );
+
+ SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+ SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+
+ // Needs the flag as it ignores a plain /rehash
+ OnRehash(NULL,"ssl");
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &param)
+ {
+ if (param != "ssl")
+ return;
+
+ Conf = new ConfigReader(ServerInstance);
+
+ for (unsigned int i = 0; i < listenports.size(); i++)
+ {
+ ServerInstance->Config->DelIOHook(listenports[i]);
+ }
+
+ listenports.clear();
+ clientactive = 0;
+ sslports.clear();
+
+ for (int i = 0; i < Conf->Enumerate("bind"); i++)
+ {
+ // For each <bind> tag
+ std::string x = Conf->ReadValue("bind", "type", i);
+ if (((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "openssl"))
+ {
+ // Get the port we're meant to be listening on with SSL
+ std::string port = Conf->ReadValue("bind", "port", i);
+ irc::portparser portrange(port, false);
+ long portno = -1;
+ while ((portno = portrange.GetToken()))
+ {
+ clientactive++;
+ try
+ {
+ if (ServerInstance->Config->AddIOHook(portno, this))
+ {
+ listenports.push_back(portno);
+ for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
+ if (ServerInstance->Config->ports[i]->GetPort() == portno)
+ ServerInstance->Config->ports[i]->SetDescription("ssl");
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %d", portno);
+ sslports.append("*:").append(ConvToStr(portno)).append(";");
+ }
+ else
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno);
+ }
+ }
+ catch (ModuleException &e)
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have another SSL or similar module loaded?", portno, e.GetReason());
+ }
+ }
+ }
+ }
+
+ if (!sslports.empty())
+ sslports.erase(sslports.end() - 1);
+
+ std::string confdir(ServerInstance->ConfigFileName);
+ // +1 so we the path ends with a /
+ confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
+
+ cafile = Conf->ReadValue("openssl", "cafile", 0);
+ certfile = Conf->ReadValue("openssl", "certfile", 0);
+ keyfile = Conf->ReadValue("openssl", "keyfile", 0);
+ dhfile = Conf->ReadValue("openssl", "dhfile", 0);
+
+ // Set all the default values needed.
+ if (cafile.empty())
+ cafile = "ca.pem";
+
+ if (certfile.empty())
+ certfile = "cert.pem";
+
+ if (keyfile.empty())
+ keyfile = "key.pem";
+
+ if (dhfile.empty())
+ dhfile = "dhparams.pem";
+
+ // Prepend relative paths with the path to the config directory.
+ if (cafile[0] != '/')
+ cafile = confdir + cafile;
+
+ if (certfile[0] != '/')
+ certfile = confdir + certfile;
+
+ if (keyfile[0] != '/')
+ keyfile = confdir + keyfile;
+
+ if (dhfile[0] != '/')
+ dhfile = confdir + dhfile;
+
+ /* Load our keys and certificates
+ * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck.
+ */
+ if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str())))
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno));
+ ERR_print_errors_cb(error_callback, this);
+ }
+
+ if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM)))
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno));
+ ERR_print_errors_cb(error_callback, this);
+ }
+
+ /* Load the CAs we trust*/
+ if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0)))
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. %s", cafile.c_str(), strerror(errno));
+ ERR_print_errors_cb(error_callback, this);
+ }
+
+ FILE* dhpfile = fopen(dhfile.c_str(), "r");
+ DH* ret;
+
+ if (dhpfile == NULL)
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno));
+ throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno));
+ }
+ else
+ {
+ ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL);
+ if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0))
+ {
+ ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str());
+ ERR_print_errors_cb(error_callback, this);
+ }
+ }
+
+ fclose(dhpfile);
+
+ DELETE(Conf);
+ }
+
+ virtual void On005Numeric(std::string &output)
+ {
+ output.append(" SSL=" + sslports);
+ }
+
+ virtual ~ModuleSSLOpenSSL()
+ {
+ SSL_CTX_free(ctx);
+ SSL_CTX_free(clictx);
+ }
+
+ virtual void OnCleanup(int target_type, void* item)
+ {
+ if (target_type == TYPE_USER)
+ {
+ userrec* user = (userrec*)item;
+
+ if (user->GetExt("ssl", dummy) && IS_LOCAL(user) && isin(user->GetPort(), listenports))
+ {
+ // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
+ // Potentially there could be multiple SSL modules loaded at once on different ports.
+ ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading");
+ }
+ if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports))
+ {
+ ssl_cert* tofree;
+ user->GetExt("ssl_cert", tofree);
+ delete tofree;
+ user->Shrink("ssl_cert");
+ }
+ }
+ }
+
+ virtual void OnUnloadModule(Module* mod, const std::string &name)
+ {
+ if (mod == this)
+ {
+ for(unsigned int i = 0; i < listenports.size(); i++)
+ {
+ ServerInstance->Config->DelIOHook(listenports[i]);
+ for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
+ if (ServerInstance->Config->ports[j]->GetPort() == listenports[i])
+ ServerInstance->Config->ports[j]->SetDescription("plaintext");
+ }
+ }
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = List[I_On005Numeric] = 1;
+ List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1;
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ ISHRequest* ISR = (ISHRequest*)request;
+ if (strcmp("IS_NAME", request->GetId()) == 0)
+ {
+ return "openssl";
+ }
+ else if (strcmp("IS_HOOK", request->GetId()) == 0)
+ {
+ char* ret = "OK";
+ try
+ {
+ ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ catch (ModuleException &e)
+ {
+ return NULL;
+ }
+
+ return ret;
+ }
+ else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
+ {
+ return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ else if (strcmp("IS_HSDONE", request->GetId()) == 0)
+ {
+ ServerInstance->Log(DEBUG,"Module checking if handshake is done");
+ if (ISR->Sock->GetFd() < 0)
+ return (char*)"OK";
+
+ issl_session* session = &sessions[ISR->Sock->GetFd()];
+ return (session->status == ISSL_HANDSHAKING) ? NULL : (char*)"OK";
+ }
+ else if (strcmp("IS_ATTACH", request->GetId()) == 0)
+ {
+ issl_session* session = &sessions[ISR->Sock->GetFd()];
+ if (session->sess)
+ {
+ VerifyCertificate(session, (InspSocket*)ISR->Sock);
+ return "OK";
+ }
+ }
+ return NULL;
+ }
+
+
+ virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
+ {
+ issl_session* session = &sessions[fd];
+
+ session->fd = fd;
+ session->inbuf = new char[inbufsize];
+ session->inbufoffset = 0;
+ session->sess = SSL_new(ctx);
+ session->status = ISSL_NONE;
+ session->outbound = false;
+
+ if (session->sess == NULL)
+ return;
+
+ if (SSL_set_fd(session->sess, fd) == 0)
+ {
+ ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
+ return;
+ }
+
+ Handshake(session);
+ }
+
+ virtual void OnRawSocketConnect(int fd)
+ {
+ ServerInstance->Log(DEBUG,"OnRawSocketConnect connecting");
+ issl_session* session = &sessions[fd];
+
+ session->fd = fd;
+ session->inbuf = new char[inbufsize];
+ session->inbufoffset = 0;
+ session->sess = SSL_new(clictx);
+ session->status = ISSL_NONE;
+ session->outbound = true;
+
+ if (session->sess == NULL)
+ return;
+
+ if (SSL_set_fd(session->sess, fd) == 0)
+ {
+ ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
+ return;
+ }
+
+ Handshake(session);
+ ServerInstance->Log(DEBUG,"Exiting OnRawSocketConnect");
+ }
+
+ virtual void OnRawSocketClose(int fd)
+ {
+ CloseSession(&sessions[fd]);
+
+ EventHandler* user = ServerInstance->SE->GetRef(fd);
+
+ if ((user) && (user->GetExt("ssl_cert", dummy)))
+ {
+ ssl_cert* tofree;
+ user->GetExt("ssl_cert", tofree);
+ delete tofree;
+ user->Shrink("ssl_cert");
+ }
+ }
+
+ virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
+ {
+ issl_session* session = &sessions[fd];
+
+ ServerInstance->Log(DEBUG,"OnRawSocketRead");
+
+ if (!session->sess)
+ {
+ ServerInstance->Log(DEBUG,"OnRawSocketRead has no session");
+ readresult = 0;
+ CloseSession(session);
+ return 1;
+ }
+
+ if (session->status == ISSL_HANDSHAKING)
+ {
+ if (session->rstat == ISSL_READ || session->wstat == ISSL_READ)
+ {
+ ServerInstance->Log(DEBUG,"Resume handshake in read");
+ // The handshake isn't finished and it wants to read, try to finish it.
+ if (!Handshake(session))
+ {
+ ServerInstance->Log(DEBUG,"Cant resume handshake in read");
+ // Couldn't resume handshake.
+ return -1;
+ }
+ }
+ else
+ {
+ errno = EAGAIN;
+ return -1;
+ }
+ }
+
+ // If we resumed the handshake then session->status will be ISSL_OPEN
+
+ if (session->status == ISSL_OPEN)
+ {
+ if (session->wstat == ISSL_READ)
+ {
+ if(DoWrite(session) == 0)
+ return 0;
+ }
+
+ if (session->rstat == ISSL_READ)
+ {
+ int ret = DoRead(session);
+
+ if (ret > 0)
+ {
+ if (count <= session->inbufoffset)
+ {
+ memcpy(buffer, session->inbuf, count);
+ // Move the stuff left in inbuf to the beginning of it
+ memcpy(session->inbuf, session->inbuf + count, (session->inbufoffset - count));
+ // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
+ session->inbufoffset -= count;
+ // Insp uses readresult as the count of how much data there is in buffer, so:
+ readresult = count;
+ }
+ else
+ {
+ // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
+ memcpy(buffer, session->inbuf, session->inbufoffset);
+
+ readresult = session->inbufoffset;
+ // Zero the offset, as there's nothing there..
+ session->inbufoffset = 0;
+ }
+
+ return 1;
+ }
+ else
+ {
+ return ret;
+ }
+ }
+ }
+
+ return -1;
+ }
+
+ virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
+ {
+ issl_session* session = &sessions[fd];
+
+ if (!session->sess)
+ {
+ ServerInstance->Log(DEBUG,"Close session missing sess");
+ CloseSession(session);
+ return -1;
+ }
+
+ session->outbuf.append(buffer, count);
+
+ if (session->status == ISSL_HANDSHAKING)
+ {
+ // The handshake isn't finished, try to finish it.
+ if (session->rstat == ISSL_WRITE || session->wstat == ISSL_WRITE)
+ {
+ ServerInstance->Log(DEBUG,"Handshake resume");
+ Handshake(session);
+ }
+ }
+
+ if (session->status == ISSL_OPEN)
+ {
+ if (session->rstat == ISSL_WRITE)
+ {
+ ServerInstance->Log(DEBUG,"DoRead");
+ DoRead(session);
+ }
+
+ if (session->wstat == ISSL_WRITE)
+ {
+ ServerInstance->Log(DEBUG,"DoWrite");
+ return DoWrite(session);
+ }
+ }
+
+ return 1;
+ }
+
+ int DoWrite(issl_session* session)
+ {
+ if (!session->outbuf.size())
+ return -1;
+
+ int ret = SSL_write(session->sess, session->outbuf.data(), session->outbuf.size());
+
+ if (ret == 0)
+ {
+ ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_write");
+ CloseSession(session);
+ return 0;
+ }
+ else if (ret < 0)
+ {
+ int err = SSL_get_error(session->sess, ret);
+
+ if (err == SSL_ERROR_WANT_WRITE)
+ {
+ session->wstat = ISSL_WRITE;
+ return -1;
+ }
+ else if (err == SSL_ERROR_WANT_READ)
+ {
+ session->wstat = ISSL_READ;
+ return -1;
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG,"Close due to returned -1 in SSL_Write");
+ CloseSession(session);
+ return 0;
+ }
+ }
+ else
+ {
+ session->outbuf = session->outbuf.substr(ret);
+ return ret;
+ }
+ }
+
+ int DoRead(issl_session* session)
+ {
+ // Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
+ // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
+
+ ServerInstance->Log(DEBUG,"DoRead");
+
+ int ret = SSL_read(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
+
+ if (ret == 0)
+ {
+ // Client closed connection.
+ ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_read");
+ CloseSession(session);
+ return 0;
+ }
+ else if (ret < 0)
+ {
+ int err = SSL_get_error(session->sess, ret);
+
+ if (err == SSL_ERROR_WANT_READ)
+ {
+ session->rstat = ISSL_READ;
+ ServerInstance->Log(DEBUG,"Setting want_read");
+ return -1;
+ }
+ else if (err == SSL_ERROR_WANT_WRITE)
+ {
+ session->rstat = ISSL_WRITE;
+ ServerInstance->Log(DEBUG,"Setting want_write");
+ return -1;
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG,"Closed due to returned -1 in SSL_Read");
+ CloseSession(session);
+ return 0;
+ }
+ }
+ else
+ {
+ // Read successfully 'ret' bytes into inbuf + inbufoffset
+ // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
+ // 'buffer' is 'count' long
+
+ session->inbufoffset += ret;
+
+ return ret;
+ }
+ }
+
+ // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
+ virtual void OnWhois(userrec* source, userrec* dest)
+ {
+ if (!clientactive)
+ return;
+
+ // Bugfix, only send this numeric for *our* SSL users
+ if (dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports)))
+ {
+ ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick);
+ }
+ }
+
+ virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
+ {
+ // check if the linking module wants to know about OUR metadata
+ if (extname == "ssl")
+ {
+ // check if this user has an swhois field to send
+ if(user->GetExt(extname, dummy))
+ {
+ // call this function in the linking module, let it format the data how it
+ // sees fit, and send it on its way. We dont need or want to know how.
+ proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
+ }
+ }
+ }
+
+ virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
+ {
+ // check if its our metadata key, and its associated with a user
+ if ((target_type == TYPE_USER) && (extname == "ssl"))
+ {
+ userrec* dest = (userrec*)target;
+ // if they dont already have an ssl flag, accept the remote server's
+ if (!dest->GetExt(extname, dummy))
+ {
+ dest->Extend(extname, "ON");
+ }
+ }
+ }
+
+ bool Handshake(issl_session* session)
+ {
+ ServerInstance->Log(DEBUG,"Handshake");
+ int ret;
+
+ if (session->outbound)
+ {
+ ServerInstance->Log(DEBUG,"SSL_connect");
+ ret = SSL_connect(session->sess);
+ }
+ else
+ ret = SSL_accept(session->sess);
+
+ if (ret < 0)
+ {
+ int err = SSL_get_error(session->sess, ret);
+
+ if (err == SSL_ERROR_WANT_READ)
+ {
+ ServerInstance->Log(DEBUG,"Want read, handshaking");
+ session->rstat = ISSL_READ;
+ session->status = ISSL_HANDSHAKING;
+ return true;
+ }
+ else if (err == SSL_ERROR_WANT_WRITE)
+ {
+ ServerInstance->Log(DEBUG,"Want write, handshaking");
+ session->wstat = ISSL_WRITE;
+ session->status = ISSL_HANDSHAKING;
+ MakePollWrite(session);
+ return true;
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG,"Handshake failed");
+ CloseSession(session);
+ }
+
+ return false;
+ }
+ else if (ret > 0)
+ {
+ // Handshake complete.
+ // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
+ userrec* u = ServerInstance->FindDescriptor(session->fd);
+ if (u)
+ {
+ if (!u->GetExt("ssl", dummy))
+ u->Extend("ssl", "ON");
+ }
+
+ session->status = ISSL_OPEN;
+
+ MakePollWrite(session);
+
+ return true;
+ }
+ else if (ret == 0)
+ {
+ int ssl_err = SSL_get_error(session->sess, ret);
+ char buf[1024];
+ ERR_print_errors_fp(stderr);
+ ServerInstance->Log(DEBUG,"Handshake fail 2: %d: %s", ssl_err, ERR_error_string(ssl_err,buf));
+ CloseSession(session);
+ return true;
+ }
+
+ return true;
+ }
+
+ virtual void OnPostConnect(userrec* user)
+ {
+ // This occurs AFTER OnUserConnect so we can be sure the
+ // protocol module has propogated the NICK message.
+ if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
+ {
+ // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
+ std::deque<std::string>* metadata = new std::deque<std::string>;
+ metadata->push_back(user->nick);
+ metadata->push_back("ssl"); // The metadata id
+ metadata->push_back("ON"); // The value to send
+ Event* event = new Event((char*)metadata,(Module*)this,"send_metadata");
+ event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up.
+ DELETE(event);
+ DELETE(metadata);
+
+ VerifyCertificate(&sessions[user->GetFd()], user);
+ if (sessions[user->GetFd()].sess)
+ user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, SSL_get_cipher(sessions[user->GetFd()].sess));
+ }
+ }
+
+ void MakePollWrite(issl_session* session)
+ {
+ OnRawSocketWrite(session->fd, NULL, 0);
+ //EventHandler* eh = ServerInstance->FindDescriptor(session->fd);
+ //if (eh)
+ // ServerInstance->SE->WantWrite(eh);
+ }
+
+ void CloseSession(issl_session* session)
+ {
+ if (session->sess)
+ {
+ SSL_shutdown(session->sess);
+ SSL_free(session->sess);
+ }
+
+ if (session->inbuf)
+ {
+ delete[] session->inbuf;
+ }
+
+ session->outbuf.clear();
+ session->inbuf = NULL;
+ session->sess = NULL;
+ session->status = ISSL_NONE;
+ }
+
+ void VerifyCertificate(issl_session* session, Extensible* user)
+ {
+ if (!session->sess || !user)
+ return;
+
+ X509* cert;
+ ssl_cert* certinfo = new ssl_cert;
+ unsigned int n;
+ unsigned char md[EVP_MAX_MD_SIZE];
+ const EVP_MD *digest = EVP_md5();
+
+ user->Extend("ssl_cert",certinfo);
+
+ cert = SSL_get_peer_certificate((SSL*)session->sess);
+
+ if (!cert)
+ {
+ certinfo->data.insert(std::make_pair("error","Could not get peer certificate: "+std::string(get_error())));
+ return;
+ }
+
+ certinfo->data.insert(std::make_pair("invalid", SSL_get_verify_result(session->sess) != X509_V_OK ? ConvToStr(1) : ConvToStr(0)));
+
+ if (SelfSigned)
+ {
+ certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
+ certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
+ certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
+ }
+
+ certinfo->data.insert(std::make_pair("dn",std::string(X509_NAME_oneline(X509_get_subject_name(cert),0,0))));
+ certinfo->data.insert(std::make_pair("issuer",std::string(X509_NAME_oneline(X509_get_issuer_name(cert),0,0))));
+
+ if (!X509_digest(cert, digest, md, &n))
+ {
+ certinfo->data.insert(std::make_pair("error","Out of memory generating fingerprint"));
+ }
+ else
+ {
+ certinfo->data.insert(std::make_pair("fingerprint",irc::hex(md, n)));
+ }
+
+ if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), time(NULL)) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), time(NULL)) == 0))
+ {
+ certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
+ }
+
+ X509_free(cert);
+ }
+};
+
+static int error_callback(const char *str, size_t len, void *u)
+{
+ ModuleSSLOpenSSL* mssl = (ModuleSSLOpenSSL*)u;
+ mssl->PublicInstance->Log(DEFAULT, "SSL error: " + std::string(str, len - 1));
+ return 0;
+}
+
+MODULE_INIT(ModuleSSLOpenSSL);
+
diff --git a/src/modules/extra/m_ssl_oper_cert.cpp b/src/modules/extra/m_ssl_oper_cert.cpp
index 7b1c90868..c67b50c8c 100644
--- a/src/modules/extra/m_ssl_oper_cert.cpp
+++ b/src/modules/extra/m_ssl_oper_cert.cpp
@@ -1 +1,180 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ /* $ModDesc: Allows for MD5 encrypted oper passwords */ /* $ModDep: transport.h */ #include "inspircd.h" #include "inspircd_config.h" #include "users.h" #include "channels.h" #include "modules.h" #include "transport.h" #include "wildcard.h" /** Handle /FINGERPRINT */ class cmd_fingerprint : public command_t { public: cmd_fingerprint (InspIRCd* Instance) : command_t(Instance,"FINGERPRINT", 0, 1) { this->source = "m_ssl_oper_cert.so"; syntax = "<nickname>"; } CmdResult Handle (const char** parameters, int pcnt, userrec *user) { userrec* target = ServerInstance->FindNick(parameters[0]); if (target) { ssl_cert* cert; if (target->GetExt("ssl_cert",cert)) { if (cert->GetFingerprint().length()) { user->WriteServ("NOTICE %s :Certificate fingerprint for %s is %s",user->nick,target->nick,cert->GetFingerprint().c_str()); return CMD_SUCCESS; } else { user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick,target->nick); return CMD_FAILURE; } } else { user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick, target->nick); return CMD_FAILURE; } } else { user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]); return CMD_FAILURE; } } }; class ModuleOperSSLCert : public Module { ssl_cert* cert; bool HasCert; cmd_fingerprint* mycommand; ConfigReader* cf; public: ModuleOperSSLCert(InspIRCd* Me) : Module(Me) { mycommand = new cmd_fingerprint(ServerInstance); ServerInstance->AddCommand(mycommand); cf = new ConfigReader(ServerInstance); } virtual ~ModuleOperSSLCert() { delete cf; } void Implements(char* List) { List[I_OnPreCommand] = List[I_OnRehash] = 1; } virtual void OnRehash(userrec* user, const std::string &parameter) { delete cf; cf = new ConfigReader(ServerInstance); } bool OneOfMatches(const char* host, const char* ip, const char* hostlist) { std::stringstream hl(hostlist); std::string xhost; while (hl >> xhost) { if (match(host,xhost.c_str()) || match(ip,xhost.c_str(),true)) { return true; } } return false; } virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) { irc::string cmd = command.c_str(); if ((cmd == "OPER") && (validated)) { char TheHost[MAXBUF]; char TheIP[MAXBUF]; std::string LoginName; std::string Password; std::string OperType; std::string HostName; std::string FingerPrint; bool SSLOnly; char* dummy; snprintf(TheHost,MAXBUF,"%s@%s",user->ident,user->host); snprintf(TheIP, MAXBUF,"%s@%s",user->ident,user->GetIPString()); HasCert = user->GetExt("ssl_cert",cert); for (int i = 0; i < cf->Enumerate("oper"); i++) { LoginName = cf->ReadValue("oper", "name", i); Password = cf->ReadValue("oper", "password", i); OperType = cf->ReadValue("oper", "type", i); HostName = cf->ReadValue("oper", "host", i); FingerPrint = cf->ReadValue("oper", "fingerprint", i); SSLOnly = cf->ReadFlag("oper", "sslonly", i); if (SSLOnly || !FingerPrint.empty()) { if ((!strcmp(LoginName.c_str(),parameters[0])) && (!ServerInstance->OperPassCompare(Password.c_str(),parameters[1],i)) && (OneOfMatches(TheHost,TheIP,HostName.c_str()))) { if (SSLOnly && !user->GetExt("ssl", dummy)) { user->WriteServ("491 %s :This oper login name requires an SSL connection.", user->nick); return 1; } /* This oper would match */ if ((!cert) || (cert->GetFingerprint() != FingerPrint)) { user->WriteServ("491 %s :This oper login name requires a matching key fingerprint.",user->nick); ServerInstance->SNO->WriteToSnoMask('o',"'%s' cannot oper, does not match fingerprint", user->nick); ServerInstance->Log(DEFAULT,"OPER: Failed oper attempt by %s!%s@%s: credentials valid, but wrong fingerprint.",user->nick,user->ident,user->host); return 1; } } } } } return 0; } virtual Version GetVersion() { return Version(1,1,0,0,VF_VENDOR,API_VERSION); } }; MODULE_INIT(ModuleOperSSLCert); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+/* $ModDesc: Allows for MD5 encrypted oper passwords */
+/* $ModDep: transport.h */
+
+#include "inspircd.h"
+#include "inspircd_config.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "transport.h"
+#include "wildcard.h"
+
+/** Handle /FINGERPRINT
+ */
+class cmd_fingerprint : public command_t
+{
+ public:
+ cmd_fingerprint (InspIRCd* Instance) : command_t(Instance,"FINGERPRINT", 0, 1)
+ {
+ this->source = "m_ssl_oper_cert.so";
+ syntax = "<nickname>";
+ }
+
+ CmdResult Handle (const char** parameters, int pcnt, userrec *user)
+ {
+ userrec* target = ServerInstance->FindNick(parameters[0]);
+ if (target)
+ {
+ ssl_cert* cert;
+ if (target->GetExt("ssl_cert",cert))
+ {
+ if (cert->GetFingerprint().length())
+ {
+ user->WriteServ("NOTICE %s :Certificate fingerprint for %s is %s",user->nick,target->nick,cert->GetFingerprint().c_str());
+ return CMD_SUCCESS;
+ }
+ else
+ {
+ user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick,target->nick);
+ return CMD_FAILURE;
+ }
+ }
+ else
+ {
+ user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick, target->nick);
+ return CMD_FAILURE;
+ }
+ }
+ else
+ {
+ user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]);
+ return CMD_FAILURE;
+ }
+ }
+};
+
+
+
+class ModuleOperSSLCert : public Module
+{
+ ssl_cert* cert;
+ bool HasCert;
+ cmd_fingerprint* mycommand;
+ ConfigReader* cf;
+ public:
+
+ ModuleOperSSLCert(InspIRCd* Me)
+ : Module(Me)
+ {
+ mycommand = new cmd_fingerprint(ServerInstance);
+ ServerInstance->AddCommand(mycommand);
+ cf = new ConfigReader(ServerInstance);
+ }
+
+ virtual ~ModuleOperSSLCert()
+ {
+ delete cf;
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnPreCommand] = List[I_OnRehash] = 1;
+ }
+
+ virtual void OnRehash(userrec* user, const std::string &parameter)
+ {
+ delete cf;
+ cf = new ConfigReader(ServerInstance);
+ }
+
+ bool OneOfMatches(const char* host, const char* ip, const char* hostlist)
+ {
+ std::stringstream hl(hostlist);
+ std::string xhost;
+ while (hl >> xhost)
+ {
+ if (match(host,xhost.c_str()) || match(ip,xhost.c_str(),true))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+
+ virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
+ {
+ irc::string cmd = command.c_str();
+
+ if ((cmd == "OPER") && (validated))
+ {
+ char TheHost[MAXBUF];
+ char TheIP[MAXBUF];
+ std::string LoginName;
+ std::string Password;
+ std::string OperType;
+ std::string HostName;
+ std::string FingerPrint;
+ bool SSLOnly;
+ char* dummy;
+
+ snprintf(TheHost,MAXBUF,"%s@%s",user->ident,user->host);
+ snprintf(TheIP, MAXBUF,"%s@%s",user->ident,user->GetIPString());
+
+ HasCert = user->GetExt("ssl_cert",cert);
+
+ for (int i = 0; i < cf->Enumerate("oper"); i++)
+ {
+ LoginName = cf->ReadValue("oper", "name", i);
+ Password = cf->ReadValue("oper", "password", i);
+ OperType = cf->ReadValue("oper", "type", i);
+ HostName = cf->ReadValue("oper", "host", i);
+ FingerPrint = cf->ReadValue("oper", "fingerprint", i);
+ SSLOnly = cf->ReadFlag("oper", "sslonly", i);
+
+ if (SSLOnly || !FingerPrint.empty())
+ {
+ if ((!strcmp(LoginName.c_str(),parameters[0])) && (!ServerInstance->OperPassCompare(Password.c_str(),parameters[1],i)) && (OneOfMatches(TheHost,TheIP,HostName.c_str())))
+ {
+ if (SSLOnly && !user->GetExt("ssl", dummy))
+ {
+ user->WriteServ("491 %s :This oper login name requires an SSL connection.", user->nick);
+ return 1;
+ }
+
+ /* This oper would match */
+ if ((!cert) || (cert->GetFingerprint() != FingerPrint))
+ {
+ user->WriteServ("491 %s :This oper login name requires a matching key fingerprint.",user->nick);
+ ServerInstance->SNO->WriteToSnoMask('o',"'%s' cannot oper, does not match fingerprint", user->nick);
+ ServerInstance->Log(DEFAULT,"OPER: Failed oper attempt by %s!%s@%s: credentials valid, but wrong fingerprint.",user->nick,user->ident,user->host);
+ return 1;
+ }
+ }
+ }
+ }
+ }
+ return 0;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1,1,0,0,VF_VENDOR,API_VERSION);
+ }
+};
+
+MODULE_INIT(ModuleOperSSLCert);
+
diff --git a/src/modules/extra/m_sslinfo.cpp b/src/modules/extra/m_sslinfo.cpp
index 83de798c8..dc9274f1e 100644
--- a/src/modules/extra/m_sslinfo.cpp
+++ b/src/modules/extra/m_sslinfo.cpp
@@ -1 +1,94 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "transport.h" #include "wildcard.h" #include "dns.h" /* $ModDesc: Provides /sslinfo command used to test who a mask matches */ /* $ModDep: transport.h */ /** Handle /SSLINFO */ class cmd_sslinfo : public command_t { public: cmd_sslinfo (InspIRCd* Instance) : command_t(Instance,"SSLINFO", 0, 1) { this->source = "m_sslinfo.so"; this->syntax = "<nick>"; } CmdResult Handle (const char** parameters, int pcnt, userrec *user) { userrec* target = ServerInstance->FindNick(parameters[0]); ssl_cert* cert; if (target) { if (target->GetExt("ssl_cert", cert)) { if (cert->GetError().length()) { user->WriteServ("NOTICE %s :*** Error: %s", user->nick, cert->GetError().c_str()); } user->WriteServ("NOTICE %s :*** Distinguised Name: %s", user->nick, cert->GetDN().c_str()); user->WriteServ("NOTICE %s :*** Issuer: %s", user->nick, cert->GetIssuer().c_str()); user->WriteServ("NOTICE %s :*** Key Fingerprint: %s", user->nick, cert->GetFingerprint().c_str()); return CMD_SUCCESS; } else { user->WriteServ("NOTICE %s :*** No SSL certificate information for this user.", user->nick); return CMD_FAILURE; } } else user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]); return CMD_FAILURE; } }; class ModuleSSLInfo : public Module { cmd_sslinfo* newcommand; public: ModuleSSLInfo(InspIRCd* Me) : Module(Me) { newcommand = new cmd_sslinfo(ServerInstance); ServerInstance->AddCommand(newcommand); } void Implements(char* List) { } virtual ~ModuleSSLInfo() { } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); } }; MODULE_INIT(ModuleSSLInfo); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "transport.h"
+#include "wildcard.h"
+#include "dns.h"
+
+/* $ModDesc: Provides /sslinfo command used to test who a mask matches */
+/* $ModDep: transport.h */
+
+/** Handle /SSLINFO
+ */
+class cmd_sslinfo : public command_t
+{
+ public:
+ cmd_sslinfo (InspIRCd* Instance) : command_t(Instance,"SSLINFO", 0, 1)
+ {
+ this->source = "m_sslinfo.so";
+ this->syntax = "<nick>";
+ }
+
+ CmdResult Handle (const char** parameters, int pcnt, userrec *user)
+ {
+ userrec* target = ServerInstance->FindNick(parameters[0]);
+ ssl_cert* cert;
+
+ if (target)
+ {
+ if (target->GetExt("ssl_cert", cert))
+ {
+ if (cert->GetError().length())
+ {
+ user->WriteServ("NOTICE %s :*** Error: %s", user->nick, cert->GetError().c_str());
+ }
+ user->WriteServ("NOTICE %s :*** Distinguised Name: %s", user->nick, cert->GetDN().c_str());
+ user->WriteServ("NOTICE %s :*** Issuer: %s", user->nick, cert->GetIssuer().c_str());
+ user->WriteServ("NOTICE %s :*** Key Fingerprint: %s", user->nick, cert->GetFingerprint().c_str());
+ return CMD_SUCCESS;
+ }
+ else
+ {
+ user->WriteServ("NOTICE %s :*** No SSL certificate information for this user.", user->nick);
+ return CMD_FAILURE;
+ }
+ }
+ else
+ user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]);
+
+ return CMD_FAILURE;
+ }
+};
+
+class ModuleSSLInfo : public Module
+{
+ cmd_sslinfo* newcommand;
+ public:
+ ModuleSSLInfo(InspIRCd* Me)
+ : Module(Me)
+ {
+
+ newcommand = new cmd_sslinfo(ServerInstance);
+ ServerInstance->AddCommand(newcommand);
+ }
+
+ void Implements(char* List)
+ {
+ }
+
+ virtual ~ModuleSSLInfo()
+ {
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
+ }
+};
+
+MODULE_INIT(ModuleSSLInfo);
+
diff --git a/src/modules/extra/m_testclient.cpp b/src/modules/extra/m_testclient.cpp
index a867dad20..f4e58b7b5 100644
--- a/src/modules/extra/m_testclient.cpp
+++ b/src/modules/extra/m_testclient.cpp
@@ -1 +1,110 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include "users.h" #include "channels.h" #include "modules.h" #include "configreader.h" #include "m_sqlv2.h" class ModuleTestClient : public Module { private: public: ModuleTestClient(InspIRCd* Me) : Module::Module(Me) { } void Implements(char* List) { List[I_OnRequest] = List[I_OnBackgroundTimer] = 1; } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); } virtual void OnBackgroundTimer(time_t foo) { Module* target = ServerInstance->FindFeature("SQL"); if(target) { SQLrequest foo = SQLreq(this, target, "foo", "UPDATE rawr SET foo = '?' WHERE bar = 42", ConvToStr(time(NULL))); if(foo.Send()) { ServerInstance->Log(DEBUG, "Sent query, got given ID %lu", foo.id); } else { ServerInstance->Log(DEBUG, "SQLrequest failed: %s", foo.error.Str()); } } } virtual char* OnRequest(Request* request) { if(strcmp(SQLRESID, request->GetId()) == 0) { ServerInstance->Log(DEBUG, "Got SQL result (%s)", request->GetId()); SQLresult* res = (SQLresult*)request; if (res->error.Id() == NO_ERROR) { if(res->Cols()) { ServerInstance->Log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols()); for (int r = 0; r < res->Rows(); r++) { ServerInstance->Log(DEBUG, "Row %d:", r); for(int i = 0; i < res->Cols(); i++) { ServerInstance->Log(DEBUG, "\t[%s]: %s", res->ColName(i).c_str(), res->GetValue(r, i).d.c_str()); } } } else { ServerInstance->Log(DEBUG, "%d rows affected in query", res->Rows()); } } else { ServerInstance->Log(DEBUG, "SQLrequest failed: %s", res->error.Str()); } return SQLSUCCESS; } ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId()); return NULL; } virtual ~ModuleTestClient() { } }; MODULE_INIT(ModuleTestClient); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "configreader.h"
+#include "m_sqlv2.h"
+
+class ModuleTestClient : public Module
+{
+private:
+
+
+public:
+ ModuleTestClient(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRequest] = List[I_OnBackgroundTimer] = 1;
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
+ }
+
+ virtual void OnBackgroundTimer(time_t foo)
+ {
+ Module* target = ServerInstance->FindFeature("SQL");
+
+ if(target)
+ {
+ SQLrequest foo = SQLreq(this, target, "foo", "UPDATE rawr SET foo = '?' WHERE bar = 42", ConvToStr(time(NULL)));
+
+ if(foo.Send())
+ {
+ ServerInstance->Log(DEBUG, "Sent query, got given ID %lu", foo.id);
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "SQLrequest failed: %s", foo.error.Str());
+ }
+ }
+ }
+
+ virtual char* OnRequest(Request* request)
+ {
+ if(strcmp(SQLRESID, request->GetId()) == 0)
+ {
+ ServerInstance->Log(DEBUG, "Got SQL result (%s)", request->GetId());
+
+ SQLresult* res = (SQLresult*)request;
+
+ if (res->error.Id() == NO_ERROR)
+ {
+ if(res->Cols())
+ {
+ ServerInstance->Log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
+
+ for (int r = 0; r < res->Rows(); r++)
+ {
+ ServerInstance->Log(DEBUG, "Row %d:", r);
+
+ for(int i = 0; i < res->Cols(); i++)
+ {
+ ServerInstance->Log(DEBUG, "\t[%s]: %s", res->ColName(i).c_str(), res->GetValue(r, i).d.c_str());
+ }
+ }
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "%d rows affected in query", res->Rows());
+ }
+ }
+ else
+ {
+ ServerInstance->Log(DEBUG, "SQLrequest failed: %s", res->error.Str());
+
+ }
+
+ return SQLSUCCESS;
+ }
+
+ ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
+
+ return NULL;
+ }
+
+ virtual ~ModuleTestClient()
+ {
+ }
+};
+
+MODULE_INIT(ModuleTestClient);
+
diff --git a/src/modules/extra/m_ziplink.cpp b/src/modules/extra/m_ziplink.cpp
index 2a127258d..e815d1042 100644
--- a/src/modules/extra/m_ziplink.cpp
+++ b/src/modules/extra/m_ziplink.cpp
@@ -1 +1,452 @@
-/* +------------------------------------+ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * * InspIRCd: (C) 2002-2007 InspIRCd Development Team * See: http://www.inspircd.org/wiki/index.php/Credits * * This program is free but copyrighted software; see * the file COPYING for details. * * --------------------------------------------------- */ #include "inspircd.h" #include <zlib.h> #include "users.h" #include "channels.h" #include "modules.h" #include "socket.h" #include "hashcomp.h" #include "transport.h" /* $ModDesc: Provides zlib link support for servers */ /* $LinkerFlags: -lz */ /* $ModDep: transport.h */ /* * Compressed data is transmitted across the link in the following format: * * 0 1 2 3 4 ... n * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ * | n | Z0 -> Zn | * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ * * Where: n is the size of a frame, in network byte order, 4 bytes. * Z0 through Zn are Zlib compressed data, n bytes in length. * * If the module fails to read the entire frame, then it will buffer * the portion of the last frame it received, then attempt to read * the next part of the frame next time a write notification arrives. * * ZLIB_BEST_COMPRESSION (9) is used for all sending of data with * a flush after each frame. A frame may contain multiple lines * and should be treated as raw binary data. * */ /* Status of a connection */ enum izip_status { IZIP_OPEN, IZIP_CLOSED }; /* Maximum transfer size per read operation */ const unsigned int CHUNK = 128 * 1024; /* This class manages a compressed chunk of data preceeded by * a length count. * * It can handle having multiple chunks of data in the buffer * at any time. */ class CountedBuffer : public classbase { std::string buffer; /* Current buffer contents */ unsigned int amount_expected; /* Amount of data expected */ public: CountedBuffer() { amount_expected = 0; } /** Adds arbitrary compressed data to the buffer. * - Binsry safe, of course. */ void AddData(unsigned char* data, int data_length) { buffer.append((const char*)data, data_length); this->NextFrameSize(); } /** Works out the size of the next compressed frame */ void NextFrameSize() { if ((!amount_expected) && (buffer.length() >= 4)) { /* We have enough to read an int - * Yes, this is safe, but its ugly. Give me * a nicer way to read 4 bytes from a binary * stream, and push them into a 32 bit int, * and i'll consider replacing this. */ amount_expected = ntohl((buffer[3] << 24) | (buffer[2] << 16) | (buffer[1] << 8) | buffer[0]); buffer = buffer.substr(4); } } /** Gets the next frame and returns its size, or returns * zero if there isnt one available yet. * A frame can contain multiple plaintext lines. * - Binary safe. */ int GetFrame(unsigned char* frame, int maxsize) { if (amount_expected) { /* We know how much we're expecting... * Do we have enough yet? */ if (buffer.length() >= amount_expected) { int j = 0; for (unsigned int i = 0; i < amount_expected; i++, j++) frame[i] = buffer[i]; buffer = buffer.substr(j); amount_expected = 0; NextFrameSize(); return j; } } /* Not enough for a frame yet, COME AGAIN! */ return 0; } }; /** Represents an zipped connections extra data */ class izip_session : public classbase { public: z_stream c_stream; /* compression stream */ z_stream d_stream; /* decompress stream */ izip_status status; /* Connection status */ int fd; /* File descriptor */ CountedBuffer* inbuf; /* Holds input buffer */ std::string outbuf; /* Holds output buffer */ }; class ModuleZLib : public Module { izip_session sessions[MAX_DESCRIPTORS]; /* Used for stats z extensions */ float total_out_compressed; float total_in_compressed; float total_out_uncompressed; float total_in_uncompressed; public: ModuleZLib(InspIRCd* Me) : Module::Module(Me) { ServerInstance->PublishInterface("InspSocketHook", this); total_out_compressed = total_in_compressed = 0; total_out_uncompressed = total_out_uncompressed = 0; } virtual ~ModuleZLib() { ServerInstance->UnpublishInterface("InspSocketHook", this); } virtual Version GetVersion() { return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); } void Implements(char* List) { List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = 1; List[I_OnStats] = List[I_OnRequest] = 1; } /* Handle InspSocketHook API requests */ virtual char* OnRequest(Request* request) { ISHRequest* ISR = (ISHRequest*)request; if (strcmp("IS_NAME", request->GetId()) == 0) { /* Return name */ return "zip"; } else if (strcmp("IS_HOOK", request->GetId()) == 0) { /* Attach to an inspsocket */ char* ret = "OK"; try { ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } catch (ModuleException& e) { return NULL; } return ret; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { /* Detatch from an inspsocket */ return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { /* Check for completion of handshake * (actually, this module doesnt handshake) */ return "OK"; } else if (strcmp("IS_ATTACH", request->GetId()) == 0) { /* Attach certificate data to the inspsocket * (this module doesnt do that, either) */ return NULL; } return NULL; } /* Handle stats z (misc stats) */ virtual int OnStats(char symbol, userrec* user, string_list &results) { if (symbol == 'z') { std::string sn = ServerInstance->Config->ServerName; /* Yeah yeah, i know, floats are ew. * We used them here because we'd be casting to float anyway to do this maths, * and also only floating point numbers can deal with the pretty large numbers * involved in the total throughput of a server over a large period of time. * (we dont count 64 bit ints because not all systems have 64 bit ints, and floats * can still hold more. */ float outbound_r = 100 - ((total_out_compressed / (total_out_uncompressed + 0.001)) * 100); float inbound_r = 100 - ((total_in_compressed / (total_in_uncompressed + 0.001)) * 100); float total_compressed = total_in_compressed + total_out_compressed; float total_uncompressed = total_in_uncompressed + total_out_uncompressed; float total_r = 100 - ((total_compressed / (total_uncompressed + 0.001)) * 100); char outbound_ratio[MAXBUF], inbound_ratio[MAXBUF], combined_ratio[MAXBUF]; sprintf(outbound_ratio, "%3.2f%%", outbound_r); sprintf(inbound_ratio, "%3.2f%%", inbound_r); sprintf(combined_ratio, "%3.2f%%", total_r); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_compressed = "+ConvToStr(total_out_compressed)); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_compressed = "+ConvToStr(total_in_compressed)); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_uncompressed = "+ConvToStr(total_out_uncompressed)); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_uncompressed = "+ConvToStr(total_in_uncompressed)); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_ratio = "+outbound_ratio); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_ratio = "+inbound_ratio); results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS combined_ratio = "+combined_ratio); return 0; } return 0; } virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) { izip_session* session = &sessions[fd]; /* allocate state and buffers */ session->fd = fd; session->status = IZIP_OPEN; session->inbuf = new CountedBuffer(); session->c_stream.zalloc = (alloc_func)0; session->c_stream.zfree = (free_func)0; session->c_stream.opaque = (voidpf)0; session->d_stream.zalloc = (alloc_func)0; session->d_stream.zfree = (free_func)0; session->d_stream.opaque = (voidpf)0; } virtual void OnRawSocketConnect(int fd) { /* Nothing special needs doing here compared to accept() */ OnRawSocketAccept(fd, "", 0); } virtual void OnRawSocketClose(int fd) { CloseSession(&sessions[fd]); } virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) { /* Find the sockets session */ izip_session* session = &sessions[fd]; if (session->status == IZIP_CLOSED) return 0; unsigned char compr[CHUNK + 4]; unsigned int offset = 0; unsigned int total_size = 0; /* Read CHUNK bytes at a time to the buffer (usually 128k) */ readresult = read(fd, compr, CHUNK); /* Did we get anything? */ if (readresult > 0) { /* Add it to the frame queue */ session->inbuf->AddData(compr, readresult); total_in_compressed += readresult; /* Parse all completed frames */ int size = 0; while ((size = session->inbuf->GetFrame(compr, CHUNK)) != 0) { session->d_stream.next_in = (Bytef*)compr; session->d_stream.avail_in = 0; session->d_stream.next_out = (Bytef*)(buffer + offset); /* If we cant call this, well, we're boned. */ if (inflateInit(&session->d_stream) != Z_OK) return 0; while ((session->d_stream.total_out < count) && (session->d_stream.total_in < (unsigned int)size)) { session->d_stream.avail_in = session->d_stream.avail_out = 1; if (inflate(&session->d_stream, Z_NO_FLUSH) == Z_STREAM_END) break; } /* Stick a fork in me, i'm done */ inflateEnd(&session->d_stream); /* Update counters and offsets */ total_size += session->d_stream.total_out; total_in_uncompressed += session->d_stream.total_out; offset += session->d_stream.total_out; } /* Null-terminate the buffer -- this doesnt harm binary data */ buffer[total_size] = 0; /* Set the read size to the correct total size */ readresult = total_size; } return (readresult > 0); } virtual int OnRawSocketWrite(int fd, const char* buffer, int count) { izip_session* session = &sessions[fd]; int ocount = count; if (!count) /* Nothing to do! */ return 0; if(session->status != IZIP_OPEN) { /* Seriously, wtf? */ CloseSession(session); return 0; } unsigned char compr[CHUNK + 4]; /* Gentlemen, start your engines! */ if (deflateInit(&session->c_stream, Z_BEST_COMPRESSION) != Z_OK) { CloseSession(session); return 0; } /* Set buffer sizes (we reserve 4 bytes at the start of the * buffer for the length counters) */ session->c_stream.next_in = (Bytef*)buffer; session->c_stream.next_out = compr + 4; /* Compress the text */ while ((session->c_stream.total_in < (unsigned int)count) && (session->c_stream.total_out < CHUNK)) { session->c_stream.avail_in = session->c_stream.avail_out = 1; if (deflate(&session->c_stream, Z_NO_FLUSH) != Z_OK) { CloseSession(session); return 0; } } /* Finish the stream */ for (session->c_stream.avail_out = 1; deflate(&session->c_stream, Z_FINISH) != Z_STREAM_END; session->c_stream.avail_out = 1); deflateEnd(&session->c_stream); total_out_uncompressed += ocount; total_out_compressed += session->c_stream.total_out; /** Assemble the frame length onto the frame, in network byte order */ compr[0] = (session->c_stream.total_out >> 24); compr[1] = (session->c_stream.total_out >> 16); compr[2] = (session->c_stream.total_out >> 8); compr[3] = (session->c_stream.total_out & 0xFF); /* Add compressed data plus leading length to the output buffer - * Note, we may have incomplete half-sent frames in here. */ session->outbuf.append((const char*)compr, session->c_stream.total_out + 4); /* Lets see how much we can send out */ int ret = write(fd, session->outbuf.data(), session->outbuf.length()); /* Check for errors, and advance the buffer if any was sent */ if (ret > 0) session->outbuf = session->outbuf.substr(ret); else if (ret < 1) { if (ret == -1) { if (errno == EAGAIN) return 0; else { session->outbuf.clear(); return 0; } } else { session->outbuf.clear(); return 0; } } /* ALL LIES the lot of it, we havent really written * this amount, but the layer above doesnt need to know. */ return ocount; } void CloseSession(izip_session* session) { if (session->status == IZIP_OPEN) { session->status = IZIP_CLOSED; session->outbuf.clear(); delete session->inbuf; } } }; MODULE_INIT(ModuleZLib); \ No newline at end of file
+/* +------------------------------------+
+ * | Inspire Internet Relay Chat Daemon |
+ * +------------------------------------+
+ *
+ * InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
+ * This program is free but copyrighted software; see
+ * the file COPYING for details.
+ *
+ * ---------------------------------------------------
+ */
+
+#include "inspircd.h"
+#include <zlib.h>
+#include "users.h"
+#include "channels.h"
+#include "modules.h"
+#include "socket.h"
+#include "hashcomp.h"
+#include "transport.h"
+
+/* $ModDesc: Provides zlib link support for servers */
+/* $LinkerFlags: -lz */
+/* $ModDep: transport.h */
+
+/*
+ * Compressed data is transmitted across the link in the following format:
+ *
+ * 0 1 2 3 4 ... n
+ * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
+ * | n | Z0 -> Zn |
+ * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
+ *
+ * Where: n is the size of a frame, in network byte order, 4 bytes.
+ * Z0 through Zn are Zlib compressed data, n bytes in length.
+ *
+ * If the module fails to read the entire frame, then it will buffer
+ * the portion of the last frame it received, then attempt to read
+ * the next part of the frame next time a write notification arrives.
+ *
+ * ZLIB_BEST_COMPRESSION (9) is used for all sending of data with
+ * a flush after each frame. A frame may contain multiple lines
+ * and should be treated as raw binary data.
+ *
+ */
+
+/* Status of a connection */
+enum izip_status { IZIP_OPEN, IZIP_CLOSED };
+
+/* Maximum transfer size per read operation */
+const unsigned int CHUNK = 128 * 1024;
+
+/* This class manages a compressed chunk of data preceeded by
+ * a length count.
+ *
+ * It can handle having multiple chunks of data in the buffer
+ * at any time.
+ */
+class CountedBuffer : public classbase
+{
+ std::string buffer; /* Current buffer contents */
+ unsigned int amount_expected; /* Amount of data expected */
+ public:
+ CountedBuffer()
+ {
+ amount_expected = 0;
+ }
+
+ /** Adds arbitrary compressed data to the buffer.
+ * - Binsry safe, of course.
+ */
+ void AddData(unsigned char* data, int data_length)
+ {
+ buffer.append((const char*)data, data_length);
+ this->NextFrameSize();
+ }
+
+ /** Works out the size of the next compressed frame
+ */
+ void NextFrameSize()
+ {
+ if ((!amount_expected) && (buffer.length() >= 4))
+ {
+ /* We have enough to read an int -
+ * Yes, this is safe, but its ugly. Give me
+ * a nicer way to read 4 bytes from a binary
+ * stream, and push them into a 32 bit int,
+ * and i'll consider replacing this.
+ */
+ amount_expected = ntohl((buffer[3] << 24) | (buffer[2] << 16) | (buffer[1] << 8) | buffer[0]);
+ buffer = buffer.substr(4);
+ }
+ }
+
+ /** Gets the next frame and returns its size, or returns
+ * zero if there isnt one available yet.
+ * A frame can contain multiple plaintext lines.
+ * - Binary safe.
+ */
+ int GetFrame(unsigned char* frame, int maxsize)
+ {
+ if (amount_expected)
+ {
+ /* We know how much we're expecting...
+ * Do we have enough yet?
+ */
+ if (buffer.length() >= amount_expected)
+ {
+ int j = 0;
+ for (unsigned int i = 0; i < amount_expected; i++, j++)
+ frame[i] = buffer[i];
+
+ buffer = buffer.substr(j);
+ amount_expected = 0;
+ NextFrameSize();
+ return j;
+ }
+ }
+ /* Not enough for a frame yet, COME AGAIN! */
+ return 0;
+ }
+};
+
+/** Represents an zipped connections extra data
+ */
+class izip_session : public classbase
+{
+ public:
+ z_stream c_stream; /* compression stream */
+ z_stream d_stream; /* decompress stream */
+ izip_status status; /* Connection status */
+ int fd; /* File descriptor */
+ CountedBuffer* inbuf; /* Holds input buffer */
+ std::string outbuf; /* Holds output buffer */
+};
+
+class ModuleZLib : public Module
+{
+ izip_session sessions[MAX_DESCRIPTORS];
+
+ /* Used for stats z extensions */
+ float total_out_compressed;
+ float total_in_compressed;
+ float total_out_uncompressed;
+ float total_in_uncompressed;
+
+ public:
+
+ ModuleZLib(InspIRCd* Me)
+ : Module::Module(Me)
+ {
+ ServerInstance->PublishInterface("InspSocketHook", this);
+
+ total_out_compressed = total_in_compressed = 0;
+ total_out_uncompressed = total_out_uncompressed = 0;
+ }
+
+ virtual ~ModuleZLib()
+ {
+ ServerInstance->UnpublishInterface("InspSocketHook", this);
+ }
+
+ virtual Version GetVersion()
+ {
+ return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
+ }
+
+ void Implements(char* List)
+ {
+ List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = 1;
+ List[I_OnStats] = List[I_OnRequest] = 1;
+ }
+
+ /* Handle InspSocketHook API requests */
+ virtual char* OnRequest(Request* request)
+ {
+ ISHRequest* ISR = (ISHRequest*)request;
+ if (strcmp("IS_NAME", request->GetId()) == 0)
+ {
+ /* Return name */
+ return "zip";
+ }
+ else if (strcmp("IS_HOOK", request->GetId()) == 0)
+ {
+ /* Attach to an inspsocket */
+ char* ret = "OK";
+ try
+ {
+ ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ catch (ModuleException& e)
+ {
+ return NULL;
+ }
+ return ret;
+ }
+ else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
+ {
+ /* Detatch from an inspsocket */
+ return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
+ }
+ else if (strcmp("IS_HSDONE", request->GetId()) == 0)
+ {
+ /* Check for completion of handshake
+ * (actually, this module doesnt handshake)
+ */
+ return "OK";
+ }
+ else if (strcmp("IS_ATTACH", request->GetId()) == 0)
+ {
+ /* Attach certificate data to the inspsocket
+ * (this module doesnt do that, either)
+ */
+ return NULL;
+ }
+ return NULL;
+ }
+
+ /* Handle stats z (misc stats) */
+ virtual int OnStats(char symbol, userrec* user, string_list &results)
+ {
+ if (symbol == 'z')
+ {
+ std::string sn = ServerInstance->Config->ServerName;
+
+ /* Yeah yeah, i know, floats are ew.
+ * We used them here because we'd be casting to float anyway to do this maths,
+ * and also only floating point numbers can deal with the pretty large numbers
+ * involved in the total throughput of a server over a large period of time.
+ * (we dont count 64 bit ints because not all systems have 64 bit ints, and floats
+ * can still hold more.
+ */
+ float outbound_r = 100 - ((total_out_compressed / (total_out_uncompressed + 0.001)) * 100);
+ float inbound_r = 100 - ((total_in_compressed / (total_in_uncompressed + 0.001)) * 100);
+
+ float total_compressed = total_in_compressed + total_out_compressed;
+ float total_uncompressed = total_in_uncompressed + total_out_uncompressed;
+
+ float total_r = 100 - ((total_compressed / (total_uncompressed + 0.001)) * 100);
+
+ char outbound_ratio[MAXBUF], inbound_ratio[MAXBUF], combined_ratio[MAXBUF];
+
+ sprintf(outbound_ratio, "%3.2f%%", outbound_r);
+ sprintf(inbound_ratio, "%3.2f%%", inbound_r);
+ sprintf(combined_ratio, "%3.2f%%", total_r);
+
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_compressed = "+ConvToStr(total_out_compressed));
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_compressed = "+ConvToStr(total_in_compressed));
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_uncompressed = "+ConvToStr(total_out_uncompressed));
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_uncompressed = "+ConvToStr(total_in_uncompressed));
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_ratio = "+outbound_ratio);
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_ratio = "+inbound_ratio);
+ results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS combined_ratio = "+combined_ratio);
+ return 0;
+ }
+
+ return 0;
+ }
+
+ virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
+ {
+ izip_session* session = &sessions[fd];
+
+ /* allocate state and buffers */
+ session->fd = fd;
+ session->status = IZIP_OPEN;
+ session->inbuf = new CountedBuffer();
+
+ session->c_stream.zalloc = (alloc_func)0;
+ session->c_stream.zfree = (free_func)0;
+ session->c_stream.opaque = (voidpf)0;
+
+ session->d_stream.zalloc = (alloc_func)0;
+ session->d_stream.zfree = (free_func)0;
+ session->d_stream.opaque = (voidpf)0;
+ }
+
+ virtual void OnRawSocketConnect(int fd)
+ {
+ /* Nothing special needs doing here compared to accept() */
+ OnRawSocketAccept(fd, "", 0);
+ }
+
+ virtual void OnRawSocketClose(int fd)
+ {
+ CloseSession(&sessions[fd]);
+ }
+
+ virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
+ {
+ /* Find the sockets session */
+ izip_session* session = &sessions[fd];
+
+ if (session->status == IZIP_CLOSED)
+ return 0;
+
+ unsigned char compr[CHUNK + 4];
+ unsigned int offset = 0;
+ unsigned int total_size = 0;
+
+ /* Read CHUNK bytes at a time to the buffer (usually 128k) */
+ readresult = read(fd, compr, CHUNK);
+
+ /* Did we get anything? */
+ if (readresult > 0)
+ {
+ /* Add it to the frame queue */
+ session->inbuf->AddData(compr, readresult);
+ total_in_compressed += readresult;
+
+ /* Parse all completed frames */
+ int size = 0;
+ while ((size = session->inbuf->GetFrame(compr, CHUNK)) != 0)
+ {
+ session->d_stream.next_in = (Bytef*)compr;
+ session->d_stream.avail_in = 0;
+ session->d_stream.next_out = (Bytef*)(buffer + offset);
+
+ /* If we cant call this, well, we're boned. */
+ if (inflateInit(&session->d_stream) != Z_OK)
+ return 0;
+
+ while ((session->d_stream.total_out < count) && (session->d_stream.total_in < (unsigned int)size))
+ {
+ session->d_stream.avail_in = session->d_stream.avail_out = 1;
+ if (inflate(&session->d_stream, Z_NO_FLUSH) == Z_STREAM_END)
+ break;
+ }
+
+ /* Stick a fork in me, i'm done */
+ inflateEnd(&session->d_stream);
+
+ /* Update counters and offsets */
+ total_size += session->d_stream.total_out;
+ total_in_uncompressed += session->d_stream.total_out;
+ offset += session->d_stream.total_out;
+ }
+
+ /* Null-terminate the buffer -- this doesnt harm binary data */
+ buffer[total_size] = 0;
+
+ /* Set the read size to the correct total size */
+ readresult = total_size;
+
+ }
+ return (readresult > 0);
+ }
+
+ virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
+ {
+ izip_session* session = &sessions[fd];
+ int ocount = count;
+
+ if (!count) /* Nothing to do! */
+ return 0;
+
+ if(session->status != IZIP_OPEN)
+ {
+ /* Seriously, wtf? */
+ CloseSession(session);
+ return 0;
+ }
+
+ unsigned char compr[CHUNK + 4];
+
+ /* Gentlemen, start your engines! */
+ if (deflateInit(&session->c_stream, Z_BEST_COMPRESSION) != Z_OK)
+ {
+ CloseSession(session);
+ return 0;
+ }
+
+ /* Set buffer sizes (we reserve 4 bytes at the start of the
+ * buffer for the length counters)
+ */
+ session->c_stream.next_in = (Bytef*)buffer;
+ session->c_stream.next_out = compr + 4;
+
+ /* Compress the text */
+ while ((session->c_stream.total_in < (unsigned int)count) && (session->c_stream.total_out < CHUNK))
+ {
+ session->c_stream.avail_in = session->c_stream.avail_out = 1;
+ if (deflate(&session->c_stream, Z_NO_FLUSH) != Z_OK)
+ {
+ CloseSession(session);
+ return 0;
+ }
+ }
+ /* Finish the stream */
+ for (session->c_stream.avail_out = 1; deflate(&session->c_stream, Z_FINISH) != Z_STREAM_END; session->c_stream.avail_out = 1);
+ deflateEnd(&session->c_stream);
+
+ total_out_uncompressed += ocount;
+ total_out_compressed += session->c_stream.total_out;
+
+ /** Assemble the frame length onto the frame, in network byte order */
+ compr[0] = (session->c_stream.total_out >> 24);
+ compr[1] = (session->c_stream.total_out >> 16);
+ compr[2] = (session->c_stream.total_out >> 8);
+ compr[3] = (session->c_stream.total_out & 0xFF);
+
+ /* Add compressed data plus leading length to the output buffer -
+ * Note, we may have incomplete half-sent frames in here.
+ */
+ session->outbuf.append((const char*)compr, session->c_stream.total_out + 4);
+
+ /* Lets see how much we can send out */
+ int ret = write(fd, session->outbuf.data(), session->outbuf.length());
+
+ /* Check for errors, and advance the buffer if any was sent */
+ if (ret > 0)
+ session->outbuf = session->outbuf.substr(ret);
+ else if (ret < 1)
+ {
+ if (ret == -1)
+ {
+ if (errno == EAGAIN)
+ return 0;
+ else
+ {
+ session->outbuf.clear();
+ return 0;
+ }
+ }
+ else
+ {
+ session->outbuf.clear();
+ return 0;
+ }
+ }
+
+ /* ALL LIES the lot of it, we havent really written
+ * this amount, but the layer above doesnt need to know.
+ */
+ return ocount;
+ }
+
+ void CloseSession(izip_session* session)
+ {
+ if (session->status == IZIP_OPEN)
+ {
+ session->status = IZIP_CLOSED;
+ session->outbuf.clear();
+ delete session->inbuf;
+ }
+ }
+
+};
+
+MODULE_INIT(ModuleZLib);
+