diff options
Diffstat (limited to 'src/modules')
-rw-r--r-- | src/modules/extra/m_sqlauth.cpp | 266 |
1 files changed, 162 insertions, 104 deletions
diff --git a/src/modules/extra/m_sqlauth.cpp b/src/modules/extra/m_sqlauth.cpp index 771749075..17381731b 100644 --- a/src/modules/extra/m_sqlauth.cpp +++ b/src/modules/extra/m_sqlauth.cpp @@ -6,6 +6,7 @@ * E-mail: * <brain@chatspike.net> * <Craig@chatspike.net> + * <omster@gmail.com> * * Written by Craig Edwards, Craig McLure, and others. * This program is free but copyrighted software; see @@ -14,169 +15,227 @@ * --------------------------------------------------- */ -using namespace std; - -#include <stdio.h> #include <string> -#include <stdlib.h> -#include <time.h> -#include <sys/types.h> -#include <sys/socket.h> -#include <sys/time.h> -#include <string.h> -#include <unistd.h> -#include <errno.h> -#include <fcntl.h> -#include <poll.h> +#include <map> + #include "users.h" #include "channels.h" #include "modules.h" #include "inspircd.h" #include "helperfuncs.h" -#include "m_sql.h" +#include "m_sqlv2.h" /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */ +typedef std::map<unsigned int, userrec*> QueryUserMap; + class ModuleSQLAuth : public Module { Server* Srv; - ConfigReader* Conf; + std::string usertable; std::string userfield; std::string passfield; std::string encryption; std::string killreason; std::string allowpattern; - bool WallOperFail; - unsigned long dbid; - Module* SQLModule; - - public: - bool ReadConfig() - { - Conf = new ConfigReader(); - usertable = Conf->ReadValue("sqlauth","usertable",0); // user table name - dbid = Conf->ReadInteger("sqlauth","dbid",0,true); // database id of a database configured in m_sql (see m_sql config) - userfield = Conf->ReadValue("sqlauth","userfield",0); // field name where username can be found - passfield = Conf->ReadValue("sqlauth","passfield",0); // field name where password can be found - killreason = Conf->ReadValue("sqlauth","killreason",0); // reason to give when access is denied to a user (put your reg details here) - encryption = Conf->ReadValue("sqlauth","encryption",0); // name of sql function used to encrypt password, e.g. "md5" or "passwd". - // define, but leave blank if no encryption is to be used. - WallOperFail = Conf->ReadFlag("sqlauth","verbose",0); // set to 1 if failed connects should be reported to operators - allowpattern = Conf->ReadValue("sqlauth","allowpattern",0); // allow nicks matching the pattern without requiring auth - if (encryption.find("(") == std::string::npos) - { - encryption.append("("); - } - DELETE(Conf); - SQLModule = Srv->FindModule("m_sql.so"); - if (!SQLModule) - Srv->Log(DEFAULT,"WARNING: m_sqlauth.so could not initialize because m_sql.so is not loaded. Load the module and rehash your server."); - return (SQLModule); - } - + std::string databaseid; + + bool verbose; + + QueryUserMap qumap; + +public: ModuleSQLAuth(Server* Me) - : Module::Module(Me) + : Module::Module(Me) { Srv = Me; - ReadConfig(); + OnRehash(""); } void Implements(char* List) { - List[I_OnRehash] = List[I_OnUserRegister] = 1; + List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1; } virtual void OnRehash(const std::string ¶meter) { - ReadConfig(); - } + ConfigReader Conf; + + usertable = Conf.ReadValue("sqlauth", "usertable", 0); /* User table name */ + databaseid = Conf.ReadValue("sqlauth", "dbid", 0); /* Database ID, given to the SQL service provider */ + userfield = Conf.ReadValue("sqlauth", "userfield", 0); /* Field name where username can be found */ + passfield = Conf.ReadValue("sqlauth", "passfield", 0); /* Field name where password can be found */ + killreason = Conf.ReadValue("sqlauth", "killreason", 0); /* Reason to give when access is denied to a user (put your reg details here) */ + allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */ + encryption = Conf.ReadValue("sqlauth", "encryption", 0); /* Name of sql function used to encrypt password, e.g. "md5" or "passwd". + * define, but leave blank if no encryption is to be used. + */ + verbose = Conf.ReadFlag("sqlauth", "verbose", 0); /* Set to true if failed connects should be reported to operators */ + + if (encryption.find("(") == std::string::npos) + { + encryption.append("("); + } + } virtual void OnUserRegister(userrec* user) { if ((allowpattern != "") && (Srv->MatchText(user->nick,allowpattern))) return; - if (!CheckCredentials(user->nick,user->password)) + if (!CheckCredentials(user)) { - if (WallOperFail) + if (verbose) WriteOpers("Forbidden connection from %s!%s@%s (invalid login/password)",user->nick,user->ident,user->host); Srv->QuitUser(user,killreason); } } - bool CheckCredentials(const std::string &s_username, const std::string &s_password) + bool CheckCredentials(userrec* user) { - bool found = false; - - // is the sql module loaded? If not, we don't attempt to do anything. - if (!SQLModule) - return false; - - // sanitize the password (we dont want any mysql insertion exploits!) - std::string username = SQLQuery::Sanitise(s_username); - std::string password = SQLQuery::Sanitise(s_password); - - // Create a request containing the SQL query and send it to m_sql.so - std::string querystr("SELECT * FROM "+usertable+" WHERE "+userfield+"='"+username+"' AND "+passfield+"="+encryption+"'"+password+"')"); + bool found; + Module* target; - Srv->Log(DEBUG, "m_sqlauth.so: Query: " + querystr); + found = false; + target = Srv->FindFeature("SQL"); - SQLRequest* query = new SQLRequest(SQL_RESULT,dbid,querystr); - Request queryrequest((char*)query, this, SQLModule); - SQLResult* result = (SQLResult*)queryrequest.Send(); - - // Did we get "OK" as a result? - if (result->GetType() == SQL_OK) + if(target) { - log(DEBUG, "m_sqlauth.so: Query OK"); + SQLrequest req = SQLreq(this, target, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password); - // if we did, this means we may now request a row... there should be only one row for each user, so, - // we don't need to loop to fetch multiple rows. - SQLRequest* rowrequest = new SQLRequest(SQL_ROW,dbid,""); - Request rowquery((char*)rowrequest, this, SQLModule); - SQLResult* rowresult = (SQLResult*)rowquery.Send(); - - // did we get a row? If we did, we can now do something with the fields - if (rowresult->GetType() == SQL_ROW) + if(req.Send()) { - log(DEBUG, "m_sqlauth.so: Got row...user '%s'", rowresult->GetField(userfield).c_str()); + /* When we get the query response from the service provider we will be given an ID to play with, + * just an ID number which is unique to this query. We need a way of associating that ID with a userrec + * so we insert it into a map mapping the IDs to users. + * This isn't quite enough though, as if the user quit while the query was in progress then when the result + * came to be processed we'd get an invalid userrec* out of the map. Now we *could* solve this by watching + * OnUserDisconnect() and iterating the map every time someone quits to make sure they didn't have any queries + * in progress, but that would be relatively slow and inefficient. Instead (thanks to w00t ;p) we attach a list + * of query IDs associated with it to the userrec, so in OnUserDisconnect() we can remove it immediately. + */ + log(DEBUG, "Sent query, got given ID %lu", req.id); + qumap.insert(std::make_pair(req.id, user)); - if (rowresult->GetField(userfield) == username) + if(!user->Extend("sqlauth_queryid", new unsigned long(req.id))) { - log(DEBUG, "m_sqlauth.so: Got correct user..."); - // because the query directly asked for the password hash, we do not need to check it - - // if it didnt match it wont be returned in the first place from the SELECT. - // This just checks we didnt get an empty row by accident. - found = true; + log(DEBUG, "BUG: user being sqlauth'd already extended with 'sqlauth_queryid' :/"); } + + return true; } else { - log(DEBUG, "m_sqlauth.so: Couldn't find row"); - // we didn't have a row. - found = false; + log(DEBUG, "SQLrequest failed: %s", req.error.Str()); + + if (verbose) + WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str()); + + return false; } - - DELETE(rowrequest); - DELETE(rowresult); } else { - log(DEBUG, "m_sqlauth.so: Query failed"); - // the query was bad - found = false; + log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be allowed to connect until it comes back unless they match an exception"); + return false; + } + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLRESID, request->GetData()) == 0) + { + SQLresult* res; + QueryUserMap::iterator iter; + + res = static_cast<SQLresult*>(request); + + log(DEBUG, "Got SQL result (%s) with ID %lu", res->GetData(), res->id); + + iter = qumap.find(res->id); + + if(iter != qumap.end()) + { + userrec* user; + unsigned long* id; + + user = iter->second; + + log(DEBUG, "Associated query ID %lu with user %s", res->id, user->nick); + + log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols()); + + if(res->Rows()) + { + /* We got a row in the result, this is enough really */ + user->Extend("sqlauthed"); + } + else if (verbose) + { + /* No rows in result, this means there was no record matching the user */ + WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host); + } + + /* Remove our ID from the lookup table to keep it as small and neat as possible */ + qumap.erase(iter); + + /* Cleanup the userrec, no point leaving this here */ + if(user->GetExt("sqlauth_queryid", id)) + { + user->Shrink("sqlauth_queryid"); + delete id; + } + } + else + { + log(DEBUG, "Got query with unknown ID, this probably means the user quit while the query was in progress"); + } + + return SQLSUCCESS; } - query->SetQueryType(SQL_DONE); - query->SetConnID(dbid); - Request donerequest((char*)query, this, SQLModule); - donerequest.Send(); + log(DEBUG, "Got unsupported API version string: %s", request->GetData()); - DELETE(query); - DELETE(result); + return NULL; + } + + virtual void OnUserDisconnect(userrec* user) + { + unsigned long* id; - return found; + if(user->GetExt("sqlauth_queryid", id)) + { + QueryUserMap::iterator iter; + + iter = qumap.find(*id); + + if(iter != qumap.end()) + { + if(iter->second == user) + { + qumap.erase(iter); + + log(DEBUG, "Erased query from map associated with quitting user %s", user->nick); + } + else + { + log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map"); + } + } + else + { + log(DEBUG, "BUG: user %s was extended with sqlauth_queryid but there was nothing matching in the map", user->nick); + } + + user->Shrink("sqlauth_queryid"); + delete id; + } + } + + virtual bool OnCheckReady(userrec* user) + { + return user->GetExt("sqlauthed"); } virtual ~ModuleSQLAuth() @@ -185,7 +244,7 @@ class ModuleSQLAuth : public Module virtual Version GetVersion() { - return Version(1,0,0,2,VF_VENDOR); + return Version(1,0,1,0,VF_VENDOR); } }; @@ -213,4 +272,3 @@ extern "C" void * init_module( void ) { return new ModuleSQLAuthFactory; } - |