summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/modules/extra/m_mysql.cpp52
1 files changed, 29 insertions, 23 deletions
diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp
index 7b6e2906d..dcdbe0004 100644
--- a/src/modules/extra/m_mysql.cpp
+++ b/src/modules/extra/m_mysql.cpp
@@ -253,6 +253,31 @@ class MySQLresult : public SQL::Result
*/
class SQLConnection : public SQL::Provider
{
+ private:
+ bool EscapeString(SQL::Query* query, const std::string& in, std::string& out)
+ {
+ // In the worst case each character may need to be encoded as using two bytes and one
+ // byte is the NUL terminator.
+ std::vector<char> buffer(in.length() * 2 + 1);
+
+ // The return value of mysql_escape_string() is either an error or the length of the
+ // encoded string not including the NUL terminator.
+ //
+ // Unfortunately, someone genius decided that mysql_escape_string should return an
+ // unsigned type even though -1 is returned on error so checking whether an error
+ // happened is a bit cursed.
+ unsigned long escapedsize = mysql_escape_string(&buffer[0], in.c_str(), in.length());
+ if (escapedsize == static_cast<unsigned long>(-1))
+ {
+ SQL::Error err(SQL::QSEND_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
+ query->OnError(err);
+ return false;
+ }
+
+ out.append(&buffer[0], escapedsize);
+ return true;
+ }
+
public:
reference<ConfigTag> config;
MYSQL *connection;
@@ -356,21 +381,8 @@ class SQLConnection : public SQL::Provider
{
if (q[i] != '?')
res.push_back(q[i]);
- else
- {
- if (param < p.size())
- {
- std::string parm = p[param++];
- // In the worst case, each character may need to be encoded as using two bytes,
- // and one byte is the terminating null
- std::vector<char> buffer(parm.length() * 2 + 1);
-
- // The return value of mysql_real_escape_string() is the length of the encoded string,
- // not including the terminating null
- unsigned long escapedsize = mysql_real_escape_string(connection, &buffer[0], parm.c_str(), parm.length());
- res.append(&buffer[0], escapedsize);
- }
- }
+ else if (param < p.size() && !EscapeString(call, p[param++], res))
+ return;
}
Submit(call, res);
}
@@ -391,14 +403,8 @@ class SQLConnection : public SQL::Provider
i--;
SQL::ParamMap::const_iterator it = p.find(field);
- if (it != p.end())
- {
- std::string parm = it->second;
- // NOTE: See above
- std::vector<char> buffer(parm.length() * 2 + 1);
- unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
- res.append(&buffer[0], escapedsize);
- }
+ if (it != p.end() && !EscapeString(call, it->second, res))
+ return;
}
}
Submit(call, res);