Skip to content

Commit

Permalink
Generated Keys (#6)
Browse files Browse the repository at this point in the history
- Support for generated keys
- Fix broken string interpolation
  • Loading branch information
nob13 authored Sep 23, 2024
1 parent 1f6216d commit 9956e84
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/main/scala/usql/RawSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import scala.util.Using

/** Raw SQL Query string. */
case class RawSql(sql: String) extends SqlBase {
override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = {
override def withPreparedStatement[T](
f: PreparedStatement => T
)(using cp: ConnectionProvider, sp: StatementPreparator): T = {
cp.withConnection {
val c = summon[Connection]
Using.resource(c.prepareStatement(sql)) { statement =>
Using.resource(sp.prepare(c, sql)) { statement =>
f(statement)
}
}
Expand Down
20 changes: 12 additions & 8 deletions src/main/scala/usql/Sql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ extension (sc: StringContext) {
val replacedPart = part.stripSuffix("#") + param.dataType.serialize(param.value)
fix(restParts, restParams, (replacedPart, SqlInterpolationParameter.Empty) :: builder)
case (part :: restParts, (param: InnerSql) :: restParams) =>
// Innre Sql
val combined = param.sql.parts.toList.reverse ++ builder
val newParts = if (part.isEmpty) {
combined
// Inner Sql

val inner = if (part.isEmpty) {
Nil
} else {
(part, SqlInterpolationParameter.Empty) :: combined
List(part -> SqlInterpolationParameter.Empty)
}
fix(restParts, restParams, newParts)

val combined = param.sql.parts.toList.reverse ++ inner ++ builder
fix(restParts, restParams, combined)
case (part :: restParts, param :: restParams) =>
// Regular Case
fix(restParts, restParams, (part, param) :: builder)
Expand All @@ -62,10 +64,12 @@ case class Sql(parts: Seq[(String, SqlInterpolationParameter)]) extends SqlBase
p
}

override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = {
override def withPreparedStatement[T](
f: PreparedStatement => T
)(using cp: ConnectionProvider, sp: StatementPreparator): T = {
cp.withConnection {
val c = summon[Connection]
Using.resource(c.prepareStatement(sql)) { statement =>
Using.resource(sp.prepare(c, sql)) { statement =>
sqlParameters.zipWithIndex.foreach { case (param, idx) =>
param.dataType.fillByZeroBasedIdx(idx, statement, param.value)
}
Expand Down
32 changes: 29 additions & 3 deletions src/main/scala/usql/SqlBase.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package usql

import java.sql.PreparedStatement
import java.sql.{Connection, PreparedStatement, Statement}

/** Something which can create prepared statements. */
trait SqlBase {

/** Prepares a statement which can then be further filled or executed. */
def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T
def withPreparedStatement[T](
f: PreparedStatement => T
)(using cp: ConnectionProvider, prep: StatementPreparator = StatementPreparator.default): T

/** Turns into a query */
def query: Query = Query(this)
Expand All @@ -32,9 +34,33 @@ trait SqlBase {
}
}

/** Hook for changing the preparation of SQL. */
trait StatementPreparator {
def prepare(connection: Connection, sql: String): PreparedStatement
}

object StatementPreparator {

/** Default Implementation */
object default extends StatementPreparator {
override def prepare(connection: Connection, sql: String): PreparedStatement = {
connection.prepareStatement(sql)
}
}

/** Statement should return generated keys */
object withGeneratedKeys extends StatementPreparator {
override def prepare(connection: Connection, sql: String): PreparedStatement = {
connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
}
}
}

/** With supplied arguments */
case class AppliedSql[T](base: SqlBase, parameter: T, parameterFiller: ParameterFiller[T]) extends SqlBase {
override def withPreparedStatement[T](f: PreparedStatement => T)(using cp: ConnectionProvider): T = {
override def withPreparedStatement[T](
f: PreparedStatement => T
)(using cp: ConnectionProvider, sp: StatementPreparator): T = {
base.withPreparedStatement { ps =>
parameterFiller.fill(ps, parameter)
f(ps)
Expand Down
28 changes: 28 additions & 0 deletions src/main/scala/usql/Update.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
package usql

import usql.Update.SqlResultMissingGenerated

import java.sql.SQLException
import scala.util.Using

/** Encapsulates an update statement */
case class Update(sql: SqlBase) {

/** Run the update statement */
def run()(using c: ConnectionProvider): Int = {
sql.withPreparedStatement(_.executeUpdate())
}

/**
* Run the update statement and get generated values. See [[java.sql.PreparedStatement.getGeneratedKeys()]]
*/
def runAndGetGenerated[T]()(using d: ResultRowDecoder[T], c: ConnectionProvider): T = {
given sp: StatementPreparator = StatementPreparator.withGeneratedKeys
sql.withPreparedStatement { statement =>
statement.executeUpdate()
Using.resource(statement.getGeneratedKeys) { resultSet =>
if (resultSet.next()) {
d.parseRow(resultSet)
} else {
throw new SqlResultMissingGenerated("Missing row for getGeneratedKeys")
}
}
}
}
}

object Update {

/** Exception thrown if the result set has no generated data. */
class SqlResultMissingGenerated(msg: String) extends SQLException(msg)
}
4 changes: 4 additions & 0 deletions src/main/scala/usql/dao/Crd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ abstract class CrdBase[T] extends Crd[T] {
private lazy val insertStatement =
sql"INSERT INTO ${tabular.tableName} (${tabular.columns}) VALUES (${tabular.columns.placeholders})"

override def insert(value: T)(using ConnectionProvider): Int = {
insertStatement.one(value).update.run()
}

override def insert(values: Seq[T])(using ConnectionProvider): Int = {
insertStatement.batch(values).run().sum
}
Expand Down
53 changes: 53 additions & 0 deletions src/test/scala/usql/AutoGeneratedUpdateTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package usql

import usql.dao.{KeyedCrudBase, SqlColumnar, SqlTabular}
import usql.util.TestBaseWithH2

class AutoGeneratedUpdateTest extends TestBaseWithH2 {

override protected def baseSql: String =
"""
|CREATE TABLE tenant (
| id SERIAL NOT NULL PRIMARY KEY,
| name TEXT
|);
|""".stripMargin

case class Tenant(
id: Int,
name: Option[String]
) derives SqlTabular

object Tenant extends KeyedCrudBase[Int, Tenant] {
override val keyColumn: SqlIdentifier = "id"

override def keyOf(value: Tenant): Int = value.id

override lazy val tabular: SqlTabular[Tenant] = summon
}

it should "be possible to insert values" in {
Tenant.findAll() shouldBe empty
val sample = Tenant(1, Some("Hello World"))
Tenant.insert(sample)
Tenant.findAll() shouldBe Seq(sample)
}

it should "be possible to use auto generated keys" in {
sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.run()
sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.run()
Tenant.findAll() should contain theSameElementsAs Seq(
Tenant(1, Some("Alice")),
Tenant(2, Some("Bob"))
)
}

it should "be possible to return auto generated keys" in {
val id1 = sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.runAndGetGenerated[Int]()
val id2 = sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.runAndGetGenerated[Int]()
Tenant.findAll() should contain theSameElementsAs Seq(
Tenant(id1, Some("Alice")),
Tenant(id2, Some("Bob"))
)
}
}
14 changes: 14 additions & 0 deletions src/test/scala/usql/SqlInterpolationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ class SqlInterpolationTest extends TestBase {
)
)
}

it should "also work in another case" in {
val inner = sql"C = ${2}"
val foo = sql"HELLO a = ${1} AND"
val combined = (sql"HELLO a = ${1} AND ${inner}")
combined shouldBe Sql(
Seq(
"HELLO a = " -> SqlParameter(1),
" AND " -> Empty,
"C = " -> SqlParameter(2)
)
)
combined.sql shouldBe "HELLO a = ? AND C = ?"
}
}

0 comments on commit 9956e84

Please sign in to comment.