Skip to content

Commit

Permalink
deal with lengths in characters and in bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
artpaul committed Aug 31, 2017
1 parent 9056a2e commit 35fcce0
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 47 deletions.
7 changes: 4 additions & 3 deletions driver/attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl_SQLGetEnvAttr(SQLHENV environment_handle, SQLINTEGER attribute,
fillOutputNumber<SQLUINTEGER>(environment.odbc_version, out_value, out_value_max_length, out_value_length);
return SQL_SUCCESS;

CASE_NUM(SQL_ATTR_METADATA_ID, SQLUINTEGER, environment.metadata_id);
CASE_NUM(SQL_ATTR_METADATA_ID, SQLUINTEGER, environment.metadata_id);

case SQL_ATTR_CONNECTION_POOLING:
case SQL_ATTR_CP_MATCH:
Expand All @@ -78,6 +78,7 @@ impl_SQLGetEnvAttr(SQLHENV environment_handle, SQLINTEGER attribute,
}


/// Description: https://docs.microsoft.com/en-us/sql/odbc/reference/syntax/sqlsetconnectattr-function
RETCODE
impl_SQLSetConnectAttr(SQLHDBC connection_handle, SQLINTEGER attribute,
SQLPOINTER value, SQLINTEGER value_length)
Expand All @@ -100,7 +101,7 @@ impl_SQLSetConnectAttr(SQLHDBC connection_handle, SQLINTEGER attribute,
}

case SQL_ATTR_CURRENT_CATALOG:
connection.database = stringFromSQLChar((SQLTCHAR *)value, value_length);
connection.setDatabase(stringFromSQLBytes((SQLTCHAR *)value, value_length));
return SQL_SUCCESS;

case SQL_ATTR_ACCESS_MODE:
Expand Down Expand Up @@ -145,7 +146,7 @@ impl_SQLGetConnectAttr(SQLHDBC connection_handle, SQLINTEGER attribute,
CASE_NUM(SQL_ATTR_LOGIN_TIMEOUT, SQLUSMALLINT, connection.session.getTimeout().seconds())

case SQL_ATTR_CURRENT_CATALOG:
fillOutputPlatformString(connection.database, out_value, out_value_max_length, out_value_length);
fillOutputPlatformString(connection.getDatabase(), out_value, out_value_max_length, out_value_length);
return SQL_SUCCESS;

case SQL_ATTR_ACCESS_MODE:
Expand Down
10 changes: 10 additions & 0 deletions driver/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ std::string Connection::connectionString() const
return ret;
}

const std::string & Connection::getDatabase() const
{
return database;
}

void Connection::setDatabase(const std::string & db)
{
database = db;
}

void Connection::init()
{
loadConfiguration();
Expand Down
10 changes: 9 additions & 1 deletion driver/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ struct Connection
uint16_t port = 0;
std::string user;
std::string password;
std::string database;

Poco::Net::HTTPClientSession session;
DiagnosticRecord diagnostic_record;
Expand All @@ -25,6 +24,12 @@ struct Connection
/// Returns the completed connection string.
std::string connectionString() const;

/// Returns database associated with the current connection.
const std::string & getDatabase() const;

/// Sets database to the current connection;
void setDatabase(const std::string & db);

void init();

void init(
Expand All @@ -42,4 +47,7 @@ struct Connection

/// Sets uninitialized fields to their default values.
void setDefaults();

private:
std::string database;
};
2 changes: 1 addition & 1 deletion driver/info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ SQLGetInfo(HDBC connection_handle,
CASE_STRING(SQL_DATA_SOURCE_NAME, connection.data_source)
CASE_STRING(SQL_CATALOG_TERM, "catalog")
CASE_STRING(SQL_COLLATION_SEQ, "UTF-8")
CASE_STRING(SQL_DATABASE_NAME, connection.database)
CASE_STRING(SQL_DATABASE_NAME, connection.getDatabase())
CASE_STRING(SQL_KEYWORDS, "")
CASE_STRING(SQL_PROCEDURE_TERM, "stored procedure")
CASE_STRING(SQL_CATALOG_NAME_SEPARATOR, ".")
Expand Down
42 changes: 21 additions & 21 deletions driver/odbc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ SQLConnect(HDBC connection_handle,

return doWith<Connection>(connection_handle, [&](Connection & connection)
{
std::string dsn_str = stringFromSQLChar(dsn, dsn_size);
std::string user_str = stringFromSQLChar(user, user_size);
std::string password_str = stringFromSQLChar(password, password_size);
std::string dsn_str = stringFromSQLSymbols(dsn, dsn_size);
std::string user_str = stringFromSQLSymbols(user, user_size);
std::string password_str = stringFromSQLSymbols(password, password_size);

connection.init(dsn_str, 0, user_str, password_str, "");
return SQL_SUCCESS;
Expand All @@ -64,9 +64,9 @@ SQLDriverConnect(HDBC connection_handle,
LOG(__FUNCTION__);
return doWith<Connection>(connection_handle, [&](Connection & connection)
{
connection.init(stringFromSQLChar(connection_str_in, connection_str_in_size));
connection.init(stringFromSQLSymbols(connection_str_in, connection_str_in_size));
// Copy complete connection string.
fillOutputPlatformString(connection.connectionString(), connection_str_out, connection_str_out_max_size, connection_str_out_size);
fillOutputPlatformString(connection.connectionString(), connection_str_out, connection_str_out_max_size, connection_str_out_size, false);
return SQL_SUCCESS;
});
}
Expand All @@ -80,7 +80,7 @@ SQLPrepare(HSTMT statement_handle,

return doWith<Statement>(statement_handle, [&](Statement & statement)
{
const std::string & query = stringFromSQLChar(statement_text, statement_text_size);
const std::string & query = stringFromSQLSymbols(statement_text, statement_text_size);

if (!statement.isEmpty())
throw std::runtime_error("Prepare called, but statement query is not empty.");
Expand Down Expand Up @@ -116,7 +116,7 @@ SQLExecDirect(HSTMT statement_handle,

return doWith<Statement>(statement_handle, [&](Statement & statement)
{
const std::string & query = stringFromSQLChar(statement_text, statement_text_size);
const std::string & query = stringFromSQLSymbols(statement_text, statement_text_size);

if (!statement.isEmpty())
{
Expand Down Expand Up @@ -299,7 +299,7 @@ SQLDescribeCol(HSTMT statement_handle,
if (out_is_nullable)
*out_is_nullable = SQL_NO_NULLS;

return fillOutputPlatformString(column_info.name, out_column_name, out_column_name_max_size, out_column_name_size);;
return fillOutputPlatformString(column_info.name, out_column_name, out_column_name_max_size, out_column_name_size, false);
});
}

Expand Down Expand Up @@ -570,7 +570,7 @@ impl_SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle,
if (out_native_error_code)
*out_native_error_code = diagnostic_record->native_error_code;

return fillOutputPlatformString(diagnostic_record->message, out_mesage, out_message_max_size, out_message_size);
return fillOutputPlatformString(diagnostic_record->message, out_mesage, out_message_max_size, out_message_size, false);
}


Expand Down Expand Up @@ -626,7 +626,7 @@ SQLTables(HSTMT statement_handle,
// TODO (artpaul) Take statement.getMetatadaId() into account.
return doWith<Statement>(statement_handle, [&](Statement & statement)
{
const std::string catalog = stringFromSQLChar(catalog_name, catalog_name_length);
const std::string catalog = stringFromSQLSymbols(catalog_name, catalog_name_length);

std::stringstream query;

Expand Down Expand Up @@ -654,7 +654,7 @@ SQLTables(HSTMT statement_handle,
", '' AS REMARKS"
" FROM system.tables"
" WHERE (database == '";
query << statement.connection.database << "')";
query << statement.connection.getDatabase() << "')";
query << " ORDER BY TABLE_TYPE, TABLE_CAT, TABLE_SCHEM, TABLE_NAME";
}
// Get a list of databases on the current connection's server.
Expand Down Expand Up @@ -685,13 +685,13 @@ SQLTables(HSTMT statement_handle,
" WHERE (1 == 1)";

if (catalog_name && catalog_name_length)
query << " AND TABLE_CAT LIKE '" << stringFromSQLChar(catalog_name, catalog_name_length) << "'";
query << " AND TABLE_CAT LIKE '" << stringFromSQLSymbols(catalog_name, catalog_name_length) << "'";
//if (schema_name_length)
// query << " AND TABLE_SCHEM LIKE '" << stringFromSQLChar(schema_name, schema_name_length) << "'";
// query << " AND TABLE_SCHEM LIKE '" << stringFromSQLSymbols(schema_name, schema_name_length) << "'";
if (table_name && table_name_length)
query << " AND TABLE_NAME LIKE '" << stringFromSQLChar(table_name, table_name_length) << "'";
query << " AND TABLE_NAME LIKE '" << stringFromSQLSymbols(table_name, table_name_length) << "'";
//if (table_type_length)
// query << " AND TABLE_TYPE = '" << stringFromSQLChar(table_type, table_type_length) << "'";
// query << " AND TABLE_TYPE = '" << stringFromSQLSymbols(table_type, table_type_length) << "'";

query << " ORDER BY TABLE_TYPE, TABLE_CAT, TABLE_SCHEM, TABLE_NAME";
}
Expand Down Expand Up @@ -739,13 +739,13 @@ SQLColumns(HSTMT statement_handle,
" WHERE (1 == 1)";

if (catalog_name_length)
query << " AND TABLE_CAT LIKE '" << stringFromSQLChar(catalog_name, catalog_name_length) << "'";
query << " AND TABLE_CAT LIKE '" << stringFromSQLSymbols(catalog_name, catalog_name_length) << "'";
if (schema_name_length)
query << " AND TABLE_SCHEM LIKE '" << stringFromSQLChar(schema_name, schema_name_length) << "'";
query << " AND TABLE_SCHEM LIKE '" << stringFromSQLSymbols(schema_name, schema_name_length) << "'";
if (table_name_length)
query << " AND TABLE_NAME LIKE '" << stringFromSQLChar(table_name, table_name_length) << "'";
query << " AND TABLE_NAME LIKE '" << stringFromSQLSymbols(table_name, table_name_length) << "'";
if (column_name_length)
query << " AND COLUMN_NAME LIKE '" << stringFromSQLChar(column_name, column_name_length) << "'";
query << " AND COLUMN_NAME LIKE '" << stringFromSQLSymbols(column_name, column_name_length) << "'";

query << " ORDER BY TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION";

Expand Down Expand Up @@ -858,8 +858,8 @@ SQLNativeSql(HDBC connection_handle,

return doWith<Connection>(connection_handle, [&](Connection & connection)
{
std::string query_str = stringFromSQLChar(query, query_length);
return fillOutputRawString(query_str, out_query, out_query_max_length, out_query_length);
std::string query_str = stringFromSQLSymbols(query, query_length);
return fillOutputPlatformString(query_str, out_query, out_query_max_length, out_query_length, false);
});
}

Expand Down
2 changes: 1 addition & 1 deletion driver/statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void Statement::sendRequest()
request.setKeepAlive(true);
request.setChunkedTransferEncoding(true);
request.setCredentials("Basic", user_password_base64.str());
request.setURI("/?database=" + connection.database + "&default_format=ODBCDriver"); /// TODO Ability to transfer settings. TODO escaping
request.setURI("/?database=" + connection.getDatabase() + "&default_format=ODBCDriver"); /// TODO Ability to transfer settings. TODO escaping

if (in && in->peek() != EOF)
connection.session.reset();
Expand Down
89 changes: 69 additions & 20 deletions driver/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,65 @@ static const char * nextKeyValuePair(const char * data, const char * end, String
return value_end;
}

template <typename SIZE_TYPE>
std::string stringFromSQLSymbols(SQLTCHAR * data, SIZE_TYPE symbols)
{
if (!data || symbols == 0)
return{};
if (symbols == SQL_NTS)
{
#ifdef UNICODE
symbols = (SIZE_TYPE)wcslen(reinterpret_cast<LPCTSTR>(data));
#else
symbols = (SIZE_TYPE)strlen(reinterpret_cast<LPCTSTR>(data));
#endif
}
else if (symbols < 0)
throw std::runtime_error("invalid size of string : " + std::to_string(symbols));
#ifdef UNICODE
return std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t>()
.to_bytes(std::wstring(data, symbols));
#else
return{ (const char*)data, (size_t)symbols };
#endif
}

template <typename SIZE_TYPE>
std::string stringFromSQLChar(SQLTCHAR * data, SIZE_TYPE size)
std::string stringFromSQLBytes(SQLTCHAR * data, SIZE_TYPE size)
{
if (!data || size == 0)
return {};
// Count of symblols in the string
size_t symbols = 0;

if (size == SQL_NTS)
{
#ifdef UNICODE
size = (SIZE_TYPE)wcslen(reinterpret_cast<LPCTSTR>(data));
symbols = (SIZE_TYPE)wcslen(reinterpret_cast<LPCTSTR>(data));
#else
size = (SIZE_TYPE)strlen(reinterpret_cast<LPCTSTR>(data));
symbols = (SIZE_TYPE)strlen(reinterpret_cast<LPCTSTR>(data));
#endif
}
else if (size == SQL_IS_POINTER || size == SQL_IS_UINTEGER ||
size == SQL_IS_INTEGER || size == SQL_IS_USMALLINT ||
size == SQL_IS_SMALLINT)
{
throw std::runtime_error("SQL data is not a string");
}
else if (size < 0)
{
throw std::runtime_error("invalid size of string : " + std::to_string(size));
return{ reinterpret_cast<const char*>(data), (size_t)SQL_LEN_BINARY_ATTR(size) };
}
else
{
symbols = static_cast<size_t>(size) / sizeof(SQLTCHAR);
}

#ifdef UNICODE
std::wstring wstr(reinterpret_cast<LPCTSTR>(data), static_cast<size_t>(size));
return std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t>().to_bytes(wstr);
return std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t>()
.to_bytes(std::wstring(data, symbols));
#else
return{ reinterpret_cast<LPCTSTR>(data), static_cast<size_t>(size) };
return{ (const char*)data, (size_t)symbols };
#endif
}

Expand Down Expand Up @@ -120,31 +154,46 @@ template <typename STRING, typename PTR, typename LENGTH>
RETCODE fillOutputStringImpl(const STRING & value,
PTR out_value,
LENGTH out_value_max_length,
LENGTH * out_value_length)
LENGTH * out_value_length,
bool length_in_bytes)
{
using CharType = typename STRING::value_type;
LENGTH size_without_zero = static_cast<LENGTH>(value.size());
LENGTH symbols = static_cast<LENGTH>(value.size());

if (out_value_length)
*out_value_length = size_without_zero * sizeof(CharType);
{
if (length_in_bytes)
*out_value_length = symbols * sizeof(CharType);
else
*out_value_length = symbols;
}

if (out_value_max_length < 0)
return SQL_ERROR;

if (out_value)
{
if (out_value_max_length >= (size_without_zero + 1) * sizeof(CharType))
size_t max_length_in_bytes;

if (length_in_bytes)
max_length_in_bytes = out_value_max_length;
else
max_length_in_bytes = out_value_max_length * sizeof(CharType);

if (max_length_in_bytes >= (symbols + 1) * sizeof(CharType))
{
memcpy(out_value, value.c_str(), (size_without_zero + 1) * sizeof(CharType));
memcpy(out_value, value.c_str(), (symbols + 1) * sizeof(CharType));
}
else
{
if (out_value_max_length >= 2 * sizeof(CharType))
if (max_length_in_bytes >= sizeof(CharType))
{
memset(out_value, 0, out_value_max_length);
memcpy(out_value, value.data(), (out_value_max_length - 2) * sizeof(CharType));
memcpy(out_value, value.data(), max_length_in_bytes - sizeof(CharType));
reinterpret_cast<CharType*>(out_value)[(max_length_in_bytes / sizeof(CharType)) - 1] = 0;
}
return SQL_SUCCESS_WITH_INFO;

;
}
}

Expand All @@ -155,12 +204,12 @@ template <typename PTR, typename LENGTH>
RETCODE fillOutputRawString(const std::string & value,
PTR out_value, LENGTH out_value_max_length, LENGTH * out_value_length)
{
return fillOutputStringImpl(value, out_value, out_value_max_length, out_value_length);
return fillOutputStringImpl(value, out_value, out_value_max_length, out_value_length, true);
}

template <typename PTR, typename LENGTH>
RETCODE fillOutputUSC2String(const std::string & value,
PTR out_value, LENGTH out_value_max_length, LENGTH * out_value_length)
PTR out_value, LENGTH out_value_max_length, LENGTH * out_value_length, bool length_in_bytes = true)
{
#if defined (_win_)
using CharType = uint_least16_t;
Expand All @@ -169,16 +218,16 @@ RETCODE fillOutputUSC2String(const std::string & value,
#endif
return fillOutputStringImpl(
std::wstring_convert<std::codecvt_utf8<CharType>, CharType>().from_bytes(value),
out_value, out_value_max_length, out_value_length);
out_value, out_value_max_length, out_value_length, length_in_bytes);
}

template <typename PTR, typename LENGTH>
RETCODE fillOutputPlatformString(
const std::string & value,
PTR out_value, LENGTH out_value_max_length, LENGTH * out_value_length)
PTR out_value, LENGTH out_value_max_length, LENGTH * out_value_length, bool length_in_bytes = true)
{
#ifdef UNICODE
return fillOutputUSC2String(value, out_value, out_value_max_length, out_value_length);
return fillOutputUSC2String(value, out_value, out_value_max_length, out_value_length, length_in_bytes);
#else
return fillOutputRawString(value, out_value, out_value_max_length, out_value_length);
#endif
Expand Down

0 comments on commit 35fcce0

Please sign in to comment.