diff --git a/docs/README.md b/docs/README.md index 4536b23..6c49cde 100644 --- a/docs/README.md +++ b/docs/README.md @@ -27,7 +27,7 @@ Welcome to the sqlgen documentation. This guide provides detailed information ab - [sqlgen::exec](exec.md) - How to execute raw SQL statements - [sqlgen::group_by and Aggregations](group_by_and_aggregations.md) - How generate GROUP BY queries and aggregate data - [sqlgen::inner_join, sqlgen::left_join, sqlgen::right_join, sqlgen::full_join](joins.md) - How to join different tables -- [sqlgen::insert, sqlgen::insert_or_replace](insert.md) - How to insert data within transactions +- [sqlgen::insert, sqlgen::insert_or_replace, sqlgen::returning](insert.md) - How to insert data within transactions - [sqlgen::select_from](select_from.md) - How to read data from a database using more complex queries - [sqlgen::unite and sqlgen::unite_all](unite.md) - How to combine results from multiple SELECT statements - [sqlgen::update](update.md) - How to update data in a table diff --git a/docs/insert.md b/docs/insert.md index 89ca2ea..2c8359b 100644 --- a/docs/insert.md +++ b/docs/insert.md @@ -1,4 +1,4 @@ -# `sqlgen::insert`, `sqlgen::insert_or_replace` +# `sqlgen::insert`, `sqlgen::insert_or_replace`, `sqlgen::returning` The `sqlgen::insert` interface provides a type-safe way to insert data from C++ containers or ranges into a SQL database. Unlike `sqlgen::write`, it does not create tables automatically and is designed to be used within transactions. It's particularly useful when you need fine-grained control over table creation and transaction boundaries. @@ -68,76 +68,67 @@ sqlgen::sqlite::connect("database.db") .value(); ``` -### With Replacement (`insert_or_replace`) +### Conflict Policies (`or_replace`, `or_ignore`) -The `insert_or_replace` helper inserts rows and updates existing rows when a primary key or unique constraint would be violated by the insert. It is a thin wrapper over the same insertion paths used by `insert`, but it sets the internal `or_replace` flag so the transpiler emits backend-specific "upsert" SQL. - -Function signatures (examples): +`insert(...)` supports typed conflict-policy tags: ```cpp -// Use with an explicit connection (or a Result>) -template -auto insert_or_replace(const auto& conn, const ContainerType& data); +using namespace sqlgen; + +insert(people, or_replace); +insert(people, or_ignore); -// Use as a pipeline element (returns a callable that accepts a connection) -template -auto insert_or_replace(const ContainerType& data); +// Pipeline style is also supported (suggest): +insert(people) | or_replace; +insert(people) | or_ignore; ``` -Compile-time requirement +Behavior by backend: -- The table type must have a primary key or at least one unique constraint. This is enforced at compile time via a static_assert: +- SQLite: `OR REPLACE`, `OR IGNORE` +- PostgreSQL: `ON CONFLICT (...) DO UPDATE ...`, `ON CONFLICT DO NOTHING` +- DuckDB: `OR REPLACE`, `OR IGNORE` +- MySQL: `ON DUPLICATE KEY UPDATE`, `INSERT IGNORE` - "The table must have a primary key or unique column for insert_or_replace(...) to work." +Compile-time rules: -Behavior notes +- You can set at most one conflict policy (`or_replace` or `or_ignore`). +- `or_replace` requires at least one primary key or unique constraint. -- SQLite, PostgreSQL and DuckDB backends emit `ON CONFLICT (...) DO UPDATE ...` (using `excluded.*` to reference the incoming values). -- MySQL backend emits `ON DUPLICATE KEY UPDATE` and uses `VALUES(...)` to reference incoming values. -- The transpilation helper `to_insert_or_write<..., dynamic::Insert>(true)` is used internally to produce the correct SQL. +### Returning Auto-generated IDs (`returning(ids)`) -Example: +Use `returning(ids)` to collect auto-generated primary keys during `insert`: ```cpp -const auto people1 = std::vector({ - Person{.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, - Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10} -}); - -const auto people2 = std::vector({ - Person{.id = 1, .first_name = "Bartholomew", .last_name = "Simpson", .age = 10} -}); +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + int age; +}; -using namespace sqlgen; +auto ids = std::vector{}; -const auto result = sqlite::connect() +sqlite::connect() .and_then(create_table | if_not_exists) - .and_then(insert(std::ref(people1))) - .and_then(insert_or_replace(std::ref(people2))) + .and_then(insert(people, returning(ids))) .value(); ``` -Generated SQL (SQLite/Postgres/DuckDB style): +Compile-time rules: -```sql -INSERT INTO "Person" ("id", "first_name", "last_name", "age") VALUES (?, ?, ?, ?) -ON CONFLICT (id) DO UPDATE SET - id=excluded.id, - first_name=excluded.first_name, - last_name=excluded.last_name, - age=excluded.age; -``` +- The target type must contain an auto-incrementing primary key. +- `returning(ids)` cannot be combined with `or_ignore`. +- The `ids` container must support `clear()` and `push_back(value_type)`. +- On MySQL, `returning(ids)` is supported for single-object inserts only. -Generated SQL (MySQL style): +Backend behavior: -```sql -INSERT INTO `Person` (`id`, `first_name`, `last_name`, `age`) VALUES (?, ?, ?, ?) -ON DUPLICATE KEY UPDATE - id=VALUES(id), - first_name=VALUES(first_name), - last_name=VALUES(last_name), - age=VALUES(age); -``` +- SQLite/PostgreSQL/DuckDB: generated SQL uses `RETURNING`. +- MySQL: no `RETURNING` SQL is emitted; IDs are read via the MySQL C API. + +### Backward Compatibility (`insert_or_replace`) + +`insert_or_replace(...)` is still available and works like before. Internally it is now a thin wrapper over `insert(..., or_replace)`. ## Example: Full Transaction Usage @@ -259,11 +250,9 @@ While both `insert` and `write` can be used to add data to a database, they serv ## Notes - The `Result>` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed -- The function has several overloads: - 1. Takes a connection reference and iterators - 2. Takes a `Result>` and iterators - 3. Takes a connection and a container directly - 4. Takes a connection and a reference wrapper to a container +- `insert(...)` accepts optional modifiers: `or_replace`, `or_ignore`, `returning(ids)` +- Modifiers can be passed directly (`insert(data, or_replace)`) or in pipeline style (`insert(data) | or_replace`) - Unlike `write`, `insert` does not create tables automatically - you must create tables separately using `create_table` - The insert operation is atomic within a transaction - When using reference wrappers (`std::ref`), the data is not copied, which can be more efficient for large datasets +- On MySQL, `returning(ids)` is limited to single-object inserts diff --git a/include/sqlgen/Session.hpp b/include/sqlgen/Session.hpp index 353f0d2..2916988 100644 --- a/include/sqlgen/Session.hpp +++ b/include/sqlgen/Session.hpp @@ -22,6 +22,11 @@ class Session { using Connection = _Connection; using ConnPtr = Ref; + static constexpr bool supports_returning_ids = + Connection::supports_returning_ids; + static constexpr bool supports_multirow_returning_ids = + Connection::supports_multirow_returning_ids; + Session(const Ref& _conn, const Ref& _flag) : conn_(_conn), flag_(_flag.ptr()) {} @@ -47,9 +52,10 @@ class Session { } template - Result insert(const dynamic::Insert& _stmt, ItBegin _begin, - ItEnd _end) { - return conn_->insert(_stmt, _begin, _end); + Result insert( + const dynamic::Insert& _stmt, ItBegin _begin, ItEnd _end, + std::vector>* _returned_ids = nullptr) { + return conn_->insert(_stmt, _begin, _end, _returned_ids); } Session& operator=(const Session& _other) = delete; diff --git a/include/sqlgen/Transaction.hpp b/include/sqlgen/Transaction.hpp index b980dca..703ae18 100644 --- a/include/sqlgen/Transaction.hpp +++ b/include/sqlgen/Transaction.hpp @@ -1,6 +1,9 @@ #ifndef SQLGEN_TRANSACTION_HPP_ #define SQLGEN_TRANSACTION_HPP_ +#include +#include + #include "Ref.hpp" #include "internal/iterator_t.hpp" #include "is_connection.hpp" @@ -13,6 +16,11 @@ class Transaction { public: using ConnType = _ConnType; + static constexpr bool supports_returning_ids = + ConnType::supports_returning_ids; + static constexpr bool supports_multirow_returning_ids = + ConnType::supports_multirow_returning_ids; + Transaction(const Ref& _conn) : conn_(_conn), transaction_ended_(false) {} @@ -57,9 +65,10 @@ class Transaction { } template - Result insert(const dynamic::Insert& _stmt, ItBegin _begin, - ItEnd _end) { - return conn_->insert(_stmt, _begin, _end); + Result insert( + const dynamic::Insert& _stmt, ItBegin _begin, ItEnd _end, + std::vector>* _returned_ids = nullptr) { + return conn_->insert(_stmt, _begin, _end, _returned_ids); } Transaction& operator=(const Transaction& _other) = delete; diff --git a/include/sqlgen/duckdb/Connection.hpp b/include/sqlgen/duckdb/Connection.hpp index b348bbc..fed0861 100644 --- a/include/sqlgen/duckdb/Connection.hpp +++ b/include/sqlgen/duckdb/Connection.hpp @@ -22,7 +22,9 @@ #include "../dynamic/Write.hpp" #include "../internal/iterator_t.hpp" #include "../internal/remove_auto_incr_primary_t.hpp" +#include "../internal/strings/strings.hpp" #include "../internal/to_container.hpp" +#include "../internal/to_str_vec.hpp" #include "../is_connection.hpp" #include "../sqlgen_api.hpp" #include "./parsing/Parser_default.hpp" @@ -38,10 +40,13 @@ class SQLGEN_API Connection { using ConnPtr = Ref; public: - Connection(const ConnPtr &_conn) : appender_(nullptr), conn_(_conn) {} + static constexpr bool supports_returning_ids = true; + static constexpr bool supports_multirow_returning_ids = true; + + Connection(const ConnPtr& _conn) : appender_(nullptr), conn_(_conn) {} static rfl::Result> make( - const std::optional &_fname) noexcept; + const std::optional& _fname) noexcept; ~Connection() = default; @@ -49,44 +54,50 @@ class SQLGEN_API Connection { Result commit() noexcept; - Result execute(const std::string &_sql) noexcept; + Result execute(const std::string& _sql) noexcept; template - Result insert(const dynamic::Insert &_insert_stmt, ItBegin _begin, - ItEnd _end) noexcept { + Result insert(const dynamic::Insert& _insert_stmt, ItBegin _begin, + ItEnd _end, + std::vector>* + _returned_ids = nullptr) noexcept { + if (_returned_ids) { + return insert_with_returning(_insert_stmt, _begin, _end, _returned_ids); + } + using namespace std::ranges::views; const auto sql = to_sql(_insert_stmt); auto columns = internal::collect::vector( _insert_stmt.columns | - transform([](const auto &_str) { return _str.c_str(); })); + transform([](const auto& _str) { return _str.c_str(); })); return get_duckdb_logical_types(_insert_stmt.table, _insert_stmt.columns) - .and_then([&](const auto &_types) { + .and_then([&](const auto& _types) { return DuckDBAppender::make(sql, conn_, columns, _types); }) .and_then([&](auto _appender) { return write_to_appender(_begin, _end, _appender->appender()) - .and_then([&](const auto &) { return _appender->close(); }); + .and_then([&](const auto&) { return _appender->close(); }); }); } template - auto read(const rfl::Variant &_query) { + auto read(const rfl::Variant& _query) { using ValueType = transpilation::value_t; - const auto sql = _query.visit([&](const auto &_q) { return to_sql(_q); }); + const auto sql = _query.visit([&](const auto& _q) { return to_sql(_q); }); return internal::to_container>( Iterator(sql, conn_)); } Result rollback() noexcept; - std::string to_sql(const dynamic::Statement &_stmt) noexcept { + std::string to_sql(const dynamic::Statement& _stmt) noexcept { return duckdb::to_sql_impl(_stmt); } - Result start_write(const dynamic::Write &_write_stmt) { + Result start_write(const dynamic::Write& _write_stmt) { if (appender_) { return error( "Write operation already in progress - you cannot start another."); @@ -96,7 +107,7 @@ class SQLGEN_API Connection { auto columns = internal::collect::vector( _write_stmt.columns | - transform([](const auto &_str) { return _str.c_str(); })); + transform([](const auto& _str) { return _str.c_str(); })); const auto sql = to_sql(_write_stmt); @@ -104,7 +115,7 @@ class SQLGEN_API Connection { .and_then([&](auto _types) { return DuckDBAppender::make(sql, conn_, columns, _types); }) - .transform([&](auto &&_appender) { + .transform([&](auto&& _appender) { appender_ = _appender.ptr(); return Nothing{}; }); @@ -127,12 +138,193 @@ class SQLGEN_API Connection { } private: + std::string build_batched_insert_sql(const dynamic::Insert& _stmt, + size_t _num_rows) { + using namespace std::ranges::views; + + const auto wrap_in_quotes = [](const auto& _name) { + return "\"" + _name + "\""; + }; + + std::stringstream stream; + + stream << "INSERT "; + + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::replace) { + stream << "OR REPLACE "; + } else if (_stmt.conflict_policy == + dynamic::Insert::ConflictPolicy::ignore) { + stream << "OR IGNORE "; + } + + stream << "INTO "; + + if (_stmt.table.schema) { + stream << wrap_in_quotes(*_stmt.table.schema) << "."; + } + + stream << wrap_in_quotes(_stmt.table.name); + + stream << " ("; + stream << internal::strings::join( + ", ", + internal::collect::vector(_stmt.columns | transform(wrap_in_quotes))); + stream << ") VALUES "; + + // Build single row placeholder: (?, ?, ...) + std::string row_placeholder = "("; + for (size_t i = 0; i < _stmt.columns.size(); ++i) { + if (i > 0) row_placeholder += ", "; + row_placeholder += "?"; + } + row_placeholder += ")"; + + // Repeat for _num_rows + for (size_t r = 0; r < _num_rows; ++r) { + if (r > 0) stream << ", "; + stream << row_placeholder; + } + + if (_stmt.returning.size() != 0) { + stream << " RETURNING "; + stream << internal::strings::join( + ", ", internal::collect::vector(_stmt.returning | + transform(wrap_in_quotes))); + } + + stream << ";"; + + return stream.str(); + } + + template + Result insert_with_returning( + const dynamic::Insert& _insert_stmt, ItBegin _begin, ItEnd _end, + std::vector>* _returned_ids) { + if (_insert_stmt.returning.size() != 1) { + return error( + "DuckDB returning(ids) requires exactly one returned " + "auto-incrementing primary key column."); + } + + constexpr size_t BATCH_SIZE = 1000; + const size_t num_cols = _insert_stmt.columns.size(); + + auto it = _begin; + while (it != _end) { + // 1. Collect current batch of rows + std::vector>> batch; + batch.reserve(BATCH_SIZE); + + for (size_t i = 0; i < BATCH_SIZE && it != _end; ++i, ++it) { + auto row = internal::to_str_vec(*it); + if (row.size() != num_cols) { + return error("Expected " + std::to_string(num_cols) + + " values, got " + std::to_string(row.size()) + "."); + } + batch.push_back(std::move(row)); + } + + if (batch.empty()) { + break; + } + + // 2. Build batched INSERT SQL + const auto sql = build_batched_insert_sql(_insert_stmt, batch.size()); + + // 3. Prepare statement + duckdb_prepared_statement stmt = nullptr; + + if (duckdb_prepare(conn_->conn(), sql.c_str(), &stmt) == DuckDBError) { + const auto* raw_err = stmt ? duckdb_prepare_error(stmt) : nullptr; + const auto err = + raw_err ? std::string(raw_err) + : std::string("Failed to prepare INSERT statement."); + if (stmt) { + duckdb_destroy_prepare(&stmt); + } + return error(err); + } + + // 4. Bind all parameters (batch.size() * num_cols) + idx_t param_idx = 1; + for (const auto& row : batch) { + for (const auto& val : row) { + const auto state = + val ? duckdb_bind_varchar(stmt, param_idx, val->c_str()) + : duckdb_bind_null(stmt, param_idx); + + if (state == DuckDBError) { + const auto* raw_err = duckdb_prepare_error(stmt); + const auto err = raw_err ? std::string(raw_err) + : std::string("Failed to bind parameter."); + duckdb_destroy_prepare(&stmt); + return error(err); + } + ++param_idx; + } + } + + // 5. Execute once + duckdb_result result{}; + const auto execute_state = duckdb_execute_prepared(stmt, &result); + + if (execute_state == DuckDBError) { + const auto* raw_err = duckdb_result_error(&result); + const auto* stmt_err = duckdb_prepare_error(stmt); + const auto err = + raw_err ? std::string(raw_err) + : (stmt_err + ? std::string(stmt_err) + : std::string("Failed to execute prepared INSERT.")); + duckdb_destroy_result(&result); + duckdb_destroy_prepare(&stmt); + return error(err); + } + + // 6. Read all returned IDs + const idx_t row_count = duckdb_row_count(&result); + + if (row_count != batch.size() || duckdb_column_count(&result) < 1) { + duckdb_destroy_result(&result); + duckdb_destroy_prepare(&stmt); + return error( + "INSERT ... RETURNING must return exactly one row per input row " + "and at least one column. Expected " + + std::to_string(batch.size()) + " rows, got " + + std::to_string(row_count) + "."); + } + + for (idx_t r = 0; r < row_count; ++r) { + if (duckdb_value_is_null(&result, 0, r)) { + _returned_ids->emplace_back(std::nullopt); + } else { + char* raw_value = duckdb_value_varchar(&result, 0, r); + + if (!raw_value) { + duckdb_destroy_result(&result); + duckdb_destroy_prepare(&stmt); + return error("Failed to read returned id from DuckDB."); + } + + _returned_ids->emplace_back(raw_value); + duckdb_free(raw_value); + } + } + + duckdb_destroy_result(&result); + duckdb_destroy_prepare(&stmt); + } + + return Nothing{}; + } + Result> get_duckdb_logical_types( - const dynamic::Table &_table, const std::vector &_columns) { + const dynamic::Table& _table, const std::vector& _columns) { using namespace std::ranges::views; const auto fields = internal::collect::vector( - _columns | transform([](const auto &_name) { + _columns | transform([](const auto& _name) { return dynamic::SelectFrom::Field{ .val = dynamic::Operation{dynamic::Column{.alias = std::nullopt, .name = _name}}, @@ -145,7 +337,7 @@ class SQLGEN_API Connection { .offset = dynamic::Offset{0}}; return DuckDBResult::make(to_sql(select_from), conn_) - .transform([&](const auto &_res) { + .transform([&](const auto& _res) { return internal::collect::vector( iota(static_cast(0), static_cast(fields.size())) | transform( @@ -156,7 +348,7 @@ class SQLGEN_API Connection { template Result write_to_appender(ItBegin _begin, ItEnd _end, duckdb_appender _appender) { - for (auto it = _begin; it < _end; ++it) { + for (auto it = _begin; it != _end; ++it) { const auto res = write_row(*it, _appender); if (!res) { return res; @@ -170,18 +362,18 @@ class SQLGEN_API Connection { } template - Result write_row(const StructT &_struct, + Result write_row(const StructT& _struct, duckdb_appender _appender) noexcept { using ViewType = internal::remove_auto_incr_primary_t>; try { - ViewType(rfl::to_view(_struct)).apply([&](const auto &_field) { + ViewType(rfl::to_view(_struct)).apply([&](const auto& _field) { using ValueType = std::remove_cvref_t::Type>>; duckdb::parsing::Parser::write(*_field.value(), _appender) .value(); }); - } catch (const std::exception &e) { + } catch (const std::exception& e) { return error(e.what()); } return Nothing{}; diff --git a/include/sqlgen/dynamic/Insert.hpp b/include/sqlgen/dynamic/Insert.hpp index 07ed8c4..3ef04cf 100644 --- a/include/sqlgen/dynamic/Insert.hpp +++ b/include/sqlgen/dynamic/Insert.hpp @@ -9,13 +9,18 @@ namespace sqlgen::dynamic { struct Insert { + enum class ConflictPolicy { none, replace, ignore }; + Table table; std::vector columns; - bool or_replace; + ConflictPolicy conflict_policy = ConflictPolicy::none; std::vector non_primary_keys; - /// Holds primary keys and unique columns when or_replace is true. + /// Holds primary keys and unique columns when conflict_policy is replace. std::vector constraints; + + /// The columns to be returned after insert. + std::vector returning; }; } // namespace sqlgen::dynamic diff --git a/include/sqlgen/insert.hpp b/include/sqlgen/insert.hpp index ee897a5..2d4d2bb 100644 --- a/include/sqlgen/insert.hpp +++ b/include/sqlgen/insert.hpp @@ -4,103 +4,520 @@ #include #include #include +#include #include #include #include #include #include -#include "internal/batch_size.hpp" +#include "dynamic/Insert.hpp" +#include "internal/has_auto_incr_primary_key.hpp" #include "internal/has_constraint.hpp" -#include "internal/to_str_vec.hpp" #include "is_connection.hpp" +#include "parsing/Parser.hpp" #include "transpilation/to_insert_or_write.hpp" #include "transpilation/value_t.hpp" namespace sqlgen { -template - requires is_connection +namespace conflict_policy { + +struct replace {}; + +struct ignore {}; + +inline constexpr replace or_replace{}; + +inline constexpr ignore or_ignore{}; + +} // namespace conflict_policy + +using conflict_policy::or_ignore; +using conflict_policy::or_replace; + +namespace internal::insert { + +template +concept OutputIDRange = + std::ranges::range && requires(T& t, typename T::value_type v) { + t.clear(); + t.push_back(v); + }; + +template +struct ReturningModifier { + IDsType* ids; +}; + +enum class ModifierKind { unsupported, conflict_policy, returning }; + +template +struct modifier_traits { + static constexpr ModifierKind kind = ModifierKind::unsupported; + static constexpr dynamic::Insert::ConflictPolicy conflict_policy = + dynamic::Insert::ConflictPolicy::none; + using ids_type = void; +}; + +template <> +struct modifier_traits { + static constexpr ModifierKind kind = ModifierKind::conflict_policy; + static constexpr dynamic::Insert::ConflictPolicy conflict_policy = + dynamic::Insert::ConflictPolicy::replace; + using ids_type = void; +}; + +template <> +struct modifier_traits { + static constexpr ModifierKind kind = ModifierKind::conflict_policy; + static constexpr dynamic::Insert::ConflictPolicy conflict_policy = + dynamic::Insert::ConflictPolicy::ignore; + using ids_type = void; +}; + +template +struct modifier_traits> { + static constexpr ModifierKind kind = ModifierKind::returning; + static constexpr dynamic::Insert::ConflictPolicy conflict_policy = + dynamic::Insert::ConflictPolicy::none; + using ids_type = IDsType; +}; + +template +using modifier_traits_t = modifier_traits>; + +template +constexpr bool is_supported_modifier_v = + modifier_traits_t::kind != ModifierKind::unsupported; + +template +constexpr bool is_conflict_policy_modifier_v = + modifier_traits_t::kind == ModifierKind::conflict_policy; + +template +constexpr bool is_returning_modifier_v = + modifier_traits_t::kind == ModifierKind::returning; + +template +constexpr bool is_ignore_modifier_v = + modifier_traits_t::conflict_policy == + dynamic::Insert::ConflictPolicy::ignore; + +template +Result assign_returning_ids( + IDsType* _ids, const std::vector>& _raw_ids) { + using ValueType = typename IDsType::value_type; + + _ids->clear(); + + if constexpr (requires(IDsType& _v, size_t _n) { _v.reserve(_n); }) { + _ids->reserve(_raw_ids.size()); + } + + for (const auto& raw_id : _raw_ids) { + auto parsed = parsing::Parser::read(raw_id); + + if (!parsed) { + return error("Could not parse returned id: " + + std::string(parsed.error().what())); + } + + _ids->push_back(parsed.value()); + } + + return Nothing{}; +} + +template +constexpr void validate_modifiers() { + static_assert((true && ... && is_supported_modifier_v), + "Unsupported insert modifier. Supported modifiers are " + "sqlgen::or_replace, " + "sqlgen::or_ignore and sqlgen::returning(ids).\n" + "Example: insert(rows, or_replace, returning(ids))."); + + constexpr auto num_conflict_policies = + (0 + ... + (is_conflict_policy_modifier_v ? 1 : 0)); + constexpr auto num_returning = + (0 + ... + (is_returning_modifier_v ? 1 : 0)); + constexpr bool has_ignore = (false || ... || is_ignore_modifier_v); + + static_assert(num_conflict_policies <= 1, + "You can only set one conflict policy on insert(...).\n" + "Use either sqlgen::or_replace or sqlgen::or_ignore."); + + static_assert(num_returning <= 1, + "You can only call returning(ids) once on insert(...)."); + + static_assert(!(has_ignore && num_returning != 0), + "You cannot combine returning(ids) with or_ignore."); +} + +template +constexpr dynamic::Insert::ConflictPolicy conflict_policy_from_modifiers() { + constexpr bool has_replace = + (false || ... || + (modifier_traits_t::conflict_policy == + dynamic::Insert::ConflictPolicy::replace)); + constexpr bool has_ignore = (false || ... || is_ignore_modifier_v); + + if constexpr (has_replace) { + return dynamic::Insert::ConflictPolicy::replace; + } else if constexpr (has_ignore) { + return dynamic::Insert::ConflictPolicy::ignore; + } else { + return dynamic::Insert::ConflictPolicy::none; + } +} + +template +struct ids_type_from_modifiers { + using type = void; +}; + +template +struct ids_type_from_modifiers { + private: + using current_ids_type = typename modifier_traits_t::ids_type; + using rest_ids_type = typename ids_type_from_modifiers::type; + + public: + using type = std::conditional_t, + rest_ids_type, current_ids_type>; +}; + +template +using ids_type_t = + typename ids_type_from_modifiers...>::type; + +inline std::nullptr_t extract_ids_ptr() { return nullptr; } + +template +auto extract_ids_ptr(const Modifier& _modifier, const Rest&... _rest) { + if constexpr (is_returning_modifier_v) { + return _modifier.ids; + } else { + return extract_ids_ptr(_rest...); + } +} + +template +struct ParsedModifiers { + static constexpr dynamic::Insert::ConflictPolicy conflict_policy = + _conflict_policy; + using ids_type = IDsType; + std::conditional_t, std::nullptr_t, IDsType*> + ids_ptr = nullptr; +}; + +template +auto parse_modifiers(const Modifiers&... _modifiers) { + validate_modifiers(); + + constexpr auto conflict_policy = + conflict_policy_from_modifiers(); + using IDsType = ids_type_t; + + if constexpr (std::is_void_v) { + return ParsedModifiers{}; + } else { + return ParsedModifiers{ + .ids_ptr = extract_ids_ptr(_modifiers...)}; + } +} + +template +struct is_connection_handle : std::false_type {}; + +template +struct is_connection_handle> + : std::bool_constant> {}; + +template +struct is_connection_handle>> + : std::bool_constant> {}; + +template +constexpr bool is_connection_handle_v = + is_connection_handle>::value; + +template +struct returning_capabilities { + static constexpr bool supports_returning_ids = false; + static constexpr bool supports_multirow_returning_ids = false; +}; + +template +struct returning_capabilities< + Connection, + std::void_t< + decltype(std::remove_cvref_t::supports_returning_ids), + decltype(std::remove_cvref_t< + Connection>::supports_multirow_returning_ids)>> { + static constexpr bool supports_returning_ids = + std::remove_cvref_t::supports_returning_ids; + static constexpr bool supports_multirow_returning_ids = + std::remove_cvref_t::supports_multirow_returning_ids; +}; + +template +constexpr void validate_insert_usage() { + if constexpr (_conflict_policy == dynamic::Insert::ConflictPolicy::replace) { + static_assert(internal::has_constraint_v, + "The table must have a primary key or unique column for " + "insert_or_replace(...) to work."); + } + + if constexpr (_has_returning) { + static_assert(internal::has_auto_incr_primary_key_v, + "The table must have an auto-incrementing primary key for " + "returning(ids) to work."); + + static_assert(_conflict_policy != dynamic::Insert::ConflictPolicy::ignore, + "You cannot combine returning(ids) with or_ignore."); + + static_assert(returning_capabilities::supports_returning_ids, + "The current backend does not support returning(ids)."); + + if constexpr (!returning_capabilities< + Connection>::supports_multirow_returning_ids) { + static_assert(_single_row_hint, + "This backend only supports returning(ids) for single-" + "object inserts."); + } + } +} + +} // namespace internal::insert + +template + requires(!std::is_const_v>) +auto returning(ContainerType& _ids) { + return internal::insert::ReturningModifier{.ids = &_ids}; +} + +template +auto returning(const std::reference_wrapper _ids) { + return returning(_ids.get()); +} + +template + requires is_connection && + std::input_or_output_iterator && + std::sentinel_for Result> insert_impl(const Ref& _conn, ItBegin _begin, ItEnd _end, - bool _or_replace) { + IDsPtr _ids = nullptr) { using T = std::remove_cvref_t::value_type>; + constexpr bool has_returning = !std::is_same_v; + + internal::insert::validate_insert_usage<_conflict_policy, has_returning, + _single_row_hint, T, Connection>(); + const auto insert_stmt = - transpilation::to_insert_or_write(_or_replace); + transpilation::to_insert_or_write(_conflict_policy, + has_returning); + + if constexpr (has_returning) { + std::vector> raw_ids; + + return _conn->insert(insert_stmt, _begin, _end, &raw_ids) + .and_then([&](const auto&) { + return internal::insert::assign_returning_ids(_ids, raw_ids); + }) + .transform([&](const auto&) { return _conn; }); + } return _conn->insert(insert_stmt, _begin, _end).transform([&](const auto&) { return _conn; }); } -template - requires is_connection +template + requires is_connection && + std::input_or_output_iterator && + std::sentinel_for Result> insert_impl(const Result>& _res, ItBegin _begin, ItEnd _end, - bool _or_replace) { + IDsPtr _ids = nullptr) { return _res.and_then([&](const auto& _conn) { - return insert_impl(_conn, _begin, _end, _or_replace); + return insert_impl<_conflict_policy, _single_row_hint>(_conn, _begin, _end, + _ids); }); } -template +template auto insert_impl(const auto& _conn, const ContainerType& _data, - bool _or_replace) { + IDsPtr _ids = nullptr) { if constexpr (std::ranges::input_range>) { - return insert_impl(_conn, _data.begin(), _data.end(), _or_replace); + return insert_impl<_conflict_policy, false>(_conn, _data.begin(), + _data.end(), _ids); } else { - return insert_impl(_conn, &_data, &_data + 1, _or_replace); + return insert_impl<_conflict_policy, true>(_conn, &_data, &_data + 1, _ids); } } -template +template auto insert_impl(const auto& _conn, const std::reference_wrapper& _data, - bool _or_replace) { - return insert_impl(_conn, _data.get(), _or_replace); + IDsPtr _ids = nullptr) { + return insert_impl<_conflict_policy>(_conn, _data.get(), _ids); } -template +template struct Insert { + using ValueType = transpilation::value_t; + auto operator()(const auto& _conn) const { - return insert_impl(_conn, data_, or_replace_); + if constexpr (std::is_void_v) { + return insert_impl<_conflict_policy>(_conn, data_); + } else { + return insert_impl<_conflict_policy>(_conn, data_, ids_); + } + } + + template + friend auto operator|(const Insert& _insert, const Modifier& _modifier) { + using ModifierType = std::remove_cvref_t; + + static_assert(internal::insert::is_supported_modifier_v, + "Unsupported insert modifier. Supported modifiers are " + "sqlgen::or_replace, " + "sqlgen::or_ignore and sqlgen::returning(ids).\n" + "Example: insert(rows, or_replace, returning(ids))."); + + if constexpr (internal::insert::is_returning_modifier_v) { + using NewIDsType = + typename internal::insert::modifier_traits_t::ids_type; + + static_assert(std::is_void_v, + "You can only call returning(ids) once on insert(...)."); + static_assert(_conflict_policy != dynamic::Insert::ConflictPolicy::ignore, + "You cannot combine returning(ids) with or_ignore."); + static_assert(internal::has_auto_incr_primary_key_v, + "The table must have an auto-incrementing primary key for " + "returning(ids) to work."); + + return Insert{ + .data_ = _insert.data_, .ids_ = _modifier.ids}; + } else { + constexpr auto next_conflict_policy = + internal::insert::modifier_traits_t::conflict_policy; + + static_assert( + next_conflict_policy != dynamic::Insert::ConflictPolicy::none, + "Unsupported insert modifier."); + static_assert(_conflict_policy == dynamic::Insert::ConflictPolicy::none, + "You can only set one conflict policy on insert(...).\n" + "Use either sqlgen::or_replace or sqlgen::or_ignore."); + static_assert( + !(next_conflict_policy == dynamic::Insert::ConflictPolicy::ignore && + !std::is_void_v), + "You cannot combine returning(ids) with or_ignore."); + + if constexpr (next_conflict_policy == + dynamic::Insert::ConflictPolicy::replace) { + static_assert(internal::has_constraint_v, + "The table must have a primary key or unique column for " + "insert_or_replace(...) to work."); + } + + return Insert{ + .data_ = _insert.data_, .ids_ = _insert.ids_}; + } } ContainerType data_; - bool or_replace_; + [[no_unique_address]] + std::conditional_t, std::monostate, IDsType*> ids_{}; }; -template -Insert insert_impl(const ContainerType& _data, - bool _or_replace) { - return Insert{.data_ = _data, .or_replace_ = _or_replace}; +template + requires(!internal::insert::is_connection_handle_v) +auto insert(const ContainerType& _data, const Modifiers&... _modifiers) { + const auto parsed_modifiers = + internal::insert::parse_modifiers(_modifiers...); + using ParsedModifiers = std::remove_cvref_t; + using ValueType = transpilation::value_t; + using IDsType = typename ParsedModifiers::ids_type; + + constexpr auto conflict_policy = ParsedModifiers::conflict_policy; + + if constexpr (conflict_policy == dynamic::Insert::ConflictPolicy::replace) { + static_assert(internal::has_constraint_v, + "The table must have a primary key or unique column for " + "insert_or_replace(...) to work."); + } + + if constexpr (!std::is_void_v) { + static_assert(internal::has_auto_incr_primary_key_v, + "The table must have an auto-incrementing primary key for " + "returning(ids) to work."); + + return Insert{ + .data_ = _data, .ids_ = parsed_modifiers.ids_ptr}; + } else { + return Insert{.data_ = _data}; + } } -template -auto insert(const Args&... args) { - return insert_impl(args..., false); +template + requires internal::insert::is_connection_handle_v && + (!std::input_or_output_iterator>) +auto insert(const ConnectionHandle& _conn, const ContainerType& _data, + const Modifiers&... _modifiers) { + return insert(_data, _modifiers...)(_conn); } -template -auto insert_or_replace(const auto& _conn, const ContainerType& _data) { +template + requires internal::insert::is_connection_handle_v && + std::input_or_output_iterator && + std::sentinel_for +auto insert(const ConnectionHandle& _conn, ItBegin _begin, ItEnd _end, + const Modifiers&... _modifiers) { + return insert(std::ranges::subrange(_begin, _end), _modifiers...)(_conn); +} + +template + requires internal::insert::is_connection_handle_v +[[deprecated( + "Use `insert(...) | or_replace` instead of `insert_or_replace(...)`.")]] +auto insert_or_replace(const ConnectionHandle& _conn, + const ContainerType& _data) { static_assert( internal::has_constraint_v>, "The table must have a primary key or unique column for " "insert_or_replace(...) to work."); - return insert_impl(_conn, _data, true); + + return insert(_conn, _data, or_replace); } template +[[deprecated( + "Use `insert(...) | or_replace` instead of `insert_or_replace(...)`.")]] auto insert_or_replace(const ContainerType& _data) { static_assert( internal::has_constraint_v>, "The table must have a primary key or unique column for " "insert_or_replace(...) to work."); - return insert_impl(_data, true); + + return insert(_data, or_replace); } -}; // namespace sqlgen +} // namespace sqlgen #endif diff --git a/include/sqlgen/internal/has_auto_incr_primary_key.hpp b/include/sqlgen/internal/has_auto_incr_primary_key.hpp new file mode 100644 index 0000000..3aba40b --- /dev/null +++ b/include/sqlgen/internal/has_auto_incr_primary_key.hpp @@ -0,0 +1,36 @@ +#ifndef SQLGEN_INTERNAL_HAS_AUTO_INCR_PRIMARY_KEY_HPP_ +#define SQLGEN_INTERNAL_HAS_AUTO_INCR_PRIMARY_KEY_HPP_ + +#include +#include + +#include "is_primary_key.hpp" + +namespace sqlgen::internal { + +template +struct is_auto_incr_primary_key : std::false_type {}; + +template +struct is_auto_incr_primary_key< + T, std::void_t::auto_incr)>> + : std::bool_constant> && + std::remove_cvref_t::auto_incr> {}; + +template +struct has_auto_incr_primary_key; + +template +struct has_auto_incr_primary_key> { + constexpr static bool value = + (false || ... || + is_auto_incr_primary_key::value); +}; + +template +constexpr bool has_auto_incr_primary_key_v = + has_auto_incr_primary_key>::value; + +} // namespace sqlgen::internal + +#endif diff --git a/include/sqlgen/mysql/Connection.hpp b/include/sqlgen/mysql/Connection.hpp index 323162e..2f473be 100644 --- a/include/sqlgen/mysql/Connection.hpp +++ b/include/sqlgen/mysql/Connection.hpp @@ -36,6 +36,9 @@ class SQLGEN_API Connection { using StmtPtr = std::shared_ptr; public: + static constexpr bool supports_returning_ids = true; + static constexpr bool supports_multirow_returning_ids = false; + Connection(const Credentials& _credentials); static rfl::Result> make( @@ -51,10 +54,14 @@ class SQLGEN_API Connection { template Result insert(const dynamic::Insert& _stmt, ItBegin _begin, - ItEnd _end) noexcept { + ItEnd _end, + std::vector>* + _returned_ids = nullptr) noexcept { return internal::write_or_insert( - [&](const auto& _data) { return insert_impl(_stmt, _data); }, _begin, - _end); + [&](const auto& _data) { + return insert_impl(_stmt, _data, _returned_ids); + }, + _begin, _end); } template @@ -85,12 +92,13 @@ class SQLGEN_API Connection { /// used by both .insert(...) and .write(...). Result actual_insert( const std::vector>>& _data, - MYSQL_STMT* _stmt) const noexcept; + MYSQL_STMT* _stmt, + std::vector>* _returned_ids) const noexcept; Result insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept; + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept; static ConnPtr make_conn(const Credentials& _credentials); diff --git a/include/sqlgen/postgres/Connection.hpp b/include/sqlgen/postgres/Connection.hpp index 45d6d58..caddb9a 100644 --- a/include/sqlgen/postgres/Connection.hpp +++ b/include/sqlgen/postgres/Connection.hpp @@ -3,13 +3,13 @@ #include +#include #include #include #include #include #include #include -#include #include "../Iterator.hpp" #include "../Ref.hpp" @@ -37,21 +37,24 @@ namespace sqlgen::postgres { enum class NotificationWaitResult { - Ready, // Data available (possibly a NOTIFY) - Timeout, // Timeout elapsed - Error // I/O or connection error + Ready, // Data available (possibly a NOTIFY) + Timeout, // Timeout elapsed + Error // I/O or connection error }; struct Notification { - std::string channel; - std::string payload; - int backend_pid; + std::string channel; + std::string payload; + int backend_pid; }; class SQLGEN_API Connection { using Conn = PostgresV2Connection; public: + static constexpr bool supports_returning_ids = true; + static constexpr bool supports_multirow_returning_ids = true; + Connection(const Conn& _conn); Connection(const Credentials& _credentials); @@ -69,10 +72,14 @@ class SQLGEN_API Connection { template Result insert(const dynamic::Insert& _stmt, ItBegin _begin, - ItEnd _end) noexcept { + ItEnd _end, + std::vector>* + _returned_ids = nullptr) noexcept { return internal::write_or_insert( - [&](const auto& _data) { return insert_impl(_stmt, _data); }, _begin, - _end); + [&](const auto& _data) { + return insert_impl(_stmt, _data, _returned_ids); + }, + _begin, _end); } template @@ -103,17 +110,18 @@ class SQLGEN_API Connection { rfl::Result listen(const std::string& channel) noexcept; - rfl::Result unlisten(const std:: string& channel) noexcept; + rfl::Result unlisten(const std::string& channel) noexcept; - rfl::Result notify(const std::string& channel, const std::string& payload = "") noexcept; + rfl::Result notify(const std::string& channel, + const std::string& payload = "") noexcept; bool consume_input() noexcept; private: Result insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept; + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept; Result> read_impl( const rfl::Variant& _query); diff --git a/include/sqlgen/sqlite/Connection.hpp b/include/sqlgen/sqlite/Connection.hpp index 670f276..1075226 100644 --- a/include/sqlgen/sqlite/Connection.hpp +++ b/include/sqlgen/sqlite/Connection.hpp @@ -32,6 +32,9 @@ class SQLGEN_API Connection { using StmtPtr = std::shared_ptr; public: + static constexpr bool supports_returning_ids = true; + static constexpr bool supports_multirow_returning_ids = true; + Connection(const std::string& _fname); static rfl::Result> make(const std::string& _fname) noexcept; @@ -46,10 +49,14 @@ class SQLGEN_API Connection { template Result insert(const dynamic::Insert& _stmt, ItBegin _begin, - ItEnd _end) noexcept { + ItEnd _end, + std::vector>* + _returned_ids = nullptr) noexcept { return internal::write_or_insert( - [&](const auto& _data) { return insert_impl(_stmt, _data); }, _begin, - _end); + [&](const auto& _data) { + return insert_impl(_stmt, _data, _returned_ids); + }, + _begin, _end); } template @@ -83,13 +90,14 @@ class SQLGEN_API Connection { /// used by both .insert(...) and .write(...). Result actual_insert( const std::vector>>& _data, - sqlite3_stmt* _stmt) const noexcept; + sqlite3_stmt* _stmt, + std::vector>* _returned_ids) const noexcept; /// Implements the actual insert. Result insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept; + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept; /// Generates a prepared statment, usually for inserts. Result prepare_statement(const std::string& _sql) const noexcept; diff --git a/include/sqlgen/transpilation/to_insert_or_write.hpp b/include/sqlgen/transpilation/to_insert_or_write.hpp index 3c59eaf..a9cc6c6 100644 --- a/include/sqlgen/transpilation/to_insert_or_write.hpp +++ b/include/sqlgen/transpilation/to_insert_or_write.hpp @@ -21,16 +21,24 @@ namespace sqlgen::transpilation { template requires std::is_class_v> && std::is_aggregate_v> -InsertOrWrite to_insert_or_write(bool or_replace) { +InsertOrWrite to_insert_or_write( + const dynamic::Insert::ConflictPolicy _conflict_policy, + const bool _returning_auto_incr_ids = false) { using namespace std::ranges::views; using NamedTupleType = sqlgen::internal::remove_auto_incr_primary_t< rfl::named_tuple_t>>; using Fields = typename NamedTupleType::Fields; + using FullNamedTupleType = rfl::named_tuple_t>; + using FullFields = typename FullNamedTupleType::Fields; + const auto columns = make_columns( std::make_integer_sequence>()); + const auto full_columns = make_columns( + std::make_integer_sequence>()); + const auto get_name = [](const auto& _col) { return _col.name; }; auto result = InsertOrWrite{ @@ -52,23 +60,51 @@ InsertOrWrite to_insert_or_write(bool or_replace) { }); }; - result.or_replace = or_replace; + const auto is_auto_incr_primary = [](const auto& _c) { + return _c.type.visit([](const auto& _t) { + return _t.properties.primary && _t.properties.auto_incr; + }); + }; + + result.conflict_policy = _conflict_policy; result.non_primary_keys = sqlgen::internal::collect::vector( columns | filter(is_non_primary) | transform(get_name)); - if (or_replace) { + + if (_conflict_policy == dynamic::Insert::ConflictPolicy::replace) { result.constraints = sqlgen::internal::collect::vector( - columns | filter(is_constraint) | transform(get_name)); + full_columns | filter(is_constraint) | transform(get_name)); + } + + if (_returning_auto_incr_ids) { + result.returning = sqlgen::internal::collect::vector( + full_columns | filter(is_auto_incr_primary) | transform(get_name)); } } return result; } +template + requires std::is_class_v> && + std::is_aggregate_v> +InsertOrWrite to_insert_or_write(const bool _or_replace) { + if constexpr (std::is_same_v) { + return to_insert_or_write( + _or_replace ? dynamic::Insert::ConflictPolicy::replace + : dynamic::Insert::ConflictPolicy::none, + false); + } else { + return to_insert_or_write( + dynamic::Insert::ConflictPolicy::none, false); + } +} + template requires std::is_class_v> && std::is_aggregate_v> InsertOrWrite to_insert_or_write() { - return to_insert_or_write(false); + return to_insert_or_write( + dynamic::Insert::ConflictPolicy::none, false); } } // namespace sqlgen::transpilation diff --git a/include/sqlgen/transpilation/to_sql.hpp b/include/sqlgen/transpilation/to_sql.hpp index e479dc1..d4aa848 100644 --- a/include/sqlgen/transpilation/to_sql.hpp +++ b/include/sqlgen/transpilation/to_sql.hpp @@ -35,9 +35,9 @@ struct ToSQL; template -struct ToSQL< - CreateAs> { +struct ToSQL> { dynamic::Statement operator()(const auto& _create_as) const { using TableTupleType = transpilation::table_tuple_t; @@ -82,19 +82,24 @@ struct ToSQL> { } }; -template -struct ToSQL> { +template +struct ToSQL> { dynamic::Statement operator()(const auto&) const { - return to_insert_or_write(); + constexpr bool has_returning = !std::is_void_v; + return to_insert_or_write(_conflict_policy, + has_returning); } }; template -struct ToSQL> { +struct ToSQL< + Read> { dynamic::Statement operator()(const auto& _read) const { return read_to_select_from, WhereType, OrderByType, - LimitType, OffsetType>(_read.where_, _read.limit_, _read.offset_); + LimitType, OffsetType>( + _read.where_, _read.limit_, _read.offset_); } }; diff --git a/include/sqlgen/transpilation/value_t.hpp b/include/sqlgen/transpilation/value_t.hpp index cf9ba11..92d5c6f 100644 --- a/include/sqlgen/transpilation/value_t.hpp +++ b/include/sqlgen/transpilation/value_t.hpp @@ -21,7 +21,7 @@ struct ValueType { template requires std::ranges::input_range struct ValueType { - using Type = std::remove_cvref_t; + using Type = std::remove_cvref_t>; }; template diff --git a/src/sqlgen/duckdb/to_sql.cpp b/src/sqlgen/duckdb/to_sql.cpp index 034aaed..696b84c 100644 --- a/src/sqlgen/duckdb/to_sql.cpp +++ b/src/sqlgen/duckdb/to_sql.cpp @@ -514,8 +514,10 @@ std::string insert_to_sql(const dynamic::Insert& _stmt) noexcept { stream << "INSERT "; - if (_stmt.or_replace) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::replace) { stream << "OR REPLACE "; + } else if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::ignore) { + stream << "OR IGNORE "; } stream << "INTO "; @@ -534,6 +536,13 @@ std::string insert_to_sql(const dynamic::Insert& _stmt) noexcept { }))); stream << " FROM sqlgen_appended_data)"; + if (_stmt.returning.size() != 0) { + stream << " RETURNING "; + stream << internal::strings::join( + ", ", + internal::collect::vector(_stmt.returning | transform(wrap_in_quotes))); + } + stream << ";"; return stream.str(); diff --git a/src/sqlgen/mysql/Connection.cpp b/src/sqlgen/mysql/Connection.cpp index 1eefa8c..fc86ea1 100644 --- a/src/sqlgen/mysql/Connection.cpp +++ b/src/sqlgen/mysql/Connection.cpp @@ -21,7 +21,12 @@ Connection::~Connection() = default; Result Connection::actual_insert( const std::vector>>& _data, - MYSQL_STMT* _stmt) const noexcept { + MYSQL_STMT* _stmt, + std::vector>* _returned_ids) const noexcept { + if (_returned_ids && _data.size() > 1) { + return error("MySQL returning(ids) only supports single-row inserts."); + } + const auto num_params = static_cast(mysql_stmt_param_count(_stmt)); std::vector bind(num_params); @@ -70,6 +75,11 @@ Result Connection::actual_insert( if (err) { return make_error(conn_); } + + if (_returned_ids) { + _returned_ids->emplace_back(std::to_string( + static_cast(mysql_insert_id(conn_.get())))); + } } return Nothing{}; @@ -87,13 +97,14 @@ Result Connection::execute(const std::string& _sql) noexcept { Result Connection::insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept { + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept { if (_data.size() == 0) { return Nothing{}; } - return prepare_statement(_stmt).and_then( - [&](auto&& _stmt_ptr) { return actual_insert(_data, _stmt_ptr.get()); }); + return prepare_statement(_stmt).and_then([&](auto&& _stmt_ptr) { + return actual_insert(_data, _stmt_ptr.get(), _returned_ids); + }); } rfl::Result> Connection::make( @@ -186,11 +197,12 @@ Result Connection::write_impl( " You need to call .start_write(...) before you can call " ".write(...)."); } - return actual_insert(_data, stmt_.get()).or_else([&](const auto& _err) { - rollback(); - stmt_ = nullptr; - return error(_err.what()); - }); + return actual_insert(_data, stmt_.get(), nullptr) + .or_else([&](const auto& _err) { + rollback(); + stmt_ = nullptr; + return error(_err.what()); + }); } Result Connection::end_write() { diff --git a/src/sqlgen/mysql/to_sql.cpp b/src/sqlgen/mysql/to_sql.cpp index 4683a78..d7795d8 100644 --- a/src/sqlgen/mysql/to_sql.cpp +++ b/src/sqlgen/mysql/to_sql.cpp @@ -561,7 +561,15 @@ std::string insert_or_write_to_sql(const InsertOrWrite& _stmt) noexcept { std::stringstream stream; - stream << "INSERT INTO "; + stream << "INSERT "; + + if constexpr (std::is_same_v) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::ignore) { + stream << "IGNORE "; + } + } + + stream << "INTO "; if (_stmt.table.schema) { stream << wrap_in_quotes(*_stmt.table.schema) << "."; } @@ -579,7 +587,7 @@ std::string insert_or_write_to_sql(const InsertOrWrite& _stmt) noexcept { internal::collect::vector(_stmt.columns | transform(to_questionmark))); stream << ")"; if constexpr (std::is_same_v) { - if (_stmt.or_replace) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::replace) { stream << " ON DUPLICATE KEY UPDATE "; stream << internal::strings::join( ", ", diff --git a/src/sqlgen/postgres/Connection.cpp b/src/sqlgen/postgres/Connection.cpp index 1494536..3b0512a 100644 --- a/src/sqlgen/postgres/Connection.cpp +++ b/src/sqlgen/postgres/Connection.cpp @@ -112,8 +112,8 @@ bool Connection::consume_input() noexcept { Result Connection::insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept { + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept { if (_data.size() == 0) { return Nothing{}; } @@ -164,7 +164,28 @@ Result Connection::insert_impl( const auto status = PQresultStatus(res.ptr()); - if (status != PGRES_COMMAND_OK) { + if (_returned_ids) { + if (status != PGRES_TUPLES_OK) { + const auto err = error( + std::string("Executing INSERT ... RETURNING failed: ") + + PQresultErrorMessage(res.ptr())); + execute("DEALLOCATE " + name + ";"); + return err; + } + + if (PQnfields(res.ptr()) < 1 || PQntuples(res.ptr()) != 1) { + execute("DEALLOCATE " + name + ";"); + return error( + "INSERT ... RETURNING must return exactly one row " + "and at least one column per input row."); + } + + if (PQgetisnull(res.ptr(), 0, 0) == 1) { + _returned_ids->emplace_back(std::nullopt); + } else { + _returned_ids->emplace_back(PQgetvalue(res.ptr(), 0, 0)); + } + } else if (status != PGRES_COMMAND_OK) { const auto err = error(std::string("Executing INSERT failed: ") + PQresultErrorMessage(res.ptr())); execute("DEALLOCATE " + name + ";"); diff --git a/src/sqlgen/postgres/to_sql.cpp b/src/sqlgen/postgres/to_sql.cpp index 534c650..7455078 100644 --- a/src/sqlgen/postgres/to_sql.cpp +++ b/src/sqlgen/postgres/to_sql.cpp @@ -494,7 +494,7 @@ std::string insert_to_sql(const dynamic::Insert& _stmt) noexcept { transform(to_placeholder))); stream << ")"; - if (_stmt.or_replace) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::replace) { stream << " ON CONFLICT ("; stream << internal::strings::join( ", ", internal::collect::vector(_stmt.constraints)); @@ -504,6 +504,15 @@ std::string insert_to_sql(const dynamic::Insert& _stmt) noexcept { stream << internal::strings::join( ", ", internal::collect::vector(_stmt.columns | transform(as_excluded))); + } else if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::ignore) { + stream << " ON CONFLICT DO NOTHING"; + } + + if (_stmt.returning.size() != 0) { + stream << " RETURNING "; + stream << internal::strings::join( + ", ", + internal::collect::vector(_stmt.returning | transform(wrap_in_quotes))); } stream << ";"; diff --git a/src/sqlgen/sqlite/Connection.cpp b/src/sqlgen/sqlite/Connection.cpp index 3d60ccb..9ef0a33 100644 --- a/src/sqlgen/sqlite/Connection.cpp +++ b/src/sqlgen/sqlite/Connection.cpp @@ -18,7 +18,8 @@ Connection::~Connection() = default; Result Connection::actual_insert( const std::vector>>& _data, - sqlite3_stmt* _stmt) const noexcept { + sqlite3_stmt* _stmt, + std::vector>* _returned_ids) const noexcept { for (const auto& row : _data) { const auto num_cols = static_cast(row.size()); @@ -39,7 +40,32 @@ Result Connection::actual_insert( } auto res = sqlite3_step(_stmt); - if (res != SQLITE_OK && res != SQLITE_ROW && res != SQLITE_DONE) { + + if (_returned_ids) { + if (res != SQLITE_ROW) { + return error("INSERT ... RETURNING did not return a row: " + + std::string(sqlite3_errmsg(conn_.get()))); + } + + if (sqlite3_column_count(_stmt) < 1) { + return error("INSERT ... RETURNING did not produce any columns."); + } + + if (sqlite3_column_type(_stmt, 0) == SQLITE_NULL) { + _returned_ids->emplace_back(std::nullopt); + } else { + const auto* value = sqlite3_column_text(_stmt, 0); + _returned_ids->emplace_back( + std::string(reinterpret_cast(value))); + } + + res = sqlite3_step(_stmt); + if (res != SQLITE_DONE) { + return error( + "INSERT ... RETURNING produced more than one row per " + "input row."); + } + } else if (res != SQLITE_OK && res != SQLITE_ROW && res != SQLITE_DONE) { return error(sqlite3_errmsg(conn_.get())); } @@ -85,11 +111,12 @@ Result Connection::execute(const std::string& _sql) noexcept { Result Connection::insert_impl( const dynamic::Insert& _stmt, - const std::vector>>& - _data) noexcept { + const std::vector>>& _data, + std::vector>* _returned_ids) noexcept { const auto sql = to_sql_impl(_stmt); - return prepare_statement(sql).and_then( - [&](auto _p_stmt) { return actual_insert(_data, _p_stmt.get()); }); + return prepare_statement(sql).and_then([&](auto _p_stmt) { + return actual_insert(_data, _p_stmt.get(), _returned_ids); + }); } typename Connection::ConnPtr Connection::make_conn(const std::string& _fname) { @@ -173,7 +200,7 @@ Result Connection::write_impl( ".write(...)."); } - return actual_insert(_data, stmt_.get()) + return actual_insert(_data, stmt_.get(), nullptr) .or_else([&](const auto& err) -> Result { rollback(); return error(err.what()); diff --git a/src/sqlgen/sqlite/to_sql.cpp b/src/sqlgen/sqlite/to_sql.cpp index 1a53328..63e2ccb 100644 --- a/src/sqlgen/sqlite/to_sql.cpp +++ b/src/sqlgen/sqlite/to_sql.cpp @@ -420,7 +420,15 @@ std::string insert_or_write_to_sql(const InsertOrWrite& _stmt) noexcept { }; std::stringstream stream; - stream << "INSERT INTO "; + stream << "INSERT "; + + if constexpr (std::is_same_v) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::ignore) { + stream << "OR IGNORE "; + } + } + + stream << "INTO "; if (_stmt.table.schema) { stream << wrap_in_quotes(*_stmt.table.schema) << "."; @@ -439,7 +447,7 @@ std::string insert_or_write_to_sql(const InsertOrWrite& _stmt) noexcept { stream << ")"; if constexpr (std::is_same_v) { - if (_stmt.or_replace) { + if (_stmt.conflict_policy == dynamic::Insert::ConflictPolicy::replace) { stream << " ON CONFLICT ("; stream << internal::strings::join( ", ", internal::collect::vector(_stmt.constraints)); @@ -450,6 +458,13 @@ std::string insert_or_write_to_sql(const InsertOrWrite& _stmt) noexcept { ", ", internal::collect::vector(_stmt.columns | transform(as_excluded))); } + + if (_stmt.returning.size() != 0) { + stream << " RETURNING "; + stream << internal::strings::join( + ", ", + internal::collect::vector(_stmt.returning | transform(in_quotes))); + } } stream << ';'; diff --git a/tests/duckdb/test_insert_or_replace.cpp b/tests/duckdb/test_insert_or_replace.cpp index e2f1c28..c55021a 100644 --- a/tests/duckdb/test_insert_or_replace.cpp +++ b/tests/duckdb/test_insert_or_replace.cpp @@ -56,7 +56,7 @@ TEST(duckdb, test_insert_or_replace) { .and_then(insert(std::ref(people1))) .and_then(commit) .and_then(begin_transaction) - .and_then(insert_or_replace(std::ref(people2))) + .and_then(insert(std::ref(people2)) | or_replace) .and_then(commit) .and_then(sqlgen::read> | order_by("id"_c)) .value(); diff --git a/tests/duckdb/test_insert_returning_ids.cpp b/tests/duckdb/test_insert_returning_ids.cpp new file mode 100644 index 0000000..b6d9fbd --- /dev/null +++ b/tests/duckdb/test_insert_returning_ids.cpp @@ -0,0 +1,46 @@ +#include + +#include +#include +#include + +namespace test_insert_returning_ids { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + int age; +}; + +TEST(duckdb, test_insert_returning_ids) { + const auto people = + std::vector({Person{.first_name = "Homer", .age = 45}, + Person{.first_name = "Bart", .age = 10}, + Person{.first_name = "Lisa", .age = 8}}); + + auto ids = std::vector{}; + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people_from_db = + duckdb::connect() + .and_then(drop | if_exists) + .and_then(create_table | if_not_exists) + .and_then(insert(std::ref(people), returning(ids))) + .and_then(sqlgen::read> | order_by("id"_c)) + .value(); + + ASSERT_EQ(ids.size(), people.size()); + EXPECT_EQ(ids, (std::vector{1, 2, 3})); + + auto ids_from_db = std::vector{}; + + for (const auto& person : people_from_db) { + ids_from_db.push_back(person.id()); + } + + EXPECT_EQ(ids, ids_from_db); +} + +} // namespace test_insert_returning_ids diff --git a/tests/duckdb/test_to_insert_or_ignore.cpp b/tests/duckdb/test_to_insert_or_ignore.cpp new file mode 100644 index 0000000..87ac539 --- /dev/null +++ b/tests/duckdb/test_to_insert_or_ignore.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_to_insert_or_ignore { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(duckdb, test_to_insert_or_ignore) { + const auto query = + sqlgen::Insert{}; + + const auto expected = + R"(INSERT OR IGNORE INTO "TestTable" BY NAME ( SELECT "field1" AS "field1", "id" AS "id" FROM sqlgen_appended_data);)"; + + EXPECT_EQ(sqlgen::duckdb::to_sql(query), expected); +} + +} // namespace test_to_insert_or_ignore diff --git a/tests/duckdb/test_to_insert_returning.cpp b/tests/duckdb/test_to_insert_returning.cpp new file mode 100644 index 0000000..8e48b77 --- /dev/null +++ b/tests/duckdb/test_to_insert_returning.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_to_insert_returning { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(duckdb, test_to_insert_returning) { + const auto query = + sqlgen::Insert>{}; + + const auto expected = + R"(INSERT INTO "TestTable" BY NAME ( SELECT "field1" AS "field1" FROM sqlgen_appended_data) RETURNING "id";)"; + + EXPECT_EQ(sqlgen::duckdb::to_sql(query), expected); +} + +} // namespace test_to_insert_returning diff --git a/tests/mysql/test_insert_or_ignore_dry.cpp b/tests/mysql/test_insert_or_ignore_dry.cpp new file mode 100644 index 0000000..890f0f3 --- /dev/null +++ b/tests/mysql/test_insert_or_ignore_dry.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_insert_or_ignore_dry { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(mysql, test_insert_or_ignore_dry) { + const auto query = + sqlgen::Insert{}; + + const auto expected = + R"(INSERT IGNORE INTO `TestTable` (`field1`, `id`) VALUES (?, ?);)"; + + EXPECT_EQ(sqlgen::mysql::to_sql(query), expected); +} + +} // namespace test_insert_or_ignore_dry diff --git a/tests/mysql/test_insert_or_replace.cpp b/tests/mysql/test_insert_or_replace.cpp index 079c3a3..5aee5bd 100644 --- a/tests/mysql/test_insert_or_replace.cpp +++ b/tests/mysql/test_insert_or_replace.cpp @@ -62,7 +62,7 @@ TEST(mysql, test_insert_or_replace) { .and_then(insert(people1)) .and_then(commit) .and_then(begin_transaction) - .and_then(insert_or_replace(people2)) + .and_then(insert(std::ref(people2)) | or_replace) .and_then(commit) .and_then(sqlgen::read> | order_by("id"_c)) .value(); diff --git a/tests/mysql/test_insert_returning_dry.cpp b/tests/mysql/test_insert_returning_dry.cpp new file mode 100644 index 0000000..ccccdf6 --- /dev/null +++ b/tests/mysql/test_insert_returning_dry.cpp @@ -0,0 +1,24 @@ +#include + +#include +#include +#include + +namespace test_insert_returning_dry { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(mysql, test_insert_returning_dry) { + const auto query = + sqlgen::Insert>{}; + + const auto expected = R"(INSERT INTO `TestTable` (`field1`) VALUES (?);)"; + + EXPECT_EQ(sqlgen::mysql::to_sql(query), expected); +} + +} // namespace test_insert_returning_dry diff --git a/tests/mysql/test_insert_returning_ids.cpp b/tests/mysql/test_insert_returning_ids.cpp new file mode 100644 index 0000000..be25518 --- /dev/null +++ b/tests/mysql/test_insert_returning_ids.cpp @@ -0,0 +1,44 @@ +#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY + +#include + +#include +#include +#include + +namespace test_insert_returning_ids { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + int age; +}; + +TEST(mysql, test_insert_returning_ids) { + const auto person = Person{.first_name = "Homer", .age = 45}; + + auto ids = std::vector{}; + + const auto credentials = sqlgen::mysql::Credentials{.host = "localhost", + .user = "sqlgen", + .password = "password", + .dbname = "mysql"}; + + using namespace sqlgen; + + const auto people_from_db = + mysql::connect(credentials) + .and_then(drop | if_exists) + .and_then(create_table | if_not_exists) + .and_then(insert(std::ref(person), returning(ids))) + .and_then(sqlgen::read>) + .value(); + + ASSERT_EQ(ids.size(), 1u); + ASSERT_EQ(people_from_db.size(), 1u); + EXPECT_EQ(ids.at(0), people_from_db.at(0).id()); +} + +} // namespace test_insert_returning_ids + +#endif diff --git a/tests/postgres/test_insert_or_ignore_dry.cpp b/tests/postgres/test_insert_or_ignore_dry.cpp new file mode 100644 index 0000000..e1b6f5d --- /dev/null +++ b/tests/postgres/test_insert_or_ignore_dry.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_insert_or_ignore_dry { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(postgres, test_insert_or_ignore_dry) { + const auto query = + sqlgen::Insert{}; + + const auto expected = + R"(INSERT INTO "TestTable" ("field1", "id") VALUES ($1, $2) ON CONFLICT DO NOTHING;)"; + + EXPECT_EQ(sqlgen::postgres::to_sql(query), expected); +} + +} // namespace test_insert_or_ignore_dry diff --git a/tests/postgres/test_insert_or_replace.cpp b/tests/postgres/test_insert_or_replace.cpp index 4b76d3c..c170a67 100644 --- a/tests/postgres/test_insert_or_replace.cpp +++ b/tests/postgres/test_insert_or_replace.cpp @@ -62,7 +62,7 @@ TEST(postgres, test_insert_or_replace) { .and_then(insert(people1)) .and_then(commit) .and_then(begin_transaction) - .and_then(insert_or_replace(people2)) + .and_then(insert(people2, or_replace)) .and_then(commit) .and_then(sqlgen::read> | order_by("id"_c)) .value(); diff --git a/tests/postgres/test_insert_returning_dry.cpp b/tests/postgres/test_insert_returning_dry.cpp new file mode 100644 index 0000000..fbbc101 --- /dev/null +++ b/tests/postgres/test_insert_returning_dry.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_insert_returning_dry { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(postgres, test_insert_returning_dry) { + const auto query = + sqlgen::Insert>{}; + + const auto expected = + R"(INSERT INTO "TestTable" ("field1") VALUES ($1) RETURNING "id";)"; + + EXPECT_EQ(sqlgen::postgres::to_sql(query), expected); +} + +} // namespace test_insert_returning_dry diff --git a/tests/postgres/test_insert_returning_ids.cpp b/tests/postgres/test_insert_returning_ids.cpp new file mode 100644 index 0000000..6943841 --- /dev/null +++ b/tests/postgres/test_insert_returning_ids.cpp @@ -0,0 +1,55 @@ +#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY + +#include + +#include +#include +#include + +namespace test_insert_returning_ids { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + int age; +}; + +TEST(postgres, test_insert_returning_ids) { + const auto people = + std::vector({Person{.first_name = "Homer", .age = 45}, + Person{.first_name = "Bart", .age = 10}, + Person{.first_name = "Lisa", .age = 8}}); + + auto ids = std::vector{}; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people_from_db = + postgres::connect(credentials) + .and_then(drop | if_exists) + .and_then(create_table | if_not_exists) + .and_then(insert(std::ref(people), returning(ids))) + .and_then(sqlgen::read> | order_by("id"_c)) + .value(); + + ASSERT_EQ(ids.size(), people.size()); + EXPECT_EQ(ids, (std::vector{1, 2, 3})); + + auto ids_from_db = std::vector{}; + + for (const auto& person : people_from_db) { + ids_from_db.push_back(person.id()); + } + + EXPECT_EQ(ids, ids_from_db); +} + +} // namespace test_insert_returning_ids + +#endif diff --git a/tests/sqlite/test_insert_or_replace.cpp b/tests/sqlite/test_insert_or_replace.cpp index eddce8a..8de9e30 100644 --- a/tests/sqlite/test_insert_or_replace.cpp +++ b/tests/sqlite/test_insert_or_replace.cpp @@ -56,7 +56,7 @@ TEST(sqlite, test_insert_or_replace) { .and_then(insert(people1)) .and_then(commit) .and_then(begin_transaction) - .and_then(insert_or_replace(std::ref(people2))) + .and_then(insert(std::ref(people2)) | or_replace) .and_then(commit) .and_then(sqlgen::read> | order_by("id"_c)) .value(); diff --git a/tests/sqlite/test_insert_returning_ids.cpp b/tests/sqlite/test_insert_returning_ids.cpp new file mode 100644 index 0000000..052be09 --- /dev/null +++ b/tests/sqlite/test_insert_returning_ids.cpp @@ -0,0 +1,46 @@ +#include + +#include +#include +#include + +namespace test_insert_returning_ids { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + int age; +}; + +TEST(sqlite, test_insert_returning_ids) { + const auto people = + std::vector({Person{.first_name = "Homer", .age = 45}, + Person{.first_name = "Bart", .age = 10}, + Person{.first_name = "Lisa", .age = 8}}); + + auto ids = std::vector{}; + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people_from_db = + sqlite::connect() + .and_then(drop | if_exists) + .and_then(create_table | if_not_exists) + .and_then(insert(std::ref(people), returning(ids))) + .and_then(sqlgen::read> | order_by("id"_c)) + .value(); + + ASSERT_EQ(ids.size(), people.size()); + EXPECT_EQ(ids, (std::vector{1, 2, 3})); + + auto ids_from_db = std::vector{}; + + for (const auto& person : people_from_db) { + ids_from_db.push_back(person.id()); + } + + EXPECT_EQ(ids, ids_from_db); +} + +} // namespace test_insert_returning_ids diff --git a/tests/sqlite/test_to_insert_or_ignore.cpp b/tests/sqlite/test_to_insert_or_ignore.cpp new file mode 100644 index 0000000..44333e1 --- /dev/null +++ b/tests/sqlite/test_to_insert_or_ignore.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include + +namespace test_to_insert_or_ignore { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(sqlite, test_to_insert_or_ignore) { + const auto query = + sqlgen::Insert{}; + + const auto expected = + R"(INSERT OR IGNORE INTO "TestTable" ("field1", "id") VALUES (?, ?);)"; + + EXPECT_EQ(sqlgen::sqlite::to_sql(query), expected); +} + +} // namespace test_to_insert_or_ignore diff --git a/tests/sqlite/test_to_insert_or_replace_tag.cpp b/tests/sqlite/test_to_insert_or_replace_tag.cpp new file mode 100644 index 0000000..d490830 --- /dev/null +++ b/tests/sqlite/test_to_insert_or_replace_tag.cpp @@ -0,0 +1,28 @@ +#include + +#include +#include +#include + +namespace test_to_insert_or_replace_tag { + +struct TestTable { + std::string field1; + int32_t field2; + sqlgen::Unique field3; + sqlgen::PrimaryKey id; + std::optional nullable; +}; + +TEST(sqlite, test_to_insert_or_replace_tag) { + const auto query = + sqlgen::Insert{}; + + const auto expected = + R"(INSERT INTO "TestTable" ("field1", "field2", "field3", "id", "nullable") VALUES (?, ?, ?, ?, ?) ON CONFLICT (field3, id) DO UPDATE SET field1=excluded.field1, field2=excluded.field2, field3=excluded.field3, id=excluded.id, nullable=excluded.nullable;)"; + + EXPECT_EQ(sqlgen::sqlite::to_sql(query), expected); +} + +} // namespace test_to_insert_or_replace_tag diff --git a/tests/sqlite/test_to_insert_returning.cpp b/tests/sqlite/test_to_insert_returning.cpp new file mode 100644 index 0000000..eecb0d4 --- /dev/null +++ b/tests/sqlite/test_to_insert_returning.cpp @@ -0,0 +1,26 @@ +#include + +#include +#include +#include + +namespace test_to_insert_returning { + +struct TestTable { + std::string field1; + sqlgen::PrimaryKey id; +}; + +TEST(sqlite, test_to_insert_returning) { + // Use a concrete IDs type (e.g., std::vector) instead of bool + const auto query = + sqlgen::Insert>{}; + + const auto expected = + R"(INSERT INTO "TestTable" ("field1") VALUES (?) RETURNING "id";)"; + + EXPECT_EQ(sqlgen::sqlite::to_sql(query), expected); +} + +} // namespace test_to_insert_returning