diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index dee5f934fb7..76d0024680a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -1005,22 +1005,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt << ", query_text: " << static_cast(query_text) << ", text_length: " << text_length; - // GH-47711 TODO: Implement SQLExecDirect - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(query_text, text_length); + + statement->Prepare(query); + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); } SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) { ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt << ", query_text: " << static_cast(query_text) << ", text_length: " << text_length; - // GH-47712 TODO: Implement SQLPrepare - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(query_text, text_length); + + statement->Prepare(query); + + return SQL_SUCCESS; + }); } SQLRETURN SQLExecute(SQLHSTMT stmt) { ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt; - // GH-47712 TODO: Implement SQLExecute - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); } SQLRETURN SQLFetch(SQLHSTMT stmt) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc index 785a04c7b0e..f6c6da860df 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc @@ -69,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, call_options_.timeout = TimeoutDuration{-1}; } +FlightSqlStatement::~FlightSqlStatement() { + ClosePreparedStatementIfAny(prepared_statement_, call_options_); +} + bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, const Attribute& value) { switch (attribute) { @@ -119,7 +123,6 @@ bool FlightSqlStatement::ExecutePrepared() { Result> result = prepared_statement_->Execute(call_options_); - ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h index 3593b2f774d..d61f8ef3787 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h @@ -49,6 +49,7 @@ class FlightSqlStatement : public Statement { FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings); + ~FlightSqlStatement(); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc index 9d6d42c4a11..a83855c2182 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -37,9 +37,86 @@ class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {}; using TestTypes = ::testing::Types; TYPED_TEST_SUITE(StatementTest, TestTypes); +TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation + /* + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Verify 1 is returned + EXPECT_EQ(1, val); + + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); + */ +} + +TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) { + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000); +} + +TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt)); + + // GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation + /* + // Fetch data + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + + // Verify 1 is returned + EXPECT_EQ(1, val); + + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); + */ +} + +TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) { + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size()))); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000); + + ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt)); + // Verify function sequence error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); +} + TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) { SQLWCHAR buf[1024]; - SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize(); SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; SQLINTEGER input_char_len = static_cast(wcslen(input_str)); SQLINTEGER output_char_len = 0; @@ -58,7 +135,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) { TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) { SQLWCHAR buf[1024]; - SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize(); SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; SQLINTEGER input_char_len = static_cast(wcslen(input_str)); SQLINTEGER output_char_len = 0; @@ -95,7 +172,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) { TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) { const SQLINTEGER small_buf_size_in_char = 11; SQLWCHAR small_buf[small_buf_size_in_char]; - SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize(); + SQLINTEGER small_buf_char_len = sizeof(small_buf) / GetSqlWCharSize(); SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; SQLINTEGER input_char_len = static_cast(wcslen(input_str)); SQLINTEGER output_char_len = 0; @@ -122,7 +199,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) { TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) { SQLWCHAR buf[1024]; - SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize(); SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; SQLINTEGER input_char_len = static_cast(wcslen(input_str)); SQLINTEGER output_char_len = 0;