diff --git a/cpp/src/arrow/flight/sql/odbc/Brewfile b/cpp/src/arrow/flight/sql/odbc/Brewfile new file mode 100644 index 0000000000000..197f84c764fb3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/Brewfile @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +brew "aws-sdk-cpp" +brew "bash" +brew "boost" +brew "brotli" +brew "bzip2" +brew "c-ares" +brew "curl" +brew "ccache" +brew "cmake" +brew "flatbuffers" +brew "git" +brew "glog" +brew "googletest" +brew "grpc" +brew "llvm@14" +brew "lz4" +brew "mimalloc" +brew "ninja" +brew "node" +brew "openssl@3" +brew "pkg-config" +brew "protobuf" +brew "python" +brew "rapidjson" +brew "re2" +brew "snappy" +brew "thrift" +brew "utf8proc" +brew "wget" +brew "xsimd" +brew "zstd" +brew "libiodbc" diff --git a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt new file mode 100644 index 0000000000000..67440ccf28bc3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.16) +set(CMAKE_CXX_STANDARD 17) + +project(flightsql_odbc) + +if(CMAKE_BUILD_TYPE STREQUAL "Release") + add_compile_definitions(NDEBUG) +endif() + +# Add Boost dependencies. Should be pre-installed (Brew on Mac). +find_package(Boost REQUIRED) +find_package(ODBC REQUIRED) + +# Fetch and include GTest +# Adapted from Google's documentation: https://google.github.io/googletest/quickstart-cmake.html#set-up-a-project +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/v1.14.0.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +add_subdirectory(flight_sql) +add_subdirectory(odbcabstraction) diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt new file mode 100644 index 0000000000000..186189c1e7c0d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.16) +set(CMAKE_CXX_STANDARD 17) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +include_directories(include include/flight_sql + ${CMAKE_SOURCE_DIR}/odbcabstraction/include) + +if(DEFINED CMAKE_TOOLCHAIN_FILE) + include(${CMAKE_TOOLCHAIN_FILE}) +endif() + +# Add Zlib dependencies needed by Arrow Flight. Should be pre-installed +# unless provided by VCPKG. find_package(ZLIB REQUIRED) + +# Add Protobuf dependencies needed by Arrow Flight. Should be pre-installed. +# set(Protobuf_USE_STATIC_LIBS OFF) find_package(Protobuf REQUIRED) + +# Add OpenSSL dependencies needed by Arrow Flight. Should be pre-installed. # +# May need to set OPENSSL_ROOT_DIR first. On Mac if using brew: # brew install +# openssl@1.1 # add to the cmake line +# -DOPENSSL_ROOT_DIR=/usr/local/Cellar/openssl@1.1/1.1.1m if (NOT DEFINED +# OPENSSL_ROOT_DIR AND DEFINED APPLE AND NOT DEFINED CMAKE_TOOLCHAIN_FILE) +# set(OPENSSL_ROOT_DIR /usr/local/Cellar/openssl@1.1/1.1.1m) endif() # This is +# based on Arrow's FindOpenSSL module. It's not clear if both variables # need +# to be set. if (NOT DEFINED MSVC) set(OpenSSL_USE_STATIC_LIBS ON) +# set(OPENSSL_USE_STATIC_LIBS ON) endif() find_package(OpenSSL REQUIRED) + +# OpenSSL depends on krb5 on CentOS if (UNIX) list(APPEND OPENSSL_LIBRARIES krb5 +# k5crypto) endif() + +# Add gRPC dependencies needed by Arrow Flight. Should be pre-installed. +# find_package(gRPC 1.36 CONFIG REQUIRED) + +find_package(RapidJSON CONFIG REQUIRED) + +if(MSVC) + # the following definitions stop arrow from using __declspec when staticly + # linking and will break on Windows without them + add_compile_definitions(ARROW_STATIC ARROW_FLIGHT_STATIC) +endif() + +enable_testing() + +dd_library(arrow_odbc_spi_impl + include/flight_sql/flight_sql_driver.h + accessors/binary_array_accessor.cc + accessors/binary_array_accessor.h + accessors/boolean_array_accessor.cc + accessors/boolean_array_accessor.h + accessors/common.h + accessors/date_array_accessor.cc + accessors/date_array_accessor.h + accessors/decimal_array_accessor.cc + accessors/decimal_array_accessor.h + accessors/main.h + accessors/primitive_array_accessor.cc + accessors/primitive_array_accessor.h + accessors/string_array_accessor.cc + accessors/string_array_accessor.h + accessors/time_array_accessor.cc + accessors/time_array_accessor.h + accessors/timestamp_array_accessor.cc + accessors/timestamp_array_accessor.h + address_info.cc + address_info.h + flight_sql_auth_method.cc + flight_sql_auth_method.h + flight_sql_connection.cc + flight_sql_connection.h + flight_sql_driver.cc + flight_sql_get_tables_reader.cc + flight_sql_get_tables_reader.h + flight_sql_get_type_info_reader.cc + flight_sql_get_type_info_reader.h + flight_sql_result_set.cc + flight_sql_result_set.h + flight_sql_result_set_accessors.cc + flight_sql_result_set_accessors.h + flight_sql_result_set_column.cc + flight_sql_result_set_column.h + flight_sql_result_set_metadata.cc + flight_sql_result_set_metadata.h + flight_sql_ssl_config.cc + flight_sql_ssl_config.h + flight_sql_statement.cc + flight_sql_statement.h + flight_sql_statement_get_columns.cc + flight_sql_statement_get_columns.h + flight_sql_statement_get_tables.cc + flight_sql_statement_get_tables.h + flight_sql_statement_get_type_info.cc + flight_sql_statement_get_type_info.h + flight_sql_stream_chunk_buffer.cc + flight_sql_stream_chunk_buffer.h + get_info_cache.cc + get_info_cache.h + json_converter.cc + json_converter.h + record_batch_transformer.cc + record_batch_transformer.h + scalar_function_reporter.cc + scalar_function_reporter.h + system_trust_store.cc + system_trust_store.h + utils.cc) +target_include_directories(arrow_odbc_spi_impl PUBLIC ${CMAKE_CURRENT_LIST_DIR}) +if(WIN32) + target_sources(arrow_odbc_spi_impl + include/flight_sql/config/configuration.h + include/flight_sql/config/connection_string_parser.h + include/flight_sql/ui/add_property_window.h + include/flight_sql/ui/custom_window.h + include/flight_sql/ui/dsn_configuration_window.h + include/flight_sql/ui/window.h + config/configuration.cc + config/connection_string_parser.cc + ui/custom_window.cc + ui/window.cc + ui/dsn_configuration_window.cc + ui/add_property_window.cc + system_dsn.cc) +endif() + +find_package(ArrowFlightSql) + +target_link_libraries(arrow_odbc_spi_impl PUBLIC odbcabstraction + ArrowFlightSql::arrow_flight_sql_static) + +if(MSVC) + set(CMAKE_CXX_FLAGS_RELEASE "/MD") + set(CMAKE_CXX_FLAGS_DEBUG "/MDd") +# else +# target_link_libraries(arrow_odbc_spi_impl PUBLIC ArrowFlightSql::arrow_flight_sql_static) +endif() + +# set(ARROW_ODBC_SPI_THIRDPARTY_LIBS +# ${ARROW_LIBS} gRPC::grpc++ ${ZLIB_LIBRARIES} ${Protobuf_LIBRARIES} +# ${OPENSSL_LIBRARIES} ${RapidJSON_LIBRARIES}) + +if(MSVC) + find_package(Boost REQUIRED COMPONENTS locale) + target_link_libraries(arrow_odbc_spi_impl PUBLIC Boost::locale) +endif() + +set_target_properties( + arrow_odbc_spi_impl + PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) + +# target_include_directories(arrow_odbc_spi_impl +# PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +# CLI +add_executable(arrow_odbc_spi_impl_cli main.cc) +set_target_properties( + arrow_odbc_spi_impl_cli PROPERTIES RUNTIME_OUTPUT_DIRECTORY + ${CMAKE_BINARY_DIR}/$/bin) +target_link_libraries(arrow_odbc_spi_impl_cli arrow_odbc_spi_impl) + +# Unit tests +add_executable( + arrow_odbc_spi_impl_test + accessors/boolean_array_accessor_test.cc + accessors/binary_array_accessor_test.cc + accessors/date_array_accessor_test.cc + accessors/decimal_array_accessor_test.cc + accessors/primitive_array_accessor_test.cc + accessors/string_array_accessor_test.cc + accessors/time_array_accessor_test.cc + accessors/timestamp_array_accessor_test.cc + flight_sql_connection_test.cc + parse_table_types_test.cc + json_converter_test.cc + record_batch_transformer_test.cc + utils_test.cc) + +set_target_properties( + arrow_odbc_spi_impl_test PROPERTIES RUNTIME_OUTPUT_DIRECTORY + ${CMAKE_BINARY_DIR}/test/$/bin) +target_link_libraries(arrow_odbc_spi_impl_test arrow_odbc_spi_impl gtest + gtest_main) + +add_test(connection_test arrow_odbc_spi_impl_test) +add_test(transformer_test arrow_odbc_spi_impl_test) + +add_custom_command( + TARGET arrow_odbc_spi_impl_test + COMMENT "Run tests" + POST_BUILD + COMMAND ${CMAKE_BINARY_DIR}/test/$/bin/arrow_odbc_spi_impl_test) diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.cc new file mode 100644 index 0000000000000..b7154b9368335 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.cc @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "binary_array_accessor.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +namespace { + +inline RowStatus MoveSingleCellToBinaryBuffer(ColumnBinding *binding, + BinaryArray *array, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + RowStatus result = odbcabstraction::RowStatus_SUCCESS; + + const char *value = array->Value(arrow_row).data(); + size_t size_in_bytes = array->value_length(arrow_row); + + size_t remaining_length = static_cast(size_in_bytes - value_offset); + size_t value_length = + std::min(remaining_length, + binding->buffer_length); + + auto *byte_buffer = static_cast(binding->buffer) + + i * binding->buffer_length; + memcpy(byte_buffer, ((char *)value) + value_offset, value_length); + + if (remaining_length > binding->buffer_length) { + result = odbcabstraction::RowStatus_SUCCESS_WITH_INFO; + diagnostics.AddTruncationWarning(); + if (update_value_offset) { + value_offset += value_length; + } + } else if (update_value_offset) { + value_offset = -1; + } + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(remaining_length); + } + + return result; +} + +} // namespace + +template +BinaryArrayFlightSqlAccessor::BinaryArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array) {} + +template <> +RowStatus BinaryArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + return MoveSingleCellToBinaryBuffer(binding, this->GetArray(), arrow_row, i, value_offset, + update_value_offset, diagnostics); +} + +template +size_t BinaryArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return binding->buffer_length; +} + +template class BinaryArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.h new file mode 100644 index 0000000000000..aea7b41ca487c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class BinaryArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit BinaryArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor_test.cc new file mode 100644 index 0000000000000..4b75b455b4c8e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/binary_array_accessor_test.cc @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/builder.h" +#include "binary_array_accessor.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BinaryArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BINARY, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length(), strlen_buffer[i]); + // Beware that CDataType_BINARY values are not null terminated. + // It's safe to create a std::string from this data because we know it's + // ASCII, this doesn't work with arbitrary binary data. + ASSERT_EQ(values[i], + std::string(buffer.data() + i * max_strlen, + buffer.data() + i * max_strlen + strlen_buffer[i])); + } +} + +TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Truncation) { + std::vector values = { + "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BinaryArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BINARY, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + std::stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + do { + diagnostics.Clear(); + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor.GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() - original_value_offset, strlen_buffer[0]); + + int64_t chunk_length = 0; + if (value_offset == -1) { + chunk_length = strlen_buffer[0]; + } else { + chunk_length = max_strlen; + } + + // Beware that CDataType_BINARY values are not null terminated. + // It's safe to create a std::string from this data because we know it's + // ASCII, this doesn't work with arbitrary binary data. + ss << std::string(buffer.data(), buffer.data() + chunk_length); + } while (value_offset < values[0].length() && value_offset != -1); + + ASSERT_EQ(values[0], ss.str()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.cc new file mode 100644 index 0000000000000..1f1883c99a6b6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.cc @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "boolean_array_accessor.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +BooleanArrayFlightSqlAccessor::BooleanArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array) {} + +template +RowStatus BooleanArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + typedef unsigned char c_type; + bool value = this->GetArray()->Value(arrow_row); + + auto *buffer = static_cast(binding->buffer); + buffer[i] = value ? 1 : 0; + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t BooleanArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(unsigned char); +} + +template class BooleanArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.h new file mode 100644 index 0000000000000..5a5fbe68a8f43 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class BooleanArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit BooleanArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, + int64_t i, int64_t &value_offset, + bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor_test.cc new file mode 100644 index 0000000000000..94e1899564b49 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/boolean_array_accessor_test.cc @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "boolean_array_accessor.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(BooleanArrayFlightSqlAccessor, Test_BooleanArray_CDataType_BIT) { + const std::vector values = {true, false, true}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BooleanArrayFlightSqlAccessor accessor(array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BIT, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(unsigned char), strlen_buffer[i]); + ASSERT_EQ(values[i] ? 1 : 0, buffer[i]); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/common.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/common.h new file mode 100644 index 0000000000000..0de9a3bc74139 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/common.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "types.h" +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +inline size_t CopyFromArrayValuesToBinding(ARRAY_TYPE* array, + ColumnBinding *binding, + int64_t starting_row, int64_t cells) { + constexpr ssize_t element_size = sizeof(typename ARRAY_TYPE::value_type); + + if (binding->strlen_buffer) { + for (int64_t i = 0; i < cells; ++i) { + int64_t current_row = starting_row + i; + if (array->IsNull(current_row)) { + binding->strlen_buffer[i] = NULL_DATA; + } else { + binding->strlen_buffer[i] = element_size; + } + } + } else { + // Duplicate this loop to avoid null checks within the loop. + for (int64_t i = starting_row; i < starting_row + cells; ++i) { + if (array->IsNull(i)) { + throw odbcabstraction::NullWithoutIndicatorException(); + } + } + } + + // Copy the entire array to the bound ODBC buffers. + // Note that the array should already have been sliced down to the same number + // of elements in the ODBC data array by the point in which this function is called. + const auto *values = array->raw_values(); + memcpy(binding->buffer, &values[starting_row], element_size * cells); + + return cells; +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.cc new file mode 100644 index 0000000000000..af8c6cbb6c7f6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.cc @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "date_array_accessor.h" +#include "time.h" +#include "arrow/compute/api.h" +#include "odbcabstraction/calendar_utils.h" + +using namespace arrow; + + +namespace { + template int64_t convertDate(typename T::value_type value) { + return value; + } + +/// Converts the value from the array, which is in milliseconds, to seconds. +/// \param value the value extracted from the array in milliseconds. +/// \return the converted value in seconds. + template <> int64_t convertDate(int64_t value) { + return value / driver::flight_sql::MILLI_TO_SECONDS_DIVISOR; + } + +/// Converts the value from the array, which is in days, to seconds. +/// \param value the value extracted from the array in days. +/// \return the converted value in seconds. + template <> int64_t convertDate(int32_t value) { + return value * driver::flight_sql::DAYS_TO_SECONDS_MULTIPLIER; + } +} // namespace + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; + +template +DateArrayFlightSqlAccessor< + TARGET_TYPE, ARROW_ARRAY>::DateArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>( + array) {} + +template +RowStatus DateArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + auto *buffer = static_cast(binding->buffer); + auto value = convertDate(this->GetArray()->Value(arrow_row)); + tm date{}; + + GetTimeForSecondsSinceEpoch(date, value); + + buffer[cell_counter].year = 1900 + (date.tm_year); + buffer[cell_counter].month = date.tm_mon + 1; + buffer[cell_counter].day = date.tm_mday; + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t DateArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(DATE_STRUCT); +} + +template class DateArrayFlightSqlAccessor; +template class DateArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.h new file mode 100644 index 0000000000000..237965be3fa56 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class DateArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + DateArrayFlightSqlAccessor> { + +public: + explicit DateArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor_test.cc new file mode 100644 index 0000000000000..a6cc7e2f01e3d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/date_array_accessor_test.cc @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "boolean_array_accessor.h" +#include "date_array_accessor.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(DateArrayAccessor, Test_Date32Array_CDataType_DATE) { + std::vector values = {7589, 12320, 18980, 19095}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + DateArrayFlightSqlAccessor accessor( + dynamic_cast *>(array.get())); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_DATE, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(DATE_STRUCT), strlen_buffer[i]); + tm date{}; + + int64_t converted_time = values[i] * 86400; + GetTimeForSecondsSinceEpoch(date, converted_time); + ASSERT_EQ((date.tm_year + 1900), buffer[i].year); + ASSERT_EQ(date.tm_mon + 1, buffer[i].month); + ASSERT_EQ(date.tm_mday, buffer[i].day); + } +} + +TEST(DateArrayAccessor, Test_Date64Array_CDataType_DATE) { + std::vector values = {86400000, 172800000, 259200000, 1649793238110, + 345600000, 432000000, 518400000}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + DateArrayFlightSqlAccessor accessor( + dynamic_cast *>(array.get())); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_DATE, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(DATE_STRUCT), strlen_buffer[i]); + tm date{}; + + int64_t converted_time = values[i] / 1000; + GetTimeForSecondsSinceEpoch(date, converted_time); + ASSERT_EQ((date.tm_year + 1900), buffer[i].year); + ASSERT_EQ(date.tm_mon + 1, buffer[i].month); + ASSERT_EQ(date.tm_mday, buffer[i].day); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.cc new file mode 100644 index 0000000000000..5b0300b15e394 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "decimal_array_accessor.h" + +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +DecimalArrayFlightSqlAccessor::DecimalArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array), + data_type_(static_cast(array->type().get())) { +} + +template <> +RowStatus DecimalArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + auto result = &(static_cast(binding->buffer)[i]); + int32_t original_scale = data_type_->scale(); + + const uint8_t* bytes = this->GetArray()->Value(arrow_row); + Decimal128 value(bytes); + if (original_scale != binding->scale) { + const Status &status = value.Rescale(original_scale, binding->scale).Value(&value); + ThrowIfNotOK(status); + } + if (!value.FitsInPrecision(binding->precision)) { + throw DriverException("Decimal value doesn't fit in precision " + std::to_string(binding->precision)); + } + + result->sign = value.IsNegative() ? 0 : 1; + + // Take the absolute value since the ODBC SQL_NUMERIC_STRUCT holds + // a positive-only number. + if (value.IsNegative()) { + Decimal128 abs_value = Decimal128::Abs(value); + abs_value.ToBytes(result->val); + } else { + value.ToBytes(result->val); + } + result->precision = static_cast(binding->precision); + result->scale = static_cast(binding->scale); + + result->precision = data_type_->precision(); + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t DecimalArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(NUMERIC_STRUCT); +} + +template class DecimalArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.h new file mode 100644 index 0000000000000..51bec2ad1cc3b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include "utils.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class DecimalArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit DecimalArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; + +private: + Decimal128Type *data_type_; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor_test.cc new file mode 100644 index 0000000000000..359ea55586cdc --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/decimal_array_accessor_test.cc @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/decimal.h" +#include "arrow/builder.h" +#include "arrow/testing/builder.h" +#include "decimal_array_accessor.h" +#include "gtest/gtest.h" + +namespace { + +using namespace arrow; +using namespace driver::odbcabstraction; +using driver::flight_sql::ThrowIfNotOK; + +std::vector MakeDecimalVector(const std::vector &values, + int32_t scale) { + std::vector ret; + for (const auto &str: values) { + Decimal128 str_value; + int32_t str_precision; + int32_t str_scale; + + ThrowIfNotOK(Decimal128::FromString(str, &str_value, &str_precision, &str_scale)); + + Decimal128 scaled_value; + if (str_scale == scale) { + scaled_value = str_value; + } else { + scaled_value = str_value.Rescale(str_scale, scale).ValueOrDie(); + } + ret.push_back(scaled_value); + } + return ret; +} + +std::string ConvertNumericToString(NUMERIC_STRUCT &numeric) { + auto v = reinterpret_cast(numeric.val); + auto decimal = Decimal128(v[1], v[0]); + if (numeric.sign == 0) { + decimal.Negate(); + } + const std::string &string = decimal.ToString(numeric.scale); + + return string; +} +} + +namespace driver { +namespace flight_sql { + +void AssertNumericOutput(int input_precision, int input_scale, const std::vector &values_str, + int output_precision, int output_scale, const std::vector &expected_values_str) { + auto decimal_type = std::make_shared(input_precision, input_scale); + const std::vector &values = MakeDecimalVector(values_str, decimal_type->scale()); + + std::shared_ptr array; + ArrayFromVector(decimal_type, values, &array); + + DecimalArrayFlightSqlAccessor accessor(array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_NUMERIC, output_precision, output_scale, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(NUMERIC_STRUCT), strlen_buffer[i]); + + ASSERT_EQ(output_precision, buffer[i].precision); + ASSERT_EQ(output_scale, buffer[i].scale); + ASSERT_STREQ(expected_values_str[i].c_str(), ConvertNumericToString(buffer[i]).c_str()); + } +} + +TEST(DecimalArrayFlightSqlAccessor, Test_Decimal128Array_CDataType_NUMERIC_SameScale) { + const std::vector &input_values = {"25.212", "-25.212", "-123456789.123", "123456789.123"}; + const std::vector &output_values = input_values; // String values should be the same + + AssertNumericOutput(38, 3, input_values, 38, 3, output_values); +} + +TEST(DecimalArrayFlightSqlAccessor, Test_Decimal128Array_CDataType_NUMERIC_IncreasingScale) { + const std::vector &input_values = {"25.212", "-25.212", "-123456789.123", "123456789.123"}; + const std::vector &output_values = {"25.2120", "-25.2120", "-123456789.1230", "123456789.1230"}; + + AssertNumericOutput(38, 3, input_values, 38, 4, output_values); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/main.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/main.h new file mode 100644 index 0000000000000..14a49033278b3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/main.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "binary_array_accessor.h" +#include "boolean_array_accessor.h" +#include "date_array_accessor.h" +#include "time_array_accessor.h" +#include "timestamp_array_accessor.h" +#include "decimal_array_accessor.h" +#include "primitive_array_accessor.h" +#include "string_array_accessor.h" diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.cc new file mode 100644 index 0000000000000..7be374c82e67b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.cc @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "primitive_array_accessor.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +PrimitiveArrayFlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE>::PrimitiveArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + PrimitiveArrayFlightSqlAccessor>(array) {} + +template +size_t +PrimitiveArrayFlightSqlAccessor::GetColumnarData_impl( + ColumnBinding *binding, int64_t starting_row, + int64_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) { + return CopyFromArrayValuesToBinding(this->GetArray(), binding, starting_row, cells); +} + +template +size_t PrimitiveArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(typename ARROW_ARRAY::TypeClass::c_type); +} + +template class PrimitiveArrayFlightSqlAccessor< + Int64Array, odbcabstraction::CDataType_SBIGINT>; +template class PrimitiveArrayFlightSqlAccessor< + Int32Array, odbcabstraction::CDataType_SLONG>; +template class PrimitiveArrayFlightSqlAccessor< + Int16Array, odbcabstraction::CDataType_SSHORT>; +template class PrimitiveArrayFlightSqlAccessor< + Int8Array, odbcabstraction::CDataType_STINYINT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt64Array, odbcabstraction::CDataType_UBIGINT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt32Array, odbcabstraction::CDataType_ULONG>; +template class PrimitiveArrayFlightSqlAccessor< + UInt16Array, odbcabstraction::CDataType_USHORT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt8Array, odbcabstraction::CDataType_UTINYINT>; +template class PrimitiveArrayFlightSqlAccessor< + DoubleArray, odbcabstraction::CDataType_DOUBLE>; +template class PrimitiveArrayFlightSqlAccessor< + FloatArray, odbcabstraction::CDataType_FLOAT>; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.h new file mode 100644 index 0000000000000..6830ab4c1d925 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "../flight_sql_result_set.h" +#include "common.h" +#include "types.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class PrimitiveArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + PrimitiveArrayFlightSqlAccessor> { +public: + explicit PrimitiveArrayFlightSqlAccessor(Array *array); + + size_t GetColumnarData_impl(ColumnBinding *binding, int64_t starting_row, int64_t cells, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc new file mode 100644 index 0000000000000..daa561d0cf458 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "primitive_array_accessor.h" +#include +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +void TestPrimitiveArraySqlAccessor() { + typedef typename ARROW_ARRAY::TypeClass::c_type c_type; + + std::vector values = {0, 1, 2, 3, 127}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + PrimitiveArrayFlightSqlAccessor accessor( + array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(TARGET_TYPE, 0, 0, buffer.data(), values.size(), + strlen_buffer.data()); + + int64_t value_offset = 0; + driver::odbcabstraction::Diagnostics diagnostics("Dummy", "Dummy", odbcabstraction::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(c_type), strlen_buffer[i]); + ASSERT_EQ(values[i], buffer[i]); + } +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int64Array_CDataType_SBIGINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int32Array_CDataType_SLONG) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int16Array_CDataType_SSHORT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int8Array_CDataType_STINYINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt64Array_CDataType_UBIGINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt32Array_CDataType_ULONG) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt16Array_CDataType_USHORT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt8Array_CDataType_UTINYINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_FloatArray_CDataType_FLOAT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_DoubleArray_CDataType_DOUBLE) { + TestPrimitiveArraySqlAccessor(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.cc new file mode 100644 index 0000000000000..8e24b47f16ba5 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.cc @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "string_array_accessor.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +namespace { + +#if defined _WIN32 || defined _WIN64 +std::string utf8_to_clocale(const char *utf8str, int len) +{ + thread_local boost::locale::generator g; + g.locale_cache_enabled(true); + std::locale loc = g(boost::locale::util::get_system_locale()); + return boost::locale::conv::from_utf(utf8str, utf8str + len, loc); +} +#endif + +template +inline RowStatus MoveSingleCellToCharBuffer(std::vector &buffer, + int64_t& last_retrieved_arrow_row, +#if defined _WIN32 || defined _WIN64 + std::string &clocale_str, +#endif + ColumnBinding *binding, + StringArray *array, int64_t arrow_row, int64_t i, + int64_t &value_offset, + bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + RowStatus result = odbcabstraction::RowStatus_SUCCESS; + + // Arrow strings come as UTF-8 + const char *raw_value = array->Value(arrow_row).data(); + const size_t raw_value_length = array->value_length(arrow_row); + const void *value; + + size_t size_in_bytes; + if (sizeof(CHAR_TYPE) > sizeof(char)) { + if (last_retrieved_arrow_row != arrow_row) { + Utf8ToWcs(raw_value, raw_value_length, &buffer); + last_retrieved_arrow_row = arrow_row; + } + value = buffer.data(); + size_in_bytes = buffer.size(); + } else { +#if defined _WIN32 || defined _WIN64 + // Convert to C locale string + if (last_retrieved_arrow_row != arrow_row) { + clocale_str = utf8_to_clocale(raw_value, raw_value_length); + last_retrieved_arrow_row = arrow_row; + } + const char* clocale_data = clocale_str.data(); + size_t clocale_length = clocale_str.size(); + + value = clocale_data; + size_in_bytes = clocale_length; +#else + value = raw_value; + size_in_bytes = raw_value_length; +#endif + } + + size_t remaining_length = static_cast(size_in_bytes - value_offset); + size_t value_length = + std::min(remaining_length, + binding->buffer_length); + + auto *byte_buffer = + static_cast(binding->buffer) + i * binding->buffer_length; + auto *char_buffer = (CHAR_TYPE *)byte_buffer; + memcpy(char_buffer, ((char *)value) + value_offset, value_length); + + // Write a NUL terminator + if (binding->buffer_length >= remaining_length + sizeof(CHAR_TYPE)) { + // The entire remainder of the data was consumed. + char_buffer[remaining_length / sizeof(CHAR_TYPE)] = '\0'; + if (update_value_offset) { + // Mark that there's no data remaining. + value_offset = -1; + } + } else { + result = odbcabstraction::RowStatus_SUCCESS_WITH_INFO; + diagnostics.AddTruncationWarning(); + size_t chars_written = binding->buffer_length / sizeof(CHAR_TYPE); + // If we failed to even write one char, the buffer is too small to hold a + // NUL-terminator. + if (chars_written > 0) { + char_buffer[(chars_written - 1)] = '\0'; + if (update_value_offset) { + value_offset += binding->buffer_length - sizeof(CHAR_TYPE); + } + } + } + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(remaining_length); + } + + return result; +} + +} // namespace + +template +StringArrayFlightSqlAccessor::StringArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array), + last_arrow_row_(-1){} + +template +RowStatus StringArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + return MoveSingleCellToCharBuffer(buffer_, last_arrow_row_, +#if defined _WIN32 || defined _WIN64 + clocale_str_, +#endif + binding, + this->GetArray(), arrow_row, i, value_offset, update_value_offset, diagnostics); +} + +template +size_t StringArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return binding->buffer_length; +} + +template class StringArrayFlightSqlAccessor; +template class StringArrayFlightSqlAccessor; +template class StringArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.h new file mode 100644 index 0000000000000..da1caeff060d8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor.h @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include "utils.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class StringArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit StringArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; + +private: + std::vector buffer_; +#if defined _WIN32 || defined _WIN64 + std::string clocale_str_; +#endif + int64_t last_arrow_row_; +}; + +inline Accessor* CreateWCharStringArrayAccessor(arrow::Array *array) { + switch(GetSqlWCharSize()) { + case sizeof(char16_t): + return new StringArrayFlightSqlAccessor(array); + case sizeof(char32_t): + return new StringArrayFlightSqlAccessor(array); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc new file mode 100644 index 0000000000000..078ed3c771a4d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "string_array_accessor.h" +#include "gtest/gtest.h" +#include "odbcabstraction/encoding.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(StringArrayAccessor, Test_CDataType_CHAR_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + StringArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_CHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length(), strlen_buffer[i]); + ASSERT_EQ(values[i], std::string(buffer.data() + i * max_strlen)); + } +} + +TEST(StringArrayAccessor, Test_CDataType_CHAR_Truncation) { + std::vector values = { + "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + StringArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_CHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + std::stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + do { + diagnostics.Clear(); + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor.GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() - original_value_offset, strlen_buffer[0]); + + ss << buffer.data(); + } while (value_offset < values[0].length() && value_offset != -1); + + ASSERT_EQ(values[0], ss.str()); +} + +TEST(StringArrayAccessor, Test_CDataType_WCHAR_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + auto accessor = CreateWCharStringArrayAccessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_WCHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor->GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length() * GetSqlWCharSize(), strlen_buffer[i]); + std::vector expected; + Utf8ToWcs(values[i].c_str(), &expected); + uint8_t *start = buffer.data() + i * max_strlen; + auto actual = std::vector(start, start + strlen_buffer[i]); + ASSERT_EQ(expected, actual); + } +} + +TEST(StringArrayAccessor, Test_CDataType_WCHAR_Truncation) { + std::vector values = { + "ABCDEFA"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + auto accessor = CreateWCharStringArrayAccessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_WCHAR, 0, 0, buffer.data(), + max_strlen, strlen_buffer.data()); + + std::basic_stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + std::vector finalStr; + driver::odbcabstraction::Diagnostics diagnostics("Dummy", "Dummy", odbcabstraction::V_3); + do { + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor->GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() * GetSqlWCharSize() - original_value_offset, strlen_buffer[0]); + + size_t length = value_offset - original_value_offset; + if (value_offset == -1) { + length = buffer.size(); + } + finalStr.insert(finalStr.end(), buffer.data(), buffer.data() + length); + + } while (value_offset < values[0].length() * GetSqlWCharSize() && value_offset != -1); + + // Trim final null bytes + finalStr.resize(values[0].length() * GetSqlWCharSize()); + + std::vector expected; + Utf8ToWcs(values[0].c_str(), &expected); + ASSERT_EQ(expected, finalStr); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.cc new file mode 100644 index 0000000000000..28e01d6e1e7a2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.cc @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "time_array_accessor.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +Accessor* CreateTimeAccessor(arrow::Array *array, arrow::Type::type type) { + auto time_type = + arrow::internal::checked_pointer_cast(array->type()); + auto time_unit = time_type->unit(); + + if (type == arrow::Type::TIME32) { + switch (time_unit) { + case TimeUnit::SECOND: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MILLI: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MICRO: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::NANO: + return new TimeArrayFlightSqlAccessor(array); + } + } else if (type == arrow::Type::TIME64) { + switch (time_unit) { + case TimeUnit::SECOND: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MILLI: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MICRO: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::NANO: + return new TimeArrayFlightSqlAccessor(array); + } + } + assert(false); + throw DriverException("Unsupported input supplied to CreateTimeAccessor"); +} + +namespace { +template +int64_t ConvertTimeValue(typename T::value_type value, TimeUnit::type unit) { + return value; +} + +template <> +int64_t ConvertTimeValue(int32_t value, TimeUnit::type unit) { + return unit == TimeUnit::SECOND ? value : value / MILLI_TO_SECONDS_DIVISOR; +} + +template <> +int64_t ConvertTimeValue(int64_t value, TimeUnit::type unit) { + return unit == TimeUnit::MICRO ? value / MICRO_TO_SECONDS_DIVISOR + : value / NANO_TO_SECONDS_DIVISOR; +} +} // namespace + +template +TimeArrayFlightSqlAccessor< + TARGET_TYPE, ARROW_ARRAY, UNIT>::TimeArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>( + array) {} + +template +RowStatus TimeArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostic) { + auto *buffer = static_cast(binding->buffer); + + tm time{}; + + auto converted_value_seconds = + ConvertTimeValue(this->GetArray()->Value(arrow_row), UNIT); + + GetTimeForSecondsSinceEpoch(time, converted_value_seconds); + + buffer[cell_counter].hour = time.tm_hour; + buffer[cell_counter].minute = time.tm_min; + buffer[cell_counter].second = time.tm_sec; + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t TimeArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(TIME_STRUCT); +} + +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.h new file mode 100644 index 0000000000000..8b3e26fb72e92 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +Accessor* CreateTimeAccessor(arrow::Array *array, arrow::Type::type type); + +template +class TimeArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + TimeArrayFlightSqlAccessor> { + +public: + explicit TimeArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostic); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor_test.cc new file mode 100644 index 0000000000000..78eac3f1e6eb8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/time_array_accessor_test.cc @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "time_array_accessor.h" +#include "utils.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(TEST_TIME32, TIME_WITH_SECONDS) { + auto value_field = field("f0", time32(TimeUnit::SECOND)); + + std::vector t32_values = {14896, 14897, 14892, 85400, 14893, 14895}; + + std::shared_ptr time32_array; + ArrayFromVector(value_field->type(), + t32_values, &time32_array); + + TimeArrayFlightSqlAccessor accessor(time32_array.get()); + + std::vector buffer(t32_values.size()); + std::vector strlen_buffer(t32_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t32_values.size(), + accessor.GetColumnarData(&binding, 0, t32_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t32_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + GetTimeForSecondsSinceEpoch(time, t32_values[i]); + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME32, TIME_WITH_MILLI) { + auto value_field = field("f0", time32(TimeUnit::MILLI)); + std::vector t32_values = {14896000, 14897000, 14892000, + 85400000, 14893000, 14895000}; + + std::shared_ptr time32_array; + ArrayFromVector(value_field->type(), + t32_values, &time32_array); + + TimeArrayFlightSqlAccessor accessor(time32_array.get()); + + std::vector buffer(t32_values.size()); + std::vector strlen_buffer(t32_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t32_values.size(), + accessor.GetColumnarData(&binding, 0, t32_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t32_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + auto convertedValue = t32_values[i] / MILLI_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME64, TIME_WITH_MICRO) { + auto value_field = field("f0", time64(TimeUnit::MICRO)); + + std::vector t64_values = {14896000, 14897000, 14892000, + 85400000, 14893000, 14895000}; + + std::shared_ptr time64_array; + ArrayFromVector(value_field->type(), + t64_values, &time64_array); + + TimeArrayFlightSqlAccessor accessor(time64_array.get()); + + std::vector buffer(t64_values.size()); + std::vector strlen_buffer(t64_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t64_values.size(), + accessor.GetColumnarData(&binding, 0, t64_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t64_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + const auto convertedValue = t64_values[i] / MICRO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME64, TIME_WITH_NANO) { + auto value_field = field("f0", time64(TimeUnit::NANO)); + std::vector t64_values = {14896000000, 14897000000, 14892000000, + 85400000000, 14893000000, 14895000000}; + + std::shared_ptr time64_array; + ArrayFromVector(value_field->type(), + t64_values, &time64_array); + + TimeArrayFlightSqlAccessor accessor( + time64_array.get()); + + std::vector buffer(t64_values.size()); + std::vector strlen_buffer(t64_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t64_values.size(), + accessor.GetColumnarData(&binding, 0, t64_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t64_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + const auto convertedValue = t64_values[i] / NANO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc new file mode 100644 index 0000000000000..0d18857df6bad --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "timestamp_array_accessor.h" +#include "odbcabstraction/calendar_utils.h" + +using namespace arrow; + +namespace { +int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { + int64_t divisor = 1; + switch (unit) { + case TimeUnit::SECOND: + divisor = 1; + break; + case TimeUnit::MILLI: + divisor = driver::flight_sql::MILLI_TO_SECONDS_DIVISOR; + break; + case TimeUnit::MICRO: + divisor = driver::flight_sql::MICRO_TO_SECONDS_DIVISOR; + break; + case TimeUnit::NANO: + divisor = driver::flight_sql::NANO_TO_SECONDS_DIVISOR; + break; + default: + assert(false); + throw driver::odbcabstraction::DriverException("Unrecognized time unit value: " + std::to_string(unit)); + } + return divisor; +} + +uint32_t CalculateFraction(TimeUnit::type unit, uint64_t units_since_epoch) { + // Convert the given remainder and time unit to nanoseconds + // since the fraction field on TIMESTAMP_STRUCT is in nanoseconds. + switch (unit) { + case TimeUnit::SECOND: + return 0; + case TimeUnit::MILLI: + // 1000000 nanoseconds = 1 millisecond. + return (units_since_epoch % + driver::odbcabstraction::MILLI_TO_SECONDS_DIVISOR) * + 1000000; + case TimeUnit::MICRO: + // 1000 nanoseconds = 1 microsecond. + return (units_since_epoch % + driver::odbcabstraction::MICRO_TO_SECONDS_DIVISOR) * 1000; + case TimeUnit::NANO: + // 1000 nanoseconds = 1 microsecond. + return (units_since_epoch % + driver::odbcabstraction::NANO_TO_SECONDS_DIVISOR); + } + return 0; +} +} // namespace + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; + +template +TimestampArrayFlightSqlAccessor::TimestampArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>(array) {} + +template +RowStatus +TimestampArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + auto *buffer = static_cast(binding->buffer); + + int64_t value = this->GetArray()->Value(arrow_row); + const auto divisor = GetConversionToSecondsDivisor(UNIT); + const auto converted_result_seconds = value / divisor; + tm timestamp = {0}; + + GetTimeForSecondsSinceEpoch(timestamp, converted_result_seconds); + + buffer[cell_counter].year = 1900 + (timestamp.tm_year); + buffer[cell_counter].month = timestamp.tm_mon + 1; + buffer[cell_counter].day = timestamp.tm_mday; + buffer[cell_counter].hour = timestamp.tm_hour; + buffer[cell_counter].minute = timestamp.tm_min; + buffer[cell_counter].second = timestamp.tm_sec; + buffer[cell_counter].fraction = CalculateFraction(UNIT, value); + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t TimestampArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(TIMESTAMP_STRUCT); +} + +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.h new file mode 100644 index 0000000000000..8fa727aeda8bb --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class TimestampArrayFlightSqlAccessor + : public FlightSqlAccessor> { + +public: + explicit TimestampArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor_test.cc new file mode 100644 index 0000000000000..ebecd22fe34dd --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor_test.cc @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/builder.h" +#include "timestamp_array_accessor.h" +#include "utils.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MILLI) { + std::vector values = {86400370, 172800000, 259200000, 1649793238110LL, + 345600000, 432000000, 518400000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::MILLI)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + + tm date{}; + + auto converted_time = values[i] / MILLI_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + + constexpr uint32_t NANOSECONDS_PER_MILLI = 1000000; + ASSERT_EQ(buffer[i].fraction, (values[i] % MILLI_TO_SECONDS_DIVISOR) * NANOSECONDS_PER_MILLI); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_SECONDS) { + std::vector values = {86400, 172800, 259200, 1649793238, + 345600, 432000, 518400}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::SECOND)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + tm date{}; + + auto converted_time = values[i]; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + ASSERT_EQ(buffer[i].fraction, 0); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MICRO) { + std::vector values = {86400000000, 1649793238000000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::MICRO)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + + tm date{}; + + auto converted_time = values[i] / MICRO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + constexpr uint32_t MICROS_PER_NANO = 1000; + ASSERT_EQ(buffer[i].fraction, (values[i] % MICRO_TO_SECONDS_DIVISOR) * MICROS_PER_NANO); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_NANO) { + std::vector values = {86400000010000, 1649793238000000000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::NANO)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + tm date{}; + + auto converted_time = values[i] / NANO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + ASSERT_EQ(buffer[i].fraction, (values[i] % NANO_TO_SECONDS_DIVISOR)); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/types.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/types.h new file mode 100644 index 0000000000000..f765e3d392da3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/types.h @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; +using odbcabstraction::CDataType; + +class FlightSqlResultSet; + +struct ColumnBinding { + void *buffer; + ssize_t *strlen_buffer; + size_t buffer_length; + CDataType target_type; + int precision; + int scale; + + ColumnBinding() = default; + + ColumnBinding(CDataType target_type, int precision, int scale, void *buffer, + size_t buffer_length, ssize_t *strlen_buffer) + : target_type(target_type), precision(precision), scale(scale), + buffer(buffer), buffer_length(buffer_length), + strlen_buffer(strlen_buffer) {} +}; + +/// \brief Accessor interface meant to provide a way of populating data of a +/// single column to buffers bound by `ColumnarResultSet::BindColumn`. +class Accessor { +public: + const CDataType target_type_; + + Accessor(CDataType target_type) : target_type_(target_type) {} + + virtual ~Accessor() = default; + + /// \brief Populates next cells + virtual size_t GetColumnarData(ColumnBinding *binding, int64_t starting_row, + size_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) = 0; + + virtual size_t GetCellLength(ColumnBinding *binding) const = 0; +}; + +template +class FlightSqlAccessor : public Accessor { +public: + explicit FlightSqlAccessor(Array *array) + : Accessor(TARGET_TYPE), + array_(arrow::internal::checked_cast(array)) {} + + size_t GetColumnarData(ColumnBinding *binding, int64_t starting_row, + size_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) override { + return static_cast(this)->GetColumnarData_impl( + binding, starting_row, cells, value_offset, update_value_offset, + diagnostics, row_status_array); + } + + size_t GetCellLength(ColumnBinding *binding) const override { + return static_cast(this)->GetCellLength_impl(binding); + } + +protected: + size_t GetColumnarData_impl(ColumnBinding *binding, int64_t starting_row, int64_t cells, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) { + for (int64_t i = 0; i < cells; ++i) { + int64_t current_arrow_row = starting_row + i; + if (array_->IsNull(current_arrow_row)) { + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = odbcabstraction::NULL_DATA; + } else { + throw odbcabstraction::NullWithoutIndicatorException(); + } + } else { + // TODO: Optimize this by creating different versions of MoveSingleCell + // depending on if strlen_buffer is null. + auto row_status = MoveSingleCell( + binding, current_arrow_row, i, value_offset, update_value_offset, + diagnostics); + if (row_status_array) { + row_status_array[i] = row_status; + } + } + } + + return static_cast(cells); + } + + inline ARROW_ARRAY *GetArray() { + return array_; + } + +private: + ARROW_ARRAY *array_; + + odbcabstraction::RowStatus MoveSingleCell(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + return static_cast(this)->MoveSingleCell_impl(binding, arrow_row, i, + value_offset, update_value_offset, diagnostics); + } + + odbcabstraction::RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, + int64_t i, int64_t &value_offset, bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + std::stringstream ss; + ss << "Unknown type conversion from StringArray to target C type " + << TARGET_TYPE; + throw odbcabstraction::DriverException(ss.str()); + } +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.cc new file mode 100644 index 0000000000000..76e9e441773a1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.cc @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "address_info.h" + +namespace driver { + +bool AddressInfo::GetAddressInfo(const std::string &host, char *host_name_info, int64_t max_host) { + if (addrinfo_result_) { + freeaddrinfo(addrinfo_result_); + addrinfo_result_ = nullptr; + } + + int error; + error = getaddrinfo(host.c_str(), NULL, NULL, &addrinfo_result_); + + if (error != 0) { + return false; + } + + error = getnameinfo(addrinfo_result_->ai_addr, addrinfo_result_->ai_addrlen, host_name_info, + max_host, NULL, 0, 0); + return error == 0; +} + +AddressInfo::~AddressInfo() { + if (addrinfo_result_) { + freeaddrinfo(addrinfo_result_); + addrinfo_result_ = nullptr; + } +} + +AddressInfo::AddressInfo() : addrinfo_result_(nullptr) {} +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h new file mode 100644 index 0000000000000..dee58ebf291e4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#if !_WIN32 +#include +#endif + +namespace driver { + +class AddressInfo { +private: + struct addrinfo * addrinfo_result_; + +public: + AddressInfo(); + + ~AddressInfo(); + + bool GetAddressInfo(const std::string &host, char *host_name_info, int64_t max_host); +}; +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc new file mode 100644 index 0000000000000..564eb52790142 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "config/configuration.h" + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +static const std::string DEFAULT_DSN = "Apache Arrow Flight SQL"; +static const std::string DEFAULT_ENABLE_ENCRYPTION = TRUE_STR; +static const std::string DEFAULT_USE_CERT_STORE = TRUE_STR; +static const std::string DEFAULT_DISABLE_CERT_VERIFICATION = FALSE_STR; + +namespace { +std::string ReadDsnString(const std::string& dsn, const std::string& key, const std::string& dflt = "") +{ + #define BUFFER_SIZE (1024) + std::vector buf(BUFFER_SIZE); + + int ret = SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), dflt.c_str(), buf.data(), static_cast(buf.size()), "ODBC.INI"); + + if (ret > BUFFER_SIZE) + { + // If there wasn't enough space, try again with the right size buffer. + buf.resize(ret + 1); + ret = SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), dflt.c_str(), buf.data(), static_cast(buf.size()), "ODBC.INI"); + } + + return std::string(buf.data(), ret); +} + +void RemoveAllKnownKeys(std::vector& keys) { + // Remove all known DSN keys from the passed in set of keys, case insensitively. + keys.erase(std::remove_if(keys.begin(), keys.end(), [&](auto& x) { + return std::find_if(FlightSqlConnection::ALL_KEYS.begin(), FlightSqlConnection::ALL_KEYS.end(), [&](auto& s) { + return boost::iequals(x, s);}) != FlightSqlConnection::ALL_KEYS.end(); + }), keys.end()); +} + +std::vector ReadAllKeys(const std::string& dsn) +{ + std::vector buf(BUFFER_SIZE); + + int ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), static_cast(buf.size()), "ODBC.INI"); + + if (ret > BUFFER_SIZE) + { + // If there wasn't enough space, try again with the right size buffer. + buf.resize(ret + 1); + ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), static_cast(buf.size()), "ODBC.INI"); + } + + // When you pass NULL to SQLGetPrivateProfileString it gives back a \0 delimited list of all the keys. + // The below loop simply tokenizes all the keys and places them into a vector. + std::vector keys; + char* begin = buf.data(); + while (begin && *begin != '\0') { + char* cur; + for (cur = begin; *cur != '\0'; ++cur); + keys.emplace_back(begin, cur); + begin = ++cur; + } + return keys; +} +} + +Configuration::Configuration() +{ + // No-op. +} + +Configuration::~Configuration() +{ + // No-op. +} + +void Configuration::LoadDefaults() +{ + Set(FlightSqlConnection::DSN, DEFAULT_DSN); + Set(FlightSqlConnection::USE_ENCRYPTION, DEFAULT_ENABLE_ENCRYPTION); + Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, DEFAULT_USE_CERT_STORE); + Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, DEFAULT_DISABLE_CERT_VERIFICATION); +} + +void Configuration::LoadDsn(const std::string& dsn) +{ + Set(FlightSqlConnection::DSN, dsn); + Set(FlightSqlConnection::HOST, ReadDsnString(dsn, FlightSqlConnection::HOST)); + Set(FlightSqlConnection::PORT, ReadDsnString(dsn, FlightSqlConnection::PORT)); + Set(FlightSqlConnection::TOKEN, ReadDsnString(dsn, FlightSqlConnection::TOKEN)); + Set(FlightSqlConnection::UID, ReadDsnString(dsn, FlightSqlConnection::UID)); + Set(FlightSqlConnection::PWD, ReadDsnString(dsn, FlightSqlConnection::PWD)); + Set(FlightSqlConnection::USE_ENCRYPTION, + ReadDsnString(dsn, FlightSqlConnection::USE_ENCRYPTION, DEFAULT_ENABLE_ENCRYPTION)); + Set(FlightSqlConnection::TRUSTED_CERTS, ReadDsnString(dsn, FlightSqlConnection::TRUSTED_CERTS)); + Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + ReadDsnString(dsn, FlightSqlConnection::USE_SYSTEM_TRUST_STORE, DEFAULT_USE_CERT_STORE)); + Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, + ReadDsnString(dsn, FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, DEFAULT_DISABLE_CERT_VERIFICATION)); + + auto customKeys = ReadAllKeys(dsn); + RemoveAllKnownKeys(customKeys); + for (auto key : customKeys) { + Set(key, ReadDsnString(dsn, key)); + } +} + +void Configuration::Clear() +{ + this->properties.clear(); +} + +bool Configuration::IsSet(const std::string& key) const +{ + return 0 != this->properties.count(key); +} + +const std::string& Configuration::Get(const std::string& key) const +{ + const auto itr = this->properties.find(key); + if (itr == this->properties.cend()) { + static const std::string empty(""); + return empty; + } + return itr->second; +} + +void Configuration::Set(const std::string& key, const std::string& value) +{ + const std::string copy = boost::trim_copy(value); + if (!copy.empty()) { + this->properties[key] = value; + } +} + +const driver::odbcabstraction::Connection::ConnPropertyMap& Configuration::GetProperties() const +{ + return this->properties; +} + +std::vector Configuration::GetCustomKeys() const +{ + driver::odbcabstraction::Connection::ConnPropertyMap copyProps(properties); + for (auto& key : FlightSqlConnection::ALL_KEYS) { + copyProps.erase(key); + } + std::vector keys; + boost::copy(copyProps | boost::adaptors::map_keys, std::back_inserter(keys)); + return keys; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/config/connection_string_parser.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/connection_string_parser.cc new file mode 100644 index 0000000000000..c6dc8ce3e3e81 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/connection_string_parser.cc @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "config/connection_string_parser.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +ConnectionStringParser::ConnectionStringParser(Configuration& cfg): + cfg(cfg) +{ + // No-op. +} + +ConnectionStringParser::~ConnectionStringParser() +{ + // No-op. +} + +void ConnectionStringParser::ParseConnectionString(const char* str, size_t len, char delimiter) +{ + std::string connect_str(str, len); + + while (connect_str.rbegin() != connect_str.rend() && *connect_str.rbegin() == 0) + connect_str.erase(connect_str.size() - 1); + + while (!connect_str.empty()) + { + size_t attr_begin = connect_str.rfind(delimiter); + + if (attr_begin == std::string::npos) + attr_begin = 0; + else + ++attr_begin; + + size_t attr_eq_pos = connect_str.rfind('='); + + if (attr_eq_pos == std::string::npos) + attr_eq_pos = 0; + + if (attr_begin < attr_eq_pos) + { + const char* key_begin = connect_str.data() + attr_begin; + const char* key_end = connect_str.data() + attr_eq_pos; + std::string key(key_begin, key_end); + boost::algorithm::trim(key); + + const char* value_begin = connect_str.data() + attr_eq_pos + 1; + const char* value_end = connect_str.data() + connect_str.size(); + std::string value(value_begin, value_end); + boost::algorithm::trim(value); + + if (value[0] == '{' && value[value.size() - 1] == '}') { + value = value.substr(1, value.size() - 2); + } + + cfg.Set(key, value); + } + + if (!attr_begin) + break; + + connect_str.erase(attr_begin - 1); + } +} + +void ConnectionStringParser::ParseConnectionString(const std::string& str) +{ + ParseConnectionString(str.data(), str.size(), ';'); +} + +void ConnectionStringParser::ParseConfigAttributes(const char* str) +{ + size_t len = 0; + + // Getting list length. List is terminated by two '\0'. + while (str[len] || str[len + 1]) + ++len; + + ++len; + + ParseConnectionString(str, len, '\0'); +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc new file mode 100644 index 0000000000000..c3ef4134d773b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_auth_method.h" + +#include + +#include "flight_sql_connection.h" +#include + +#include +#include +#include + +#include + +using namespace driver::flight_sql; + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClient; +using arrow::flight::TimeoutDuration; +using driver::odbcabstraction::AuthenticationException; +using driver::odbcabstraction::CommunicationException; +using driver::odbcabstraction::Connection; + +namespace { +class NoOpAuthMethod : public FlightSqlAuthMethod { +public: + void Authenticate(FlightSqlConnection &connection, + FlightCallOptions &call_options) override { + // Do nothing + } +}; + +class NoOpClientAuthHandler : public arrow::flight::ClientAuthHandler { +public: + NoOpClientAuthHandler() {} + + arrow::Status Authenticate(arrow::flight::ClientAuthSender* outgoing, arrow::flight::ClientAuthReader* incoming) override { + // Write a blank string. The server should ignore this and just accept any Handshake request. + return outgoing->Write(std::string()); + } + + arrow::Status GetToken(std::string* token) override { + *token = std::string(); + return arrow::Status::OK(); + } +}; + +class UserPasswordAuthMethod : public FlightSqlAuthMethod { +public: + UserPasswordAuthMethod(FlightClient &client, std::string user, + std::string password) + : client_(client), user_(std::move(user)), + password_(std::move(password)) {} + + void Authenticate(FlightSqlConnection &connection, + FlightCallOptions &call_options) override { + FlightCallOptions auth_call_options; + const boost::optional &login_timeout = + connection.GetAttribute(Connection::LOGIN_TIMEOUT); + if (login_timeout && boost::get(*login_timeout) > 0) { + // ODBC's LOGIN_TIMEOUT attribute and FlightCallOptions.timeout use + // seconds as time unit. + double timeout_seconds = static_cast(boost::get(*login_timeout)); + if (timeout_seconds > 0) { + auth_call_options.timeout = TimeoutDuration{timeout_seconds}; + } + } + + Result> bearer_result = + client_.AuthenticateBasicToken(auth_call_options, user_, password_); + + if (!bearer_result.ok()) { + const auto& flightStatus = arrow::flight::FlightStatusDetail::UnwrapStatus(bearer_result.status()); + if (flightStatus != nullptr) { + if (flightStatus->code() == arrow::flight::FlightStatusCode::Unauthenticated) { + throw AuthenticationException("Failed to authenticate with user and password: " + + bearer_result.status().ToString()); + } else if (flightStatus->code() == arrow::flight::FlightStatusCode::Unavailable) { + throw CommunicationException(bearer_result.status().message()); + } + } + + throw odbcabstraction::DriverException(bearer_result.status().message()); + } + + call_options.headers.push_back(bearer_result.ValueOrDie()); + } + + std::string GetUser() override { return user_; } + +private: + FlightClient &client_; + std::string user_; + std::string password_; +}; + + class TokenAuthMethod : public FlightSqlAuthMethod { + private: + FlightClient &client_; + std::string token_; // this is the token the user provides + + public: + TokenAuthMethod(FlightClient &client, std::string token): client_{client}, token_{std::move(token)} {} + + void Authenticate(FlightSqlConnection &connection, FlightCallOptions &call_options) override { + // add the token to the headers + const std::pair token_header("authorization", "Bearer " + token_); + call_options.headers.push_back(token_header); + + const arrow::Status status = client_.Authenticate(call_options, std::unique_ptr(new NoOpClientAuthHandler())); + if (!status.ok()) { + const auto& flightStatus = arrow::flight::FlightStatusDetail::UnwrapStatus(status); + if (flightStatus != nullptr) { + if (flightStatus->code() == arrow::flight::FlightStatusCode::Unauthenticated) { + throw AuthenticationException("Failed to authenticate with token: " + token_ + " Message: " + status.message()); + } else if (flightStatus->code() == arrow::flight::FlightStatusCode::Unavailable) { + throw CommunicationException(status.message()); + } + } + throw odbcabstraction::DriverException(status.message()); + } + } + }; +} // namespace + +std::unique_ptr FlightSqlAuthMethod::FromProperties( + const std::unique_ptr &client, + const Connection::ConnPropertyMap &properties) { + + // Check if should use user-password authentication + auto it_user = properties.find(FlightSqlConnection::USER); + if (it_user == properties.end()) { + // The Microsoft OLE DB to ODBC bridge provider (MSDASQL) will write + // "User ID" and "Password" properties instead of mapping + // to ODBC compliant UID/PWD keys. + it_user = properties.find(FlightSqlConnection::USER_ID); + } + + auto it_password = properties.find(FlightSqlConnection::PASSWORD); + auto it_token = properties.find(FlightSqlConnection::TOKEN); + + if (it_user == properties.end() || it_password == properties.end()) { + // Accept UID/PWD as aliases for User/Password. These are suggested as + // standard properties in the documentation for SQLDriverConnect. + it_user = properties.find(FlightSqlConnection::UID); + it_password = properties.find(FlightSqlConnection::PWD); + } + if (it_user != properties.end() || it_password != properties.end()) { + const std::string &user = + it_user != properties.end() + ? it_user->second + : ""; + const std::string &password = + it_password != properties.end() + ? it_password->second + : ""; + + return std::unique_ptr( + new UserPasswordAuthMethod(*client, user, password)); + } else if (it_token != properties.end()) { + const auto& token = it_token->second; + return std::unique_ptr(new TokenAuthMethod(*client, token)); + } + + return std::unique_ptr(new NoOpAuthMethod); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.h new file mode 100644 index 0000000000000..b07b3f4cea75d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlAuthMethod { +public: + virtual ~FlightSqlAuthMethod() = default; + + virtual void Authenticate(FlightSqlConnection &connection, + arrow::flight::FlightCallOptions &call_options) = 0; + + virtual std::string GetUser() { return std::string(); } + + static std::unique_ptr FromProperties( + const std::unique_ptr &client, + const odbcabstraction::Connection::ConnPropertyMap &properties); + +protected: + FlightSqlAuthMethod() = default; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc new file mode 100644 index 0000000000000..662bc502b9194 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_connection.h" + +#include +#include + +#include +#include +#include "address_info.h" +#include "flight_sql_auth_method.h" +#include "flight_sql_statement.h" +#include "flight_sql_ssl_config.h" +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "system_trust_store.h" + +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::Status; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClient; +using arrow::flight::FlightClientOptions; +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using arrow::flight::sql::FlightSqlClient; +using driver::odbcabstraction::AsBool; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::DriverException; +using driver::odbcabstraction::CommunicationException; +using driver::odbcabstraction::OdbcVersion; +using driver::odbcabstraction::Statement; + +const std::string FlightSqlConnection::DSN = "dsn"; +const std::string FlightSqlConnection::DRIVER = "driver"; +const std::string FlightSqlConnection::HOST = "host"; +const std::string FlightSqlConnection::PORT = "port"; +const std::string FlightSqlConnection::USER = "user"; +const std::string FlightSqlConnection::USER_ID = "user id"; +const std::string FlightSqlConnection::UID = "uid"; +const std::string FlightSqlConnection::PASSWORD = "password"; +const std::string FlightSqlConnection::PWD = "pwd"; +const std::string FlightSqlConnection::TOKEN = "token"; +const std::string FlightSqlConnection::USE_ENCRYPTION = "useEncryption"; +const std::string FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION = "disableCertificateVerification"; +const std::string FlightSqlConnection::TRUSTED_CERTS = "trustedCerts"; +const std::string FlightSqlConnection::USE_SYSTEM_TRUST_STORE = "useSystemTrustStore"; +const std::string FlightSqlConnection::STRING_COLUMN_LENGTH = "StringColumnLength"; +const std::string FlightSqlConnection::USE_WIDE_CHAR = "UseWideChar"; +const std::string FlightSqlConnection::CHUNK_BUFFER_CAPACITY = "ChunkBufferCapacity"; + +const std::vector FlightSqlConnection::ALL_KEYS = { + FlightSqlConnection::DSN, FlightSqlConnection::DRIVER, FlightSqlConnection::HOST, FlightSqlConnection::PORT, + FlightSqlConnection::TOKEN, FlightSqlConnection::UID, FlightSqlConnection::USER_ID, FlightSqlConnection::PWD, + FlightSqlConnection::USE_ENCRYPTION, FlightSqlConnection::TRUSTED_CERTS, FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, FlightSqlConnection::STRING_COLUMN_LENGTH, + FlightSqlConnection::USE_WIDE_CHAR, FlightSqlConnection::CHUNK_BUFFER_CAPACITY}; + +namespace { + +#if _WIN32 || _WIN64 +constexpr auto SYSTEM_TRUST_STORE_DEFAULT = true; +constexpr auto STORES = { + "CA", + "MY", + "ROOT", + "SPC" +}; + +inline std::string GetCerts() { + std::string certs; + + for (auto store : STORES) { + std::shared_ptr cert_iterator = std::make_shared(store); + + if (!cert_iterator->SystemHasStore()) { + // If the system does not have the specific store, we skip it using the continue. + continue; + } + while (cert_iterator->HasNext()) { + certs += cert_iterator->GetNext(); + } + } + + return certs; +} + +#else + +constexpr auto SYSTEM_TRUST_STORE_DEFAULT = false; +inline std::string GetCerts() { + return ""; +} + +#endif + +const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::HOST, + FlightSqlConnection::PORT, + FlightSqlConnection::USER, + FlightSqlConnection::USER_ID, + FlightSqlConnection::UID, + FlightSqlConnection::PASSWORD, + FlightSqlConnection::PWD, + FlightSqlConnection::TOKEN, + FlightSqlConnection::USE_ENCRYPTION, + FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, + FlightSqlConnection::TRUSTED_CERTS, + FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + FlightSqlConnection::STRING_COLUMN_LENGTH, + FlightSqlConnection::USE_WIDE_CHAR +}; + +Connection::ConnPropertyMap::const_iterator +TrackMissingRequiredProperty(const std::string &property, + const Connection::ConnPropertyMap &properties, + std::vector &missing_attr) { + auto prop_iter = + properties.find(property); + if (properties.end() == prop_iter) { + missing_attr.push_back(property); + } + return prop_iter; +} +} // namespace + +std::shared_ptr LoadFlightSslConfigs(const Connection::ConnPropertyMap &connPropertyMap) { + bool use_encryption = AsBool(connPropertyMap, FlightSqlConnection::USE_ENCRYPTION).value_or(true); + bool disable_cert = AsBool(connPropertyMap, FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION).value_or(false); + bool use_system_trusted = AsBool(connPropertyMap, FlightSqlConnection::USE_SYSTEM_TRUST_STORE).value_or(SYSTEM_TRUST_STORE_DEFAULT); + + auto trusted_certs_iterator = connPropertyMap.find( + FlightSqlConnection::TRUSTED_CERTS); + auto trusted_certs = + trusted_certs_iterator != connPropertyMap.end() ? trusted_certs_iterator->second : ""; + + return std::make_shared(disable_cert, trusted_certs, + use_system_trusted, use_encryption); +} + +void FlightSqlConnection::Connect(const ConnPropertyMap &properties, + std::vector &missing_attr) { + try { + auto flight_ssl_configs = LoadFlightSslConfigs(properties); + + Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); + FlightClientOptions client_options = + BuildFlightClientOptions(properties, missing_attr, + flight_ssl_configs); + + const std::shared_ptr + &cookie_factory = arrow::flight::GetCookieFactory(); + client_options.middleware.push_back(cookie_factory); + + std::unique_ptr flight_client; + ThrowIfNotOK( + FlightClient::Connect(location, client_options, &flight_client)); + + std::unique_ptr auth_method = + FlightSqlAuthMethod::FromProperties(flight_client, properties); + auth_method->Authenticate(*this, call_options_); + + sql_client_.reset(new FlightSqlClient(std::move(flight_client))); + closed_ = false; + + // Note: This should likely come from Flight instead of being from the + // connection properties to allow reporting a user for other auth mechanisms + // and also decouple the database user from user credentials. + info_.SetProperty(SQL_USER_NAME, auth_method->GetUser()); + attribute_[CONNECTION_DEAD] = static_cast(SQL_FALSE); + + PopulateMetadataSettings(properties); + PopulateCallOptions(properties); + } catch (...) { + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); + sql_client_.reset(); + + throw; + } +} + +void FlightSqlConnection::PopulateMetadataSettings(const Connection::ConnPropertyMap &conn_property_map) { + metadata_settings_.string_column_length_ = GetStringColumnLength(conn_property_map); + metadata_settings_.use_wide_char_ = GetUseWideChar(conn_property_map); + metadata_settings_.chunk_buffer_capacity_ = GetChunkBufferCapacity(conn_property_map); +} + +boost::optional FlightSqlConnection::GetStringColumnLength(const Connection::ConnPropertyMap &conn_property_map) { + const int32_t min_string_column_length = 1; + + try { + return AsInt32(min_string_column_length, conn_property_map, FlightSqlConnection::STRING_COLUMN_LENGTH); + } catch (const std::exception& e) { + diagnostics_.AddWarning( + std::string("Invalid value for connection property " + FlightSqlConnection::STRING_COLUMN_LENGTH + + ". Please ensure it has a valid numeric value. Message: " + e.what()), + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + } + + return boost::none; +} + +bool FlightSqlConnection::GetUseWideChar(const ConnPropertyMap &connPropertyMap) { + #if defined _WIN32 || defined _WIN64 + // Windows should use wide chars by default + bool default_value = true; + #else + // Mac and Linux should not use wide chars by default + bool default_value = false; +#endif + return AsBool(connPropertyMap, FlightSqlConnection::USE_WIDE_CHAR).value_or(default_value); +} + +size_t FlightSqlConnection::GetChunkBufferCapacity(const ConnPropertyMap &connPropertyMap) { + size_t default_value = 5; + try { + return AsInt32(1, connPropertyMap, FlightSqlConnection::CHUNK_BUFFER_CAPACITY).value_or(default_value); + } catch (const std::exception& e) { + diagnostics_.AddWarning( + std::string("Invalid value for connection property " + FlightSqlConnection::CHUNK_BUFFER_CAPACITY + + ". Please ensure it has a valid numeric value. Message: " + e.what()), + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + } + + return default_value; +} + +const FlightCallOptions & +FlightSqlConnection::PopulateCallOptions(const ConnPropertyMap &props) { + // Set CONNECTION_TIMEOUT attribute or LOGIN_TIMEOUT depending on if this + // is the first request. + const boost::optional &connection_timeout = closed_ ? + GetAttribute(LOGIN_TIMEOUT) : GetAttribute(CONNECTION_TIMEOUT); + if (connection_timeout && boost::get(*connection_timeout) > 0) { + call_options_.timeout = + TimeoutDuration{static_cast(boost::get(*connection_timeout))}; + } + + for (auto prop : props) { + if (BUILT_IN_PROPERTIES.count(prop.first) != 0) { + continue; + } + + if (prop.first.find(' ') != std::string::npos) { + // Connection properties containing spaces will crash gRPC, but some tools + // such as the OLE DB to ODBC bridge generate unused properties containing spaces. + diagnostics_.AddWarning( + std::string("Ignoring connection option " + prop.first) + + ". Server-specific options must be valid HTTP header names and " + + "cannot contain spaces.", + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + continue; + } + + // Note: header names must be lower case for gRPC. + // gRPC will crash if they are not lower-case. + std::string key_lc = boost::algorithm::to_lower_copy(prop.first); + call_options_.headers.emplace_back(std::make_pair(key_lc, prop.second)); + } + + return call_options_; +} + +FlightClientOptions +FlightSqlConnection::BuildFlightClientOptions(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config) { + FlightClientOptions options; + // Persist state information using cookies if the FlightProducer supports it. + options.middleware.push_back(arrow::flight::GetCookieFactory()); + + if (ssl_config->useEncryption()) { + if (ssl_config->shouldDisableCertificateVerification()) { + options.disable_server_verification = ssl_config->shouldDisableCertificateVerification(); + } else { + if (ssl_config->useSystemTrustStore()) { + const std::string certs = GetCerts(); + + options.tls_root_certs = certs; + } else if (!ssl_config->getTrustedCerts().empty()) { + flight::CertKeyPair cert_key_pair; + ssl_config->populateOptionsWithCerts(&cert_key_pair); + options.tls_root_certs = cert_key_pair.pem_cert; + } + } + } + + return std::move(options); +} + +Location +FlightSqlConnection::BuildLocation(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config) { + const auto &host_iter = + TrackMissingRequiredProperty(HOST, properties, missing_attr); + + const auto &port_iter = + TrackMissingRequiredProperty(PORT, properties, missing_attr); + + if (!missing_attr.empty()) { + std::string missing_attr_str = + std::string("Missing required properties: ") + + boost::algorithm::join(missing_attr, ", "); + throw DriverException(missing_attr_str); + } + + const std::string &host = host_iter->second; + const int &port = boost::lexical_cast(port_iter->second); + + Location location; + if (ssl_config->useEncryption()) { + AddressInfo address_info; + char host_name_info[NI_MAXHOST] = ""; + bool operation_result = false; + + try { + auto ip_address = boost::asio::ip::make_address(host); + // We should only attempt to resolve the hostname from the IP if the given + // HOST input is an IP address. + if (ip_address.is_v4() || ip_address.is_v6()) { + operation_result = address_info.GetAddressInfo(host, host_name_info, + NI_MAXHOST); + if (operation_result) { + ThrowIfNotOK(Location::ForGrpcTls(host_name_info, port, &location)); + return location; + } + // TODO: We should log that we could not convert an IP to hostname here. + } + } + catch (...) { + // This is expected. The Host attribute can be an IP or name, but make_address will throw + // if it is not an IP. + } + + ThrowIfNotOK(Location::ForGrpcTls(host, port, &location)); + return location; + } + + ThrowIfNotOK(Location::ForGrpcTcp(host, port, &location)); + return location; +} + +void FlightSqlConnection::Close() { + if (closed_) { + throw DriverException("Connection already closed."); + } + + sql_client_.reset(); + closed_ = true; + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); +} + +std::shared_ptr FlightSqlConnection::CreateStatement() { + return std::shared_ptr( + new FlightSqlStatement( + diagnostics_, + *sql_client_, + call_options_, + metadata_settings_ + ) + ); +} + +bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, + const Connection::Attribute &value) { + switch (attribute) { + case ACCESS_MODE: + // We will always return read-write. + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_MODE_READ_WRITE)); + case PACKET_SIZE: + return CheckIfSetToOnlyValidValue(value, static_cast(0)); + default: + attribute_[attribute] = value; + return true; + } +} + +boost::optional +FlightSqlConnection::GetAttribute(Connection::AttributeId attribute) { + switch (attribute) { + case ACCESS_MODE: + // FlightSQL does not provide this metadata. + return boost::make_optional(Attribute(static_cast(SQL_MODE_READ_WRITE))); + case PACKET_SIZE: + return boost::make_optional(Attribute(static_cast(0))); + default: + const auto &it = attribute_.find(attribute); + return boost::make_optional(it != attribute_.end(), it->second); + } +} + +Connection::Info FlightSqlConnection::GetInfo(uint16_t info_type) { + auto result = info_.GetInfo(info_type); + if (info_type == SQL_DBMS_NAME || info_type == SQL_SERVER_NAME) { + // Update the database component reported in error messages. + // We do this lazily for performance reasons. + diagnostics_.SetDataSourceComponent(boost::get(result)); + } + return result; +} + +FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string &driver_version) + : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), + odbc_version_(odbc_version), info_(call_options_, sql_client_, driver_version), + closed_(true) { + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); + attribute_[LOGIN_TIMEOUT] = static_cast(0); + attribute_[CONNECTION_TIMEOUT] = static_cast(0); + attribute_[CURRENT_CATALOG] = ""; +} +odbcabstraction::Diagnostics &FlightSqlConnection::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlConnection::SetClosed(bool is_closed) { + closed_ = is_closed; +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h new file mode 100644 index 0000000000000..71605a2776eb1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include + +#include "get_info_cache.h" +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +class FlightSqlSslConfig; + +/// \brief Create an instance of the FlightSqlSslConfig class, from the properties passed +/// into the map. +/// \param connPropertyMap the map with the Connection properties. +/// \return An instance of the FlightSqlSslConfig. +std::shared_ptr LoadFlightSslConfigs( + const odbcabstraction::Connection::ConnPropertyMap &connPropertyMap); + + +class FlightSqlConnection : public odbcabstraction::Connection { + +private: + odbcabstraction::MetadataSettings metadata_settings_; + std::map attribute_; + arrow::flight::FlightClientOptions client_options_; + arrow::flight::FlightCallOptions call_options_; + std::unique_ptr sql_client_; + GetInfoCache info_; + odbcabstraction::Diagnostics diagnostics_; + odbcabstraction::OdbcVersion odbc_version_; + bool closed_; + + void PopulateMetadataSettings(const Connection::ConnPropertyMap &connPropertyMap); + +public: + static const std::vector ALL_KEYS; + static const std::string DSN; + static const std::string DRIVER; + static const std::string HOST; + static const std::string PORT; + static const std::string USER; + static const std::string UID; + static const std::string USER_ID; + static const std::string PASSWORD; + static const std::string PWD; + static const std::string TOKEN; + static const std::string USE_ENCRYPTION; + static const std::string DISABLE_CERTIFICATE_VERIFICATION; + static const std::string TRUSTED_CERTS; + static const std::string USE_SYSTEM_TRUST_STORE; + static const std::string STRING_COLUMN_LENGTH; + static const std::string USE_WIDE_CHAR; + static const std::string CHUNK_BUFFER_CAPACITY; + + explicit FlightSqlConnection(odbcabstraction::OdbcVersion odbc_version, const std::string &driver_version = "0.9.0.0"); + + void Connect(const ConnPropertyMap &properties, + std::vector &missing_attr) override; + + void Close() override; + + std::shared_ptr CreateStatement() override; + + bool SetAttribute(AttributeId attribute, const Attribute &value) override; + + boost::optional + GetAttribute(Connection::AttributeId attribute) override; + + Info GetInfo(uint16_t info_type) override; + + /// \brief Builds a Location used for FlightClient connection. + /// \note Visible for testing + static arrow::flight::Location + BuildLocation(const ConnPropertyMap &properties, std::vector &missing_attr, + const std::shared_ptr& ssl_config); + + /// \brief Builds a FlightClientOptions used for FlightClient connection. + /// \note Visible for testing + static arrow::flight::FlightClientOptions + BuildFlightClientOptions(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config); + + /// \brief Builds a FlightCallOptions used on gRPC calls. + /// \note Visible for testing + const arrow::flight::FlightCallOptions &PopulateCallOptions(const ConnPropertyMap &properties); + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + /// \brief A setter to the field closed_. + /// \note Visible for testing + void SetClosed(bool is_closed); + + boost::optional GetStringColumnLength(const ConnPropertyMap &connPropertyMap); + + bool GetUseWideChar(const ConnPropertyMap &connPropertyMap); + + size_t GetChunkBufferCapacity(const ConnPropertyMap &connPropertyMap); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc new file mode 100644 index 0000000000000..51ca84ecfae66 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_connection.h" + +#include + +#include "gtest/gtest.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using odbcabstraction::Connection; + +TEST(AttributeTests, SetAndGetAttribute) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(200)); + const boost::optional firstValue = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + + EXPECT_TRUE(firstValue); + + EXPECT_EQ(boost::get(*firstValue), static_cast(200)); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(300)); + + const boost::optional changeValue = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + + EXPECT_TRUE(changeValue); + EXPECT_EQ(boost::get(*changeValue), static_cast(300)); + + connection.Close(); +} + +TEST(AttributeTests, GetAttributeWithoutSetting) { + FlightSqlConnection connection(odbcabstraction::V_3); + + const boost::optional optional = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + connection.SetClosed(false); + + EXPECT_EQ(0, boost::get(*optional)); + + connection.Close(); +} + +TEST(MetadataSettingsTest, StringColumnLengthTest) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + const int32_t expected_string_column_length = 100000; + + const Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, // expect not used + {FlightSqlConnection::PORT, std::string("32010")}, // expect not used + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, // expect not used + {FlightSqlConnection::STRING_COLUMN_LENGTH, std::to_string(expected_string_column_length)}, + }; + + const boost::optional actual_string_column_length = connection.GetStringColumnLength(properties); + + EXPECT_TRUE(actual_string_column_length); + EXPECT_EQ(expected_string_column_length, *actual_string_column_length); + + connection.Close(); +} + +TEST(MetadataSettingsTest, UseWideCharTest) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + const Connection::ConnPropertyMap properties1 = { + {FlightSqlConnection::USE_WIDE_CHAR, std::string("true")}, + }; + const Connection::ConnPropertyMap properties2 = { + {FlightSqlConnection::USE_WIDE_CHAR, std::string("false")}, + }; + + EXPECT_EQ(true, connection.GetUseWideChar(properties1)); + EXPECT_EQ(false, connection.GetUseWideChar(properties2)); + + connection.Close(); +} + +TEST(BuildLocationTests, ForTcp) { + std::vector missing_attr; + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + }; + + const std::shared_ptr &ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location1 = + FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); + const Location &actual_location2 = FlightSqlConnection::BuildLocation( + { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32011")}, + }, + missing_attr, ssl_config); + + Location expected_location; + ASSERT_TRUE( + Location::ForGrpcTcp("localhost", 32010, &expected_location).ok()); + ASSERT_EQ(expected_location, actual_location1); + ASSERT_NE(expected_location, actual_location2); +} + +TEST(BuildLocationTests, ForTls) { + std::vector missing_attr; + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + }; + + const std::shared_ptr &ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location1 = + FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); + + Connection::ConnPropertyMap second_properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32011")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + }; + + const std::shared_ptr &second_ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location2 = FlightSqlConnection::BuildLocation( + second_properties, missing_attr, ssl_config); + + Location expected_location; + ASSERT_TRUE( + Location::ForGrpcTls("localhost", 32010, &expected_location).ok()); + ASSERT_EQ(expected_location, actual_location1); + ASSERT_NE(expected_location, actual_location2); +} + +TEST(PopulateCallOptionsTest, ConnectionTimeout) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + // Expect default timeout to be -1 + ASSERT_EQ(TimeoutDuration{-1.0}, + connection.PopulateCallOptions(Connection::ConnPropertyMap()).timeout); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(10)); + ASSERT_EQ(TimeoutDuration{10.0}, + connection.PopulateCallOptions(Connection::ConnPropertyMap()).timeout); +} + +TEST(PopulateCallOptionsTest, GenericOption) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + Connection::ConnPropertyMap properties; + properties["Foo"] = "Bar"; + auto options = connection.PopulateCallOptions(properties); + auto headers = options.headers; + ASSERT_EQ(1, headers.size()); + + // Header name must be lower-case because gRPC will crash if it is not lower-case. + ASSERT_EQ("foo", headers[0].first); + + // Header value should preserve case. + ASSERT_EQ("Bar", headers[0].second); +} + +TEST(PopulateCallOptionsTest, GenericOptionWithSpaces) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + Connection::ConnPropertyMap properties; + properties["Persist Security Info"] = "False"; + auto options = connection.PopulateCallOptions(properties); + auto headers = options.headers; + // Header names with spaces must be omitted or gRPC will crash. + ASSERT_TRUE(headers.empty()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc new file mode 100644 index 0000000000000..ff1fbf78932c7 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_connection.h" +#include "odbcabstraction/utils.h" +#include +#include +#include + + +#define DEFAULT_MAXIMUM_FILE_SIZE 16777216 +#define CONFIG_FILE_NAME "arrow-odbc.ini" + +namespace driver { +namespace flight_sql { + +using odbcabstraction::Connection; +using odbcabstraction::OdbcVersion; +using odbcabstraction::LogLevel; +using odbcabstraction::SPDLogger; + +namespace { + LogLevel ToLogLevel(int64_t level) { + switch (level) { + case 0: + return LogLevel::LogLevel_TRACE; + case 1: + return LogLevel::LogLevel_DEBUG; + case 2: + return LogLevel::LogLevel_INFO; + case 3: + return LogLevel::LogLevel_WARN; + case 4: + return LogLevel::LogLevel_ERROR; + default: + return LogLevel::LogLevel_OFF; + } + } +} + +FlightSqlDriver::FlightSqlDriver() + : diagnostics_("Apache Arrow", "Flight SQL", OdbcVersion::V_3), + version_("0.9.0.0") +{} + +std::shared_ptr +FlightSqlDriver::CreateConnection(OdbcVersion odbc_version) { + return std::make_shared(odbc_version, version_); +} + +odbcabstraction::Diagnostics &FlightSqlDriver::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlDriver::SetVersion(std::string version) { + version_ = std::move(version); +} + +void FlightSqlDriver::RegisterLog() { + odbcabstraction::PropertyMap propertyMap; + driver::odbcabstraction::ReadConfigFile(propertyMap, CONFIG_FILE_NAME); + + auto log_enable_iterator = propertyMap.find(SPDLogger::LOG_ENABLED); + auto log_enabled = log_enable_iterator != propertyMap.end() ? + odbcabstraction::AsBool(log_enable_iterator->second) : false; + if (!log_enabled) { + return; + } + + auto log_path_iterator = propertyMap.find(SPDLogger::LOG_PATH); + auto log_path = + log_path_iterator != propertyMap.end() ? log_path_iterator->second : ""; + if (log_path.empty()) { + return; + } + + auto log_level_iterator = propertyMap.find(SPDLogger::LOG_LEVEL); + auto log_level = + ToLogLevel(log_level_iterator != propertyMap.end() ? std::stoi(log_level_iterator->second) : 1); + if (log_level == odbcabstraction::LogLevel_OFF) { + return; + } + + auto maximum_file_size_iterator = propertyMap.find(SPDLogger::MAXIMUM_FILE_SIZE); + auto maximum_file_size = maximum_file_size_iterator != propertyMap.end() ? + std::stoi(maximum_file_size_iterator->second) : DEFAULT_MAXIMUM_FILE_SIZE; + + auto maximum_file_quantity_iterator = propertyMap.find(SPDLogger::FILE_QUANTITY); + auto maximum_file_quantity = + maximum_file_quantity_iterator != propertyMap.end() ? std::stoi( + maximum_file_quantity_iterator->second) : 1; + + std::unique_ptr logger(new odbcabstraction::SPDLogger()); + + logger->init(maximum_file_quantity, maximum_file_size, + log_path, log_level); + odbcabstraction::Logger::SetInstance(std::move(logger)); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc new file mode 100644 index 0000000000000..99d083c48d596 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_get_tables_reader.h" +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +namespace driver { +namespace flight_sql { + +using arrow::internal::checked_pointer_cast; +using arrow::util::nullopt; + +GetTablesReader::GetTablesReader(std::shared_ptr record_batch) + : record_batch_(std::move(record_batch)), current_row_(-1) {} + +bool GetTablesReader::Next() { + return ++current_row_ < record_batch_->num_rows(); +} + +optional GetTablesReader::GetCatalogName() { + const auto &array = + checked_pointer_cast(record_batch_->column(0)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTablesReader::GetDbSchemaName() { + const auto &array = + checked_pointer_cast(record_batch_->column(1)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +std::string GetTablesReader::GetTableName() { + const auto &array = + checked_pointer_cast(record_batch_->column(2)); + + return array->GetString(current_row_); +} + +std::string GetTablesReader::GetTableType() { + const auto &array = + checked_pointer_cast(record_batch_->column(3)); + + return array->GetString(current_row_); +} + +std::shared_ptr GetTablesReader::GetSchema() { + const auto &array = + checked_pointer_cast(record_batch_->column(4)); + if (array == nullptr) { + return nullptr; + } + + io::BufferReader dataset_schema_reader(array->GetView(current_row_)); + ipc::DictionaryMemo in_memo; + const Result> &result = + ReadSchema(&dataset_schema_reader, &in_memo); + if (!result.ok()) { + // TODO: Ignoring this error until we fix the problem on Dremio server + // The problem is that complex types columns are being returned without the children types. + return nullptr; + } + + return result.ValueOrDie(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.h new file mode 100644 index 0000000000000..c7f317cefe8e3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "record_batch_transformer.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using std::optional; + +class GetTablesReader { +private: + std::shared_ptr record_batch_; + int64_t current_row_; + +public: + explicit GetTablesReader(std::shared_ptr record_batch); + + bool Next(); + + optional GetCatalogName(); + + optional GetDbSchemaName(); + + std::string GetTableName(); + + std::string GetTableType(); + + std::shared_ptr GetSchema(); +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.cc new file mode 100644 index 0000000000000..f55e8f46a7ec3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.cc @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_get_type_info_reader.h" +#include +#include +#include +#include "utils.h" + +#include + +namespace driver { +namespace flight_sql { + +using arrow::internal::checked_pointer_cast; +using arrow::util::nullopt; + +GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr record_batch) + : record_batch_(std::move(record_batch)), current_row_(-1) {} + +bool GetTypeInfoReader::Next() { + return ++current_row_ < record_batch_->num_rows(); +} + +std::string GetTypeInfoReader::GetTypeName() { + const auto &array = + checked_pointer_cast(record_batch_->column(0)); + + return array->GetString(current_row_); +} + +int32_t GetTypeInfoReader::GetDataType() { + const auto &array = + checked_pointer_cast(record_batch_->column(1)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetColumnSize() { + const auto &array = + checked_pointer_cast(record_batch_->column(2)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetLiteralPrefix() { + const auto &array = + checked_pointer_cast(record_batch_->column(3)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTypeInfoReader::GetLiteralSuffix() { + const auto &array = + checked_pointer_cast(record_batch_->column(4)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional> GetTypeInfoReader::GetCreateParams() { + const auto &array = + checked_pointer_cast(record_batch_->column(5)); + + if (array->IsNull(current_row_)) + return nullopt; + + int values_length = array->value_length(current_row_); + int start_offset = array->value_offset(current_row_); + const auto &values_array = checked_pointer_cast(array->values()); + + std::vector result(values_length); + for (int i = 0; i < values_length; ++i) { + result[i] = values_array->GetString(start_offset + i); + } + + return result; +} + +int32_t GetTypeInfoReader::GetNullable() { + const auto &array = + checked_pointer_cast(record_batch_->column(6)); + + return array->GetView(current_row_); +} + +bool GetTypeInfoReader::GetCaseSensitive() { + const auto &array = + checked_pointer_cast(record_batch_->column(7)); + + return array->GetView(current_row_); +} + +int32_t GetTypeInfoReader::GetSearchable() { + const auto &array = + checked_pointer_cast(record_batch_->column(8)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetUnsignedAttribute() { + const auto &array = + checked_pointer_cast(record_batch_->column(9)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +bool GetTypeInfoReader::GetFixedPrecScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(10)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetAutoIncrement() { + const auto &array = + checked_pointer_cast(record_batch_->column(11)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetLocalTypeName() { + const auto &array = + checked_pointer_cast(record_batch_->column(12)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTypeInfoReader::GetMinimumScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(13)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetMaximumScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(14)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +int32_t GetTypeInfoReader::GetSqlDataType() { + const auto &array = + checked_pointer_cast(record_batch_->column(15)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetDatetimeSubcode() { + const auto &array = + checked_pointer_cast(record_batch_->column(16)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetNumPrecRadix() { + const auto &array = + checked_pointer_cast(record_batch_->column(17)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetIntervalPrecision() { + const auto &array = + checked_pointer_cast(record_batch_->column(18)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.h new file mode 100644 index 0000000000000..8ac5d18cbe926 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_type_info_reader.h @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "record_batch_transformer.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using std::optional; + +class GetTypeInfoReader { +private: + std::shared_ptr record_batch_; + int64_t current_row_; + +public: + explicit GetTypeInfoReader(std::shared_ptr record_batch); + + bool Next(); + + std::string GetTypeName(); + + int32_t GetDataType(); + + optional GetColumnSize(); + + optional GetLiteralPrefix(); + + optional GetLiteralSuffix(); + + optional> GetCreateParams(); + + int32_t GetNullable(); + + bool GetCaseSensitive(); + + int32_t GetSearchable(); + + optional GetUnsignedAttribute(); + + bool GetFixedPrecScale(); + + optional GetAutoIncrement(); + + optional GetLocalTypeName(); + + optional GetMinimumScale(); + + optional GetMaximumScale(); + + int32_t GetSqlDataType(); + + optional GetDatetimeSubcode(); + + optional GetNumPrecRadix(); + + optional GetIntervalPrecision(); + +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc new file mode 100644 index 0000000000000..7f7eac9d6a3dd --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_result_set.h" +#include + +#include +#include +#include + +#include "flight_sql_result_set_column.h" +#include "flight_sql_result_set_metadata.h" +#include "utils.h" +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; +using arrow::RecordBatch; +using arrow::Scalar; +using arrow::Status; +using arrow::flight::FlightEndpoint; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using odbcabstraction::CDataType; +using odbcabstraction::DriverException; + +FlightSqlResultSet::FlightSqlResultSet( + FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + const std::shared_ptr &transformer, + odbcabstraction::Diagnostics& diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) + : + metadata_settings_(metadata_settings), + chunk_buffer_(flight_sql_client, call_options, flight_info, metadata_settings_.chunk_buffer_capacity_), + transformer_(transformer), + metadata_(transformer ? new FlightSqlResultSetMetadata(transformer->GetTransformedSchema(), + metadata_settings_) + : new FlightSqlResultSetMetadata(flight_info, metadata_settings_)), + columns_(metadata_->GetColumnCount()), + get_data_offsets_(metadata_->GetColumnCount(), 0), + diagnostics_(diagnostics), + current_row_(0), num_binding_(0), reset_get_data_(false) { + current_chunk_.data = nullptr; + if (transformer_) { + schema_ = transformer_->GetTransformedSchema(); + } else { + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema_)); + } + + for (size_t i = 0; i < columns_.size(); ++i) { + columns_[i] = FlightSqlResultSetColumn(metadata_settings.use_wide_char_); + } +} + +size_t FlightSqlResultSet::Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) { + // Consider it might be the first call to Move() and current_chunk is not + // populated yet + assert(rows > 0); + if (current_chunk_.data == nullptr) { + if (!chunk_buffer_.GetNext(¤t_chunk_)) { + return 0; + } + + if (transformer_) { + current_chunk_.data = transformer_->Transform(current_chunk_.data); + } + + for (size_t column_num = 0; column_num < columns_.size(); ++column_num) { + columns_[column_num].ResetAccessor(current_chunk_.data->column(column_num)); + } + } + + // Reset GetData value offsets. + if (num_binding_ != get_data_offsets_.size() && reset_get_data_) { + std::fill(get_data_offsets_.begin(), get_data_offsets_.end(), 0); + } + + size_t fetched_rows = 0; + while (fetched_rows < rows) { + size_t batch_rows = current_chunk_.data->num_rows(); + size_t rows_to_fetch = + std::min(static_cast(rows - fetched_rows), + static_cast(batch_rows - current_row_)); + + if (rows_to_fetch == 0) { + if (!chunk_buffer_.GetNext(¤t_chunk_)) { + break; + } + + if (transformer_) { + current_chunk_.data = transformer_->Transform(current_chunk_.data); + } + + for (size_t column_num = 0; column_num < columns_.size(); ++column_num) { + columns_[column_num].ResetAccessor(current_chunk_.data->column(column_num)); + } + current_row_ = 0; + continue; + } + + for (auto & column : columns_) { + // There can be unbound columns. + if (!column.is_bound_) + continue; + + auto *accessor = column.GetAccessorForBinding(); + ColumnBinding shifted_binding = column.binding_; + uint16_t *shifted_row_status_array = row_status_array ? &row_status_array[fetched_rows] : nullptr; + + if (shifted_row_status_array) { + std::fill(shifted_row_status_array, &shifted_row_status_array[rows_to_fetch], odbcabstraction::RowStatus_SUCCESS); + } + + size_t accessor_rows = 0; + try { + if (!bind_type) { + // Columnar binding. Have the accessor convert multiple rows. + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + + accessor->GetCellLength(&shifted_binding) * fetched_rows + + bind_offset; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast( + &shifted_binding.strlen_buffer[fetched_rows]) + + bind_offset); + } + + int64_t value_offset = 0; + accessor_rows = accessor->GetColumnarData(&shifted_binding, current_row_, rows_to_fetch, value_offset, false, + diagnostics_, shifted_row_status_array); + } + else { + // Row-wise binding. Identify the base position of the buffer and indicator based on the bind offset, + // the number of already-fetched rows, and the bind_type holding the size of an application-side row. + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + bind_offset + + bind_type * fetched_rows; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast(shifted_binding.strlen_buffer) + + bind_offset + bind_type * fetched_rows); + } + + // Loop and run the accessor one-row-at-a-time. + for (size_t i = 0; i < rows_to_fetch; ++i) { + int64_t value_offset = 0; + + // Adjust offsets passed to the accessor as we fetch rows. + // Note that current_row_ is updated outside of this loop. + accessor_rows += accessor->GetColumnarData(&shifted_binding, current_row_ + i, 1, value_offset, false, + diagnostics_, shifted_row_status_array); + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + bind_type; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast(shifted_binding.strlen_buffer) + + bind_type); + } + + if (shifted_row_status_array) { + shifted_row_status_array++; + } + } + } + } catch (...) { + if (shifted_row_status_array) { + std::fill(shifted_row_status_array, &shifted_row_status_array[rows_to_fetch], odbcabstraction::RowStatus_ERROR); + } + throw; + } + + + if (rows_to_fetch != accessor_rows) { + throw DriverException( + "Expected the same number of rows for all columns"); + } + } + + current_row_ += static_cast(rows_to_fetch); + fetched_rows += rows_to_fetch; + } + + if (rows > fetched_rows && row_status_array) { + std::fill(&row_status_array[fetched_rows], &row_status_array[rows], odbcabstraction::RowStatus_NOROW); + } + return fetched_rows; +} + +void FlightSqlResultSet::Close() { + chunk_buffer_.Close(); + current_chunk_.data = nullptr; +} + +void FlightSqlResultSet::Cancel() { + chunk_buffer_.Close(); + current_chunk_.data = nullptr; +} + +bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, + int precision, int scale, void *buffer, + size_t buffer_length, ssize_t *strlen_buffer) { + reset_get_data_ = true; + // Check if the offset is already at the end. + int64_t& value_offset = get_data_offsets_[column_n - 1]; + if (value_offset == -1) { + return false; + } + + ColumnBinding binding(ConvertCDataTypeFromV2ToV3(target_type), precision, scale, buffer, buffer_length, + strlen_buffer); + + auto &column = columns_[column_n - 1]; + Accessor *accessor = column.GetAccessorForGetData(binding.target_type); + + + // Note: current_row_ is always positioned at the index _after_ the one we are + // on after calling Move(). So if we want to get data from the _last_ row + // fetched, we need to subtract one from the current row. + accessor->GetColumnarData(&binding, current_row_ - 1, 1, value_offset, true, diagnostics_, nullptr); + + // If there was truncation, the converter would have reported it to the diagnostics. + return diagnostics_.HasWarning(); +} + +std::shared_ptr FlightSqlResultSet::GetMetadata() { + return metadata_; +} + +void FlightSqlResultSet::BindColumn(int column_n, int16_t target_type, + int precision, int scale, void *buffer, + size_t buffer_length, + ssize_t *strlen_buffer) { + auto &column = columns_[column_n - 1]; + if (buffer == nullptr) { + if (column.is_bound_) { + num_binding_--; + } + column.ResetBinding(); + return; + } + + if (!column.is_bound_) { + num_binding_++; + } + + ColumnBinding binding(ConvertCDataTypeFromV2ToV3(target_type), precision, scale, buffer, buffer_length, + strlen_buffer); + column.SetBinding(binding, schema_->field(column_n - 1)->type()->id()); +} + +FlightSqlResultSet::~FlightSqlResultSet() = default; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h new file mode 100644 index 0000000000000..9f76a05b6032e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "flight_sql_stream_chunk_buffer.h" +#include "record_batch_transformer.h" +#include "utils.h" +#include "odbcabstraction/types.h" +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::Schema; +using arrow::flight::FlightEndpoint; +using arrow::flight::FlightInfo; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using arrow::flight::sql::FlightSqlClient; +using odbcabstraction::CDataType; +using odbcabstraction::DriverException; +using odbcabstraction::ResultSet; +using odbcabstraction::ResultSetMetadata; + +class FlightSqlResultSetColumn; + +class FlightSqlResultSet : public ResultSet { +private: + const odbcabstraction::MetadataSettings& metadata_settings_; + FlightStreamChunkBuffer chunk_buffer_; + FlightStreamChunk current_chunk_; + std::shared_ptr schema_; + std::shared_ptr transformer_; + std::shared_ptr metadata_; + std::vector columns_; + std::vector get_data_offsets_; + odbcabstraction::Diagnostics &diagnostics_; + int64_t current_row_; + int num_binding_; + bool reset_get_data_; + +public: + ~FlightSqlResultSet() override; + + FlightSqlResultSet( + FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + const std::shared_ptr &transformer, + odbcabstraction::Diagnostics& diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + + void Close() override; + + void Cancel() override; + + bool GetData(int column_n, int16_t target_type, int precision, int scale, + void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) override; + + size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) override; + + std::shared_ptr GetMetadata() override; + + void BindColumn(int column_n, int16_t target_type, int precision, int scale, + void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.cc new file mode 100644 index 0000000000000..336a1ced899f4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.cc @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "accessors/main.h" + +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::CDataType; + +typedef std::pair SourceAndTargetPair; +typedef std::function AccessorConstructor; + +namespace { + +const std::unordered_map> + ACCESSORS_CONSTRUCTORS = { + {SourceAndTargetPair(arrow::Type::type::STRING, CDataType_CHAR), + [](arrow::Array *array) { + return new StringArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::STRING, CDataType_WCHAR), + CreateWCharStringArrayAccessor}, + {SourceAndTargetPair(arrow::Type::type::DOUBLE, CDataType_DOUBLE), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::FLOAT, CDataType_FLOAT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT64, CDataType_SBIGINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT64, CDataType_UBIGINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT32, CDataType_SLONG), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT32, CDataType_ULONG), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT16, CDataType_SSHORT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT16, CDataType_USHORT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT8, CDataType_STINYINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor( + array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT8, CDataType_UTINYINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor( + array); + }}, + {SourceAndTargetPair(arrow::Type::type::BOOL, CDataType_BIT), + [](arrow::Array *array) { + return new BooleanArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::BINARY, CDataType_BINARY), + [](arrow::Array *array) { + return new BinaryArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::DATE32, CDataType_DATE), + [](arrow::Array *array) { + return new DateArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::DATE64, CDataType_DATE), + [](arrow::Array *array) { + return new DateArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::TIMESTAMP, CDataType_TIMESTAMP), + [](arrow::Array *array) { + auto time_type = + arrow::internal::checked_pointer_cast(array->type()); + auto time_unit = time_type->unit(); + Accessor* result; + switch (time_unit) { + case TimeUnit::SECOND: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::MILLI: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::MICRO: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::NANO: + result = new TimestampArrayFlightSqlAccessor(array); + break; + default: + assert(false); + throw DriverException("Unrecognized time unit " + std::to_string(time_unit)); + } + return result; + }}, + {SourceAndTargetPair(arrow::Type::type::TIME32, CDataType_TIME), + [](arrow::Array *array) { + return CreateTimeAccessor(array, arrow::Type::type::TIME32); + }}, + {SourceAndTargetPair(arrow::Type::type::TIME64, CDataType_TIME), + [](arrow::Array *array) { + return CreateTimeAccessor(array, arrow::Type::type::TIME64); + }}, + {SourceAndTargetPair(arrow::Type::type::DECIMAL128, CDataType_NUMERIC), + [](arrow::Array *array) { + return new DecimalArrayFlightSqlAccessor(array); + }}}; +} + +std::unique_ptr CreateAccessor(arrow::Array *source_array, + CDataType target_type) { + auto it = ACCESSORS_CONSTRUCTORS.find( + SourceAndTargetPair(source_array->type_id(), target_type)); + if (it != ACCESSORS_CONSTRUCTORS.end()) { + auto accessor = it->second(source_array); + return std::unique_ptr(accessor); + } + + std::stringstream ss; + ss << "Unsupported type conversion! Tried to convert '" + << source_array->type()->ToString() << "' to C type '" << target_type + << "'"; + throw odbcabstraction::DriverException(ss.str()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h new file mode 100644 index 0000000000000..ddc6017f9a8a1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class Accessor; +class FlightSqlResultSet; + +std::unique_ptr +CreateAccessor(arrow::Array *source_array, + odbcabstraction::CDataType target_type); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.cc new file mode 100644 index 0000000000000..3a758c99a2726 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.cc @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_result_set_column.h" +#include +#include "flight_sql_result_set_accessors.h" +#include "utils.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +namespace { +std::shared_ptr +CastArray(const std::shared_ptr &original_array, + CDataType target_type) { + bool conversion = NeedArrayConversion(original_array->type()->id(), target_type); + + if (conversion) { + auto converter = GetConverter(original_array->type_id(), target_type); + return converter(original_array); + } else { + return original_array; + } +} +} // namespace + +std::unique_ptr +FlightSqlResultSetColumn::CreateAccessor(CDataType target_type) { + cached_casted_array_ = CastArray(original_array_, target_type); + + return flight_sql::CreateAccessor(cached_casted_array_.get(), target_type); +} + +Accessor * +FlightSqlResultSetColumn::GetAccessorForTargetType(CDataType target_type) { + // Cast the original array to a type matching the target_type. + if (target_type == odbcabstraction::CDataType_DEFAULT) { + target_type = ConvertArrowTypeToC(original_array_->type_id(), use_wide_char_); + } + + cached_accessor_ = CreateAccessor(target_type); + return cached_accessor_.get(); +} + +FlightSqlResultSetColumn::FlightSqlResultSetColumn(bool use_wide_char) + : use_wide_char_(use_wide_char), + is_bound_(false) {} + +void FlightSqlResultSetColumn::SetBinding(const ColumnBinding& new_binding, arrow::Type::type arrow_type) { + binding_ = new_binding; + is_bound_ = true; + + if (binding_.target_type == odbcabstraction::CDataType_DEFAULT) { + binding_.target_type = ConvertArrowTypeToC(arrow_type, use_wide_char_); + } + + // Overwrite the binding if the caller is using SQL_C_NUMERIC and has used zero + // precision if it is zero (this is precision unset and will always fail). + if (binding_.precision == 0 && + binding_.target_type == odbcabstraction::CDataType_NUMERIC) { + binding_.precision = arrow::Decimal128Type::kMaxPrecision; + } + + // Rebuild the accessor and casted array if the target type changed. + if (original_array_ && (!cached_casted_array_ || cached_accessor_->target_type_ != binding_.target_type)) { + cached_accessor_ = CreateAccessor(binding_.target_type); + } +} + +void FlightSqlResultSetColumn::ResetBinding() { + is_bound_ = false; + cached_casted_array_.reset(); + cached_accessor_.reset(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.h new file mode 100644 index 0000000000000..e55c4b7af1883 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_column.h @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "utils.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; + +class FlightSqlResultSetColumn { +private: + std::shared_ptr original_array_; + std::shared_ptr cached_casted_array_; + std::unique_ptr cached_accessor_; + + std::unique_ptr CreateAccessor(CDataType target_type); + + Accessor *GetAccessorForTargetType(CDataType target_type); + +public: + FlightSqlResultSetColumn() = default; + explicit FlightSqlResultSetColumn(bool use_wide_char); + + ColumnBinding binding_; + bool use_wide_char_; + bool is_bound_; + + inline Accessor *GetAccessorForBinding() { + return cached_accessor_.get(); + } + + inline Accessor *GetAccessorForGetData(CDataType target_type) { + if (target_type == odbcabstraction::CDataType_DEFAULT) { + target_type = ConvertArrowTypeToC(original_array_->type_id(), use_wide_char_); + } + + if (cached_accessor_ && cached_accessor_->target_type_ == target_type) { + return cached_accessor_.get(); + } + return GetAccessorForTargetType(target_type); + } + + void SetBinding(const ColumnBinding& new_binding, arrow::Type::type arrow_type); + + void ResetBinding(); + + inline void ResetAccessor(std::shared_ptr array) { + original_array_ = std::move(array); + if (cached_accessor_) { + cached_accessor_ = CreateAccessor(cached_accessor_->target_type_); + } else if (is_bound_) { + cached_accessor_ = CreateAccessor(binding_.target_type); + } else { + cached_casted_array_.reset(); + cached_accessor_.reset(); + } + } +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc new file mode 100644 index 0000000000000..1d52872bcf0f0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_result_set_metadata.h" +#include +#include +#include +#include "utils.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; +using arrow::DataType; +using arrow::Field; +using arrow::util::make_optional; +using arrow::util::nullopt; + +constexpr int32_t DefaultDecimalPrecision = 38; + +// This indicates the column length used when the both property StringColumnLength is not specified and +// the server does not provide a length on column metadata. +constexpr int32_t DefaultLengthForVariableLengthColumns = 1024; + +namespace { +std::shared_ptr empty_metadata_map(new arrow::KeyValueMetadata); + +inline arrow::flight::sql::ColumnMetadata GetMetadata(const std::shared_ptr &field) { + const auto &metadata_map = field->metadata(); + + arrow::flight::sql::ColumnMetadata metadata(metadata_map ? metadata_map : empty_metadata_map); + return metadata; +} + +arrow::Result GetFieldPrecision(const std::shared_ptr &field) { + return GetMetadata(field).GetPrecision(); +} +} + +size_t FlightSqlResultSetMetadata::GetColumnCount() { + return schema_->num_fields(); +} + +std::string FlightSqlResultSetMetadata::GetColumnName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +std::string FlightSqlResultSetMetadata::GetName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +size_t FlightSqlResultSetMetadata::GetPrecision(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = GetFieldPrecision(field).ValueOrElse([] { return 0; }); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetColumnSize(data_type_v3, column_size).value_or(0); +} + +size_t FlightSqlResultSetMetadata::GetScale(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(field); + + int32_t type_scale = metadata.GetScale().ValueOrElse([] { return 0; }); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetTypeScale(data_type_v3, type_scale).value_or(0); +} + +uint16_t FlightSqlResultSetMetadata::GetDataType(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + const SqlDataType conciseType = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + return GetNonConciseDataType(conciseType); +} + +driver::odbcabstraction::Nullability +FlightSqlResultSetMetadata::IsNullable(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + return field->nullable() ? odbcabstraction::NULLABILITY_NULLABLE : odbcabstraction::NULLABILITY_NO_NULLS; +} + +std::string FlightSqlResultSetMetadata::GetSchemaName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetSchemaName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetCatalogName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetCatalogName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetTableName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetTableName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetColumnLabel(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +size_t FlightSqlResultSetMetadata::GetColumnDisplaySize( + int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetDisplaySize(data_type_v3, column_size).value_or(NO_TOTAL); +} + +std::string FlightSqlResultSetMetadata::GetBaseColumnName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +std::string FlightSqlResultSetMetadata::GetBaseTableName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + return metadata.GetTableName().ValueOrElse([] { return ""; }); +} + +uint16_t FlightSqlResultSetMetadata::GetConciseType(int column_position) { + const std::shared_ptr &field = schema_->field(column_position -1); + + const SqlDataType sqlColumnType = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + return sqlColumnType; +} + +size_t FlightSqlResultSetMetadata::GetLength(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return flight_sql::GetLength(data_type_v3, column_size).value_or(DefaultLengthForVariableLengthColumns); +} + +std::string FlightSqlResultSetMetadata::GetLiteralPrefix(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return ""; +} + +std::string FlightSqlResultSetMetadata::GetLiteralSuffix(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return ""; +} + +std::string FlightSqlResultSetMetadata::GetLocalTypeName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + // TODO: Is local type name the same as type name? + return metadata.GetTypeName().ValueOrElse([] { return ""; }); +} + +size_t FlightSqlResultSetMetadata::GetNumPrecRadix(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetRadixFromSqlDataType(data_type_v3).value_or(NO_TOTAL); +} + +size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(field); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + // Workaround to get the precision for Decimal and Numeric types, since server doesn't return it currently. + // TODO: Use the server precision when its fixed. + std::shared_ptr arrow_type = field->type(); + if (arrow_type->id() == arrow::Type::DECIMAL128){ + int32_t precision = GetDecimalTypePrecision(arrow_type); + return GetCharOctetLength(data_type_v3, column_size, precision).value_or(DefaultDecimalPrecision+2); + } + + return GetCharOctetLength(data_type_v3, column_size).value_or(DefaultLengthForVariableLengthColumns); +} + +std::string FlightSqlResultSetMetadata::GetTypeName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetTypeName().ValueOrElse([] { return ""; }); +} + +driver::odbcabstraction::Updatability +FlightSqlResultSetMetadata::GetUpdatable(int column_position) { + return odbcabstraction::UPDATABILITY_READWRITE_UNKNOWN; +} + +bool FlightSqlResultSetMetadata::IsAutoUnique(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + // TODO: Is AutoUnique equivalent to AutoIncrement? + return metadata.GetIsAutoIncrement().ValueOrElse([] { return false; }); +} + +bool FlightSqlResultSetMetadata::IsCaseSensitive(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetIsCaseSensitive().ValueOrElse([] { return false; }); +} + +driver::odbcabstraction::Searchability +FlightSqlResultSetMetadata::IsSearchable(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + bool is_searchable = metadata.GetIsSearchable().ValueOrElse([] { return false; }); + return is_searchable ? odbcabstraction::SEARCHABILITY_ALL : odbcabstraction::SEARCHABILITY_NONE; +} + +bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + switch (field->type()->id()) { + case arrow::Type::UINT8: + case arrow::Type::UINT16: + case arrow::Type::UINT32: + case arrow::Type::UINT64: + return true; + default: + return false; + } +} + +bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return false; +} + +FlightSqlResultSetMetadata::FlightSqlResultSetMetadata( + std::shared_ptr schema, + const odbcabstraction::MetadataSettings& metadata_settings) + : + metadata_settings_(metadata_settings), + schema_(std::move(schema)) {} + +FlightSqlResultSetMetadata::FlightSqlResultSetMetadata( + const std::shared_ptr &flight_info, + const odbcabstraction::MetadataSettings& metadata_settings) + : + metadata_settings_(metadata_settings){ + arrow::ipc::DictionaryMemo dict_memo; + + ThrowIfNotOK(flight_info->GetSchema(&dict_memo, &schema_)); +} + +} // namespace flight_sql +} // namespace driver \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h new file mode 100644 index 0000000000000..0a28a17dd0480 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { +class FlightSqlResultSetMetadata : public odbcabstraction::ResultSetMetadata { +private: + const odbcabstraction::MetadataSettings& metadata_settings_; + std::shared_ptr schema_; + +public: + FlightSqlResultSetMetadata( + const std::shared_ptr &flight_info, + const odbcabstraction::MetadataSettings& metadata_settings); + + FlightSqlResultSetMetadata( + std::shared_ptr schema, + const odbcabstraction::MetadataSettings& metadata_settings); + + size_t GetColumnCount() override; + + std::string GetColumnName(int column_position) override; + + size_t GetPrecision(int column_position) override; + + size_t GetScale(int column_position) override; + + uint16_t GetDataType(int column_position) override; + + odbcabstraction::Nullability IsNullable(int column_position) override; + + std::string GetSchemaName(int column_position) override; + + std::string GetCatalogName(int column_position) override; + + std::string GetTableName(int column_position) override; + + std::string GetColumnLabel(int column_position) override; + + size_t GetColumnDisplaySize(int column_position) override; + + std::string GetBaseColumnName(int column_position) override; + + std::string GetBaseTableName(int column_position) override; + + uint16_t GetConciseType(int column_position) override; + + size_t GetLength(int column_position) override; + + std::string GetLiteralPrefix(int column_position) override; + + std::string GetLiteralSuffix(int column_position) override; + + std::string GetLocalTypeName(int column_position) override; + + std::string GetName(int column_position) override; + + size_t GetNumPrecRadix(int column_position) override; + + size_t GetOctetLength(int column_position) override; + + std::string GetTypeName(int column_position) override; + + odbcabstraction::Updatability GetUpdatable(int column_position) override; + + bool IsAutoUnique(int column_position) override; + + bool IsCaseSensitive(int column_position) override; + + odbcabstraction::Searchability IsSearchable(int column_position) override; + + bool IsUnsigned(int column_position) override; + + bool IsFixedPrecScale(int column_position) override; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.cc new file mode 100644 index 0000000000000..3c95b7790b9a8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.cc @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_ssl_config.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + + +FlightSqlSslConfig::FlightSqlSslConfig( + bool disableCertificateVerification, const std::string& trustedCerts, + bool systemTrustStore, bool useEncryption) + : trustedCerts_(trustedCerts), useEncryption_(useEncryption), + disableCertificateVerification_(disableCertificateVerification), + systemTrustStore_(systemTrustStore) {} + +bool FlightSqlSslConfig::useEncryption() const { + return useEncryption_; +} + +bool FlightSqlSslConfig::shouldDisableCertificateVerification() const { + return disableCertificateVerification_; +} + +const std::string& FlightSqlSslConfig::getTrustedCerts() const { + return trustedCerts_; +} + +bool FlightSqlSslConfig::useSystemTrustStore() const { + return systemTrustStore_; +} + +void FlightSqlSslConfig::populateOptionsWithCerts(arrow::flight::CertKeyPair* out) { + try { + std::ifstream cert_file(trustedCerts_); + if (!cert_file) { + throw odbcabstraction::DriverException("Could not open certificate: " + trustedCerts_); + } + std::stringstream cert; + cert << cert_file.rdbuf(); + out->pem_cert = cert.str(); + } + catch (const std::ifstream::failure& e) { + throw odbcabstraction::DriverException(e.what()); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h new file mode 100644 index 0000000000000..ca3bb38d4ce81 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +/// \brief An Auxiliary class that holds all the information to perform +/// a SSL connection. +class FlightSqlSslConfig { +public: + FlightSqlSslConfig(bool disableCertificateVerification, + const std::string &trustedCerts, bool systemTrustStore, + bool useEncryption); + + /// \brief Tells if ssl is enabled. By default it will be true. + /// \return Whether ssl is enabled. + bool useEncryption() const; + + /// \brief Tells if disable certificate verification is enabled. + /// \return Whether disable certificate verification is enabled. + bool shouldDisableCertificateVerification() const; + + /// \brief The path to the trusted certificate. + /// \return Certificate path. + const std::string &getTrustedCerts() const; + + /// \brief Tells if we need to check if the certificate is in the system trust store. + /// \return Whether to use the system trust store. + bool useSystemTrustStore() const; + + /// \brief Loads the certificate file and extract the certificate file from it + /// and create the object CertKeyPair with it on. + /// \param out A CertKeyPair with the cert on it. + /// \return The cert key pair object + void populateOptionsWithCerts(arrow::flight::CertKeyPair *out); + +private: + const std::string trustedCerts_; + const bool useEncryption_; + const bool disableCertificateVerification_; + const bool systemTrustStore_; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc new file mode 100644 index 0000000000000..04c04bf38d5d4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_statement.h" +#include +#include "flight_sql_result_set.h" +#include "flight_sql_result_set_metadata.h" +#include "flight_sql_statement_get_columns.h" +#include "flight_sql_statement_get_tables.h" +#include "flight_sql_statement_get_type_info.h" +#include "record_batch_transformer.h" +#include "utils.h" +#include +#include +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::Status; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightInfo; +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using arrow::flight::sql::FlightSqlClient; +using arrow::flight::sql::PreparedStatement; +using driver::odbcabstraction::DriverException; +using driver::odbcabstraction::ResultSet; +using driver::odbcabstraction::ResultSetMetadata; +using driver::odbcabstraction::Statement; + +namespace { + +void ClosePreparedStatementIfAny( + std::shared_ptr + &prepared_statement) { + if (prepared_statement != nullptr) { + ThrowIfNotOK(prepared_statement->Close()); + prepared_statement.reset(); + } +} + +} // namespace + +FlightSqlStatement::FlightSqlStatement( + const odbcabstraction::Diagnostics& diagnostics, + FlightSqlClient &sql_client, + FlightCallOptions call_options, + const odbcabstraction::MetadataSettings& metadata_settings) + : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), + sql_client_(sql_client), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { + attribute_[METADATA_ID] = static_cast(SQL_FALSE); + attribute_[MAX_LENGTH] = static_cast(0); + attribute_[NOSCAN] = static_cast(SQL_NOSCAN_OFF); + attribute_[QUERY_TIMEOUT] = static_cast(0); + call_options_.timeout = TimeoutDuration{-1}; +} + +bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, + const Attribute &value) { + switch (attribute) { + case METADATA_ID: + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_FALSE)); + case NOSCAN: + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_NOSCAN_OFF)); + case MAX_LENGTH: + return CheckIfSetToOnlyValidValue(value, static_cast(0)); + case QUERY_TIMEOUT: + if (boost::get(value) > 0) { + call_options_.timeout = + TimeoutDuration{static_cast(boost::get(value))}; + } else { + call_options_.timeout = TimeoutDuration{-1}; + // Intentional fall-through. + } + default: + attribute_[attribute] = value; + return true; + } +} + +boost::optional +FlightSqlStatement::GetAttribute(StatementAttributeId attribute) { + const auto &it = attribute_.find(attribute); + return boost::make_optional(it != attribute_.end(), it->second); +} + +boost::optional> +FlightSqlStatement::Prepare(const std::string &query) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = + sql_client_.Prepare(call_options_, query); + ThrowIfNotOK(result.status()); + + prepared_statement_ = *result; + + const auto &result_set_metadata = + std::make_shared( + prepared_statement_->dataset_schema(), metadata_settings_); + return boost::optional>( + result_set_metadata); +} + +bool FlightSqlStatement::ExecutePrepared() { + assert(prepared_statement_.get() != nullptr); + + Result> result = prepared_statement_->Execute(); + ThrowIfNotOK(result.status()); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, metadata_settings_); + + return true; +} + +bool FlightSqlStatement::Execute(const std::string &query) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = + sql_client_.Execute(call_options_, query); + ThrowIfNotOK(result.status()); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, metadata_settings_); + + return true; +} + +std::shared_ptr FlightSqlStatement::GetResultSet() { + return current_result_set_; +} + +long FlightSqlStatement::GetUpdateCount() { return -1; } + +std::shared_ptr FlightSqlStatement::GetTables( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type, + const ColumnNames &column_names) { + ClosePreparedStatementIfAny(prepared_statement_); + + std::vector table_types; + + if ((catalog_name && *catalog_name == "%") && + (schema_name && schema_name->empty()) && + (table_name && table_name->empty())) { + current_result_set_ = + GetTablesForSQLAllCatalogs( + column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + } else if ((catalog_name && catalog_name->empty()) && + (schema_name && *schema_name == "%") && + (table_name && table_name->empty())) { + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, call_options_, sql_client_, schema_name, diagnostics_, metadata_settings_); + } else if ((catalog_name && catalog_name->empty()) && + (schema_name && schema_name->empty()) && + (table_name && table_name->empty()) && + (table_type && *table_type == "%")) { + current_result_set_ = + GetTablesForSQLAllTableTypes( + column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + } else { + if (table_type) { + ParseTableTypes(*table_type, table_types); + } + + current_result_set_ = GetTablesForGenericUse( + column_names, call_options_, sql_client_, catalog_name, schema_name, + table_name, table_types, diagnostics_, metadata_settings_); + } + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTables_V2( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) { + ColumnNames column_names{"TABLE_QUALIFIER", "TABLE_OWNER", "TABLE_NAME", + "TABLE_TYPE", "REMARKS"}; + + return GetTables(catalog_name, schema_name, table_name, table_type, + column_names); +} + +std::shared_ptr FlightSqlStatement::GetTables_V3( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) { + ColumnNames column_names{"TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", + "TABLE_TYPE", "REMARKS"}; + + return GetTables(catalog_name, schema_name, table_name, table_type, + column_names); +} + +std::shared_ptr FlightSqlStatement::GetColumns_V2( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetTables( + call_options_, catalog_name, schema_name, table_name, true, nullptr); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_2, column_name); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetColumns_V3( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetTables( + call_options_, catalog_name, schema_name, table_name, true, nullptr); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_3, column_name); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetXdbcTypeInfo( + call_options_); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_2, data_type); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetXdbcTypeInfo( + call_options_); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_3, data_type); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +odbcabstraction::Diagnostics &FlightSqlStatement::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlStatement::Cancel() { + if (!current_result_set_) return; + current_result_set_->Cancel(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h new file mode 100644 index 0000000000000..f51977636a439 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "flight_sql_statement_get_tables.h" +#include "odbcabstraction/types.h" +#include +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlStatement : public odbcabstraction::Statement { + +private: + odbcabstraction::Diagnostics diagnostics_; + std::map attribute_; + arrow::flight::FlightCallOptions call_options_; + arrow::flight::sql::FlightSqlClient &sql_client_; + std::shared_ptr current_result_set_; + std::shared_ptr prepared_statement_; + const odbcabstraction::MetadataSettings& metadata_settings_; + + std::shared_ptr + GetTables(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type, + const ColumnNames &column_names); + +public: + FlightSqlStatement( + const odbcabstraction::Diagnostics &diagnostics, + arrow::flight::sql::FlightSqlClient &sql_client, + arrow::flight::FlightCallOptions call_options, + const odbcabstraction::MetadataSettings& metadata_settings); + + bool SetAttribute(StatementAttributeId attribute, const Attribute &value) override; + + boost::optional GetAttribute(StatementAttributeId attribute) override; + + boost::optional> + Prepare(const std::string &query) override; + + bool ExecutePrepared() override; + + bool Execute(const std::string &query) override; + + std::shared_ptr GetResultSet() override; + + long GetUpdateCount() override; + + std::shared_ptr + GetTables_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) override; + + std::shared_ptr + GetTables_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) override; + + std::shared_ptr + GetColumns_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) override; + + std::shared_ptr + GetColumns_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) override; + + std::shared_ptr GetTypeInfo_V2(int16_t data_type) override; + + std::shared_ptr GetTypeInfo_V3(int16_t data_type) override; + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + void Cancel() override; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc new file mode 100644 index 0000000000000..cd0f788a1b94d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_statement_get_columns.h" +#include +#include "flight_sql_connection.h" +#include "flight_sql_get_tables_reader.h" +#include "utils.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::sql::ColumnMetadata; +using arrow::util::make_optional; +using arrow::util::nullopt; +using arrow::util::optional; + +namespace { +std::shared_ptr GetColumns_V3_Schema() { + return schema({ + field("TABLE_CAT", utf8()), + field("TABLE_SCHEM", utf8()), + field("TABLE_NAME", utf8()), + field("COLUMN_NAME", utf8()), + field("DATA_TYPE", int16()), + field("TYPE_NAME", utf8()), + field("COLUMN_SIZE", int32()), + field("BUFFER_LENGTH", int32()), + field("DECIMAL_DIGITS", int16()), + field("NUM_PREC_RADIX", int16()), + field("NULLABLE", int16()), + field("REMARKS", utf8()), + field("COLUMN_DEF", utf8()), + field("SQL_DATA_TYPE", int16()), + field("SQL_DATETIME_SUB", int16()), + field("CHAR_OCTET_LENGTH", int32()), + field("ORDINAL_POSITION", int32()), + field("IS_NULLABLE", utf8()), + }); +} + +std::shared_ptr GetColumns_V2_Schema() { + return schema({ + field("TABLE_QUALIFIER", utf8()), + field("TABLE_OWNER", utf8()), + field("TABLE_NAME", utf8()), + field("COLUMN_NAME", utf8()), + field("DATA_TYPE", int16()), + field("TYPE_NAME", utf8()), + field("PRECISION", int32()), + field("LENGTH", int32()), + field("SCALE", int16()), + field("RADIX", int16()), + field("NULLABLE", int16()), + field("REMARKS", utf8()), + field("COLUMN_DEF", utf8()), + field("SQL_DATA_TYPE", int16()), + field("SQL_DATETIME_SUB", int16()), + field("CHAR_OCTET_LENGTH", int32()), + field("ORDINAL_POSITION", int32()), + field("IS_NULLABLE", utf8()), + }); +} + +Result> +Transform_inner(const odbcabstraction::OdbcVersion odbc_version, + const std::shared_ptr &original, + const optional &column_name_pattern, + const MetadataSettings& metadata_settings) { + GetColumns_RecordBatchBuilder builder(odbc_version); + GetColumns_RecordBatchBuilder::Data data; + + GetTablesReader reader(original); + + optional column_name_regex = + column_name_pattern + ? make_optional(ConvertSqlPatternToRegex(*column_name_pattern)) + : nullopt; + + while (reader.Next()) { + const auto &table_catalog = reader.GetCatalogName(); + const auto &table_schema = reader.GetDbSchemaName(); + const auto &table_name = reader.GetTableName(); + const std::shared_ptr &schema = reader.GetSchema(); + if (schema == nullptr) { + // TODO: Remove this if after fixing TODO on GetTablesReader::GetSchema() + // This is because of a problem on Dremio server, where complex types columns + // are being returned without the children types, so we are simply ignoring + // it by now. + continue; + } + for (int i = 0; i < schema->num_fields(); ++i) { + const std::shared_ptr &field = schema->field(i); + + if (column_name_regex && + !boost::xpressive::regex_match(field->name(), + *column_name_regex)) { + continue; + } + + odbcabstraction::SqlDataType data_type_v3 = + GetDataTypeFromArrowField_V3(field, metadata_settings.use_wide_char_); + + ColumnMetadata metadata(field->metadata()); + + data.table_cat = table_catalog; + data.table_schem = table_schema; + data.table_name = table_name; + data.column_name = field->name(); + data.data_type = odbc_version == odbcabstraction::V_3 + ? data_type_v3 + : ConvertSqlDataTypeFromV3ToV2(data_type_v3); + + // TODO: Use `metadata.GetTypeName()` when ARROW-16064 is merged. + const auto &type_name_result = field->metadata()->Get("ARROW:FLIGHT:SQL:TYPE_NAME"); + data.type_name = type_name_result.ok() ? + type_name_result.ValueOrDie() : + GetTypeNameFromSqlDataType(data_type_v3); + + const Result &precision_result = metadata.GetPrecision(); + data.column_size = precision_result.ok() + ? make_optional(precision_result.ValueOrDie()) + : nullopt; + data.char_octet_length = + GetCharOctetLength(data_type_v3, precision_result); + + data.buffer_length = GetBufferLength(data_type_v3, data.column_size); + + const Result &scale_result = metadata.GetScale(); + data.decimal_digits = scale_result.ok() + ? make_optional(scale_result.ValueOrDie()) + : nullopt; + data.num_prec_radix = GetRadixFromSqlDataType(data_type_v3); + data.nullable = field->nullable(); + data.remarks = nullopt; + data.column_def = nullopt; + data.sql_data_type = GetNonConciseDataType(data_type_v3); + data.sql_datetime_sub = GetSqlDateTimeSubCode(data_type_v3); + data.ordinal_position = i + 1; + data.is_nullable = field->nullable() ? "YES" : "NO"; + + ARROW_RETURN_NOT_OK(builder.Append(data)); + } + } + + return builder.Build(); +} +} // namespace + +GetColumns_RecordBatchBuilder::GetColumns_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version) + : odbc_version_(odbc_version) {} + +Result> GetColumns_RecordBatchBuilder::Build() { + ARROW_ASSIGN_OR_RAISE(auto TABLE_CAT_Array, TABLE_CAT_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TABLE_SCHEM_Array, TABLE_SCHEM_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TABLE_NAME_Array, TABLE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_NAME_Array, COLUMN_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DATA_TYPE_Array, DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TYPE_NAME_Array, TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_SIZE_Array, COLUMN_SIZE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto BUFFER_LENGTH_Array, + BUFFER_LENGTH_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DECIMAL_DIGITS_Array, + DECIMAL_DIGITS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NUM_PREC_RADIX_Array, + NUM_PREC_RADIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NULLABLE_Array, NULLABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto REMARKS_Array, REMARKS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_DEF_Array, COLUMN_DEF_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATA_TYPE_Array, + SQL_DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATETIME_SUB_Array, + SQL_DATETIME_SUB_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CHAR_OCTET_LENGTH_Array, + CHAR_OCTET_LENGTH_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto ORDINAL_POSITION_Array, + ORDINAL_POSITION_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto IS_NULLABLE_Array, IS_NULLABLE_Builder_.Finish()) + + std::vector> arrays = { + TABLE_CAT_Array, TABLE_SCHEM_Array, TABLE_NAME_Array, + COLUMN_NAME_Array, DATA_TYPE_Array, TYPE_NAME_Array, + COLUMN_SIZE_Array, BUFFER_LENGTH_Array, DECIMAL_DIGITS_Array, + NUM_PREC_RADIX_Array, NULLABLE_Array, REMARKS_Array, + COLUMN_DEF_Array, SQL_DATA_TYPE_Array, SQL_DATETIME_SUB_Array, + CHAR_OCTET_LENGTH_Array, ORDINAL_POSITION_Array, IS_NULLABLE_Array}; + + const std::shared_ptr &schema = odbc_version_ == odbcabstraction::V_3 + ? GetColumns_V3_Schema() + : GetColumns_V2_Schema(); + return RecordBatch::Make(schema, num_rows_, arrays); +} + +Status GetColumns_RecordBatchBuilder::Append( + const GetColumns_RecordBatchBuilder::Data &data) { + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_CAT_Builder_, data.table_cat)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_SCHEM_Builder_, data.table_schem)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_NAME_Builder_, data.table_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_NAME_Builder_, data.column_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(DATA_TYPE_Builder_, data.data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TYPE_NAME_Builder_, data.type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_SIZE_Builder_, data.column_size)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(BUFFER_LENGTH_Builder_, data.buffer_length)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(DECIMAL_DIGITS_Builder_, data.decimal_digits)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(NUM_PREC_RADIX_Builder_, data.num_prec_radix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NULLABLE_Builder_, data.nullable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(REMARKS_Builder_, data.remarks)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_DEF_Builder_, data.column_def)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(SQL_DATA_TYPE_Builder_, data.sql_data_type)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(SQL_DATETIME_SUB_Builder_, data.sql_datetime_sub)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(CHAR_OCTET_LENGTH_Builder_, data.char_octet_length)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(ORDINAL_POSITION_Builder_, data.ordinal_position)); + ARROW_RETURN_NOT_OK(AppendToBuilder(IS_NULLABLE_Builder_, data.is_nullable)); + num_rows_++; + + return Status::OK(); +} + +GetColumns_Transformer::GetColumns_Transformer( + const MetadataSettings& metadata_settings, + const odbcabstraction::OdbcVersion odbc_version, + const std::string *column_name_pattern) + : metadata_settings_(metadata_settings), + odbc_version_(odbc_version), + column_name_pattern_( + column_name_pattern ? make_optional(*column_name_pattern) : nullopt) { +} + +std::shared_ptr GetColumns_Transformer::Transform( + const std::shared_ptr &original) { + const Result> &result = + Transform_inner(odbc_version_, original, column_name_pattern_, metadata_settings_); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); +} + +std::shared_ptr GetColumns_Transformer::GetTransformedSchema() { + return odbc_version_ == odbcabstraction::V_3 ? GetColumns_V3_Schema() + : GetColumns_V2_Schema(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.h new file mode 100644 index 0000000000000..e7593b6991e63 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.h @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "record_batch_transformer.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::MetadataSettings; +using std::optional; + +class GetColumns_RecordBatchBuilder { +private: + odbcabstraction::OdbcVersion odbc_version_; + + StringBuilder TABLE_CAT_Builder_; + StringBuilder TABLE_SCHEM_Builder_; + StringBuilder TABLE_NAME_Builder_; + StringBuilder COLUMN_NAME_Builder_; + Int16Builder DATA_TYPE_Builder_; + StringBuilder TYPE_NAME_Builder_; + Int32Builder COLUMN_SIZE_Builder_; + Int32Builder BUFFER_LENGTH_Builder_; + Int16Builder DECIMAL_DIGITS_Builder_; + Int16Builder NUM_PREC_RADIX_Builder_; + Int16Builder NULLABLE_Builder_; + StringBuilder REMARKS_Builder_; + StringBuilder COLUMN_DEF_Builder_; + Int16Builder SQL_DATA_TYPE_Builder_; + Int16Builder SQL_DATETIME_SUB_Builder_; + Int32Builder CHAR_OCTET_LENGTH_Builder_; + Int32Builder ORDINAL_POSITION_Builder_; + StringBuilder IS_NULLABLE_Builder_; + int64_t num_rows_{0}; + +public: + struct Data { + optional table_cat; + optional table_schem; + std::string table_name; + std::string column_name; + std::string type_name; + optional column_size; + optional buffer_length; + optional decimal_digits; + optional num_prec_radix; + optional remarks; + optional column_def; + int16_t sql_data_type{}; + optional sql_datetime_sub; + optional char_octet_length; + optional is_nullable; + int16_t data_type; + int16_t nullable; + int32_t ordinal_position; + }; + + explicit GetColumns_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version); + + Result> Build(); + + Status Append(const Data &data); +}; + +class GetColumns_Transformer : public RecordBatchTransformer { +private: + const MetadataSettings& metadata_settings_; + odbcabstraction::OdbcVersion odbc_version_; + optional column_name_pattern_; + +public: + explicit GetColumns_Transformer(const MetadataSettings& metadata_settings, + odbcabstraction::OdbcVersion odbc_version, + const std::string *column_name_pattern); + + std::shared_ptr + Transform(const std::shared_ptr &original) override; + + std::shared_ptr GetTransformedSchema() override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc new file mode 100644 index 0000000000000..b1bd263b177c8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_statement_get_tables.h" +#include +#include "arrow/flight/api.h" +#include "arrow/flight/types.h" +#include "flight_sql_result_set.h" +#include "record_batch_transformer.h" +#include "utils.h" + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightInfo; +using arrow::flight::sql::FlightSqlClient; + +void ParseTableTypes(const std::string &table_type, + std::vector &table_types) { + bool encountered = false; // for checking if there is a single quote + std::string curr_parse; // the current string + + for (char temp : table_type) { // while still in the string + switch (temp) { // switch depending on the character + case '\'': // if the character is a single quote + if (encountered) { + encountered = + false; // if we already found a single quote, reset encountered + } else { + encountered = + true; // if we haven't found a single quote, set encountered to true + } + break; + case ',': // if it is a comma + if (!encountered) { // if we have not found a single quote + table_types.push_back( + curr_parse); // put our current string into our vector + curr_parse = ""; // reset the current string + break; + } + default: // if it is a normal character + if (encountered && isspace(temp)) { + curr_parse.push_back(temp); // if we have found a single quote put the + // whitespace, we don't care + } else if (temp == '\'' || temp == ' ') { + break; // if the current character is a single quote, trash it and go to + // the next character. + } else { + curr_parse.push_back(temp); // if all of the above failed, put the + // character into the current string + } + break; // go to the next character + } + } + table_types.emplace_back( + curr_parse); // if we have found a single quote put the whitespace, + // we don't care +} + +std::shared_ptr +GetTablesForSQLAllCatalogs(const ColumnNames &names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetCatalogs(call_options); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("catalog_name", names.catalog_column) + .AddFieldOfNulls(names.schema_column, utf8()) + .AddFieldOfNulls(names.table_column, utf8()) + .AddFieldOfNulls(names.table_type_column, utf8()) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr GetTablesForSQLAllDbSchemas( + const ColumnNames &names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *schema_name, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetDbSchemas(call_options, nullptr, schema_name); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .AddFieldOfNulls(names.catalog_column, utf8()) + .RenameField("db_schema_name", names.schema_column) + .AddFieldOfNulls(names.table_column, utf8()) + .AddFieldOfNulls(names.table_type_column, utf8()) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr +GetTablesForSQLAllTableTypes(const ColumnNames &names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetTableTypes(call_options); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .AddFieldOfNulls(names.catalog_column, utf8()) + .AddFieldOfNulls(names.schema_column, utf8()) + .AddFieldOfNulls(names.table_column, utf8()) + .RenameField("table_type", names.table_type_column) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr GetTablesForGenericUse( + const ColumnNames &names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *catalog_name, + const std::string *schema_name, const std::string *table_name, + const std::vector &table_types, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = sql_client.GetTables( + call_options, catalog_name, schema_name, table_name, false, &table_types); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("catalog_name", names.catalog_column) + .RenameField("db_schema_name", names.schema_column) + .RenameField("table_name", names.table_column) + .RenameField("table_type", names.table_type_column) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h new file mode 100644 index 0000000000000..b63d1ca6e8c34 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "flight_sql_connection.h" +#include "arrow/flight/types.h" +#include +#include +#include "record_batch_transformer.h" +#include "odbcabstraction/types.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::FlightCallOptions; +using arrow::flight::sql::FlightSqlClient; +using odbcabstraction::ResultSet; +using odbcabstraction::MetadataSettings; + +typedef struct { + std::string catalog_column; + std::string schema_column; + std::string table_column; + std::string table_type_column; + std::string remarks_column; +} ColumnNames; + +void ParseTableTypes(const std::string &table_type, + std::vector &table_types); + +std::shared_ptr +GetTablesForSQLAllCatalogs(const ColumnNames &column_names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr GetTablesForSQLAllDbSchemas( + const ColumnNames &column_names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *schema_name, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr +GetTablesForSQLAllTableTypes(const ColumnNames &column_names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr GetTablesForGenericUse( + const ColumnNames &column_names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *catalog_name, + const std::string *schema_name, const std::string *table_name, + const std::vector &table_types, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc new file mode 100644 index 0000000000000..58cf2d37903f1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_statement_get_type_info.h" +#include +#include "flight_sql_get_type_info_reader.h" +#include "flight_sql_connection.h" +#include "utils.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::util::make_optional; +using arrow::util::nullopt; +using arrow::util::optional; + +namespace { +std::shared_ptr GetTypeInfo_V3_Schema() { + return schema({ + field("TYPE_NAME", utf8(), false), + field("DATA_TYPE", int16(), false), + field("COLUMN_SIZE", int32()), + field("LITERAL_PREFIX", utf8()), + field("LITERAL_SUFFIX", utf8()), + field("CREATE_PARAMS", utf8()), + field("NULLABLE", int16(), false), + field("CASE_SENSITIVE", int16(), false), + field("SEARCHABLE", int16(), false), + field("UNSIGNED_ATTRIBUTE", int16()), + field("FIXED_PREC_SCALE", int16(), false), + field("AUTO_UNIQUE_VALUE", int16()), + field("LOCAL_TYPE_NAME", utf8()), + field("MINIMUM_SCALE", int16()), + field("MAXIMUM_SCALE", int16()), + field("SQL_DATA_TYPE", int16(), false), + field("SQL_DATETIME_SUB", int16()), + field("NUM_PREC_RADIX", int32()), + field("INTERVAL_PRECISION", int16()), + }); +} + +std::shared_ptr GetTypeInfo_V2_Schema() { + return schema({ + field("TYPE_NAME", utf8(), false), + field("DATA_TYPE", int16(), false), + field("PRECISION", int32()), + field("LITERAL_PREFIX", utf8()), + field("LITERAL_SUFFIX", utf8()), + field("CREATE_PARAMS", utf8()), + field("NULLABLE", int16(), false), + field("CASE_SENSITIVE", int16(), false), + field("SEARCHABLE", int16(), false), + field("UNSIGNED_ATTRIBUTE", int16()), + field("MONEY", int16(), false), + field("AUTO_INCREMENT", int16()), + field("LOCAL_TYPE_NAME", utf8()), + field("MINIMUM_SCALE", int16()), + field("MAXIMUM_SCALE", int16()), + field("SQL_DATA_TYPE", int16(), false), + field("SQL_DATETIME_SUB", int16()), + field("NUM_PREC_RADIX", int32()), + field("INTERVAL_PRECISION", int16()), + }); +} + +Result> +Transform_inner(const odbcabstraction::OdbcVersion odbc_version, + const std::shared_ptr &original, + int data_type, + const MetadataSettings& metadata_settings_) { + GetTypeInfo_RecordBatchBuilder builder(odbc_version); + GetTypeInfo_RecordBatchBuilder::Data data; + + GetTypeInfoReader reader(original); + + while (reader.Next()) { + auto data_type_v3 = EnsureRightSqlCharType(static_cast(reader.GetDataType()), metadata_settings_.use_wide_char_); + int16_t data_type_v2 = ConvertSqlDataTypeFromV3ToV2(data_type_v3); + + if (data_type != odbcabstraction::ALL_TYPES && data_type_v3 != data_type && data_type_v2 != data_type) { + continue; + } + + data.data_type = odbc_version == odbcabstraction::V_3 + ? data_type_v3 + : data_type_v2; + data.type_name = reader.GetTypeName(); + data.column_size = reader.GetColumnSize(); + data.literal_prefix = reader.GetLiteralPrefix(); + data.literal_suffix = reader.GetLiteralSuffix(); + + const auto &create_params = reader.GetCreateParams(); + if (create_params) { + data.create_params = boost::algorithm::join(*create_params, ","); + } else { + data.create_params = nullopt; + } + + data.nullable = reader.GetNullable() ? odbcabstraction::NULLABILITY_NULLABLE : odbcabstraction::NULLABILITY_NO_NULLS; + data.case_sensitive = reader.GetCaseSensitive(); + data.searchable = reader.GetSearchable() ? odbcabstraction::SEARCHABILITY_ALL : odbcabstraction::SEARCHABILITY_NONE; + data.unsigned_attribute = reader.GetUnsignedAttribute(); + data.fixed_prec_scale = reader.GetFixedPrecScale(); + data.auto_unique_value = reader.GetAutoIncrement(); + data.local_type_name = reader.GetLocalTypeName(); + data.minimum_scale = reader.GetMinimumScale(); + data.maximum_scale = reader.GetMaximumScale(); + data.sql_data_type = EnsureRightSqlCharType(static_cast(reader.GetSqlDataType()), metadata_settings_.use_wide_char_); + data.sql_datetime_sub = GetSqlDateTimeSubCode(static_cast(data.data_type)); + data.num_prec_radix = reader.GetNumPrecRadix(); + data.interval_precision = reader.GetIntervalPrecision(); + + ARROW_RETURN_NOT_OK(builder.Append(data)); + } + + return builder.Build(); +} +} // namespace + +GetTypeInfo_RecordBatchBuilder::GetTypeInfo_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version) + : odbc_version_(odbc_version) {} + +Result> GetTypeInfo_RecordBatchBuilder::Build() { + + ARROW_ASSIGN_OR_RAISE(auto TYPE_NAME_Array, TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DATA_TYPE_Array, DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_SIZE_Array, COLUMN_SIZE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LITERAL_PREFIX_Array, LITERAL_PREFIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LITERAL_SUFFIX_Array, LITERAL_SUFFIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CREATE_PARAMS_Array, CREATE_PARAMS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NULLABLE_Array, NULLABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CASE_SENSITIVE_Array, CASE_SENSITIVE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SEARCHABLE_Array, SEARCHABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto UNSIGNED_ATTRIBUTE_Array, UNSIGNED_ATTRIBUTE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto FIXED_PREC_SCALE_Array, FIXED_PREC_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto AUTO_UNIQUE_VALUE_Array, AUTO_UNIQUE_VALUE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LOCAL_TYPE_NAME_Array, LOCAL_TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto MINIMUM_SCALE_Array, MINIMUM_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto MAXIMUM_SCALE_Array, MAXIMUM_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATA_TYPE_Array, SQL_DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATETIME_SUB_Array, SQL_DATETIME_SUB_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NUM_PREC_RADIX_Array, NUM_PREC_RADIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto INTERVAL_PRECISION_Array, INTERVAL_PRECISION_Builder_.Finish()) + + std::vector> arrays = { + TYPE_NAME_Array, + DATA_TYPE_Array, + COLUMN_SIZE_Array, + LITERAL_PREFIX_Array, + LITERAL_SUFFIX_Array, + CREATE_PARAMS_Array, + NULLABLE_Array, + CASE_SENSITIVE_Array, + SEARCHABLE_Array, + UNSIGNED_ATTRIBUTE_Array, + FIXED_PREC_SCALE_Array, + AUTO_UNIQUE_VALUE_Array, + LOCAL_TYPE_NAME_Array, + MINIMUM_SCALE_Array, + MAXIMUM_SCALE_Array, + SQL_DATA_TYPE_Array, + SQL_DATETIME_SUB_Array, + NUM_PREC_RADIX_Array, + INTERVAL_PRECISION_Array + }; + + const std::shared_ptr &schema = odbc_version_ == odbcabstraction::V_3 + ? GetTypeInfo_V3_Schema() + : GetTypeInfo_V2_Schema(); + return RecordBatch::Make(schema, num_rows_, arrays); +} + +Status GetTypeInfo_RecordBatchBuilder::Append( + const GetTypeInfo_RecordBatchBuilder::Data &data) { + ARROW_RETURN_NOT_OK(AppendToBuilder(TYPE_NAME_Builder_, data.type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(DATA_TYPE_Builder_, data.data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_SIZE_Builder_, data.column_size)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LITERAL_PREFIX_Builder_, data.literal_prefix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LITERAL_SUFFIX_Builder_, data.literal_suffix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(CREATE_PARAMS_Builder_, data.create_params)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NULLABLE_Builder_, data.nullable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(CASE_SENSITIVE_Builder_, data.case_sensitive)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SEARCHABLE_Builder_, data.searchable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(UNSIGNED_ATTRIBUTE_Builder_, data.unsigned_attribute)); + ARROW_RETURN_NOT_OK(AppendToBuilder(FIXED_PREC_SCALE_Builder_, data.fixed_prec_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(AUTO_UNIQUE_VALUE_Builder_, data.auto_unique_value)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LOCAL_TYPE_NAME_Builder_, data.local_type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(MINIMUM_SCALE_Builder_, data.minimum_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(MAXIMUM_SCALE_Builder_, data.maximum_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SQL_DATA_TYPE_Builder_, data.sql_data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SQL_DATETIME_SUB_Builder_, data.sql_datetime_sub)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NUM_PREC_RADIX_Builder_, data.num_prec_radix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(INTERVAL_PRECISION_Builder_, data.interval_precision)); + num_rows_++; + + return Status::OK(); +} + +GetTypeInfo_Transformer::GetTypeInfo_Transformer( + const MetadataSettings& metadata_settings, + const odbcabstraction::OdbcVersion odbc_version, + int data_type) + : metadata_settings_(metadata_settings), + odbc_version_(odbc_version), + data_type_(data_type) { +} + +std::shared_ptr GetTypeInfo_Transformer::Transform( + const std::shared_ptr &original) { + const Result> &result = + Transform_inner(odbc_version_, original, data_type_, metadata_settings_); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); +} + +std::shared_ptr GetTypeInfo_Transformer::GetTransformedSchema() { + return odbc_version_ == odbcabstraction::V_3 ? GetTypeInfo_V3_Schema() + : GetTypeInfo_V2_Schema(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.h new file mode 100644 index 0000000000000..3d6d34df40994 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.h @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "record_batch_transformer.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::MetadataSettings; +using std::optional; + +class GetTypeInfo_RecordBatchBuilder { +private: + odbcabstraction::OdbcVersion odbc_version_; + + StringBuilder TYPE_NAME_Builder_; + Int16Builder DATA_TYPE_Builder_; + Int32Builder COLUMN_SIZE_Builder_; + StringBuilder LITERAL_PREFIX_Builder_; + StringBuilder LITERAL_SUFFIX_Builder_; + StringBuilder CREATE_PARAMS_Builder_; + Int16Builder NULLABLE_Builder_; + Int16Builder CASE_SENSITIVE_Builder_; + Int16Builder SEARCHABLE_Builder_; + Int16Builder UNSIGNED_ATTRIBUTE_Builder_; + Int16Builder FIXED_PREC_SCALE_Builder_; + Int16Builder AUTO_UNIQUE_VALUE_Builder_; + StringBuilder LOCAL_TYPE_NAME_Builder_; + Int16Builder MINIMUM_SCALE_Builder_; + Int16Builder MAXIMUM_SCALE_Builder_; + Int16Builder SQL_DATA_TYPE_Builder_; + Int16Builder SQL_DATETIME_SUB_Builder_; + Int32Builder NUM_PREC_RADIX_Builder_; + Int16Builder INTERVAL_PRECISION_Builder_; + int64_t num_rows_{0}; + +public: + struct Data { + std::string type_name; + int16_t data_type; + optional column_size; + optional literal_prefix; + optional literal_suffix; + optional create_params; + int16_t nullable; + int16_t case_sensitive; + int16_t searchable; + optional unsigned_attribute; + int16_t fixed_prec_scale; + optional auto_unique_value; + optional local_type_name; + optional minimum_scale; + optional maximum_scale; + int16_t sql_data_type; + optional sql_datetime_sub; + optional num_prec_radix; + optional interval_precision; + }; + + explicit GetTypeInfo_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version); + + Result> Build(); + + Status Append(const Data &data); +}; + +class GetTypeInfo_Transformer : public RecordBatchTransformer { +private: + const MetadataSettings& metadata_settings_; + odbcabstraction::OdbcVersion odbc_version_; + int data_type_; + +public: + explicit GetTypeInfo_Transformer(const MetadataSettings& metadata_settings, + odbcabstraction::OdbcVersion odbc_version, + int data_type); + + std::shared_ptr + Transform(const std::shared_ptr &original) override; + + std::shared_ptr GetTransformedSchema() override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc new file mode 100644 index 0000000000000..f771c8cc80169 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_stream_chunk_buffer.h" +#include "utils.h" + + +namespace driver { +namespace flight_sql { + +using arrow::flight::FlightEndpoint; + +FlightStreamChunkBuffer::FlightStreamChunkBuffer(FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + size_t queue_capacity): queue_(queue_capacity) { + + // FIXME: Endpoint iteration should consider endpoints may be at different hosts + for (const auto & endpoint : flight_info->endpoints()) { + const arrow::flight::Ticket &ticket = endpoint.ticket; + + auto result = flight_sql_client.DoGet(call_options, ticket); + ThrowIfNotOK(result.status()); + std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); + + BlockingQueue>::Supplier supplier = [=] { + auto result = stream_reader_ptr->Next(); + bool isNotOk = !result.ok(); + bool isNotEmpty = result.ok() && (result.ValueOrDie().data != nullptr); + + return boost::make_optional(isNotOk || isNotEmpty, std::move(result)); + }; + queue_.AddProducer(std::move(supplier)); + } +} + +bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk *chunk) { + Result result; + if (!queue_.Pop(&result)) { + return false; + } + + if (!result.status().ok()) { + Close(); + throw odbcabstraction::DriverException(result.status().message()); + } + *chunk = std::move(result.ValueOrDie()); + return chunk->data != nullptr; +} + +void FlightStreamChunkBuffer::Close() { + queue_.Close(); +} + +FlightStreamChunkBuffer::~FlightStreamChunkBuffer() { + Close(); +} + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h new file mode 100644 index 0000000000000..3f452770c4829 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightInfo; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using arrow::flight::sql::FlightSqlClient; +using driver::odbcabstraction::BlockingQueue; + +class FlightStreamChunkBuffer { + BlockingQueue> queue_; + +public: + FlightStreamChunkBuffer(FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + size_t queue_capacity = 5); + + ~FlightStreamChunkBuffer(); + + void Close(); + + bool GetNext(FlightStreamChunk* chunk); + +}; + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc new file mode 100644 index 0000000000000..ddefb3db561b6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc @@ -0,0 +1,1361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "get_info_cache.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "flight_sql_stream_chunk_buffer.h" +#include "scalar_function_reporter.h" +#include "utils.h" + +// Aliases for entries in SqlInfoOptions::SqlInfo that are defined here +// due to causing compilation errors conflicting with ODBC definitions. +#define ARROW_SQL_IDENTIFIER_CASE 503 +#define ARROW_SQL_IDENTIFIER_QUOTE_CHAR 504 +#define ARROW_SQL_QUOTED_IDENTIFIER_CASE 505 +#define ARROW_SQL_KEYWORDS 508 +#define ARROW_SQL_NUMERIC_FUNCTIONS 509 +#define ARROW_SQL_STRING_FUNCTIONS 510 +#define ARROW_SQL_SYSTEM_FUNCTIONS 511 +#define ARROW_SQL_SCHEMA_TERM 529 +#define ARROW_SQL_PROCEDURE_TERM 530 +#define ARROW_SQL_CATALOG_TERM 531 +#define ARROW_SQL_MAX_COLUMNS_IN_GROUP_BY 544 +#define ARROW_SQL_MAX_COLUMNS_IN_INDEX 545 +#define ARROW_SQL_MAX_COLUMNS_IN_ORDER_BY 546 +#define ARROW_SQL_MAX_COLUMNS_IN_SELECT 547 +#define ARROW_SQL_MAX_COLUMNS_IN_TABLE 548 +#define ARROW_SQL_MAX_ROW_SIZE 555 +#define ARROW_SQL_MAX_TABLES_IN_SELECT 560 + +#define ARROW_CONVERT_BIGINT 0 +#define ARROW_CONVERT_BINARY 1 +#define ARROW_CONVERT_BIT 2 +#define ARROW_CONVERT_CHAR 3 +#define ARROW_CONVERT_DATE 4 +#define ARROW_CONVERT_DECIMAL 5 +#define ARROW_CONVERT_FLOAT 6 +#define ARROW_CONVERT_INTEGER 7 +#define ARROW_CONVERT_INTERVAL_DAY_TIME 8 +#define ARROW_CONVERT_INTERVAL_YEAR_MONTH 9 +#define ARROW_CONVERT_LONGVARBINARY 10 +#define ARROW_CONVERT_LONGVARCHAR 11 +#define ARROW_CONVERT_NUMERIC 12 +#define ARROW_CONVERT_REAL 13 +#define ARROW_CONVERT_SMALLINT 14 +#define ARROW_CONVERT_TIME 15 +#define ARROW_CONVERT_TIMESTAMP 16 +#define ARROW_CONVERT_TINYINT 17 +#define ARROW_CONVERT_VARBINARY 18 +#define ARROW_CONVERT_VARCHAR 19 + +namespace { +// Return the corresponding field in SQLGetInfo's SQL_CONVERT_* field +// types for the given Arrow SqlConvert enum value. +// +// The caller is responsible for casting the result to a uint16. Note +// that -1 is returned if there's no corresponding entry. +int32_t GetInfoTypeForArrowConvertEntry(int32_t convert_entry) { + switch (convert_entry) { + case ARROW_CONVERT_BIGINT: + return SQL_CONVERT_BIGINT; + case ARROW_CONVERT_BINARY: + return SQL_CONVERT_BINARY; + case ARROW_CONVERT_BIT: + return SQL_CONVERT_BIT; + case ARROW_CONVERT_CHAR: + return SQL_CONVERT_CHAR; + case ARROW_CONVERT_DATE: + return SQL_CONVERT_DATE; + case ARROW_CONVERT_DECIMAL: + return SQL_CONVERT_DECIMAL; + case ARROW_CONVERT_FLOAT: + return SQL_CONVERT_FLOAT; + case ARROW_CONVERT_INTEGER: + return SQL_CONVERT_INTEGER; + case ARROW_CONVERT_INTERVAL_DAY_TIME: + return SQL_CONVERT_INTERVAL_DAY_TIME; + case ARROW_CONVERT_INTERVAL_YEAR_MONTH: + return SQL_CONVERT_INTERVAL_YEAR_MONTH; + case ARROW_CONVERT_LONGVARBINARY: + return SQL_CONVERT_LONGVARBINARY; + case ARROW_CONVERT_LONGVARCHAR: + return SQL_CONVERT_LONGVARCHAR; + case ARROW_CONVERT_NUMERIC: + return SQL_CONVERT_NUMERIC; + case ARROW_CONVERT_REAL: + return SQL_CONVERT_REAL; + case ARROW_CONVERT_SMALLINT: + return SQL_CONVERT_SMALLINT; + case ARROW_CONVERT_TIME: + return SQL_CONVERT_TIME; + case ARROW_CONVERT_TIMESTAMP: + return SQL_CONVERT_TIMESTAMP; + case ARROW_CONVERT_TINYINT: + return SQL_CONVERT_TINYINT; + case ARROW_CONVERT_VARBINARY: + return SQL_CONVERT_VARBINARY; + case ARROW_CONVERT_VARCHAR: + return SQL_CONVERT_VARCHAR; + } + // Arbitrarily return a negative value + return -1; +} + +// Return the corresponding bitmask to OR in SQLGetInfo's SQL_CONVERT_* field +// value for the given Arrow SqlConvert enum value. +// +// This is _not_ a bit position, it is an integer with only a single bit set. +uint32_t GetCvtBitForArrowConvertEntry(int32_t convert_entry) { + switch (convert_entry) { + case ARROW_CONVERT_BIGINT: + return SQL_CVT_BIGINT; + case ARROW_CONVERT_BINARY: + return SQL_CVT_BINARY; + case ARROW_CONVERT_BIT: + return SQL_CVT_BIT; + case ARROW_CONVERT_CHAR: + return SQL_CVT_CHAR | SQL_CVT_WCHAR; + case ARROW_CONVERT_DATE: + return SQL_CVT_DATE; + case ARROW_CONVERT_DECIMAL: + return SQL_CVT_DECIMAL; + case ARROW_CONVERT_FLOAT: + return SQL_CVT_FLOAT; + case ARROW_CONVERT_INTEGER: + return SQL_CVT_INTEGER; + case ARROW_CONVERT_INTERVAL_DAY_TIME: + return SQL_CVT_INTERVAL_DAY_TIME; + case ARROW_CONVERT_INTERVAL_YEAR_MONTH: + return SQL_CVT_INTERVAL_YEAR_MONTH; + case ARROW_CONVERT_LONGVARBINARY: + return SQL_CVT_LONGVARBINARY; + case ARROW_CONVERT_LONGVARCHAR: + return SQL_CVT_LONGVARCHAR | SQL_CVT_WLONGVARCHAR; + case ARROW_CONVERT_NUMERIC: + return SQL_CVT_NUMERIC; + case ARROW_CONVERT_REAL: + return SQL_CVT_REAL; + case ARROW_CONVERT_SMALLINT: + return SQL_CVT_SMALLINT; + case ARROW_CONVERT_TIME: + return SQL_CVT_TIME; + case ARROW_CONVERT_TIMESTAMP: + return SQL_CVT_TIMESTAMP; + case ARROW_CONVERT_TINYINT: + return SQL_CVT_TINYINT; + case ARROW_CONVERT_VARBINARY: + return SQL_CVT_VARBINARY; + case ARROW_CONVERT_VARCHAR: + return SQL_CVT_VARCHAR | SQL_CVT_WLONGVARCHAR; + } + // Note: GUID not supported by GetSqlInfo. + // Return zero, which has no bits set. + return 0; +} + +inline int32_t ScalarToInt32(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value; +} + +inline int64_t ScalarToInt64(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value; +} + +inline std::string ScalarToBoolString(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value ? "Y" : "N"; +} + +inline void SetDefaultIfMissing(std::unordered_map& cache, + uint16_t info_type, driver::odbcabstraction::Connection::Info default_value) { + // Note: emplace() only writes if the key isn't found. + cache.emplace(info_type, std::move(default_value)); +} + +} // namespace + +namespace driver { +namespace flight_sql { +using namespace arrow::flight::sql; +using namespace arrow::flight; +using namespace driver::odbcabstraction; + +GetInfoCache::GetInfoCache(FlightCallOptions &call_options, + std::unique_ptr &client, const std::string &driver_version) + : call_options_(call_options), sql_client_(client), + has_server_info_(false) { + info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; + info_[SQL_DRIVER_VER] = ConvertToDBMSVer(driver_version); + + info_[SQL_GETDATA_EXTENSIONS] = + static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER); + info_[SQL_CURSOR_SENSITIVITY] = static_cast(SQL_UNSPECIFIED); + + // Properties which don't currently have SqlGetInfo fields but probably + // should. + info_[SQL_ACCESSIBLE_PROCEDURES] = "N"; + info_[SQL_COLLATION_SEQ] = ""; + info_[SQL_ALTER_DOMAIN] = static_cast(0); + info_[SQL_ALTER_TABLE] = static_cast(0); + info_[SQL_COLUMN_ALIAS] = "Y"; + info_[SQL_DATETIME_LITERALS] = static_cast( + SQL_DL_SQL92_DATE | SQL_DL_SQL92_TIME | SQL_DL_SQL92_TIMESTAMP); + info_[SQL_CREATE_ASSERTION] = static_cast(0); + info_[SQL_CREATE_CHARACTER_SET] = static_cast(0); + info_[SQL_CREATE_COLLATION] = static_cast(0); + info_[SQL_CREATE_DOMAIN] = static_cast(0); + info_[SQL_INDEX_KEYWORDS] = static_cast(SQL_IK_NONE); + info_[SQL_TIMEDATE_ADD_INTERVALS] = static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR); + info_[SQL_TIMEDATE_DIFF_INTERVALS] = static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR); + info_[SQL_CURSOR_COMMIT_BEHAVIOR] = static_cast(SQL_CB_CLOSE); + info_[SQL_CURSOR_ROLLBACK_BEHAVIOR] = static_cast(SQL_CB_CLOSE); + info_[SQL_CREATE_TRANSLATION] = static_cast(0); + info_[SQL_DDL_INDEX] = static_cast(0); + info_[SQL_DROP_ASSERTION] = static_cast(0); + info_[SQL_DROP_CHARACTER_SET] = static_cast(0); + info_[SQL_DROP_COLLATION] = static_cast(0); + info_[SQL_DROP_DOMAIN] = static_cast(0); + info_[SQL_DROP_SCHEMA] = static_cast(0); + info_[SQL_DROP_TABLE] = static_cast(0); + info_[SQL_DROP_TRANSLATION] = static_cast(0); + info_[SQL_DROP_VIEW] = static_cast(0); + info_[SQL_MAX_IDENTIFIER_LEN] = static_cast(65535); // arbitrary + + // Assume all aggregate functions reported in ODBC are supported. + info_[SQL_AGGREGATE_FUNCTIONS] = static_cast( + SQL_AF_ALL | SQL_AF_AVG | SQL_AF_COUNT | SQL_AF_DISTINCT | SQL_AF_MAX | + SQL_AF_MIN | SQL_AF_SUM); + + // Assume catalogs are not supported by default. ODBC checks if SQL_CATALOG_NAME is + // "Y" or "N" to determine if catalogs are supported. + info_[SQL_CATALOG_TERM] = ""; + info_[SQL_CATALOG_NAME] = "N"; + info_[SQL_CATALOG_NAME_SEPARATOR] = ""; + info_[SQL_CATALOG_LOCATION] = static_cast(0); +} + +void GetInfoCache::SetProperty( + uint16_t property, driver::odbcabstraction::Connection::Info value) { + info_[property] = value; +} + +Connection::Info GetInfoCache::GetInfo(uint16_t info_type) { + auto it = info_.find(info_type); + + if (info_.end() == it) { + if (LoadInfoFromServer()) { + it = info_.find(info_type); + } + if (info_.end() == it) { + throw DriverException("Unknown GetInfo type: " + + std::to_string(info_type)); + } + } + return it->second; +} + +bool GetInfoCache::LoadInfoFromServer() { + if (sql_client_ && !has_server_info_.exchange(true)) { + std::unique_lock lock(mutex_); + arrow::Result> result = + sql_client_->GetSqlInfo(call_options_, {}); + ThrowIfNotOK(result.status()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, + result.ValueOrDie()); + + FlightStreamChunk chunk; + bool supports_correlation_name = false; + bool requires_different_correlation_name = false; + bool transactions_supported = false; + bool transaction_ddl_commit = false; + bool transaction_ddl_ignore = false; + while (chunk_iter.GetNext(&chunk)) { + auto name_array = chunk.data->GetColumnByName("info_name"); + auto value_array = chunk.data->GetColumnByName("value"); + + arrow::UInt32Array *info_type_array = + static_cast(name_array.get()); + arrow::UnionArray *value_union_array = + static_cast(value_array.get()); + for (int64_t i = 0; i < chunk.data->num_rows(); ++i) { + if (!value_array->IsNull(i)) { + auto info_type = + static_cast( + info_type_array->Value(i)); + auto result_scalar = value_union_array->GetScalar(i); + ThrowIfNotOK(result_scalar.status()); + std::shared_ptr scalar_ptr = + result_scalar.ValueOrDie(); + arrow::UnionScalar *scalar = + reinterpret_cast(scalar_ptr.get()); + switch (info_type) { + // String properties + case SqlInfoOptions::FLIGHT_SQL_SERVER_NAME: { + std::string server_name(reinterpret_cast(scalar->value.get())->view()); + + // TODO: Consider creating different properties in GetSqlInfo. + // TODO: Investigate if SQL_SERVER_NAME should just be the host + // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for + // the DatabaseProductName. + info_[SQL_SERVER_NAME] = server_name; + info_[SQL_DBMS_NAME] = server_name; + info_[SQL_DATABASE_NAME] = + server_name; // This is usually the current catalog. May need to + // throw HYC00 instead. + break; + } + case SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION: { + info_[SQL_DBMS_VER] = ConvertToDBMSVer( + std::string(reinterpret_cast(scalar->value.get())->view())); + break; + } + case SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION: { + // Unused. + break; + } + case SqlInfoOptions::SQL_SEARCH_STRING_ESCAPE: { + info_[SQL_SEARCH_PATTERN_ESCAPE] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_IDENTIFIER_QUOTE_CHAR: { + info_[SQL_IDENTIFIER_QUOTE_CHAR] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case SqlInfoOptions::SQL_EXTRA_NAME_CHARACTERS: { + info_[SQL_SPECIAL_CHARACTERS] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_SCHEMA_TERM: { + info_[SQL_SCHEMA_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_PROCEDURE_TERM: { + info_[SQL_PROCEDURE_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_CATALOG_TERM: { + std::string catalog_term(std::string(reinterpret_cast(scalar->value.get())->view())); + if (catalog_term.empty()) { + info_[SQL_CATALOG_NAME] = "N"; + info_[SQL_CATALOG_NAME_SEPARATOR] = ""; + info_[SQL_CATALOG_LOCATION] = static_cast(0); + } else { + info_[SQL_CATALOG_NAME] = "Y"; + info_[SQL_CATALOG_NAME_SEPARATOR] = "."; + info_[SQL_CATALOG_LOCATION] = static_cast(SQL_CL_START); + } + info_[SQL_CATALOG_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + + break; + } + + // Bool properties + case SqlInfoOptions::FLIGHT_SQL_SERVER_READ_ONLY: { + info_[SQL_DATA_SOURCE_READ_ONLY] = ScalarToBoolString(scalar); + + // Assume all forms of insert are supported, however this should + // come from a property. + info_[SQL_INSERT_STATEMENT] = static_cast( + SQL_IS_INSERT_LITERALS | SQL_IS_INSERT_SEARCHED | + SQL_IS_SELECT_INTO); + break; + } + case SqlInfoOptions::SQL_DDL_CATALOG: + // Unused by ODBC. + break; + case SqlInfoOptions::SQL_DDL_SCHEMA: { + bool supports_schema_ddl = + reinterpret_cast(scalar->value.get())->value; + // Note: this is a bitmask and we can't describe cascade or restrict + // flags. + info_[SQL_DROP_SCHEMA] = static_cast(SQL_DS_DROP_SCHEMA); + + // Note: this is a bitmask and we can't describe authorization or + // collation + info_[SQL_CREATE_SCHEMA] = + static_cast(SQL_CS_CREATE_SCHEMA); + break; + } + case SqlInfoOptions::SQL_DDL_TABLE: { + bool supports_table_ddl = + reinterpret_cast(scalar->value.get())->value; + // This is a bitmask and we cannot describe all clauses. + info_[SQL_CREATE_TABLE] = + static_cast(SQL_CT_CREATE_TABLE); + info_[SQL_DROP_TABLE] = static_cast(SQL_DT_DROP_TABLE); + break; + } + case SqlInfoOptions::SQL_ALL_TABLES_ARE_SELECTABLE: { + info_[SQL_ACCESSIBLE_TABLES] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_COLUMN_ALIASING: { + info_[SQL_COLUMN_ALIAS] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_NULL_PLUS_NULL_IS_NULL: { + info_[SQL_CONCAT_NULL_BEHAVIOR] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_CB_NULL + : SQL_CB_NON_NULL); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_TABLE_CORRELATION_NAMES: { + // Simply cache SQL_SUPPORTS_TABLE_CORRELATION_NAMES and + // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both + // properties to determine the value for SQL_CORRELATION_NAME. + supports_correlation_name = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES: { + // Simply cache SQL_SUPPORTS_TABLE_CORRELATION_NAMES and + // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both + // properties to determine the value for SQL_CORRELATION_NAME. + requires_different_correlation_name = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY: { + info_[SQL_EXPRESSIONS_IN_ORDERBY] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_ORDER_BY_UNRELATED: { + // Note: this is the negation of the Flight SQL property. + info_[SQL_ORDER_BY_COLUMNS_IN_SELECT] = + reinterpret_cast(scalar->value.get())->value ? "N" + : "Y"; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE: { + info_[SQL_LIKE_ESCAPE_CLAUSE] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_NON_NULLABLE_COLUMNS: { + info_[SQL_NON_NULLABLE_COLUMNS] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_NNC_NON_NULL + : SQL_NNC_NULL); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY: { + info_[SQL_INTEGRITY] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_CATALOG_AT_START: { + info_[SQL_CATALOG_LOCATION] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_CL_START + : SQL_CL_END); + break; + } + case SqlInfoOptions::SQL_SELECT_FOR_UPDATE_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_STORED_PROCEDURES_SUPPORTED: { + info_[SQL_PROCEDURES] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_MAX_ROW_SIZE_INCLUDES_BLOBS: { + info_[SQL_MAX_ROW_SIZE_INCLUDES_LONG] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_TRANSACTIONS_SUPPORTED: { + transactions_supported = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT: { + transaction_ddl_commit = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED: { + transaction_ddl_ignore = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_BATCH_UPDATES_SUPPORTED: { + info_[SQL_BATCH_SUPPORT] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_BS_ROW_COUNT_EXPLICIT + : 0); + break; + } + case SqlInfoOptions::SQL_SAVEPOINTS_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_NAMED_PARAMETERS_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_LOCATORS_UPDATE_COPY: + // Not used. + break; + case SqlInfoOptions::SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_CORRELATED_SUBQUERIES_SUPPORTED: + // Not used. This is implied by SQL_SUPPORTED_SUBQUERIES. + break; + + // Int64 properties + case ARROW_SQL_IDENTIFIER_CASE: { + // Missing from C++ enum. constant from Java. + constexpr int64_t LOWER = 3; + uint16_t value = 0; + int64_t sensitivity = ScalarToInt64(scalar); + switch (sensitivity) { + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UNKNOWN: + value = SQL_IC_SENSITIVE; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_CASE_INSENSITIVE: + value = SQL_IC_MIXED; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UPPERCASE: + value = SQL_IC_UPPER; + break; + case LOWER: + value = SQL_IC_LOWER; + break; + default: + value = SQL_IC_SENSITIVE; + break; + } + info_[SQL_IDENTIFIER_CASE] = value; + break; + } + case SqlInfoOptions::SQL_NULL_ORDERING: { + uint16_t value = 0; + int64_t scalar_value = ScalarToInt64(scalar); + switch (scalar_value) { + case SqlInfoOptions::SQL_NULLS_SORTED_AT_START: + value = SQL_NC_START; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_AT_END: + value = SQL_NC_END; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_HIGH: + value = SQL_NC_HIGH; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_LOW: + default: + value = SQL_NC_LOW; + break; + } + info_[SQL_NULL_COLLATION] = value; + break; + } + case ARROW_SQL_QUOTED_IDENTIFIER_CASE: { + // Missing from C++ enum. constant from Java. + constexpr int64_t LOWER = 3; + uint16_t value = 0; + int64_t sensitivity = ScalarToInt64(scalar); + switch (sensitivity) { + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UNKNOWN: + value = SQL_IC_SENSITIVE; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_CASE_INSENSITIVE: + value = SQL_IC_MIXED; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UPPERCASE: + value = SQL_IC_UPPER; + break; + case LOWER: + value = SQL_IC_LOWER; + break; + default: + value = SQL_IC_SENSITIVE; + break; + } + info_[SQL_QUOTED_IDENTIFIER_CASE] = value; + break; + } + case SqlInfoOptions::SQL_MAX_BINARY_LITERAL_LENGTH: { + info_[SQL_MAX_BINARY_LITERAL_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CHAR_LITERAL_LENGTH: { + info_[SQL_MAX_CHAR_LITERAL_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_COLUMN_NAME_LENGTH: { + info_[SQL_MAX_COLUMN_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_GROUP_BY: { + info_[SQL_MAX_COLUMNS_IN_GROUP_BY] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_INDEX: { + info_[SQL_MAX_COLUMNS_IN_INDEX] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_ORDER_BY: { + info_[SQL_MAX_COLUMNS_IN_ORDER_BY] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_SELECT: { + info_[SQL_MAX_COLUMNS_IN_SELECT] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_TABLE: { + info_[SQL_MAX_COLUMNS_IN_TABLE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CONNECTIONS: { + info_[SQL_MAX_DRIVER_CONNECTIONS] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CURSOR_NAME_LENGTH: { + info_[SQL_MAX_CURSOR_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_INDEX_LENGTH: { + info_[SQL_MAX_INDEX_SIZE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_SCHEMA_NAME_LENGTH: { + info_[SQL_MAX_SCHEMA_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_PROCEDURE_NAME_LENGTH: { + info_[SQL_MAX_PROCEDURE_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CATALOG_NAME_LENGTH: { + info_[SQL_MAX_CATALOG_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_ROW_SIZE: { + info_[SQL_MAX_ROW_SIZE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_STATEMENT_LENGTH: { + info_[SQL_MAX_STATEMENT_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_STATEMENTS: { + info_[SQL_MAX_CONCURRENT_ACTIVITIES] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_TABLE_NAME_LENGTH: { + info_[SQL_MAX_TABLE_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_TABLES_IN_SELECT: { + info_[SQL_MAX_TABLES_IN_SELECT] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_USERNAME_LENGTH: { + info_[SQL_MAX_USER_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_DEFAULT_TRANSACTION_ISOLATION: { + constexpr int32_t NONE = 0; + constexpr int32_t READ_UNCOMMITTED = 1; + constexpr int32_t READ_COMMITTED = 2; + constexpr int32_t REPEATABLE_READ = 3; + constexpr int32_t SERIALIZABLE = 4; + int64_t scalar_value = static_cast(ScalarToInt64(scalar)); + uint32_t result_val = 0; + if ((scalar_value & (1 << READ_UNCOMMITTED)) != 0) { + result_val = SQL_TXN_READ_UNCOMMITTED; + } else if ((scalar_value & (1 << READ_COMMITTED)) != 0) { + result_val = SQL_TXN_READ_COMMITTED; + } else if ((scalar_value & (1 << REPEATABLE_READ)) != 0) { + result_val = SQL_TXN_REPEATABLE_READ; + } else if ((scalar_value & (1 << SERIALIZABLE)) != 0) { + result_val = SQL_TXN_SERIALIZABLE; + } + info_[SQL_DEFAULT_TXN_ISOLATION] = result_val; + break; + } + + // Int32 properties + case SqlInfoOptions::SQL_SUPPORTED_GROUP_BY: { + // Note: SqlGroupBy enum is missing in C++. Using Java values. + constexpr int32_t UNRELATED = 0; + constexpr int32_t BEYOND_SELECT = 1; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint16_t result_val = SQL_GB_NOT_SUPPORTED; + if ((scalar_value & (1 << UNRELATED)) != 0) { + result_val = SQL_GB_NO_RELATION; + } else if ((scalar_value & (1 << BEYOND_SELECT)) != 0) { + result_val = SQL_GB_GROUP_BY_CONTAINS_SELECT; + } + // Note GROUP_BY_EQUALS_SELECT and COLLATE cannot be described. + info_[SQL_GROUP_BY] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_GRAMMAR: { + // Note: SupportedSqlGrammar enum is missing in C++. Using Java + // values. + constexpr int32_t MINIMUM = 0; + constexpr int32_t CORE = 1; + constexpr int32_t EXTENDED = 2; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint32_t result_val = SQL_OIC_CORE; + if ((scalar_value & (1 << MINIMUM)) != 0) { + result_val = SQL_OIC_CORE; + } else if ((scalar_value & (1 << CORE)) != 0) { + result_val = SQL_OIC_LEVEL1; + } else if ((scalar_value & (1 << EXTENDED)) != 0) { + result_val = SQL_OIC_LEVEL2; + } + info_[SQL_ODBC_API_CONFORMANCE] = result_val; + break; + } + case SqlInfoOptions::SQL_ANSI92_SUPPORTED_LEVEL: { + // Note: SupportedAnsi92SqlGrammarLevel enum is missing in C++. + // Using Java values. + constexpr int32_t ENTRY = 0; + constexpr int32_t INTERMEDIATE = 1; + constexpr int32_t FULL = 2; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint32_t result_val = SQL_SC_SQL92_ENTRY; + uint16_t odbc_sql_conformance = SQL_OSC_MINIMUM; + if ((scalar_value & (1 << ENTRY)) != 0) { + result_val = SQL_SC_SQL92_ENTRY; + } else if ((scalar_value & (1 << INTERMEDIATE)) != 0) { + result_val = SQL_SC_SQL92_INTERMEDIATE; + odbc_sql_conformance = SQL_OSC_CORE; + } else if ((scalar_value & (1 << FULL)) != 0) { + result_val = SQL_SC_SQL92_FULL; + odbc_sql_conformance = SQL_OSC_EXTENDED; + } + info_[SQL_SQL_CONFORMANCE] = result_val; + info_[SQL_ODBC_SQL_CONFORMANCE] = odbc_sql_conformance; + break; + } + case SqlInfoOptions::SQL_OUTER_JOINS_SUPPORT_LEVEL: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // If limited outer joins is supported, we can't tell which joins + // are supported so just report none. If full outer joins is + // supported, nested joins are supported and full outer joins are + // supported, so all joins + nested are supported. + constexpr int32_t UNSUPPORTED = 0; + constexpr int32_t LIMITED = 1; + constexpr int32_t FULL = 2; + uint32_t result_val = 0; + // Assume inner and cross joins are supported. Flight SQL can't + // report this currently. + uint32_t relational_operators = + SQL_SRJO_CROSS_JOIN | SQL_SRJO_INNER_JOIN; + if ((scalar_value & (1 << FULL)) != 0) { + result_val = SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL | SQL_OJ_NESTED; + relational_operators |= SQL_SRJO_FULL_OUTER_JOIN | + SQL_SRJO_LEFT_OUTER_JOIN | + SQL_SRJO_RIGHT_OUTER_JOIN; + } else if ((scalar_value & (1 << LIMITED)) != 0) { + result_val = SQL_SC_SQL92_INTERMEDIATE; + } else if ((scalar_value & (1 << UNSUPPORTED)) != 0) { + result_val = 0; + } + info_[SQL_OJ_CAPABILITIES] = result_val; + info_[SQL_OUTER_JOINS] = result_val != 0 ? "Y" : "N"; + info_[SQL_SQL92_RELATIONAL_JOIN_OPERATORS] = relational_operators; + break; + } + case SqlInfoOptions::SQL_SCHEMAS_SUPPORTED_ACTIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t PROCEDURE = 0; + constexpr int32_t INDEX = 1; + constexpr int32_t PRIVILEGE = 2; + // Assume schemas are supported in DML and Table manipulation. + uint32_t result_val = SQL_SU_DML_STATEMENTS | SQL_SU_TABLE_DEFINITION; + if ((scalar_value & (1 << PROCEDURE)) != 0) { + result_val |= SQL_SU_PROCEDURE_INVOCATION; + } + if ((scalar_value & (1 << INDEX)) != 0) { + result_val |= SQL_SU_INDEX_DEFINITION; + } + if ((scalar_value & (1 << PRIVILEGE)) != 0) { + result_val |= SQL_SU_PRIVILEGE_DEFINITION; + } + info_[SQL_SCHEMA_USAGE] = result_val; + break; + } + case SqlInfoOptions::SQL_CATALOGS_SUPPORTED_ACTIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t PROCEDURE = 0; + constexpr int32_t INDEX = 1; + constexpr int32_t PRIVILEGE = 2; + // Assume catalogs are supported in DML and Table manipulation. + uint32_t result_val = SQL_CU_DML_STATEMENTS | SQL_CU_TABLE_DEFINITION; + if ((scalar_value & (1 << PROCEDURE)) != 0) { + result_val |= SQL_CU_PROCEDURE_INVOCATION; + } + if ((scalar_value & (1 << INDEX)) != 0) { + result_val |= SQL_CU_INDEX_DEFINITION; + } + if ((scalar_value & (1 << PRIVILEGE)) != 0) { + result_val |= SQL_CU_PRIVILEGE_DEFINITION; + } + info_[SQL_CATALOG_USAGE] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_POSITIONED_COMMANDS: { + // Ignore, positioned updates/deletes unsupported. + break; + } + case SqlInfoOptions::SQL_SUPPORTED_SUBQUERIES: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t COMPARISONS = 0; + constexpr int32_t EXISTS = 1; + constexpr int32_t INN = 2; + constexpr int32_t QUANTIFIEDS = 3; + uint32_t result_val = 0; + if ((scalar_value & (1 << COMPARISONS)) != 0) { + result_val |= SQL_SQ_COMPARISON; + } + if ((scalar_value & (1 << EXISTS)) != 0) { + result_val |= SQL_SQ_EXISTS; + } + if ((scalar_value & (1 << INN)) != 0) { + result_val |= SQL_SQ_IN; + } + if ((scalar_value & (1 << QUANTIFIEDS)) != 0) { + result_val |= SQL_SQ_QUANTIFIED; + } + info_[SQL_SUBQUERIES] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_UNIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing enum in C++. Values taken from java. + constexpr int32_t UNION = 0; + constexpr int32_t UNION_ALL = 1; + uint32_t result_val = 0; + if ((scalar_value & (1 << UNION)) != 0) { + result_val |= SQL_U_UNION; + } + if ((scalar_value & (1 << UNION_ALL)) != 0) { + result_val |= SQL_U_UNION_ALL; + } + info_[SQL_UNION] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing enum in C++. Values taken from java. + constexpr int32_t NONE = 0; + constexpr int32_t READ_UNCOMMITTED = 1; + constexpr int32_t READ_COMMITTED = 2; + constexpr int32_t REPEATABLE_READ = 3; + constexpr int32_t SERIALIZABLE = 4; + uint32_t result_val = 0; + if ((scalar_value & (1 << NONE)) != 0) { + result_val = 0; + } + if ((scalar_value & (1 << READ_UNCOMMITTED)) != 0) { + result_val |= SQL_TXN_READ_UNCOMMITTED; + } + if ((scalar_value & (1 << READ_COMMITTED)) != 0) { + result_val |= SQL_TXN_READ_COMMITTED; + } + if ((scalar_value & (1 << REPEATABLE_READ)) != 0) { + result_val |= SQL_TXN_REPEATABLE_READ; + } + if ((scalar_value & (1 << SERIALIZABLE)) != 0) { + result_val |= SQL_TXN_SERIALIZABLE; + } + info_[SQL_TXN_ISOLATION_OPTION] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_RESULT_SET_TYPES: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE: + // Ignored. Warpdrive supports forward-only only. + break; + + // List properties + case ARROW_SQL_NUMERIC_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportNumericFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_NUMERIC_FUNCTIONS] = result_val; + break; + } + + case ARROW_SQL_STRING_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportStringFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_STRING_FUNCTIONS] = result_val; + break; + } + case ARROW_SQL_SYSTEM_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t sys_result = 0; + uint32_t convert_result = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportSystemFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + sys_result, convert_result); + } + } + info_[SQL_CONVERT_FUNCTIONS] = convert_result; + info_[SQL_SYSTEM_FUNCTIONS] = sys_result; + break; + } + case SqlInfoOptions::SQL_DATETIME_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportDatetimeFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_TIMEDATE_FUNCTIONS] = result_val; + break; + } + + case ARROW_SQL_KEYWORDS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + std::string result_str; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + if (list_index != 0) { + result_str += ", "; + } + + result_str += reinterpret_cast(list_value.get()) + ->GetString(list_index); + } + } + info_[SQL_KEYWORDS] = std::move(result_str); + break; + } + + // Map properties + case SqlInfoOptions::SQL_SUPPORTS_CONVERT: { + arrow::MapScalar *map_scalar = + reinterpret_cast(scalar->value.get()); + auto data_array = map_scalar->value; + arrow::StructArray *map_contents = + reinterpret_cast(data_array.get()); + auto map_keys = map_contents->field(0); + auto map_values = map_contents->field(1); + for (int64_t map_index = 0; map_index < map_contents->length(); + ++map_index) { + if (!map_values->IsNull(map_index)) { + auto map_key_scalar_ptr = + map_keys->GetScalar(map_index).ValueOrDie(); + auto map_value_scalar_ptr = + map_values->GetScalar(map_index).ValueOrDie(); + int32_t map_key_scalar = reinterpret_cast( + map_key_scalar_ptr.get()) + ->value; + auto map_value_scalar = + reinterpret_cast( + map_value_scalar_ptr.get()) + ->value; + + int32_t get_info_type = + GetInfoTypeForArrowConvertEntry(map_key_scalar); + if (get_info_type < 0) { + continue; + } + uint32_t info_bitmask_value_to_write = 0; + for (int64_t map_value_array_index = 0; + map_value_array_index < map_value_scalar->length(); + ++map_value_array_index) { + if (!map_value_scalar->IsNull(map_value_array_index)) { + auto list_entry_scalar = + map_value_scalar->GetScalar(map_value_array_index) + .ValueOrDie(); + info_bitmask_value_to_write |= GetCvtBitForArrowConvertEntry( + reinterpret_cast( + list_entry_scalar.get()) + ->value); + } + } + info_[get_info_type] = info_bitmask_value_to_write; + } + } + break; + } + + default: + // Ignore unrecognized. + break; + } + } + } + + if (transactions_supported) { + if (transaction_ddl_commit) { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DDL_COMMIT); + } else if (transaction_ddl_ignore) { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DDL_IGNORE); + } else { + // Ambiguous if this means transactions on DDL is supported or not. + // Assume not + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DML); + } + } else { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_NONE); + } + + if (supports_correlation_name) { + if (requires_different_correlation_name) { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_DIFFERENT); + } else { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_ANY); + } + } else { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_NONE); + } + } + LoadDefaultsForMissingEntries(); + return true; + } + + return false; +} + +void GetInfoCache::LoadDefaultsForMissingEntries() { + // For safety's sake, this function does not discriminate between driver and hard-coded values. + SetDefaultIfMissing(info_, SQL_ACCESSIBLE_PROCEDURES, "N"); + SetDefaultIfMissing(info_, SQL_ACCESSIBLE_TABLES, "Y"); + SetDefaultIfMissing(info_, SQL_ACTIVE_ENVIRONMENTS, static_cast(0)); + SetDefaultIfMissing(info_, SQL_AGGREGATE_FUNCTIONS, + static_cast(SQL_AF_ALL | SQL_AF_AVG | + SQL_AF_COUNT | SQL_AF_DISTINCT | + SQL_AF_MAX | SQL_AF_MIN | + SQL_AF_SUM)); + SetDefaultIfMissing(info_, SQL_ALTER_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_ALTER_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_ASYNC_MODE, + static_cast(SQL_AM_NONE)); + SetDefaultIfMissing(info_, SQL_BATCH_ROW_COUNT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_BATCH_SUPPORT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_BOOKMARK_PERSISTENCE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CATALOG_LOCATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CATALOG_NAME, "N"); + SetDefaultIfMissing(info_, SQL_CATALOG_NAME_SEPARATOR, ""); + SetDefaultIfMissing(info_, SQL_CATALOG_TERM, ""); + SetDefaultIfMissing(info_, SQL_CATALOG_USAGE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_COLLATION_SEQ, ""); + SetDefaultIfMissing(info_, SQL_COLUMN_ALIAS, "Y"); + SetDefaultIfMissing(info_, SQL_CONCAT_NULL_BEHAVIOR, + static_cast(SQL_CB_NULL)); + SetDefaultIfMissing(info_, SQL_CONVERT_BIGINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_BINARY, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_BIT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_CHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DATE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DECIMAL, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DOUBLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_FLOAT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_GUID, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTEGER,static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_YEAR_MONTH, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_DAY_TIME, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_LONGVARBINARY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_LONGVARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_NUMERIC, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_REAL, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_SMALLINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TIME, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TIMESTAMP, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TINYINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_VARBINARY, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_VARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WVARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WLONGVARCHAR, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WLONGVARCHAR, + static_cast(SQL_FN_CVT_CAST)); + SetDefaultIfMissing(info_, SQL_CORRELATION_NAME, + static_cast(SQL_CN_NONE)); + SetDefaultIfMissing(info_, SQL_CREATE_ASSERTION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_CHARACTER_SET, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_SCHEMA, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_TRANSLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_VIEW, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CURSOR_COMMIT_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + SetDefaultIfMissing(info_, SQL_CURSOR_ROLLBACK_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + SetDefaultIfMissing(info_, SQL_CURSOR_SENSITIVITY, + static_cast(SQL_UNSPECIFIED)); + SetDefaultIfMissing(info_, SQL_DATA_SOURCE_READ_ONLY, "N"); + SetDefaultIfMissing(info_, SQL_DBMS_NAME, "Arrow Flight SQL Server"); + SetDefaultIfMissing(info_, SQL_DBMS_VER, "00.01.0000"); + SetDefaultIfMissing(info_, SQL_DDL_INDEX, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DEFAULT_TXN_ISOLATION, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_DESCRIBE_PARAMETER, "N"); + SetDefaultIfMissing(info_, SQL_DRIVER_NAME, "Arrow Flight SQL Driver"); + SetDefaultIfMissing(info_, SQL_DRIVER_ODBC_VER, "03.80"); + SetDefaultIfMissing(info_, SQL_DRIVER_VER, "00.09.0000"); + SetDefaultIfMissing(info_, SQL_DROP_ASSERTION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_CHARACTER_SET, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_COLLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_SCHEMA, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_TRANSLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_VIEW, static_cast(0)); + SetDefaultIfMissing(info_, SQL_EXPRESSIONS_IN_ORDERBY, "N"); + SetDefaultIfMissing( + info_, SQL_GETDATA_EXTENSIONS, + static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER)); + SetDefaultIfMissing(info_, SQL_GROUP_BY, + static_cast(SQL_GB_GROUP_BY_CONTAINS_SELECT)); + SetDefaultIfMissing(info_, SQL_IDENTIFIER_CASE, + static_cast(SQL_IC_MIXED)); + SetDefaultIfMissing(info_, SQL_IDENTIFIER_QUOTE_CHAR, "\""); + SetDefaultIfMissing(info_, SQL_INDEX_KEYWORDS, + static_cast(SQL_IK_NONE)); + SetDefaultIfMissing( + info_, SQL_INFO_SCHEMA_VIEWS, + static_cast(SQL_ISV_TABLES | SQL_ISV_COLUMNS | SQL_ISV_VIEWS)); + SetDefaultIfMissing(info_, SQL_INSERT_STATEMENT, + static_cast(SQL_IS_INSERT_LITERALS | + SQL_IS_INSERT_SEARCHED | + SQL_IS_SELECT_INTO)); + SetDefaultIfMissing(info_, SQL_INTEGRITY, "N"); + SetDefaultIfMissing(info_, SQL_KEYWORDS, ""); + SetDefaultIfMissing(info_, SQL_LIKE_ESCAPE_CLAUSE, "Y"); + SetDefaultIfMissing(info_, SQL_MAX_ASYNC_CONCURRENT_STATEMENTS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_BINARY_LITERAL_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CATALOG_NAME_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CHAR_LITERAL_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMN_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_GROUP_BY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_INDEX, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_ORDER_BY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_SELECT, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_TABLE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CURSOR_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_DRIVER_CONNECTIONS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_IDENTIFIER_LEN, + static_cast(65535)); + SetDefaultIfMissing(info_, SQL_MAX_INDEX_SIZE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_PROCEDURE_NAME_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_ROW_SIZE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_ROW_SIZE_INCLUDES_LONG, "N"); + SetDefaultIfMissing(info_, SQL_MAX_SCHEMA_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_STATEMENT_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_TABLE_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_TABLES_IN_SELECT, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_USER_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_NON_NULLABLE_COLUMNS, + static_cast(SQL_NNC_NULL)); + SetDefaultIfMissing(info_, SQL_NULL_COLLATION, + static_cast(SQL_NC_END)); + SetDefaultIfMissing(info_, SQL_NUMERIC_FUNCTIONS, static_cast(0)); + SetDefaultIfMissing( + info_, SQL_OJ_CAPABILITIES, + static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL)); + SetDefaultIfMissing(info_, SQL_ORDER_BY_COLUMNS_IN_SELECT, "Y"); + SetDefaultIfMissing(info_, SQL_PROCEDURE_TERM, ""); + SetDefaultIfMissing(info_, SQL_PROCEDURES, "N"); + SetDefaultIfMissing(info_, SQL_QUOTED_IDENTIFIER_CASE, + static_cast(SQL_IC_SENSITIVE)); + SetDefaultIfMissing(info_, SQL_SCHEMA_TERM, "schema"); + SetDefaultIfMissing(info_, SQL_SCHEMA_USAGE, + static_cast(SQL_SU_DML_STATEMENTS)); + SetDefaultIfMissing(info_, SQL_SEARCH_PATTERN_ESCAPE, "\\"); + SetDefaultIfMissing(info_, SQL_SERVER_NAME, + "Arrow Flight SQL Server"); // This might actually need to be the hostname. + SetDefaultIfMissing(info_, SQL_SQL_CONFORMANCE, + static_cast(SQL_SC_SQL92_ENTRY)); + SetDefaultIfMissing(info_, SQL_SQL92_DATETIME_FUNCTIONS, + static_cast(SQL_SDF_CURRENT_DATE | + SQL_SDF_CURRENT_TIME | + SQL_SDF_CURRENT_TIMESTAMP)); + SetDefaultIfMissing(info_, SQL_SQL92_FOREIGN_KEY_DELETE_RULE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_FOREIGN_KEY_UPDATE_RULE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_GRANT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_NUMERIC_VALUE_FUNCTIONS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_PREDICATES, + static_cast(SQL_SP_BETWEEN | SQL_SP_COMPARISON | + SQL_SP_EXISTS | SQL_SP_IN | + SQL_SP_ISNOTNULL | SQL_SP_ISNULL | + SQL_SP_LIKE)); + SetDefaultIfMissing(info_, SQL_SQL92_RELATIONAL_JOIN_OPERATORS, + static_cast( + SQL_SRJO_INNER_JOIN | SQL_SRJO_CROSS_JOIN | + SQL_SRJO_LEFT_OUTER_JOIN | SQL_SRJO_FULL_OUTER_JOIN | + SQL_SRJO_RIGHT_OUTER_JOIN)); + SetDefaultIfMissing(info_, SQL_SQL92_REVOKE, static_cast(0)); + SetDefaultIfMissing( + info_, SQL_SQL92_ROW_VALUE_CONSTRUCTOR, + static_cast(SQL_SRVC_VALUE_EXPRESSION | SQL_SRVC_NULL)); + SetDefaultIfMissing( + info_, SQL_SQL92_STRING_FUNCTIONS, + static_cast(SQL_SSF_CONVERT | SQL_SSF_LOWER | SQL_SSF_UPPER | + SQL_SSF_SUBSTRING | SQL_SSF_TRIM_BOTH | + SQL_SSF_TRIM_LEADING | SQL_SSF_TRIM_TRAILING)); + SetDefaultIfMissing(info_, SQL_SQL92_VALUE_EXPRESSIONS, + static_cast(SQL_SVE_CASE | SQL_SVE_CAST | + SQL_SVE_COALESCE | SQL_SVE_NULLIF)); + SetDefaultIfMissing(info_, SQL_STANDARD_CLI_CONFORMANCE, + static_cast(0)); + SetDefaultIfMissing( + info_, SQL_STRING_FUNCTIONS, + static_cast(SQL_FN_STR_CONCAT | SQL_FN_STR_LCASE | + SQL_FN_STR_LENGTH | SQL_FN_STR_LTRIM | + SQL_FN_STR_RTRIM | SQL_FN_STR_SPACE | + SQL_FN_STR_SUBSTRING | SQL_FN_STR_UCASE)); + SetDefaultIfMissing(info_, SQL_SUBQUERIES, + static_cast(SQL_SQ_CORRELATED_SUBQUERIES | + SQL_SQ_COMPARISON | SQL_SQ_EXISTS | + SQL_SQ_IN | SQL_SQ_QUANTIFIED)); + SetDefaultIfMissing( + info_, SQL_SYSTEM_FUNCTIONS, + static_cast(SQL_FN_SYS_IFNULL | SQL_FN_SYS_USERNAME)); + SetDefaultIfMissing(info_, SQL_TIMEDATE_ADD_INTERVALS, + static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing(info_, SQL_TIMEDATE_DIFF_INTERVALS, + static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing(info_, SQL_UNION, + static_cast(SQL_U_UNION | SQL_U_UNION_ALL)); + SetDefaultIfMissing(info_, SQL_XOPEN_CLI_YEAR, "1995"); + SetDefaultIfMissing(info_, SQL_ODBC_SQL_CONFORMANCE, + static_cast(SQL_OSC_MINIMUM)); + SetDefaultIfMissing(info_, SQL_ODBC_SAG_CLI_CONFORMANCE, + static_cast(SQL_OSCC_COMPLIANT)); + } + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h new file mode 100644 index 0000000000000..0a4ebf15b3248 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace sql { +class FlightSqlClient; +} +} // namespace flight +} // namespace arrow + +namespace driver { +namespace flight_sql { + +class GetInfoCache { + +private: + std::unordered_map info_; + arrow::flight::FlightCallOptions &call_options_; + std::unique_ptr &sql_client_; + std::mutex mutex_; + std::atomic has_server_info_; + +public: + GetInfoCache(arrow::flight::FlightCallOptions &call_options, + std::unique_ptr &client, + const std::string &driver_version); + void SetProperty(uint16_t property, + driver::odbcabstraction::Connection::Info value); + driver::odbcabstraction::Connection::Info GetInfo(uint16_t info_type); + +private: + bool LoadInfoFromServer(); + void LoadDefaultsForMissingEntries(); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h new file mode 100644 index 0000000000000..85c8753a0639d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "winuser.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +#define TRUE_STR "true" +#define FALSE_STR "false" + +/** + * ODBC configuration abstraction. + */ +class Configuration +{ +public: + /** + * Default constructor. + */ + Configuration(); + + /** + * Destructor. + */ + ~Configuration(); + + /** + * Convert configure to connect string. + * + * @return Connect string. + */ + std::string ToConnectString() const; + + void LoadDefaults(); + void LoadDsn(const std::string& dsn); + + void Clear(); + bool IsSet(const std::string& key) const; + const std::string& Get(const std::string& key) const; + void Set(const std::string& key, const std::string& value); + + /** + * Get properties map. + */ + const driver::odbcabstraction::Connection::ConnPropertyMap& GetProperties() const; + + std::vector GetCustomKeys() const; + +private: + driver::odbcabstraction::Connection::ConnPropertyMap properties; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/connection_string_parser.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/connection_string_parser.h new file mode 100644 index 0000000000000..258fae540b9f6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/connection_string_parser.h @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "config/configuration.h" + +namespace driver { +namespace flight_sql { +namespace config { + +/** + * ODBC configuration parser abstraction. + */ +class ConnectionStringParser +{ +public: + /** + * Constructor. + * + * @param cfg Configuration. + */ + explicit ConnectionStringParser(Configuration& cfg); + + /** + * Destructor. + */ + ~ConnectionStringParser(); + + /** + * Parse connect string. + * + * @param str String to parse. + * @param len String length. + * @param delimiter delimiter. + */ + void ParseConnectionString(const char* str, size_t len, char delimiter); + + /** + * Parse connect string. + * + * @param str String to parse. + */ + void ParseConnectionString(const std::string& str); + + /** + * Parse config attributes. + * + * @param str String to parse. + */ + void ParseConfigAttributes(const char* str); + +private: + ConnectionStringParser(const ConnectionStringParser& parser) = delete; + ConnectionStringParser& operator=(const ConnectionStringParser&) = delete; + + /** Configuration. */ + Configuration& cfg; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h new file mode 100644 index 0000000000000..15bed38db7c6f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlDriver : public odbcabstraction::Driver { +private: + odbcabstraction::Diagnostics diagnostics_; + std::string version_; + +public: + FlightSqlDriver(); + + std::shared_ptr + CreateConnection(odbcabstraction::OdbcVersion odbc_version) override; + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + void SetVersion(std::string version) override; + + void RegisterLog() override; +}; + +}; // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h new file mode 100644 index 0000000000000..bf49fb1d138ec --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "ui/custom_window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * Add property window class. + */ +class AddPropertyWindow : public CustomWindow +{ + /** + * Children windows ids. + */ + struct ChildId + { + enum Type + { + KEY_EDIT = 100, + KEY_LABEL, + VALUE_EDIT, + VALUE_LABEL, + OK_BUTTON, + CANCEL_BUTTON + }; + }; + +public: + /** + * Constructor. + * + * @param parent Parent window handle. + */ + explicit AddPropertyWindow(Window* parent); + + /** + * Destructor. + */ + virtual ~AddPropertyWindow(); + + /** + * Create window in the center of the parent window. + */ + void Create(); + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnCreate + */ + virtual void OnCreate() override; + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnMessage + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) override; + + /** + * Get the property from the dialog. + * + * @return true if the dialog was OK'd, false otherwise. + */ + bool GetProperty(std::string& key, std::string& value); + +private: + /** + * Create property edit boxes. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateEdits(int posX, int posY, int sizeX); + + void CheckEnableOk(); + + std::vector > labels; + + /** Ok button. */ + std::unique_ptr okButton; + + /** Cancel button. */ + std::unique_ptr cancelButton; + + std::unique_ptr keyEdit; + + std::unique_ptr valueEdit; + + std::string key; + + std::string value; + + /** Window width. */ + int width; + + /** Window height. */ + int height; + + /** Flag indicating whether OK option was selected. */ + bool accepted; + + bool isInitialized; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h new file mode 100644 index 0000000000000..bd8f8daa43745 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "ui/window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * Application execution result. + */ +struct Result +{ + enum Type + { + OK, + CANCEL + }; +}; + +/** + * Process UI messages in current thread. + * Blocks until quit message has been received. + * + * @param window Main window. + * @return Application execution result. + */ +Result::Type ProcessMessages(Window& window); + +/** + * Window class. + */ +class CustomWindow : public Window +{ +public: + // Window margin size. + enum { MARGIN = 10 }; + + // Standard interval between UI elements. + enum { INTERVAL = 10 }; + + // Standard row height. + enum { ROW_HEIGHT = 20 }; + + // Standard button width. + enum { BUTTON_WIDTH = 80 }; + + // Standard button height. + enum { BUTTON_HEIGHT = 25 }; + + /** + * Constructor. + * + * @param parent Parent window. + * @param className Window class name. + * @param title Window title. + */ + CustomWindow(Window* parent, const char* className, const char* title); + + /** + * Destructor. + */ + virtual ~CustomWindow(); + + /** + * Callback which is called upon receiving new message. + * Pure virtual. Should be defined by user. + * + * @param msg Message. + * @param wParam Word-sized parameter. + * @param lParam Long parameter. + * @return Should return true if the message has been + * processed by the handler and false otherwise. + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) = 0; + + /** + * Callback that is called upon window creation. + */ + virtual void OnCreate() = 0; + +private: +// IGNITE_NO_COPY_ASSIGNMENT(CustomWindow) + + /** + * Static callback. + * + * @param hwnd Window handle. + * @param msg Message. + * @param wParam Word-sized parameter. + * @param lParam Long parameter. + * @return Operation result. + */ + static LRESULT CALLBACK WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam); +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h new file mode 100644 index 0000000000000..4d5e3f8b7b82f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "config/configuration.h" +#include "ui/custom_window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * DSN configuration window class. + */ +class DsnConfigurationWindow : public CustomWindow +{ + /** + * Children windows ids. + */ + struct ChildId + { + enum Type + { + CONNECTION_SETTINGS_GROUP_BOX = 100, + AUTH_SETTINGS_GROUP_BOX, + ENCRYPTION_SETTINGS_GROUP_BOX, + NAME_EDIT, + NAME_LABEL, + SERVER_EDIT, + SERVER_LABEL, + PORT_EDIT, + PORT_LABEL, + AUTH_TYPE_LABEL, + AUTH_TYPE_COMBOBOX, + USER_LABEL, + USER_EDIT, + PASSWORD_LABEL, + PASSWORD_EDIT, + AUTH_TOKEN_LABEL, + AUTH_TOKEN_EDIT, + ENABLE_ENCRYPTION_LABEL, + ENABLE_ENCRYPTION_CHECKBOX, + CERTIFICATE_LABEL, + CERTIFICATE_EDIT, + CERTIFICATE_BROWSE_BUTTON, + USE_SYSTEM_CERT_STORE_LABEL, + USE_SYSTEM_CERT_STORE_CHECKBOX, + DISABLE_CERT_VERIFICATION_LABEL, + DISABLE_CERT_VERIFICATION_CHECKBOX, + PROPERTY_GROUP_BOX, + PROPERTY_LIST, + ADD_BUTTON, + DELETE_BUTTON, + TAB_CONTROL, + TEST_CONNECTION_BUTTON, + OK_BUTTON, + CANCEL_BUTTON + }; + }; + +public: + /** + * Constructor. + * + * @param parent Parent window handle. + */ + DsnConfigurationWindow(Window* parent, config::Configuration& config); + + /** + * Destructor. + */ + virtual ~DsnConfigurationWindow(); + + /** + * Create window in the center of the parent window. + */ + void Create(); + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnCreate + */ + virtual void OnCreate() override; + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnMessage + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) override; + +private: + /** + * Create connection settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateConnectionSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create aythentication settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateAuthSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create Encryption settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateEncryptionSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create advanced properties group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreatePropertiesGroup(int posX, int posY, int sizeX); + + void SelectTab(int tabIndex); + + void CheckEnableOk(); + + void CheckAuthType(); + + void SaveParameters(Configuration& targetConfig); + + /** Window width. */ + int width; + + /** Window height. */ + int height; + + std::unique_ptr tabControl; + + std::unique_ptr commonContent; + + std::unique_ptr advancedContent; + + /** Connection settings group box. */ + std::unique_ptr connectionSettingsGroupBox; + + /** Authentication settings group box. */ + std::unique_ptr authSettingsGroupBox; + + /** Encryption settings group box. */ + std::unique_ptr encryptionSettingsGroupBox; + + std::vector > labels; + + /** Test button. */ + std::unique_ptr testButton; + + /** Ok button. */ + std::unique_ptr okButton; + + /** Cancel button. */ + std::unique_ptr cancelButton; + + /** DSN name edit field. */ + std::unique_ptr nameEdit; + + std::unique_ptr serverEdit; + + std::unique_ptr portEdit; + + std::unique_ptr authTypeComboBox; + + /** User edit. */ + std::unique_ptr userEdit; + + /** Password edit. */ + std::unique_ptr passwordEdit; + + std::unique_ptr authTokenEdit; + + std::unique_ptr enableEncryptionCheckBox; + + std::unique_ptr certificateEdit; + + std::unique_ptr certificateBrowseButton; + + std::unique_ptr useSystemCertStoreCheckBox; + + std::unique_ptr disableCertVerificationCheckBox; + + std::unique_ptr propertyGroupBox; + + std::unique_ptr propertyList; + + std::unique_ptr addButton; + + std::unique_ptr deleteButton; + + /** Configuration. */ + Configuration& config; + + /** Flag indicating whether OK option was selected. */ + bool accepted; + + bool isInitialized; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h new file mode 100644 index 0000000000000..8b9755c055aa0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +/** + * Get handle for the current module. + * + * @return Handle for the current module. + */ +HINSTANCE GetHInstance(); + +/** + * Window class. + */ +class Window +{ +public: + /** + * Constructor for a new window that is going to be created. + * + * @param parent Parent window handle. + * @param className Window class name. + * @param title Window title. + * @param callback Event processing function. + */ + Window(Window* parent, const char* className, const char* title); + + /** + * Constructor for the existing window. + * + * @param handle Window handle. + */ + explicit Window(HWND handle); + + /** + * Destructor. + */ + virtual ~Window(); + + /** + * Create window. + * + * @param style Window style. + * @param posX Window x position. + * @param posY Window y position. + * @param width Window width. + * @param height Window height. + * @param id ID for child window. + */ + void Create(DWORD style, int posX, int posY, int width, int height, int id); + + /** + * Create child tab controlwindow. + * + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateTabControl(int id); + + /** + * Create child list view window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateList(int posX, int posY, + int sizeX, int sizeY, int id); + + /** + * Create child group box window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateGroupBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Create child label window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateLabel(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Create child Edit window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateEdit(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style = 0); + + /** + * Create child button window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateButton(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style = 0); + + /** + * Create child CheckBox window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateCheckBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, bool state); + + /** + * Create child ComboBox window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateComboBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Show window. + */ + void Show(); + + /** + * Update window. + */ + void Update(); + + /** + * Destroy window. + */ + void Destroy(); + + /** + * Get window handle. + * + * @return Window handle. + */ + HWND GetHandle() const + { + return handle; + } + + void SetVisible(bool isVisible); + + void ListAddColumn(const std::string& name, int index, int width); + + void ListAddItem(const std::vector& items); + + void ListDeleteSelectedItem(); + + std::vector > ListGetAll(); + + void AddTab(const std::string& name, int index); + + bool IsTextEmpty() const; + + /** + * Get window text. + * + * @param text Text. + */ + void GetText(std::string& text) const; + + /** + * Set window text. + * + * @param text Text. + */ + void SetText(const std::string& text) const; + + /** + * Get CheckBox state. + * + * @param True if checked. + */ + bool IsChecked() const; + + /** + * Set CheckBox state. + * + * @param state True if checked. + */ + void SetChecked(bool state); + + /** + * Add string. + * + * @param str String. + */ + void AddString(const std::string& str); + + /** + * Set current ComboBox selection. + * + * @param idx List index. + */ + void SetSelection(int idx); + + /** + * Get current ComboBox selection. + * + * @return idx List index. + */ + int GetSelection() const; + + /** + * Set enabled. + * + * @param enabled Enable flag. + */ + void SetEnabled(bool enabled); + + /** + * Check if the window is enabled. + * + * @return True if enabled. + */ + bool IsEnabled() const; + +protected: + /** + * Set window handle. + * + * @param value Window handle. + */ + void SetHandle(HWND value) + { + handle = value; + } + + /** Window class name. */ + std::string className; + + /** Window title. */ + std::string title; + + /** Window handle. */ + HWND handle; + + /** Window parent. */ + Window* parent; + + /** Specifies whether window has been created by the thread and needs destruction. */ + bool created; + +private: + Window(const Window& window) = delete; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.cc new file mode 100644 index 0000000000000..04d67bf58dc93 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.cc @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "json_converter.h" + +#include +#include +#include +#include +#include +#include "utils.h" +#include + +using namespace arrow; +using namespace boost::beast::detail; +using driver::flight_sql::ThrowIfNotOK; + +namespace { +template +Status ConvertScalarToStringAndWrite(const ScalarT& scalar, rapidjson::Writer& writer) { + ARROW_ASSIGN_OR_RAISE(auto string_scalar, scalar.CastTo(utf8())) + const auto &view = reinterpret_cast(string_scalar.get())->view(); + writer.String(view.data(), view.length(), true); + return Status::OK(); +} + +template +Status ConvertBinaryToBase64StringAndWrite(const BinaryScalarT& scalar, rapidjson::Writer& writer) { + const auto &view = scalar.view(); + size_t encoded_size = base64::encoded_size(view.length()); + std::vector encoded(std::max(encoded_size, static_cast(1))); + base64::encode(&encoded[0], view.data(), view.length()); + writer.String(&encoded[0], encoded_size, true); + return Status::OK(); +} + +template +Status WriteListScalar(const ListScalarT& scalar, rapidjson::Writer& writer, + arrow::ScalarVisitor* visitor) { + writer.StartArray(); + for (int64_t i = 0; i < scalar.value->length(); ++i) { + if (scalar.value->IsNull(i)) { + writer.Null(); + } else { + const auto &result = scalar.value->GetScalar(i); + ThrowIfNotOK(result.status()); + ThrowIfNotOK(result.ValueOrDie()->Accept(visitor)); + } + } + + writer.EndArray(); + return Status::OK(); +} + + +class ScalarToJson : public arrow::ScalarVisitor { +private: + rapidjson::StringBuffer string_buffer_; + rapidjson::Writer writer_{string_buffer_}; + +public: + void Reset() { + string_buffer_.Clear(); + writer_.Reset(string_buffer_); + } + + std::string ToString() { + return string_buffer_.GetString(); + } + + Status Visit(const NullScalar &scalar) override { + writer_.Null(); + + return Status::OK(); + } + + Status Visit(const BooleanScalar &scalar) override { + writer_.Bool(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int8Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int16Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int32Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int64Scalar &scalar) override { + writer_.Int64(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt8Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt16Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt32Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt64Scalar &scalar) override { + writer_.Uint64(scalar.value); + + return Status::OK(); + } + + Status Visit(const HalfFloatScalar &scalar) override { + return Status::NotImplemented("Cannot convert HalfFloatScalar to JSON."); + } + + Status Visit(const FloatScalar &scalar) override { + writer_.Double(scalar.value); + + return Status::OK(); + } + + Status Visit(const DoubleScalar &scalar) override { + writer_.Double(scalar.value); + + return Status::OK(); + } + + Status Visit(const StringScalar &scalar) override { + const auto &view = scalar.view(); + writer_.String(view.data(), view.length()); + + return Status::OK(); + } + + Status Visit(const BinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const LargeStringScalar &scalar) override { + const auto &view = scalar.view(); + writer_.String(view.data(), view.length()); + + return Status::OK(); + } + + Status Visit(const LargeBinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const FixedSizeBinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const Date64Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Date32Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Time32Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Time64Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const TimestampScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const DayTimeIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const MonthDayNanoIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const MonthIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const DurationScalar &scalar) override { + // TODO: Append TimeUnit on conversion + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Decimal128Scalar &scalar) override { + const auto &view = scalar.ToString(); + writer_.RawValue(view.data(), view.length(), rapidjson::kNumberType); + + return Status::OK(); + } + + Status Visit(const Decimal256Scalar &scalar) override { + const auto &view = scalar.ToString(); + writer_.RawValue(view.data(), view.length(), rapidjson::kNumberType); + + return Status::OK(); + } + + Status Visit(const ListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const LargeListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const MapScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const FixedSizeListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const StructScalar &scalar) override { + writer_.StartObject(); + + const std::shared_ptr &data_type = std::static_pointer_cast(scalar.type); + for (int i = 0; i < data_type->num_fields(); ++i) { + const auto& result = scalar.field(i); + ThrowIfNotOK(result.status()); + const auto& value = result.ValueOrDie(); + writer_.Key(data_type->field(i)->name().c_str()); + if (value->is_valid) { + ThrowIfNotOK(value->Accept(this)); + } + else { + writer_.Null(); + } + } + writer_.EndObject(); + return Status::OK(); + } + + Status Visit(const DictionaryScalar &scalar) override { + return Status::NotImplemented("Cannot convert DictionaryScalar to JSON."); + } + + Status Visit(const SparseUnionScalar &scalar) override { + return scalar.value->Accept(this); + } + + Status Visit(const DenseUnionScalar &scalar) override { + return scalar.value->Accept(this); + } + + Status Visit(const ExtensionScalar &scalar) override { + return Status::NotImplemented("Cannot convert ExtensionScalar to JSON."); + } +}; +} + +namespace driver { +namespace flight_sql { + +std::string ConvertToJson(const arrow::Scalar &scalar) { + static thread_local ScalarToJson converter; + converter.Reset(); + ThrowIfNotOK(scalar.Accept(&converter)); + + return converter.ToString(); +} + +arrow::Result> ConvertToJson(const std::shared_ptr& input) { + arrow::StringBuilder builder; + int64_t length = input->length(); + RETURN_NOT_OK(builder.ReserveData(length)); + + for (int64_t i = 0; i < length; ++i) { + if (input->IsNull(i)) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + ARROW_ASSIGN_OR_RAISE(auto scalar, input->GetScalar(i)) + RETURN_NOT_OK(builder.Append(ConvertToJson(*scalar))); + } + } + + return builder.Finish(); +} + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h new file mode 100644 index 0000000000000..590785cefbc8d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace driver { +namespace flight_sql { + +std::string ConvertToJson(const arrow::Scalar& scalar); + +arrow::Result> ConvertToJson(const std::shared_ptr& input); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter_test.cc new file mode 100644 index 0000000000000..5de12643c89a3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter_test.cc @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "json_converter.h" + +#include "gtest/gtest.h" +#include "arrow/testing/builder.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +TEST(ConvertToJson, String) { + ASSERT_EQ("\"\"", ConvertToJson(StringScalar(""))); + ASSERT_EQ("\"string\"", ConvertToJson(StringScalar("string"))); + ASSERT_EQ("\"string\\\"\"", ConvertToJson(StringScalar("string\""))); +} + +TEST(ConvertToJson, LargeString) { + ASSERT_EQ("\"\"", ConvertToJson(LargeStringScalar(""))); + ASSERT_EQ("\"string\"", ConvertToJson(LargeStringScalar("string"))); + ASSERT_EQ("\"string\\\"\"", ConvertToJson(LargeStringScalar("string\""))); +} + +TEST(ConvertToJson, Binary) { + ASSERT_EQ("\"\"", ConvertToJson(BinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(BinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(BinaryScalar("string\""))); +} + +TEST(ConvertToJson, LargeBinary) { + ASSERT_EQ("\"\"", ConvertToJson(LargeBinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(LargeBinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(LargeBinaryScalar("string\""))); +} + +TEST(ConvertToJson, FixedSizeBinary) { + ASSERT_EQ("\"\"", ConvertToJson(FixedSizeBinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(FixedSizeBinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(FixedSizeBinaryScalar("string\""))); +} + +TEST(ConvertToJson, Int8) { + ASSERT_EQ("127", ConvertToJson(Int8Scalar(127))); + ASSERT_EQ("-128", ConvertToJson(Int8Scalar(-128))); +} + +TEST(ConvertToJson, Int16) { + ASSERT_EQ("32767", ConvertToJson(Int16Scalar(32767))); + ASSERT_EQ("-32768", ConvertToJson(Int16Scalar(-32768))); +} + +TEST(ConvertToJson, Int32) { + ASSERT_EQ("2147483647", ConvertToJson(Int32Scalar(2147483647))); + ASSERT_EQ("-2147483648", ConvertToJson(Int32Scalar(-2147483648))); +} + +TEST(ConvertToJson, Int64) { + ASSERT_EQ("9223372036854775807", ConvertToJson(Int64Scalar(9223372036854775807LL))); + ASSERT_EQ("-9223372036854775808", ConvertToJson(Int64Scalar(-9223372036854775808ULL))); +} + +TEST(ConvertToJson, UInt8) { + ASSERT_EQ("127", ConvertToJson(UInt8Scalar(127))); + ASSERT_EQ("255", ConvertToJson(UInt8Scalar(255))); +} + +TEST(ConvertToJson, UInt16) { + ASSERT_EQ("32767", ConvertToJson(UInt16Scalar(32767))); + ASSERT_EQ("65535", ConvertToJson(UInt16Scalar(65535))); +} + +TEST(ConvertToJson, UInt32) { + ASSERT_EQ("2147483647", ConvertToJson(UInt32Scalar(2147483647))); + ASSERT_EQ("4294967295", ConvertToJson(UInt32Scalar(4294967295))); +} + +TEST(ConvertToJson, UInt64) { + ASSERT_EQ("9223372036854775807", ConvertToJson(UInt64Scalar(9223372036854775807LL))); + ASSERT_EQ("18446744073709551615", ConvertToJson(UInt64Scalar(18446744073709551615ULL))); +} + +TEST(ConvertToJson, Float) { + ASSERT_EQ("1.5", ConvertToJson(FloatScalar(1.5))); + ASSERT_EQ("-1.5", ConvertToJson(FloatScalar(-1.5))); +} + +TEST(ConvertToJson, Double) { + ASSERT_EQ("1.5", ConvertToJson(DoubleScalar(1.5))); + ASSERT_EQ("-1.5", ConvertToJson(DoubleScalar(-1.5))); +} + +TEST(ConvertToJson, Boolean) { + ASSERT_EQ("true", ConvertToJson(BooleanScalar(true))); + ASSERT_EQ("false", ConvertToJson(BooleanScalar(false))); +} + +TEST(ConvertToJson, Null) { + ASSERT_EQ("null", ConvertToJson(NullScalar())); +} + +TEST(ConvertToJson, Date32) { + ASSERT_EQ("\"1969-12-31\"", ConvertToJson(Date32Scalar(-1))); + ASSERT_EQ("\"1970-01-01\"", ConvertToJson(Date32Scalar(0))); + ASSERT_EQ("\"2022-01-01\"", ConvertToJson(Date32Scalar(18993))); +} + +TEST(ConvertToJson, Date64) { + ASSERT_EQ("\"1969-12-31\"", ConvertToJson(Date64Scalar(-86400000))); + ASSERT_EQ("\"1970-01-01\"", ConvertToJson(Date64Scalar(0))); + ASSERT_EQ("\"2022-01-01\"", ConvertToJson(Date64Scalar(1640995200000))); +} + +TEST(ConvertToJson, Time32) { + ASSERT_EQ("\"00:00:00\"", ConvertToJson(Time32Scalar(0, TimeUnit::SECOND))); + ASSERT_EQ("\"01:02:03\"", ConvertToJson(Time32Scalar(3723, TimeUnit::SECOND))); + ASSERT_EQ("\"00:00:00.123\"", ConvertToJson(Time32Scalar(123, TimeUnit::MILLI))); +} + +TEST(ConvertToJson, Time64) { + ASSERT_EQ("\"00:00:00.123456\"", ConvertToJson(Time64Scalar(123456, TimeUnit::MICRO))); + ASSERT_EQ("\"00:00:00.123456789\"", ConvertToJson(Time64Scalar(123456789, TimeUnit::NANO))); +} + +TEST(ConvertToJson, Timestamp) { + ASSERT_EQ("\"1969-12-31 00:00:00.000\"", ConvertToJson(TimestampScalar(-86400000, TimeUnit::MILLI))); + ASSERT_EQ("\"1970-01-01 00:00:00.000\"", ConvertToJson(TimestampScalar(0, TimeUnit::MILLI))); + ASSERT_EQ("\"2022-01-01 00:00:00.000\"", ConvertToJson(TimestampScalar(1640995200000, TimeUnit::MILLI))); + ASSERT_EQ("\"2022-01-01 00:00:01.234\"", ConvertToJson(TimestampScalar(1640995201234, TimeUnit::MILLI))); +} + +TEST(ConvertToJson, DayTimeInterval) { + ASSERT_EQ("\"123d0ms\"", ConvertToJson(DayTimeIntervalScalar({123, 0}))); + ASSERT_EQ("\"1d234ms\"", ConvertToJson(DayTimeIntervalScalar({1, 234}))); +} + +TEST(ConvertToJson, MonthDayNanoInterval) { + ASSERT_EQ("\"12M34d56ns\"", ConvertToJson(MonthDayNanoIntervalScalar({12, 34, 56}))); +} + +TEST(ConvertToJson, MonthInterval) { + ASSERT_EQ("\"1M\"", ConvertToJson(MonthIntervalScalar(1))); +} + +TEST(ConvertToJson, Duration) { + // TODO: Append TimeUnit on conversion + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::SECOND))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MILLI))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MICRO))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::NANO))); +} + +TEST(ConvertToJson, Lists) { + std::vector values = {"ABC", "DEF", "XYZ"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + const char *expected_string = R"(["ABC","DEF","XYZ"])"; + ASSERT_EQ(expected_string, ConvertToJson(ListScalar{array})); + ASSERT_EQ(expected_string, ConvertToJson(FixedSizeListScalar{array})); + ASSERT_EQ(expected_string, ConvertToJson(LargeListScalar{array})); + + StringBuilder builder; + ASSERT_OK(builder.AppendNull()); + ASSERT_EQ("[null]", ConvertToJson(ListScalar{builder.Finish().ValueOrDie()})); + ASSERT_EQ("[]", ConvertToJson(ListScalar{StringBuilder().Finish().ValueOrDie()})); +} + +TEST(ConvertToJson, Struct) { + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + ASSERT_OK_AND_ASSIGN(auto scalar, + StructScalar::Make({i32, f64, str, + MakeNullScalar(std::shared_ptr(new arrow::Date32Type()))}, + {"i", "f", "s", "null"})); + ASSERT_EQ("{\"i\":1,\"f\":2.5,\"s\":\"yo\",\"null\":null}", ConvertToJson(*scalar)); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc new file mode 100644 index 0000000000000..c0c537a4b2f69 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "flight_sql_connection.h" +#include "flight_sql_result_set.h" +#include "flight_sql_result_set_metadata.h" +#include "flight_sql_statement.h" + +#include +#include +#include +#include + +using arrow::Status; +using arrow::flight::FlightClient; +using arrow::flight::Location; +using arrow::flight::sql::FlightSqlClient; + +using driver::flight_sql::FlightSqlConnection; +using driver::flight_sql::FlightSqlDriver; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::ResultSet; +using driver::odbcabstraction::ResultSetMetadata; +using driver::odbcabstraction::Statement; + +void TestBindColumn(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + + const int batch_size = 100; + const int max_strlen = 1000; + + char IncidntNum[batch_size][max_strlen]; + ssize_t IncidntNum_length[batch_size]; + + char Category[batch_size][max_strlen]; + ssize_t Category_length[batch_size]; + + result_set->BindColumn(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + IncidntNum, max_strlen, IncidntNum_length); + result_set->BindColumn(2, driver::odbcabstraction::CDataType_CHAR, 0, 0, + Category, max_strlen, Category_length); + + size_t total = 0; + while (true) { + size_t fetched_rows = result_set->Move(batch_size, 0, 0, nullptr); + std::cout << "Fetched " << fetched_rows << " rows." << std::endl; + + total += fetched_rows; + std::cout << "Total:" << total << std::endl; + + for (int i = 0; i < fetched_rows; ++i) { + std::cout << "Row[" << i << "] IncidntNum: '" << IncidntNum[i] + << "', Category: '" << Category[i] << "'" << std::endl; + } + + if (fetched_rows < batch_size) + break; + } +} + +void TestGetData(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5 UNION ALL SELECT 6"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + const std::shared_ptr &metadata = result_set->GetMetadata(); + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + char result[128]; + ssize_t result_length; + result_set->GetData(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + &result, sizeof(result), &result_length); + std::cout << result << std::endl; + } +} + +void TestBindColumnBigInt(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT IncidntNum, CAST(\"IncidntNum\" AS DOUBLE) / 100 AS " + "double_field, Category\n" + "FROM (\n" + " SELECT CONVERT_TO_INTEGER(IncidntNum, 1, 1, 0) AS IncidntNum, " + "Category\n" + " FROM (\n" + " SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10\n" + " ) nested_0\n" + ") nested_0"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + + const int batch_size = 100; + const int max_strlen = 1000; + + char IncidntNum[batch_size][max_strlen]; + ssize_t IncidntNum_length[batch_size]; + + double double_field[batch_size]; + ssize_t double_field_length[batch_size]; + + char Category[batch_size][max_strlen]; + ssize_t Category_length[batch_size]; + + result_set->BindColumn(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + IncidntNum, max_strlen, IncidntNum_length); + result_set->BindColumn(2, driver::odbcabstraction::CDataType_DOUBLE, 0, 0, + double_field, max_strlen, double_field_length); + result_set->BindColumn(3, driver::odbcabstraction::CDataType_CHAR, 0, 0, + Category, max_strlen, Category_length); + + size_t total = 0; + while (true) { + size_t fetched_rows = result_set->Move(batch_size, 0, 0, nullptr); + std::cout << "Fetched " << fetched_rows << " rows." << std::endl; + + total += fetched_rows; + std::cout << "Total:" << total << std::endl; + + for (int i = 0; i < fetched_rows; ++i) { + std::cout << "Row[" << i << "] IncidntNum: '" << IncidntNum[i] << "', " + << "double_field: '" << double_field[i] << "', " + << "Category: '" << Category[i] << "'" << std::endl; + } + + if (fetched_rows < batch_size) + break; + } +} + +void TestGetTablesV2(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + const std::shared_ptr &result_set = + statement->GetTables_V2(nullptr, nullptr, nullptr, nullptr); + + const std::shared_ptr &metadata = + result_set->GetMetadata(); + size_t column_count = metadata->GetColumnCount(); + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + int buffer_length = 1024; + std::vector result(buffer_length); + ssize_t result_length; + result_set->GetData(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + result.data(), buffer_length, &result_length); + std::cout << result.data() << std::endl; + } + + std::cout << column_count << std::endl; +} + +void TestGetColumnsV3(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + std::string table_name = "test_numeric"; + std::string column_name = "%"; + const std::shared_ptr &result_set = + statement->GetColumns_V3(nullptr, nullptr, &table_name, &column_name); + + const std::shared_ptr &metadata = + result_set->GetMetadata(); + size_t column_count = metadata->GetColumnCount(); + + int buffer_length = 1024; + std::vector result(buffer_length); + ssize_t result_length; + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + for (int i = 0; i < column_count; ++i) { + result_set->GetData(1 + i, driver::odbcabstraction::CDataType_CHAR, 0, 0, + result.data(), buffer_length, &result_length); + std::cout << (result_length != -1 ? result.data() : "NULL") << '\t'; + } + + std::cout << std::endl; + } + + std::cout << column_count << std::endl; +} + +int main() { + FlightSqlDriver driver; + + const std::shared_ptr &connection = + driver.CreateConnection(driver::odbcabstraction::V_3); + + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("automaster.drem.io")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USER, std::string("dremio")}, + {FlightSqlConnection::PASSWORD, std::string("dremio123")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + }; + std::vector missing_attr; + connection->Connect(properties, missing_attr); + + // TestBindColumnBigInt(connection); +// TestBindColumn(connection); + TestGetData(connection); + // TestGetTablesV2(connection); +// TestGetColumnsV3(connection); + + connection->Close(); + return 0; +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc new file mode 100644 index 0000000000000..b16cfb3ca4b6e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "flight_sql_statement_get_tables.h" +#include +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +void AssertParseTest(const std::string &input_string, + const std::vector &assert_vector) { + std::vector table_types; + + ParseTableTypes(input_string, table_types); + ASSERT_EQ(table_types, assert_vector); +} + +TEST(TableTypeParser, ParsingWithoutSingleQuotesWithLeadingWhiteSpace) { + AssertParseTest("TABLE, VIEW", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithoutSingleQuotesWithoutLeadingWhiteSpace) { + AssertParseTest("TABLE,VIEW", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithSingleQuotesWithLeadingWhiteSpace) { + AssertParseTest("'TABLE', 'VIEW'", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithSingleQuotesWithoutLeadingWhiteSpace) { + AssertParseTest("'TABLE','VIEW'", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithCommaInsideSingleQuotes) { + AssertParseTest("'TABLE, TEST', 'VIEW, TEMPORARY'", + {"TABLE, TEST", "VIEW, TEMPORARY"}); +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.cc new file mode 100644 index 0000000000000..69d7087bb012c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.cc @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "record_batch_transformer.h" +#include + +#include "utils.h" +#include +#include +#include +#include + +#include "arrow/array/array_base.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +namespace { +Result> MakeEmptyArray(std::shared_ptr type, + MemoryPool *memory_pool, + int64_t array_size) { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(memory_pool, type, &builder)); + RETURN_NOT_OK(builder->AppendNulls(array_size)); + return builder->Finish(); +} + +/// A transformer class which is responsible to convert the name of fields +/// inside a RecordBatch. These fields are changed based on tasks created by the +/// methods RenameField() and AddFieldOfNulls(). The execution of the tasks is +/// handled by the method transformer. +class RecordBatchTransformerWithTasks : public RecordBatchTransformer { +private: + std::vector> fields_; + std::vector( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)>> + tasks_; + +public: + RecordBatchTransformerWithTasks( + std::vector> fields, + std::vector( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)>> + tasks) { + this->fields_.swap(fields); + this->tasks_.swap(tasks); + } + + std::shared_ptr + Transform(const std::shared_ptr &original) override { + auto new_schema = schema(fields_); + + std::vector> arrays; + arrays.reserve(new_schema->num_fields()); + + for (const auto &item : tasks_) { + arrays.emplace_back(item(original, new_schema)); + } + + auto transformed_batch = + RecordBatch::Make(new_schema, original->num_rows(), arrays); + return transformed_batch; + } + + std::shared_ptr GetTransformedSchema() override { + return schema(fields_); + } +}; +} // namespace + +RecordBatchTransformerWithTasksBuilder & +RecordBatchTransformerWithTasksBuilder::RenameField( + const std::string &original_name, const std::string &transformed_name) { + + auto rename_task = [=](const std::shared_ptr &original_record, + const std::shared_ptr &transformed_schema) { + auto original_data_type = + original_record->schema()->GetFieldByName(original_name); + auto transformed_data_type = + transformed_schema->GetFieldByName(transformed_name); + + if (original_data_type->type() != transformed_data_type->type()) { + throw odbcabstraction::DriverException( + "Original data and target data has different types"); + } + + return original_record->GetColumnByName(original_name); + }; + + task_collection_.emplace_back(rename_task); + + auto original_fields = schema_->GetFieldByName(original_name); + + if (original_fields->HasMetadata()) { + new_fields_.push_back(field(transformed_name, original_fields->type(), + original_fields->metadata())); + } else { + new_fields_.push_back( + field(transformed_name, original_fields->type(), std::shared_ptr())); + } + + return *this; +} + +RecordBatchTransformerWithTasksBuilder & +RecordBatchTransformerWithTasksBuilder::AddFieldOfNulls( + const std::string &field_name, const std::shared_ptr &data_type) { + auto empty_fields_task = + [=](const std::shared_ptr &original_record, + const std::shared_ptr &transformed_schema) { + auto result = + MakeEmptyArray(data_type, nullptr, original_record->num_rows()); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); + }; + + task_collection_.emplace_back(empty_fields_task); + + new_fields_.push_back(field(field_name, data_type)); + + return *this; +} + +std::shared_ptr +RecordBatchTransformerWithTasksBuilder::Build() { + std::shared_ptr transformer( + new RecordBatchTransformerWithTasks(this->new_fields_, + this->task_collection_)); + + return transformer; +} + +RecordBatchTransformerWithTasksBuilder::RecordBatchTransformerWithTasksBuilder( + std::shared_ptr schema) + : schema_(std::move(schema)) {} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h new file mode 100644 index 0000000000000..70b5395f4aa7f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +typedef std::function( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)> + TransformTask; + +/// A base class to implement different types of transformer. +class RecordBatchTransformer { +public: + virtual ~RecordBatchTransformer() = default; + + /// Execute the transformation based on predeclared tasks created by + /// RenameField() method and/or AddFieldOfNulls(). + /// \param original The original RecordBatch that will be used as base + /// for the transformation. + /// \return The new transformed RecordBatch. + virtual std::shared_ptr + Transform(const std::shared_ptr &original) = 0; + + /// Use the new list of fields constructed during creation of task + /// to return the new schema. + /// \return the schema from the transformedRecordBatch. + virtual std::shared_ptr GetTransformedSchema() = 0; +}; + +class RecordBatchTransformerWithTasksBuilder { +private: + std::vector> new_fields_; + std::vector task_collection_; + std::shared_ptr schema_; + +public: + /// Based on the original array name and in a target array name it prepares + /// a task that will execute the transformation. + /// \param original_name The original name of the field. + /// \param transformed_name The name after the transformation. + RecordBatchTransformerWithTasksBuilder & + RenameField(const std::string &original_name, + const std::string &transformed_name); + + /// Add an empty field to the transformed record batch. + /// \param field_name The name of the empty fields. + /// \param data_type The target data type for the new fields. + RecordBatchTransformerWithTasksBuilder & + AddFieldOfNulls(const std::string &field_name, + const std::shared_ptr &data_type); + + /// It creates an object of RecordBatchTransformerWithTasksBuilder + /// \return a RecordBatchTransformerWithTasksBuilder object. + std::shared_ptr Build(); + + /// Instantiate a RecordBatchTransformerWithTasksBuilder object. + /// \param schema The schema from the original RecordBatch. + explicit RecordBatchTransformerWithTasksBuilder( + std::shared_ptr schema); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer_test.cc new file mode 100644 index 0000000000000..b1150c0fdff9c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer_test.cc @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "arrow/testing/builder.h" +#include "record_batch_transformer.h" +#include "gtest/gtest.h" +#include +using namespace arrow; + +namespace { +std::shared_ptr CreateOriginalRecordBatch() { + std::vector values = {1, 2, 3, 4, 5}; + std::shared_ptr array; + + ArrayFromVector(values, &array); + + auto schema = arrow::schema({field("test", int32(), false)}); + + return RecordBatch::Make(schema, 4, {array}); +} +} // namespace + +namespace driver { +namespace flight_sql { + +TEST(Transformer, TransformerRenameTest) { + // Prepare the Original Record Batch + auto original_record_batch = CreateOriginalRecordBatch(); + auto schema = original_record_batch->schema(); + + // Execute the transformation of the Record Batch + std::string original_name("test"); + std::string transformed_name("test1"); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField(original_name, transformed_name) + .Build(); + + auto transformed_record_batch = transformer->Transform(original_record_batch); + + auto transformed_array_ptr = + transformed_record_batch->GetColumnByName(transformed_name); + + auto original_array_ptr = + original_record_batch->GetColumnByName(original_name); + + // Assert that the arrays are being the same and we are not creating new + // buffers + ASSERT_EQ(transformed_array_ptr, original_array_ptr); + + // Assert if the schema is not the same + ASSERT_NE(original_record_batch->schema(), + transformed_record_batch->schema()); + // Assert if the data is not changed + ASSERT_EQ(original_record_batch->GetColumnByName(original_name), + transformed_record_batch->GetColumnByName(transformed_name)); +} + +TEST(Transformer, TransformerAddEmptyVectorTest) { + // Prepare the Original Record Batch + auto original_record_batch = CreateOriginalRecordBatch(); + auto schema = original_record_batch->schema(); + + std::string original_name("test"); + std::string transformed_name("test1"); + auto emptyField = std::string("empty"); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField(original_name, transformed_name) + .AddFieldOfNulls(emptyField, int32()) + .Build(); + + auto transformed_schema = transformer->GetTransformedSchema(); + + ASSERT_EQ(transformed_schema->num_fields(), 2); + ASSERT_EQ(transformed_schema->GetFieldIndex(transformed_name), 0); + ASSERT_EQ(transformed_schema->GetFieldIndex(emptyField), 1); + + auto transformed_record_batch = transformer->Transform(original_record_batch); + + auto transformed_array_ptr = + transformed_record_batch->GetColumnByName(transformed_name); + + auto original_array_ptr = + original_record_batch->GetColumnByName(original_name); + + // Assert that the arrays are being the same and we are not creating new + // buffers + ASSERT_EQ(transformed_array_ptr, original_array_ptr); + + // Assert if the schema is not the same + ASSERT_NE(original_record_batch->schema(), + transformed_record_batch->schema()); + // Assert if the data is not changed + ASSERT_EQ(original_record_batch->GetColumnByName(original_name), + transformed_record_batch->GetColumnByName(transformed_name)); +} + +TEST(Transformer, TransformerChangingOrderOfArrayTest) { + std::vector first_array_value = {1, 2, 3, 4, 5}; + std::vector second_array_value = {6, 7, 8, 9, 10}; + std::vector third_array_value = {2, 4, 6, 8, 10}; + std::shared_ptr first_array; + std::shared_ptr second_array; + std::shared_ptr third_array; + + ArrayFromVector(first_array_value, &first_array); + ArrayFromVector(second_array_value, &second_array); + ArrayFromVector(third_array_value, &third_array); + + auto schema = arrow::schema({field("first_array", int32(), false), + field("second_array", int32(), false), + field("third_array", int32(), false)}); + + auto original_record_batch = + RecordBatch::Make(schema, 5, {first_array, second_array, third_array}); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("third_array", "test3") + .RenameField("second_array", "test2") + .RenameField("first_array", "test1") + .AddFieldOfNulls("empty", int32()) + .Build(); + + const std::shared_ptr &transformed_record_batch = + transformer->Transform(original_record_batch); + + auto transformed_schema = transformed_record_batch->schema(); + + // Assert to check if the empty fields was added + ASSERT_EQ(transformed_record_batch->num_columns(), 4); + + // Assert to make sure that the elements changed his order. + ASSERT_EQ(transformed_schema->GetFieldIndex("test3"), 0); + ASSERT_EQ(transformed_schema->GetFieldIndex("test2"), 1); + ASSERT_EQ(transformed_schema->GetFieldIndex("test1"), 2); + ASSERT_EQ(transformed_schema->GetFieldIndex("empty"), 3); + + // Assert to make sure that the data didn't change after renaming the arrays + ASSERT_EQ(transformed_record_batch->GetColumnByName("test3"), third_array); + ASSERT_EQ(transformed_record_batch->GetColumnByName("test2"), second_array); + ASSERT_EQ(transformed_record_batch->GetColumnByName("test1"), first_array); +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.cc new file mode 100644 index 0000000000000..26e859e282a33 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.cc @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "scalar_function_reporter.h" + +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +// The list of functions that can be converted from string to ODBC bitmasks is +// based on Calcite's SqlJdbcFunctionCall class. + +namespace { +static const std::unordered_map numeric_functions = { + {"ABS", SQL_FN_NUM_ABS}, {"ACOS", SQL_FN_NUM_ACOS}, + {"ASIN", SQL_FN_NUM_ASIN}, {"ATAN", SQL_FN_NUM_ATAN}, + {"ATAN2", SQL_FN_NUM_ATAN2}, {"CEILING", SQL_FN_NUM_CEILING}, + {"COS", SQL_FN_NUM_ACOS}, {"COT", SQL_FN_NUM_COT}, + {"DEGREES", SQL_FN_NUM_DEGREES}, {"EXP", SQL_FN_NUM_EXP}, + {"FLOOR", SQL_FN_NUM_FLOOR}, {"LOG", SQL_FN_NUM_LOG}, + {"LOG10", SQL_FN_NUM_LOG10}, {"MOD", SQL_FN_NUM_MOD}, + {"PI", SQL_FN_NUM_PI}, {"POWER", SQL_FN_NUM_POWER}, + {"RADIANS", SQL_FN_NUM_RADIANS}, {"RAND", SQL_FN_NUM_RAND}, + {"ROUND", SQL_FN_NUM_ROUND}, {"SIGN", SQL_FN_NUM_SIGN}, + {"SIN", SQL_FN_NUM_SIN}, {"SQRT", SQL_FN_NUM_SQRT}, + {"TAN", SQL_FN_NUM_TAN}, {"TRUNCATE", SQL_FN_NUM_TRUNCATE}}; + +static const std::unordered_map system_functions = { + {"DATABASE", SQL_FN_SYS_DBNAME}, + {"IFNULL", SQL_FN_SYS_IFNULL}, + {"USER", SQL_FN_SYS_USERNAME}}; + +static const std::unordered_map datetime_functions = { + {"CURDATE", SQL_FN_TD_CURDATE}, + {"CURTIME", SQL_FN_TD_CURTIME}, + {"DAYNAME", SQL_FN_TD_DAYNAME}, + {"DAYOFMONTH", SQL_FN_TD_DAYOFMONTH}, + {"DAYOFWEEK", SQL_FN_TD_DAYOFWEEK}, + {"DAYOFYEAR", SQL_FN_TD_DAYOFYEAR}, + {"HOUR", SQL_FN_TD_HOUR}, + {"MINUTE", SQL_FN_TD_MINUTE}, + {"MONTH", SQL_FN_TD_MONTH}, + {"MONTHNAME", SQL_FN_TD_MONTHNAME}, + {"NOW", SQL_FN_TD_NOW}, + {"QUARTER", SQL_FN_TD_QUARTER}, + {"SECOND", SQL_FN_TD_SECOND}, + {"TIMESTAMPADD", SQL_FN_TD_TIMESTAMPADD}, + {"TIMESTAMPDIFF", SQL_FN_TD_TIMESTAMPDIFF}, + {"WEEK", SQL_FN_TD_WEEK}, + {"YEAR", SQL_FN_TD_YEAR}, + // Additional functions in ODBC but not Calcite: + {"CURRENT_DATE", SQL_FN_TD_CURRENT_DATE}, + {"CURRENT_TIME", SQL_FN_TD_CURRENT_TIME}, + {"CURRENT_TIMESTAMP", SQL_FN_TD_CURRENT_TIMESTAMP}, + {"EXTRACT", SQL_FN_TD_EXTRACT}}; + +static const std::unordered_map string_functions = { + {"ASCII", SQL_FN_STR_ASCII}, + {"CHAR", SQL_FN_STR_CHAR}, + {"CONCAT", SQL_FN_STR_CONCAT}, + {"DIFFERENCE", SQL_FN_STR_DIFFERENCE}, + {"INSERT", SQL_FN_STR_INSERT}, + {"LCASE", SQL_FN_STR_LCASE}, + {"LEFT", SQL_FN_STR_LEFT}, + {"LENGTH", SQL_FN_STR_LENGTH}, + {"LOCATE", SQL_FN_STR_LOCATE}, + {"LTRIM", SQL_FN_STR_LTRIM}, + {"REPEAT", SQL_FN_STR_REPEAT}, + {"REPLACE", SQL_FN_STR_REPLACE}, + {"RIGHT", SQL_FN_STR_RIGHT}, + {"RTRIM", SQL_FN_STR_RTRIM}, + {"SOUNDEX", SQL_FN_STR_SOUNDEX}, + {"SPACE", SQL_FN_STR_SPACE}, + {"SUBSTRING", SQL_FN_STR_SUBSTRING}, + {"UCASE", SQL_FN_STR_UCASE}, + // Additional functions in ODBC but not Calcite: + {"LOCATE_2", SQL_FN_STR_LOCATE_2}, + {"BIT_LENGTH", SQL_FN_STR_BIT_LENGTH}, + {"CHAR_LENGTH", SQL_FN_STR_CHAR_LENGTH}, + {"CHARACTER_LENGTH", SQL_FN_STR_CHARACTER_LENGTH}, + {"OCTET_LENGTH", SQL_FN_STR_OCTET_LENGTH}, + {"POSTION", SQL_FN_STR_POSITION}, + {"SOUNDEX", SQL_FN_STR_SOUNDEX}}; +} // namespace + +void ReportSystemFunction(const std::string &function, + uint32_t ¤t_sys_functions, + uint32_t ¤t_convert_functions) { + const auto &result = system_functions.find(function); + if (result != system_functions.end()) { + current_sys_functions |= result->second; + } else if (function == "CONVERT") { + // CAST and CONVERT are system functions from FlightSql/Calcite, but are + // CONVERT functions in ODBC. Assume that if CONVERT is reported as a system + // function, then CAST and CONVERT are both supported. + current_convert_functions |= SQL_FN_CVT_CONVERT | SQL_FN_CVT_CAST; + } +} + +void ReportNumericFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = numeric_functions.find(function); + if (result != numeric_functions.end()) { + current_functions |= result->second; + } +} + +void ReportStringFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = string_functions.find(function); + if (result != string_functions.end()) { + current_functions |= result->second; + } +} + +void ReportDatetimeFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = datetime_functions.find(function); + if (result != datetime_functions.end()) { + current_functions |= result->second; + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h new file mode 100644 index 0000000000000..9e31536188f4e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +namespace driver { +namespace flight_sql { + +void ReportSystemFunction(const std::string &function, + uint32_t ¤t_sys_functions, + uint32_t ¤t_convert_functions); +void ReportNumericFunction(const std::string &function, + uint32_t ¤t_functions); +void ReportStringFunction(const std::string &function, + uint32_t ¤t_functions); +void ReportDatetimeFunction(const std::string &function, + uint32_t ¤t_functions); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc new file mode 100644 index 0000000000000..d27c6420fd871 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include "flight_sql_connection.h" +#include "config/configuration.h" +#include "config/connection_string_parser.h" +#include "ui/window.h" +#include "ui/dsn_configuration_window.h" +#include +#include +#include + +#include +#include +#include +#include + +using namespace std; +using namespace driver::flight_sql; +using namespace driver::flight_sql::config; + +BOOL CALLBACK ConfigDriver( + HWND hwndParent, + WORD fRequest, + LPCSTR lpszDriver, + LPCSTR lpszArgs, + LPSTR lpszMsg, + WORD cbMsgMax, + WORD* pcbMsgOut) { + return false; +} + +bool DisplayConnectionWindow(void* windowParent, Configuration& config) +{ + HWND hwndParent = (HWND)windowParent; + + if (!hwndParent) + return true; + + try + { + Window parent(hwndParent); + DsnConfigurationWindow window(&parent, config); + + window.Create(); + + window.Show(); + window.Update(); + + return ProcessMessages(window) == Result::OK; + } + catch (driver::odbcabstraction::DriverException& err) + { + std::stringstream buf; + buf << "Message: " << err.GetMessageText() << ", Code: " << err.GetNativeError(); + std::string message = buf.str(); + MessageBox(NULL, message.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + + SQLPostInstallerError(err.GetNativeError(), err.GetMessageText().c_str()); + } + + return false; +} + +void PostLastInstallerError() { + + #define BUFFER_SIZE (1024) + DWORD code; + char msg[BUFFER_SIZE]; + SQLInstallerError(1, &code, msg, BUFFER_SIZE, NULL); + + std::stringstream buf; + buf << "Message: \"" << msg << "\", Code: " << code; + std::string errorMsg = buf.str(); + + MessageBox(NULL, errorMsg.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + SQLPostInstallerError(code, errorMsg.c_str()); +} + +/** + * Unregister specified DSN. + * + * @param dsn DSN name. + * @return True on success and false on fail. + */ +bool UnregisterDsn(const std::string& dsn) +{ + if (SQLRemoveDSNFromIni(dsn.c_str())) { + return true; + } + + PostLastInstallerError(); + return false; +} + +/** + * Register DSN with specified configuration. + * + * @param config Configuration. + * @param driver Driver. + * @return True on success and false on fail. + */ +bool RegisterDsn(const Configuration& config, LPCSTR driver) +{ + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + + if (!SQLWriteDSNToIni(dsn.c_str(), driver)) + { + PostLastInstallerError(); + return false; + } + + const auto& map = config.GetProperties(); + for (auto it = map.begin(); it != map.end(); ++it) + { + const std::string& key = it->first; + if (boost::iequals(FlightSqlConnection::DSN, key) || boost::iequals(FlightSqlConnection::DRIVER, key)) { + continue; + } + + if (!SQLWritePrivateProfileString(dsn.c_str(), key.c_str(), it->second.c_str(), "ODBC.INI")) { + PostLastInstallerError(); + return false; + } + } + + return true; +} + +BOOL INSTAPI ConfigDSN(HWND hwndParent, WORD req, LPCSTR driver, LPCSTR attributes) +{ + Configuration config; + ConnectionStringParser parser(config); + parser.ParseConfigAttributes(attributes); + + switch (req) + { + case ODBC_ADD_DSN: + { + config.LoadDefaults(); + if (!DisplayConnectionWindow(hwndParent, config) || !RegisterDsn(config, driver)) + return FALSE; + + break; + } + + case ODBC_CONFIG_DSN: + { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + if (!SQLValidDSN(dsn.c_str())) + return FALSE; + + Configuration loaded(config); + loaded.LoadDsn(dsn); + + if (!DisplayConnectionWindow(hwndParent, loaded) || !UnregisterDsn(dsn.c_str()) || !RegisterDsn(loaded, driver)) + return FALSE; + + break; + } + + case ODBC_REMOVE_DSN: + { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + if (!SQLValidDSN(dsn.c_str()) || !UnregisterDsn(dsn)) + return FALSE; + + break; + } + + default: + return FALSE; + } + + return TRUE; +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc new file mode 100644 index 0000000000000..8387a5897336c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "system_trust_store.h" + + +#if defined _WIN32 || defined _WIN64 + +namespace driver { +namespace flight_sql { + bool SystemTrustStore::HasNext() { + p_context_ = CertEnumCertificatesInStore(h_store_, p_context_); + + return p_context_ != nullptr; + } + + std::string SystemTrustStore::GetNext() const { + DWORD size = 0; + CryptBinaryToString(p_context_->pbCertEncoded, p_context_->cbCertEncoded, + CRYPT_STRING_BASE64HEADER, nullptr, &size); + + std::string cert; + cert.resize(size); + CryptBinaryToString(p_context_->pbCertEncoded, + p_context_->cbCertEncoded, CRYPT_STRING_BASE64HEADER, + &cert[0], &size); + cert.resize(size); + + return cert; + } + + bool SystemTrustStore::SystemHasStore() { + return h_store_ != nullptr; + } + + SystemTrustStore::SystemTrustStore(const char* store) : stores_(store), + h_store_(CertOpenSystemStore(NULL, store)), p_context_(nullptr) {} + + SystemTrustStore::~SystemTrustStore() { + if (p_context_) { + CertFreeCertificateContext(p_context_); + } + if (h_store_) { + CertCloseStore(h_store_, 0); + } + } +} // namespace flight_sql +} // namespace driver + +#endif diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h new file mode 100644 index 0000000000000..8504806ee361e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined _WIN32 || defined _WIN64 + +#include +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +/// Load the certificates from the windows system trust store. Part of the logic +/// was based in the drill connector +/// https://github.com/apache/drill/blob/master/contrib/native/client/src/clientlib/wincert.ipp. +class SystemTrustStore { +private: + const char* stores_; + HCERTSTORE h_store_; + PCCERT_CONTEXT p_context_; + +public: + explicit SystemTrustStore(const char* store); + + ~SystemTrustStore(); + + /// Check if there is a certificate inside the system trust store to be extracted + /// \return If there is a valid cert in the store. + bool HasNext(); + + /// Get the next certificate from the store. + /// \return the certificate. + std::string GetNext() const; + + /// Check if the system has the specify store. + /// \return If the specific store exist in the system. + bool SystemHasStore(); +}; +} // namespace flight_sql +} // namespace driver + +#else // Not Windows +namespace driver { +namespace flight_sql { +class SystemTrustStore; +} // namespace flight_sql +} // namespace driver + +#endif diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc new file mode 100644 index 0000000000000..7bc2555dbddcb --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "ui/add_property_window.h" + +#include +#include +#include +#include + +#include "ui/custom_window.h" +#include "ui/window.h" +#include + +namespace driver { +namespace flight_sql { +namespace config { + +AddPropertyWindow::AddPropertyWindow(Window* parent) : + CustomWindow(parent, "AddProperty", "Add Property"), + width(300), + height(120), + accepted(false), + isInitialized(false) +{ + // No-op. +} + +AddPropertyWindow::~AddPropertyWindow() +{ + // No-op. +} + +void AddPropertyWindow::Create() +{ + // Finding out parent position. + RECT parentRect; + GetWindowRect(parent->GetHandle(), &parentRect); + + // Positioning window to the center of parent window. + const int posX = parentRect.left + (parentRect.right - parentRect.left - width) / 2; + const int posY = parentRect.top + (parentRect.bottom - parentRect.top - height) / 2; + + RECT desiredRect = { posX, posY, posX + width, posY + height }; + AdjustWindowRect(&desiredRect, WS_BORDER | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME, FALSE); + + Window::Create(WS_OVERLAPPED | WS_SYSMENU, desiredRect.left, desiredRect.top, + desiredRect.right - desiredRect.left, desiredRect.bottom - desiredRect.top, 0); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +bool AddPropertyWindow::GetProperty(std::string& key, std::string& value) +{ + if (accepted) + { + key = this->key; + value = this->value; + return true; + } + return false; +} + +void AddPropertyWindow::OnCreate() +{ + int groupPosY = MARGIN; + int groupSizeY = width - 2 * MARGIN; + + groupPosY += INTERVAL + CreateEdits(MARGIN, groupPosY, groupSizeY); + + int cancelPosX = width - MARGIN - BUTTON_WIDTH; + int okPosX = cancelPosX - INTERVAL - BUTTON_WIDTH; + + okButton = CreateButton(okPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", ChildId::OK_BUTTON, BS_DEFPUSHBUTTON); + cancelButton = CreateButton(cancelPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, + "Cancel", ChildId::CANCEL_BUTTON); + isInitialized = true; + CheckEnableOk(); +} + +int AddPropertyWindow::CreateEdits(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 30 }; + + const int editSizeX = sizeX - LABEL_WIDTH - INTERVAL; + const int editPosX = posX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY; + + labels.push_back(CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Key:", ChildId::KEY_LABEL)); + keyEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::KEY_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + labels.push_back(CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Value:", ChildId::VALUE_LABEL)); + valueEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::VALUE_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + return rowPos - posY; +} + +void AddPropertyWindow::CheckEnableOk() { + if (!isInitialized) { + return; + } + + okButton->SetEnabled(!keyEdit->IsTextEmpty() && !valueEdit->IsTextEmpty()); +} + +bool AddPropertyWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) +{ + switch (msg) + { + case WM_COMMAND: + { + switch (LOWORD(wParam)) + { + case ChildId::OK_BUTTON: + { + keyEdit->GetText(key); + valueEdit->GetText(value); + accepted = true; + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + + break; + } + + case IDCANCEL: + case ChildId::CANCEL_BUTTON: + { + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + break; + } + + case ChildId::KEY_EDIT: + case ChildId::VALUE_EDIT: + { + if (HIWORD(wParam) == EN_CHANGE) + { + CheckEnableOk(); + } + break; + } + + default: + return false; + } + + break; + } + + case WM_DESTROY: + { + PostQuitMessage(accepted ? Result::OK : Result::CANCEL); + + break; + } + + default: + return false; + } + + return true; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc new file mode 100644 index 0000000000000..10c28cea67ae7 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "ui/custom_window.h" +#include + +namespace driver { +namespace flight_sql { +namespace config { + +Result::Type ProcessMessages(Window& window) +{ + MSG msg; + + while (GetMessage(&msg, NULL, 0, 0) > 0) + { + if (!IsDialogMessage(window.GetHandle(), &msg)) + { + TranslateMessage(&msg); + + DispatchMessage(&msg); + } + } + + return static_cast(msg.wParam); +} + +LRESULT CALLBACK CustomWindow::WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam) +{ + CustomWindow* window = reinterpret_cast(GetWindowLongPtr(hwnd, GWLP_USERDATA)); + + switch (msg) + { + case WM_NCCREATE: + { + _ASSERT(lParam != NULL); + + CREATESTRUCT* createStruct = reinterpret_cast(lParam); + + LONG_PTR longSelfPtr = reinterpret_cast(createStruct->lpCreateParams); + + SetWindowLongPtr(hwnd, GWLP_USERDATA, longSelfPtr); + + return DefWindowProc(hwnd, msg, wParam, lParam); + } + + case WM_CREATE: + { + _ASSERT(window != NULL); + + window->SetHandle(hwnd); + + window->OnCreate(); + + return 0; + } + + default: + break; + } + + if (window && window->OnMessage(msg, wParam, lParam)) + return 0; + + return DefWindowProc(hwnd, msg, wParam, lParam); +} + +CustomWindow::CustomWindow(Window* parent, const char* className, const char* title) : + Window(parent, className, title) +{ + WNDCLASS wcx; + + wcx.style = CS_HREDRAW | CS_VREDRAW; + wcx.lpfnWndProc = WndProc; + wcx.cbClsExtra = 0; + wcx.cbWndExtra = 0; + wcx.hInstance = GetHInstance(); + wcx.hIcon = NULL; + wcx.hCursor = LoadCursor(NULL, IDC_ARROW); + wcx.hbrBackground = (HBRUSH)COLOR_WINDOW; + wcx.lpszMenuName = NULL; + wcx.lpszClassName = className; + + if (!RegisterClass(&wcx)) + { + std::stringstream buf; + buf << "Can not register window class, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +CustomWindow::~CustomWindow() +{ + UnregisterClass(className.c_str(), GetHInstance()); +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc new file mode 100644 index 0000000000000..719fe4f7a9ba9 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc @@ -0,0 +1,616 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "ui/dsn_configuration_window.h" + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include +#include +#include + +#include "ui/add_property_window.h" + +#define COMMON_TAB 0 +#define ADVANCED_TAB 1 + +namespace { + std::string TestConnection(const driver::flight_sql::config::Configuration& config) { + std::unique_ptr flightSqlConn( + new driver::flight_sql::FlightSqlConnection(driver::odbcabstraction::V_3)); + + std::vector missingProperties; + flightSqlConn->Connect(config.GetProperties(), missingProperties); + + // This should have been checked before enabling the Test button. + assert(missingProperties.empty()); + std::string serverName = boost::get(flightSqlConn->GetInfo(SQL_SERVER_NAME)); + std::string serverVersion = boost::get(flightSqlConn->GetInfo(SQL_DBMS_VER)); + return "Server Name: " + serverName + "\n" + + "Server Version: " + serverVersion; + } +} + +namespace driver { +namespace flight_sql { +namespace config { + +DsnConfigurationWindow::DsnConfigurationWindow(Window* parent, config::Configuration& config) : + CustomWindow(parent, "FlightConfigureDSN", "Configure Apache Arrow Flight SQL"), + width(480), + height(375), + config(config), + accepted(false), + isInitialized(false) +{ + // No-op. +} + +DsnConfigurationWindow::~DsnConfigurationWindow() +{ + // No-op. +} + +void DsnConfigurationWindow::Create() +{ + // Finding out parent position. + RECT parentRect; + GetWindowRect(parent->GetHandle(), &parentRect); + + // Positioning window to the center of parent window. + const int posX = parentRect.left + (parentRect.right - parentRect.left - width) / 2; + const int posY = parentRect.top + (parentRect.bottom - parentRect.top - height) / 2; + + RECT desiredRect = { posX, posY, posX + width, posY + height }; + AdjustWindowRect(&desiredRect, WS_BORDER | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME, FALSE); + + Window::Create(WS_OVERLAPPED | WS_SYSMENU, desiredRect.left, desiredRect.top, + desiredRect.right - desiredRect.left, desiredRect.bottom - desiredRect.top, 0); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void DsnConfigurationWindow::OnCreate() +{ + tabControl = CreateTabControl(ChildId::TAB_CONTROL); + tabControl->AddTab("Common", COMMON_TAB); + tabControl->AddTab("Advanced", ADVANCED_TAB); + + int groupPosY = 3 * MARGIN; + int groupSizeY = width - 2 * MARGIN; + + int commonGroupPosY = groupPosY; + commonGroupPosY += INTERVAL + CreateConnectionSettingsGroup(MARGIN, commonGroupPosY, groupSizeY); + commonGroupPosY += INTERVAL + CreateAuthSettingsGroup(MARGIN, commonGroupPosY, groupSizeY); + + int advancedGroupPosY = groupPosY; + advancedGroupPosY += INTERVAL + CreateEncryptionSettingsGroup(MARGIN, advancedGroupPosY, groupSizeY); + advancedGroupPosY += INTERVAL + CreatePropertiesGroup(MARGIN, advancedGroupPosY, groupSizeY); + + int testPosX = MARGIN; + int cancelPosX = width - MARGIN - BUTTON_WIDTH; + int okPosX = cancelPosX - INTERVAL - BUTTON_WIDTH; + + int buttonPosY = std::max(commonGroupPosY, advancedGroupPosY); + testButton = CreateButton(testPosX, buttonPosY, BUTTON_WIDTH + 20, BUTTON_HEIGHT, "Test Connection", ChildId::TEST_CONNECTION_BUTTON); + okButton = CreateButton(okPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", ChildId::OK_BUTTON); + cancelButton = CreateButton(cancelPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, + "Cancel", ChildId::CANCEL_BUTTON); + isInitialized = true; + CheckEnableOk(); + SelectTab(COMMON_TAB); +} + +int DsnConfigurationWindow::CreateConnectionSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 100 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + const char* val = config.Get(FlightSqlConnection::DSN).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Data Source Name:", ChildId::NAME_LABEL)); + nameEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::NAME_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::HOST).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Host Name:", ChildId::SERVER_LABEL)); + serverEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::SERVER_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::PORT).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Port:", ChildId::PORT_LABEL)); + portEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::PORT_EDIT, ES_NUMBER); + + rowPos += INTERVAL + ROW_HEIGHT; + + connectionSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Connection settings", ChildId::CONNECTION_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreateAuthSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Authentication Type:", ChildId::AUTH_TYPE_LABEL)); + authTypeComboBox = CreateComboBox(editPosX, rowPos, editSizeX, ROW_HEIGHT, + "Authentication Type:", ChildId::AUTH_TYPE_COMBOBOX); + authTypeComboBox->AddString("Basic Authentication"); + authTypeComboBox->AddString("Token Authentication"); + + rowPos += INTERVAL + ROW_HEIGHT; + + const char* val = config.Get(FlightSqlConnection::UID).c_str(); + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "User:", ChildId::USER_LABEL)); + userEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::USER_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::PWD).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Password:", ChildId::PASSWORD_LABEL)); + passwordEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, + val, ChildId::USER_EDIT, ES_PASSWORD); + + rowPos += INTERVAL + ROW_HEIGHT; + + const auto& token = config.Get(FlightSqlConnection::TOKEN); + val = token.c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Authentication Token:", ChildId::AUTH_TOKEN_LABEL)); + authTokenEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, + val, ChildId::AUTH_TOKEN_EDIT); + authTokenEdit->SetEnabled(false); + + // Ensure the right elements are selected. + authTypeComboBox->SetSelection(token.empty() ? 0 : 1); + CheckAuthType(); + + rowPos += INTERVAL + ROW_HEIGHT; + + authSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Authentication settings", ChildId::AUTH_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + const char* val = config.Get(FlightSqlConnection::USE_ENCRYPTION).c_str(); + + const bool enableEncryption = driver::odbcabstraction::AsBool(val).value_or(true); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Use Encryption:", + ChildId::ENABLE_ENCRYPTION_LABEL)); + enableEncryptionCheckBox = CreateCheckBox(editPosX, rowPos - 2, editSizeX, ROW_HEIGHT, "", + ChildId::ENABLE_ENCRYPTION_CHECKBOX, enableEncryption); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::TRUSTED_CERTS).c_str(); + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Certificate:", ChildId::CERTIFICATE_LABEL)); + certificateEdit = CreateEdit(editPosX, rowPos, editSizeX - MARGIN - BUTTON_WIDTH, ROW_HEIGHT, val, ChildId::CERTIFICATE_EDIT); + certificateBrowseButton = CreateButton(editPosX + editSizeX - BUTTON_WIDTH, rowPos - 2, BUTTON_WIDTH, BUTTON_HEIGHT, + "Browse", ChildId::CERTIFICATE_BROWSE_BUTTON); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::USE_SYSTEM_TRUST_STORE).c_str(); + + const bool useSystemCertStore = driver::odbcabstraction::AsBool(val).value_or(true); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, "Use System Certificate Store:", + ChildId::USE_SYSTEM_CERT_STORE_LABEL)); + useSystemCertStoreCheckBox = CreateCheckBox(editPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + ChildId::USE_SYSTEM_CERT_STORE_CHECKBOX, useSystemCertStore); + + + val = config.Get(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION).c_str(); + + const int rightPosX = labelPosX + (sizeX - (2 * INTERVAL)) / 2; + const int rightCheckPosX = rightPosX + (editPosX - labelPosX); + const bool disableCertVerification = driver::odbcabstraction::AsBool(val).value_or(false); + labels.push_back(CreateLabel(rightPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, "Disable Certificate Verification:", + ChildId::DISABLE_CERT_VERIFICATION_LABEL)); + disableCertVerificationCheckBox = CreateCheckBox(rightCheckPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + ChildId::DISABLE_CERT_VERIFICATION_CHECKBOX, disableCertVerification); + + rowPos += INTERVAL + static_cast(1.5 * ROW_HEIGHT); + + encryptionSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Encryption settings", ChildId::AUTH_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreatePropertiesGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + const int listSize = sizeX - 2 * INTERVAL; + const int columnSize = listSize / 2; + + int rowPos = posY + 2 * INTERVAL; + const int listHeight = 5 * ROW_HEIGHT; + + propertyList = CreateList(labelPosX, rowPos, listSize, listHeight, ChildId::PROPERTY_LIST); + propertyList->ListAddColumn("Key", 0, columnSize); + propertyList->ListAddColumn("Value", 1, columnSize); + + const auto keys = config.GetCustomKeys(); + for (const auto& key : keys) { + propertyList->ListAddItem({ key, config.Get(key) }); + } + + SendMessage(propertyList->GetHandle(), LVM_SETEXTENDEDLISTVIEWSTYLE, LVS_EX_FULLROWSELECT, LVS_EX_FULLROWSELECT); + + rowPos += INTERVAL + listHeight; + + int deletePosX = width - INTERVAL - MARGIN - BUTTON_WIDTH; + int addPosX = deletePosX - INTERVAL - BUTTON_WIDTH; + addButton = CreateButton(addPosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, "Add", ChildId::ADD_BUTTON); + deleteButton = CreateButton(deletePosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, + "Delete", ChildId::DELETE_BUTTON); + + rowPos += INTERVAL + BUTTON_HEIGHT; + + propertyGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Advanced properties", ChildId::PROPERTY_GROUP_BOX); + + return rowPos - posY; +} + +void DsnConfigurationWindow::SelectTab(int tabIndex) { + if (!isInitialized) { + return; + } + + connectionSettingsGroupBox->SetVisible(COMMON_TAB == tabIndex); + authSettingsGroupBox->SetVisible(COMMON_TAB == tabIndex); + nameEdit->SetVisible(COMMON_TAB == tabIndex); + serverEdit->SetVisible(COMMON_TAB == tabIndex); + portEdit->SetVisible(COMMON_TAB == tabIndex); + authTypeComboBox->SetVisible(COMMON_TAB == tabIndex); + userEdit->SetVisible(COMMON_TAB == tabIndex); + passwordEdit->SetVisible(COMMON_TAB == tabIndex); + authTokenEdit->SetVisible(COMMON_TAB == tabIndex); + for (size_t i = 0; i < 7; ++i) { + labels[i]->SetVisible(COMMON_TAB == tabIndex); + } + + encryptionSettingsGroupBox->SetVisible(ADVANCED_TAB == tabIndex); + enableEncryptionCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + certificateEdit->SetVisible(ADVANCED_TAB == tabIndex); + certificateBrowseButton->SetVisible(ADVANCED_TAB == tabIndex); + useSystemCertStoreCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + disableCertVerificationCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + propertyGroupBox->SetVisible(ADVANCED_TAB == tabIndex); + propertyList->SetVisible(ADVANCED_TAB == tabIndex); + addButton->SetVisible(ADVANCED_TAB == tabIndex); + deleteButton->SetVisible(ADVANCED_TAB == tabIndex); + for (size_t i = 7; i < labels.size(); ++i) { + labels[i]->SetVisible(ADVANCED_TAB == tabIndex); + } +} + +void DsnConfigurationWindow::CheckEnableOk() { + if (!isInitialized) { + return; + } + + bool enableOk = !nameEdit->IsTextEmpty(); + enableOk = enableOk && !serverEdit->IsTextEmpty(); + enableOk = enableOk && !portEdit->IsTextEmpty(); + if (authTokenEdit->IsEnabled()) + { + enableOk = enableOk && !authTokenEdit->IsTextEmpty(); + } + else + { + enableOk = enableOk && !userEdit->IsTextEmpty(); + enableOk = enableOk && !passwordEdit->IsTextEmpty(); + } + + testButton->SetEnabled(enableOk); + okButton->SetEnabled(enableOk); +} + +void DsnConfigurationWindow::SaveParameters(Configuration& targetConfig) +{ + targetConfig.Clear(); + + std::string text; + nameEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::DSN, text); + serverEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::HOST, text); + portEdit->GetText(text); + try { + const int portInt = std::stoi(text); + if (0 > portInt || USHRT_MAX < portInt) + { + throw odbcabstraction::DriverException("Invalid port value."); + } + targetConfig.Set(FlightSqlConnection::PORT, text); + } + catch (odbcabstraction::DriverException&) { + throw; + } + catch (std::exception&) { + throw odbcabstraction::DriverException("Invalid port value."); + } + + if (0 == authTypeComboBox->GetSelection()) + { + userEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::UID, text); + passwordEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::PWD, text); + } + else + { + authTokenEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::TOKEN, text); + } + + if (enableEncryptionCheckBox->IsChecked()) + { + targetConfig.Set(FlightSqlConnection::USE_ENCRYPTION, TRUE_STR); + certificateEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::TRUSTED_CERTS, text); + targetConfig.Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, useSystemCertStoreCheckBox->IsChecked() ? TRUE_STR : FALSE_STR); + targetConfig.Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, disableCertVerificationCheckBox->IsChecked() ? TRUE_STR : FALSE_STR); + } + else + { + targetConfig.Set(FlightSqlConnection::USE_ENCRYPTION, FALSE_STR); + } + + // Get all the list properties. + const auto properties = propertyList->ListGetAll(); + for (const auto& property : properties) { + targetConfig.Set(property[0], property[1]); + } +} + +void DsnConfigurationWindow::CheckAuthType() { + const bool isBasic = COMMON_TAB == authTypeComboBox->GetSelection(); + userEdit->SetEnabled(isBasic); + passwordEdit->SetEnabled(isBasic); + authTokenEdit->SetEnabled(!isBasic); +} + +bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) +{ + switch (msg) + { + case WM_NOTIFY: + { + switch (((LPNMHDR)lParam)->code) + { + case TCN_SELCHANGING: + { + // Return FALSE to allow the selection to change. + return FALSE; + } + + case TCN_SELCHANGE: + { + SelectTab(TabCtrl_GetCurSel(tabControl->GetHandle())); + break; + } + } + break; + } + + case WM_COMMAND: + { + switch (LOWORD(wParam)) + { + case ChildId::TEST_CONNECTION_BUTTON: + { + try + { + Configuration testConfig; + SaveParameters(testConfig); + std::string testMessage = TestConnection(testConfig); + + MessageBox(NULL, testMessage.c_str(), "Test Connection Success", MB_OK); + } + catch (odbcabstraction::DriverException& err) + { + MessageBox(NULL, err.GetMessageText().c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + } + + break; + } + case ChildId::OK_BUTTON: + { + try + { + SaveParameters(config); + accepted = true; + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + } + catch (odbcabstraction::DriverException& err) + { + MessageBox(NULL, err.GetMessageText().c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + } + + break; + } + + case IDCANCEL: + case ChildId::CANCEL_BUTTON: + { + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + break; + } + + case ChildId::AUTH_TOKEN_EDIT: + case ChildId::NAME_EDIT: + case ChildId::PASSWORD_EDIT: + case ChildId::PORT_EDIT: + case ChildId::SERVER_EDIT: + case ChildId::USER_EDIT: + { + if (HIWORD(wParam) == EN_CHANGE) + { + CheckEnableOk(); + } + break; + } + + case ChildId::AUTH_TYPE_COMBOBOX: + { + CheckAuthType(); + CheckEnableOk(); + break; + } + + case ChildId::ENABLE_ENCRYPTION_CHECKBOX: + { + const bool toggle = !enableEncryptionCheckBox->IsChecked(); + enableEncryptionCheckBox->SetChecked(toggle); + certificateEdit->SetEnabled(toggle); + certificateBrowseButton->SetEnabled(toggle); + useSystemCertStoreCheckBox->SetEnabled(toggle); + disableCertVerificationCheckBox->SetEnabled(toggle); + break; + } + + case ChildId::CERTIFICATE_BROWSE_BUTTON: + { + OPENFILENAME openFileName; + char fileName[FILENAME_MAX]; + + ZeroMemory(&openFileName, sizeof(openFileName)); + openFileName.lStructSize = sizeof(openFileName); + openFileName.hwndOwner = NULL; + openFileName.lpstrFile = fileName; + openFileName.lpstrFile[0] = '\0'; + openFileName.nMaxFile = FILENAME_MAX; + // TODO: What type should this be? + openFileName.lpstrFilter = "All\0*.*"; + openFileName.nFilterIndex = 1; + openFileName.lpstrFileTitle = NULL; + openFileName.nMaxFileTitle = 0; + openFileName.lpstrInitialDir = NULL; + openFileName.Flags = OFN_PATHMUSTEXIST | OFN_FILEMUSTEXIST; + + if (GetOpenFileName(&openFileName)) { + certificateEdit->SetText(fileName); + } + break; + } + + case ChildId::USE_SYSTEM_CERT_STORE_CHECKBOX: + { + useSystemCertStoreCheckBox->SetChecked(!useSystemCertStoreCheckBox->IsChecked()); + break; + } + + case ChildId::DISABLE_CERT_VERIFICATION_CHECKBOX: + { + disableCertVerificationCheckBox->SetChecked(!disableCertVerificationCheckBox->IsChecked()); + break; + } + + case ChildId::DELETE_BUTTON: + { + propertyList->ListDeleteSelectedItem(); + break; + } + + case ChildId::ADD_BUTTON: + { + AddPropertyWindow addWindow(this); + addWindow.Create(); + addWindow.Show(); + addWindow.Update(); + + if (ProcessMessages(addWindow) == Result::OK) + { + std::string key; + std::string value; + addWindow.GetProperty(key, value); + propertyList->ListAddItem({ key, value }); + } + break; + } + + default: + return false; + } + + break; + } + + case WM_DESTROY: + { + PostQuitMessage(accepted ? Result::OK : Result::CANCEL); + + break; + } + + default: + return false; + } + + return true; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc new file mode 100644 index 0000000000000..9e48b0fdc6b57 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "winuser.h" +#include +#include +#include + +#include "ui/window.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +HINSTANCE GetHInstance() +{ + TCHAR szFileName[MAX_PATH]; + GetModuleFileName(NULL, szFileName, MAX_PATH); + + // TODO: This needs to be the module name. + HINSTANCE hInstance = GetModuleHandle(szFileName); + + if (hInstance == NULL) + { + std::stringstream buf; + buf << "Can not get hInstance for the module, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + return hInstance; +} + +Window::Window(Window* parent, const char* className, const char* title) : + className(className), + title(title), + handle(NULL), + parent(parent), + created(false) +{ + // No-op. +} + +Window::Window(HWND handle) : + className(), + title(), + handle(handle), + parent(0), + created(false) +{ + // No-op. +} + +Window::~Window() +{ + if (created) + Destroy(); +} + +void Window::Create(DWORD style, int posX, int posY, int width, int height, int id) +{ + if (handle) + { + std::stringstream buf; + buf << "Window already created, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + handle = CreateWindow( + className.c_str(), + title.c_str(), + style, + posX, + posY, + width, + height, + parent ? parent->GetHandle() : NULL, + reinterpret_cast(static_cast(id)), + GetHInstance(), + this + ); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + created = true; + + const HGDIOBJ hfDefault = GetStockObject(DEFAULT_GUI_FONT); + SendMessage(GetHandle(), WM_SETFONT, (WPARAM)hfDefault, MAKELPARAM(FALSE, 0)); +} + + +std::unique_ptr Window::CreateTabControl(int id) +{ + std::unique_ptr child(new Window(this, WC_TABCONTROL, "")); + + // Get the dimensions of the parent window's client area, and + // create a tab control child window of that size. + RECT rcClient; + GetClientRect(handle, &rcClient); + + child->Create(WS_CHILD | WS_CLIPSIBLINGS | WS_VISIBLE | WS_TABSTOP, 0, 0, rcClient.right, 20, id); + + return child; +} + +std::unique_ptr Window::CreateList(int posX, int posY, + int sizeX, int sizeY, int id) +{ + std::unique_ptr child(new Window(this, WC_LISTVIEW, "")); + + child->Create(WS_CHILD | WS_VISIBLE | WS_BORDER | LVS_REPORT | LVS_EDITLABELS | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateGroupBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | BS_GROUPBOX, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateLabel(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Static", title)); + + child->Create(WS_CHILD | WS_VISIBLE, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateEdit(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style) +{ + std::unique_ptr child(new Window(this, "Edit", title)); + + child->Create(WS_CHILD | WS_VISIBLE | WS_BORDER | ES_AUTOHSCROLL | WS_TABSTOP | style, + posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateButton(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | WS_TABSTOP | style, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateCheckBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, bool state) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | BS_CHECKBOX | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + child->SetChecked(state); + + return child; +} + +std::unique_ptr Window::CreateComboBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Combobox", title)); + + child->Create(WS_CHILD | WS_VISIBLE | CBS_DROPDOWNLIST | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + return child; +} + +void Window::Show() +{ + ShowWindow(handle, SW_SHOW); +} + +void Window::Update() +{ + UpdateWindow(handle); +} + +void Window::Destroy() +{ + if (handle) + DestroyWindow(handle); + + handle = NULL; +} + +void Window::SetVisible(bool isVisible) { + ShowWindow(handle, isVisible ? SW_SHOW : SW_HIDE); +} + +bool Window::IsTextEmpty() const +{ + if (!IsEnabled()) + { + return true; + } + int len = GetWindowTextLength(handle); + + return (len <= 0); +} + +void Window::ListAddColumn(const std::string& name, int index, int width) +{ + LVCOLUMN lvc; + lvc.mask = LVCF_FMT | LVCF_WIDTH | LVCF_TEXT | LVCF_SUBITEM; + lvc.fmt = LVCFMT_LEFT; + lvc.cx = width; + lvc.pszText = const_cast(name.c_str()); + lvc.iSubItem = index; + + if (ListView_InsertColumn(handle, index, &lvc) == -1) + { + std::stringstream buf; + buf << "Can not add list column, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void Window::ListAddItem(const std::vector& items) +{ + LVITEM lvi = { 0 }; + lvi.mask = LVIF_TEXT; + lvi.pszText = const_cast(items[0].c_str()); + + int ret = ListView_InsertItem(handle, &lvi); + if (ret < 0) { + std::stringstream buf; + buf << "Can not add list item, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + for (size_t i = 1; i < items.size(); ++i) { + ListView_SetItemText(handle, ret, static_cast(i), const_cast(items[i].c_str())); + } +} + +void Window::ListDeleteSelectedItem() +{ + const int rowIndex = ListView_GetSelectionMark(handle); + if (rowIndex >= 0) { + if (ListView_DeleteItem(handle, rowIndex) == -1) { + std::stringstream buf; + buf << "Can not delete list item, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + } +} + +std::vector > Window::ListGetAll() +{ + #define BUF_LEN 1024 + char buf[BUF_LEN]; + + std::vector > values; + const int numColumns = Header_GetItemCount(ListView_GetHeader(handle)); + const int numItems = ListView_GetItemCount(handle); + for (int i = 0; i < numItems; ++i) { + std::vector row; + for (int j = 0; j < numColumns; ++j) { + ListView_GetItemText(handle, i, j, buf, BUF_LEN); + row.emplace_back(buf); + } + values.push_back(row); + } + + return values; +} + +void Window::AddTab(const std::string& name, int index) +{ + TCITEM tabControlItem; + tabControlItem.mask = TCIF_TEXT | TCIF_IMAGE; + tabControlItem.iImage = -1; + tabControlItem.pszText = const_cast(name.c_str()); + if (TabCtrl_InsertItem(handle, index, &tabControlItem) == -1) + { + std::stringstream buf; + buf << "Can not add tab, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void Window::GetText(std::string& text) const +{ + if (!IsEnabled()) + { + text.clear(); + + return; + } + + int len = GetWindowTextLength(handle); + + if (len <= 0) + { + text.clear(); + + return; + } + + text.resize(len + 1); + + if (!GetWindowText(handle, &text[0], len + 1)) + text.clear(); + + text.resize(len); + boost::algorithm::trim(text); +} + +void Window::SetText(const std::string& text) const +{ + SNDMSG(handle, WM_SETTEXT, 0, reinterpret_cast(text.c_str())); +} + +bool Window::IsChecked() const +{ + return IsEnabled() && Button_GetCheck(handle) == BST_CHECKED; +} + +void Window::SetChecked(bool state) +{ + Button_SetCheck(handle, state ? BST_CHECKED : BST_UNCHECKED); +} + +void Window::AddString(const std::string & str) +{ + SNDMSG(handle, CB_ADDSTRING, 0, reinterpret_cast(str.c_str())); +} + +void Window::SetSelection(int idx) +{ + SNDMSG(handle, CB_SETCURSEL, static_cast(idx), 0); +} + +int Window::GetSelection() const +{ + return static_cast(SNDMSG(handle, CB_GETCURSEL, 0, 0)); +} + +void Window::SetEnabled(bool enabled) +{ + EnableWindow(GetHandle(), enabled); +} + +bool Window::IsEnabled() const +{ + return IsWindowEnabled(GetHandle()) != 0; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc new file mode 100644 index 0000000000000..ebfd50b633012 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc @@ -0,0 +1,1114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "utils.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "json_converter.h" + +#include + +#include +#include + +namespace driver { +namespace flight_sql { + +namespace { +bool IsComplexType(arrow::Type::type type_id) { + switch (type_id) { + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::MAP: + case arrow::Type::STRUCT: + return true; + default: + return false; + } +} + +odbcabstraction::SqlDataType GetDefaultSqlCharType(bool useWideChar) { + return useWideChar ? odbcabstraction::SqlDataType_WCHAR : odbcabstraction::SqlDataType_CHAR; +} +odbcabstraction::SqlDataType GetDefaultSqlVarcharType(bool useWideChar) { + return useWideChar ? odbcabstraction::SqlDataType_WVARCHAR : odbcabstraction::SqlDataType_VARCHAR; +} +odbcabstraction::CDataType GetDefaultCCharType(bool useWideChar) { + return useWideChar ? odbcabstraction::CDataType_WCHAR : odbcabstraction::CDataType_CHAR; +} + +} + +using namespace odbcabstraction; +using arrow::util::make_optional; +using arrow::util::nullopt; + +/// \brief Returns the mapping from Arrow type to SqlDataType +/// \param field the field to return the SqlDataType for +/// \return the concise SqlDataType for the field. +/// \note use GetNonConciseDataType on the output to get the verbose type +/// \note the concise and verbose types are the same for all but types relating to times and intervals +SqlDataType +GetDataTypeFromArrowField_V3(const std::shared_ptr &field, bool useWideChar) { + const std::shared_ptr &type = field->type(); + + switch (type->id()) { + case arrow::Type::BOOL: + return odbcabstraction::SqlDataType_BIT; + case arrow::Type::UINT8: + case arrow::Type::INT8: + return odbcabstraction::SqlDataType_TINYINT; + case arrow::Type::UINT16: + case arrow::Type::INT16: + return odbcabstraction::SqlDataType_SMALLINT; + case arrow::Type::UINT32: + case arrow::Type::INT32: + return odbcabstraction::SqlDataType_INTEGER; + case arrow::Type::UINT64: + case arrow::Type::INT64: + return odbcabstraction::SqlDataType_BIGINT; + case arrow::Type::HALF_FLOAT: + case arrow::Type::FLOAT: + return odbcabstraction::SqlDataType_FLOAT; + case arrow::Type::DOUBLE: + return odbcabstraction::SqlDataType_DOUBLE; + case arrow::Type::BINARY: + case arrow::Type::FIXED_SIZE_BINARY: + case arrow::Type::LARGE_BINARY: + return odbcabstraction::SqlDataType_BINARY; + case arrow::Type::STRING: + case arrow::Type::LARGE_STRING: + return GetDefaultSqlVarcharType(useWideChar); + case arrow::Type::DATE32: + case arrow::Type::DATE64: + return odbcabstraction::SqlDataType_TYPE_DATE; + case arrow::Type::TIMESTAMP: + return odbcabstraction::SqlDataType_TYPE_TIMESTAMP; + case arrow::Type::DECIMAL128: + return odbcabstraction::SqlDataType_DECIMAL; + case arrow::Type::TIME32: + case arrow::Type::TIME64: + return odbcabstraction::SqlDataType_TYPE_TIME; + case arrow::Type::INTERVAL_MONTHS: + return odbcabstraction::SqlDataType_INTERVAL_MONTH; // TODO: maybe SqlDataType_INTERVAL_YEAR_TO_MONTH + case arrow::Type::INTERVAL_DAY_TIME: + return odbcabstraction::SqlDataType_INTERVAL_DAY; + + // TODO: Handle remaining types. + case arrow::Type::INTERVAL_MONTH_DAY_NANO: + case arrow::Type::LIST: + case arrow::Type::STRUCT: + case arrow::Type::SPARSE_UNION: + case arrow::Type::DENSE_UNION: + case arrow::Type::DICTIONARY: + case arrow::Type::MAP: + case arrow::Type::EXTENSION: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::DURATION: + case arrow::Type::LARGE_LIST: + case arrow::Type::MAX_ID: + case arrow::Type::NA: + break; + } + + return GetDefaultSqlVarcharType(useWideChar); +} + +SqlDataType EnsureRightSqlCharType(SqlDataType data_type, bool useWideChar) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_WCHAR: + return GetDefaultSqlCharType(useWideChar); + case SqlDataType_VARCHAR: + case SqlDataType_WVARCHAR: + return GetDefaultSqlVarcharType(useWideChar); + default: + return data_type; + } +} + +int16_t ConvertSqlDataTypeFromV3ToV2(int16_t data_type_v3) { + switch (data_type_v3) { + case SqlDataType_TYPE_DATE: + return 9; // Same as SQL_DATE from sqlext.h + case SqlDataType_TYPE_TIME: + return 10; // Same as SQL_TIME from sqlext.h + case SqlDataType_TYPE_TIMESTAMP: + return 11; // Same as SQL_TIMESTAMP from sqlext.h + default: + return data_type_v3; + } +} + +CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2) { + switch (data_type_v2) { + case -6: // Same as SQL_C_TINYINT from sqlext.h + return CDataType_STINYINT; + case 4: // Same as SQL_C_LONG from sqlext.h + return CDataType_SLONG; + case 5: // Same as SQL_C_SHORT from sqlext.h + return CDataType_SSHORT; + case 7: // Same as SQL_C_FLOAT from sqlext.h + return CDataType_FLOAT; + case 8: // Same as SQL_C_DOUBLE from sqlext.h + return CDataType_DOUBLE; + case 9: // Same as SQL_C_DATE from sqlext.h + return CDataType_DATE; + case 10: // Same as SQL_C_TIME from sqlext.h + return CDataType_TIME; + case 11: // Same as SQL_C_TIMESTAMP from sqlext.h + return CDataType_TIMESTAMP; + default: + return static_cast(data_type_v2); + } +} + +std::string GetTypeNameFromSqlDataType(int16_t data_type) { + switch (data_type) { + case SqlDataType_CHAR: + return "CHAR"; + case SqlDataType_VARCHAR: + return "VARCHAR"; + case SqlDataType_LONGVARCHAR: + return "LONGVARCHAR"; + case SqlDataType_WCHAR: + return "WCHAR"; + case SqlDataType_WVARCHAR: + return "WVARCHAR"; + case SqlDataType_WLONGVARCHAR: + return "WLONGVARCHAR"; + case SqlDataType_DECIMAL: + return "DECIMAL"; + case SqlDataType_NUMERIC: + return "NUMERIC"; + case SqlDataType_SMALLINT: + return "SMALLINT"; + case SqlDataType_INTEGER: + return "INTEGER"; + case SqlDataType_REAL: + return "REAL"; + case SqlDataType_FLOAT: + return "FLOAT"; + case SqlDataType_DOUBLE: + return "DOUBLE"; + case SqlDataType_BIT: + return "BIT"; + case SqlDataType_TINYINT: + return "TINYINT"; + case SqlDataType_BIGINT: + return "BIGINT"; + case SqlDataType_BINARY: + return "BINARY"; + case SqlDataType_VARBINARY: + return "VARBINARY"; + case SqlDataType_LONGVARBINARY: + return "LONGVARBINARY"; + case SqlDataType_TYPE_DATE: + case 9: + return "DATE"; + case SqlDataType_TYPE_TIME: + case 10: + return "TIME"; + case SqlDataType_TYPE_TIMESTAMP: + case 11: + return "TIMESTAMP"; + case SqlDataType_INTERVAL_MONTH: + return "INTERVAL_MONTH"; + case SqlDataType_INTERVAL_YEAR: + return "INTERVAL_YEAR"; + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + return "INTERVAL_YEAR_TO_MONTH"; + case SqlDataType_INTERVAL_DAY: + return "INTERVAL_DAY"; + case SqlDataType_INTERVAL_HOUR: + return "INTERVAL_HOUR"; + case SqlDataType_INTERVAL_MINUTE: + return "INTERVAL_MINUTE"; + case SqlDataType_INTERVAL_SECOND: + return "INTERVAL_SECOND"; + case SqlDataType_INTERVAL_DAY_TO_HOUR: + return "INTERVAL_DAY_TO_HOUR"; + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + return "INTERVAL_DAY_TO_MINUTE"; + case SqlDataType_INTERVAL_DAY_TO_SECOND: + return "INTERVAL_DAY_TO_SECOND"; + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + return "INTERVAL_HOUR_TO_MINUTE"; + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + return "INTERVAL_HOUR_TO_SECOND"; + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return "INTERVAL_MINUTE_TO_SECOND"; + case SqlDataType_GUID: + return "GUID"; + } + + throw driver::odbcabstraction::DriverException("Unsupported data type: " + + std::to_string(data_type)); +} + +optional +GetRadixFromSqlDataType(odbcabstraction::SqlDataType data_type) { + switch (data_type) { + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + case SqlDataType_SMALLINT: + case SqlDataType_TINYINT: + case SqlDataType_INTEGER: + case SqlDataType_BIGINT: + return 10; + case SqlDataType_REAL: + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 2; + default: + return arrow::util::nullopt; + } +} + +int16_t GetNonConciseDataType(odbcabstraction::SqlDataType data_type) { + switch (data_type) { + case SqlDataType_TYPE_DATE: + case SqlDataType_TYPE_TIME: + case SqlDataType_TYPE_TIMESTAMP: + return 9; // Same as SQL_DATETIME on sql.h + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 10; // Same as SQL_INTERVAL on sqlext.h + default: + return data_type; + } +} + +optional GetSqlDateTimeSubCode(SqlDataType data_type) { + switch (data_type) { + case SqlDataType_TYPE_DATE: + return SqlDateTimeSubCode_DATE; + case SqlDataType_TYPE_TIME: + return SqlDateTimeSubCode_TIME; + case SqlDataType_TYPE_TIMESTAMP: + return SqlDateTimeSubCode_TIMESTAMP; + case SqlDataType_INTERVAL_YEAR: + return SqlDateTimeSubCode_YEAR; + case SqlDataType_INTERVAL_MONTH: + return SqlDateTimeSubCode_MONTH; + case SqlDataType_INTERVAL_DAY: + return SqlDateTimeSubCode_DAY; + case SqlDataType_INTERVAL_HOUR: + return SqlDateTimeSubCode_HOUR; + case SqlDataType_INTERVAL_MINUTE: + return SqlDateTimeSubCode_MINUTE; + case SqlDataType_INTERVAL_SECOND: + return SqlDateTimeSubCode_SECOND; + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + return SqlDateTimeSubCode_YEAR_TO_MONTH; + case SqlDataType_INTERVAL_DAY_TO_HOUR: + return SqlDateTimeSubCode_DAY_TO_HOUR; + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + return SqlDateTimeSubCode_DAY_TO_MINUTE; + case SqlDataType_INTERVAL_DAY_TO_SECOND: + return SqlDateTimeSubCode_DAY_TO_SECOND; + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + return SqlDateTimeSubCode_HOUR_TO_MINUTE; + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + return SqlDateTimeSubCode_HOUR_TO_SECOND; + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return SqlDateTimeSubCode_MINUTE_TO_SECOND; + default: + return arrow::util::nullopt; + } +} + +optional GetCharOctetLength(SqlDataType data_type, + const arrow::Result& column_size, const int32_t decimal_precison) { + switch (data_type) { + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + if (column_size.ok()) { + return column_size.ValueOrDie(); + } else { + return arrow::util::nullopt; + } + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + if (column_size.ok()) { + return column_size.ValueOrDie() * GetSqlWCharSize(); + } else { + return arrow::util::nullopt; + } + case SqlDataType_TINYINT: + case SqlDataType_BIT: + return 1; // The same as sizeof(SQL_C_BIT) + case SqlDataType_SMALLINT: + return 2; // The same as sizeof(SQL_C_SMALLINT) + case SqlDataType_INTEGER: + return 4; // The same as sizeof(SQL_C_INTEGER) + case SqlDataType_BIGINT: + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; // The same as sizeof(SQL_C_DOUBLE) + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return decimal_precison + 2; // One char for each digit and two extra chars for a sign and a decimal point + case SqlDataType_TYPE_DATE: + case SqlDataType_TYPE_TIME: + return 6; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 16; // The same as sizeof(SQL_TIMESTAMP_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 34; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} +optional GetTypeScale(SqlDataType data_type, + const optional& type_scale) { + switch (data_type) { + case SqlDataType_TYPE_TIMESTAMP: + case SqlDataType_TYPE_TIME: + return 3; + case SqlDataType_DECIMAL: + return type_scale; + case SqlDataType_NUMERIC: + return type_scale; + case SqlDataType_TINYINT: + case SqlDataType_SMALLINT: + case SqlDataType_INTEGER: + case SqlDataType_BIGINT: + return 0; + default: + return arrow::util::nullopt; + } +} +optional GetColumnSize(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + return column_size; + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size.has_value() ? arrow::util::make_optional(column_size.value() * GetSqlWCharSize()) + : arrow::util::nullopt; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetBufferLength(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + return column_size; + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size.has_value() ? arrow::util::make_optional(column_size.value() * GetSqlWCharSize()) + : arrow::util::nullopt; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetLength(SqlDataType data_type, const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetDisplaySize(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size ? make_optional(*column_size * 2) : nullopt; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return column_size ? make_optional(*column_size + 2) : nullopt; + case SqlDataType_BIT: + return 1; + case SqlDataType_TINYINT: + return 4; + case SqlDataType_SMALLINT: + return 6; + case SqlDataType_INTEGER: + return 11; + case SqlDataType_BIGINT: + return 20; + case SqlDataType_REAL: + return 14; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 24; + case SqlDataType_TYPE_DATE: + return 10; + case SqlDataType_TYPE_TIME: + return 12; // Assuming format "hh:mm:ss.fff" + case SqlDataType_TYPE_TIMESTAMP: + return 23; // Assuming format "yyyy-mm-dd hh:mm:ss.fff" + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return nullopt; // TODO: Implement for INTERVAL types + case SqlDataType_GUID: + return 36; + default: + return nullopt; + } +} + +std::string ConvertSqlPatternToRegexString(const std::string &pattern) { + static const std::string specials = "[]()|^-+*?{}$\\."; + + std::string regex_str; + bool escape = false; + for (const auto &c : pattern) { + if (escape) { + regex_str += c; + escape = false; + continue; + } + + switch (c) { + case '\\': + escape = true; + break; + case '_': + regex_str += '.'; + break; + case '%': + regex_str += ".*"; + break; + default: + if (specials.find(c) != std::string::npos) { + regex_str += '\\'; + } + regex_str += c; + break; + } + } + return regex_str; +} + +boost::xpressive::sregex ConvertSqlPatternToRegex(const std::string &pattern) { + const std::string ®ex_str = ConvertSqlPatternToRegexString(pattern); + return boost::xpressive::sregex(boost::xpressive::sregex::compile(regex_str)); +} + +bool NeedArrayConversion(arrow::Type::type original_type_id, odbcabstraction::CDataType data_type) { + switch (original_type_id) { + case arrow::Type::DATE32: + case arrow::Type::DATE64: + return data_type != odbcabstraction::CDataType_DATE; + case arrow::Type::TIME32: + case arrow::Type::TIME64: + return data_type != odbcabstraction::CDataType_TIME; + case arrow::Type::TIMESTAMP: + return data_type != odbcabstraction::CDataType_TIMESTAMP; + case arrow::Type::STRING: + return data_type != odbcabstraction::CDataType_CHAR && + data_type != odbcabstraction::CDataType_WCHAR; + case arrow::Type::INT16: + return data_type != odbcabstraction::CDataType_SSHORT; + case arrow::Type::UINT16: + return data_type != odbcabstraction::CDataType_USHORT; + case arrow::Type::INT32: + return data_type != odbcabstraction::CDataType_SLONG; + case arrow::Type::UINT32: + return data_type != odbcabstraction::CDataType_ULONG; + case arrow::Type::FLOAT: + return data_type != odbcabstraction::CDataType_FLOAT; + case arrow::Type::DOUBLE: + return data_type != odbcabstraction::CDataType_DOUBLE; + case arrow::Type::BOOL: + return data_type != odbcabstraction::CDataType_BIT; + case arrow::Type::INT8: + return data_type != odbcabstraction::CDataType_STINYINT; + case arrow::Type::UINT8: + return data_type != odbcabstraction::CDataType_UTINYINT; + case arrow::Type::INT64: + return data_type != odbcabstraction::CDataType_SBIGINT; + case arrow::Type::UINT64: + return data_type != odbcabstraction::CDataType_UBIGINT; + case arrow::Type::BINARY: + return data_type != odbcabstraction::CDataType_BINARY; + case arrow::Type::DECIMAL128: + return data_type != odbcabstraction::CDataType_NUMERIC; + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::MAP: + case arrow::Type::STRUCT: + return data_type == odbcabstraction::CDataType_CHAR || data_type == odbcabstraction::CDataType_WCHAR; + default: + throw odbcabstraction::DriverException(std::string("Invalid conversion")); + } +} + +std::shared_ptr +GetDefaultDataTypeForTypeId(arrow::Type::type type_id) { + switch (type_id) { + case arrow::Type::STRING: + return arrow::utf8(); + case arrow::Type::INT16: + return arrow::int16(); + case arrow::Type::UINT16: + return arrow::uint16(); + case arrow::Type::INT32: + return arrow::int32(); + case arrow::Type::UINT32: + return arrow::uint32(); + case arrow::Type::FLOAT: + return arrow::float32(); + case arrow::Type::DOUBLE: + return arrow::float64(); + case arrow::Type::BOOL: + return arrow::boolean(); + case arrow::Type::INT8: + return arrow::int8(); + case arrow::Type::UINT8: + return arrow::uint8(); + case arrow::Type::INT64: + return arrow::int64(); + case arrow::Type::UINT64: + return arrow::uint64(); + case arrow::Type::BINARY: + return arrow::binary(); + case arrow::Type::DECIMAL128: + return arrow::decimal128(arrow::Decimal128Type::kMaxPrecision, 0); + case arrow::Type::DATE64: + return arrow::date64(); + case arrow::Type::TIME64: + return arrow::time64(arrow::TimeUnit::MICRO); + case arrow::Type::TIMESTAMP: + return arrow::timestamp(arrow::TimeUnit::SECOND); + } + + throw odbcabstraction::DriverException(std::string("Invalid type id: ") + std::to_string(type_id)); +} + +arrow::Type::type +ConvertCToArrowType(odbcabstraction::CDataType data_type) { + switch (data_type) { + case odbcabstraction::CDataType_CHAR: + case odbcabstraction::CDataType_WCHAR: + return arrow::Type::STRING; + case odbcabstraction::CDataType_SSHORT: + return arrow::Type::INT16; + case odbcabstraction::CDataType_USHORT: + return arrow::Type::UINT16; + case odbcabstraction::CDataType_SLONG: + return arrow::Type::INT32; + case odbcabstraction::CDataType_ULONG: + return arrow::Type::UINT32; + case odbcabstraction::CDataType_FLOAT: + return arrow::Type::FLOAT; + case odbcabstraction::CDataType_DOUBLE: + return arrow::Type::DOUBLE; + case odbcabstraction::CDataType_BIT: + return arrow::Type::BOOL; + case odbcabstraction::CDataType_STINYINT: + return arrow::Type::INT8; + case odbcabstraction::CDataType_UTINYINT: + return arrow::Type::UINT8; + case odbcabstraction::CDataType_SBIGINT: + return arrow::Type::INT64; + case odbcabstraction::CDataType_UBIGINT: + return arrow::Type::UINT64; + case odbcabstraction::CDataType_BINARY: + return arrow::Type::BINARY; + case odbcabstraction::CDataType_NUMERIC: + return arrow::Type::DECIMAL128; + case odbcabstraction::CDataType_TIMESTAMP: + return arrow::Type::TIMESTAMP; + case odbcabstraction::CDataType_TIME: + return arrow::Type::TIME64; + case odbcabstraction::CDataType_DATE: + return arrow::Type::DATE64; + default: + throw odbcabstraction::DriverException(std::string("Invalid target type: ") + std::to_string(data_type)); + } +} + +odbcabstraction::CDataType ConvertArrowTypeToC(arrow::Type::type type_id, bool useWideChar) { + switch (type_id) { + case arrow::Type::STRING: + return GetDefaultCCharType(useWideChar); + case arrow::Type::INT16: + return odbcabstraction::CDataType_SSHORT; + case arrow::Type::UINT16: + return odbcabstraction::CDataType_USHORT; + case arrow::Type::INT32: + return odbcabstraction::CDataType_SLONG; + case arrow::Type::UINT32: + return odbcabstraction::CDataType_ULONG; + case arrow::Type::FLOAT: + return odbcabstraction::CDataType_FLOAT; + case arrow::Type::DOUBLE: + return odbcabstraction::CDataType_DOUBLE; + case arrow::Type::BOOL: + return odbcabstraction::CDataType_BIT; + case arrow::Type::INT8: + return odbcabstraction::CDataType_STINYINT; + case arrow::Type::UINT8: + return odbcabstraction::CDataType_UTINYINT; + case arrow::Type::INT64: + return odbcabstraction::CDataType_SBIGINT; + case arrow::Type::UINT64: + return odbcabstraction::CDataType_UBIGINT; + case arrow::Type::BINARY: + return odbcabstraction::CDataType_BINARY; + case arrow::Type::DECIMAL128: + return odbcabstraction::CDataType_NUMERIC; + case arrow::Type::DATE64: + case arrow::Type::DATE32: + return odbcabstraction::CDataType_DATE; + case arrow::Type::TIME64: + case arrow::Type::TIME32: + return odbcabstraction::CDataType_TIME; + case arrow::Type::TIMESTAMP: + return odbcabstraction::CDataType_TIMESTAMP; + default: + throw odbcabstraction::DriverException(std::string("Invalid type id: ") + std::to_string(type_id)); + } +} + +std::shared_ptr +CheckConversion(const arrow::Result &result) { + if (result.ok()) { + const arrow::Datum &datum = result.ValueOrDie(); + return datum.make_array(); + } else { + throw odbcabstraction::DriverException(result.status().message()); + } +} + +ArrayConvertTask GetConverter(arrow::Type::type original_type_id, + odbcabstraction::CDataType target_type) { + // The else statement has a convert the works for the most case of array + // conversion. In case, we find conversion that the default one can't handle + // we can include some additional if-else statement with the logic to handle + // it + if (original_type_id == arrow::Type::STRING && + target_type == odbcabstraction::CDataType_TIME) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::StrptimeOptions options("%H:%M", arrow::TimeUnit::MICRO, false); + + auto converted_result = + arrow::compute::Strptime({original_array}, options); + auto first_converted_array = CheckConversion(converted_result); + + arrow::compute::CastOptions cast_options; + cast_options.to_type = time64(arrow::TimeUnit::MICRO); + return CheckConversion(arrow::compute::CallFunction( + "cast", {first_converted_array}, &cast_options)); + }; + } else if (original_type_id == arrow::Type::TIME32 && + target_type == odbcabstraction::CDataType_TIMESTAMP) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::int32(); + + auto first_converted_array = CheckConversion( + arrow::compute::Cast(original_array, cast_options)); + + cast_options.to_type = arrow::int64(); + + auto second_converted_array = CheckConversion( + arrow::compute::Cast(first_converted_array, cast_options)); + + auto seconds_from_epoch = GetTodayTimeFromEpoch(); + + auto third_converted_array = CheckConversion( + arrow::compute::Add(second_converted_array, std::make_shared(seconds_from_epoch * 1000))); + + arrow::compute::CastOptions cast_options_2; + cast_options_2.to_type = arrow::timestamp(arrow::TimeUnit::MILLI); + + return CheckConversion( + arrow::compute::Cast(third_converted_array, cast_options_2)); + }; + } else if (original_type_id == arrow::Type::TIME64 && + target_type == odbcabstraction::CDataType_TIMESTAMP) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::int64(); + + auto first_converted_array = CheckConversion( + arrow::compute::Cast(original_array, cast_options)); + + auto seconds_from_epoch = GetTodayTimeFromEpoch(); + + auto second_converted_array = CheckConversion( + arrow::compute::Add(first_converted_array, + std::make_shared(seconds_from_epoch * 1000000000))); + + arrow::compute::CastOptions cast_options_2; + cast_options_2.to_type = arrow::timestamp(arrow::TimeUnit::NANO); + + return CheckConversion( + arrow::compute::Cast(second_converted_array, cast_options_2)); + }; + } else if (original_type_id == arrow::Type::STRING && + target_type == odbcabstraction::CDataType_DATE) { + return [=](const std::shared_ptr &original_array) { + // The Strptime requires a date format. Using the ISO 8601 format + arrow::compute::StrptimeOptions options("%Y-%m-%d", + arrow::TimeUnit::SECOND, false); + + auto converted_result = + arrow::compute::Strptime({original_array}, options); + + auto first_converted_array = CheckConversion(converted_result); + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::date64(); + return CheckConversion(arrow::compute::CallFunction( + "cast", {first_converted_array}, &cast_options)); + }; + } else if (original_type_id == arrow::Type::DECIMAL128 && + (target_type == odbcabstraction::CDataType_CHAR || + target_type == odbcabstraction::CDataType_WCHAR)) { + return [=](const std::shared_ptr &original_array) { + arrow::StringBuilder builder; + int64_t length = original_array->length(); + ThrowIfNotOK(builder.ReserveData(length)); + + for (int64_t i = 0; i < length; ++i) { + if (original_array->IsNull(i)) { + ThrowIfNotOK(builder.AppendNull()); + } else { + auto result = original_array->GetScalar(i); + auto scalar = result.ValueOrDie(); + ThrowIfNotOK(builder.Append(scalar->ToString())); + } + } + + auto finish = builder.Finish(); + + return finish.ValueOrDie(); + }; + } else if (IsComplexType(original_type_id) && + (target_type == odbcabstraction::CDataType_CHAR || + target_type == odbcabstraction::CDataType_WCHAR)) { + return [=](const std::shared_ptr &original_array) { + const auto &json_conversion_result = ConvertToJson(original_array); + ThrowIfNotOK(json_conversion_result.status()); + return json_conversion_result.ValueOrDie(); + }; + } else { + // Default converter + return [=](const std::shared_ptr &original_array) { + const arrow::Type::type &target_arrow_type_id = + ConvertCToArrowType(target_type); + arrow::compute::CastOptions cast_options; + cast_options.to_type = GetDefaultDataTypeForTypeId(target_arrow_type_id); + + return CheckConversion(arrow::compute::CallFunction( + "cast", {original_array}, &cast_options)); + }; + } +} +std::string ConvertToDBMSVer(const std::string &str) { + boost::char_separator separator("."); + boost::tokenizer< boost::char_separator > tokenizer(str, separator); + std::string result; + // The permitted ODBC format is ##.##.#### + // If any of the first 3 tokens are not numbers or are greater than the permitted digits, + // assume we hit the custom-server-information early and assume the remaining version digits are zero. + size_t position = 0; + bool is_showing_custom_data = false; + auto pad_remaining_tokens = [&](size_t pos) -> std::string { + std::string padded_str; + if (pos == 0) { + padded_str += "00"; + } + if (pos <= 1) { + padded_str += ".00"; + } + if (pos <= 2) { + padded_str += ".0000"; + } + return padded_str; + }; + + for(auto token : tokenizer) + { + if (token.empty()) { + continue; + } + + if (!is_showing_custom_data && position < 3) { + std::string suffix; + try { + size_t next_pos = 0; + int version = stoi(token, &next_pos); + if (next_pos != token.size()) { + suffix = &token[0]; + } + if (version < 0 || + (position < 2 && (version > 99)) || + (position == 2 && version > 9999)) { + is_showing_custom_data = true; + } else { + std::stringstream strstream; + if (position == 2) { + strstream << std::setfill('0') << std::setw(4); + } else { + strstream << std::setfill('0') << std::setw(2); + } + strstream << version; + + if (position != 0) { + result += "."; + } + result += strstream.str(); + if (next_pos != token.size()) { + suffix = &token[next_pos]; + result += pad_remaining_tokens(++position) + suffix; + position = 4; // Prevent additional padding. + is_showing_custom_data = true; + continue; + } + ++position; + continue; + } + } catch (std::logic_error&) { + is_showing_custom_data = true; + } + + result += pad_remaining_tokens(position) + suffix; + ++position; + } + + result += "." + token; + ++position; + } + + result += pad_remaining_tokens(position); + return result; +} + +int32_t GetDecimalTypeScale(const std::shared_ptr& decimalType){ + auto decimal128Type = std::dynamic_pointer_cast(decimalType); + return decimal128Type->scale(); +} + +int32_t GetDecimalTypePrecision(const std::shared_ptr& decimalType){ + auto decimal128Type = std::dynamic_pointer_cast(decimalType); + return decimal128Type->precision(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h new file mode 100644 index 0000000000000..61b27c86713d0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +typedef std::function< + std::shared_ptr(const std::shared_ptr &)> + ArrayConvertTask; + +using std::optional; + +inline void ThrowIfNotOK(const arrow::Status &status) { + if (!status.ok()) { + throw odbcabstraction::DriverException(status.message()); + } +} + +template +inline bool CheckIfSetToOnlyValidValue(const AttributeTypeT &value, T allowed_value) { + return boost::get(value) == allowed_value; +} + +template +arrow::Status AppendToBuilder(BUILDER &builder, optional opt_value) { + if (opt_value) { + return builder.Append(*opt_value); + } else { + return builder.AppendNull(); + } +} + +template +arrow::Status AppendToBuilder(BUILDER &builder, T value) { + return builder.Append(value); +} + +odbcabstraction::SqlDataType +GetDataTypeFromArrowField_V3(const std::shared_ptr &field, bool useWideChar); + +odbcabstraction::SqlDataType EnsureRightSqlCharType(odbcabstraction::SqlDataType data_type, bool useWideChar); + +int16_t ConvertSqlDataTypeFromV3ToV2(int16_t data_type_v3); + +odbcabstraction::CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2); + +std::string GetTypeNameFromSqlDataType(int16_t data_type); + +optional +GetRadixFromSqlDataType(odbcabstraction::SqlDataType data_type); + +int16_t GetNonConciseDataType(odbcabstraction::SqlDataType data_type); + +optional GetSqlDateTimeSubCode(odbcabstraction::SqlDataType data_type); + +optional GetCharOctetLength(odbcabstraction::SqlDataType data_type, + const arrow::Result& column_size, + const int32_t decimal_precison=0); + +optional GetBufferLength(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetLength(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetTypeScale(odbcabstraction::SqlDataType data_type, + const optional& type_scale); + +optional GetColumnSize(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetDisplaySize(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +std::string ConvertSqlPatternToRegexString(const std::string &pattern); + +boost::xpressive::sregex ConvertSqlPatternToRegex(const std::string &pattern); + +bool NeedArrayConversion(arrow::Type::type original_type_id, + odbcabstraction::CDataType data_type); + +std::shared_ptr GetDefaultDataTypeForTypeId(arrow::Type::type type_id); + +arrow::Type::type ConvertCToArrowType(odbcabstraction::CDataType data_type); + +odbcabstraction::CDataType ConvertArrowTypeToC(arrow::Type::type type_id, bool useWideChar); + +std::shared_ptr CheckConversion(const arrow::Result &result); + +ArrayConvertTask GetConverter(arrow::Type::type original_type_id, + odbcabstraction::CDataType target_type); + +std::string ConvertToDBMSVer(const std::string& str); + +int32_t GetDecimalTypeScale(const std::shared_ptr& decimalType); + +int32_t GetDecimalTypePrecision(const std::shared_ptr& decimalType); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc new file mode 100644 index 0000000000000..5bada8c77dc7d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "utils.h" + +#include "odbcabstraction/calendar_utils.h" + +#include "arrow/testing/builder.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/util.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +void AssertConvertedArray(const std::shared_ptr& expected_array, + const std::shared_ptr& converted_array, + uint64_t size, + arrow::Type::type arrow_type) { + ASSERT_EQ(converted_array->type_id(), arrow_type); + ASSERT_EQ(converted_array->length(),size); + ASSERT_EQ(expected_array->ToString(), converted_array->ToString()); +} + +std::shared_ptr convertArray( + const std::shared_ptr& original_array, + odbcabstraction::CDataType c_type) { + auto converter = GetConverter(original_array->type_id(), + c_type); + return converter(original_array); +} + +void TestArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +void TestTime32ArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(time32(arrow::TimeUnit::MILLI), + input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +void TestTime64ArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(time64(arrow::TimeUnit::NANO), + input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +TEST(Utils, Time32ToTimeStampArray) { + std::vector input_data = {14896, 17820}; + + const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); + std::vector expected_data; + expected_data.reserve(2); + + for (const auto &item : input_data) { + expected_data.emplace_back(item + seconds_from_epoch * 1000); + } + + std::shared_ptr expected; + auto timestamp_field = field("timestamp_field", timestamp(arrow::TimeUnit::MILLI)); + arrow::ArrayFromVector(timestamp_field->type(), + expected_data, &expected); + + TestTime32ArrayConversion(input_data, expected, + odbcabstraction::CDataType_TIMESTAMP, + arrow::Type::TIMESTAMP); +} + +TEST(Utils, Time64ToTimeStampArray) { + std::vector input_data = {1579489200000, 1646881200000}; + + const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); + std::vector expected_data; + expected_data.reserve(2); + + for (const auto &item : input_data) { + expected_data.emplace_back(item + seconds_from_epoch * 1000000000); + } + + std::shared_ptr expected; + auto timestamp_field = field("timestamp_field", timestamp(arrow::TimeUnit::NANO)); + arrow::ArrayFromVector(timestamp_field->type(), + expected_data, &expected); + + TestTime64ArrayConversion(input_data, expected, + odbcabstraction::CDataType_TIMESTAMP, + arrow::Type::TIMESTAMP); +} + +TEST(Utils, StringToDateArray) { + std::shared_ptr expected; + arrow::ArrayFromVector( + {1579489200000, 1646881200000}, &expected); + + TestArrayConversion({"2020-01-20", "2022-03-10"}, expected, + odbcabstraction::CDataType_DATE, + arrow::Type::DATE64); +} + +TEST(Utils, StringToTimeArray) { + std::shared_ptr expected; + arrow::ArrayFromVector(time64(arrow::TimeUnit::MICRO), + {36000000000, 43200000000}, &expected); + + TestArrayConversion({"10:00", "12:00"}, expected, + odbcabstraction::CDataType_TIME, arrow::Type::TIME64); +} + +TEST(Utils, ConvertSqlPatternToRegexString) { + ASSERT_EQ(std::string("XY"), ConvertSqlPatternToRegexString("XY")); + ASSERT_EQ(std::string("X.Y"), ConvertSqlPatternToRegexString("X_Y")); + ASSERT_EQ(std::string("X.*Y"), ConvertSqlPatternToRegexString("X%Y")); + ASSERT_EQ(std::string("X%Y"), ConvertSqlPatternToRegexString("X\\%Y")); + ASSERT_EQ(std::string("X_Y"), ConvertSqlPatternToRegexString("X\\_Y")); +} + +TEST(Utils, ConvertToDBMSVer) { + ASSERT_EQ(std::string("01.02.0003"), ConvertToDBMSVer("1.2.3")); + ASSERT_EQ(std::string("01.02.0003.0"), ConvertToDBMSVer("1.2.3.0")); + ASSERT_EQ(std::string("01.02.0000"), ConvertToDBMSVer("1.2")); + ASSERT_EQ(std::string("01.00.0000"), ConvertToDBMSVer("1")); + ASSERT_EQ(std::string("01.02.0000-foo"), ConvertToDBMSVer("1.2-foo")); + ASSERT_EQ(std::string("01.00.0000-foo"), ConvertToDBMSVer("1-foo")); + ASSERT_EQ(std::string("10.11.0001-foo"), ConvertToDBMSVer("10.11.1-foo")); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt new file mode 100644 index 0000000000000..88f048899f283 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# Ensure fmt is loaded as header only +add_compile_definitions(FMT_HEADER_ONLY) + +add_library(odbcabstraction + include/odbcabstraction/calendar_utils.h + include/odbcabstraction/diagnostics.h + include/odbcabstraction/error_codes.h + include/odbcabstraction/exceptions.h + include/odbcabstraction/logger.h + include/odbcabstraction/platform.h + include/odbcabstraction/spd_logger.h + include/odbcabstraction/types.h + include/odbcabstraction/utils.h + include/odbcabstraction/odbc_impl/AttributeUtils.h + include/odbcabstraction/odbc_impl/EncodingUtils.h + include/odbcabstraction/odbc_impl/ODBCConnection.h + include/odbcabstraction/odbc_impl/ODBCDescriptor.h + include/odbcabstraction/odbc_impl/ODBCEnvironment.h + include/odbcabstraction/odbc_impl/ODBCHandle.h + include/odbcabstraction/odbc_impl/ODBCStatement.h + include/odbcabstraction/odbc_impl/TypeUtilities.h + include/odbcabstraction/spi/connection.h + include/odbcabstraction/spi/driver.h + include/odbcabstraction/spi/result_set.h + include/odbcabstraction/spi/result_set_metadata.h + include/odbcabstraction/spi/statement.h + calendar_utils.cc + diagnostics.cc + encoding.cc + exceptions.cc + logger.cc + spd_logger.cc + utils.cc + whereami.h + whereami.cc + odbc_impl/ODBCConnection.cc + odbc_impl/ODBCDescriptor.cc + odbc_impl/ODBCEnvironment.cc + odbc_impl/ODBCStatement.cc +) +target_include_directories(odbcabstraction PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include) +target_link_libraries(odbcabstraction PUBLIC ODBC::ODBC Boost::headers) + +set_target_properties(odbcabstraction + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + ) + +include(FetchContent) +FetchContent_Declare( + spdlog + URL https://github.com/gabime/spdlog/archive/76fb40d95455f249bd70824ecfcae7a8f0930fa3.zip + CONFIGURE_COMMAND "" + BUILD_COMMAND "" +) +FetchContent_GetProperties(spdlog) +if(NOT spdlog_POPULATED) + FetchContent_Populate(spdlog) +endif() + +add_library(spdlog INTERFACE) +target_include_directories(spdlog INTERFACE ${spdlog_SOURCE_DIR}/include) + +target_link_libraries(odbcabstraction PUBLIC spdlog) diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/calendar_utils.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/calendar_utils.cc new file mode 100644 index 0000000000000..a8514e5934ba5 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/calendar_utils.cc @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "odbcabstraction/calendar_utils.h" + +#include +#include + +namespace driver { +namespace odbcabstraction { +int64_t GetTodayTimeFromEpoch() { + tm date{}; + int64_t t = std::time(0); + + GetTimeForSecondsSinceEpoch(date, t); + + date.tm_hour = 0; + date.tm_min = 0; + date.tm_sec = 0; + + #if defined(_WIN32) + return _mkgmtime(&date); + #else + return timegm(&date); + #endif +} + +void GetTimeForSecondsSinceEpoch(tm& date, int64_t value) { + #if defined(_WIN32) + gmtime_s(&date, &value); + #else + time_t time_value = static_cast(value); + gmtime_r(&time_value, &date); + #endif + } +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc new file mode 100644 index 0000000000000..596f76f6441b4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +namespace { + void RewriteSQLStateForODBC2(std::string& sql_state) { + if (sql_state[0] == 'H' && sql_state[1] == 'Y') { + sql_state[0] = 'S'; + sql_state[1] = '1'; + } + } +} + +namespace driver { +namespace odbcabstraction { + +Diagnostics::Diagnostics( + std::string vendor, std::string data_source_component, OdbcVersion version) : + vendor_(std::move(vendor)), + data_source_component_(std::move(data_source_component)), + version_(version) +{} + +void Diagnostics::SetDataSourceComponent(std::string component) { + data_source_component_ = std::move(component); +} + +std::string Diagnostics::GetDataSourceComponent() const { + return data_source_component_; +} + +std::string Diagnostics::GetVendor() const { + return vendor_; +} + +void driver::odbcabstraction::Diagnostics::AddError( + const driver::odbcabstraction::DriverException &exception) { + auto record = std::unique_ptr(new DiagnosticsRecord{ + exception.GetMessageText(), exception.GetSqlState(), exception.GetNativeError()}); + if (version_ == OdbcVersion::V_2) { + RewriteSQLStateForODBC2(record->sql_state_); + } + TrackRecord(*record); + owned_records_.push_back(std::move(record)); +} + +void driver::odbcabstraction::Diagnostics::AddWarning( + std::string message, std::string sql_state, int32_t native_error) { +auto record = std::unique_ptr(new DiagnosticsRecord{ + std::move(message),std::move(sql_state), native_error}); + if (version_ == OdbcVersion::V_2) { + RewriteSQLStateForODBC2(record->sql_state_); + } + TrackRecord(*record); + owned_records_.push_back(std::move(record)); +} + +std::string driver::odbcabstraction::Diagnostics::GetMessageText( + uint32_t record_index) const { + std::string message; + if (!vendor_.empty()) { + message += std::string("[") + vendor_ + "]"; + } + const DiagnosticsRecord* rec = GetRecordAtIndex(record_index); + return message + "[" + data_source_component_ + "] (" + std::to_string(rec->native_error_) + ") " + rec->msg_text_; +} + +OdbcVersion Diagnostics::GetOdbcVersion() const { return version_; } + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc new file mode 100644 index 0000000000000..f84443cc485ed --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#if defined(__APPLE__) +#include +#include +#include +#endif + +namespace driver { +namespace odbcabstraction { + +#if defined(__APPLE__) +std::atomic SqlWCharSize{0}; + +namespace { +std::mutex SqlWCharSizeMutex; + +bool IsUsingIODBC() { + // Detects iODBC by looking up by symbol iodbc_version + void* handle = dlsym(RTLD_DEFAULT, "iodbc_version"); + bool using_iodbc = handle != nullptr; + dlclose(handle); + + return using_iodbc; +} +} + +void ComputeSqlWCharSize() { + std::unique_lock lock(SqlWCharSizeMutex); + if (SqlWCharSize != 0) return; // double-checked locking + + const char *env_p = std::getenv("WCHAR_ENCODING"); + if (env_p) { + if (boost::iequals(env_p, "UTF-16")) { + SqlWCharSize = sizeof(char16_t); + return; + } else if (boost::iequals(env_p, "UTF-32")) { + SqlWCharSize = sizeof(char32_t); + return; + } + } + + SqlWCharSize = IsUsingIODBC() ? sizeof(char32_t) : sizeof(char16_t); +} +#endif + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc new file mode 100644 index 0000000000000..ea676621c40c1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +DriverException::DriverException(std::string message, std::string sql_state, + int32_t native_error) + : msg_text_(std::move(message)), + sql_state_(std::move(sql_state)), + native_error_(native_error) {} + +const char *DriverException::what() const throw() { return msg_text_.c_str(); } +const std::string &DriverException::GetMessageText() const { return msg_text_; } +const std::string &DriverException::GetSqlState() const { return sql_state_; } +int32_t DriverException::GetNativeError() const { return native_error_; } + +AuthenticationException::AuthenticationException(std::string message, std::string sql_state, + int32_t native_error) + : DriverException(message, sql_state, native_error) {} + +CommunicationException::CommunicationException(std::string message, std::string sql_state, + int32_t native_error) + : DriverException(message + ". Please ensure your encryption settings match the server.", + sql_state, native_error) {} + +NullWithoutIndicatorException::NullWithoutIndicatorException( + std::string message, std::string sql_state, int32_t native_error) + : DriverException(message, sql_state, native_error) {} +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h new file mode 100644 index 0000000000000..7dee68a144483 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + + +template +class BlockingQueue { + + size_t capacity_; + std::vector buffer_; + size_t buffer_size_{0}; + size_t left_{0}; // index where variables are put inside of buffer (produced) + size_t right_{0}; // index where variables are removed from buffer (consumed) + + std::mutex mtx_; + std::condition_variable not_empty_; + std::condition_variable not_full_; + + std::vector threads_; + std::atomic active_threads_{0}; + std::atomic closed_{false}; + +public: + typedef std::function(void)> Supplier; + + BlockingQueue(size_t capacity): capacity_(capacity), buffer_(capacity) {} + + void AddProducer(Supplier supplier) { + active_threads_++; + threads_.emplace_back([=] { + while (!closed_) { + // Block while queue is full + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPushOrClosed(unique_lock)) break; + unique_lock.unlock(); + + // Only one thread at a time be notified and call supplier + auto item = supplier(); + if (!item) break; + + Push(*item); + } + + std::unique_lock unique_lock(mtx_); + active_threads_--; + not_empty_.notify_all(); + }); + } + + void Push(T item) { + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPushOrClosed(unique_lock)) return; + + buffer_[right_] = std::move(item); + + right_ = (right_ + 1) % capacity_; + buffer_size_++; + + not_empty_.notify_one(); + } + + bool Pop(T *result) { + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPopOrClosed(unique_lock)) return false; + + *result = std::move(buffer_[left_]); + + left_ = (left_ + 1) % capacity_; + buffer_size_--; + + not_full_.notify_one(); + + return true; + } + + void Close() { + std::unique_lock unique_lock(mtx_); + + if (closed_) return; + closed_ = true; + not_empty_.notify_all(); + not_full_.notify_all(); + + unique_lock.unlock(); + + for (auto &item: threads_) { + item.join(); + } + } + +private: + bool WaitUntilCanPushOrClosed(std::unique_lock &unique_lock) { + not_full_.wait(unique_lock, [this]() { + return closed_ || buffer_size_ != capacity_; + }); + return !closed_; + } + + bool WaitUntilCanPopOrClosed(std::unique_lock &unique_lock) { + not_empty_.wait(unique_lock, [this]() { + return closed_ || buffer_size_ != 0 || active_threads_ == 0; + }); + + return !closed_ && buffer_size_ > 0; + } +}; + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h new file mode 100644 index 0000000000000..555a37f67547c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + int64_t GetTodayTimeFromEpoch(); + + void GetTimeForSecondsSinceEpoch(tm& date, int64_t value); +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h new file mode 100644 index 0000000000000..001e64be4e455 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + class Diagnostics { + public: + struct DiagnosticsRecord { + std::string msg_text_; + std::string sql_state_; + int32_t native_error_; + }; + + private: + std::vector error_records_; + std::vector warning_records_; + std::vector> owned_records_; + std::string vendor_; + std::string data_source_component_; + OdbcVersion version_; + + public: + Diagnostics(std::string vendor, std::string data_source_component, OdbcVersion version); + void AddError(const DriverException& exception); + void AddWarning(std::string message, std::string sql_state, int32_t native_error); + + /// \brief Add a pre-existing truncation warning. + inline void AddTruncationWarning() { + static const std::unique_ptr TRUNCATION_WARNING(new DiagnosticsRecord { + "String or binary data, right-truncated.", "01004", + ODBCErrorCodes_TRUNCATION_WARNING + }); + warning_records_.push_back(TRUNCATION_WARNING.get()); + } + + inline void TrackRecord(const DiagnosticsRecord& record) { + if (record.sql_state_[0] == '0' && record.sql_state_[1] == '1') { + warning_records_.push_back(&record); + } else { + error_records_.push_back(&record); + } + } + + void SetDataSourceComponent(std::string component); + std::string GetDataSourceComponent() const; + + std::string GetVendor() const; + + inline void Clear() { + error_records_.clear(); + warning_records_.clear(); + owned_records_.clear(); + } + + std::string GetMessageText(uint32_t record_index) const; + std::string GetSQLState(uint32_t record_index) const { + return GetRecordAtIndex(record_index)->sql_state_; + } + + int32_t GetNativeError(uint32_t record_index) const { + return GetRecordAtIndex(record_index)->native_error_; + } + + inline size_t GetRecordCount() const { + return error_records_.size() + warning_records_.size(); + } + + inline bool HasRecord(uint32_t record_index) const { + return error_records_.size() + warning_records_.size() > record_index; + } + + inline bool HasWarning() const { + return !warning_records_.empty(); + } + + inline bool HasError() const { + return !error_records_.empty(); + } + + OdbcVersion GetOdbcVersion() const; + + private: + inline const DiagnosticsRecord* GetRecordAtIndex(uint32_t record_index) const { + if (record_index < error_records_.size()) { + return error_records_[record_index]; + } + return warning_records_[record_index - error_records_.size()]; + } + }; +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/encoding.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/encoding.h new file mode 100644 index 0000000000000..a3996beb103e2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/encoding.h @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +#include +#endif + +namespace driver { +namespace odbcabstraction { + +#if defined(__APPLE__) +extern std::atomic SqlWCharSize; + +void ComputeSqlWCharSize(); + +inline size_t GetSqlWCharSize() { + if (SqlWCharSize == 0) { + ComputeSqlWCharSize(); + } + + return SqlWCharSize; +} +#else +constexpr inline size_t GetSqlWCharSize() { + return sizeof(char16_t); +} +#endif + +namespace { + +template +inline size_t wcsstrlen(const void *wcs_string) { + size_t len; + for (len = 0; ((CHAR_TYPE *) wcs_string)[len]; len++); + return len; +} + +inline size_t wcsstrlen(const void *wcs_string) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return wcsstrlen(wcs_string); + case sizeof(char32_t): + return wcsstrlen(wcs_string); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +} + +template +inline void Utf8ToWcs(const char *utf8_string, size_t length, std::vector *result) { + thread_local std::wstring_convert, CHAR_TYPE> converter; + auto string = converter.from_bytes(utf8_string, utf8_string + length); + + unsigned long length_in_bytes = string.size() * GetSqlWCharSize(); + const uint8_t *data = (uint8_t*) string.data(); + + result->reserve(length_in_bytes); + result->assign(data, data + length_in_bytes); +} + +inline void Utf8ToWcs(const char *utf8_string, size_t length, std::vector *result) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return Utf8ToWcs(utf8_string, length, result); + case sizeof(char32_t): + return Utf8ToWcs(utf8_string, length, result); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +inline void Utf8ToWcs(const char *utf8_string, std::vector *result) { + return Utf8ToWcs(utf8_string, strlen(utf8_string), result); +} + +template +inline void WcsToUtf8(const void *wcs_string, size_t length_in_code_units, std::vector *result) { + thread_local std::wstring_convert, CHAR_TYPE> converter; + auto byte_string = converter.to_bytes((CHAR_TYPE*) wcs_string, (CHAR_TYPE*) wcs_string + length_in_code_units); + + unsigned long length_in_bytes = byte_string.size(); + const uint8_t *data = (uint8_t*) byte_string.data(); + + result->reserve(length_in_bytes); + result->assign(data, data + length_in_bytes); +} + +inline void WcsToUtf8(const void *wcs_string, size_t length_in_code_units, std::vector *result) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return WcsToUtf8(wcs_string, length_in_code_units, result); + case sizeof(char32_t): + return WcsToUtf8(wcs_string, length_in_code_units, result); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +inline void WcsToUtf8(const void *wcs_string, std::vector *result) { + return WcsToUtf8(wcs_string, wcsstrlen(wcs_string), result); +} + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/error_codes.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/error_codes.h new file mode 100644 index 0000000000000..50ed8fb105499 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/error_codes.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +namespace driver { +namespace odbcabstraction { + + enum ODBCErrorCodes : int32_t { + ODBCErrorCodes_GENERAL_ERROR = 100, + ODBCErrorCodes_AUTH = 200, + ODBCErrorCodes_TLS = 300, + ODBCErrorCodes_FRACTIONAL_TRUNCATION_ERROR = 400, + ODBCErrorCodes_COMMUNICATION = 500, + ODBCErrorCodes_GENERAL_WARNING = 1000000, + ODBCErrorCodes_TRUNCATION_WARNING = 1000100, + ODBCErrorCodes_FRACTIONAL_TRUNCATION_WARNING = 1000100, + ODBCErrorCodes_INDICATOR_NEEDED = 1000200 + }; +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h new file mode 100644 index 0000000000000..b1fae7b5cb551 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Base for all driver specific exceptions +class DriverException : public std::exception { +public: + explicit DriverException(std::string message, std::string sql_state = "HY000", + int32_t native_error = ODBCErrorCodes_GENERAL_ERROR); + + const char *what() const throw() override; + + const std::string &GetMessageText() const; + const std::string &GetSqlState() const; + int32_t GetNativeError() const; + +private: + const std::string msg_text_; + const std::string sql_state_; + const int32_t native_error_; +}; + +/// \brief Authentication specific exception +class AuthenticationException : public DriverException { +public: + explicit AuthenticationException(std::string message, std::string sql_state = "28000", + int32_t native_error = ODBCErrorCodes_AUTH); +}; + +/// \brief Communication link specific exception +class CommunicationException : public DriverException { +public: + explicit CommunicationException(std::string message, std::string sql_state = "08S01", + int32_t native_error = ODBCErrorCodes_COMMUNICATION); +}; + +/// \brief Error when null is retrieved from the database but no indicator was supplied. +/// (This means the driver has no way to report ot the application that there was a NULL value). +class NullWithoutIndicatorException : public DriverException { +public: + explicit NullWithoutIndicatorException( + std::string message = "Indicator variable required but not supplied", std::string sql_state = "22002", + int32_t native_error = ODBCErrorCodes_INDICATOR_NEEDED); +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h new file mode 100644 index 0000000000000..cdc9676ecc474 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include + +#define __LAZY_LOG(LEVEL, ...) do { \ + driver::odbcabstraction::Logger *logger = driver::odbcabstraction::Logger::GetInstance(); \ + if (logger) { \ + logger->log(driver::odbcabstraction::LogLevel::LogLevel_##LEVEL, [&]() { \ + return fmt::format(__VA_ARGS__); \ + }); \ + } \ +} while(0) +#define LOG_DEBUG(...) __LAZY_LOG(DEBUG, __VA_ARGS__) +#define LOG_INFO(...) __LAZY_LOG(INFO, __VA_ARGS__) +#define LOG_ERROR(...) __LAZY_LOG(ERROR, __VA_ARGS__) +#define LOG_TRACE(...) __LAZY_LOG(TRACE, __VA_ARGS__) +#define LOG_WARN(...) __LAZY_LOG(WARN, __VA_ARGS__) + +namespace driver { +namespace odbcabstraction { + +enum LogLevel { + LogLevel_TRACE, + LogLevel_DEBUG, + LogLevel_INFO, + LogLevel_WARN, + LogLevel_ERROR, + LogLevel_OFF +}; + +class Logger { +protected: + Logger() = default; + +public: + static Logger *GetInstance(); + static void SetInstance(std::unique_ptr logger); + + virtual ~Logger() = default; + + virtual void log(LogLevel level, const std::function &build_message) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h new file mode 100644 index 0000000000000..955a56ecd94fd --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace ODBC { +template +inline void GetAttribute(T attributeValue, SQLPOINTER output, O outputSize, + O *outputLenPtr) { + if (output) { + T *typedOutput = reinterpret_cast(output); + *typedOutput = attributeValue; + } + + if (outputLenPtr) { + *outputLenPtr = sizeof(T); + } +} + +template +inline SQLRETURN GetAttributeUTF8(const std::string &attributeValue, + SQLPOINTER output, O outputSize, O *outputLenPtr) { + if (output) { + size_t outputLenBeforeNul = + std::min(static_cast(attributeValue.size()), static_cast(outputSize - 1)); + memcpy(output, attributeValue.c_str(), outputLenBeforeNul); + reinterpret_cast(output)[outputLenBeforeNul] = '\0'; + } + + if (outputLenPtr) { + *outputLenPtr = static_cast(attributeValue.size()); + } + + if (output && outputSize < attributeValue.size() + 1) { + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} + +template +inline SQLRETURN GetAttributeUTF8(const std::string &attributeValue, + SQLPOINTER output, O outputSize, O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = GetAttributeUTF8(attributeValue, output, outputSize, outputLenPtr); + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline SQLRETURN GetAttributeSQLWCHAR(const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, + O *outputLenPtr) { + size_t result = ConvertToSqlWChar( + attributeValue, reinterpret_cast(output), isLengthInBytes ? outputSize : outputSize * GetSqlWCharSize()); + + if (outputLenPtr) { + *outputLenPtr = static_cast(isLengthInBytes ? result : result / GetSqlWCharSize()); + } + + if (output && outputSize < result + (isLengthInBytes ? GetSqlWCharSize() : 1)) { + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} + +template +inline SQLRETURN GetAttributeSQLWCHAR(const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, + O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = GetAttributeSQLWCHAR(attributeValue, isLengthInBytes, output, outputSize, outputLenPtr); + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline SQLRETURN +GetStringAttribute(bool isUnicode, const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = SQL_SUCCESS; + if (isUnicode) { + result = GetAttributeSQLWCHAR(attributeValue, isLengthInBytes, output, outputSize, outputLenPtr); + } else { + result = GetAttributeUTF8(attributeValue, output, outputSize, outputLenPtr); + } + + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline void SetAttribute(SQLPOINTER newValue, T &attributeToWrite) { + SQLLEN valueAsLen = reinterpret_cast(newValue); + attributeToWrite = static_cast(valueAsLen); +} + +template +inline void SetPointerAttribute(SQLPOINTER newValue, T &attributeToWrite) { + attributeToWrite = static_cast(newValue); +} + +inline void SetAttributeUTF8(SQLPOINTER newValue, SQLINTEGER inputLength, + std::string &attributeToWrite) { + const char *newValueAsChar = static_cast(newValue); + attributeToWrite.assign(newValueAsChar, inputLength == SQL_NTS + ? strlen(newValueAsChar) + : inputLength); +} + +inline void SetAttributeSQLWCHAR(SQLPOINTER newValue, + SQLINTEGER inputLengthInBytes, + std::string &attributeToWrite) { + thread_local std::vector utf8_str; + if (inputLengthInBytes == SQL_NTS) { + WcsToUtf8(newValue, &utf8_str); + } else { + WcsToUtf8(newValue, inputLengthInBytes / GetSqlWCharSize(), &utf8_str); + } + attributeToWrite.assign((char *) utf8_str.data()); +} + +template +void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) { + if (static_cast(reinterpret_cast(value)) != allowed_value) { + throw driver::odbcabstraction::DriverException("Optional feature not implemented", "HYC00"); + } +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h new file mode 100644 index 0000000000000..2f12a50c7b528 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING + +namespace ODBC { + using namespace driver::odbcabstraction; + + // Return the number of bytes required for the conversion. + template + inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { + thread_local std::vector wstr; + Utf8ToWcs(str.data(), str.size(), &wstr); + SQLLEN valueLengthInBytes = wstr.size(); + + if (buffer) { + memcpy(buffer, wstr.data(), std::min(static_cast(wstr.size()), bufferSizeInBytes)); + + // Write a NUL terminator + if (bufferSizeInBytes >= valueLengthInBytes + GetSqlWCharSize()) { + reinterpret_cast(buffer)[valueLengthInBytes / GetSqlWCharSize()] = '\0'; + } else { + SQLLEN numCharsWritten = bufferSizeInBytes / GetSqlWCharSize(); + // If we failed to even write one char, the buffer is too small to hold a NUL-terminator. + if (numCharsWritten > 0) { + reinterpret_cast(buffer)[numCharsWritten-1] = '\0'; + } + } + } + return valueLengthInBytes; + } + + inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return ConvertToSqlWChar(str, buffer, bufferSizeInBytes); + case sizeof(char32_t): + return ConvertToSqlWChar(str, buffer, bufferSizeInBytes); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } + } +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h new file mode 100644 index 0000000000000..08ca1790832a9 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace ODBC +{ + class ODBCEnvironment; + class ODBCDescriptor; + class ODBCStatement; +} + +/** + * @brief An abstraction over an ODBC connection handle. This also wraps an SPI Connection. + */ +namespace ODBC +{ +class ODBCConnection : public ODBCHandle { + public: + ODBCConnection(const ODBCConnection&) = delete; + ODBCConnection& operator=(const ODBCConnection&) = delete; + + ODBCConnection(ODBCEnvironment& environment, + std::shared_ptr spiConnection); + + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + + const std::string& GetDSN() const; + bool isConnected() const; + void connect(std::string dsn, const driver::odbcabstraction::Connection::ConnPropertyMap &properties, + std::vector &missing_properties); + + void GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, bool isUnicode); + void SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength, bool isUnicode); + void GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength, bool isUnicode); + + ~ODBCConnection() = default; + + inline ODBCStatement& GetTrackingStatement() { + return *m_attributeTrackingStatement; + } + + void disconnect(); + + void releaseConnection(); + + std::shared_ptr createStatement(); + void dropStatement(ODBCStatement* statement); + + std::shared_ptr createDescriptor(); + void dropDescriptor(ODBCDescriptor* descriptor); + + inline bool IsOdbc2Connection() const { + return m_is2xConnection; + } + + /// @return the DSN or empty string if Driver was used. + static std::string getPropertiesFromConnString(const std::string& connStr, + driver::odbcabstraction::Connection::ConnPropertyMap &properties); + + private: + ODBCEnvironment& m_environment; + std::shared_ptr m_spiConnection; + // Extra ODBC statement that's used to track and validate when statement attributes are + // set through the connection handle. These attributes get copied to new ODBC statements + // when they are allocated. + std::shared_ptr m_attributeTrackingStatement; + std::vector > m_statements; + std::vector > m_descriptors; + std::string m_dsn; + const bool m_is2xConnection; + bool m_isConnected; +}; + +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h new file mode 100644 index 0000000000000..ea6d07a5030a0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class ResultSetMetadata; +} +} +namespace ODBC { + class ODBCConnection; + class ODBCStatement; +} + +namespace ODBC +{ + struct DescriptorRecord { + std::string m_baseColumnName; + std::string m_baseTableName; + std::string m_catalogName; + std::string m_label; + std::string m_literalPrefix; + std::string m_literalSuffix; + std::string m_localTypeName; + std::string m_name; + std::string m_schemaName; + std::string m_tableName; + std::string m_typeName; + SQLPOINTER m_dataPtr = NULL; + SQLLEN* m_indicatorPtr = NULL; + SQLLEN m_displaySize = 0; + SQLLEN m_octetLength = 0; + SQLULEN m_length = 0; + SQLINTEGER m_autoUniqueValue; + SQLINTEGER m_caseSensitive = SQL_TRUE; + SQLINTEGER m_datetimeIntervalPrecision = 0; + SQLINTEGER m_numPrecRadix = 0; + SQLSMALLINT m_conciseType = SQL_C_DEFAULT; + SQLSMALLINT m_datetimeIntervalCode = 0; + SQLSMALLINT m_fixedPrecScale = 0; + SQLSMALLINT m_nullable = SQL_NULLABLE_UNKNOWN; + SQLSMALLINT m_paramType = SQL_PARAM_INPUT; + SQLSMALLINT m_precision = 0; + SQLSMALLINT m_rowVer = 0; + SQLSMALLINT m_scale = 0; + SQLSMALLINT m_searchable = SQL_SEARCHABLE; + SQLSMALLINT m_type = SQL_C_DEFAULT; + SQLSMALLINT m_unnamed = SQL_TRUE; + SQLSMALLINT m_unsigned = SQL_FALSE; + SQLSMALLINT m_updatable = SQL_FALSE; + bool m_isBound = false; + + void CheckConsistency(); + }; + + class ODBCDescriptor : public ODBCHandle{ + public: + /** + * @brief Construct a new ODBCDescriptor object. Link the descriptor to a connection, + * if applicable. A nullptr should be supplied for conn if the descriptor should not be linked. + */ + ODBCDescriptor(driver::odbcabstraction::Diagnostics& baseDiagnostics, + ODBCConnection* conn, ODBCStatement* stmt, bool isAppDescriptor, bool isWritable, bool is2xConnection); + + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + + ODBCConnection &GetConnection(); + + void SetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength); + void SetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength); + void GetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) const; + void GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength); + SQLSMALLINT getAllocType() const; + bool IsAppDescriptor() const; + + inline bool HaveBindingsChanged() const { + return m_hasBindingsChanged; + } + + void RegisterToStatement(ODBCStatement* statement, bool isApd); + void DetachFromStatement(ODBCStatement* statement, bool isApd); + void ReleaseDescriptor(); + + void PopulateFromResultSetMetadata(driver::odbcabstraction::ResultSetMetadata* rsmd); + + const std::vector& GetRecords() const; + std::vector& GetRecords(); + + void BindCol(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); + void SetDataPtrOnRecord(SQLPOINTER dataPtr, SQLSMALLINT recNumber); + + inline SQLULEN GetBindOffset() { + return m_bindOffsetPtr ? *m_bindOffsetPtr : 0UL; + } + + inline SQLULEN GetBoundStructOffset() { + // If this is SQL_BIND_BY_COLUMN, m_bindType is zero which indicates no offset due to use of a bound struct. + // If this is non-zero, row-wise binding is being used so the app should set this to sizeof(their struct). + return m_bindType; + } + + inline SQLULEN GetArraySize() { + return m_arraySize; + } + + inline SQLUSMALLINT* GetArrayStatusPtr() { + return m_arrayStatusPtr; + } + + inline void SetRowsProcessed(SQLULEN rows) { + if (m_rowsProccessedPtr) { + *m_rowsProccessedPtr = rows; + } + } + + inline void NotifyBindingsHavePropagated() { + m_hasBindingsChanged = false; + } + + inline void NotifyBindingsHaveChanged() { + m_hasBindingsChanged = true; + } + + private: + driver::odbcabstraction::Diagnostics m_diagnostics; + std::vector m_registeredOnStatementsAsApd; + std::vector m_registeredOnStatementsAsArd; + std::vector m_records; + ODBCConnection* m_owningConnection; + ODBCStatement* m_parentStatement; + SQLUSMALLINT* m_arrayStatusPtr; + SQLULEN* m_bindOffsetPtr; + SQLULEN* m_rowsProccessedPtr; + SQLULEN m_arraySize; + SQLINTEGER m_bindType; + SQLSMALLINT m_highestOneBasedBoundRecord; + const bool m_is2xConnection; + bool m_isAppDescriptor; + bool m_isWritable; + bool m_hasBindingsChanged; + }; +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h new file mode 100644 index 0000000000000..a10b1d5feb9e6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class Driver; +} +} + +namespace ODBC { + class ODBCConnection; +} + +/** + * @brief An abstraction over an ODBC environment handle. + */ +namespace ODBC +{ +class ODBCEnvironment : public ODBCHandle { + public: + ODBCEnvironment(std::shared_ptr driver); + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + SQLINTEGER getODBCVersion() const; + void setODBCVersion(SQLINTEGER version); + SQLINTEGER getConnectionPooling() const; + void setConnectionPooling(SQLINTEGER pooling); + std::shared_ptr CreateConnection(); + void DropConnection(ODBCConnection* conn); + ~ODBCEnvironment() = default; + + private: + std::vector > m_connections; + std::shared_ptr m_driver; + std::unique_ptr m_diagnostics; + SQLINTEGER m_version; + SQLINTEGER m_connectionPooling; +}; + +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h new file mode 100644 index 0000000000000..c97c3e54d6514 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +/** + * @brief An abstraction over a generic ODBC handle. + */ +namespace ODBC { + +template +class ODBCHandle { + +public: + inline driver::odbcabstraction::Diagnostics& GetDiagnostics() { + return static_cast(this)->GetDiagnostics_Impl(); + } + + inline driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl() { + throw std::runtime_error("Illegal state -- diagnostics requested on invalid handle"); + } + + template + inline SQLRETURN execute(SQLRETURN rc, Function function) { + try { + GetDiagnostics().Clear(); + rc = function(); + } catch (const driver::odbcabstraction::DriverException& ex) { + GetDiagnostics().AddError(ex); + } catch (const std::bad_alloc& ex) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException("A memory allocation error occurred.", "HY001")); + } catch (const std::exception& ex) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException(ex.what())); + } catch (...) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException("An unknown error occurred.")); + } + + if (GetDiagnostics().HasError()) { + return SQL_ERROR; + } if (SQL_SUCCEEDED(rc) && GetDiagnostics().HasWarning()) { + return SQL_SUCCESS_WITH_INFO; + } + return rc; + } + + template + inline SQLRETURN executeWithLock(SQLRETURN rc, Function function) { + const std::lock_guard lock(mtx_); + return execute(rc, function); + } + + template + static inline SQLRETURN ExecuteWithDiagnostics(SQLHANDLE handle, SQLRETURN rc, Function func) { + if (!handle) { + return SQL_INVALID_HANDLE; + } + if (SHOULD_LOCK) { + return reinterpret_cast(handle)->executeWithLock(rc, func); + } else { + return reinterpret_cast(handle)->execute(rc, func); + } + } + + static Derived* of(SQLHANDLE handle) { + return reinterpret_cast(handle); + } + +private: + std::mutex mtx_; +}; +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h new file mode 100644 index 0000000000000..ec2c191a4b281 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class Statement; + class ResultSet; +} +} + +namespace ODBC { + class ODBCConnection; + class ODBCDescriptor; +} + +/** + * @brief An abstraction over an ODBC connection handle. This also wraps an SPI Connection. + */ +namespace ODBC +{ +class ODBCStatement : public ODBCHandle { + public: + ODBCStatement(const ODBCStatement&) = delete; + ODBCStatement& operator=(const ODBCStatement&) = delete; + + ODBCStatement(ODBCConnection& connection, + std::shared_ptr spiStatement); + + ~ODBCStatement() = default; + + inline driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl() { + return *m_diagnostics; + } + + ODBCConnection &GetConnection(); + + void CopyAttributesFromConnection(ODBCConnection& connection); + void Prepare(const std::string& query); + void ExecutePrepared(); + void ExecuteDirect(const std::string& query); + + /** + * @brief Returns true if the number of rows fetch was greater than zero. + */ + bool Fetch(size_t rows); + bool isPrepared() const; + + void GetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER output, + SQLINTEGER bufferSize, SQLINTEGER *strLenPtr, bool isUnicode); + void SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, + SQLINTEGER bufferSize, bool isUnicode); + + void RevertAppDescriptor(bool isApd); + + inline ODBCDescriptor* GetIRD() { + return m_ird.get(); + } + + inline ODBCDescriptor* GetARD() { + return m_currentArd; + } + + inline SQLULEN GetRowsetSize() { + return m_rowsetSize; + } + + bool GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); + + /** + * @brief Closes the cursor. This does _not_ un-prepare the statement or change + * bindings. + */ + void closeCursor(bool suppressErrors); + + /** + * @brief Releases this statement from memory. + */ + void releaseStatement(); + + void GetTables(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* tableType); + void GetColumns(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* column); + void GetTypeInfo(SQLSMALLINT dataType); + void Cancel(); + + private: + ODBCConnection& m_connection; + std::shared_ptr m_spiStatement; + std::shared_ptr m_currenResult; + driver::odbcabstraction::Diagnostics* m_diagnostics; + + std::shared_ptr m_builtInArd; + std::shared_ptr m_builtInApd; + std::shared_ptr m_ipd; + std::shared_ptr m_ird; + ODBCDescriptor* m_currentArd; + ODBCDescriptor* m_currentApd; + SQLULEN m_rowNumber; + SQLULEN m_maxRows; + SQLULEN m_rowsetSize; // Used by SQLExtendedFetch instead of the ARD array size. + bool m_isPrepared; + bool m_hasReachedEndOfResult; +}; +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h new file mode 100644 index 0000000000000..1fd44643cfe8d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace ODBC { + inline SQLSMALLINT GetSqlTypeForODBCVersion(SQLSMALLINT type, bool isOdbc2x) { + switch (type) { + case SQL_DATE: + case SQL_TYPE_DATE: + return isOdbc2x ? SQL_DATE : SQL_TYPE_DATE; + + case SQL_TIME: + case SQL_TYPE_TIME: + return isOdbc2x ? SQL_TIME : SQL_TYPE_TIME; + + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + return isOdbc2x ? SQL_TIMESTAMP : SQL_TYPE_TIMESTAMP; + + default: + return type; + } + } +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h new file mode 100644 index 0000000000000..b1862e7f92a89 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) + // NOMINMAX avoids std::min/max being defined as a c macro + #ifndef NOMINMAX + #define NOMINMAX + #endif + + // Avoid including extraneous Windows headers. + #ifndef WIN32_LEAN_AND_MEAN + #define WIN32_LEAN_AND_MEAN + #endif + + #include + + #include + #include + + #include + typedef SSIZE_T ssize_t; + +#endif diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spd_logger.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spd_logger.h new file mode 100644 index 0000000000000..022bf9e8cb207 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spd_logger.h @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "odbcabstraction/logger.h" + +#include +#include + +#include + +namespace driver { +namespace odbcabstraction { + +class SPDLogger : public Logger { +protected: + std::shared_ptr logger_; + +public: + static const std::string LOG_LEVEL; + static const std::string LOG_PATH; + static const std::string MAXIMUM_FILE_SIZE; + static const std::string FILE_QUANTITY; + static const std::string LOG_ENABLED; + + SPDLogger() = default; + ~SPDLogger(); + SPDLogger(SPDLogger &other) = delete; + + void operator=(const SPDLogger &) = delete; + void init(int64_t fileQuantity, int64_t maxFileSize, + const std::string &fileNamePrefix, LogLevel level); + + void log(LogLevel level, const std::function &build_message) override; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h new file mode 100644 index 0000000000000..7e403ba3063a3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Case insensitive comparator +struct CaseInsensitiveComparator { + bool operator()(const std::string &s1, const std::string &s2) const { + return boost::lexicographical_compare(s1, s2, boost::is_iless()); + } +}; + +// PropertyMap is case-insensitive for keys. +typedef std::map PropertyMap; + +class Statement; + +/// \brief High-level representation of an ODBC connection. +class Connection { +protected: + Connection() = default; + +public: + virtual ~Connection() = default; + + /// \brief Connection attributes + enum AttributeId { + ACCESS_MODE, // uint32_t - Tells if it should support write operations + CONNECTION_DEAD, // uint32_t - Tells if connection is still alive + CONNECTION_TIMEOUT, // uint32_t - The timeout for connection functions after connecting. + CURRENT_CATALOG, // std::string - The current catalog + LOGIN_TIMEOUT, // uint32_t - The timeout for the initial connection + PACKET_SIZE, // uint32_t - The Packet Size + }; + + typedef boost::variant Attribute; + typedef boost::variant Info; + typedef PropertyMap ConnPropertyMap; + + /// \brief Establish the connection. + /// \param properties[in] properties used to establish the connection. + /// \param missing_properties[out] vector of missing properties (if any). + virtual void Connect(const ConnPropertyMap &properties, + std::vector &missing_properties) = 0; + + /// \brief Close the connection. + virtual void Close() = 0; + + /// \brief Create a statement. + virtual std::shared_ptr CreateStatement() = 0; + + /// \brief Set a connection attribute (may be called at any time). + /// \param attribute[in] Which attribute to set. + /// \param value The value to be set. + /// \return true if the value was set successfully or false if it was substituted with + /// a similar value. + virtual bool SetAttribute(AttributeId attribute, const Attribute &value) = 0; + + /// \brief Retrieve a connection attribute + /// \param attribute[in] Attribute to be retrieved. + virtual boost::optional + GetAttribute(Connection::AttributeId attribute) = 0; + + /// \brief Retrieves info from the database (see ODBC's SQLGetInfo). + virtual Info GetInfo(uint16_t info_type) = 0; + + /// \brief Gets the diagnostics for this connection. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h new file mode 100644 index 0000000000000..f3bfc275aa021 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +class Connection; + +/// \brief High-level representation of an ODBC driver. +class Driver { +protected: + Driver() = default; + +public: + virtual ~Driver() = default; + + /// \brief Create a connection using given ODBC version. + /// \param odbc_version ODBC version to be used. + virtual std::shared_ptr + CreateConnection(OdbcVersion odbc_version) = 0; + + /// \brief Gets the diagnostics for this connection. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; + + /// \brief Sets the driver version. + virtual void SetVersion(std::string version) = 0; + + /// \brief Register a log to be used by the system. + virtual void RegisterLog() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h new file mode 100644 index 0000000000000..06a22b5597571 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include + +#include + +namespace driver { +namespace odbcabstraction { + +class ResultSetMetadata; + +class ResultSet { +protected: + ResultSet() = default; + +public: + virtual ~ResultSet() = default; + + /// \brief Returns metadata for this ResultSet. + virtual std::shared_ptr GetMetadata() = 0; + + /// \brief Closes ResultSet, releasing any resources allocated by it. + virtual void Close() = 0; + + /// \brief Cancels ResultSet. + virtual void Cancel() = 0; + + /// \brief Binds a column with a result buffer. The buffer will be filled with + /// up to `GetMaxBatchSize()` values. + /// + /// \param column Column number to be bound with (starts from 1). + /// \param target_type Target data type expected by client. + /// \param precision Column's precision + /// \param scale Column's scale + /// \param buffer Target buffer to be filled with column values. + /// \param buffer_length Target buffer length. + /// \param strlen_buffer Buffer that holds the length of each value contained + /// on target buffer. + virtual void BindColumn(int column, int16_t target_type, int precision, + int scale, void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) = 0; + + /// \brief Fetches next rows from ResultSet and load values on buffers + /// previously bound with `BindColumn`. + /// + /// The parameters `buffer` and `strlen_buffer` passed to `BindColumn()` + /// should have capacity to accommodate the rows requested, otherwise data + /// will be truncated. + /// + /// \param rows The maximum number of rows to be fetched. + /// \param bind_offset The offset for bound columns and indicators. + /// \param bind_type The type of binding. Zero indicates columnar binding, non-zero indicates + /// that this holds the size of an application row buffer. This corresponds + /// directly to SQL_DESC_BIND_TYPE in ODBC. + /// \param row_status_array The array to write statuses. + /// \returns The number of rows fetched. + virtual size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) = 0; + + /// \brief Populates `buffer` with the value on current row for given column. + /// If the value doesn't fit the buffer this method returns true and + /// subsequent calls will fetch the rest of data. + /// + /// \param column Column number to be fetched. + /// \param target_type Target data type expected by client. + /// \param precision Column's precision + /// \param scale Column's scale + /// \param buffer Target buffer to be populated. + /// \param buffer_length Target buffer length. + /// \param strlen_buffer Buffer that holds the length of value being fetched. + /// \returns true if there is more data to fetch from the current cell; + /// false if the whole value was already fetched. + virtual bool GetData(int column, int16_t target_type, int precision, + int scale, void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h new file mode 100644 index 0000000000000..45a1b20f3ffb6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief High Level representation of the ResultSetMetadata from ODBC. +class ResultSetMetadata { +protected: + ResultSetMetadata() = default; + +public: + virtual ~ResultSetMetadata() = default; + + /// \brief It returns the total amount of the columns in the ResultSet. + /// \return the amount of columns. + virtual size_t GetColumnCount() = 0; + + /// \brief It retrieves the name of a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column name. + virtual std::string GetColumnName(int column_position) = 0; + + /// \brief It retrieves the size of a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column size. + virtual size_t GetPrecision(int column_position) = 0; + + /// \brief It retrieves the total of number of decimal digits. + /// \param column_position[in] the position of the column, starting from 1. + /// \return amount of decimal digits. + virtual size_t GetScale(int column_position) = 0; + + /// \brief It retrieves the SQL_DATA_TYPE of the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the SQL_DATA_TYPE + virtual uint16_t GetDataType(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column can have + /// null values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return true if column is nullable. + virtual Nullability IsNullable(int column_position) = 0; + + /// \brief It returns the Schema name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the Schema name for given column. + virtual std::string GetSchemaName(int column_position) = 0; + + /// \brief It returns the Catalog Name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the catalog name for given column. + virtual std::string GetCatalogName(int column_position) = 0; + + /// \brief It returns the Table Name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the Table name for given column. + virtual std::string GetTableName(int column_position) = 0; + + /// \brief It retrieves the column label. + /// \param column_position[in] the position of the column, starting from 1. + /// \return column label. + virtual std::string GetColumnLabel(int column_position) = 0; + + /// \brief It retrieves the designated column's normal maximum width in + /// characters. + /// \param column_position[in] the position of the column, starting from 1. + /// \return column normal maximum width. + virtual size_t GetColumnDisplaySize(int column_position) = 0; + + /// \brief It retrieves the base name for the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the base column name. + virtual std::string GetBaseColumnName(int column_position) = 0; + + /// \brief It retrieves the base table name that contains the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the base table name. + virtual std::string GetBaseTableName(int column_position) = 0; + + /// \brief It retrieves the concise data type (SQL_DESC_CONCISE_TYPE). + /// \param column_position[in] the position of the column, starting from 1. + /// \return the concise data type. + virtual uint16_t GetConciseType(int column_position) = 0; + + /// \brief It retrieves the maximum or the actual character length + /// of a character string or binary data type. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the maximum length + virtual size_t GetLength(int column_position) = 0; + + /// \brief It retrieves the character or characters that the driver uses + /// as prefix for literal values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the prefix character(s). + virtual std::string GetLiteralPrefix(int column_position) = 0; + + /// \brief It retrieves the character or characters that the driver uses + /// as prefix for literal values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the suffix character(s). + virtual std::string GetLiteralSuffix(int column_position) = 0; + + /// \brief It retrieves the local type name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the local type name. + virtual std::string GetLocalTypeName(int column_position) = 0; + + /// \brief It returns the column name alias. If it has no alias + /// it returns the column name. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column name alias. + virtual std::string GetName(int column_position) = 0; + + /// \brief It returns a numeric value to indicate if the data + /// is an approximate or exact numeric data type. + /// \param column_position[in] the position of the column, starting from 1. + virtual size_t GetNumPrecRadix(int column_position) = 0; + + /// \brief It returns the length in bytes from a string or binary data. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the length in bytes. + virtual size_t GetOctetLength(int column_position) = 0; + + /// \brief It returns the data type as a string. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the data type string. + virtual std::string GetTypeName(int column_position) = 0; + + /// \brief It returns a numeric values indicate the updatability of the + /// column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the updatability of the column. + virtual Updatability GetUpdatable(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column is + /// autoincrementing. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column is auto incremental. + virtual bool IsAutoUnique(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column is + /// case sensitive. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column is case sensitive. + virtual bool IsCaseSensitive(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column can be used + /// in where clauses. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column can be used in where clauses. + virtual Searchability IsSearchable(int column_position) = 0; + + /// \brief It checks if a numeric column is signed or unsigned. + /// \param column_position[in] the position of the column, starting from 1. + /// \return check if the column is signed or not. + virtual bool IsUnsigned(int column_position) = 0; + + /// \brief It check if the columns has fixed precision and a nonzero + /// scale. + /// \param column_position[in] the position of the column, starting from 1. + /// \return if column has a fixed precision and non zero scale. + virtual bool IsFixedPrecScale(int column_position) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/statement.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/statement.h new file mode 100644 index 0000000000000..7b557ec065d6e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/statement.h @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +using boost::optional; + +class ResultSet; + +class ResultSetMetadata; + +/// \brief High-level representation of an ODBC statement. +class Statement { +protected: + Statement() = default; + +public: + virtual ~Statement() = default; + + /// \brief Statement attributes that can be called at anytime. + ////TODO: Document attributes + enum StatementAttributeId { + MAX_LENGTH, // size_t - The maximum length when retrieving variable length data. 0 means no limit. + METADATA_ID, // size_t - Modifies catalog function arguments to be identifiers. SQL_TRUE or SQL_FALSE. + NOSCAN, // size_t - Indicates that the driver does not scan for escape sequences. Default to SQL_NOSCAN_OFF + QUERY_TIMEOUT, // size_t - The time to wait in seconds for queries to execute. 0 to have no timeout. + }; + + typedef boost::variant Attribute; + + /// \brief Set a statement attribute (may be called at any time) + /// + /// NOTE: Meant to be bound with SQLSetStmtAttr. + /// + /// \param attribute Attribute identifier to set. + /// \param value Value to be associated with the attribute. + /// \return true if the value was set successfully or false if it was substituted with + /// a similar value. + virtual bool SetAttribute(StatementAttributeId attribute, + const Attribute &value) = 0; + + /// \brief Retrieve a statement attribute. + /// + /// NOTE: Meant to be bound with SQLGetStmtAttr. + /// + /// \param attribute Attribute identifier to be retrieved. + /// \return Value associated with the attribute. + virtual optional + GetAttribute(Statement::StatementAttributeId attribute) = 0; + + /// \brief Prepares the statement. + /// Returns ResultSetMetadata if query returns a result set, + /// otherwise it returns `boost::none`. + /// \param query The SQL query to prepare. + virtual boost::optional> + Prepare(const std::string &query) = 0; + + /// \brief Execute the prepared statement. + /// + /// NOTE: Must call `Prepare(const std::string &query)` before, otherwise it + /// will throw an exception. + /// + /// \returns true if the first result is a ResultSet object; + /// false if it is an update count or there are no results. + virtual bool ExecutePrepared() = 0; + + /// \brief Execute the statement if it is prepared or not. + /// \param query The SQL query to execute. + /// \returns true if the first result is a ResultSet object; + /// false if it is an update count or there are no results. + virtual bool Execute(const std::string &query) = 0; + + /// \brief Returns the current result as a ResultSet object. + virtual std::shared_ptr GetResultSet() = 0; + + /// \brief Retrieves the current result as an update count; + /// if the result is a ResultSet object or there are no more results, -1 is + /// returned. + virtual long GetUpdateCount() = 0; + + /// \brief Returns the list of table, catalog, or schema names, and table + /// types, stored in a specific data source. The driver returns the + /// information as a result set. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param table_type The table type. + virtual std::shared_ptr + GetTables_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *table_type) = 0; + + /// \brief Returns the list of table, catalog, or schema names, and table + /// types, stored in a specific data source. The driver returns the + /// information as a result set. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param table_type The table type. + virtual std::shared_ptr + GetTables_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *table_type) = 0; + + /// \brief Returns the list of column names in specified tables. The driver + /// returns this information as a result set.. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param column_name The column name. + virtual std::shared_ptr + GetColumns_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *column_name) = 0; + + /// \brief Returns the list of column names in specified tables. The driver + /// returns this information as a result set.. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param column_name The column name. + virtual std::shared_ptr + GetColumns_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *column_name) = 0; + + /// \brief Returns information about data types supported by the data source. + /// The driver returns the information in the form of an SQL result set. The + /// data types are intended for use in Data Definition Language (DDL) + /// statements. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param data_type The SQL data type. + virtual std::shared_ptr GetTypeInfo_V2(int16_t data_type) = 0; + + /// \brief Returns information about data types supported by the data source. + /// The driver returns the information in the form of an SQL result set. The + /// data types are intended for use in Data Definition Language (DDL) + /// statements. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param data_type The SQL data type. + virtual std::shared_ptr GetTypeInfo_V3(int16_t data_type) = 0; + + /// \brief Gets the diagnostics for this statement. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; + + /// \brief Cancels the processing of this statement. + virtual void Cancel() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h new file mode 100644 index 0000000000000..959a20a521766 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Supported ODBC versions. +enum OdbcVersion { V_2, V_3, V_4 }; + +// Based on ODBC sql.h and sqlext.h definitions. +enum SqlDataType : int16_t { + SqlDataType_CHAR = 1, + SqlDataType_VARCHAR = 12, + SqlDataType_LONGVARCHAR = (-1), + SqlDataType_WCHAR = (-8), + SqlDataType_WVARCHAR = (-9), + SqlDataType_WLONGVARCHAR = (-10), + SqlDataType_DECIMAL = 3, + SqlDataType_NUMERIC = 2, + SqlDataType_SMALLINT = 5, + SqlDataType_INTEGER = 4, + SqlDataType_REAL = 7, + SqlDataType_FLOAT = 6, + SqlDataType_DOUBLE = 8, + SqlDataType_BIT = (-7), + SqlDataType_TINYINT = (-6), + SqlDataType_BIGINT = (-5), + SqlDataType_BINARY = (-2), + SqlDataType_VARBINARY = (-3), + SqlDataType_LONGVARBINARY = (-4), + SqlDataType_TYPE_DATE = 91, + SqlDataType_TYPE_TIME = 92, + SqlDataType_TYPE_TIMESTAMP = 93, + SqlDataType_INTERVAL_MONTH = (100 + 2), + SqlDataType_INTERVAL_YEAR = (100 + 1), + SqlDataType_INTERVAL_YEAR_TO_MONTH = (100 + 7), + SqlDataType_INTERVAL_DAY = (100 + 3), + SqlDataType_INTERVAL_HOUR = (100 + 4), + SqlDataType_INTERVAL_MINUTE = (100 + 5), + SqlDataType_INTERVAL_SECOND = (100 + 6), + SqlDataType_INTERVAL_DAY_TO_HOUR = (100 + 8), + SqlDataType_INTERVAL_DAY_TO_MINUTE = (100 + 9), + SqlDataType_INTERVAL_DAY_TO_SECOND = (100 + 10), + SqlDataType_INTERVAL_HOUR_TO_MINUTE = (100 + 11), + SqlDataType_INTERVAL_HOUR_TO_SECOND = (100 + 12), + SqlDataType_INTERVAL_MINUTE_TO_SECOND = (100 + 13), + SqlDataType_GUID = (-11), +}; + +enum SqlDateTimeSubCode : int16_t { + SqlDateTimeSubCode_DATE = 1, + SqlDateTimeSubCode_TIME = 2, + SqlDateTimeSubCode_TIMESTAMP = 3, + SqlDateTimeSubCode_YEAR = 1, + SqlDateTimeSubCode_MONTH = 2, + SqlDateTimeSubCode_DAY = 3, + SqlDateTimeSubCode_HOUR = 4, + SqlDateTimeSubCode_MINUTE = 5, + SqlDateTimeSubCode_SECOND = 6, + SqlDateTimeSubCode_YEAR_TO_MONTH = 7, + SqlDateTimeSubCode_DAY_TO_HOUR = 8, + SqlDateTimeSubCode_DAY_TO_MINUTE = 9, + SqlDateTimeSubCode_DAY_TO_SECOND = 10, + SqlDateTimeSubCode_HOUR_TO_MINUTE = 11, + SqlDateTimeSubCode_HOUR_TO_SECOND = 12, + SqlDateTimeSubCode_MINUTE_TO_SECOND = 13, +}; + +// Based on ODBC sql.h and sqlext.h definitions. +enum CDataType { + CDataType_CHAR = 1, + CDataType_WCHAR = -8, + CDataType_SSHORT = (5 + (-20)), + CDataType_USHORT = (5 + (-22)), + CDataType_SLONG = (4 + (-20)), + CDataType_ULONG = (4 + (-22)), + CDataType_FLOAT = 7, + CDataType_DOUBLE = 8, + CDataType_BIT = -7, + CDataType_DATE = 91, + CDataType_TIME = 92, + CDataType_TIMESTAMP = 93, + CDataType_STINYINT = ((-6) + (-20)), + CDataType_UTINYINT = ((-6) + (-22)), + CDataType_SBIGINT = ((-5) + (-20)), + CDataType_UBIGINT = ((-5) + (-22)), + CDataType_BINARY = (-2), + CDataType_NUMERIC = 2, + CDataType_DEFAULT = 99, +}; + +enum Nullability { + NULLABILITY_NO_NULLS = 0, + NULLABILITY_NULLABLE = 1, + NULLABILITY_UNKNOWN = 2, +}; + +enum Searchability { + SEARCHABILITY_NONE = 0, + SEARCHABILITY_LIKE_ONLY = 1, + SEARCHABILITY_ALL_EXPECT_LIKE = 2, + SEARCHABILITY_ALL = 3, +}; + +enum Updatability { + UPDATABILITY_READONLY = 0, + UPDATABILITY_WRITE = 1, + UPDATABILITY_READWRITE_UNKNOWN = 2, +}; + +constexpr ssize_t NULL_DATA = -1; +constexpr ssize_t NO_TOTAL = -4; +constexpr ssize_t ALL_TYPES = 0; +constexpr ssize_t DAYS_TO_SECONDS_MULTIPLIER = 86400; +constexpr ssize_t MILLI_TO_SECONDS_DIVISOR = 1000; +constexpr ssize_t MICRO_TO_SECONDS_DIVISOR = 1000000; +constexpr ssize_t NANO_TO_SECONDS_DIVISOR = 1000000000; + +typedef struct tagDATE_STRUCT +{ + int16_t year; + uint16_t month; + uint16_t day; +} DATE_STRUCT; + +typedef struct tagTIME_STRUCT +{ + uint16_t hour; + uint16_t minute; + uint16_t second; +} TIME_STRUCT; + +typedef struct tagTIMESTAMP_STRUCT +{ + int16_t year; + uint16_t month; + uint16_t day; + uint16_t hour; + uint16_t minute; + uint16_t second; + uint32_t fraction; +} TIMESTAMP_STRUCT; + +typedef struct tagNUMERIC_STRUCT { + uint8_t precision; + int8_t scale; + uint8_t sign; // The sign field is 1 if positive, 0 if negative. + uint8_t val[16]; //[e], [f] +} NUMERIC_STRUCT; + +enum RowStatus: uint16_t { + RowStatus_SUCCESS = 0, // Same as SQL_ROW_SUCCESS + RowStatus_SUCCESS_WITH_INFO = 6, // Same as SQL_ROW_SUCCESS_WITH_INFO + RowStatus_ERROR = 5, // Same as SQL_ROW_ERROR + RowStatus_NOROW = 3 // Same as SQL_ROW_NOROW +}; + +struct MetadataSettings { + boost::optional string_column_length_{boost::none}; + size_t chunk_buffer_capacity_; + bool use_wide_char_; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h new file mode 100644 index 0000000000000..4db2765d7a29d --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +using driver::odbcabstraction::Connection; + +/// Parse a string value to a boolean. +/// \param value the value to be parsed. +/// \return the parsed valued. +boost::optional AsBool(const std::string& value); + +/// Looks up for a value inside the ConnPropertyMap and then try to parse it. +/// In case it does not find or it cannot parse, the default value will be returned. +/// \param connPropertyMap the map with the connection properties. +/// \param property_name the name of the property that will be looked up. +/// \return the parsed valued. +boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, const std::string& property_name); + +/// Looks up for a value inside the ConnPropertyMap and then try to parse it. +/// In case it does not find or it cannot parse, the default value will be returned. +/// \param min_value the minimum value to be parsed, else the default value is returned. +/// \param connPropertyMap the map with the connection properties. +/// \param property_name the name of the property that will be looked up. +/// \return the parsed valued. +/// \exception std::invalid_argument exception from \link std::stoi \endlink +/// \exception std::out_of_range exception from \link std::stoi \endlink +boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, + const std::string& property_name); + + +void ReadConfigFile(PropertyMap &properties, const std::string &configFileName); + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc new file mode 100644 index 0000000000000..1457f28166974 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +namespace driver { +namespace odbcabstraction { + +static std::unique_ptr odbc_logger_ = nullptr; + +Logger *Logger::GetInstance() { + return odbc_logger_.get(); +} + +void Logger::SetInstance(std::unique_ptrlogger) { + odbc_logger_ = std::move(logger); +} + +} +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCConnection.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCConnection.cc new file mode 100644 index 0000000000000..7796280e508cc --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCConnection.cc @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::DriverException; + +namespace +{ + // Key-value pairs separated by semi-colon. + // Note that the value can be wrapped in curly braces to escape other significant characters + // such as semi-colons and equals signs. + // NOTE: This can be optimized to be built statically. + const boost::xpressive::sregex CONNECTION_STR_REGEX(boost::xpressive::sregex::compile( + "([^=;]+)=({.+}|[^=;]+|[^;])")); + +// Load properties from the given DSN. The properties loaded do _not_ overwrite existing +// entries in the properties. +void loadPropertiesFromDSN(const std::string& dsn, Connection::ConnPropertyMap& properties) { + const size_t BUFFER_SIZE = 1024 * 10; + std::vector outputBuffer; + outputBuffer.resize(BUFFER_SIZE, '\0'); + SQLSetConfigMode(ODBC_BOTH_DSN); + SQLGetPrivateProfileString(dsn.c_str(), NULL, "", &outputBuffer[0], BUFFER_SIZE, "odbc.ini"); + + // The output buffer holds the list of keys in a series of NUL-terminated strings. + // The series is terminated with an empty string (eg a NUL-terminator terminating the last + // key followed by a NUL terminator after). + std::vector keys; + size_t pos = 0; + while (pos < BUFFER_SIZE) { + std::string key(&outputBuffer[pos]); + if (key.empty()) { + break; + } + size_t len = key.size(); + + // Skip over Driver or DSN keys. + if (!boost::iequals(key, "DSN") && + !boost::iequals(key, "Driver")) { + keys.emplace_back(std::move(key)); + } + pos += len + 1; + } + + for (auto& key : keys) { + outputBuffer.clear(); + outputBuffer.resize(BUFFER_SIZE, '\0'); + SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), "", &outputBuffer[0], BUFFER_SIZE, "odbc.ini"); + std::string value = std::string(&outputBuffer[0]); + auto propIter = properties.find(key); + if (propIter == properties.end()) { + properties.emplace(std::make_pair(std::move(key), std::move(value))); + } + } +} + +} + +// Public ========================================================================================= +ODBCConnection::ODBCConnection(ODBCEnvironment& environment, + std::shared_ptr spiConnection) : + m_environment(environment), + m_spiConnection(std::move(spiConnection)), + m_is2xConnection(environment.getODBCVersion() == SQL_OV_ODBC2), + m_isConnected(false) +{ + +} + +Diagnostics &ODBCConnection::GetDiagnostics_Impl() { + return m_spiConnection->GetDiagnostics(); +} + +bool ODBCConnection::isConnected() const +{ + return m_isConnected; +} + +const std::string& ODBCConnection::GetDSN() const { + return m_dsn; +} + +void ODBCConnection::connect(std::string dsn, const Connection::ConnPropertyMap &properties, + std::vector &missing_properties) +{ + if (m_isConnected) { + throw DriverException("Already connected.", "HY010"); + } + + m_dsn = std::move(dsn); + m_spiConnection->Connect(properties, missing_properties); + m_isConnected = true; + std::shared_ptr spiStatement = m_spiConnection->CreateStatement(); + m_attributeTrackingStatement = std::make_shared(*this, spiStatement); +} + +void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, bool isUnicode) +{ + + switch (infoType) { + case SQL_ACTIVE_ENVIRONMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + #ifdef SQL_ASYNC_DBC_FUNCTIONS + case SQL_ASYNC_DBC_FUNCTIONS: + GetAttribute(static_cast(SQL_ASYNC_DBC_NOT_CAPABLE), value, bufferLength, outputLength); + break; + #endif + case SQL_ASYNC_MODE: + GetAttribute(static_cast(SQL_AM_NONE), value, bufferLength, outputLength); + break; + #ifdef SQL_ASYNC_NOTIFICATION + case SQL_ASYNC_NOTIFICATION: + GetAttribute(static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE), value, bufferLength, outputLength); + break; + #endif + case SQL_BATCH_ROW_COUNT: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_BATCH_SUPPORT: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DATA_SOURCE_NAME: + GetStringAttribute(isUnicode, m_dsn, true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DRIVER_ODBC_VER: + GetStringAttribute(isUnicode, "03.80", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DYNAMIC_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DYNAMIC_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(SQL_CA1_NEXT), value, bufferLength, outputLength); + break; + case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(SQL_CA2_READ_ONLY_CONCURRENCY), value, bufferLength, outputLength); + break; + case SQL_FILE_USAGE: + GetAttribute(static_cast(SQL_FILE_NOT_SUPPORTED), value, bufferLength, outputLength); + break; + case SQL_KEYSET_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_KEYSET_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_ODBC_INTERFACE_CONFORMANCE: + GetAttribute(static_cast(SQL_OIC_CORE), value, bufferLength, outputLength); + break; + // case SQL_ODBC_STANDARD_CLI_CONFORMANCE: - mentioned in SQLGetInfo spec with no description + // and there is no constant for this. + case SQL_PARAM_ARRAY_ROW_COUNTS: + GetAttribute(static_cast(SQL_PARC_NO_BATCH), value, bufferLength, outputLength); + break; + case SQL_PARAM_ARRAY_SELECTS: + GetAttribute(static_cast(SQL_PAS_NO_SELECT), value, bufferLength, outputLength); + break; + case SQL_ROW_UPDATES: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_SCROLL_OPTIONS: + GetAttribute(static_cast(SQL_SO_FORWARD_ONLY), value, bufferLength, outputLength); + break; + case SQL_STATIC_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_STATIC_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_BOOKMARK_PERSISTENCE: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DESCRIBE_PARAMETER: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_MULT_RESULT_SETS: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_MULTIPLE_ACTIVE_TXN: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_NEED_LONG_DATA_LEN: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_TXN_CAPABLE: + GetAttribute(static_cast(SQL_TC_NONE), value, bufferLength, outputLength); + break; + case SQL_TXN_ISOLATION_OPTION: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_TABLE_TERM: + GetStringAttribute(isUnicode, "table", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + // Deprecated ODBC 2.x fields required for backwards compatibility. + case SQL_ODBC_API_CONFORMANCE: + GetAttribute(static_cast(SQL_OAC_LEVEL1), value, bufferLength, outputLength); + break; + case SQL_FETCH_DIRECTION: + GetAttribute(static_cast(SQL_FETCH_NEXT), value, bufferLength, outputLength); + break; + case SQL_LOCK_TYPES: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_POS_OPERATIONS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_POSITIONED_STATEMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_SCROLL_CONCURRENCY: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_STATIC_SENSITIVITY: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + + // Driver-level string properties. + case SQL_USER_NAME: + case SQL_COLUMN_ALIAS: + case SQL_DBMS_NAME: + case SQL_DBMS_VER: + case SQL_DRIVER_NAME: // TODO: This should be the driver's filename and shouldn't come from the SPI. + case SQL_DRIVER_VER: + case SQL_SEARCH_PATTERN_ESCAPE: + case SQL_SERVER_NAME: + case SQL_DATA_SOURCE_READ_ONLY: + case SQL_ACCESSIBLE_TABLES: + case SQL_ACCESSIBLE_PROCEDURES: + case SQL_CATALOG_TERM: + case SQL_COLLATION_SEQ: + case SQL_SCHEMA_TERM: + case SQL_CATALOG_NAME: + case SQL_CATALOG_NAME_SEPARATOR: + case SQL_EXPRESSIONS_IN_ORDERBY: + case SQL_IDENTIFIER_QUOTE_CHAR: + case SQL_INTEGRITY: + case SQL_KEYWORDS: + case SQL_LIKE_ESCAPE_CLAUSE: + case SQL_MAX_ROW_SIZE_INCLUDES_LONG: + case SQL_ORDER_BY_COLUMNS_IN_SELECT: + case SQL_OUTER_JOINS: // Not documented in SQLGetInfo, but other drivers return Y/N strings + case SQL_PROCEDURE_TERM: + case SQL_PROCEDURES: + case SQL_SPECIAL_CHARACTERS: + case SQL_XOPEN_CLI_YEAR: + { + const auto& info = m_spiConnection->GetInfo(infoType); + const std::string& infoValue = boost::get(info); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, outputLength, GetDiagnostics()); + break; + } + + // Driver-level 32-bit integer properties. + case SQL_GETDATA_EXTENSIONS: + case SQL_INFO_SCHEMA_VIEWS: + case SQL_CURSOR_SENSITIVITY: + case SQL_DEFAULT_TXN_ISOLATION: + case SQL_AGGREGATE_FUNCTIONS: + case SQL_ALTER_DOMAIN: +// case SQL_ALTER_SCHEMA: + case SQL_ALTER_TABLE: + case SQL_DATETIME_LITERALS: + case SQL_CATALOG_USAGE: + case SQL_CREATE_ASSERTION: + case SQL_CREATE_CHARACTER_SET: + case SQL_CREATE_COLLATION: + case SQL_CREATE_DOMAIN: + case SQL_CREATE_SCHEMA: + case SQL_CREATE_TABLE: + case SQL_CREATE_TRANSLATION: + case SQL_CREATE_VIEW: + case SQL_INDEX_KEYWORDS: + case SQL_INSERT_STATEMENT: + case SQL_OJ_CAPABILITIES: + case SQL_SCHEMA_USAGE: + case SQL_SQL_CONFORMANCE: + case SQL_SUBQUERIES: + case SQL_UNION: + case SQL_MAX_BINARY_LITERAL_LEN: + case SQL_MAX_CHAR_LITERAL_LEN: + case SQL_MAX_ROW_SIZE: + case SQL_MAX_STATEMENT_LEN: + case SQL_CONVERT_FUNCTIONS: + case SQL_NUMERIC_FUNCTIONS: + case SQL_STRING_FUNCTIONS: + case SQL_SYSTEM_FUNCTIONS: + case SQL_TIMEDATE_ADD_INTERVALS: + case SQL_TIMEDATE_DIFF_INTERVALS: + case SQL_TIMEDATE_FUNCTIONS: + case SQL_CONVERT_BIGINT: + case SQL_CONVERT_BINARY: + case SQL_CONVERT_BIT: + case SQL_CONVERT_CHAR: + case SQL_CONVERT_DATE: + case SQL_CONVERT_DECIMAL: + case SQL_CONVERT_DOUBLE: + case SQL_CONVERT_FLOAT: + case SQL_CONVERT_GUID: + case SQL_CONVERT_INTEGER: + case SQL_CONVERT_INTERVAL_DAY_TIME: + case SQL_CONVERT_INTERVAL_YEAR_MONTH: + case SQL_CONVERT_LONGVARBINARY: + case SQL_CONVERT_LONGVARCHAR: + case SQL_CONVERT_NUMERIC: + case SQL_CONVERT_REAL: + case SQL_CONVERT_SMALLINT: + case SQL_CONVERT_TIME: + case SQL_CONVERT_TIMESTAMP: + case SQL_CONVERT_TINYINT: + case SQL_CONVERT_VARBINARY: + case SQL_CONVERT_VARCHAR: + case SQL_CONVERT_WCHAR: + case SQL_CONVERT_WVARCHAR: + case SQL_CONVERT_WLONGVARCHAR: + case SQL_DDL_INDEX: + case SQL_DROP_ASSERTION: + case SQL_DROP_CHARACTER_SET: + case SQL_DROP_COLLATION: + case SQL_DROP_DOMAIN: + case SQL_DROP_SCHEMA: + case SQL_DROP_TABLE: + case SQL_DROP_TRANSLATION: + case SQL_DROP_VIEW: + case SQL_MAX_INDEX_SIZE: + case SQL_SQL92_DATETIME_FUNCTIONS: + case SQL_SQL92_FOREIGN_KEY_DELETE_RULE: + case SQL_SQL92_FOREIGN_KEY_UPDATE_RULE: + case SQL_SQL92_GRANT: + case SQL_SQL92_NUMERIC_VALUE_FUNCTIONS: + case SQL_SQL92_PREDICATES: + case SQL_SQL92_RELATIONAL_JOIN_OPERATORS: + case SQL_SQL92_REVOKE: + case SQL_SQL92_ROW_VALUE_CONSTRUCTOR: + case SQL_SQL92_STRING_FUNCTIONS: + case SQL_SQL92_VALUE_EXPRESSIONS: + case SQL_STANDARD_CLI_CONFORMANCE: + { + const auto& info = m_spiConnection->GetInfo(infoType); + uint32_t infoValue = boost::get(info); + GetAttribute(infoValue, value, bufferLength, outputLength); + break; + } + + // Driver-level 16-bit integer properties. + case SQL_MAX_CONCURRENT_ACTIVITIES: + case SQL_MAX_DRIVER_CONNECTIONS: + case SQL_CONCAT_NULL_BEHAVIOR: + case SQL_CURSOR_COMMIT_BEHAVIOR: + case SQL_CURSOR_ROLLBACK_BEHAVIOR: + case SQL_NULL_COLLATION: + case SQL_CATALOG_LOCATION: + case SQL_CORRELATION_NAME: + case SQL_GROUP_BY: + case SQL_IDENTIFIER_CASE: + case SQL_NON_NULLABLE_COLUMNS: + case SQL_QUOTED_IDENTIFIER_CASE: + case SQL_MAX_CATALOG_NAME_LEN: + case SQL_MAX_COLUMN_NAME_LEN: + case SQL_MAX_COLUMNS_IN_GROUP_BY: + case SQL_MAX_COLUMNS_IN_INDEX: + case SQL_MAX_COLUMNS_IN_ORDER_BY: + case SQL_MAX_COLUMNS_IN_SELECT: + case SQL_MAX_COLUMNS_IN_TABLE: + case SQL_MAX_CURSOR_NAME_LEN: + case SQL_MAX_IDENTIFIER_LEN: + case SQL_MAX_SCHEMA_NAME_LEN: + case SQL_MAX_TABLE_NAME_LEN: + case SQL_MAX_TABLES_IN_SELECT: + case SQL_MAX_PROCEDURE_NAME_LEN: + case SQL_MAX_USER_NAME_LEN: + case SQL_ODBC_SQL_CONFORMANCE: + case SQL_ODBC_SAG_CLI_CONFORMANCE: + { + const auto& info = m_spiConnection->GetInfo(infoType); + uint16_t infoValue = boost::get(info); + GetAttribute(infoValue, value, bufferLength, outputLength); + break; + } + + // Special case - SQL_DATABASE_NAME is an alias for SQL_ATTR_CURRENT_CATALOG. + case SQL_DATABASE_NAME: + { + const auto &attr = + m_spiConnection->GetAttribute(Connection::CURRENT_CATALOG); + if (!attr) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + const std::string &infoValue = boost::get(*attr); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength,outputLength, GetDiagnostics()); + break; + } + default: + throw DriverException("Unknown SQLGetInfo type: " + std::to_string(infoType)); + } +} + +void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength, bool isUnicode) { + uint32_t attributeToWrite = 0; + bool successfully_written = false; + switch (attribute) { + // Internal connection attributes +#ifdef SQL_ATR_ASYNC_DBC_EVENT + case SQL_ATTR_ASYNC_DBC_EVENT: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE + case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_PCALLBACK + case SQL_ATTR_ASYNC_DBC_PCALLBACK: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT + case SQL_ATTR_ASYNC_DBC_PCONTEXT: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif + case SQL_ATTR_AUTO_IPD: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_AUTOCOMMIT: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_AUTOCOMMIT_ON)); + return; + case SQL_ATTR_CONNECTION_DEAD: + throw DriverException("Cannot set read-only attribute", "HY092"); +#ifdef SQL_ATTR_DBC_INFO_TOKEN + case SQL_ATTR_DBC_INFO_TOKEN: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif + case SQL_ATTR_ENLIST_IN_DTC: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_ODBC_CURSORS: // DM-only. + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_QUIET_MODE: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_TRACE: // DM-only + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_TRACEFILE: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_LIB: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_OPTION: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TXN_ISOLATION: + throw DriverException("Optional feature not supported.", "HYC00"); + + // ODBCAbstraction-level attributes + case SQL_ATTR_CURRENT_CATALOG: { + std::string catalog; + if (isUnicode) { + SetAttributeUTF8(value, stringLength, catalog); + } else { + SetAttributeSQLWCHAR(value, stringLength, catalog); + } + if (!m_spiConnection->SetAttribute(Connection::CURRENT_CATALOG, catalog)) { + throw DriverException("Option value changed.", "01S02"); + } + return; + } + + // Statement attributes that can be set through the connection. + // Only applies to SQL_ATTR_METADATA_ID, SQL_ATTR_ASYNC_ENABLE, and ODBC 2.x statement attributes. + // SQL_ATTR_ROW_NUMBER is excluded because it is read-only. + // Note that SQLGetConnectAttr cannot retrieve these attributes. + case SQL_ATTR_ASYNC_ENABLE: + case SQL_ATTR_METADATA_ID: + case SQL_ATTR_CONCURRENCY: + case SQL_ATTR_CURSOR_TYPE: + case SQL_ATTR_KEYSET_SIZE: + case SQL_ATTR_MAX_LENGTH: + case SQL_ATTR_MAX_ROWS: + case SQL_ATTR_NOSCAN: + case SQL_ATTR_QUERY_TIMEOUT: + case SQL_ATTR_RETRIEVE_DATA: + case SQL_ATTR_ROW_BIND_TYPE: + case SQL_ATTR_SIMULATE_CURSOR: + case SQL_ATTR_USE_BOOKMARKS: + m_attributeTrackingStatement->SetStmtAttr(attribute, value, stringLength, isUnicode); + return; + + case SQL_ATTR_ACCESS_MODE: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::ACCESS_MODE, attributeToWrite); + break; + case SQL_ATTR_CONNECTION_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::CONNECTION_TIMEOUT, attributeToWrite); + break; + case SQL_ATTR_LOGIN_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::LOGIN_TIMEOUT, attributeToWrite); + break; + case SQL_ATTR_PACKET_SIZE: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::PACKET_SIZE, attributeToWrite); + break; + default: + throw DriverException("Invalid attribute: " + std::to_string(attribute), "HY092"); + } + + if (!successfully_written) { + GetDiagnostics().AddWarning("Option value changed.", "01S02", ODBCErrorCodes_GENERAL_WARNING); + } +} + +void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER *outputLength, bool isUnicode) { + using driver::odbcabstraction::Connection; + boost::optional spiAttribute; + + switch (attribute) { + // Internal connection attributes +#ifdef SQL_ATR_ASYNC_DBC_EVENT + case SQL_ATTR_ASYNC_DBC_EVENT: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE + case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_DBC_ENABLE_OFF), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_PCALLBACK + case SQL_ATTR_ASYNC_DBC_PCALLBACK: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT + case SQL_ATTR_ASYNC_DBC_PCONTEXT: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif + case SQL_ATTR_ASYNC_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), value, bufferLength, outputLength); + return; + case SQL_ATTR_AUTO_IPD: + GetAttribute(static_cast(SQL_FALSE), value, bufferLength, outputLength); + return; + case SQL_ATTR_AUTOCOMMIT: + GetAttribute(static_cast(SQL_AUTOCOMMIT_ON), value, bufferLength, outputLength); + return; +#ifdef SQL_ATTR_DBC_INFO_TOKEN + case SQL_ATTR_DBC_INFO_TOKEN: + throw DriverException("Cannot read set-only attribute", "HY092"); +#endif + case SQL_ATTR_ENLIST_IN_DTC: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; + case SQL_ATTR_ODBC_CURSORS: // DM-only. + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_QUIET_MODE: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; + case SQL_ATTR_TRACE: // DM-only + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_TRACEFILE: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_LIB: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_OPTION: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TXN_ISOLATION: + throw DriverException("Optional feature not supported.", "HCY00"); + + // ODBCAbstraction-level connection attributes. + case SQL_ATTR_CURRENT_CATALOG: + { + const auto &catalog = + m_spiConnection->GetAttribute(Connection::CURRENT_CATALOG); + if (!catalog) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + const std::string &infoValue = boost::get(*catalog); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength,outputLength, GetDiagnostics()); + return; + } + + // These all are uint32_t attributes. + case SQL_ATTR_ACCESS_MODE: + spiAttribute = m_spiConnection->GetAttribute(Connection::ACCESS_MODE); + break; + case SQL_ATTR_CONNECTION_DEAD: + spiAttribute = m_spiConnection->GetAttribute(Connection::CONNECTION_DEAD); + break; + case SQL_ATTR_CONNECTION_TIMEOUT: + spiAttribute = m_spiConnection->GetAttribute(Connection::CONNECTION_TIMEOUT); + break; + case SQL_ATTR_LOGIN_TIMEOUT: + spiAttribute = m_spiConnection->GetAttribute(Connection::LOGIN_TIMEOUT); + break; + case SQL_ATTR_PACKET_SIZE: + spiAttribute = m_spiConnection->GetAttribute(Connection::PACKET_SIZE); + break; + default: + throw DriverException("Invalid attribute", "HY092"); + } + + if (!spiAttribute) { + throw DriverException("Invalid attribute", "HY092"); + } + + GetAttribute(static_cast(boost::get(*spiAttribute)), value, bufferLength, outputLength); +} + +void ODBCConnection::disconnect() { + if (m_isConnected) { + // Ensure that all statements (and corresponding SPI statements) get cleaned + // up before terminating the SPI connection in case they need to be de-allocated in + // the reverse of the allocation order. + m_statements.clear(); + m_spiConnection->Close(); + m_isConnected = false; + } +} + +void ODBCConnection::releaseConnection() { + disconnect(); + m_environment.DropConnection(this); +} + +std::shared_ptr ODBCConnection::createStatement() { + std::shared_ptr spiStatement = m_spiConnection->CreateStatement(); + std::shared_ptr statement = std::make_shared(*this, spiStatement); + m_statements.push_back(statement); + statement->CopyAttributesFromConnection(*this); + return statement; +} + +void ODBCConnection::dropStatement(ODBCStatement* stmt) { + auto it = std::find_if(m_statements.begin(), m_statements.end(), + [&stmt] (const std::shared_ptr& statement) { return statement.get() == stmt; }); + if (m_statements.end() != it) { + m_statements.erase(it); + } +} + +std::shared_ptr ODBCConnection::createDescriptor() { + std::shared_ptr desc = std::make_shared( + m_spiConnection->GetDiagnostics(), this, nullptr, true, true, false); + m_descriptors.push_back(desc); + return desc; +} + +void ODBCConnection::dropDescriptor(ODBCDescriptor* desc) { + auto it = std::find_if(m_descriptors.begin(), m_descriptors.end(), + [&desc] (const std::shared_ptr& descriptor) { return descriptor.get() == desc; }); + if (m_descriptors.end() != it) { + m_descriptors.erase(it); + } +} + +// Public Static =================================================================================== +std::string ODBCConnection::getPropertiesFromConnString(const std::string& connStr, + Connection::ConnPropertyMap &properties) +{ + const int groups[] = { 1, 2 }; // CONNECTION_STR_REGEX has two groups. key: 1, value: 2 + boost::xpressive::sregex_token_iterator regexIter(connStr.begin(), connStr.end(), + CONNECTION_STR_REGEX, groups), end; + + bool isDsnFirst = false; + bool isDriverFirst = false; + std::string dsn; + for (auto it = regexIter; end != regexIter; ++regexIter) { + std::string key = *regexIter; + std::string value = *++regexIter; + + // If the DSN shows up before driver key, load settings from the DSN. + // Only load values from the DSN once regardless of how many times the DSN + // key shows up. + if (boost::iequals(key, "DSN")) { + if (!isDriverFirst) { + if (!isDsnFirst) { + isDsnFirst = true; + loadPropertiesFromDSN(value, properties); + dsn.swap(value); + } + } + continue; + } else if (boost::iequals(key, "Driver")) { + if (!isDsnFirst) { + isDriverFirst = true; + } + continue; + } + + // Strip wrapping curly braces. + if (value.size() >= 2 && value[0] == '{' && value[value.size() - 1] == '}') { + value = value.substr(1, value.size() - 2); + } + + // Overwrite the existing value. Later copies of the key take precedence, + // including over entries in the DSN. + properties[key] = std::move(value); + } + return dsn; +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc new file mode 100644 index 0000000000000..d5d1e70fa1725 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc @@ -0,0 +1,547 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +namespace { + SQLSMALLINT CalculateHighestBoundRecord(const std::vector& records) { + // Most applications will bind every column, so optimistically assume that we'll + // find the next bound record fastest by counting backwards. + for (size_t i = records.size(); i > 0; --i) { + if (records[i-1].m_isBound) { + return i; + } + } + return 0; + } +} + +// Public ========================================================================================= +ODBCDescriptor::ODBCDescriptor(Diagnostics& baseDiagnostics, + ODBCConnection* conn, ODBCStatement* stmt, bool isAppDescriptor, bool isWritable, bool is2xConnection) : + m_diagnostics(baseDiagnostics.GetVendor(), baseDiagnostics.GetDataSourceComponent(), V_3), + m_owningConnection(conn), + m_parentStatement(stmt), + m_arrayStatusPtr(nullptr), + m_bindOffsetPtr(nullptr), + m_rowsProccessedPtr(nullptr), + m_arraySize(1), + m_bindType(SQL_BIND_BY_COLUMN), + m_highestOneBasedBoundRecord(0), + m_is2xConnection(is2xConnection), + m_isAppDescriptor(isAppDescriptor), + m_isWritable(isWritable), + m_hasBindingsChanged(true) { +} + +Diagnostics &ODBCDescriptor::GetDiagnostics_Impl() { + return m_diagnostics; +} + +ODBCConnection &ODBCDescriptor::GetConnection() { + if (m_owningConnection) { + return *m_owningConnection; + } + assert(m_parentStatement); + return m_parentStatement->GetConnection(); +} + +void ODBCDescriptor::SetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength) { + // Only these two fields can be set on the IRD. + if (!m_isWritable && fieldIdentifier != SQL_DESC_ARRAY_STATUS_PTR && fieldIdentifier != SQL_DESC_ROWS_PROCESSED_PTR) { + throw DriverException("Cannot modify read-only descriptor", "HY016"); + } + + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + throw DriverException("Invalid descriptor field", "HY091"); + case SQL_DESC_ARRAY_SIZE: + SetAttribute(value, m_arraySize); + m_hasBindingsChanged = true; + break; + case SQL_DESC_ARRAY_STATUS_PTR: + SetPointerAttribute(value, m_arrayStatusPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_BIND_OFFSET_PTR: + SetPointerAttribute(value, m_bindOffsetPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_BIND_TYPE: + SetAttribute(value, m_bindType); + m_hasBindingsChanged = true; + break; + case SQL_DESC_ROWS_PROCESSED_PTR: + SetPointerAttribute(value, m_rowsProccessedPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_COUNT: { + SQLSMALLINT newCount; + SetAttribute(value, newCount); + m_records.resize(newCount); + + if (m_isAppDescriptor && newCount <= m_highestOneBasedBoundRecord) { + m_highestOneBasedBoundRecord = CalculateHighestBoundRecord(m_records); + } else { + m_highestOneBasedBoundRecord = newCount; + } + m_hasBindingsChanged = true; + break; + } + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::SetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength) { + if (!m_isWritable) { + throw DriverException("Cannot modify read-only descriptor", "HY016"); + } + + // Handle header fields before validating the record number. + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + case SQL_DESC_ARRAY_SIZE: + case SQL_DESC_ARRAY_STATUS_PTR: + case SQL_DESC_BIND_OFFSET_PTR: + case SQL_DESC_BIND_TYPE: + case SQL_DESC_ROWS_PROCESSED_PTR: + case SQL_DESC_COUNT: + SetHeaderField(fieldIdentifier, value, bufferLength); + return; + default: + break; + } + + if (recordNumber == 0) { + throw DriverException("Bookmarks are unsupported.", "07009"); + } + + if (recordNumber > m_records.size()) { + throw DriverException("Invalid descriptor index", "HY009"); + } + + SQLSMALLINT zeroBasedRecord = recordNumber - 1; + DescriptorRecord& record = m_records[zeroBasedRecord]; + switch (fieldIdentifier) { + case SQL_DESC_AUTO_UNIQUE_VALUE: + case SQL_DESC_BASE_COLUMN_NAME: + case SQL_DESC_BASE_TABLE_NAME: + case SQL_DESC_CASE_SENSITIVE: + case SQL_DESC_CATALOG_NAME: + case SQL_DESC_DISPLAY_SIZE: + case SQL_DESC_FIXED_PREC_SCALE: + case SQL_DESC_LABEL: + case SQL_DESC_LITERAL_PREFIX: + case SQL_DESC_LITERAL_SUFFIX: + case SQL_DESC_LOCAL_TYPE_NAME: + case SQL_DESC_NULLABLE: + case SQL_DESC_NUM_PREC_RADIX: + case SQL_DESC_ROWVER: + case SQL_DESC_SCHEMA_NAME: + case SQL_DESC_SEARCHABLE: + case SQL_DESC_TABLE_NAME: + case SQL_DESC_TYPE_NAME: + case SQL_DESC_UNNAMED: + case SQL_DESC_UNSIGNED: + case SQL_DESC_UPDATABLE: + throw DriverException("Cannot modify read-only field.", "HY092"); + case SQL_DESC_CONCISE_TYPE: + SetAttribute(value, record.m_conciseType); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_DATA_PTR: + SetDataPtrOnRecord(value, recordNumber); + break; + case SQL_DESC_DATETIME_INTERVAL_CODE: + SetAttribute(value, record.m_datetimeIntervalCode); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_DATETIME_INTERVAL_PRECISION: + SetAttribute(value, record.m_datetimeIntervalPrecision); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_INDICATOR_PTR: + case SQL_DESC_OCTET_LENGTH_PTR: + SetPointerAttribute(value, record.m_indicatorPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_LENGTH: + SetAttribute(value, record.m_length); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_NAME: + SetAttributeUTF8(value, bufferLength, record.m_name); + m_hasBindingsChanged = true; + break; + case SQL_DESC_OCTET_LENGTH: + SetAttribute(value, record.m_octetLength); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_PARAMETER_TYPE: + SetAttribute(value, record.m_paramType); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_PRECISION: + SetAttribute(value, record.m_precision); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_SCALE: + SetAttribute(value, record.m_scale); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_TYPE: + SetAttribute(value, record.m_type); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::GetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) const { + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: { + SQLSMALLINT result; + if (m_owningConnection) { + result = SQL_DESC_ALLOC_USER; + } else { + result = SQL_DESC_ALLOC_AUTO; + } + GetAttribute(result, value, bufferLength, outputLength); + break; + } + case SQL_DESC_ARRAY_SIZE: + GetAttribute(m_arraySize, value, bufferLength, outputLength); + break; + case SQL_DESC_ARRAY_STATUS_PTR: + GetAttribute(m_arrayStatusPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_BIND_OFFSET_PTR: + GetAttribute(m_bindOffsetPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_BIND_TYPE: + GetAttribute(m_bindType, value, bufferLength, outputLength); + break; + case SQL_DESC_ROWS_PROCESSED_PTR: + GetAttribute(m_rowsProccessedPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_COUNT: { + GetAttribute(m_highestOneBasedBoundRecord, value, bufferLength, outputLength); + break; + } + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) { + // Handle header fields before validating the record number. + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + case SQL_DESC_ARRAY_SIZE: + case SQL_DESC_ARRAY_STATUS_PTR: + case SQL_DESC_BIND_OFFSET_PTR: + case SQL_DESC_BIND_TYPE: + case SQL_DESC_ROWS_PROCESSED_PTR: + case SQL_DESC_COUNT: + GetHeaderField(fieldIdentifier, value, bufferLength, outputLength); + return; + default: + break; + } + + if (recordNumber == 0) { + throw DriverException("Bookmarks are unsupported.", "07009"); + } + + if (recordNumber > m_records.size()) { + throw DriverException("Invalid descriptor index", "07009"); + } + + // TODO: Restrict fields based on AppDescriptor IPD, and IRD. + + SQLSMALLINT zeroBasedRecord = recordNumber - 1; + const DescriptorRecord& record = m_records[zeroBasedRecord]; + switch (fieldIdentifier) { + case SQL_DESC_BASE_COLUMN_NAME: + GetAttributeUTF8(record.m_baseColumnName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_BASE_TABLE_NAME: + GetAttributeUTF8(record.m_baseTableName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_CATALOG_NAME: + GetAttributeUTF8(record.m_catalogName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LABEL: + GetAttributeUTF8(record.m_label, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LITERAL_PREFIX: + GetAttributeUTF8(record.m_literalPrefix, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LITERAL_SUFFIX: + GetAttributeUTF8(record.m_literalSuffix, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LOCAL_TYPE_NAME: + GetAttributeUTF8(record.m_localTypeName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_NAME: + GetAttributeUTF8(record.m_name, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_SCHEMA_NAME: + GetAttributeUTF8(record.m_schemaName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_TABLE_NAME: + GetAttributeUTF8(record.m_tableName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_TYPE_NAME: + GetAttributeUTF8(record.m_typeName, value, bufferLength, outputLength, GetDiagnostics()); + break; + + case SQL_DESC_DATA_PTR: + GetAttribute(record.m_dataPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_INDICATOR_PTR: + case SQL_DESC_OCTET_LENGTH_PTR: + GetAttribute(record.m_indicatorPtr, value, bufferLength, outputLength); + break; + + case SQL_DESC_LENGTH: + GetAttribute(record.m_length, value, bufferLength, outputLength); + break; + case SQL_DESC_OCTET_LENGTH: + GetAttribute(record.m_octetLength, value, bufferLength, outputLength); + break; + + case SQL_DESC_AUTO_UNIQUE_VALUE: + GetAttribute(record.m_autoUniqueValue, value, bufferLength, outputLength); + break; + case SQL_DESC_CASE_SENSITIVE: + GetAttribute(record.m_caseSensitive, value, bufferLength, outputLength); + break; + case SQL_DESC_DATETIME_INTERVAL_PRECISION: + GetAttribute(record.m_datetimeIntervalPrecision, value, bufferLength, outputLength); + break; + case SQL_DESC_NUM_PREC_RADIX: + GetAttribute(record.m_numPrecRadix, value, bufferLength, outputLength); + break; + + case SQL_DESC_CONCISE_TYPE: + GetAttribute(record.m_conciseType, value, bufferLength, outputLength); + break; + case SQL_DESC_DATETIME_INTERVAL_CODE: + GetAttribute(record.m_datetimeIntervalCode, value, bufferLength, outputLength); + break; + case SQL_DESC_DISPLAY_SIZE: + GetAttribute(record.m_displaySize, value, bufferLength, outputLength); + break; + case SQL_DESC_FIXED_PREC_SCALE: + GetAttribute(record.m_fixedPrecScale, value, bufferLength, outputLength); + break; + case SQL_DESC_NULLABLE: + GetAttribute(record.m_nullable, value, bufferLength, outputLength); + break; + case SQL_DESC_PARAMETER_TYPE: + GetAttribute(record.m_paramType, value, bufferLength, outputLength); + break; + case SQL_DESC_PRECISION: + GetAttribute(record.m_precision, value, bufferLength, outputLength); + break; + case SQL_DESC_ROWVER: + GetAttribute(record.m_rowVer, value, bufferLength, outputLength); + break; + case SQL_DESC_SCALE: + GetAttribute(record.m_scale, value, bufferLength, outputLength); + break; + case SQL_DESC_SEARCHABLE: + GetAttribute(record.m_searchable, value, bufferLength, outputLength); + break; + case SQL_DESC_TYPE: + GetAttribute(record.m_type, value, bufferLength, outputLength); + break; + case SQL_DESC_UNNAMED: + GetAttribute(record.m_unnamed, value, bufferLength, outputLength); + break; + case SQL_DESC_UNSIGNED: + GetAttribute(record.m_unsigned, value, bufferLength, outputLength); + break; + case SQL_DESC_UPDATABLE: + GetAttribute(record.m_updatable, value, bufferLength, outputLength); + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +SQLSMALLINT ODBCDescriptor::getAllocType() const { + return m_owningConnection != nullptr ? SQL_DESC_ALLOC_USER : SQL_DESC_ALLOC_AUTO; +} + +bool ODBCDescriptor::IsAppDescriptor() const { + return m_isAppDescriptor; +} + +void ODBCDescriptor::RegisterToStatement(ODBCStatement* statement, bool isApd) { + if (isApd) { + m_registeredOnStatementsAsApd.push_back(statement); + } else { + m_registeredOnStatementsAsArd.push_back(statement); + } +} + +void ODBCDescriptor::DetachFromStatement(ODBCStatement* statement, bool isApd) { + auto& vectorToUpdate = isApd ? m_registeredOnStatementsAsApd : m_registeredOnStatementsAsArd; + auto it = std::find(vectorToUpdate.begin(), vectorToUpdate.end(), statement); + if (it != vectorToUpdate.end()) { + vectorToUpdate.erase(it); + } +} + +void ODBCDescriptor::ReleaseDescriptor() { + for (ODBCStatement* stmt : m_registeredOnStatementsAsApd) { + stmt->RevertAppDescriptor(true); + } + + for (ODBCStatement* stmt : m_registeredOnStatementsAsArd) { + stmt->RevertAppDescriptor(false); + } + + if (m_owningConnection) { + m_owningConnection->dropDescriptor(this); + } +} + +void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { + m_records.assign(rsmd->GetColumnCount(), DescriptorRecord()); + m_highestOneBasedBoundRecord = m_records.size() + 1; + + for (size_t i = 0; i < m_records.size(); ++i) { + size_t oneBasedIndex = i + 1; + m_records[i].m_baseColumnName = rsmd->GetBaseColumnName(oneBasedIndex); + m_records[i].m_baseTableName = rsmd->GetBaseTableName(oneBasedIndex); + m_records[i].m_catalogName = rsmd->GetCatalogName(oneBasedIndex); + m_records[i].m_label = rsmd->GetColumnLabel(oneBasedIndex); + m_records[i].m_literalPrefix = rsmd->GetLiteralPrefix(oneBasedIndex); + m_records[i].m_literalSuffix = rsmd->GetLiteralSuffix(oneBasedIndex); + m_records[i].m_localTypeName = rsmd->GetLocalTypeName(oneBasedIndex); + m_records[i].m_name = rsmd->GetName(oneBasedIndex); + m_records[i].m_schemaName = rsmd->GetSchemaName(oneBasedIndex); + m_records[i].m_tableName = rsmd->GetTableName(oneBasedIndex); + m_records[i].m_typeName = rsmd->GetTypeName(oneBasedIndex); + m_records[i].m_conciseType = GetSqlTypeForODBCVersion(rsmd->GetConciseType(oneBasedIndex), m_is2xConnection); + m_records[i].m_dataPtr = nullptr; + m_records[i].m_indicatorPtr = nullptr; + m_records[i].m_displaySize = rsmd->GetColumnDisplaySize(oneBasedIndex); + m_records[i].m_octetLength = rsmd->GetOctetLength(oneBasedIndex); + m_records[i].m_length = rsmd->GetLength(oneBasedIndex); + m_records[i].m_autoUniqueValue = rsmd->IsAutoUnique(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_caseSensitive = rsmd->IsCaseSensitive(oneBasedIndex)? SQL_TRUE : SQL_FALSE; + m_records[i].m_datetimeIntervalPrecision; // TODO - update when rsmd adds this + m_records[i].m_numPrecRadix = rsmd->GetNumPrecRadix(oneBasedIndex); + m_records[i].m_datetimeIntervalCode; // TODO + m_records[i].m_fixedPrecScale = rsmd->IsFixedPrecScale(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_nullable = rsmd->IsNullable(oneBasedIndex); + m_records[i].m_paramType = SQL_PARAM_INPUT; + m_records[i].m_precision = rsmd->GetPrecision(oneBasedIndex); + m_records[i].m_rowVer = SQL_FALSE; + m_records[i].m_scale = rsmd->GetScale(oneBasedIndex); + m_records[i].m_searchable = rsmd->IsSearchable(oneBasedIndex); + m_records[i].m_type = GetSqlTypeForODBCVersion(rsmd->GetDataType(oneBasedIndex), m_is2xConnection); + m_records[i].m_unnamed = m_records[i].m_name.empty() ? SQL_TRUE : SQL_FALSE; + m_records[i].m_unsigned = rsmd->IsUnsigned(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_updatable = rsmd->GetUpdatable(oneBasedIndex); + } +} + +const std::vector& ODBCDescriptor::GetRecords() const { + return m_records; +} + +std::vector& ODBCDescriptor::GetRecords() { + return m_records; +} + +void ODBCDescriptor::BindCol(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + assert(m_isAppDescriptor); + assert(m_isWritable); + + // The set of records auto-expands to the supplied record number. + if (m_records.size() < recordNumber) { + m_records.resize(recordNumber); + } + + SQLSMALLINT zeroBasedRecordIndex = recordNumber - 1; + DescriptorRecord& record = m_records[zeroBasedRecordIndex]; + + record.m_type = cType; + record.m_indicatorPtr = indicatorPtr; + record.m_length = bufferLength; + + // Initialize default precision and scale for SQL_C_NUMERIC. + if (record.m_type == SQL_C_NUMERIC) { + record.m_precision = 38; + record.m_scale = 0; + } + SetDataPtrOnRecord(dataPtr, recordNumber); +} + +void ODBCDescriptor::SetDataPtrOnRecord(SQLPOINTER dataPtr, SQLSMALLINT recordNumber) { + assert(recordNumber <= m_records.size()); + DescriptorRecord& record = m_records[recordNumber-1]; + if (dataPtr) { + record.CheckConsistency(); + record.m_isBound = true; + } else { + record.m_isBound = false; + } + record.m_dataPtr = dataPtr; + + // Bookkeeping on the highest bound record (used for returning SQL_DESC_COUNT) + if (m_highestOneBasedBoundRecord < recordNumber && dataPtr) { + m_highestOneBasedBoundRecord = recordNumber; + } else if (m_highestOneBasedBoundRecord == recordNumber && !dataPtr) { + m_highestOneBasedBoundRecord = CalculateHighestBoundRecord(m_records); + } + m_hasBindingsChanged = true; +} + +void DescriptorRecord::CheckConsistency() { + // TODO. +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc new file mode 100644 index 0000000000000..c5709c97b6c05 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +// Public ========================================================================================= +ODBCEnvironment::ODBCEnvironment(std::shared_ptr driver) : + m_driver(std::move(driver)), + m_diagnostics(new Diagnostics(m_driver->GetDiagnostics().GetVendor(), + m_driver->GetDiagnostics().GetDataSourceComponent(), + V_2)), + m_version(SQL_OV_ODBC2), + m_connectionPooling(SQL_CP_OFF) { +} + +Diagnostics &ODBCEnvironment::GetDiagnostics_Impl() { + return *m_diagnostics; +} + +SQLINTEGER ODBCEnvironment::getODBCVersion() const { + return m_version; +} + +void ODBCEnvironment::setODBCVersion(SQLINTEGER version) { + if (version != m_version) { + m_version = version; + m_diagnostics.reset( + new Diagnostics(m_diagnostics->GetVendor(), + m_diagnostics->GetDataSourceComponent(), + version == SQL_OV_ODBC2 ? V_2 : V_3)); + } +} + +SQLINTEGER ODBCEnvironment::getConnectionPooling() const { + return m_connectionPooling; +} + +void ODBCEnvironment::setConnectionPooling(SQLINTEGER connectionPooling) { + m_connectionPooling = connectionPooling; +} + +std::shared_ptr ODBCEnvironment::CreateConnection() { + std::shared_ptr spiConnection = m_driver->CreateConnection(m_version == SQL_OV_ODBC2 ? V_2 : V_3); + std::shared_ptr newConn = std::make_shared(*this, spiConnection); + m_connections.push_back(newConn); + return newConn; +} + +void ODBCEnvironment::DropConnection(ODBCConnection* conn) { + auto it = std::find_if(m_connections.begin(), m_connections.end(), + [&conn] (const std::shared_ptr& connection) { return connection.get() == conn; }); + if (m_connections.end() != it) { + m_connections.erase(it); + } +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCStatement.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCStatement.cc new file mode 100644 index 0000000000000..6d7c9ee171ff1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/ODBCStatement.cc @@ -0,0 +1,739 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +namespace { + void DescriptorToHandle(SQLPOINTER output, ODBCDescriptor* descriptor, SQLINTEGER* lenPtr) { + if (output) { + SQLHANDLE* outputHandle = static_cast(output); + *outputHandle = reinterpret_cast(descriptor); + } + if (lenPtr) { + *lenPtr = sizeof(SQLHANDLE); + } + } + + size_t GetLength(const DescriptorRecord& record) { + switch (record.m_type) { + case SQL_C_CHAR: + case SQL_C_WCHAR: + case SQL_C_BINARY: + return record.m_length; + + case SQL_C_BIT: + case SQL_C_TINYINT: + case SQL_C_STINYINT: + case SQL_C_UTINYINT: + return sizeof(SQLSCHAR); + + case SQL_C_SHORT: + case SQL_C_SSHORT: + case SQL_C_USHORT: + return sizeof(SQLSMALLINT); + + case SQL_C_LONG: + case SQL_C_SLONG: + case SQL_C_ULONG: + case SQL_C_FLOAT: + return sizeof(SQLINTEGER); + + case SQL_C_SBIGINT: + case SQL_C_UBIGINT: + case SQL_C_DOUBLE: + return sizeof(SQLBIGINT); + + case SQL_C_NUMERIC: + return sizeof(SQL_NUMERIC_STRUCT); + + case SQL_C_DATE: + case SQL_C_TYPE_DATE: + return sizeof(SQL_DATE_STRUCT); + + case SQL_C_TIME: + case SQL_C_TYPE_TIME: + return sizeof(SQL_TIME_STRUCT); + + case SQL_C_TIMESTAMP: + case SQL_C_TYPE_TIMESTAMP: + return sizeof(SQL_TIMESTAMP_STRUCT); + + case SQL_C_INTERVAL_DAY: + case SQL_C_INTERVAL_DAY_TO_HOUR: + case SQL_C_INTERVAL_DAY_TO_MINUTE: + case SQL_C_INTERVAL_DAY_TO_SECOND: + case SQL_C_INTERVAL_HOUR: + case SQL_C_INTERVAL_HOUR_TO_MINUTE: + case SQL_C_INTERVAL_HOUR_TO_SECOND: + case SQL_C_INTERVAL_MINUTE: + case SQL_C_INTERVAL_MINUTE_TO_SECOND: + case SQL_C_INTERVAL_SECOND: + case SQL_C_INTERVAL_YEAR: + case SQL_C_INTERVAL_YEAR_TO_MONTH: + case SQL_C_INTERVAL_MONTH: + return sizeof(SQL_INTERVAL_STRUCT); + default: + return record.m_length; + } + } + + SQLSMALLINT getCTypeForSQLType(const DescriptorRecord& record) { + switch (record.m_conciseType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + return SQL_C_CHAR; + + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + return SQL_C_WCHAR; + + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + return SQL_C_BINARY; + + case SQL_TINYINT: + return record.m_unsigned ? SQL_C_UTINYINT : SQL_C_STINYINT; + + case SQL_SMALLINT: + return record.m_unsigned ? SQL_C_USHORT : SQL_C_SSHORT; + + case SQL_INTEGER: + return record.m_unsigned ? SQL_C_ULONG : SQL_C_SLONG; + + case SQL_BIGINT: + return record.m_unsigned ? SQL_C_UBIGINT : SQL_C_SBIGINT; + + case SQL_REAL: + return SQL_C_FLOAT; + + case SQL_FLOAT: + case SQL_DOUBLE: + return SQL_C_DOUBLE; + + case SQL_DATE: + case SQL_TYPE_DATE: + return SQL_C_TYPE_DATE; + + case SQL_TIME: + case SQL_TYPE_TIME: + return SQL_C_TYPE_TIME; + + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + return SQL_C_TYPE_TIMESTAMP; + + case SQL_C_INTERVAL_DAY: + return SQL_INTERVAL_DAY; + case SQL_C_INTERVAL_DAY_TO_HOUR: + return SQL_INTERVAL_DAY_TO_HOUR; + case SQL_C_INTERVAL_DAY_TO_MINUTE: + return SQL_INTERVAL_DAY_TO_MINUTE; + case SQL_C_INTERVAL_DAY_TO_SECOND: + return SQL_INTERVAL_DAY_TO_SECOND; + case SQL_C_INTERVAL_HOUR: + return SQL_INTERVAL_HOUR; + case SQL_C_INTERVAL_HOUR_TO_MINUTE: + return SQL_INTERVAL_HOUR_TO_MINUTE; + case SQL_C_INTERVAL_HOUR_TO_SECOND: + return SQL_INTERVAL_HOUR_TO_SECOND; + case SQL_C_INTERVAL_MINUTE: + return SQL_INTERVAL_MINUTE; + case SQL_C_INTERVAL_MINUTE_TO_SECOND: + return SQL_INTERVAL_MINUTE_TO_SECOND; + case SQL_C_INTERVAL_SECOND: + return SQL_INTERVAL_SECOND; + case SQL_C_INTERVAL_YEAR: + return SQL_INTERVAL_YEAR; + case SQL_C_INTERVAL_YEAR_TO_MONTH: + return SQL_INTERVAL_YEAR_TO_MONTH; + case SQL_C_INTERVAL_MONTH: + return SQL_INTERVAL_MONTH; + + default: + throw DriverException("Unknown SQL type: " + std::to_string(record.m_conciseType), "HY003"); + } + } + + void CopyAttribute(Statement& source, Statement& target, Statement::StatementAttributeId attributeId) { + auto optionalValue = source.GetAttribute(attributeId); + if (optionalValue) { + target.SetAttribute(attributeId, *optionalValue); + } + } +} + +// Public ========================================================================================= +ODBCStatement::ODBCStatement(ODBCConnection& connection, + std::shared_ptr spiStatement) : + m_connection(connection), + m_spiStatement(std::move(spiStatement)), + m_diagnostics(&m_spiStatement->GetDiagnostics()), + m_builtInArd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, true, true, connection.IsOdbc2Connection())), + m_builtInApd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, true, true, connection.IsOdbc2Connection())), + m_ipd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, false, true, connection.IsOdbc2Connection())), + m_ird(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, false, false, connection.IsOdbc2Connection())), + m_currentArd(m_builtInApd.get()), + m_currentApd(m_builtInApd.get()), + m_rowNumber(0), + m_maxRows(0), + m_rowsetSize(1), + m_isPrepared(false), + m_hasReachedEndOfResult(false) { +} + +ODBCConnection &ODBCStatement::GetConnection() { + return m_connection; +} + +void ODBCStatement::CopyAttributesFromConnection(ODBCConnection& connection) { + ODBCStatement& trackingStatement = connection.GetTrackingStatement(); + + // Get abstraction attributes and copy to this m_spiStatement. + // Possible ODBC attributes are below, but many of these are not supported by warpdrive + // or ODBCAbstaction: + // SQL_ATTR_ASYNC_ENABLE: + // SQL_ATTR_METADATA_ID: + // SQL_ATTR_CONCURRENCY: + // SQL_ATTR_CURSOR_TYPE: + // SQL_ATTR_KEYSET_SIZE: + // SQL_ATTR_MAX_LENGTH: + // SQL_ATTR_MAX_ROWS: + // SQL_ATTR_NOSCAN: + // SQL_ATTR_QUERY_TIMEOUT: + // SQL_ATTR_RETRIEVE_DATA: + // SQL_ATTR_SIMULATE_CURSOR: + // SQL_ATTR_USE_BOOKMARKS: + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::METADATA_ID); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::MAX_LENGTH); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::NOSCAN); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::QUERY_TIMEOUT); + + // SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->SetHeaderField(SQL_DESC_BIND_TYPE, + reinterpret_cast(static_cast(trackingStatement.m_currentArd->GetBoundStructOffset())), 0); +} + +bool ODBCStatement::isPrepared() const { + return m_isPrepared; +} + +void ODBCStatement::Prepare(const std::string& query) { + boost::optional > metadata = m_spiStatement->Prepare(query); + + if (metadata) { + m_ird->PopulateFromResultSetMetadata(metadata->get()); + } + m_isPrepared = true; +} + +void ODBCStatement::ExecutePrepared() { + if (!m_isPrepared) { + throw DriverException("Function sequence error", "HY010"); + } + + if (m_spiStatement->ExecutePrepared()) { + m_currenResult = m_spiStatement->GetResultSet(); + m_ird->PopulateFromResultSetMetadata(m_spiStatement->GetResultSet()->GetMetadata().get()); + m_hasReachedEndOfResult = false; + } +} + +void ODBCStatement::ExecuteDirect(const std::string& query) { + if (m_spiStatement->Execute(query)) { + m_currenResult = m_spiStatement->GetResultSet(); + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + } + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +bool ODBCStatement::Fetch(size_t rows) { + if (m_hasReachedEndOfResult) { + m_ird->SetRowsProcessed(0); + return false; + } + + if (m_maxRows) { + rows = std::min(rows, m_maxRows - m_rowNumber); + } + + if (m_currentArd->HaveBindingsChanged()) { + // TODO: Deal handle when offset != bufferlength. + + // Wipe out all bindings in the ResultSet. + // Note that the number of ARD records can both be more or less + // than the number of columns. + for (size_t i = 0; i < m_ird->GetRecords().size(); i++) { + if (i < m_currentArd->GetRecords().size() && m_currentArd->GetRecords()[i].m_isBound) { + const DescriptorRecord& ardRecord = m_currentArd->GetRecords()[i]; + m_currenResult->BindColumn(i+1, ardRecord.m_type, ardRecord.m_precision, + ardRecord.m_scale, ardRecord.m_dataPtr, + GetLength(ardRecord), + ardRecord.m_indicatorPtr); + } else { + m_currenResult->BindColumn(i+1, CDataType_CHAR /* arbitrary type, not used */, 0, 0, nullptr, 0, nullptr); + } + } + m_currentArd->NotifyBindingsHavePropagated(); + } + + size_t rowsFetched = m_currenResult->Move(rows, m_currentArd->GetBindOffset(), + m_currentArd->GetBoundStructOffset(), m_ird->GetArrayStatusPtr()); + m_ird->SetRowsProcessed(static_cast(rowsFetched)); + + m_rowNumber += rowsFetched; + m_hasReachedEndOfResult = rowsFetched != rows; + return rowsFetched != 0; +} + +void ODBCStatement::GetStmtAttr(SQLINTEGER statementAttribute, + SQLPOINTER output, SQLINTEGER bufferSize, + SQLINTEGER *strLenPtr, bool isUnicode) { + using driver::odbcabstraction::Statement; + boost::optional spiAttribute; + switch (statementAttribute) { + // Descriptor accessor attributes + case SQL_ATTR_APP_PARAM_DESC: + DescriptorToHandle(output, m_currentApd, strLenPtr); + return; + case SQL_ATTR_APP_ROW_DESC: + DescriptorToHandle(output, m_currentArd, strLenPtr); + return; + case SQL_ATTR_IMP_PARAM_DESC: + DescriptorToHandle(output, m_ipd.get(), strLenPtr); + return; + case SQL_ATTR_IMP_ROW_DESC: + DescriptorToHandle(output, m_ird.get(), strLenPtr); + return; + + // Attributes that are descriptor fields + case SQL_ATTR_PARAM_BIND_OFFSET_PTR: + m_currentApd->GetHeaderField(SQL_DESC_BIND_OFFSET_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_BIND_TYPE: + m_currentApd->GetHeaderField(SQL_DESC_BIND_TYPE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_OPERATION_PTR: + m_currentApd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_STATUS_PTR: + m_ipd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAMS_PROCESSED_PTR: + m_ipd->GetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAMSET_SIZE: + m_currentApd->GetHeaderField(SQL_DESC_ARRAY_SIZE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_ARRAY_SIZE: + m_currentArd->GetHeaderField(SQL_DESC_ARRAY_SIZE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_BIND_OFFSET_PTR: + m_currentArd->GetHeaderField(SQL_DESC_BIND_OFFSET_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->GetHeaderField(SQL_DESC_BIND_TYPE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_OPERATION_PTR: + m_currentArd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_STATUS_PTR: + m_ird->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROWS_FETCHED_PTR: + m_ird->GetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ASYNC_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), output, bufferSize, strLenPtr); + return; + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT + case SQL_ATTR_ASYNC_STMT_EVENT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK + case SQL_ATTR_ASYNC_STMT_PCALLBACK: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT + case SQL_ATTR_ASYNC_STMT_PCONTEXT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif + case SQL_ATTR_CURSOR_SCROLLABLE: + GetAttribute(static_cast(SQL_NONSCROLLABLE), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_CURSOR_SENSITIVITY: + GetAttribute(static_cast(SQL_UNSPECIFIED), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_CURSOR_TYPE: + GetAttribute(static_cast(SQL_CURSOR_FORWARD_ONLY), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ENABLE_AUTO_IPD: + GetAttribute(static_cast(SQL_FALSE), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_FETCH_BOOKMARK_PTR: + GetAttribute(static_cast(NULL), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_KEYSET_SIZE: + GetAttribute(static_cast(0), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ROW_NUMBER: + GetAttribute(static_cast(m_rowNumber), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_SIMULATE_CURSOR: + GetAttribute(static_cast(SQL_SC_UNIQUE), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_USE_BOOKMARKS: + GetAttribute(static_cast(SQL_UB_OFF), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_CONCURRENCY: + GetAttribute(static_cast(SQL_CONCUR_READ_ONLY), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_MAX_ROWS: + GetAttribute(static_cast(m_maxRows), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_RETRIEVE_DATA: + GetAttribute(static_cast(SQL_RD_ON), output, bufferSize, strLenPtr); + return; + case SQL_ROWSET_SIZE: + GetAttribute(static_cast(m_rowsetSize), output, bufferSize, strLenPtr); + return; + + // Driver-level statement attributes. These are all SQLULEN attributes. + case SQL_ATTR_MAX_LENGTH: + spiAttribute = m_spiStatement->GetAttribute(Statement::MAX_LENGTH); + break; + case SQL_ATTR_METADATA_ID: + spiAttribute = m_spiStatement->GetAttribute(Statement::METADATA_ID); + break; + case SQL_ATTR_NOSCAN: + spiAttribute = m_spiStatement->GetAttribute(Statement::NOSCAN); + break; + case SQL_ATTR_QUERY_TIMEOUT: + spiAttribute = m_spiStatement->GetAttribute(Statement::QUERY_TIMEOUT); + break; + default: + throw DriverException("Invalid statement attribute: " + std::to_string(statementAttribute), "HY092"); + } + + if (spiAttribute) { + GetAttribute(static_cast(boost::get(*spiAttribute)), + output, bufferSize, strLenPtr); + return; + } + + throw DriverException("Invalid statement attribute: " + std::to_string(statementAttribute), "HY092"); +} + +void ODBCStatement::SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, + SQLINTEGER bufferSize, bool isUnicode) { + size_t attributeToWrite = 0; + bool successfully_written = false; + + switch (statementAttribute) { + case SQL_ATTR_APP_PARAM_DESC: { + ODBCDescriptor* desc = static_cast(value); + if (m_currentApd != desc) { + if (m_currentApd != m_builtInApd.get()) { + m_currentApd->DetachFromStatement(this, true); + } + m_currentApd = desc; + if (m_currentApd != m_builtInApd.get()) { + desc->RegisterToStatement(this, true); + } + } + return; + } + case SQL_ATTR_APP_ROW_DESC: { + ODBCDescriptor* desc = static_cast(value); + if (m_currentArd != desc) { + if (m_currentArd != m_builtInArd.get()) { + m_currentArd->DetachFromStatement(this, false); + } + m_currentArd = desc; + if (m_currentArd != m_builtInArd.get()) { + desc->RegisterToStatement(this, false); + } + } + return; + } + case SQL_ATTR_IMP_PARAM_DESC: + throw DriverException("Cannot assign implementation descriptor.", "HY017"); + case SQL_ATTR_IMP_ROW_DESC: + throw DriverException("Cannot assign implementation descriptor.", "HY017"); + // Attributes that are descriptor fields + case SQL_ATTR_PARAM_BIND_OFFSET_PTR: + m_currentApd->SetHeaderField(SQL_DESC_BIND_OFFSET_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAM_BIND_TYPE: + m_currentApd->SetHeaderField(SQL_DESC_BIND_TYPE, value, bufferSize); + return; + case SQL_ATTR_PARAM_OPERATION_PTR: + m_currentApd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAM_STATUS_PTR: + m_ipd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAMS_PROCESSED_PTR: + m_ipd->SetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAMSET_SIZE: + m_currentApd->SetHeaderField(SQL_DESC_ARRAY_SIZE, value, bufferSize); + return; + case SQL_ATTR_ROW_ARRAY_SIZE: + m_currentArd->SetHeaderField(SQL_DESC_ARRAY_SIZE, value, bufferSize); + return; + case SQL_ATTR_ROW_BIND_OFFSET_PTR: + m_currentArd->SetHeaderField(SQL_DESC_BIND_OFFSET_PTR, value, bufferSize); + return; + case SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->SetHeaderField(SQL_DESC_BIND_TYPE, value, bufferSize); + return; + case SQL_ATTR_ROW_OPERATION_PTR: + m_currentArd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_ROW_STATUS_PTR: + m_ird->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_ROWS_FETCHED_PTR: + m_ird->SetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, value, bufferSize); + return; + + case SQL_ATTR_ASYNC_ENABLE: +#ifdef SQL_ATTR_ASYNC_STMT_EVENT + case SQL_ATTR_ASYNC_STMT_EVENT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK + case SQL_ATTR_ASYNC_STMT_PCALLBACK: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT + case SQL_ATTR_ASYNC_STMT_PCONTEXT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif + case SQL_ATTR_CONCURRENCY: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_CONCUR_READ_ONLY)); + return; + case SQL_ATTR_CURSOR_SCROLLABLE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_NONSCROLLABLE)); + return; + case SQL_ATTR_CURSOR_SENSITIVITY: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UNSPECIFIED)); + return; + case SQL_ATTR_CURSOR_TYPE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_CURSOR_FORWARD_ONLY)); + return; + case SQL_ATTR_ENABLE_AUTO_IPD: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_FALSE)); + return; + case SQL_ATTR_FETCH_BOOKMARK_PTR: + if (value != NULL) { + throw DriverException("Optional feature not implemented", "HYC00"); + } + return; + case SQL_ATTR_KEYSET_SIZE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(0)); + return; + case SQL_ATTR_ROW_NUMBER: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_SIMULATE_CURSOR: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_SC_UNIQUE)); + return; + case SQL_ATTR_USE_BOOKMARKS: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UB_OFF)); + return; + case SQL_ATTR_RETRIEVE_DATA: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_TRUE)); + return; + case SQL_ROWSET_SIZE: + SetAttribute(value, m_rowsetSize); + return; + + case SQL_ATTR_MAX_ROWS: + throw DriverException("Cannot set read-only attribute", "HY092"); + + // Driver-leve statement attributes. These are all size_t attributes + case SQL_ATTR_MAX_LENGTH: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::MAX_LENGTH, attributeToWrite); + break; + case SQL_ATTR_METADATA_ID: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::METADATA_ID, attributeToWrite); + break; + case SQL_ATTR_NOSCAN: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::NOSCAN, attributeToWrite); + break; + case SQL_ATTR_QUERY_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::QUERY_TIMEOUT, attributeToWrite); + break; + default: + throw DriverException("Invalid attribute: " + std::to_string(attributeToWrite), "HY092"); + } + if (!successfully_written) { + GetDiagnostics().AddWarning("Optional value changed.", "01S02", ODBCErrorCodes_GENERAL_WARNING); + } +} + +void ODBCStatement::RevertAppDescriptor(bool isApd) { + if (isApd) { + m_currentApd = m_builtInApd.get(); + } else { + m_currentArd = m_builtInArd.get(); + } +} + +void ODBCStatement::closeCursor(bool suppressErrors) { + if (!suppressErrors && !m_currenResult) { + throw DriverException("Invalid cursor state", "28000"); + } + + if (m_currenResult) { + m_currenResult->Close(); + m_currenResult = nullptr; + } + + // Reset the fetching state of this statement. + m_currentArd->NotifyBindingsHaveChanged(); + m_rowNumber = 0; + m_hasReachedEndOfResult = false; +} + +bool ODBCStatement::GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + if (recordNumber == 0) { + throw DriverException("Bookmarks are not supported", "07009"); + } else if (recordNumber > m_ird->GetRecords().size()) { + throw DriverException("Invalid column index: " + std::to_string(recordNumber), "07009"); + } + + SQLSMALLINT evaluatedCType = cType; + + // TODO: Get proper default precision and scale from abstraction. + int precision = 38; // arrow::Decimal128Type::kMaxPrecision; + int scale = 0; + + if (cType == SQL_ARD_TYPE) { + if (recordNumber > m_currentArd->GetRecords().size()) { + throw DriverException("Invalid column index: " + std::to_string(recordNumber), "07009"); + } + const DescriptorRecord& record = m_currentArd->GetRecords()[recordNumber-1]; + evaluatedCType = record.m_conciseType; + precision = record.m_precision; + scale = record.m_scale; + } + + // Note: this is intentionally not an else if, since the type can be SQL_C_DEFAULT in the ARD. + if (evaluatedCType == SQL_C_DEFAULT) { + if (recordNumber <= m_currentArd->GetRecords().size()) { + const DescriptorRecord &ardRecord = + m_currentArd->GetRecords()[recordNumber - 1]; + precision = ardRecord.m_precision; + scale = ardRecord.m_scale; + } + + const DescriptorRecord& irdRecord = m_ird->GetRecords()[recordNumber-1]; + evaluatedCType = getCTypeForSQLType(irdRecord); + } + + return m_currenResult->GetData(recordNumber, evaluatedCType, precision, + scale, dataPtr, bufferLength, indicatorPtr); +} + +void ODBCStatement::releaseStatement() { + closeCursor(true); + m_connection.dropStatement(this); +} + +void ODBCStatement::GetTables(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* tableType) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetTables_V2(catalog, schema, table, tableType); + } else { + m_currenResult = m_spiStatement->GetTables_V3(catalog, schema, table, tableType); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::GetColumns(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* column) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetColumns_V2(catalog, schema, table, column); + } else { + m_currenResult = m_spiStatement->GetColumns_V3(catalog, schema, table, column); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::GetTypeInfo(SQLSMALLINT dataType) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetTypeInfo_V2(dataType); + } else { + m_currenResult = m_spiStatement->GetTypeInfo_V3(dataType); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::Cancel() { + m_spiStatement->Cancel(); +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/spd_logger.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/spd_logger.cc new file mode 100644 index 0000000000000..44708349221ab --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/spd_logger.cc @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "odbcabstraction/spd_logger.h" + +#include "odbcabstraction/logger.h" + +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +const std::string SPDLogger::LOG_LEVEL = "LogLevel"; +const std::string SPDLogger::LOG_PATH= "LogPath"; +const std::string SPDLogger::MAXIMUM_FILE_SIZE= "MaximumFileSize"; +const std::string SPDLogger::FILE_QUANTITY= "FileQuantity"; +const std::string SPDLogger::LOG_ENABLED= "LogEnabled"; + +namespace { +std::function shutdown_handler; +void signal_handler(int signal) { + shutdown_handler(signal); +} + +typedef void (*Handler)(int signum); + +Handler old_sigint_handler = SIG_IGN; +Handler old_sigsegv_handler = SIG_IGN; +Handler old_sigabrt_handler = SIG_IGN; +#ifdef SIGKILL +Handler old_sigkill_handler = SIG_IGN; +#endif + +Handler GetHandlerFromSignal(int signum) { + switch (signum) { + case(SIGINT): + return old_sigint_handler; + case(SIGSEGV): + return old_sigsegv_handler; + case(SIGABRT): + return old_sigabrt_handler; +#ifdef SIGKILL + case(SIGKILL): + return old_sigkill_handler; +#endif + } +} + +void SetSignalHandler(int signum) { + Handler old = signal(signum, SIG_IGN); + if (old != SIG_IGN) { + auto old_handler = GetHandlerFromSignal(signum); + old_handler = old; + } + signal(signum, signal_handler); +} + +void ResetSignalHandler(int signum) { + Handler actual_handler = signal(signum, SIG_IGN); + if (actual_handler == signal_handler) { + signal(signum, GetHandlerFromSignal(signum)); + } +} + + +inline spdlog::level::level_enum ToSpdLogLevel(LogLevel level) { + switch (level) { + case LogLevel_TRACE: + return spdlog::level::trace; + case LogLevel_DEBUG: + return spdlog::level::debug; + case LogLevel_INFO: + return spdlog::level::info; + case LogLevel_WARN: + return spdlog::level::warn; + case LogLevel_ERROR: + return spdlog::level::err; + default: + return spdlog::level::off; + } +} +} // namespace + +void SPDLogger::init(int64_t fileQuantity, int64_t maxFileSize, + const std::string &fileNamePrefix, LogLevel level) { + logger_ = spdlog::rotating_logger_mt( + "ODBC Logger", fileNamePrefix, maxFileSize, fileQuantity); + + logger_->set_level(ToSpdLogLevel(level)); + + if (level != LogLevel::LogLevel_OFF) { + SetSignalHandler(SIGINT); + SetSignalHandler(SIGSEGV); + SetSignalHandler(SIGABRT); +#ifdef SIGKILL + SetSignalHandler(SIGKILL); +#endif + shutdown_handler = [&](int signal) { + logger_->flush(); + spdlog::shutdown(); + auto handler = GetHandlerFromSignal(signal); + handler(signal); + }; + } +} + +void SPDLogger::log(LogLevel level, const std::function &build_message) { + auto level_set = logger_->level(); + spdlog::level::level_enum spdlog_level = ToSpdLogLevel(level); + if (level_set == spdlog::level::off || level_set > spdlog_level) { + return; + } + + const std::string &message = build_message(); + logger_->log(spdlog_level, message); +} + +SPDLogger::~SPDLogger() { + ResetSignalHandler(SIGINT); + ResetSignalHandler(SIGSEGV); + ResetSignalHandler(SIGABRT); +#ifdef SIGKILL + ResetSignalHandler(SIGKILL); +#endif +} + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc new file mode 100644 index 0000000000000..7fd4db579fc33 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "whereami.h" + +#include +#include + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +boost::optional AsBool(const std::string& value) { + if (boost::iequals(value, "true") || boost::iequals(value, "1")) { + return true; + } else if (boost::iequals(value, "false") || boost::iequals(value, "0")) { + return false; + } else { + return boost::none; + } +} + +boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, + const std::string& property_name) { + auto extracted_property = connPropertyMap.find(property_name); + + if (extracted_property != connPropertyMap.end()) { + return AsBool(extracted_property->second); + } + + return boost::none; +} + +boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, const std::string& property_name) { + auto extracted_property = connPropertyMap.find(property_name); + + if (extracted_property != connPropertyMap.end()) { + const int32_t stringColumnLength = std::stoi(extracted_property->second); + + if (stringColumnLength >= min_value && stringColumnLength <= INT32_MAX) { + return stringColumnLength; + } + } + return boost::none; +} + +std::string GetModulePath() { + std::vector path; + int length, dirname_length; + length = wai_getModulePath(NULL, 0, &dirname_length); + + if (length != 0) { + path.resize(length); + wai_getModulePath(path.data(), length, &dirname_length); + } else { + throw DriverException("Could not find module path."); + } + + return std::string(path.begin(), path.begin() + dirname_length); +} + +void ReadConfigFile(PropertyMap &properties, const std::string &config_file_name) { + auto config_path = GetModulePath(); + + std::ifstream config_file; + auto config_file_path = config_path + "/" + config_file_name; + config_file.open(config_file_path); + + if (config_file.fail()) { + auto error_msg = "Arrow Flight SQL ODBC driver config file not found on \"" + config_file_path + "\""; + std::cerr << error_msg << std::endl; + + throw DriverException(error_msg); + } + + std::string temp_config; + + boost::char_separator separator("="); + while(config_file.good()) { + config_file >> temp_config; + boost::tokenizer> tokenizer(temp_config, separator); + + auto iterator = tokenizer.begin(); + + std::string key = *iterator; + std::string value = *++iterator; + + properties[key] = std::move(value); + } +} + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.cc new file mode 100644 index 0000000000000..39324d16e2cb8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.cc @@ -0,0 +1,804 @@ +// (‑●‑●)> dual licensed under the WTFPL v2 and MIT licenses +// without any warranty. +// by Gregory Pakosz (@gpakosz) +// https://github.com/gpakosz/whereami + +// in case you want to #include "whereami.c" in a larger compilation unit +#if !defined(WHEREAMI_H) +#include "whereami.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__linux__) || defined(__CYGWIN__) +#undef _DEFAULT_SOURCE +#define _DEFAULT_SOURCE +#elif defined(__APPLE__) +#undef _DARWIN_C_SOURCE +#define _DARWIN_C_SOURCE +#define _DARWIN_BETTER_REALPATH +#endif + +#if !defined(WAI_MALLOC) || !defined(WAI_FREE) || !defined(WAI_REALLOC) +#include +#endif + +#if !defined(WAI_MALLOC) +#define WAI_MALLOC(size) malloc(size) +#endif + +#if !defined(WAI_FREE) +#define WAI_FREE(p) free(p) +#endif + +#if !defined(WAI_REALLOC) +#define WAI_REALLOC(p, size) realloc(p, size) +#endif + +#ifndef WAI_NOINLINE +#if defined(_MSC_VER) +#define WAI_NOINLINE __declspec(noinline) +#elif defined(__GNUC__) +#define WAI_NOINLINE __attribute__((noinline)) +#else +#error unsupported compiler +#endif +#endif + +#if defined(_MSC_VER) +#define WAI_RETURN_ADDRESS() _ReturnAddress() +#elif defined(__GNUC__) +#define WAI_RETURN_ADDRESS() __builtin_extract_return_addr(__builtin_return_address(0)) +#else +#error unsupported compiler +#endif + +#if defined(_WIN32) + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#if defined(_MSC_VER) +#pragma warning(push, 3) +#endif +#include +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +#include + +static int WAI_PREFIX(getModulePath_)(HMODULE module, char* out, int capacity, int* dirname_length) +{ + wchar_t buffer1[MAX_PATH]; + wchar_t buffer2[MAX_PATH]; + wchar_t* path = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + DWORD size; + int length_, length__; + + size = GetModuleFileNameW(module, buffer1, sizeof(buffer1) / sizeof(buffer1[0])); + + if (size == 0) + break; + else if (size == (DWORD)(sizeof(buffer1) / sizeof(buffer1[0]))) + { + DWORD size_ = size; + do + { + wchar_t* path_; + + path_ = (wchar_t*)WAI_REALLOC(path, sizeof(wchar_t) * size_ * 2); + if (!path_) + break; + size_ *= 2; + path = path_; + size = GetModuleFileNameW(module, path, size_); + } + while (size == size_); + + if (size == size_) + break; + } + else + path = buffer1; + + if (!_wfullpath(buffer2, path, MAX_PATH)) + break; + length_ = (int)wcslen(buffer2); + length__ = WideCharToMultiByte(CP_UTF8, 0, buffer2, length_ , out, capacity, NULL, NULL); + + if (length__ == 0) + length__ = WideCharToMultiByte(CP_UTF8, 0, buffer2, length_, NULL, 0, NULL, NULL); + if (length__ == 0) + break; + + if (length__ <= capacity && dirname_length) + { + int i; + + for (i = length__ - 1; i >= 0; --i) + { + if (out[i] == '\\') + { + *dirname_length = i; + break; + } + } + } + + length = length__; + } + + if (path != buffer1) + WAI_FREE(path); + + return ok ? length : -1; +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + return WAI_PREFIX(getModulePath_)(NULL, out, capacity, dirname_length); +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + HMODULE module; + int length = -1; + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4054) +#endif + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, (LPCTSTR)WAI_RETURN_ADDRESS(), &module)) +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + { + length = WAI_PREFIX(getModulePath_)(module, out, capacity, dirname_length); + } + + return length; +} + +#elif defined(__linux__) || defined(__CYGWIN__) || defined(__sun) || defined(WAI_USE_PROC_SELF_EXE) + +#include +#include +#include +#if defined(__linux__) +#include +#else +#include +#endif +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif +#include +#include + +#if !defined(WAI_PROC_SELF_EXE) +#if defined(__sun) +#define WAI_PROC_SELF_EXE "/proc/self/path/a.out" +#else +#define WAI_PROC_SELF_EXE "/proc/self/exe" +#endif +#endif + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + resolved = realpath(WAI_PROC_SELF_EXE, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + return ok ? length : -1; +} + +#if !defined(WAI_PROC_SELF_MAPS_RETRY) +#define WAI_PROC_SELF_MAPS_RETRY 5 +#endif + +#if !defined(WAI_PROC_SELF_MAPS) +#if defined(__sun) +#define WAI_PROC_SELF_MAPS "/proc/self/map" +#else +#define WAI_PROC_SELF_MAPS "/proc/self/maps" +#endif +#endif + +#if defined(__ANDROID__) || defined(ANDROID) +#include +#include +#include +#endif +#include + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + int length = -1; + FILE* maps = NULL; + + for (int r = 0; r < WAI_PROC_SELF_MAPS_RETRY; ++r) + { + maps = fopen(WAI_PROC_SELF_MAPS, "r"); + if (!maps) + break; + + for (;;) + { + char buffer[PATH_MAX < 1024 ? 1024 : PATH_MAX]; + uint64_t low, high; + char perms[5]; + uint64_t offset; + uint32_t major, minor; + char path[PATH_MAX]; + uint32_t inode; + + if (!fgets(buffer, sizeof(buffer), maps)) + break; + + if (sscanf(buffer, "%" PRIx64 "-%" PRIx64 " %s %" PRIx64 " %x:%x %u %s\n", &low, &high, perms, &offset, &major, &minor, &inode, path) == 8) + { + uint64_t addr = (uintptr_t)WAI_RETURN_ADDRESS(); + if (low <= addr && addr <= high) + { + char* resolved; + + resolved = realpath(path, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); +#if defined(__ANDROID__) || defined(ANDROID) + if (length > 4 + &&buffer[length - 1] == 'k' + &&buffer[length - 2] == 'p' + &&buffer[length - 3] == 'a' + &&buffer[length - 4] == '.') + { + int fd = open(path, O_RDONLY); + if (fd == -1) + { + length = -1; // retry + break; + } + + char* begin = (char*)mmap(0, offset, PROT_READ, MAP_SHARED, fd, 0); + if (begin == MAP_FAILED) + { + close(fd); + length = -1; // retry + break; + } + + char* p = begin + offset - 30; // minimum size of local file header + while (p >= begin) // scan backwards + { + if (*((uint32_t*)p) == 0x04034b50UL) // local file header signature found + { + uint16_t length_ = *((uint16_t*)(p + 26)); + + if (length + 2 + length_ < (int)sizeof(buffer)) + { + memcpy(&buffer[length], "!/", 2); + memcpy(&buffer[length + 2], p + 30, length_); + length += 2 + length_; + } + + break; + } + + --p; + } + + munmap(begin, offset); + close(fd); + } +#endif + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + + break; + } + } + } + + fclose(maps); + maps = NULL; + + if (length != -1) + break; + } + + return length; +} + +#elif defined(__APPLE__) + +#include +#include +#include +#include +#include +#include + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* path = buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + uint32_t size = (uint32_t)sizeof(buffer1); + if (_NSGetExecutablePath(path, &size) == -1) + { + path = (char*)WAI_MALLOC(size); + if (!_NSGetExecutablePath(path, &size)) + break; + } + + resolved = realpath(path, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + if (path != buffer1) + WAI_FREE(path); + + return ok ? length : -1; +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#elif defined(__QNXNTO__) + +#include +#include +#include +#include +#include +#include + +#if !defined(WAI_PROC_SELF_EXE) +#define WAI_PROC_SELF_EXE "/proc/self/exefile" +#endif + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* resolved = NULL; + FILE* self_exe = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + self_exe = fopen(WAI_PROC_SELF_EXE, "r"); + if (!self_exe) + break; + + if (!fgets(buffer1, sizeof(buffer1), self_exe)) + break; + + resolved = realpath(buffer1, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + fclose(self_exe); + + return ok ? length : -1; +} + +WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#elif defined(__DragonFly__) || defined(__FreeBSD__) || \ + defined(__FreeBSD_kernel__) || defined(__NetBSD__) || defined(__OpenBSD__) + +#include +#include +#include +#include +#include +#include +#include + +#if defined(__OpenBSD__) + +#include + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[4096]; + char buffer2[PATH_MAX]; + char buffer3[PATH_MAX]; + char** argv = (char**)buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + int mib[4] = { CTL_KERN, KERN_PROC_ARGS, getpid(), KERN_PROC_ARGV }; + size_t size; + + if (sysctl(mib, 4, NULL, &size, NULL, 0) != 0) + break; + + if (size > sizeof(buffer1)) + { + argv = (char**)WAI_MALLOC(size); + if (!argv) + break; + } + + if (sysctl(mib, 4, argv, &size, NULL, 0) != 0) + break; + + if (strchr(argv[0], '/')) + { + resolved = realpath(argv[0], buffer2); + if (!resolved) + break; + } + else + { + const char* PATH = getenv("PATH"); + if (!PATH) + break; + + size_t argv0_length = strlen(argv[0]); + + const char* begin = PATH; + while (1) + { + const char* separator = strchr(begin, ':'); + const char* end = separator ? separator : begin + strlen(begin); + + if (end - begin > 0) + { + if (*(end -1) == '/') + --end; + + if (((end - begin) + 1 + argv0_length + 1) <= sizeof(buffer2)) + { + memcpy(buffer2, begin, end - begin); + buffer2[end - begin] = '/'; + memcpy(buffer2 + (end - begin) + 1, argv[0], argv0_length + 1); + + resolved = realpath(buffer2, buffer3); + if (resolved) + break; + } + } + + if (!separator) + break; + + begin = ++separator; + } + + if (!resolved) + break; + } + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + if (argv != (char**)buffer1) + WAI_FREE(argv); + + return ok ? length : -1; +} + +#else + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* path = buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { +#if defined(__NetBSD__) + int mib[4] = { CTL_KERN, KERN_PROC_ARGS, -1, KERN_PROC_PATHNAME }; +#else + int mib[4] = { CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, -1 }; +#endif + size_t size = sizeof(buffer1); + + if (sysctl(mib, 4, path, &size, NULL, 0) != 0) + break; + + resolved = realpath(path, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + return ok ? length : -1; +} + +#endif + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#else + +#error unsupported platform + +#endif + +#ifdef __cplusplus +} +#endif diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.h new file mode 100644 index 0000000000000..ca62d674cd2d1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/whereami.h @@ -0,0 +1,67 @@ +// (‑●‑●)> dual licensed under the WTFPL v2 and MIT licenses +// without any warranty. +// by Gregory Pakosz (@gpakosz) +// https://github.com/gpakosz/whereami + +#ifndef WHEREAMI_H +#define WHEREAMI_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef WAI_FUNCSPEC +#define WAI_FUNCSPEC +#endif +#ifndef WAI_PREFIX +#define WAI_PREFIX(function) wai_##function +#endif + +/** + * Returns the path to the current executable. + * + * Usage: + * - first call `int length = wai_getExecutablePath(NULL, 0, NULL);` to + * retrieve the length of the path + * - allocate the destination buffer with `path = (char*)malloc(length + 1);` + * - call `wai_getExecutablePath(path, length, NULL)` again to retrieve the + * path + * - add a terminal NUL character with `path[length] = '\0';` + * + * @param out destination buffer, optional + * @param capacity destination buffer capacity + * @param dirname_length optional recipient for the length of the dirname part + * of the path. + * + * @return the length of the executable path on success (without a terminal NUL + * character), otherwise `-1` + */ +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length); + +/** + * Returns the path to the current module + * + * Usage: + * - first call `int length = wai_getModulePath(NULL, 0, NULL);` to retrieve + * the length of the path + * - allocate the destination buffer with `path = (char*)malloc(length + 1);` + * - call `wai_getModulePath(path, length, NULL)` again to retrieve the path + * - add a terminal NUL character with `path[length] = '\0';` + * + * @param out destination buffer, optional + * @param capacity destination buffer capacity + * @param dirname_length optional recipient for the length of the dirname part + * of the path. + * + * @return the length of the module path on success (without a terminal NUL + * character), otherwise `-1` + */ +WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length); + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef WHEREAMI_H diff --git a/cpp/src/arrow/flight/sql/odbc/vcpkg.json b/cpp/src/arrow/flight/sql/odbc/vcpkg.json new file mode 100644 index 0000000000000..519d6441bec61 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/vcpkg.json @@ -0,0 +1,31 @@ +{ + "name": "flightsql-odbc", + "version-string": "1.0.0", + "dependencies": [ + "abseil", + "benchmark", + "boost-beast", + "boost-crc", + "boost-filesystem", + "boost-locale", + "boost-multiprecision", + "boost-optional", + "boost-process", + "boost-system", + "boost-variant", + "boost-xpressive", + "brotli", + "gflags", + "openssl", + "protobuf", + "zlib", + "re2", + "spdlog", + "grpc", + "utf8proc", + "zlib", + "zstd", + "rapidjson" + ], + "builtin-baseline": "4e485c34f5e056327ef00c14e2e3620bc50de098" +}