diff --git a/src/main/scala/usql/RawSql.scala b/src/main/scala/usql/RawSql.scala index bf3ad6e..b44f676 100644 --- a/src/main/scala/usql/RawSql.scala +++ b/src/main/scala/usql/RawSql.scala @@ -5,10 +5,12 @@ import scala.util.Using /** Raw SQL Query string. */ case class RawSql(sql: String) extends SqlBase { - override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = { + override def withPreparedStatement[T]( + f: PreparedStatement => T + )(using cp: ConnectionProvider, sp: StatementPreparator): T = { cp.withConnection { val c = summon[Connection] - Using.resource(c.prepareStatement(sql)) { statement => + Using.resource(sp.prepare(c, sql)) { statement => f(statement) } } diff --git a/src/main/scala/usql/Sql.scala b/src/main/scala/usql/Sql.scala index 8762c1e..8ce1181 100644 --- a/src/main/scala/usql/Sql.scala +++ b/src/main/scala/usql/Sql.scala @@ -35,14 +35,16 @@ extension (sc: StringContext) { val replacedPart = part.stripSuffix("#") + param.dataType.serialize(param.value) fix(restParts, restParams, (replacedPart, SqlInterpolationParameter.Empty) :: builder) case (part :: restParts, (param: InnerSql) :: restParams) => - // Innre Sql - val combined = param.sql.parts.toList.reverse ++ builder - val newParts = if (part.isEmpty) { - combined + // Inner Sql + + val inner = if (part.isEmpty) { + Nil } else { - (part, SqlInterpolationParameter.Empty) :: combined + List(part -> SqlInterpolationParameter.Empty) } - fix(restParts, restParams, newParts) + + val combined = param.sql.parts.toList.reverse ++ inner ++ builder + fix(restParts, restParams, combined) case (part :: restParts, param :: restParams) => // Regular Case fix(restParts, restParams, (part, param) :: builder) @@ -62,10 +64,12 @@ case class Sql(parts: Seq[(String, SqlInterpolationParameter)]) extends SqlBase p } - override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = { + override def withPreparedStatement[T]( + f: PreparedStatement => T + )(using cp: ConnectionProvider, sp: StatementPreparator): T = { cp.withConnection { val c = summon[Connection] - Using.resource(c.prepareStatement(sql)) { statement => + Using.resource(sp.prepare(c, sql)) { statement => sqlParameters.zipWithIndex.foreach { case (param, idx) => param.dataType.fillByZeroBasedIdx(idx, statement, param.value) } diff --git a/src/main/scala/usql/SqlBase.scala b/src/main/scala/usql/SqlBase.scala index aead7ce..7e444a7 100644 --- a/src/main/scala/usql/SqlBase.scala +++ b/src/main/scala/usql/SqlBase.scala @@ -1,12 +1,14 @@ package usql -import java.sql.PreparedStatement +import java.sql.{Connection, PreparedStatement, Statement} /** Something which can create prepared statements. */ trait SqlBase { /** Prepares a statement which can then be further filled or executed. */ - def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T + def withPreparedStatement[T]( + f: PreparedStatement => T + )(using cp: ConnectionProvider, prep: StatementPreparator = StatementPreparator.default): T /** Turns into a query */ def query: Query = Query(this) @@ -32,9 +34,33 @@ trait SqlBase { } } +/** Hook for changing the preparation of SQL. */ +trait StatementPreparator { + def prepare(connection: Connection, sql: String): PreparedStatement +} + +object StatementPreparator { + + /** Default Implementation */ + object default extends StatementPreparator { + override def prepare(connection: Connection, sql: String): PreparedStatement = { + connection.prepareStatement(sql) + } + } + + /** Statement should return generated keys */ + object withGeneratedKeys extends StatementPreparator { + override def prepare(connection: Connection, sql: String): PreparedStatement = { + connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS) + } + } +} + /** With supplied arguments */ case class AppliedSql[T](base: SqlBase, parameter: T, parameterFiller: ParameterFiller[T]) extends SqlBase { - override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = { + override def withPreparedStatement[T]( + f: PreparedStatement => T + )(using cp: ConnectionProvider, sp: StatementPreparator): T = { base.withPreparedStatement { ps => parameterFiller.fill(ps, parameter) f(ps) diff --git a/src/main/scala/usql/Update.scala b/src/main/scala/usql/Update.scala index a37728a..62f6e14 100644 --- a/src/main/scala/usql/Update.scala +++ b/src/main/scala/usql/Update.scala @@ -1,5 +1,10 @@ package usql +import usql.Update.SqlResultMissingGenerated + +import java.sql.SQLException +import scala.util.Using + /** Encapsulates an update statement */ case class Update(sql: SqlBase) { @@ -7,4 +12,27 @@ case class Update(sql: SqlBase) { def run()(using c: ConnectionProvider): Int = { sql.withPreparedStatement(_.executeUpdate()) } + + /** + * Run the update statement and get generated values. See [[java.sql.PreparedStatement.getGeneratedKeys()]] + */ + def runAndGetGenerated[T]()(using d: ResultRowDecoder[T], c: ConnectionProvider): T = { + given sp: StatementPreparator = StatementPreparator.withGeneratedKeys + sql.withPreparedStatement { statement => + statement.executeUpdate() + Using.resource(statement.getGeneratedKeys) { resultSet => + if (resultSet.next()) { + d.parseRow(resultSet) + } else { + throw new SqlResultMissingGenerated("Missing row for getGeneratedKeys") + } + } + } + } +} + +object Update { + + /** Exception thrown if the result set has no generated data. */ + class SqlResultMissingGenerated(msg: String) extends SQLException(msg) } diff --git a/src/main/scala/usql/dao/Crd.scala b/src/main/scala/usql/dao/Crd.scala index 4325105..5348c9b 100644 --- a/src/main/scala/usql/dao/Crd.scala +++ b/src/main/scala/usql/dao/Crd.scala @@ -59,6 +59,10 @@ abstract class CrdBase[T] extends Crd[T] { private lazy val insertStatement = sql"INSERT INTO ${tabular.tableName} (${tabular.columns}) VALUES (${tabular.columns.placeholders})" + override def insert(value: T)(using ConnectionProvider): Int = { + insertStatement.one(value).update.run() + } + override def insert(values: Seq[T])(using ConnectionProvider): Int = { insertStatement.batch(values).run().sum } diff --git a/src/test/scala/usql/AutoGeneratedUpdateTest.scala b/src/test/scala/usql/AutoGeneratedUpdateTest.scala new file mode 100644 index 0000000..d9814f1 --- /dev/null +++ b/src/test/scala/usql/AutoGeneratedUpdateTest.scala @@ -0,0 +1,53 @@ +package usql + +import usql.dao.{KeyedCrudBase, SqlColumnar, SqlTabular} +import usql.util.TestBaseWithH2 + +class AutoGeneratedUpdateTest extends TestBaseWithH2 { + + override protected def baseSql: String = + """ + |CREATE TABLE tenant ( + | id SERIAL NOT NULL PRIMARY KEY, + | name TEXT + |); + |""".stripMargin + + case class Tenant( + id: Int, + name: Option[String] + ) derives SqlTabular + + object Tenant extends KeyedCrudBase[Int, Tenant] { + override val keyColumn: SqlIdentifier = "id" + + override def keyOf(value: Tenant): Int = value.id + + override lazy val tabular: SqlTabular[Tenant] = summon + } + + it should "be possible to insert values" in { + Tenant.findAll() shouldBe empty + val sample = Tenant(1, Some("Hello World")) + Tenant.insert(sample) + Tenant.findAll() shouldBe Seq(sample) + } + + it should "be possible to use auto generated keys" in { + sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.run() + sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.run() + Tenant.findAll() should contain theSameElementsAs Seq( + Tenant(1, Some("Alice")), + Tenant(2, Some("Bob")) + ) + } + + it should "be possible to return auto generated keys" in { + val id1 = sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.runAndGetGenerated[Int]() + val id2 = sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.runAndGetGenerated[Int]() + Tenant.findAll() should contain theSameElementsAs Seq( + Tenant(id1, Some("Alice")), + Tenant(id2, Some("Bob")) + ) + } +} diff --git a/src/test/scala/usql/SqlInterpolationTest.scala b/src/test/scala/usql/SqlInterpolationTest.scala index 423dd1f..edb06f3 100644 --- a/src/test/scala/usql/SqlInterpolationTest.scala +++ b/src/test/scala/usql/SqlInterpolationTest.scala @@ -80,4 +80,18 @@ class SqlInterpolationTest extends TestBase { ) ) } + + it should "also work in another case" in { + val inner = sql"C = ${2}" + val foo = sql"HELLO a = ${1} AND" + val combined = (sql"HELLO a = ${1} AND ${inner}") + combined shouldBe Sql( + Seq( + "HELLO a = " -> SqlParameter(1), + " AND " -> Empty, + "C = " -> SqlParameter(2) + ) + ) + combined.sql shouldBe "HELLO a = ? AND C = ?" + } }