diff --git a/.buildkite/engineer b/.buildkite/engineer index 5de99cea5390..838758ab565e 100755 --- a/.buildkite/engineer +++ b/.buildkite/engineer @@ -54,7 +54,7 @@ fi # Check if the system has engineer installed, if not, use a local copy. if ! type "engineer" &> /dev/null; then # Setup Prisma engine build & test tool (engineer). - curl --fail -sSL "https://prisma-engineer.s3-eu-west-1.amazonaws.com/1.60/latest/$OS/engineer.gz" --output engineer.gz + curl --fail -sSL "https://prisma-engineer.s3-eu-west-1.amazonaws.com/1.65/latest/$OS/engineer.gz" --output engineer.gz gzip -d engineer.gz chmod +x engineer diff --git a/.envrc b/.envrc index 48b1254c1700..5488da9e10e7 100644 --- a/.envrc +++ b/.envrc @@ -23,7 +23,7 @@ export QE_LOG_LEVEL=debug # Set it to "trace" to enable query-graph debugging lo # export FMT_SQL=1 # Uncomment it to enable logging formatted SQL queries ### Uncomment to run driver adapters tests. See query-engine-driver-adapters.yml workflow for how tests run in CI. -# export EXTERNAL_TEST_EXECUTOR="$(pwd)/query-engine/driver-adapters/js/connector-test-kit-executor/script/start_node.sh" +# export EXTERNAL_TEST_EXECUTOR="napi" # export DRIVER_ADAPTER=pg # Set to pg, neon or planetscale # export PRISMA_DISABLE_QUAINT_EXECUTORS=1 # Disable quaint executors for driver adapters # export DRIVER_ADAPTER_URL_OVERRIDE ="postgres://USER:PASSWORD@DATABASExxxx" # Override the database url for the driver adapter tests diff --git a/.github/workflows/build-apple-intel.yml b/.github/workflows/build-engines-apple-intel.yml similarity index 97% rename from .github/workflows/build-apple-intel.yml rename to .github/workflows/build-engines-apple-intel.yml index 994cbfbb0ad0..9d4e66e1b2fa 100644 --- a/.github/workflows/build-apple-intel.yml +++ b/.github/workflows/build-engines-apple-intel.yml @@ -1,3 +1,4 @@ +name: Build Engines for Apple Intel on: workflow_dispatch: inputs: diff --git a/.github/workflows/build-apple-silicon.yml b/.github/workflows/build-engines-apple-silicon.yml similarity index 97% rename from .github/workflows/build-apple-silicon.yml rename to .github/workflows/build-engines-apple-silicon.yml index 74c49c5154fa..2ba7cb341cc9 100644 --- a/.github/workflows/build-apple-silicon.yml +++ b/.github/workflows/build-engines-apple-silicon.yml @@ -1,3 +1,4 @@ +name: Build Engines for Apple Silicon on: workflow_dispatch: inputs: diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-engines-windows.yml similarity index 97% rename from .github/workflows/build-windows.yml rename to .github/workflows/build-engines-windows.yml index 1dcd11f092ee..ca826698d7e8 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-engines-windows.yml @@ -1,3 +1,4 @@ +name: Build Engines for Windows on: workflow_dispatch: inputs: diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-prisma-schema-wasm.yml similarity index 94% rename from .github/workflows/build-wasm.yml rename to .github/workflows/build-prisma-schema-wasm.yml index 7969cd2dd462..20906326401b 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-prisma-schema-wasm.yml @@ -1,4 +1,4 @@ -name: WASM build +name: Build prisma-schema-wasm on: push: branches: diff --git a/.github/workflows/benchmark.yml b/.github/workflows/codspeed.yml similarity index 97% rename from .github/workflows/benchmark.yml rename to .github/workflows/codspeed.yml index 4dbfa4855fc9..62131fe3b572 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/codspeed.yml @@ -1,4 +1,4 @@ -name: Benchmark +name: Codspeed Benchmark on: push: branches: diff --git a/.github/workflows/qe-wasm-check.yml b/.github/workflows/qe-wasm-check.yml new file mode 100644 index 000000000000..f67d2d247b27 --- /dev/null +++ b/.github/workflows/qe-wasm-check.yml @@ -0,0 +1,27 @@ +name: WASM engine compile check +on: + push: + branches: + - main + pull_request: + paths-ignore: + - '.github/**' + - '!.github/workflows/qe-wasm-check.yml' + - '.buildkite/**' + - '*.md' + - 'LICENSE' + - 'CODEOWNERS' + - 'renovate.json' + +jobs: + build: + name: 'Compilation check for query-engine-wasm' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - name: Install wasm-pack + run: cargo install wasm-pack + - name: Build wasm query engine + run: ./build.sh + working-directory: ./query-engine/query-engine-wasm diff --git a/.github/workflows/send-main-push-event.yml b/.github/workflows/send-main-push-event.yml new file mode 100644 index 000000000000..fa9294cba03f --- /dev/null +++ b/.github/workflows/send-main-push-event.yml @@ -0,0 +1,20 @@ +name: Trigger prisma-engines-builds run +run-name: Trigger prisma-engines-builds run for ${{ github.sha }} + +on: + push: + branches: + - main + +jobs: + send-commit-hash: + runs-on: ubuntu-22.04 + steps: + - run: echo "Sending event for commit $GITHUB_SHA" + - name: Workflow dispatch to prisma/prisma-engines-builds + uses: benc-uk/workflow-dispatch@v1 + with: + workflow: .github/workflows/build-engines.yml + repo: prisma/prisma-engines-builds + token: ${{ secrets.BOT_TOKEN_PRISMA_ENGINES_BUILD }} + inputs: '{ "commit": "${{ github.sha }}" }' diff --git a/.github/workflows/send-tag-event.yml b/.github/workflows/send-tag-event.yml index 2088e258ed49..eb33406d9580 100644 --- a/.github/workflows/send-tag-event.yml +++ b/.github/workflows/send-tag-event.yml @@ -1,4 +1,4 @@ -name: Send Tag Event +name: Send Tag Event to PDP on: push: diff --git a/.github/workflows/compilation.yml b/.github/workflows/test-compilation.yml similarity index 96% rename from .github/workflows/compilation.yml rename to .github/workflows/test-compilation.yml index d9f81f47772b..562c60d2718a 100644 --- a/.github/workflows/compilation.yml +++ b/.github/workflows/test-compilation.yml @@ -1,4 +1,4 @@ -name: 'Release binary compilation test' +name: Test release binary compilation on: pull_request: paths-ignore: diff --git a/.github/workflows/quaint.yml b/.github/workflows/test-quaint.yml similarity index 99% rename from .github/workflows/quaint.yml rename to .github/workflows/test-quaint.yml index 7b49e80a7bd0..6df094302dfc 100644 --- a/.github/workflows/quaint.yml +++ b/.github/workflows/test-quaint.yml @@ -1,4 +1,4 @@ -name: Quaint +name: Test Quaint on: push: branches: diff --git a/.github/workflows/query-engine-black-box.yml b/.github/workflows/test-query-engine-black-box.yml similarity index 96% rename from .github/workflows/query-engine-black-box.yml rename to .github/workflows/test-query-engine-black-box.yml index 5ebcd79cec4c..6494572cc13b 100644 --- a/.github/workflows/query-engine-black-box.yml +++ b/.github/workflows/test-query-engine-black-box.yml @@ -1,4 +1,4 @@ -name: Query Engine Black Box +name: Test Query Engine (Black Box) on: push: branches: @@ -19,7 +19,7 @@ concurrency: jobs: rust-tests: - name: 'Test query-engine as a black-box' + name: 'query-engine as a black-box' strategy: fail-fast: false diff --git a/.github/workflows/query-engine-driver-adapters.yml b/.github/workflows/test-query-engine-driver-adapters.yml similarity index 66% rename from .github/workflows/query-engine-driver-adapters.yml rename to .github/workflows/test-query-engine-driver-adapters.yml index b8434e2fa04c..0a9c933c9b58 100644 --- a/.github/workflows/query-engine-driver-adapters.yml +++ b/.github/workflows/test-query-engine-driver-adapters.yml @@ -1,4 +1,4 @@ -name: Driver Adapters +name: Test Driver Adapters on: push: branches: @@ -19,20 +19,33 @@ concurrency: jobs: rust-query-engine-tests: - name: 'Test `${{ matrix.adapter.name }}` on node v${{ matrix.node_version }}' + name: '${{ matrix.adapter.name }} on node v${{ matrix.node_version }}' strategy: fail-fast: false matrix: adapter: - - name: 'pg' - setup_task: 'dev-pg-postgres13' - - name: 'neon:ws' - setup_task: 'dev-neon-ws-postgres13' - - name: 'libsql' - setup_task: 'dev-libsql-sqlite' - - name: 'planetscale' - setup_task: 'dev-planetscale-vitess8' + - name: '@prisma/adapter-planetscale' + setup_task: 'dev-planetscale-js' + - name: '@prisma/adapter-pg (napi)' + setup_task: 'dev-pg-js' + - name: '@prisma/adapter-neon (ws) (napi)' + setup_task: 'dev-neon-js' + - name: '@prisma/adapter-libsql (Turso) (napi)' + setup_task: 'dev-libsql-js' + # TODO: uncomment when WASM engine is functional + # - name: '@prisma/adapter-planetscale' + # setup_task: 'dev-planetscale-wasm' + # needs_wasm_pack: true + # - name: '@prisma/adapter-pg (wasm)' + # setup_task: 'dev-pg-wasm' + # needs_wasm_pack: true + # - name: '@prisma/adapter-neon (ws) (wasm)' + # setup_task: 'dev-neon-wasm' + # needs_wasm_pack: true + # - name: '@prisma/adapter-libsql (Turso) (wasm)' + # setup_task: 'dev-libsql-wasm' + # needs_wasm_pack: true node_version: ['18'] env: LOG_LEVEL: 'info' # Set to "debug" to trace the query engine and node process running the driver adapter @@ -87,9 +100,13 @@ jobs: echo "DRIVER_ADAPTERS_BRANCH=$branch" >> "$GITHUB_ENV" fi - - run: make ${{ matrix.adapter.setup_task }} - - uses: dtolnay/rust-toolchain@stable + - name: 'Install wasm-pack' + if: ${{ matrix.adapter.needs_wasm_pack }} + run: cargo install wasm-pack + + - run: make ${{ matrix.adapter.setup_task }} + - name: 'Run tests' run: cargo test --package query-engine-tests -- --test-threads=1 diff --git a/.github/workflows/query-engine.yml b/.github/workflows/test-query-engine.yml similarity index 92% rename from .github/workflows/query-engine.yml rename to .github/workflows/test-query-engine.yml index 762c3da4a50a..6d5e0ada4eb3 100644 --- a/.github/workflows/query-engine.yml +++ b/.github/workflows/test-query-engine.yml @@ -1,4 +1,4 @@ -name: Query Engine +name: Test Query Engine on: push: branches: @@ -19,16 +19,12 @@ concurrency: jobs: rust-query-engine-tests: - name: 'Test ${{ matrix.database.name }} (${{ matrix.engine_protocol }}) on Linux' + name: '${{ matrix.database.name }} (${{ matrix.engine_protocol }}) on Linux' strategy: fail-fast: false matrix: database: - - name: 'vitess_5_7' - single_threaded: true - connector: 'vitess' - version: '5.7' - name: 'vitess_8_0' single_threaded: true connector: 'vitess' @@ -41,6 +37,10 @@ jobs: single_threaded: false connector: 'sqlserver' version: '2022' + - name: 'sqlite' + single_threaded: false + connector: 'sqlite' + version: '3' - name: 'mongodb_4_2' single_threaded: true connector: 'mongodb' diff --git a/.github/workflows/schema-engine.yml b/.github/workflows/test-schema-engine.yml similarity index 95% rename from .github/workflows/schema-engine.yml rename to .github/workflows/test-schema-engine.yml index 03d23317bbd0..425085d3af48 100644 --- a/.github/workflows/schema-engine.yml +++ b/.github/workflows/test-schema-engine.yml @@ -1,4 +1,4 @@ -name: Schema Engine +name: Test Schema Engine on: push: branches: @@ -22,7 +22,7 @@ concurrency: jobs: test-mongodb-schema-connector: - name: 'Test ${{ matrix.database.name }} on Linux' + name: '${{ matrix.database.name }} on Linux' strategy: fail-fast: false matrix: @@ -54,7 +54,7 @@ jobs: TEST_DATABASE_URL: ${{ matrix.database.url }} test-linux: - name: 'Test ${{ matrix.database.name }} on Linux' + name: '${{ matrix.database.name }} on Linux' strategy: fail-fast: false @@ -94,11 +94,6 @@ jobs: url: 'postgresql://prisma@localhost:26257' - name: sqlite url: sqlite - - name: vitess_5_7 - url: 'mysql://root:prisma@localhost:33577/test' - shadow_database_url: 'mysql://root:prisma@localhost:33578/shadow' - is_vitess: true - single_threaded: true - name: vitess_8_0 url: 'mysql://root:prisma@localhost:33807/test' shadow_database_url: 'mysql://root:prisma@localhost:33808/shadow' @@ -212,7 +207,7 @@ jobs: runs-on: ${{ matrix.os }} - name: 'Test ${{ matrix.db.name }} on Windows' + name: '${{ matrix.db.name }} on Windows' steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/test-unit-tests.yml similarity index 98% rename from .github/workflows/unit-tests.yml rename to .github/workflows/test-unit-tests.yml index b852499205e9..631e13b19e96 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/test-unit-tests.yml @@ -1,4 +1,4 @@ -name: Unit tests +name: Test Unit tests on: push: branches: diff --git a/.test_database_urls/vitess_5_7 b/.test_database_urls/vitess_5_7 deleted file mode 100644 index 2259628658ac..000000000000 --- a/.test_database_urls/vitess_5_7 +++ /dev/null @@ -1,2 +0,0 @@ -export TEST_DATABASE_URL="mysql://root:prisma@localhost:33577/test" -export TEST_SHADOW_DATABASE_URL="mysql://root:prisma@localhost:33578/shadow" \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS index c1a996de1f21..cb8fc144133d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -* @prisma/team-orm-rust +* @prisma/ORM-Rust diff --git a/Cargo.lock b/Cargo.lock index 35eff530999a..cd2fe6c6b5c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -326,6 +326,7 @@ dependencies = [ "query-engine-metrics", "query-engine-tests", "query-tests-setup", + "regex", "reqwest", "serde_json", "tokio", @@ -2396,9 +2397,9 @@ dependencies = [ [[package]] name = "mobc" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bdeff49b387edef305eccfe166af3e1483bb57902dbf369dddc42dc824df23b" +checksum = "90eb49dc5d193287ff80e72a86f34cfb27aae562299d22fea215e06ea1059dd3" dependencies = [ "async-trait", "futures-channel", @@ -3570,6 +3571,7 @@ dependencies = [ "connection-string", "either", "futures", + "getrandom 0.2.10", "hex", "indoc 0.3.6", "lru-cache", @@ -3678,6 +3680,7 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "pin-project", "prisma-models", "psl", "query-connector", @@ -3693,6 +3696,7 @@ dependencies = [ "tracing-subscriber", "user-facing-errors", "uuid", + "wasm-bindgen-futures", ] [[package]] @@ -3820,9 +3824,14 @@ dependencies = [ "log", "prisma-models", "psl", + "quaint", + "query-connector", + "query-core", + "request-handlers", "serde", "serde-wasm-bindgen", "serde_json", + "sql-query-connector", "thiserror", "tokio", "tracing", @@ -6009,9 +6018,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -6019,9 +6028,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" dependencies = [ "bumpalo", "log", @@ -6046,9 +6055,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6056,9 +6065,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", @@ -6069,9 +6078,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" [[package]] name = "wasm-logger" diff --git a/Cargo.toml b/Cargo.toml index 4a3cd1450caf..b32a1a85cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", + "native", ] [profile.dev.package.backtrace] diff --git a/Makefile b/Makefile index a30a32ca1871..3aec261dc2f0 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,11 @@ ifndef DRIVER_ADAPTER cargo test --package query-engine-tests else @echo "Executing query engine tests with $(DRIVER_ADAPTER) driver adapter"; \ - # Add your actual command for the "test-driver-adapter" task here - $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER); + if [ "$(ENGINE)" = "wasm" ]; then \ + $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER)-wasm; \ + else \ + $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER); \ + fi endif test-qe-verbose: @@ -84,12 +87,18 @@ start-sqlite: dev-sqlite: cp $(CONFIG_PATH)/sqlite $(CONFIG_FILE) -dev-libsql-sqlite: build-qe-napi build-connector-kit-js - cp $(CONFIG_PATH)/libsql-sqlite $(CONFIG_FILE) +dev-libsql-js: build-qe-napi build-connector-kit-js + cp $(CONFIG_PATH)/libsql-js $(CONFIG_FILE) + +test-libsql-js: dev-libsql-js test-qe-st + +test-driver-adapter-libsql: test-libsql-js -test-libsql-sqlite: dev-libsql-sqlite test-qe-st +dev-libsql-wasm: build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/libsql-wasm $(CONFIG_FILE) -test-driver-adapter-libsql: test-libsql-sqlite +test-libsql-wasm: dev-libsql-wasm test-qe-st +test-driver-adapter-libsql-wasm: test-libsql-wasm start-postgres9: docker compose -f docker-compose.yml up --wait -d --remove-orphans postgres9 @@ -121,24 +130,36 @@ start-postgres13: dev-postgres13: start-postgres13 cp $(CONFIG_PATH)/postgres13 $(CONFIG_FILE) -start-pg-postgres13: build-qe-napi build-connector-kit-js start-postgres13 +start-pg-js: start-postgres13 + +dev-pg-js: start-pg-js build-qe-napi build-connector-kit-js + cp $(CONFIG_PATH)/pg-js $(CONFIG_FILE) + +test-pg-js: dev-pg-js test-qe-st + +dev-pg-wasm: start-pg-js build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/pg-wasm $(CONFIG_FILE) -dev-pg-postgres13: start-pg-postgres13 - cp $(CONFIG_PATH)/pg-postgres13 $(CONFIG_FILE) +test-pg-wasm: dev-pg-wasm test-qe-st -test-pg-postgres13: dev-pg-postgres13 test-qe-st +test-driver-adapter-pg: test-pg-js +test-driver-adapter-pg-wasm: test-pg-wasm -test-driver-adapter-pg: test-pg-postgres13 +start-neon-js: + docker compose -f docker-compose.yml up --wait -d --remove-orphans neon-proxy -start-neon-postgres13: build-qe-napi build-connector-kit-js - docker compose -f docker-compose.yml up --wait -d --remove-orphans neon-postgres13 +dev-neon-js: start-neon-js build-qe-napi build-connector-kit-js + cp $(CONFIG_PATH)/neon-js $(CONFIG_FILE) -dev-neon-ws-postgres13: start-neon-postgres13 - cp $(CONFIG_PATH)/neon-ws-postgres13 $(CONFIG_FILE) +test-neon-js: dev-neon-js test-qe-st -test-neon-ws-postgres13: dev-neon-ws-postgres13 test-qe-st +dev-neon-wasm: start-neon-js build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/neon-wasm $(CONFIG_FILE) -test-driver-adapter-neon: test-neon-ws-postgres13 +test-neon-wasm: dev-neon-wasm test-qe-st + +test-driver-adapter-neon: test-neon-js +test-driver-adapter-neon-wasm: test-neon-wasm start-postgres14: docker compose -f docker-compose.yml up --wait -d --remove-orphans postgres14 @@ -256,27 +277,27 @@ dev-mongodb_5: start-mongodb_5 dev-mongodb_4_2: start-mongodb_4_2 cp $(CONFIG_PATH)/mongodb42 $(CONFIG_FILE) -start-vitess_5_7: - docker compose -f docker-compose.yml up --wait -d --remove-orphans vitess-test-5_7 vitess-shadow-5_7 - -dev-vitess_5_7: start-vitess_5_7 - cp $(CONFIG_PATH)/vitess_5_7 $(CONFIG_FILE) - start-vitess_8_0: docker compose -f docker-compose.yml up --wait -d --remove-orphans vitess-test-8_0 vitess-shadow-8_0 dev-vitess_8_0: start-vitess_8_0 cp $(CONFIG_PATH)/vitess_8_0 $(CONFIG_FILE) -start-planetscale-vitess8: build-qe-napi build-connector-kit-js - docker compose -f docker-compose.yml up -d --remove-orphans planetscale-vitess8 +start-planetscale-js: + docker compose -f docker-compose.yml up -d --remove-orphans planetscale-proxy -dev-planetscale-vitess8: start-planetscale-vitess8 - cp $(CONFIG_PATH)/planetscale-vitess8 $(CONFIG_FILE) +dev-planetscale-js: start-planetscale-js build-qe-napi build-connector-kit-js + cp $(CONFIG_PATH)/planetscale-js $(CONFIG_FILE) -test-planetscale-vitess8: dev-planetscale-vitess8 test-qe-st +test-planetscale-js: dev-planetscale-js test-qe-st -test-driver-adapter-planetscale: test-planetscale-vitess8 +dev-planetscale-wasm: start-planetscale-js build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/planetscale-wasm $(CONFIG_FILE) + +test-planetscale-wasm: dev-planetscale-wasm test-qe-st + +test-driver-adapter-planetscale: test-planetscale-js +test-driver-adapter-planetscale-wasm: test-planetscale-wasm ###################### # Local dev commands # @@ -285,6 +306,9 @@ test-driver-adapter-planetscale: test-planetscale-vitess8 build-qe-napi: cargo build --package query-engine-node-api +build-qe-wasm: + cd query-engine/query-engine-wasm && ./build.sh + build-connector-kit-js: build-driver-adapters cd query-engine/driver-adapters && pnpm i && pnpm build diff --git a/README.md b/README.md index 49c7c1a8ab39..502f0d31e8ae 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Prisma Engines -[![Query Engine](https://github.com/prisma/prisma-engines/actions/workflows/query-engine.yml/badge.svg)](https://github.com/prisma/prisma-engines/actions/workflows/query-engine.yml) -[![Schema Engine + sql_schema_describer](https://github.com/prisma/prisma-engines/actions/workflows/schema-engine.yml/badge.svg)](https://github.com/prisma/prisma-engines/actions/workflows/schema-engine.yml) +[![Query Engine](https://github.com/prisma/prisma-engines/actions/workflows/test-query-engine.yml/badge.svg)](https://github.com/prisma/prisma-engines/actions/workflows/test-query-engine.yml) +[![Schema Engine + sql_schema_describer](https://github.com/prisma/prisma-engines/actions/workflows/test-schema-engine.yml/badge.svg)](https://github.com/prisma/prisma-engines/actions/workflows/test-schema-engine.yml) [![Cargo docs](https://github.com/prisma/prisma-engines/actions/workflows/on-push-to-main.yml/badge.svg)](https://github.com/prisma/prisma-engines/actions/workflows/on-push-to-main.yml) This repository contains a collection of engines that power the core stack for @@ -259,6 +259,29 @@ GitHub actions will then pick up the branch name and use it to clone that branch When it's time to merge the sibling PRs, you'll need to merge the prisma/prisma PR first, so when merging the engines PR you have the code of the adapters ready in prisma/prisma `main` branch. +### Testing engines in `prisma/prisma` + +You can trigger releases from this repository to npm that can be used for testing the engines in `prisma/prisma` either automatically or manually: + +#### Automated integration releases from this repository to npm + +(Since July 2022). Any branch name starting with `integration/` will, first, run the full test suite in Buildkite `[Test] Prisma Engines` and, second, if passing, run the publish pipeline (build and upload engines to S3 & R2) + +The journey through the pipeline is the same as a commit on the `main` branch. +- It will trigger [`prisma/engines-wrapper`](https://github.com/prisma/engines-wrapper) and publish a new [`@prisma/engines-version`](https://www.npmjs.com/package/@prisma/engines-version) npm package but on the `integration` tag. +- Which triggers [`prisma/prisma`](https://github.com/prisma/prisma) to create a `chore(Automated Integration PR): [...]` PR with a branch name also starting with `integration/` +- Since in `prisma/prisma` we also trigger the publish pipeline when a branch name starts with `integration/`, this will publish all `prisma/prisma` monorepo packages to npm on the `integration` tag. +- Our [ecosystem-tests](https://github.com/prisma/ecosystem-tests/) tests will automatically pick up this new version and run tests, results will show in [GitHub Actions](https://github.com/prisma/ecosystem-tests/actions?query=branch%3Aintegration) + +This end to end will take minimum ~1h20 to complete, but is completely automated :robot: + +Notes: +- in `prisma/prisma` repository, we do not run tests for `integration/` branches, it is much faster and also means that there is no risk of tests failing (e.g. flaky tests, snapshots) that would stop the publishing process. +- in `prisma/prisma-engines` the Buildkite test pipeline must first pass, then the engines will be built and uploaded to our storage via the Buildkite release pipeline. These 2 pipelines can fail for different reasons, it's recommended to keep an eye on them (check notifications in Slack) and restart jobs as needed. Finally, it will trigger [`prisma/engines-wrapper`](https://github.com/prisma/engines-wrapper). + +#### Manual integration releases from this repository to npm + +Additionally to the automated integration release for `integration/` branches, you can also trigger a publish **manually** in the Buildkite `[Test] Prisma Engines` job if that succeeds for _any_ branch name. Click "🚀 Publish binaries" at the bottom of the test list to unlock the publishing step. When all the jobs in `[Release] Prisma Engines` succeed, you also have to unlock the next step by clicking "🚀 Publish client". This will then trigger the same journey as described above. ## Parallel rust-analyzer builds @@ -269,22 +292,25 @@ rust-analyzer. To avoid this. Open VSCode settings and search for `Check on Save --target-dir:/tmp/rust-analyzer-check ``` -### Automated integration releases from this repository to npm -(Since July 2022). Any branch name starting with `integration/` will, first, run the full test suite and, second, if passing, run the publish pipeline (build and upload engines to S3) +## Community PRs: create a local branch for a branch coming from a fork -The journey through the pipeline is the same as a commit on the `main` branch. -- It will trigger [prisma/engines-wrapper](https://github.com/prisma/engines-wrapper) and publish a new [`@prisma/engines-version`](https://www.npmjs.com/package/@prisma/engines-version) npm package but on the `integration` tag. -- Which triggers [prisma/prisma](https://github.com/prisma/prisma) to create a `chore(Automated Integration PR): [...]` PR with a branch name also starting with `integration/` -- Since in prisma/prisma we also trigger the publish pipeline when a branch name starts with `integration/`, this will publish all prisma/prisma monorepo packages to npm on the `integration` tag. -- Our [ecosystem-tests](https://github.com/prisma/ecosystem-tests/) tests will automatically pick up this new version and run tests, results will show in [GitHub Actions](https://github.com/prisma/ecosystem-tests/actions?query=branch%3Aintegration) +To trigger an [Automated integration releases from this repository to npm](#automated-integration-releases-from-this-repository-to-npm) or [Manual integration releases from this repository to npm](#manual-integration-releases-from-this-repository-to-npm) branches of forks need to be pulled into this repository so the Buildkite job is triggered. You can use these GitHub and git CLI commands to achieve that easily: -This end to end will take minimum ~1h20 to complete, but is completely automated :robot: +``` +gh pr checkout 4375 +git checkout -b integration/sql-nested-transactions +git push --set-upstream origin integration/sql-nested-transactions +``` -Notes: -- in prisma/prisma repository, we do not run tests for `integration/` branches, it is much faster and also means that there is no risk of test failing (e.g. flaky tests, snapshots) that would stop the publishing process. -- in prisma/prisma-engines tests must first pass, before publishing starts. So better keep an eye on them and restart them as needed. +If there is a need to re-create this branch because it has been updated, deleting it and re-creating will make sure the content is identical and avoid any conflicts. +``` +git branch --delete integration/sql-nested-transactions +gh pr checkout 4375 +git checkout -b integration/sql-nested-transactions +git push --set-upstream origin integration/sql-nested-transactions --force +``` ## Security diff --git a/docker-compose.yml b/docker-compose.yml index c0d4f179e0a4..9be5e6978f87 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: '3' services: cockroach_23_1: - image: prismagraphql/cockroachdb-custom:23.1 + image: prismagraphql/cockroachdb-custom:23.1@sha256:c5a97355d56a7692ed34d835dfd8e3663d642219ea90736658a24840ea26862d restart: unless-stopped command: | start-single-node --insecure @@ -107,7 +107,7 @@ services: networks: - databases - neon-postgres13: + neon-proxy: image: ghcr.io/neondatabase/wsproxy:latest restart: unless-stopped environment: @@ -123,22 +123,69 @@ services: networks: - databases - planetscale-vitess8: + # Tests using a vitess backend were not running properly for query-engine tests, and In + # https://github.com/prisma/prisma-engines/pull/4423 work was done to have coverage of the correctness of planetscale + # driver adapters. + # + # Given these tests run against the planetscale proxy, and given a different test suite will exist for vitess with + # rust drivers, we opted for the path of least friction when running the driver adapter tests, which is putting a + # single mysql box behind the planetscale proxy instead of full vttest cluster. + # + # The tradeoffs are: + # + # - we don't exercise vitess but mysql. This is a close approximation, but there might be small differences in + # behavior. (ex. vttest can be returning different error messages than mysql) + # + # - however, we 1) do exercise the planetscale proxy, 2) we use relationMode=prisma and this resembles what actually + # happens within the query engine, where vitess does not exist as a provider, and as such there isn't any particular + # capability or conditional code making the engine behave differently then when using Mysql. + # In the end Vitess is just an abstraction existing in the test kit to a) use the mysql provider, b) run the suite + # with relationMode=prisma; c) being able to run or exclude specific tests for that configuration. But the existence + # of this testing connector is misleading, and it should probably be just a version of the MySQL testing connector + # instead. + planetscale-proxy: build: ./docker/planetscale_proxy environment: - MYSQL_HOST: 'vitess-test-8_0' - MYSQL_PORT: 33807 - MYSQL_DATABASE: 'test' + MYSQL_HOST: 'mysql-planetscale' + MYSQL_PORT: 3306 + MYSQL_DATABASE: prisma ports: - '8085:8085' depends_on: - - vitess-test-8_0 + mysql-planetscale: + condition: service_healthy restart: unless-stopped healthcheck: test: ['CMD', 'nc', '-z', '127.0.0.1', '8085'] interval: 5s timeout: 2s retries: 20 + networks: + - databases + + mysql-planetscale: + image: mysql:8.0.28 + command: mysqld + restart: unless-stopped + platform: linux/x86_64 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: prisma + ports: + - '3310:3306' + networks: + - databases + tmpfs: /var/lib/planetscale-mysql + volumes: + - "./docker/planetscale-mysql/my.cnf:/etc/mysql/my.cnf" + ulimits: + nofile: + soft: 20000 + hard: 40000 + healthcheck: + test: [ "CMD", "mysqladmin" ,"ping", "-h", "localhost" ] + timeout: 20s + retries: 10 postgres14: image: postgres:14 @@ -180,6 +227,10 @@ services: networks: - databases tmpfs: /var/lib/mysql + healthcheck: + test: [ "CMD", "mysqladmin" ,"ping", "-h", "localhost" ] + timeout: 20s + retries: 10 mysql-5-7: image: mysql:5.7.44 @@ -187,7 +238,6 @@ services: restart: unless-stopped platform: linux/x86_64 environment: - MYSQL_USER: root MYSQL_ROOT_PASSWORD: prisma MYSQL_DATABASE: prisma ports: @@ -195,6 +245,10 @@ services: networks: - databases tmpfs: /var/lib/mysql + healthcheck: + test: [ "CMD", "mysqladmin" ,"ping", "-h", "localhost" ] + timeout: 20s + retries: 10 mysql-8-0: image: mysql:8.0.28 @@ -209,6 +263,10 @@ services: networks: - databases tmpfs: /var/lib/mysql8 + healthcheck: + test: [ "CMD", "mysqladmin" ,"ping", "-h", "localhost" ] + timeout: 20s + retries: 10 mariadb-10-0: image: mariadb:10 @@ -223,26 +281,6 @@ services: - databases tmpfs: /var/lib/mariadb - vitess-test-5_7: - image: vitess/vttestserver:mysql57@sha256:23863a518b34330109c502ac61a396008f5f023e96263bcb2bb1b0f7f7d5dc7f - restart: unless-stopped - ports: - - 33577:33577 - environment: - PORT: 33574 - KEYSPACES: 'test' - NUM_SHARDS: '1' - MYSQL_BIND_HOST: '0.0.0.0' - FOREIGN_KEY_MODE: 'disallow' - ENABLE_ONLINE_DDL: false - MYSQL_MAX_CONNECTIONS: 100000 - TABLET_REFRESH_INTERVAL: '500ms' - healthcheck: - test: ['CMD', 'mysqladmin', 'ping', '-h127.0.0.1', '-P33577'] - interval: 5s - timeout: 2s - retries: 20 - vitess-test-8_0: image: vitess/vttestserver:mysql80@sha256:8bec2644d83cb322eb2cdd596d33c0f858243ba6ade9164c95dfcc519643094e restart: unless-stopped @@ -263,26 +301,6 @@ services: timeout: 2s retries: 20 - vitess-shadow-5_7: - image: vitess/vttestserver:mysql57@sha256:23863a518b34330109c502ac61a396008f5f023e96263bcb2bb1b0f7f7d5dc7f - restart: unless-stopped - ports: - - 33578:33577 - environment: - PORT: 33574 - KEYSPACES: 'shadow' - NUM_SHARDS: '1' - MYSQL_BIND_HOST: '0.0.0.0' - FOREIGN_KEY_MODE: 'disallow' - ENABLE_ONLINE_DDL: false - MYSQL_MAX_CONNECTIONS: 100000 - TABLET_REFRESH_INTERVAL: '500ms' - healthcheck: - test: ['CMD', 'mysqladmin', 'ping', '-h127.0.0.1', '-P33577'] - interval: 5s - timeout: 2s - retries: 20 - vitess-shadow-8_0: image: vitess/vttestserver:mysql80@sha256:8bec2644d83cb322eb2cdd596d33c0f858243ba6ade9164c95dfcc519643094e restart: unless-stopped diff --git a/docker/planetscale-mysql/my.cnf b/docker/planetscale-mysql/my.cnf new file mode 100644 index 000000000000..f47f808a75fb --- /dev/null +++ b/docker/planetscale-mysql/my.cnf @@ -0,0 +1,6 @@ +[mysqld] +pid-file = /var/run/mysqld/mysqld.pid +socket = /var/run/mysqld/mysqld.sock +datadir = /var/lib/mysql +secure-file-priv= NULL +max_connections=1000 \ No newline at end of file diff --git a/docker/planetscale_proxy/Dockerfile b/docker/planetscale_proxy/Dockerfile index 2411894d88f0..9d6cca2f5dd8 100644 --- a/docker/planetscale_proxy/Dockerfile +++ b/docker/planetscale_proxy/Dockerfile @@ -9,7 +9,7 @@ ENTRYPOINT /go/bin/ps-http-sim \ -http-port=8085 \ -mysql-addr=$MYSQL_HOST \ -mysql-port=$MYSQL_PORT \ - -mysql-idle-timeout=1200s \ + -mysql-idle-timeout=1s \ -mysql-no-pass \ -mysql-max-rows=1000 \ -mysql-dbname=$MYSQL_DATABASE diff --git a/flake.lock b/flake.lock index c2750d0435ed..b887051dac9b 100644 --- a/flake.lock +++ b/flake.lock @@ -2,23 +2,16 @@ "nodes": { "crane": { "inputs": { - "flake-compat": "flake-compat", - "flake-utils": [ - "flake-utils" - ], "nixpkgs": [ "nixpkgs" - ], - "rust-overlay": [ - "rust-overlay" ] }, "locked": { - "lastModified": 1696384830, - "narHash": "sha256-j8ZsVqzmj5sOm5MW9cqwQJUZELFFwOislDmqDDEMl6k=", + "lastModified": 1699548976, + "narHash": "sha256-xnpxms0koM8mQpxIup9JnT0F7GrKdvv0QvtxvRuOYR4=", "owner": "ipetkov", "repo": "crane", - "rev": "f2143cd27f8bd09ee4f0121336c65015a2a0a19c", + "rev": "6849911446e18e520970cc6b7a691e64ee90d649", "type": "github" }, "original": { @@ -27,22 +20,6 @@ "type": "github" } }, - "flake-compat": { - "flake": false, - "locked": { - "lastModified": 1696267196, - "narHash": "sha256-AAQ/2sD+0D18bb8hKuEEVpHUYD1GmO2Uh/taFamn6XQ=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "4f910c9827911b1ec2bf26b5a062cd09f8d89f85", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, "flake-parts": { "inputs": { "nixpkgs-lib": [ @@ -50,11 +27,11 @@ ] }, "locked": { - "lastModified": 1696343447, - "narHash": "sha256-B2xAZKLkkeRFG5XcHHSXXcP7To9Xzr59KXeZiRf4vdQ=", + "lastModified": 1698882062, + "narHash": "sha256-HkhafUayIqxXyHH1X8d9RDl1M2CkFgZLjKD3MzabiEo=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "c9afaba3dfa4085dbd2ccb38dfade5141e33d9d4", + "rev": "8c9fa2545007b49a5db5f650ae91f227672c3877", "type": "github" }, "original": { @@ -105,11 +82,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1696193975, - "narHash": "sha256-mnQjUcYgp9Guu3RNVAB2Srr1TqKcPpRXmJf4LJk6KRY=", + "lastModified": 1699963925, + "narHash": "sha256-LE7OV/SwkIBsCpAlIPiFhch/J+jBDGEZjNfdnzCnCrY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "fdd898f8f79e8d2f99ed2ab6b3751811ef683242", + "rev": "bf744fe90419885eefced41b3e5ae442d732712d", "type": "github" }, "original": { @@ -139,11 +116,11 @@ ] }, "locked": { - "lastModified": 1696558324, - "narHash": "sha256-TnnP4LGwDB8ZGE7h2n4nA9Faee8xPkMdNcyrzJ57cbw=", + "lastModified": 1700187354, + "narHash": "sha256-RRIVKv+tiI1yn1PqZiVGQ9YlQGZ+/9iEkA4rst1QiNk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "fdb37574a04df04aaa8cf7708f94a9309caebe2b", + "rev": "e3ebc177291f5de627d6dfbac817b4a661b15d1c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 67f4042d8c68..e62a09803d3d 100644 --- a/flake.nix +++ b/flake.nix @@ -3,8 +3,6 @@ crane = { url = "github:ipetkov/crane"; inputs.nixpkgs.follows = "nixpkgs"; - inputs.rust-overlay.follows = "rust-overlay"; - inputs.flake-utils.follows = "flake-utils"; }; flake-utils = { url = "github:numtide/flake-utils"; diff --git a/libs/user-facing-errors/Cargo.toml b/libs/user-facing-errors/Cargo.toml index 9900892209c6..3049a19712b1 100644 --- a/libs/user-facing-errors/Cargo.toml +++ b/libs/user-facing-errors/Cargo.toml @@ -11,7 +11,7 @@ backtrace = "0.3.40" tracing = "0.1" indoc.workspace = true itertools = "0.10" -quaint = { workspace = true, optional = true } +quaint = { path = "../../quaint", optional = true } [features] default = [] diff --git a/prisma-schema-wasm/Cargo.toml b/prisma-schema-wasm/Cargo.toml index 248c726c9ba4..51638e55b1c1 100644 --- a/prisma-schema-wasm/Cargo.toml +++ b/prisma-schema-wasm/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -wasm-bindgen = "=0.2.87" +wasm-bindgen = "=0.2.88" wasm-logger = { version = "0.2.0", optional = true } prisma-fmt = { path = "../prisma-fmt" } diff --git a/psl/psl-core/src/validate/validation_pipeline/validations.rs b/psl/psl-core/src/validate/validation_pipeline/validations.rs index 4040844bb767..90f8ec9fe79e 100644 --- a/psl/psl-core/src/validate/validation_pipeline/validations.rs +++ b/psl/psl-core/src/validate/validation_pipeline/validations.rs @@ -123,7 +123,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { indexes::supports_clustering_setting(index, ctx); indexes::clustering_can_be_defined_only_once(index, ctx); indexes::opclasses_are_not_allowed_with_other_than_normal_indices(index, ctx); - indexes::composite_types_are_not_allowed_in_index(index, ctx); + indexes::composite_type_in_compound_unique_index(index, ctx); for field_attribute in index.scalar_field_attributes() { let span = index.ast_attribute().span; diff --git a/psl/psl-core/src/validate/validation_pipeline/validations/indexes.rs b/psl/psl-core/src/validate/validation_pipeline/validations/indexes.rs index 5f3288264016..7a7d0e1d105e 100644 --- a/psl/psl-core/src/validate/validation_pipeline/validations/indexes.rs +++ b/psl/psl-core/src/validate/validation_pipeline/validations/indexes.rs @@ -386,20 +386,25 @@ pub(crate) fn opclasses_are_not_allowed_with_other_than_normal_indices(index: In } } -pub(crate) fn composite_types_are_not_allowed_in_index(index: IndexWalker<'_>, ctx: &mut Context<'_>) { - for field in index.fields() { - if field.scalar_field_type().as_composite_type().is_some() { - let message = format!( - "Indexes can only contain scalar attributes. Please remove {:?} from the argument list of the indexes.", - field.name() - ); - ctx.push_error(DatamodelError::new_attribute_validation_error( - &message, - index.attribute_name(), - index.ast_attribute().span, - )); - return; - } +pub(crate) fn composite_type_in_compound_unique_index(index: IndexWalker<'_>, ctx: &mut Context<'_>) { + if !index.is_unique() { + return; + } + + let composite_type = index + .fields() + .find(|f| f.scalar_field_type().as_composite_type().is_some()); + + if index.fields().len() > 1 && composite_type.is_some() { + let message = format!( + "Prisma does not currently support composite types in compound unique indices, please remove {:?} from the index. See https://pris.ly/d/mongodb-composite-compound-indices for more details", + composite_type.unwrap().name() + ); + ctx.push_error(DatamodelError::new_attribute_validation_error( + &message, + index.attribute_name(), + index.ast_attribute().span, + )); } } diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b699518d0910..52a7edf72aca 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -23,20 +23,28 @@ resolver = "2" features = ["docs", "all"] [features] -default = [] +default = ["mysql", "postgresql", "mssql", "sqlite"] docs = [] # Expose the underlying database drivers when a connector is enabled. This is a # way to access database-specific methods when you need extra control. expose-drivers = [] -all = ["mssql", "mysql", "pooled", "postgresql", "sqlite"] +native = [ + "postgresql-native", + "mysql-native", + "mssql-native", + "sqlite-native", +] + +all = ["native", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql = [ +postgresql-native = [ + "postgresql", "native-tls", "tokio-postgres", "postgres-types", @@ -47,11 +55,24 @@ postgresql = [ "lru-cache", "byteorder", ] +postgresql = [] + +mssql-native = [ + "mssql", + "tiberius", + "tokio-util", + "tokio/time", + "tokio/net", +] +mssql = [] + +mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql = ["chrono/std"] -mssql = ["tiberius", "tokio-util", "tokio/time", "tokio/net", "either"] -mysql = ["mysql_async", "tokio/time", "lru-cache"] pooled = ["mobc"] -sqlite = ["rusqlite", "tokio/sync"] +sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite = [] + fmt-sql = ["sqlformat"] [dependencies] @@ -67,7 +88,7 @@ futures = "0.3" url = "2.1" hex = "0.4" -either = { version = "1.6", optional = true } +either = { version = "1.6" } base64 = { version = "0.12.3" } chrono = { version = "0.4", default-features = false, features = ["serde"] } lru-cache = { version = "0.1", optional = true } @@ -88,7 +109,11 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.0", features = ["macros", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] +version = "0.2" +features = ["js"] [dependencies.byteorder] default-features = false @@ -102,7 +127,7 @@ branch = "vendored-openssl" [dependencies.rusqlite] version = "0.29" -features = ["chrono", "bundled", "column_decltype"] +features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] diff --git a/quaint/README.md b/quaint/README.md index 92033db269b1..03108d9090d3 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,9 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. + - On non-WebAssembly targets, choose `mysql-native` instead. - `postgresql`: Support for PostgreSQL databases. + - On non-WebAssembly targets, choose `postgresql-native` instead. - `sqlite`: Support for SQLite databases. + - On non-WebAssembly targets, choose `sqlite-native` instead. - `mssql`: Support for Microsoft SQL Server databases. + - On non-WebAssembly targets, choose `mssql-native` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index de8bc64d22bb..dddb3c953ad7 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -10,37 +10,49 @@ //! querying interface. mod connection_info; + pub mod metrics; mod queryable; mod result_set; -#[cfg(any(feature = "mssql", feature = "postgresql", feature = "mysql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql")] -pub(crate) mod mssql; -#[cfg(feature = "mysql")] -pub(crate) mod mysql; -#[cfg(feature = "postgresql")] -pub(crate) mod postgres; -#[cfg(feature = "sqlite")] -pub(crate) mod sqlite; - -#[cfg(feature = "mysql")] -pub use self::mysql::*; -#[cfg(feature = "postgresql")] -pub use self::postgres::*; pub use self::result_set::*; pub use connection_info::*; -#[cfg(feature = "mssql")] -pub use mssql::*; pub use queryable::*; -#[cfg(feature = "sqlite")] -pub use sqlite::*; pub use transaction::*; -#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] #[allow(unused_imports)] pub(crate) use type_identifier::*; pub use self::metrics::query; + +#[cfg(feature = "postgresql")] +pub(crate) mod postgres; +#[cfg(feature = "postgresql-native")] +pub use postgres::native::*; +#[cfg(feature = "postgresql")] +pub use postgres::*; + +#[cfg(feature = "mysql")] +pub(crate) mod mysql; +#[cfg(feature = "mysql-native")] +pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::*; + +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite; +#[cfg(feature = "sqlite-native")] +pub use sqlite::native::*; +#[cfg(feature = "sqlite")] +pub use sqlite::*; + +#[cfg(feature = "mssql")] +pub(crate) mod mssql; +#[cfg(feature = "mssql-native")] +pub use mssql::native::*; +#[cfg(feature = "mssql")] +pub use mssql::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..5a493ba17b24 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,614 +1,8 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. +pub(crate) mod url; -use super::{IsolationLevel, Transaction, TransactionOptions}; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use connection_string::JdbcString; -use futures::lock::Mutex; -use std::{ - convert::TryFrom, - fmt, - future::Future, - str::FromStr, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tiberius::*; -use tokio::net::TcpStream; -use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +pub use self::url::*; -/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tiberius; - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -/// TLS mode when connecting to SQL Server. -#[derive(Debug, Clone, Copy)] -pub enum EncryptMode { - /// All traffic is encrypted. - On, - /// Only the login credentials are encrypted. - Off, - /// Nothing is encrypted. - DangerPlainText, -} - -impl fmt::Display for EncryptMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::On => write!(f, "true"), - Self::Off => write!(f, "false"), - Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), - } - } -} - -impl FromStr for EncryptMode { - type Err = Error; - - fn from_str(s: &str) -> crate::Result { - let mode = match s.parse::() { - Ok(true) => Self::On, - _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, - _ => Self::Off, - }; - - Ok(mode) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: EncryptMode, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - schema: String, - trust_server_certificate: bool, - trust_server_certificate_ca: Option, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - transaction_isolation_level: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, -} - -static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; - -#[async_trait] -impl TransactionCapable for Mssql { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> crate::Result> { - // Isolation levels in SQL Server are set on the connection and live until they're changed. - // Always explicitly setting the isolation level each time a tx is started (either to the given value - // or by using the default/connection string value) prevents transactions started on connections from - // the pool to have unexpected isolation levels set. - let isolation = isolation - .or(self.url.query_params.transaction_isolation_level) - .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) - } -} - -impl MssqlUrl { - /// Maximum number of connections the pool can have (if used together with - /// pooled Quaint). - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - /// A duration how long one query can take. - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - /// A duration how long we can try to connect to the database. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - /// A pool check_out timeout. - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout() - } - - /// The isolation level of a transaction. - fn transaction_isolation_level(&self) -> Option { - self.query_params.transaction_isolation_level - } - - /// Name of the database. - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - /// The prefix which to use when querying database. - pub fn schema(&self) -> &str { - self.query_params.schema() - } - - /// Database hostname. - pub fn host(&self) -> &str { - self.query_params.host() - } - - /// The username to use when connecting to the database. - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - /// The password to use when connecting to the database. - pub fn password(&self) -> Option<&str> { - self.query_params.password() - } - - /// The TLS mode to use when connecting to the database. - pub fn encrypt(&self) -> EncryptMode { - self.query_params.encrypt() - } - - /// If true, we allow invalid certificates (self-signed, or otherwise - /// dangerous) when connecting. Should be true only for development and - /// testing. - pub fn trust_server_certificate(&self) -> bool { - self.query_params.trust_server_certificate() - } - - /// Path to a custom server certificate file. - pub fn trust_server_certificate_ca(&self) -> Option<&str> { - self.query_params.trust_server_certificate_ca() - } - - /// Database port. - pub fn port(&self) -> u16 { - self.query_params.port() - } - - /// The JDBC connection string - pub fn connection_string(&self) -> &str { - &self.connection_string - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime() - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime() - } -} - -impl MssqlQueryParams { - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_deref().unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_deref() - } - - fn password(&self) -> Option<&str> { - self.password.as_deref() - } - - fn encrypt(&self) -> EncryptMode { - self.encrypt - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn trust_server_certificate_ca(&self) -> Option<&str> { - self.trust_server_certificate_ca.as_deref() - } - - fn database(&self) -> &str { - &self.database - } - - fn schema(&self) -> &str { - &self.schema - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } - - fn pool_timeout(&self) -> Option { - self.pool_timeout - } - - fn max_connection_lifetime(&self) -> Option { - self.max_connection_lifetime - } - - fn max_idle_connection_lifetime(&self) -> Option { - self.max_idle_connection_lifetime - } -} - -/// A connector interface for the SQL Server database. -#[derive(Debug)] -pub struct Mssql { - client: Mutex>>, - url: MssqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, -} - -impl Mssql { - /// Creates a new connection to SQL Server. - pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_jdbc_string(&url.connection_string)?; - let tcp = TcpStream::connect_named(&config).await?; - let socket_timeout = url.socket_timeout(); - - let connecting = async { - match Client::connect(config, tcp.compat_write()).await { - Ok(client) => Ok(client), - Err(tiberius::error::Error::Routing { host, port }) => { - let mut config = Config::from_jdbc_string(&url.connection_string)?; - config.host(host); - config.port(port); - - let tcp = TcpStream::connect_named(&config).await?; - Client::connect(config, tcp.compat_write()).await - } - Err(e) => Err(e), - } - }; - - let client = super::timeout::connect(url.connect_timeout(), connecting).await?; - - let this = Self { - client: Mutex::new(client), - url, - socket_timeout, - is_healthy: AtomicBool::new(true), - }; - - if let Some(isolation) = this.url.transaction_isolation_level() { - this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) - .await?; - }; - - Ok(this) - } - - /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. - /// This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &Mutex>> { - &self.client - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } -} - -#[async_trait] -impl Queryable for Mssql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { - let mut client = self.client.lock().await; - - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; - - match results.pop() { - Some(rows) => { - let mut columns_set = false; - let mut columns = Vec::new(); - let mut result_rows = Vec::with_capacity(rows.len()); - - for row in rows.into_iter() { - if !columns_set { - columns = row.columns().iter().map(|c| c.name().to_string()).collect(); - columns_set = true; - } - - let mut values: Vec> = Vec::with_capacity(row.len()); - - for val in row.into_iter() { - values.push(Value::try_from(val)?); - } - - result_rows.push(values); - } - - Ok(ResultSet::new(columns, result_rows)) - } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), - } - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut client = self.client.lock().await; - let changes = self.perform_io(query.execute(&mut client)).await?.total(); - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { - let mut client = self.client.lock().await; - self.perform_io(client.simple_query(cmd)).await?.into_results().await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); - - Ok(Self { - connection_string, - query_params, - }) - } - - fn with_jdbc_prefix(input: &str) -> String { - if input.starts_with("jdbc:sqlserver") { - input.into() - } else { - format!("jdbc:{input}") - } - } - - fn parse_query_params(input: &str) -> crate::Result { - let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; - - let host = conn.server_name().map(|server_name| match conn.instance_name() { - Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), - None => server_name.to_string(), - }); - - let port = conn.port(); - let props = conn.properties_mut(); - let user = props.remove("user"); - let password = props.remove("password"); - let database = props.remove("database").unwrap_or_else(|| String::from("master")); - let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); - - let connection_limit = props - .remove("connectionlimit") - .or_else(|| props.remove("connection_limit")) - .map(|param| param.parse()) - .transpose()?; - - let transaction_isolation_level = props - .remove("isolationlevel") - .or_else(|| props.remove("isolation_level")) - .map(|level| { - IsolationLevel::from_str(&level).map_err(|_| { - let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); - Error::builder(kind).build() - }) - }) - .transpose()?; - - let mut connect_timeout = props - .remove("logintimeout") - .or_else(|| props.remove("login_timeout")) - .or_else(|| props.remove("connecttimeout")) - .or_else(|| props.remove("connect_timeout")) - .or_else(|| props.remove("connectiontimeout")) - .or_else(|| props.remove("connection_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match connect_timeout { - None => connect_timeout = Some(Duration::from_secs(5)), - Some(dur) if dur.as_secs() == 0 => connect_timeout = None, - _ => (), - } - - let mut pool_timeout = props - .remove("pooltimeout") - .or_else(|| props.remove("pool_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match pool_timeout { - None => pool_timeout = Some(Duration::from_secs(10)), - Some(dur) if dur.as_secs() == 0 => pool_timeout = None, - _ => (), - } - - let socket_timeout = props - .remove("sockettimeout") - .or_else(|| props.remove("socket_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - let encrypt = props - .remove("encrypt") - .map(|param| EncryptMode::from_str(¶m)) - .transpose()? - .unwrap_or(EncryptMode::On); - - let trust_server_certificate = props - .remove("trustservercertificate") - .or_else(|| props.remove("trust_server_certificate")) - .map(|param| param.parse()) - .transpose()? - .unwrap_or(false); - - let trust_server_certificate_ca: Option = props - .remove("trustservercertificateca") - .or_else(|| props.remove("trust_server_certificate_ca")); - - let mut max_connection_lifetime = props - .remove("max_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_connection_lifetime { - Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, - _ => (), - } - - let mut max_idle_connection_lifetime = props - .remove("max_idle_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_idle_connection_lifetime { - None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), - Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, - _ => (), - } - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - schema, - trust_server_certificate, - trust_server_certificate_ca, - connection_limit, - socket_timeout, - connect_timeout, - pool_timeout, - transaction_isolation_level, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mssql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mssql/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mssql/conversion.rs rename to quaint/src/connector/mssql/native/conversion.rs diff --git a/quaint/src/connector/mssql/error.rs b/quaint/src/connector/mssql/native/error.rs similarity index 100% rename from quaint/src/connector/mssql/error.rs rename to quaint/src/connector/mssql/native/error.rs diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs new file mode 100644 index 000000000000..d22aa7a15dd6 --- /dev/null +++ b/quaint/src/connector/mssql/native/mod.rs @@ -0,0 +1,239 @@ +//! Definitions for the MSSQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mssql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mssql::MssqlUrl; +use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{ + convert::TryFrom, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tiberius::*; +use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; + +/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tiberius; + +static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + // Isolation levels in SQL Server are set on the connection and live until they're changed. + // Always explicitly setting the isolation level each time a tx is started (either to the given value + // or by using the default/connection string value) prevents transactions started on connections from + // the pool to have unexpected isolation levels set. + let isolation = isolation + .or(self.url.query_params.transaction_isolation_level) + .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); + + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} + +/// A connector interface for the SQL Server database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, +} + +impl Mssql { + /// Creates a new connection to SQL Server. + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_jdbc_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let socket_timeout = url.socket_timeout(); + + let connecting = async { + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(tiberius::error::Error::Routing { host, port }) => { + let mut config = Config::from_jdbc_string(&url.connection_string)?; + config.host(host); + config.port(port); + + let tcp = TcpStream::connect_named(&config).await?; + Client::connect(config, tcp.compat_write()).await + } + Err(e) => Err(e), + } + }; + + let client = timeout::connect(url.connect_timeout(), connecting).await?; + + let this = Self { + client: Mutex::new(client), + url, + socket_timeout, + is_healthy: AtomicBool::new(true), + }; + + if let Some(isolation) = this.url.transaction_isolation_level() { + this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) + .await?; + }; + + Ok(this) + } + + /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. + /// This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &Mutex>> { + &self.client + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; + + match results.pop() { + Some(rows) => { + let mut columns_set = false; + let mut columns = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); + + for row in rows.into_iter() { + if !columns_set { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + columns_set = true; + } + + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result_rows.push(values); + } + + Ok(ResultSet::new(columns, result_rows)) + } + None => Ok(ResultSet::new(Vec::new(), Vec::new())), + } + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut client = self.client.lock().await; + let changes = self.perform_io(query.execute(&mut client)).await?.total(); + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.perform_io(client.simple_query(cmd)).await?.into_results().await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mssql/url.rs b/quaint/src/connector/mssql/url.rs new file mode 100644 index 000000000000..42cc0868f9bf --- /dev/null +++ b/quaint/src/connector/mssql/url.rs @@ -0,0 +1,384 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::{ + connector::IsolationLevel, + error::{Error, ErrorKind}, +}; +use connection_string::JdbcString; +use std::{fmt, str::FromStr, time::Duration}; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct MssqlUrl { + pub(crate) connection_string: String, + pub(crate) query_params: MssqlQueryParams, +} + +/// TLS mode when connecting to SQL Server. +#[derive(Debug, Clone, Copy)] +pub enum EncryptMode { + /// All traffic is encrypted. + On, + /// Only the login credentials are encrypted. + Off, + /// Nothing is encrypted. + DangerPlainText, +} + +impl fmt::Display for EncryptMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::On => write!(f, "true"), + Self::Off => write!(f, "false"), + Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), + } + } +} + +impl FromStr for EncryptMode { + type Err = Error; + + fn from_str(s: &str) -> crate::Result { + let mode = match s.parse::() { + Ok(true) => Self::On, + _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, + _ => Self::Off, + }; + + Ok(mode) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + pub(crate) encrypt: EncryptMode, + pub(crate) port: Option, + pub(crate) host: Option, + pub(crate) user: Option, + pub(crate) password: Option, + pub(crate) database: String, + pub(crate) schema: String, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) connection_limit: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) transaction_isolation_level: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, +} + +impl MssqlUrl { + /// Maximum number of connections the pool can have (if used together with + /// pooled Quaint). + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + /// A duration how long one query can take. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + /// A duration how long we can try to connect to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + /// A pool check_out timeout. + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout() + } + + /// The isolation level of a transaction. + pub(crate) fn transaction_isolation_level(&self) -> Option { + self.query_params.transaction_isolation_level + } + + /// Name of the database. + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + /// The prefix which to use when querying database. + pub fn schema(&self) -> &str { + self.query_params.schema() + } + + /// Database hostname. + pub fn host(&self) -> &str { + self.query_params.host() + } + + /// The username to use when connecting to the database. + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + /// The password to use when connecting to the database. + pub fn password(&self) -> Option<&str> { + self.query_params.password() + } + + /// The TLS mode to use when connecting to the database. + pub fn encrypt(&self) -> EncryptMode { + self.query_params.encrypt() + } + + /// If true, we allow invalid certificates (self-signed, or otherwise + /// dangerous) when connecting. Should be true only for development and + /// testing. + pub fn trust_server_certificate(&self) -> bool { + self.query_params.trust_server_certificate() + } + + /// Path to a custom server certificate file. + pub fn trust_server_certificate_ca(&self) -> Option<&str> { + self.query_params.trust_server_certificate_ca() + } + + /// Database port. + pub fn port(&self) -> u16 { + self.query_params.port() + } + + /// The JDBC connection string + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime() + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime() + } +} + +impl MssqlQueryParams { + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_deref().unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_deref() + } + + fn password(&self) -> Option<&str> { + self.password.as_deref() + } + + fn encrypt(&self) -> EncryptMode { + self.encrypt + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn trust_server_certificate_ca(&self) -> Option<&str> { + self.trust_server_certificate_ca.as_deref() + } + + fn database(&self) -> &str { + &self.database + } + + fn schema(&self) -> &str { + &self.schema + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.connect_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } + + fn pool_timeout(&self) -> Option { + self.pool_timeout + } + + fn max_connection_lifetime(&self) -> Option { + self.max_connection_lifetime + } + + fn max_idle_connection_lifetime(&self) -> Option { + self.max_idle_connection_lifetime + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); + + Ok(Self { + connection_string, + query_params, + }) + } + + fn with_jdbc_prefix(input: &str) -> String { + if input.starts_with("jdbc:sqlserver") { + input.into() + } else { + format!("jdbc:{input}") + } + } + + fn parse_query_params(input: &str) -> crate::Result { + let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; + + let host = conn.server_name().map(|server_name| match conn.instance_name() { + Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), + None => server_name.to_string(), + }); + + let port = conn.port(); + let props = conn.properties_mut(); + let user = props.remove("user"); + let password = props.remove("password"); + let database = props.remove("database").unwrap_or_else(|| String::from("master")); + let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); + + let connection_limit = props + .remove("connectionlimit") + .or_else(|| props.remove("connection_limit")) + .map(|param| param.parse()) + .transpose()?; + + let transaction_isolation_level = props + .remove("isolationlevel") + .or_else(|| props.remove("isolation_level")) + .map(|level| { + IsolationLevel::from_str(&level).map_err(|_| { + let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); + Error::builder(kind).build() + }) + }) + .transpose()?; + + let mut connect_timeout = props + .remove("logintimeout") + .or_else(|| props.remove("login_timeout")) + .or_else(|| props.remove("connecttimeout")) + .or_else(|| props.remove("connect_timeout")) + .or_else(|| props.remove("connectiontimeout")) + .or_else(|| props.remove("connection_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match connect_timeout { + None => connect_timeout = Some(Duration::from_secs(5)), + Some(dur) if dur.as_secs() == 0 => connect_timeout = None, + _ => (), + } + + let mut pool_timeout = props + .remove("pooltimeout") + .or_else(|| props.remove("pool_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match pool_timeout { + None => pool_timeout = Some(Duration::from_secs(10)), + Some(dur) if dur.as_secs() == 0 => pool_timeout = None, + _ => (), + } + + let socket_timeout = props + .remove("sockettimeout") + .or_else(|| props.remove("socket_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + let encrypt = props + .remove("encrypt") + .map(|param| EncryptMode::from_str(¶m)) + .transpose()? + .unwrap_or(EncryptMode::On); + + let trust_server_certificate = props + .remove("trustservercertificate") + .or_else(|| props.remove("trust_server_certificate")) + .map(|param| param.parse()) + .transpose()? + .unwrap_or(false); + + let trust_server_certificate_ca: Option = props + .remove("trustservercertificateca") + .or_else(|| props.remove("trust_server_certificate_ca")); + + let mut max_connection_lifetime = props + .remove("max_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_connection_lifetime { + Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, + _ => (), + } + + let mut max_idle_connection_lifetime = props + .remove("max_idle_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_idle_connection_lifetime { + None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), + Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, + _ => (), + } + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + schema, + trust_server_certificate, + trust_server_certificate_ca, + connection_limit, + socket_timeout, + connect_timeout, + pool_timeout, + transaction_isolation_level, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..77bb6e0d1b8a 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,669 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use lru_cache::LruCache; -use mysql_async::{ - self as my, - prelude::{Query as _, Queryable as _}, -}; -use percent_encoding::percent_decode; -use std::{ - borrow::Cow, - future::Future, - path::{Path, PathBuf}, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio::sync::Mutex; -use url::{Host, Url}; +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. +pub(crate) mod error; +pub(crate) mod url; +pub use self::url::*; pub use error::MysqlError; -/// The underlying MySQL driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use mysql_async; - -use super::IsolationLevel; - -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - - pub(crate) fn cache(&self) -> LruCache { - LruCache::new(self.query_params.statement_cache_size) - } - - fn parse_query_params(url: &Url) -> Result { - let mut ssl_opts = my::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::default() - .stmt_cache_size(Some(0)) - .user(Some(self.username())) - .pass(self.password()) - .db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config = config.socket(Some(socket)); - } - None => { - config = config.ip_or_hostname(self.host()).tcp_port(self.port()); - } - } - - config = config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - if self.query_params.prefer_socket.is_some() { - config = config.prefer_socket(self.query_params.prefer_socket); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - -impl Mysql { - /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub async fn new(url: MysqlUrl) -> crate::Result { - let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; - - Ok(Self { - socket_timeout: url.query_params.socket_timeout, - conn: Mutex::new(conn), - statement_cache: Mutex::new(url.cache()), - url, - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying mysql_async::Conn. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn conn(&self) -> &Mutex { - &self.conn - } - - async fn perform_io(&self, op: U) -> crate::Result - where - F: Future>, - U: FnOnce() -> F, - { - match super::timeout::socket(self.socket_timeout, op()).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => Ok(res?), - } - } - - async fn prepared(&self, sql: &str, op: U) -> crate::Result - where - F: Future>, - U: Fn(my::Statement) -> F, - { - if self.url.statement_cache_size() == 0 { - self.perform_io(|| async move { - let stmt = { - let mut conn = self.conn.lock().await; - conn.prep(sql).await? - }; - - let res = op(stmt.clone()).await; - - { - let mut conn = self.conn.lock().await; - conn.close(stmt).await?; - } - - res - }) - .await - } else { - self.perform_io(|| async move { - let stmt = self.fetch_cached(sql).await?; - op(stmt).await - }) - .await - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let mut conn = self.conn.lock().await; - if cache.capacity() == cache.len() { - if let Some((_, stmt)) = cache.remove_lru() { - conn.close(stmt).await?; - } - } - - let stmt = conn.prep(sql).await?; - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } -} - -impl_default_TransactionCapable!(Mysql); - -#[async_trait] -impl Queryable for Mysql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); - - let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); - - for mut row in rows { - result_set.rows.push(row.take_result_row()?); - } - - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; - - Ok(result_set) - }) - .await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - conn.exec_drop(stmt, conversion::conv_params(params)?).await?; - - Ok(conn.affected_rows()) - }) - .await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - self.perform_io(|| async move { - let mut conn = self.conn.lock().await; - let mut result = cmd.run(&mut *conn).await?; - - loop { - result.map(drop).await?; - - if result.is_empty() { - result.map(drop).await?; - break; - } - } - - Ok(()) - }) - .await - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@GLOBAL.version version"#; - let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mysql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/error.rs index dd7c3d3bfa66..7b4813bf0223 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/error.rs @@ -1,22 +1,23 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; +use thiserror::Error; + +// This is a partial copy of the `mysql_async::Error` using only the enum variant used by Prisma. +// This avoids pulling in `mysql_async`, which would break Wasm compilation. +#[derive(Debug, Error)] +enum MysqlAsyncError { + #[error("Server error: `{}'", _0)] + Server(#[source] MysqlError), +} +/// This type represents MySql server error. +#[derive(Debug, Error, Clone, Eq, PartialEq)] +#[error("ERROR {} ({}): {}", state, code, message)] pub struct MysqlError { pub code: u16, pub message: String, pub state: String, } -impl From<&my::ServerError> for MysqlError { - fn from(value: &my::ServerError) -> Self { - MysqlError { - code: value.code, - message: value.message.to_owned(), - state: value.state.to_owned(), - } - } -} - impl From for Error { fn from(error: MysqlError) -> Self { let code = error.code; @@ -232,7 +233,7 @@ impl From for Error { } _ => { let kind = ErrorKind::QueryError( - my::Error::Server(my::ServerError { + MysqlAsyncError::Server(MysqlError { message: error.message.clone(), code, state: error.state.clone(), @@ -249,24 +250,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: my::Error) -> Error { - match e { - my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { - message: err.to_string(), - }) - .build(), - my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - Error::builder(ErrorKind::ConnectionClosed).build() - } - my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::Error::Server(ref server_error) => { - let mysql_error: MysqlError = server_error.into(); - mysql_error.into() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} diff --git a/quaint/src/connector/mysql/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mysql/conversion.rs rename to quaint/src/connector/mysql/native/conversion.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs new file mode 100644 index 000000000000..89c21fb706f6 --- /dev/null +++ b/quaint/src/connector/mysql/native/error.rs @@ -0,0 +1,36 @@ +use crate::{ + connector::mysql::error::MysqlError, + error::{Error, ErrorKind}, +}; +use mysql_async as my; + +impl From<&my::ServerError> for MysqlError { + fn from(value: &my::ServerError) -> Self { + MysqlError { + code: value.code, + message: value.message.to_owned(), + state: value.state.to_owned(), + } + } +} + +impl From for Error { + fn from(e: my::Error) -> Error { + match e { + my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { + message: err.to_string(), + }) + .build(), + my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + Error::builder(ErrorKind::ConnectionClosed).build() + } + my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), + my::Error::Server(ref server_error) => { + let mysql_error: MysqlError = server_error.into(); + mysql_error.into() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs new file mode 100644 index 000000000000..fdcc3a6276d1 --- /dev/null +++ b/quaint/src/connector/mysql/native/mod.rs @@ -0,0 +1,297 @@ +//! Definitions for the MySQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mysql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mysql::MysqlUrl; +use crate::connector::{timeout, IsolationLevel}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use lru_cache::LruCache; +use mysql_async::{ + self as my, + prelude::{Query as _, Queryable as _}, +}; +use std::{ + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::sync::Mutex; + +/// The underlying MySQL driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use mysql_async; + +impl MysqlUrl { + pub(crate) fn cache(&self) -> LruCache { + LruCache::new(self.query_params.statement_cache_size) + } + + pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { + let mut config = my::OptsBuilder::default() + .stmt_cache_size(Some(0)) + .user(Some(self.username())) + .pass(self.password()) + .db_name(Some(self.dbname())); + + match self.socket() { + Some(ref socket) => { + config = config.socket(Some(socket)); + } + None => { + config = config.ip_or_hostname(self.host()).tcp_port(self.port()); + } + } + + config = config.conn_ttl(Some(Duration::from_secs(5))); + + if self.query_params.use_ssl { + config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); + } + + if self.query_params.prefer_socket.is_some() { + config = config.prefer_socket(self.query_params.prefer_socket); + } + + config + } +} + +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + +impl Mysql { + /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. + pub async fn new(url: MysqlUrl) -> crate::Result { + let conn = timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; + + Ok(Self { + socket_timeout: url.query_params.socket_timeout, + conn: Mutex::new(conn), + statement_cache: Mutex::new(url.cache()), + url, + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying mysql_async::Conn. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn conn(&self) -> &Mutex { + &self.conn + } + + async fn perform_io(&self, op: U) -> crate::Result + where + F: Future>, + U: FnOnce() -> F, + { + match timeout::socket(self.socket_timeout, op()).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => Ok(res?), + } + } + + async fn prepared(&self, sql: &str, op: U) -> crate::Result + where + F: Future>, + U: Fn(my::Statement) -> F, + { + if self.url.statement_cache_size() == 0 { + self.perform_io(|| async move { + let stmt = { + let mut conn = self.conn.lock().await; + conn.prep(sql).await? + }; + + let res = op(stmt.clone()).await; + + { + let mut conn = self.conn.lock().await; + conn.close(stmt).await?; + } + + res + }) + .await + } else { + self.perform_io(|| async move { + let stmt = self.fetch_cached(sql).await?; + op(stmt).await + }) + .await + } + } + + async fn fetch_cached(&self, sql: &str) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let mut conn = self.conn.lock().await; + if cache.capacity() == cache.len() { + if let Some((_, stmt)) = cache.remove_lru() { + conn.close(stmt).await?; + } + } + + let stmt = conn.prep(sql).await?; + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } +} + +impl_default_TransactionCapable!(Mysql); + +#[async_trait] +impl Queryable for Mysql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.query_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; + let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); + + let last_id = conn.last_insert_id(); + let mut result_set = ResultSet::new(columns, Vec::new()); + + for mut row in rows { + result_set.rows.push(row.take_result_row()?); + } + + if let Some(id) = last_id { + result_set.set_last_insert_id(id); + }; + + Ok(result_set) + }) + .await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.execute_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + conn.exec_drop(stmt, conversion::conv_params(params)?).await?; + + Ok(conn.affected_rows()) + }) + .await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + self.perform_io(|| async move { + let mut conn = self.conn.lock().await; + let mut result = cmd.run(&mut *conn).await?; + + loop { + result.map(drop).await?; + + if result.is_empty() { + result.map(drop).await?; + break; + } + } + + Ok(()) + }) + .await + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@GLOBAL.version version"#; + let rows = timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mysql/url.rs b/quaint/src/connector/mysql/url.rs new file mode 100644 index 000000000000..f0756fa95833 --- /dev/null +++ b/quaint/src/connector/mysql/url.rs @@ -0,0 +1,401 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(crate) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-native")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-native")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-native")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-native")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..befc980ce29e 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,1593 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::{Borrow, Cow}, - fmt::{Debug, Display}, - fs, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio_postgres::{ - config::{ChannelBinding, SslMode}, - Client, Config, Statement, -}; -use url::{Host, Url}; +//! Wasm-compatible definitions for the PostgreSQL connector. +//! This module is only available with the `postgresql` feature. +pub(crate) mod error; +pub(crate) mod url; +pub use self::url::*; pub use error::PostgresError; -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -/// The underlying postgres driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tokio_postgres; - -use super::{IsolationLevel, Transaction}; - -#[derive(Clone)] -struct Hidden(T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - -struct PostgresClient(Client); - -impl Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("PostgresClient") - } -} - -/// A connector interface for the PostgreSQL database. -#[derive(Debug)] -pub struct PostgreSql { - client: PostgresClient, - pg_bouncer: bool, - socket_timeout: Option, - statement_cache: Mutex>, - is_healthy: AtomicBool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({err})"), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({err})"), - }) - .build() - })?; - let password = self.identity_password.0.as_deref().unwrap_or(""); - let identity = Identity::from_pkcs12(&db, password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, - flavour: PostgresFlavour, -} - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut application_name = None; - let mut channel_binding = ChannelBinding::Prefer; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - channel_binding, - options, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - fn set_search_path(&self, config: &mut Config) { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if self.query_params.pg_bouncer { - return; - } - - if let Some(schema) = &self.query_params.schema { - if self.flavour().is_cockroach() && is_safe_identifier(schema) { - config.search_path(CockroachSearchPath(schema).to_string()); - } - - if self.flavour().is_postgres() { - config.search_path(PostgresSearchPath(schema).to_string()); - } - } - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(options) = self.options() { - config.options(options); - } - - if let Some(application_name) = self.application_name() { - config.application_name(application_name); - } - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - } - - self.set_search_path(&mut config); - - config.ssl_mode(self.query_params.ssl_mode); - - config.channel_binding(self.query_params.channel_binding); - - config - } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: Option, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - statement_cache_size: usize, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - application_name: Option, - channel_binding: ChannelBinding, - options: Option, -} - -impl PostgreSql { - /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { - let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); - - // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. - if let Some(schema) = &url.query_params.schema { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if url.query_params.pg_bouncer - || url.flavour().is_unknown() - || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) - { - let session_variables = format!( - r##"{set_search_path}"##, - set_search_path = SetSearchPath(url.query_params.schema.as_deref()) - ); - - client.simple_query(session_variables.as_str()).await?; - } - } - - Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying tokio_postgres::Client. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &tokio_postgres::Client { - &self.client.0 - } - - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } - - fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { - if params.len() > i16::MAX as usize { - // tokio_postgres would return an error here. Let's avoid calling the driver - // and return an error early. - let kind = ErrorKind::QueryInvalidInput(format!( - "too many bind variables in prepared statement, expected maximum of {}, received {}", - i16::MAX, - params.len() - )); - Err(Error::builder(kind).build()) - } else { - Ok(()) - } - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -impl_default_TransactionCapable!(PostgreSql); - -#[async_trait] -impl Queryable for PostgreSql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.query_raw(sql.as_str(), ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.execute_raw(sql.as_str(), ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.perform_io(self.client.0.simple_query(cmd)).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { - if self.pg_bouncer { - tx.raw_cmd("DEALLOCATE ALL").await - } else { - Ok(()) - } - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::test_api::postgres::CONN_STR; - use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - - #[tokio::test] - async fn test_custom_search_path_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_pg_pgbouncer() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - url.query_pairs_mut().append_pair("pbbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - - #[test] - fn test_safe_ident() { - // Safe - assert!(is_safe_identifier("hello")); - assert!(is_safe_identifier("_hello")); - assert!(is_safe_identifier("àbracadabra")); - assert!(is_safe_identifier("h3ll0")); - assert!(is_safe_identifier("héllo")); - assert!(is_safe_identifier("héll0$")); - assert!(is_safe_identifier("héll_0$")); - assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); - - // Not safe - assert!(!is_safe_identifier("")); - assert!(!is_safe_identifier("Hello")); - assert!(!is_safe_identifier("hEllo")); - assert!(!is_safe_identifier("$hello")); - assert!(!is_safe_identifier("hello!")); - assert!(!is_safe_identifier("hello#")); - assert!(!is_safe_identifier("he llo")); - assert!(!is_safe_identifier(" hello")); - assert!(!is_safe_identifier("he-llo")); - assert!(!is_safe_identifier("hÉllo")); - assert!(!is_safe_identifier("1337")); - assert!(!is_safe_identifier("_HELLO")); - assert!(!is_safe_identifier("HELLO")); - assert!(!is_safe_identifier("HELLO$")); - assert!(!is_safe_identifier("ÀBRACADABRA")); - - for ident in RESERVED_KEYWORDS { - assert!(!is_safe_identifier(ident)); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert!(!is_safe_identifier(ident)); - } - } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } -} +#[cfg(feature = "postgresql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/error.rs index d4e5ec7837fe..ab6ec7b07847 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/error.rs @@ -1,7 +1,5 @@ use std::fmt::{Display, Formatter}; -use tokio_postgres::error::DbError; - use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; #[derive(Debug)] @@ -17,7 +15,7 @@ pub struct PostgresError { impl std::error::Error for PostgresError {} impl Display for PostgresError { - // copy of DbError::fmt + // copy of tokio_postgres::error::DbError::fmt fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { @@ -30,19 +28,6 @@ impl Display for PostgresError { } } -impl From<&DbError> for PostgresError { - fn from(value: &DbError) -> Self { - PostgresError { - code: value.code().code().to_string(), - severity: value.severity().to_string(), - message: value.message().to_string(), - detail: value.detail().map(ToString::to_string), - column: value.column().map(ToString::to_string), - hint: value.hint().map(ToString::to_string), - } - } -} - impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -245,110 +230,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - if e.is_closed() { - return Error::builder(ErrorKind::ConnectionClosed).build(); - } - - if let Some(db_error) = e.as_db_error() { - return PostgresError::from(db_error).into(); - } - - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } - - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - if let Some(uuid_error) = try_extracting_uuid_error(&e) { - return uuid_error; - } - - let reason = format!("{e}"); - let code = e.code().map(|c| c.code()); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } - } -} - -fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::UUIDError(format!("{err}"))) - .map(|kind| Error::builder(kind).build()) -} - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{e}"), - }; - - Error::builder(kind).build() - } -} diff --git a/quaint/src/connector/postgres/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs similarity index 100% rename from quaint/src/connector/postgres/conversion.rs rename to quaint/src/connector/postgres/native/conversion.rs diff --git a/quaint/src/connector/postgres/conversion/decimal.rs b/quaint/src/connector/postgres/native/conversion/decimal.rs similarity index 100% rename from quaint/src/connector/postgres/conversion/decimal.rs rename to quaint/src/connector/postgres/native/conversion/decimal.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs new file mode 100644 index 000000000000..c353e397705c --- /dev/null +++ b/quaint/src/connector/postgres/native/error.rs @@ -0,0 +1,126 @@ +use tokio_postgres::error::DbError; + +use crate::{ + connector::postgres::error::PostgresError, + error::{Error, ErrorKind}, +}; + +impl From<&DbError> for PostgresError { + fn from(value: &DbError) -> Self { + PostgresError { + code: value.code().code().to_string(), + severity: value.severity().to_string(), + message: value.message().to_string(), + detail: value.detail().map(ToString::to_string), + column: value.column().map(ToString::to_string), + hint: value.hint().map(ToString::to_string), + } + } +} + +impl From for Error { + fn from(e: tokio_postgres::error::Error) -> Error { + if e.is_closed() { + return Error::builder(ErrorKind::ConnectionClosed).build(); + } + + if let Some(db_error) = e.as_db_error() { + return PostgresError::from(db_error).into(); + } + + if let Some(tls_error) = try_extracting_tls_error(&e) { + return tls_error; + } + + // Same for IO errors. + if let Some(io_error) = try_extracting_io_error(&e) { + return io_error; + } + + if let Some(uuid_error) = try_extracting_uuid_error(&e) { + return uuid_error; + } + + let reason = format!("{e}"); + let code = e.code().map(|c| c.code()); + + match reason.as_str() { + "error connecting to server: timed out" => { + let mut builder = Error::builder(ErrorKind::ConnectTimeout); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // sigh... + // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 + "error performing TLS handshake: server does not support TLS" => { + let mut builder = Error::builder(ErrorKind::TlsError { + message: reason.clone(), + }); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // double sigh + _ => { + let code = code.map(|c| c.to_string()); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } + } + } +} + +fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::UUIDError(format!("{err}"))) + .map(|kind| Error::builder(kind).build()) +} + +fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| err.into()) +} + +fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) + .map(|kind| Error::builder(kind).build()) +} + +impl From for Error { + fn from(e: native_tls::Error) -> Error { + Error::from(&e) + } +} + +impl From<&native_tls::Error> for Error { + fn from(e: &native_tls::Error) -> Error { + let kind = ErrorKind::TlsError { + message: format!("{e}"), + }; + + Error::builder(kind).build() + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs new file mode 100644 index 000000000000..30f34e7002be --- /dev/null +++ b/quaint/src/connector/postgres/native/mod.rs @@ -0,0 +1,972 @@ +//! Definitions for the Postgres connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `postgresql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::postgres::url::PostgresUrl; +use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; +use crate::connector::{timeout, IsolationLevel, Transaction}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::{future::FutureExt, lock::Mutex}; +use lru_cache::LruCache; +use native_tls::{Certificate, Identity, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + fs, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; + +/// The underlying postgres driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tokio_postgres; + +struct PostgresClient(Client); + +impl Debug for PostgresClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PostgresClient") + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct PostgreSql { + client: PostgresClient, + pg_bouncer: bool, + socket_timeout: Option, + statement_cache: Mutex>, + is_healthy: AtomicBool, +} + +#[derive(Debug)] +struct SslAuth { + certificate: Hidden>, + identity: Hidden>, + ssl_accept_mode: SslAcceptMode, +} + +impl Default for SslAuth { + fn default() -> Self { + Self { + certificate: Hidden(None), + identity: Hidden(None), + ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, + } + } +} + +impl SslAuth { + fn certificate(&mut self, certificate: Certificate) -> &mut Self { + self.certificate = Hidden(Some(certificate)); + self + } + + fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Hidden(Some(identity)); + self + } + + fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { + self.ssl_accept_mode = mode; + self + } +} + +impl SslParams { + async fn into_auth(self) -> crate::Result { + let mut auth = SslAuth::default(); + auth.accept_mode(self.ssl_accept_mode); + + if let Some(ref cert_file) = self.certificate_file { + let cert = fs::read(cert_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("cert file not found ({err})"), + }) + .build() + })?; + + auth.certificate(Certificate::from_pem(&cert)?); + } + + if let Some(ref identity_file) = self.identity_file { + let db = fs::read(identity_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("identity file not found ({err})"), + }) + .build() + })?; + let password = self.identity_password.0.as_deref().unwrap_or(""); + let identity = Identity::from_pkcs12(&db, password)?; + + auth.identity(identity); + } + + Ok(auth) + } +} + +impl PostgresUrl { + pub(crate) fn cache(&self) -> LruCache { + if self.query_params.pg_bouncer { + LruCache::new(0) + } else { + LruCache::new(self.query_params.statement_cache_size) + } + } + + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding + } + + /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + fn set_search_path(&self, config: &mut Config) { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if self.query_params.pg_bouncer { + return; + } + + if let Some(schema) = &self.query_params.schema { + if self.flavour().is_cockroach() && is_safe_identifier(schema) { + config.search_path(CockroachSearchPath(schema).to_string()); + } + + if self.flavour().is_postgres() { + config.search_path(PostgresSearchPath(schema).to_string()); + } + } + } + + pub(crate) fn to_config(&self) -> Config { + let mut config = Config::new(); + + config.user(self.username().borrow()); + config.password(self.password().borrow() as &str); + config.host(self.host()); + config.port(self.port()); + config.dbname(self.dbname()); + config.pgbouncer_mode(self.query_params.pg_bouncer); + + if let Some(options) = self.options() { + config.options(options); + } + + if let Some(application_name) = self.application_name() { + config.application_name(application_name); + } + + if let Some(connect_timeout) = self.query_params.connect_timeout { + config.connect_timeout(connect_timeout); + } + + self.set_search_path(&mut config); + + config.ssl_mode(self.query_params.ssl_mode); + + config.channel_binding(self.query_params.channel_binding); + + config + } +} + +impl PostgreSql { + /// Create a new connection to the database. + pub async fn new(url: PostgresUrl) -> crate::Result { + let config = url.to_config(); + + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls = MakeTlsConnector::new(tls_builder.build()?); + let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + + tokio::spawn(conn.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + + // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. + if let Some(schema) = &url.query_params.schema { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if url.query_params.pg_bouncer + || url.flavour().is_unknown() + || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) + { + let session_variables = format!( + r##"{set_search_path}"##, + set_search_path = SetSearchPath(url.query_params.schema.as_deref()) + ); + + client.simple_query(session_variables.as_str()).await?; + } + } + + Ok(Self { + client: PostgresClient(client), + socket_timeout: url.query_params.socket_timeout, + pg_bouncer: url.query_params.pg_bouncer, + statement_cache: Mutex::new(url.cache()), + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying tokio_postgres::Client. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &tokio_postgres::Client { + &self.client.0 + } + + async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let param_types = conversion::params_to_types(params); + let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; + + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } + + fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { + if params.len() > i16::MAX as usize { + // tokio_postgres would return an error here. Let's avoid calling the driver + // and return an error early. + let kind = ErrorKind::QueryInvalidInput(format!( + "too many bind variables in prepared statement, expected maximum of {}, received {}", + i16::MAX, + params.len() + )); + Err(Error::builder(kind).build()) + } else { + Ok(()) + } + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +impl_default_TransactionCapable!(PostgreSql); + +#[async_trait] +impl Queryable for PostgreSql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.query_raw(sql.as_str(), ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.execute_raw(sql.as_str(), ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + self.perform_io(self.client.0.simple_query(cmd)).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT version()"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { + if self.pg_bouncer { + tx.raw_cmd("DEALLOCATE ALL").await + } else { + Ok(()) + } + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::connector::Queryable; + use crate::tests::test_api::postgres::CONN_STR; + use crate::tests::test_api::CRDB_CONN_STR; + use url::Url; + + #[tokio::test] + async fn test_custom_search_path_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_pg_pgbouncer() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + url.query_pairs_mut().append_pair("pbbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[test] + fn test_safe_ident() { + // Safe + assert!(is_safe_identifier("hello")); + assert!(is_safe_identifier("_hello")); + assert!(is_safe_identifier("àbracadabra")); + assert!(is_safe_identifier("h3ll0")); + assert!(is_safe_identifier("héllo")); + assert!(is_safe_identifier("héll0$")); + assert!(is_safe_identifier("héll_0$")); + assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); + + // Not safe + assert!(!is_safe_identifier("")); + assert!(!is_safe_identifier("Hello")); + assert!(!is_safe_identifier("hEllo")); + assert!(!is_safe_identifier("$hello")); + assert!(!is_safe_identifier("hello!")); + assert!(!is_safe_identifier("hello#")); + assert!(!is_safe_identifier("he llo")); + assert!(!is_safe_identifier(" hello")); + assert!(!is_safe_identifier("he-llo")); + assert!(!is_safe_identifier("hÉllo")); + assert!(!is_safe_identifier("1337")); + assert!(!is_safe_identifier("_HELLO")); + assert!(!is_safe_identifier("HELLO")); + assert!(!is_safe_identifier("HELLO$")); + assert!(!is_safe_identifier("ÀBRACADABRA")); + + for ident in RESERVED_KEYWORDS { + assert!(!is_safe_identifier(ident)); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert!(!is_safe_identifier(ident)); + } + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs new file mode 100644 index 000000000000..f0b60d88a848 --- /dev/null +++ b/quaint/src/connector/postgres/url.rs @@ -0,0 +1,695 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-native")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(crate) certificate_file: Option, + pub(crate) identity_file: Option, + pub(crate) identity_password: Hidden>, + pub(crate) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(crate) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(crate) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(crate) url: Url, + pub(crate) query_params: PostgresUrlQueryParams, + pub(crate) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-native")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-native")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-native")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-native")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-native")] + channel_binding, + #[cfg(feature = "postgresql-native")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-native")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-native")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Value; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::tests::test_api::postgres::CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..c59c947b8dc1 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,353 +1,11 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. +pub(crate) mod error; +mod ffi; +pub(crate) mod params; pub use error::SqliteError; +pub use params::*; -pub use rusqlite::{params_from_iter, version as sqlite_version}; - -use super::IsolationLevel; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use std::{convert::TryFrom, path::Path, time::Duration}; -use tokio::sync::Mutex; - -pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; - -/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use rusqlite; - -/// A connector interface for the SQLite database -pub struct Sqlite { - pub(crate) client: Mutex, -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, - pub max_connection_lifetime: Option, - pub max_idle_connection_lifetime: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut socket_timeout = None; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), - socket_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } - } -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; - let file_path = params.file_path; - - let conn = rusqlite::Connection::open(file_path.as_str())?; - - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; - - let client = Mutex::new(conn); - - Ok(Sqlite { client }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) - } - - /// Open a new SQLite database in memory. - pub fn new_in_memory() -> crate::Result { - let client = rusqlite::Connection::open_in_memory()?; - - Ok(Sqlite { - client: Mutex::new(client), - }) - } - - /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo - /// feature. This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn connection(&self) -> &Mutex { - &self.client - } -} - -impl_default_TransactionCapable!(Sqlite); - -#[async_trait] -impl Queryable for Sqlite { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - - let mut stmt = client.prepare_cached(sql)?; - - let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); - - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); - } - - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; - - Ok(res) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) - } - - fn is_healthy(&self) -> bool { - true - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - // SQLite is always "serializable", other modes involve pragmas - // and shared cache mode, which is out of scope for now and should be implemented - // as part of a separate effort. - if !matches!(isolation_level, IsolationLevel::Serializable) { - let kind = ErrorKind::invalid_isolation_level(&isolation_level); - return Err(Error::builder(kind).build()); - } - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ast::*, - connector::Queryable, - error::{ErrorKind, Name}, - }; - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[tokio::test] - async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); - let select = Select::from_table("not_there"); - - let err = conn.select(select).await.unwrap_err(); - - match err.kind() { - ErrorKind::TableDoesNotExist { table } => { - assert_eq!(&Name::available("not_there"), table); - } - e => panic!("Expected error TableDoesNotExist, got {:?}", e), - } - } - - #[tokio::test] - async fn in_memory_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); - - // Check that we do get a separate, new database. - let other_conn = Sqlite::new_in_memory().unwrap(); - - let err = other_conn.select(select).await.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); - } - - #[tokio::test] - async fn quoting_in_returning_in_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - let insert: Insert = Insert::from(insert).returning(["txt space"]); - - let result = conn.insert(insert).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - } -} +#[cfg(feature = "sqlite-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/error.rs b/quaint/src/connector/sqlite/error.rs index c10b335cb3c0..2c6ff11350fd 100644 --- a/quaint/src/connector/sqlite/error.rs +++ b/quaint/src/connector/sqlite/error.rs @@ -1,8 +1,4 @@ -use std::fmt; - use crate::error::*; -use rusqlite::ffi; -use rusqlite::types::FromSqlError; #[derive(Debug)] pub struct SqliteError { @@ -10,14 +6,10 @@ pub struct SqliteError { pub message: Option, } -impl fmt::Display for SqliteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Error code {}: {}", - self.extended_code, - ffi::code_to_str(self.extended_code) - ) +#[cfg(not(feature = "sqlite-native"))] +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error code {}", self.extended_code) } } @@ -37,7 +29,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: super::ffi::SQLITE_CONSTRAINT_UNIQUE | super::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -58,7 +50,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: super::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -79,7 +71,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_FOREIGNKEY | ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: super::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | super::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -92,7 +84,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == super::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -152,55 +144,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error)); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure(ffi::Error { code: _, extended_code }, message) => { - SqliteError::new(extended_code, message).into() - } - - rusqlite::Error::SqlInputError { - error: ffi::Error { extended_code, .. }, - msg, - .. - } => SqliteError::new(extended_code, Some(msg)).into(), - - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} - -impl From for Error { - fn from(e: FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() - } -} diff --git a/quaint/src/connector/sqlite/ffi.rs b/quaint/src/connector/sqlite/ffi.rs new file mode 100644 index 000000000000..c510a459be81 --- /dev/null +++ b/quaint/src/connector/sqlite/ffi.rs @@ -0,0 +1,8 @@ +//! Here, we export only the constants we need to avoid pulling in `rusqlite::ffi::*`, in the sibling `error.rs` file, +//! which would break Wasm compilation. +pub const SQLITE_BUSY: i32 = 5; +pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; +pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; +pub const SQLITE_CONSTRAINT_PRIMARYKEY: i32 = 1555; +pub const SQLITE_CONSTRAINT_TRIGGER: i32 = 1811; +pub const SQLITE_CONSTRAINT_UNIQUE: i32 = 2067; diff --git a/quaint/src/connector/sqlite/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs similarity index 100% rename from quaint/src/connector/sqlite/conversion.rs rename to quaint/src/connector/sqlite/native/conversion.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs new file mode 100644 index 000000000000..51b2417ed821 --- /dev/null +++ b/quaint/src/connector/sqlite/native/error.rs @@ -0,0 +1,66 @@ +use crate::connector::sqlite::error::SqliteError; + +use crate::error::*; + +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error code {}: {}", + self.extended_code, + rusqlite::ffi::code_to_str(self.extended_code) + ) + } +} + +impl From for Error { + fn from(e: rusqlite::Error) -> Error { + match e { + rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { + Ok(error) => *error, + Err(error) => { + let mut builder = Error::builder(ErrorKind::QueryError(error)); + + builder.set_original_message("Could not interpret parameters in an SQLite query."); + + builder.build() + } + }, + rusqlite::Error::InvalidQuery => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + builder.set_original_message( + "Could not interpret the query or its parameters. Check the syntax and parameter types.", + ); + + builder.build() + } + rusqlite::Error::ExecuteReturnedResults => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + builder.set_original_message("Execute returned results, which is not allowed in SQLite."); + + builder.build() + } + + rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), + + rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code: _, extended_code }, message) => { + SqliteError::new(extended_code, message).into() + } + + rusqlite::Error::SqlInputError { + error: rusqlite::ffi::Error { extended_code, .. }, + msg, + .. + } => SqliteError::new(extended_code, Some(msg)).into(), + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} + +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { + Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs new file mode 100644 index 000000000000..3bf0c46a7db5 --- /dev/null +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -0,0 +1,234 @@ +//! Definitions for the SQLite connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `sqlite-native` feature. +mod conversion; +mod error; + +use crate::connector::sqlite::params::SqliteParams; +use crate::connector::IsolationLevel; + +pub use rusqlite::{params_from_iter, version as sqlite_version}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use std::convert::TryFrom; +use tokio::sync::Mutex; + +/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use rusqlite; + +/// A connector interface for the SQLite database +pub struct Sqlite { + pub(crate) client: Mutex, +} + +impl TryFrom<&str> for Sqlite { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let params = SqliteParams::try_from(path)?; + let file_path = params.file_path; + + let conn = rusqlite::Connection::open(file_path.as_str())?; + + if let Some(timeout) = params.socket_timeout { + conn.busy_timeout(timeout)?; + }; + + let client = Mutex::new(conn); + + Ok(Sqlite { client }) + } +} + +impl Sqlite { + pub fn new(file_path: &str) -> crate::Result { + Self::try_from(file_path) + } + + /// Open a new SQLite database in memory. + pub fn new_in_memory() -> crate::Result { + let client = rusqlite::Connection::open_in_memory()?; + + Ok(Sqlite { + client: Mutex::new(client), + }) + } + + /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo + /// feature. This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn connection(&self) -> &Mutex { + &self.client + } +} + +impl_default_TransactionCapable!(Sqlite); + +#[async_trait] +impl Queryable for Sqlite { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + + let mut stmt = client.prepare_cached(sql)?; + + let mut rows = stmt.query(params_from_iter(params.iter()))?; + let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + + while let Some(row) = rows.next()? { + result.rows.push(row.get_result_row()?); + } + + result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + let mut stmt = client.prepare_cached(sql)?; + let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; + + Ok(res) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + let client = self.client.lock().await; + client.execute_batch(cmd)?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + Ok(Some(rusqlite::version().into())) + } + + fn is_healthy(&self) -> bool { + true + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + // SQLite is always "serializable", other modes involve pragmas + // and shared cache mode, which is out of scope for now and should be implemented + // as part of a separate effort. + if !matches!(isolation_level, IsolationLevel::Serializable) { + let kind = ErrorKind::invalid_isolation_level(&isolation_level); + return Err(Error::builder(kind).build()); + } + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::*, + connector::Queryable, + error::{ErrorKind, Name}, + }; + + #[tokio::test] + async fn unknown_table_should_give_a_good_error() { + let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let select = Select::from_table("not_there"); + + let err = conn.select(select).await.unwrap_err(); + + match err.kind() { + ErrorKind::TableDoesNotExist { table } => { + assert_eq!(&Name::available("not_there"), table); + } + e => panic!("Expected error TableDoesNotExist, got {:?}", e), + } + } + + #[tokio::test] + async fn in_memory_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); + + // Check that we do get a separate, new database. + let other_conn = Sqlite::new_in_memory().unwrap(); + + let err = other_conn.select(select).await.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); + } + + #[tokio::test] + async fn quoting_in_returning_in_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + let insert: Insert = Insert::from(insert).returning(["txt space"]); + + let result = conn.insert(insert).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + } +} diff --git a/quaint/src/connector/sqlite/params.rs b/quaint/src/connector/sqlite/params.rs new file mode 100644 index 000000000000..f024aa97a694 --- /dev/null +++ b/quaint/src/connector/sqlite/params.rs @@ -0,0 +1,131 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug)] +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub max_connection_lifetime: Option, + pub max_idle_connection_lifetime: Option, +} + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let mut connection_limit = None; + let mut socket_timeout = None; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = None; + + if path_parts.len() > 1 { + let params = path_parts.last().unwrap().split('&').map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (splitted[0], splitted[1]) + }); + + for (k, v) in params { + match k { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = k); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), + socket_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } +} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..a77513876726 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..73441b7609ba 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql")] +#[cfg(feature = "postgresql-native")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..1a4dbdf52a61 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -1,7 +1,5 @@ //! A single connection abstraction to a SQL database. -#[cfg(feature = "sqlite")] -use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, @@ -9,7 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite")] +#[cfg(feature = "sqlite-native")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -127,30 +125,31 @@ impl Quaint { /// - `isolationLevel` the transaction isolation level. Possible values: /// `READ UNCOMMITTED`, `READ COMMITTED`, `REPEATABLE READ`, `SNAPSHOT`, /// `SERIALIZABLE`. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -166,9 +165,11 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { + use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; + Ok(Quaint { inner: Arc::new(connector::Sqlite::new_in_memory()?), connection_info: Arc::new(ConnectionInfo::InMemorySqlite { diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index fda8a6132037..b587a7b5b0ec 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -78,33 +78,27 @@ impl<'a> Visitor<'a> for Postgres<'a> { variants: Vec>, name: Option>, ) -> visitor::Result { - let len = variants.len(); - // Since enums are user-defined custom types, tokio-postgres fires an additional query // when parameterizing values of type enum to know which custom type the value refers to. // Casting the enum value to `TEXT` avoid this roundtrip since `TEXT` is a builtin type. if let Some(enum_name) = name.clone() { - self.surround_with("ARRAY[", "]", |s| { - for (i, variant) in variants.into_iter().enumerate() { - s.add_parameter(variant.into_text()); - s.parameter_substitution()?; - s.write("::text")?; - - if i < (len - 1) { - s.write(", ")?; - } + self.add_parameter(Value::array(variants.into_iter().map(|v| v.into_text()))); + + self.surround_with("CAST(", ")", |s| { + s.parameter_substitution()?; + s.write("::text[]")?; + s.write(" AS ")?; + + if let Some(schema_name) = enum_name.schema_name { + s.surround_with_backticks(schema_name.deref())?; + s.write(".")? } + s.surround_with_backticks(enum_name.name.deref())?; + s.write("[]")?; + Ok(()) })?; - - self.write("::")?; - if let Some(schema_name) = enum_name.schema_name { - self.surround_with_backticks(schema_name.deref())?; - self.write(".")? - } - self.surround_with_backticks(enum_name.name.deref())?; - self.write("[]")?; } else { self.visit_parameterized(Value::array( variants.into_iter().map(|variant| variant.into_enum(name.clone())), diff --git a/query-engine/black-box-tests/Cargo.toml b/query-engine/black-box-tests/Cargo.toml index 056ee2bcdb43..cc9e99b8ca3c 100644 --- a/query-engine/black-box-tests/Cargo.toml +++ b/query-engine/black-box-tests/Cargo.toml @@ -15,3 +15,4 @@ user-facing-errors.workspace = true insta = "1.7.1" enumflags2 = "0.7" query-engine-metrics = {path = "../metrics"} +regex = "1.9.3" diff --git a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs index 3397de75af99..69207f3fff5d 100644 --- a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs +++ b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs @@ -4,6 +4,7 @@ use query_engine_tests::*; /// Asserts common basics for composite type writes. #[test_suite(schema(schema))] mod smoke_tests { + use regex::Regex; fn schema() -> String { let schema = indoc! { r#"model Person { @@ -14,6 +15,24 @@ mod smoke_tests { schema.to_owned() } + fn assert_value_in_range(metrics: &str, metric: &str, low: f64, high: f64) { + let regex = Regex::new(format!(r"{metric}\s+([+-]?\d+(\.\d+)?)").as_str()).unwrap(); + match regex.captures(metrics) { + Some(capture) => { + let value = capture.get(1).unwrap().as_str().parse::().unwrap(); + assert!( + value >= low && value <= high, + "expected {} value of {} to be between {} and {}", + metric, + value, + low, + high + ); + } + None => panic!("Metric {} not found in metrics text", metric), + } + } + #[connector_test] #[rustfmt::skip] async fn expected_metrics_rendered(r: Runner) -> TestResult<()> { @@ -62,6 +81,8 @@ mod smoke_tests { // counters assert_eq!(metrics.matches("# HELP prisma_client_queries_total The total number of Prisma Client queries executed").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_client_queries_total counter").count(), 1); + assert_eq!(metrics.matches("prisma_client_queries_total 1").count(), 1); + assert_eq!(metrics.matches("# HELP prisma_datasource_queries_total The total number of datasource queries executed").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_datasource_queries_total counter").count(), 1); @@ -81,13 +102,15 @@ mod smoke_tests { assert_eq!(metrics.matches("# HELP prisma_pool_connections_busy The number of pool connections currently executing datasource queries").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_pool_connections_busy gauge").count(), 1); + assert_value_in_range(&metrics, "prisma_pool_connections_busy", 0f64, 1f64); assert_eq!(metrics.matches("# HELP prisma_pool_connections_idle The number of pool connections that are not busy running a query").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_pool_connections_idle gauge").count(), 1); assert_eq!(metrics.matches("# HELP prisma_pool_connections_open The number of pool connections currently open").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_pool_connections_open gauge").count(), 1); - + assert_value_in_range(&metrics, "prisma_pool_connections_open", 0f64, 1f64); + // histograms assert_eq!(metrics.matches("# HELP prisma_client_queries_duration_histogram_ms The distribution of the time Prisma Client queries took to run end to end").count(), 1); assert_eq!(metrics.matches("# TYPE prisma_client_queries_duration_histogram_ms histogram").count(), 1); diff --git a/query-engine/connector-test-kit-rs/README.md b/query-engine/connector-test-kit-rs/README.md index 97d19467879a..993f636e0d28 100644 --- a/query-engine/connector-test-kit-rs/README.md +++ b/query-engine/connector-test-kit-rs/README.md @@ -82,15 +82,16 @@ drivers the code that actually communicates with the databases. See [`adapter-*` To run tests through a driver adapters, you should also configure the following environment variables: -* `EXTERNAL_TEST_EXECUTOR`: tells the query engine test kit to use an external process to run the queries, this is a node process running a program that will read the queries to run from STDIN, and return responses to STDOUT. The connector kit follows a protocol over JSON RPC for this communication. * `DRIVER_ADAPTER`: tells the test executor to use a particular driver adapter. Set to `neon`, `planetscale` or any other supported adapter. * `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [github workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. +* `ENGINE`: can be used to run either `wasm` or `napi` version of the engine. Example: ```shell export EXTERNAL_TEST_EXECUTOR="$WORKSPACE_ROOT/query-engine/driver-adapters/connector-test-kit-executor/script/start_node.sh" export DRIVER_ADAPTER=neon +export ENGINE=wasm export DRIVER_ADAPTER_CONFIG ='{ "proxyUrl": "127.0.0.1:5488/v1" }' ```` @@ -98,7 +99,7 @@ We have provided helpers to run the query-engine tests with driver adapters, the variables for you: ```shell -DRIVER_ADAPTER=$adapter make test-qe +DRIVER_ADAPTER=$adapter ENGINE=$engine make test-qe ``` Where `$adapter` is one of the supported adapters: `neon`, `planetscale`, `libsql`. diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs index 62c4e3005f71..73455011d04e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs @@ -1,8 +1,8 @@ use query_engine_tests::*; -#[test_suite(schema(generic), only(Postgres))] +#[test_suite(schema(generic))] mod raw_params { - #[connector_test] + #[connector_test(only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] async fn value_too_many_bind_variables(runner: Runner) -> TestResult<()> { let n = 32768; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index 9aa34a943560..33908a9e079e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -1,7 +1,7 @@ use query_engine_tests::test_suite; use std::borrow::Cow; -#[test_suite(schema(generic))] +#[test_suite(schema(generic), exclude(Vitess("planetscale.js")))] mod interactive_tx { use query_engine_tests::*; use tokio::time; @@ -213,7 +213,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(JS))] + #[connector_test] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. let tx_id = runner.start_tx(5000, 5000, None).await?; @@ -256,7 +256,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(JS))] + #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. let tx_id = runner.start_tx(5000, 1000, None).await?; @@ -573,7 +573,7 @@ mod itx_isolation { use query_engine_tests::*; // All (SQL) connectors support serializable. - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -595,7 +595,7 @@ mod itx_isolation { Ok(()) } - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; runner.set_active_tx(tx_id.clone()); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 77a56f46c34b..cd270bb334c6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -1,6 +1,14 @@ use query_engine_tests::test_suite; -#[test_suite(schema(generic))] +#[test_suite( + schema(generic), + exclude( + Vitess("planetscale.js"), + Postgres("neon.js"), + Postgres("pg.js"), + Sqlite("libsql.js") + ) +)] mod metrics { use query_engine_metrics::{ PRISMA_CLIENT_QUERIES_ACTIVE, PRISMA_CLIENT_QUERIES_TOTAL, PRISMA_DATASOURCE_QUERIES_TOTAL, @@ -9,7 +17,7 @@ mod metrics { use query_engine_tests::*; use serde_json::Value; - #[connector_test(exclude(Js))] + #[connector_test] async fn metrics_are_recorded(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), @@ -27,7 +35,7 @@ mod metrics { let total_operations = get_counter(&json, PRISMA_CLIENT_QUERIES_TOTAL); match runner.connector_version() { - Sqlite => assert_eq!(total_queries, 9), + Sqlite(_) => assert_eq!(total_queries, 9), SqlServer(_) => assert_eq!(total_queries, 17), MongoDb(_) => assert_eq!(total_queries, 5), CockroachDb(_) => (), // not deterministic @@ -40,7 +48,7 @@ mod metrics { Ok(()) } - #[connector_test(exclude(Js))] + #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, None).await?; runner.set_active_tx(tx_id.clone()); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs index 10d7c376b38e..db0f020e029c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs @@ -1,6 +1,6 @@ use query_engine_tests::test_suite; -#[test_suite(capabilities(MultiSchema), exclude(Mysql))] +#[test_suite(capabilities(MultiSchema), exclude(Mysql, Vitess("planetscale.js")))] mod multi_schema { use query_engine_tests::*; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs index dc247f98f948..b495c8627e5a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs @@ -112,7 +112,7 @@ mod occ { assert_eq!(booked_user_id, found_booked_user_id); } - #[connector_test(schema(occ_simple), exclude(MongoDB, CockroachDb))] + #[connector_test(schema(occ_simple), exclude(MongoDB, CockroachDb, Vitess("planetscale.js")))] async fn occ_update_many_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); @@ -127,7 +127,7 @@ mod occ { Ok(()) } - #[connector_test(schema(occ_simple), exclude(CockroachDb))] + #[connector_test(schema(occ_simple), exclude(CockroachDb, Vitess("planetscale.js")))] async fn occ_update_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); @@ -158,7 +158,7 @@ mod occ { Ok(()) } - #[connector_test(schema(occ_simple))] + #[connector_test(schema(occ_simple), exclude(Vitess("planetscale.js")))] async fn occ_delete_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs index 8ea08acc85da..40ef54ed11f1 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs @@ -2,7 +2,7 @@ use indoc::indoc; use query_engine_tests::*; -#[test_suite(suite = "setdefault_onD_1to1_req", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onD_1to1_req", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] mod one2one_req { fn required_with_default() -> String { let schema = indoc! { @@ -66,7 +66,7 @@ mod one2one_req { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), @@ -103,7 +103,7 @@ mod one2one_req { } } -#[test_suite(suite = "setdefault_onD_1to1_opt", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onD_1to1_opt", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] mod one2one_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -167,7 +167,7 @@ mod one2one_opt { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), @@ -206,7 +206,7 @@ mod one2one_opt { } } -#[test_suite(suite = "setdefault_onD_1toM_req", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onD_1toM_req", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] mod one2many_req { fn required_with_default() -> String { let schema = indoc! { @@ -270,7 +270,7 @@ mod one2many_req { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, children: { create: { id: 1 }}}) { id }}"#), @@ -307,7 +307,7 @@ mod one2many_req { } } -#[test_suite(suite = "setdefault_onD_1toM_opt", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onD_1toM_opt", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] mod one2many_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -371,7 +371,7 @@ mod one2many_opt { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, children: { create: { id: 1 }}}) { id }}"#), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs index b0e566ffcb55..b942d6f0bc7b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs @@ -2,7 +2,7 @@ use indoc::indoc; use query_engine_tests::*; -#[test_suite(suite = "setdefault_onU_1to1_req", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onU_1to1_req", exclude(MongoDb, MySQL, Vitess))] mod one2one_req { fn required_with_default() -> String { let schema = indoc! { @@ -68,7 +68,7 @@ mod one2one_req { } /// Updating the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess))] async fn update_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", child: { create: { id: 1 }}}) { id }}"#), @@ -105,7 +105,7 @@ mod one2one_req { } } -#[test_suite(suite = "setdefault_onU_1to1_opt", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onU_1to1_opt", exclude(MongoDb, MySQL, Vitess))] mod one2one_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -171,7 +171,7 @@ mod one2one_opt { } /// Updating the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess))] async fn update_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", child: { create: { id: 1 }}}) { id }}"#), @@ -210,7 +210,7 @@ mod one2one_opt { } } -#[test_suite(suite = "setdefault_onU_1toM_req", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onU_1toM_req", exclude(MongoDb, MySQL, Vitess))] mod one2many_req { fn required_with_default() -> String { let schema = indoc! { @@ -276,7 +276,7 @@ mod one2many_req { } /// Updating the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess))] async fn update_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", children: { create: { id: 1 }}}) { id }}"#), @@ -313,7 +313,7 @@ mod one2many_req { } } -#[test_suite(suite = "setdefault_onU_1toM_opt", exclude(MongoDb, MySQL))] +#[test_suite(suite = "setdefault_onU_1toM_opt", exclude(MongoDb, MySQL, Vitess))] mod one2many_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -379,7 +379,7 @@ mod one2many_opt { } /// Updating the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, JS))] + #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess))] async fn update_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", children: { create: { id: 1 }}}) { id }}"#), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs index 581bc21bebe8..78206f6394a6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs @@ -187,8 +187,8 @@ mod max_integer { schema.to_owned() } - #[connector_test(schema(overflow_pg), only(Postgres))] - async fn unfitted_int_should_fail_pg(runner: Runner) -> TestResult<()> { + #[connector_test(schema(overflow_pg), only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] + async fn unfitted_int_should_fail_pg_quaint(runner: Runner) -> TestResult<()> { // int assert_error!( runner, @@ -234,6 +234,55 @@ mod max_integer { Ok(()) } + // The driver adapter for neon provides different error messages on overflow + #[connector_test(schema(overflow_pg), only(Postgres("neon.js"), Postgres("pg.js")))] + async fn unfitted_int_should_fail_pg_js(runner: Runner) -> TestResult<()> { + // int + assert_error!( + runner, + format!("mutation {{ createOneTest(data: {{ int: {I32_OVERFLOW_MAX} }}) {{ id }} }}"), + None, + "value \\\"2147483648\\\" is out of range for type integer" + ); + assert_error!( + runner, + format!("mutation {{ createOneTest(data: {{ int: {I32_OVERFLOW_MIN} }}) {{ id }} }}"), + None, + "value \\\"-2147483649\\\" is out of range for type integer" + ); + + // smallint + assert_error!( + runner, + format!("mutation {{ createOneTest(data: {{ smallint: {I16_OVERFLOW_MAX} }}) {{ id }} }}"), + None, + "value \\\"32768\\\" is out of range for type smallint" + ); + assert_error!( + runner, + format!("mutation {{ createOneTest(data: {{ smallint: {I16_OVERFLOW_MIN} }}) {{ id }} }}"), + None, + "value \\\"-32769\\\" is out of range for type smallint" + ); + + //oid + assert_error!( + runner, + format!("mutation {{ createOneTest(data: {{ oid: {U32_OVERFLOW_MAX} }}) {{ id }} }}"), + None, + "value \\\"4294967296\\\" is out of range for type oid" + ); + + // The underlying driver swallows a negative id by interpreting it as unsigned. + // {"data":{"createOneTest":{"id":1,"oid":4294967295}}} + run_query!( + runner, + format!("mutation {{ createOneTest(data: {{ oid: {OVERFLOW_MIN} }}) {{ id, oid }} }}") + ); + + Ok(()) + } + #[connector_test(schema(overflow_pg), only(Postgres))] async fn fitted_int_should_work_pg(runner: Runner) -> TestResult<()> { // int diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs index 8a2cbc7f24a2..0714015efd06 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs @@ -20,6 +20,7 @@ mod prisma_17103; mod prisma_18517; mod prisma_20799; mod prisma_21369; +mod prisma_21901; mod prisma_5952; mod prisma_6173; mod prisma_7010; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs index c1df015c577b..8582c14d0bc0 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs @@ -24,8 +24,8 @@ mod conversion_error { schema.to_owned() } - #[connector_test(schema(schema_int))] - async fn convert_to_int(runner: Runner) -> TestResult<()> { + #[connector_test(schema(schema_int), only(Sqlite), exclude(Sqlite("libsql.js")))] + async fn convert_to_int_sqlite_quaint(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; assert_error!( @@ -38,8 +38,22 @@ mod conversion_error { Ok(()) } - #[connector_test(schema(schema_bigint))] - async fn convert_to_bigint(runner: Runner) -> TestResult<()> { + #[connector_test(schema(schema_int), only(Sqlite("libsql.js")))] + async fn convert_to_int_sqlite_js(runner: Runner) -> TestResult<()> { + create_test_data(&runner).await?; + + assert_error!( + runner, + r#"query { findManyTestModel { field } }"#, + 2023, + "Inconsistent column data: Conversion failed: number must be an integer in column 'field', got '1.84467440724388e19'" + ); + + Ok(()) + } + + #[connector_test(schema(schema_bigint), only(Sqlite), exclude(Sqlite("libsql.js")))] + async fn convert_to_bigint_sqlite_quaint(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; assert_error!( @@ -52,6 +66,20 @@ mod conversion_error { Ok(()) } + #[connector_test(schema(schema_bigint), only(Sqlite("libsql.js")))] + async fn convert_to_bigint_sqlite_js(runner: Runner) -> TestResult<()> { + create_test_data(&runner).await?; + + assert_error!( + runner, + r#"query { findManyTestModel { field } }"#, + 2023, + "Inconsistent column data: Conversion failed: number must be an integer in column 'field', got '1.84467440724388e19'" + ); + + Ok(()) + } + async fn create_test_data(runner: &Runner) -> TestResult<()> { run_query!( runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs index d82f7bd17bc4..c9065ec54c58 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs @@ -21,7 +21,7 @@ mod prisma_17103 { schema.to_owned() } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn regression(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs new file mode 100644 index 000000000000..5b9dd4f46dcc --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs @@ -0,0 +1,50 @@ +use indoc::indoc; +use query_engine_tests::*; + +#[test_suite(schema(schema), capabilities(Enums, ScalarLists), exclude(MongoDb))] +mod prisma_21901 { + fn schema() -> String { + let schema = indoc! { + r#"model Test { + #id(id, Int, @id) + colors Color[] + } + + enum Color { + red + blue + green + } + "# + }; + + schema.to_owned() + } + + // fixes https://github.com/prisma/prisma/issues/21901 + #[connector_test] + async fn test(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!( + runner, + r#"mutation { createOneTest(data: { id: 1, colors: ["red"] }) { colors } }"# + ), + @r###"{"data":{"createOneTest":{"colors":["red"]}}}"### + ); + + insta::assert_snapshot!( + run_query!(runner, fmt_execute_raw(r#"TRUNCATE TABLE "prisma_21901_test"."Test" CASCADE;"#, [])), + @r###"{"data":{"executeRaw":0}}"### + ); + + insta::assert_snapshot!( + run_query!( + runner, + r#"mutation { createOneTest(data: { id: 2, colors: ["blue"] }) { colors } }"# + ), + @r###"{"data":{"createOneTest":{"colors":["blue"]}}}"### + ); + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs index a155090c7d56..4793fa24ae2a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs @@ -33,7 +33,7 @@ mod aggregation_avg { Ok(()) } - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn avg_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10" }"#).await?; @@ -126,7 +126,7 @@ mod decimal_aggregation_avg { Ok(()) } - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn avg_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs index 46bdd77ddb58..3c1f1b092690 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs @@ -87,7 +87,7 @@ mod combinations { } // Mongo precision issue. - #[connector_test(exclude(MongoDB))] + #[connector_test(exclude(MongoDB, Vitess("planetscale.js")))] async fn with_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: "1", float: 5.5, int: 5 }"#).await?; create_row(&runner, r#"{ id: "2", float: 4.5, int: 10 }"#).await?; @@ -369,7 +369,7 @@ mod decimal_combinations { } // Mongo precision issue. - #[connector_test(exclude(MongoDB))] + #[connector_test(exclude(MongoDB, Vitess("planetscale.js")))] async fn with_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: "1", dec: "5.5" }"#).await?; create_row(&runner, r#"{ id: "2", dec: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs index 3d5572650c13..78ab88fd59c6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs @@ -27,7 +27,7 @@ mod aggregation_count { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn count_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, string: "1" }"#).await?; create_row(&runner, r#"{ id: 2, string: "2" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs index d4ef72ee3cf6..12f9b6861892 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs @@ -30,7 +30,7 @@ mod aggregation_max { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn max_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5", string: "2" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10", string: "f" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_max { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn max_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs index 1927beba7ea5..332a5e10707f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs @@ -30,7 +30,7 @@ mod aggregation_min { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn min_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5", string: "2" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10", string: "f" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_min { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn min_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs index 59a89cdff930..14d194a1a4f4 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs @@ -30,7 +30,7 @@ mod aggregation_sum { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn sum_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_sum { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn sum_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs index 8c6e24db67ea..2c332f95f29a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs @@ -44,7 +44,7 @@ mod transactional { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn one_success_one_fail(runner: Runner) -> TestResult<()> { let queries = vec![ r#"mutation { createOneModelA(data: { id: 1 }) { id }}"#.to_string(), @@ -77,7 +77,7 @@ mod transactional { Ok(()) } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn one_query(runner: Runner) -> TestResult<()> { // Existing ModelA in the DB will prevent the nested ModelA creation in the batch. insta::assert_snapshot!( @@ -104,7 +104,7 @@ mod transactional { Ok(()) } - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn valid_isolation_level(runner: Runner) -> TestResult<()> { let queries = vec![r#"mutation { createOneModelB(data: { id: 1 }) { id }}"#.to_string()]; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs index 1fe86702ef5b..b865731161c2 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs @@ -126,7 +126,7 @@ mod json_filter { Ok(()) } - #[connector_test(schema(schema))] + #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn string_comparison_filters(runner: Runner) -> TestResult<()> { test_string_data(&runner).await?; @@ -169,7 +169,7 @@ mod json_filter { Ok(()) } - #[connector_test(schema(schema))] + #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn array_comparison_filters(runner: Runner) -> TestResult<()> { test_array_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs index 2fe8af850120..d1b62a086153 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs @@ -207,7 +207,18 @@ mod json { Ok(()) } - #[connector_test(schema(json_opt))] + // The external runner for driver adapters, in spite of the protocol being used in the test matrix + // uses the JSON representation of queries, so this test should not apply to driver adapters (exclude(JS)) + #[connector_test( + schema(json_opt), + exclude( + Vitess("planetscale.js"), + Postgres("neon.js"), + Postgres("pg.js"), + Sqlite("libsql.js"), + MySQL(5.6) + ) + )] async fn nested_not_shorthand(runner: Runner) -> TestResult<()> { // Those tests pass with the JSON protocol because the entire object is parsed as JSON. // They remain useful to ensure we don't ever allow a full JSON filter input object type at the schema level. diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs index a30808902c1d..e2ab83cfd62f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs @@ -27,7 +27,7 @@ mod json_filters { schema.to_owned() } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn no_path_without_filter(runner: Runner) -> TestResult<()> { assert_error!( runner, @@ -262,7 +262,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn array_contains(runner: Runner) -> TestResult<()> { array_contains_runner(runner).await?; @@ -371,7 +371,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn array_starts_with(runner: Runner) -> TestResult<()> { array_starts_with_runner(runner).await?; @@ -478,7 +478,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn array_ends_with(runner: Runner) -> TestResult<()> { array_ends_with_runner(runner).await?; @@ -517,7 +517,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn string_contains(runner: Runner) -> TestResult<()> { string_contains_runner(runner).await?; @@ -557,7 +557,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn string_starts_with(runner: Runner) -> TestResult<()> { string_starts_with_runner(runner).await?; @@ -596,7 +596,7 @@ mod json_filters { Ok(()) } - #[connector_test] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] async fn string_ends_with(runner: Runner) -> TestResult<()> { string_ends_with_runner(runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs index 218ecb7eb877..51637d3bbcb8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs @@ -229,7 +229,7 @@ mod search_filter_with_index { super::ensure_filter_tree_shake_works(runner).await } - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn throws_error_on_missing_index(runner: Runner) -> TestResult<()> { super::create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs index 5337806756ed..27c04241288c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs @@ -80,7 +80,7 @@ mod nested_pagination { ***************/ // should skip the first item - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn mid_lvl_skip_1(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -102,7 +102,7 @@ mod nested_pagination { } // should "skip all items" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn mid_lvl_skip_3(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -124,7 +124,7 @@ mod nested_pagination { } // should "skip all items" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn mid_lvl_skip_4(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs index 357051280efc..b4c6e7b5ef34 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs @@ -158,7 +158,7 @@ mod order_by_dependent { } }"#, // Depends on how null values are handled. - MongoDb(_) | Sqlite => vec![r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":null}},{"id":3,"b":null},{"id":1,"b":{"c":{"id":1}}}]}}"#], + MongoDb(_) | Sqlite(_) => vec![r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":null}},{"id":3,"b":null},{"id":1,"b":{"c":{"id":1}}}]}}"#], SqlServer(_) => vec![r#"{"data":{"findManyModelA":[{"id":3,"b":null},{"id":2,"b":{"c":null}},{"id":1,"b":{"c":{"id":1}}}]}}"#], Postgres(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"id":1}}},{"id":2,"b":{"c":null}},{"id":3,"b":null}]}}"#], _ => vec![ @@ -223,7 +223,7 @@ mod order_by_dependent { } // "[Circular with differing records] Ordering by related record field ascending" should "work" - #[connector_test(exclude(SqlServer))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] async fn circular_diff_related_record_asc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; @@ -243,7 +243,7 @@ mod order_by_dependent { } } }"#, - MongoDb(_) | Sqlite => vec![r#"{"data":{"findManyModelA":[{"id":3,"b":null},{"id":4,"b":null},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#], + MongoDb(_) | Sqlite(_) => vec![r#"{"data":{"findManyModelA":[{"id":3,"b":null},{"id":4,"b":null},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#], MySql(_) | CockroachDb(_) => vec![ r#"{"data":{"findManyModelA":[{"id":4,"b":null},{"id":3,"b":null},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#, r#"{"data":{"findManyModelA":[{"id":3,"b":null},{"id":4,"b":null},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#, @@ -255,7 +255,7 @@ mod order_by_dependent { } // "[Circular with differing records] Ordering by related record field descending" should "work" - #[connector_test(exclude(SqlServer))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] async fn circular_diff_related_record_desc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; @@ -275,7 +275,7 @@ mod order_by_dependent { } } }"#, - MongoDb(_) | Sqlite => vec![r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":{"a":{"id":4}}}},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":3,"b":null},{"id":4,"b":null}]}}"#], + MongoDb(_) | Sqlite(_)=> vec![r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":{"a":{"id":4}}}},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":3,"b":null},{"id":4,"b":null}]}}"#], MySql(_) | CockroachDb(_) => vec![ r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":{"a":{"id":4}}}},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":4,"b":null},{"id":3,"b":null}]}}"#, r#"{"data":{"findManyModelA":[{"id":2,"b":{"c":{"a":{"id":4}}}},{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":3,"b":null},{"id":4,"b":null}]}}"#, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs index ab0820ef0a35..f8e5e831971b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs @@ -79,7 +79,7 @@ mod order_by_dependent_pag { // "[Hops: 1] Ordering by related record field ascending with nulls" should "work" // TODO(julius): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] async fn hop_1_related_record_asc_nulls(runner: Runner) -> TestResult<()> { // 1 record has the "full chain", one half, one none create_row(&runner, 1, Some(1), Some(1), None).await?; @@ -97,7 +97,7 @@ mod order_by_dependent_pag { } }"#, // Depends on how null values are handled. - MongoDb(_) | Sqlite | MySql(_) | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"id":1}},{"id":2,"b":{"id":2}}]}}"#], + MongoDb(_) | Sqlite(_) | MySql(_) | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"id":1}},{"id":2,"b":{"id":2}}]}}"#], _ => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"id":1}},{"id":2,"b":{"id":2}},{"id":3,"b":null}]}}"#] ); @@ -146,7 +146,7 @@ mod order_by_dependent_pag { // "[Hops: 2] Ordering by related record field ascending with nulls" should "work" // TODO(garren): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] async fn hop_2_related_record_asc_null(runner: Runner) -> TestResult<()> { // 1 record has the "full chain", one half, one none create_row(&runner, 1, Some(1), Some(1), None).await?; @@ -166,7 +166,7 @@ mod order_by_dependent_pag { } }"#, // Depends on how null values are handled. - MongoDb(_) | Sqlite | MySql(_) | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"id":1}}}]}}"#], + MongoDb(_) | Sqlite(_) | MySql(_) | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"id":1}}}]}}"#], _ => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"id":1}}},{"id":2,"b":{"c":null}},{"id":3,"b":null}]}}"#] ); @@ -227,7 +227,7 @@ mod order_by_dependent_pag { // "[Circular with differing records] Ordering by related record field ascending" should "work" // TODO(julius): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] async fn circular_diff_related_record_asc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; @@ -248,7 +248,7 @@ mod order_by_dependent_pag { } }"#, // Depends on how null values are handled. - MongoDb(_) | MySql(_) | Sqlite | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#], + MongoDb(_) | MySql(_) | Sqlite(_) | CockroachDb(_) => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}}]}}"#], _ => vec![r#"{"data":{"findManyModelA":[{"id":1,"b":{"c":{"a":{"id":3}}}},{"id":2,"b":{"c":{"a":{"id":4}}}},{"id":3,"b":null},{"id":4,"b":null}]}}"#] ); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs index f0874cae02c8..83c472a064e7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs @@ -277,7 +277,7 @@ mod pagination { ********************/ // "A skip" should "return all records after the offset specified" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn skip_returns_all_after_offset(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -296,7 +296,7 @@ mod pagination { } // "A skip with order reversed" should "return all records after the offset specified" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn skip_reversed_order(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -315,7 +315,7 @@ mod pagination { } // "A skipping beyond all records" should "return no records" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn skipping_beyond_all_records(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/simple/multi_field_unique.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/simple/multi_field_unique.rs index fea6cfba0078..cb691971132d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/simple/multi_field_unique.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/simple/multi_field_unique.rs @@ -176,7 +176,7 @@ mod multi_field_unique { schema.to_owned() } - #[connector_test(schema(many_unique_fields), exclude(MySQL))] + #[connector_test(schema(many_unique_fields), exclude(MySQL, Vitess))] async fn ludicrous_fields(runner: Runner) -> TestResult<()> { create_user( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/views.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/views.rs index 4177f4d3a07d..feab8a87f2fe 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/views.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/views.rs @@ -2,7 +2,7 @@ use query_engine_tests::*; // https://stackoverflow.com/questions/4380813/how-to-get-rid-of-mysql-error-prepared-statement-needs-to-be-re-prepared // Looks like there's a bug with create view stmt on MariaDB -#[test_suite(schema(schema), exclude(MongoDb, MySql("mariadb")))] +#[test_suite(schema(schema), exclude(MongoDb, MySQL("mariadb"), Vitess))] mod views { use query_engine_tests::{connector_test, run_query, Runner}; @@ -146,7 +146,7 @@ mod views { => { r#"CREATE VIEW TestView AS SELECT TestModel.*, CONCAT(TestModel.firstName, ' ', TestModel.lastName) AS "fullName" FROM TestModel"#.to_owned() }, - ConnectorVersion::Sqlite => { + ConnectorVersion::Sqlite(_) => { r#"CREATE VIEW TestView AS SELECT TestModel.*, TestModel.firstName || ' ' || TestModel.lastName AS "fullName" FROM TestModel"#.to_owned() } ConnectorVersion::SqlServer(_) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs index 0039b924108c..c03067eed818 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs @@ -5,7 +5,20 @@ use query_engine_tests::*; mod casts { use query_engine_tests::{fmt_query_raw, run_query, RawParam}; - #[connector_test] + // The following tests are excluded for driver adapters. The underlying + // driver rejects queries where the values of the positional arguments do + // not match the expected types. As an example, the following query to the + // driver + // + // ```json + // { + // sql: 'SELECT $1::int4 AS decimal_to_i4; ', + // args: [ 42.51 ] + // } + // + // Bails with: ERROR: invalid input syntax for type integer: "42.51" + // + #[connector_test(only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] async fn query_numeric_casts(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query_pretty!(&runner, fmt_query_raw(r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs index 88409d8d17f6..cb44a2285ff2 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs @@ -34,8 +34,12 @@ mod raw_errors { Ok(()) } - #[connector_test(schema(common_nullable_types))] - async fn list_param_for_scalar_column_should_not_panic(runner: Runner) -> TestResult<()> { + #[connector_test( + schema(common_nullable_types), + only(Postgres), + exclude(Postgres("neon.js"), Postgres("pg.js")) + )] + async fn list_param_for_scalar_column_should_not_panic_quaint(runner: Runner) -> TestResult<()> { assert_error!( runner, fmt_execute_raw( @@ -48,4 +52,19 @@ mod raw_errors { Ok(()) } + + #[connector_test(schema(common_nullable_types), only(Postgres("neon.js"), Postgres("pg.js")))] + async fn list_param_for_scalar_column_should_not_panic_pg_js(runner: Runner) -> TestResult<()> { + assert_error!( + runner, + fmt_execute_raw( + r#"INSERT INTO "TestModel" ("id") VALUES ($1);"#, + vec![RawParam::array(vec![1])], + ), + 2010, + r#"invalid input syntax for type integer"# + ); + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs index a89bafd4c7ed..791b0a2137fb 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs @@ -77,7 +77,7 @@ mod bytes { Ok(()) } - #[connector_test(schema(bytes_id), exclude(MySQL, SqlServer))] + #[connector_test(schema(bytes_id), exclude(MySQL, Vitess, SqlServer))] async fn byte_id_coercion(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(runner, r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/json.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/json.rs index e34604815a7b..609607e22f1f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/json.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/json.rs @@ -17,7 +17,7 @@ mod json { } // "Json float accuracy" should "work" - #[connector_test(exclude(SqlServer, Mysql, Sqlite))] + #[connector_test(exclude(SqlServer, MySQL, Vitess, Sqlite))] async fn json_float_accuracy(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs index 66ca6defd4e9..5493ff7f2778 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs @@ -45,7 +45,11 @@ mod byoid { } // "A Create Mutation" should "create and return item with own Id" - #[connector_test(schema(schema_1))] + #[connector_test( + schema(schema_1), + only(MySql, Postgres, Sqlite, Vitess), + exclude(Vitess("planetscale.js")) + )] async fn create_and_return_item_woi_1(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { @@ -73,7 +77,11 @@ mod byoid { } // "A Create Mutation" should "create and return item with own Id" - #[connector_test(schema(schema_2))] + #[connector_test( + schema(schema_2), + only(MySql, Postgres, Sqlite, Vitess), + exclude(Vitess("planetscale.js")) + )] async fn create_and_return_item_woi_2(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { @@ -131,7 +139,11 @@ mod byoid { } // "A Nested Create Mutation" should "create and return item with own Id" - #[connector_test(schema(schema_1))] + #[connector_test( + schema(schema_1), + only(MySql, Postgres, Sqlite, Vitess), + exclude(Vitess("planetscale.js")) + )] async fn nested_create_return_item_woi_1(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { @@ -159,7 +171,11 @@ mod byoid { } // "A Nested Create Mutation" should "create and return item with own Id" - #[connector_test(schema(schema_2))] + #[connector_test( + schema(schema_2), + only(MySql, Postgres, Sqlite, Vitess), + exclude(Vitess("planetscale.js")) + )] async fn nested_create_return_item_woi_2(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs index 3cd6be2eabe2..45562b5f6be8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs @@ -78,7 +78,7 @@ mod nested_create_many { // "Nested createMany" should "error on duplicates by default" // TODO(dom): Not working for mongo - #[connector_test(exclude(Sqlite, MongoDb))] + #[connector_test(exclude(Sqlite, MongoDb, Vitess("planetscale.js")))] async fn nested_createmany_fail_dups(runner: Runner) -> TestResult<()> { assert_error!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs index 2ed763a1f123..808af82deec4 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs @@ -26,7 +26,7 @@ mod compound_fks { } // "A One to Many relation with mixed requiredness" should "be writable and readable" - #[connector_test(exclude(MySql(5.6), MongoDb))] + #[connector_test(exclude(MySql(5.6), MongoDb, Vitess("planetscale.js")))] async fn one2m_mix_required_writable_readable(runner: Runner) -> TestResult<()> { // Setup user insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs index 1247b3e27bea..1507ea0c082b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs @@ -205,7 +205,7 @@ mod create { // TODO(dom): Not working on mongo // TODO(dom): 'Expected result to return an error, but found success: {"data":{"createOneScalarModel":{"optUnique":"test"}}}' // Comment(dom): Expected, we're not enforcing uniqueness for the test setup yet. - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn gracefully_fails_when_uniq_violation(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs index 35a044b1473d..94118b669c1b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs @@ -165,7 +165,7 @@ mod create_many { } // "createMany" should "error on duplicates by default" - #[connector_test(schema(schema_4))] + #[connector_test(schema(schema_4), exclude(Vitess("planetscale.js")))] async fn create_many_error_dups(runner: Runner) -> TestResult<()> { assert_error!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs index 7e969e21cdce..749048fd3edc 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs @@ -123,7 +123,7 @@ mod update_many { } // "An updateMany mutation" should "correctly apply all number operations for Int" - #[connector_test(exclude(CockroachDb))] + #[connector_test(exclude(Vitess("planetscale.js"), CockroachDb))] async fn apply_number_ops_for_int(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; create_row(&runner, r#"{ id: 2, optStr: "str2", optInt: 2 }"#).await?; @@ -240,7 +240,7 @@ mod update_many { } // "An updateMany mutation" should "correctly apply all number operations for Float" - #[connector_test] + #[connector_test(exclude(Vitess("planetscale.js")))] async fn apply_number_ops_for_float(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; create_row(&runner, r#"{ id: 2, optStr: "str2", optFloat: 2 }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs index 2b3dee14f8e7..f4f43eda05ac 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs @@ -674,7 +674,7 @@ mod upsert { Ok(()) } - #[connector_test(schema(generic))] + #[connector_test(schema(generic), exclude(Vitess("planetscale.js")))] async fn upsert_fails_if_filter_dont_match(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index 088a0d4b2d34..f257d9e52162 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -10,7 +10,7 @@ once_cell = "1" qe-setup = { path = "../qe-setup" } request-handlers = { path = "../../request-handlers" } tokio.workspace = true -query-core = { path = "../../core" } +query-core = { path = "../../core", features = ["metrics"] } sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine"} psl.workspace = true diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs index 4af4e763298a..49ca5440a25c 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs @@ -2,43 +2,64 @@ use crate::{ CockroachDbConnectorTag, ConnectorTag, ConnectorVersion, MongoDbConnectorTag, MySqlConnectorTag, PostgresConnectorTag, SqlServerConnectorTag, SqliteConnectorTag, TestResult, VitessConnectorTag, }; -use serde::Deserialize; -use std::{convert::TryFrom, env, fs::File, io::Read, path::PathBuf}; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, env, fmt::Display, fs::File, io::Read, path::PathBuf}; static TEST_CONFIG_FILE_NAME: &str = ".test_config"; +#[derive(Debug, Deserialize, Clone)] +pub enum TestExecutor { + Napi, + Wasm, +} + +impl Display for TestExecutor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TestExecutor::Napi => f.write_str("Napi"), + TestExecutor::Wasm => f.write_str("Wasm"), + } + } +} + /// The central test configuration. #[derive(Debug, Default, Deserialize)] pub struct TestConfig { /// The connector that tests should run for. /// Env key: `TEST_CONNECTOR` - connector: String, + pub(crate) connector: String, /// The connector version tests should run for. /// If the test connector is versioned, this option is required. /// Env key: `TEST_CONNECTOR_VERSION` #[serde(rename = "version")] - connector_version: Option, + pub(crate) connector_version: Option, /// An external process to execute the test queries and produced responses for assertion /// Used when testing driver adapters, this process is expected to be a javascript process /// loading the library engine (as a library, or WASM modules) and providing it with a /// driver adapter. + /// Possible values: Napi, Wasm /// Env key: `EXTERNAL_TEST_EXECUTOR` - external_test_executor: Option, + pub(crate) external_test_executor: Option, /// The driver adapter to use when running tests, will be forwarded to the external test /// executor by setting the `DRIVER_ADAPTER` env var when spawning the executor process - driver_adapter: Option, + pub(crate) driver_adapter: Option, /// The driver adapter configuration to forward as a stringified JSON object to the external /// test executor by setting the `DRIVER_ADAPTER_CONFIG` env var when spawning the executor - driver_adapter_config: Option, + pub(crate) driver_adapter_config: Option, /// Indicates whether or not the tests are running in CI context. /// Env key: `BUILDKITE` #[serde(default)] - is_ci: bool, + pub(crate) is_ci: bool, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub(crate) struct DriverAdapterConfig { + pub(crate) proxy_url: Option, } const CONFIG_LOAD_FAILED: &str = r####" @@ -85,12 +106,11 @@ fn exit_with_message(msg: &str) -> ! { impl TestConfig { /// Loads a configuration. File-based config has precedence over env config. pub(crate) fn load() -> Self { - let mut config = match Self::from_file().or_else(Self::from_env) { + let config = match Self::from_file().or_else(Self::from_env) { Some(config) => config, None => exit_with_message(CONFIG_LOAD_FAILED), }; - config.fill_defaults(); config.validate(); config.log_info(); @@ -107,10 +127,10 @@ impl TestConfig { self.connector_version().unwrap_or_default() ); println!("* CI? {}", self.is_ci); - if self.external_test_executor.as_ref().is_some() { - println!("* External test executor: {}", self.external_test_executor().unwrap_or_default()); + if let Some(external_test_executor) = self.external_test_executor.as_ref() { + println!("* External test executor: {}", external_test_executor); println!("* Driver adapter: {}", self.driver_adapter().unwrap_or_default()); - println!("* Driver adapter url override: {}", self.json_stringify_driver_adapter_config()); + println!("* Driver adapter config: {}", self.json_stringify_driver_adapter_config()); } println!("******************************"); } @@ -118,10 +138,13 @@ impl TestConfig { fn from_env() -> Option { let connector = std::env::var("TEST_CONNECTOR").ok(); let connector_version = std::env::var("TEST_CONNECTOR_VERSION").ok(); - let external_test_executor = std::env::var("EXTERNAL_TEST_EXECUTOR").ok(); + let external_test_executor = std::env::var("EXTERNAL_TEST_EXECUTOR") + .map(|value| serde_json::from_str::(&value).ok()) + .unwrap_or_default(); + let driver_adapter = std::env::var("DRIVER_ADAPTER").ok(); let driver_adapter_config = std::env::var("DRIVER_ADAPTER_CONFIG") - .map(|config| serde_json::from_str::(config.as_str()).ok()) + .map(|config| serde_json::from_str::(config.as_str()).ok()) .unwrap_or_default(); // Just care for a set value for now. @@ -155,31 +178,24 @@ impl TestConfig { }) } - /// if the loaded value for external_test_executor is "default" (case insensitive), - /// and the workspace_root is set, then use the default external test executor. - fn fill_defaults(&mut self) { + fn workspace_root() -> Option { + env::var("WORKSPACE_ROOT").ok().map(PathBuf::from) + } + + pub fn external_test_executor_path(&self) -> Option { const DEFAULT_TEST_EXECUTOR: &str = "query-engine/driver-adapters/connector-test-kit-executor/script/start_node.sh"; - - if self - .external_test_executor + self.external_test_executor .as_ref() - .filter(|s| s.eq_ignore_ascii_case("default")) - .is_some() - { - self.external_test_executor = Self::workspace_root() - .map(|path| path.join(DEFAULT_TEST_EXECUTOR)) - .or_else(|| { + .and_then(|_| { + Self::workspace_root().or_else(|| { exit_with_message( "WORKSPACE_ROOT needs to be correctly set to the root of the prisma-engines repository", ) }) - .and_then(|path| path.to_str().map(|s| s.to_owned())); - } - } - - fn workspace_root() -> Option { - env::var("WORKSPACE_ROOT").ok().map(PathBuf::from) + }) + .map(|path| path.join(DEFAULT_TEST_EXECUTOR)) + .and_then(|path| path.to_str().map(|s| s.to_owned())) } fn validate(&self) { @@ -193,7 +209,8 @@ impl TestConfig { | Ok(ConnectorVersion::SqlServer(None)) | Ok(ConnectorVersion::MongoDb(None)) | Ok(ConnectorVersion::CockroachDb(None)) - | Ok(ConnectorVersion::Postgres(None)) => { + | Ok(ConnectorVersion::Postgres(None)) + | Ok(ConnectorVersion::Sqlite(None)) => { exit_with_message("The current test connector requires a version to be set to run."); } Ok(ConnectorVersion::Vitess(Some(_))) @@ -202,11 +219,11 @@ impl TestConfig { | Ok(ConnectorVersion::MongoDb(Some(_))) | Ok(ConnectorVersion::CockroachDb(Some(_))) | Ok(ConnectorVersion::Postgres(Some(_))) - | Ok(ConnectorVersion::Sqlite) => (), + | Ok(ConnectorVersion::Sqlite(Some(_))) => (), Err(err) => exit_with_message(&err.to_string()), } - if let Some(file) = self.external_test_executor.as_ref() { + if let Some(file) = self.external_test_executor_path().as_ref() { let path = PathBuf::from(file); let md = path.metadata(); if !path.exists() || md.is_err() || !md.unwrap().is_file() { @@ -259,19 +276,16 @@ impl TestConfig { self.is_ci } - pub fn external_test_executor(&self) -> Option<&str> { - self.external_test_executor.as_deref() + pub fn external_test_executor(&self) -> Option { + self.external_test_executor.clone() } pub fn driver_adapter(&self) -> Option<&str> { self.driver_adapter.as_deref() } - pub fn json_stringify_driver_adapter_config(&self) -> String { - self.driver_adapter_config - .as_ref() - .map(|value| value.to_string()) - .unwrap_or("{}".to_string()) + fn json_stringify_driver_adapter_config(&self) -> String { + serde_json::to_string(&self.driver_adapter_config).unwrap_or_default() } pub fn test_connector(&self) -> TestResult<(ConnectorTag, ConnectorVersion)> { @@ -281,7 +295,7 @@ impl TestConfig { ConnectorVersion::Postgres(_) => &PostgresConnectorTag, ConnectorVersion::MySql(_) => &MySqlConnectorTag, ConnectorVersion::MongoDb(_) => &MongoDbConnectorTag, - ConnectorVersion::Sqlite => &SqliteConnectorTag, + ConnectorVersion::Sqlite(_) => &SqliteConnectorTag, ConnectorVersion::CockroachDb(_) => &CockroachDbConnectorTag, ConnectorVersion::Vitess(_) => &VitessConnectorTag, }; @@ -294,11 +308,16 @@ impl TestConfig { vec!( ( "DRIVER_ADAPTER".to_string(), - self.driver_adapter.clone().unwrap_or_default()), + self.driver_adapter.clone().unwrap_or_default() + ), ( "DRIVER_ADAPTER_CONFIG".to_string(), self.json_stringify_driver_adapter_config() ), + ( + "EXTERNAL_TEST_EXECUTOR".to_string(), + self.external_test_executor.clone().unwrap_or(TestExecutor::Napi).to_string(), + ), ( "PRISMA_DISABLE_QUAINT_EXECUTORS".to_string(), "1".to_string(), diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs index 583d5058c62e..1abfedbaf8ee 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs @@ -74,7 +74,7 @@ impl ExecutorProcess { }; self.task_handle.send((method_call, sender)).await?; - let raw_response = receiver.await?; + let raw_response = receiver.await??; tracing::debug!(%raw_response); let response = serde_json::from_value(raw_response)?; Ok(response) @@ -91,14 +91,17 @@ pub(super) static EXTERNAL_PROCESS: Lazy = } }); -type ReqImpl = (jsonrpc_core::MethodCall, oneshot::Sender); +type ReqImpl = ( + jsonrpc_core::MethodCall, + oneshot::Sender>, +); fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { use std::process::Stdio; use tokio::process::Command; let path = crate::CONFIG - .external_test_executor() + .external_test_executor_path() .unwrap_or_else(|| exit_with_message(1, "start_rpc_thread() error: external test executor is not set")); tokio::runtime::Builder::new_current_thread() @@ -106,7 +109,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { .build() .unwrap() .block_on(async move { - let process = match Command::new(path) + let process = match Command::new(&path) .envs(CONFIG.for_external_executor()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) @@ -119,7 +122,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { let mut stdout = BufReader::new(process.stdout.unwrap()).lines(); let mut stdin = process.stdin.unwrap(); - let mut pending_requests: HashMap> = + let mut pending_requests: HashMap>> = HashMap::new(); loop { @@ -140,10 +143,11 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { // The other end may be dropped if the whole // request future was dropped and not polled to // completion, so we ignore send errors here. - _ = sender.send(success.result); + _ = sender.send(Ok(success.result)); } jsonrpc_core::Output::Failure(err) => { - panic!("error response from jsonrpc: {err:?}") + tracing::error!("error response from jsonrpc: {err:?}"); + _ = sender.send(Err(Box::new(err.error))); } } } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs index d92bb5e96314..6cc6120f71c8 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs @@ -18,7 +18,7 @@ pub(crate) use sql_server::*; pub(crate) use sqlite::*; pub(crate) use vitess::*; -use crate::{datamodel_rendering::DatamodelRenderer, BoxFuture, TestError, CONFIG}; +use crate::{datamodel_rendering::DatamodelRenderer, BoxFuture, TestConfig, TestError, CONFIG}; use psl::datamodel_connector::ConnectorCapabilities; use std::{convert::TryFrom, fmt}; @@ -49,12 +49,13 @@ pub trait ConnectorTagInterface { /// - `is_ci` signals whether or not the test run is done on CI or not. May be important if local /// test run connection strings and CI connection strings differ because of networking. pub(crate) fn connection_string( + test_config: &TestConfig, version: &ConnectorVersion, database: &str, - is_ci: bool, is_multi_schema: bool, isolation_level: Option<&'static str>, ) -> String { + let is_ci = test_config.is_ci; match version { ConnectorVersion::SqlServer(v) => { let database = if is_multi_schema { @@ -98,7 +99,7 @@ pub(crate) fn connection_string( Some(PostgresVersion::V12) if is_ci => { format!("postgresql://postgres:prisma@test-db-postgres-12:5432/{database}") } - Some(PostgresVersion::V13) if is_ci => { + Some(PostgresVersion::V13) | Some(PostgresVersion::NeonJs) | Some(PostgresVersion::PgJs) if is_ci => { format!("postgresql://postgres:prisma@test-db-postgres-13:5432/{database}") } Some(PostgresVersion::V14) if is_ci => { @@ -115,7 +116,9 @@ pub(crate) fn connection_string( Some(PostgresVersion::V10) => format!("postgresql://postgres:prisma@127.0.0.1:5432/{database}"), Some(PostgresVersion::V11) => format!("postgresql://postgres:prisma@127.0.0.1:5433/{database}"), Some(PostgresVersion::V12) => format!("postgresql://postgres:prisma@127.0.0.1:5434/{database}"), - Some(PostgresVersion::V13) => format!("postgresql://postgres:prisma@127.0.0.1:5435/{database}"), + Some(PostgresVersion::V13) | Some(PostgresVersion::NeonJs) | Some(PostgresVersion::PgJs) => { + format!("postgresql://postgres:prisma@127.0.0.1:5435/{database}") + } Some(PostgresVersion::V14) => format!("postgresql://postgres:prisma@127.0.0.1:5437/{database}"), Some(PostgresVersion::V15) => format!("postgresql://postgres:prisma@127.0.0.1:5438/{database}"), Some(PostgresVersion::PgBouncer) => { @@ -162,7 +165,7 @@ pub(crate) fn connection_string( } None => unreachable!("A versioned connector must have a concrete version to run."), }, - ConnectorVersion::Sqlite => { + ConnectorVersion::Sqlite(_) => { let workspace_root = std::env::var("WORKSPACE_ROOT") .unwrap_or_else(|_| ".".to_owned()) .trim_end_matches('/') @@ -196,8 +199,12 @@ pub(crate) fn connection_string( None => unreachable!("A versioned connector must have a concrete version to run."), } } - ConnectorVersion::Vitess(Some(VitessVersion::V5_7)) => "mysql://root@localhost:33577/test".into(), + ConnectorVersion::Vitess(Some(VitessVersion::V8_0)) => "mysql://root@localhost:33807/test".into(), + ConnectorVersion::Vitess(Some(VitessVersion::PlanetscaleJs)) => { + format!("mysql://root@127.0.0.1:3310/{database}") + } + ConnectorVersion::Vitess(None) => unreachable!("A versioned connector must have a concrete version to run."), } } @@ -210,12 +217,25 @@ pub enum ConnectorVersion { Postgres(Option), MySql(Option), MongoDb(Option), - Sqlite, + Sqlite(Option), CockroachDb(Option), Vitess(Option), } impl ConnectorVersion { + fn is_broader(&self, other: &ConnectorVersion) -> bool { + matches!( + (self, other), + (Self::SqlServer(None), Self::SqlServer(_)) + | (Self::Postgres(None), Self::Postgres(_)) + | (Self::MySql(None), Self::MySql(_)) + | (Self::MongoDb(None), Self::MongoDb(_)) + | (Self::Sqlite(None), Self::Sqlite(_)) + | (Self::CockroachDb(None), Self::CockroachDb(_)) + | (Self::Vitess(None), Self::Vitess(_)) + ) + } + fn matches_pattern(&self, pat: &ConnectorVersion) -> bool { use ConnectorVersion::*; @@ -233,14 +253,14 @@ impl ConnectorVersion { (MongoDb(a), MongoDb(b)) => versions_match(a, b), (CockroachDb(a), CockroachDb(b)) => versions_match(a, b), (Vitess(a), Vitess(b)) => versions_match(a, b), - (Sqlite, Sqlite) => true, + (Sqlite(a), Sqlite(b)) => versions_match(a, b), (MongoDb(..), _) | (_, MongoDb(..)) | (SqlServer(..), _) | (_, SqlServer(..)) - | (Sqlite, _) - | (_, Sqlite) + | (Sqlite(..), _) + | (_, Sqlite(..)) | (CockroachDb(..), _) | (_, CockroachDb(..)) | (Vitess(..), _) @@ -270,7 +290,10 @@ impl fmt::Display for ConnectorVersion { Some(v) => format!("MongoDB ({})", v.to_string()), None => "MongoDB (unknown)".to_string(), }, - Self::Sqlite => "SQLite".to_string(), + Self::Sqlite(v) => match v { + Some(v) => format!("SQLite ({})", v.to_string()), + None => "SQLite (unknown)".to_string(), + }, Self::Vitess(v) => match v { Some(v) => format!("Vitess ({v})"), None => "Vitess (unknown)".to_string(), @@ -285,38 +308,47 @@ impl fmt::Display for ConnectorVersion { /// Determines whether or not a test should run for the given enabled connectors and capabilities /// a connector is required to have. pub(crate) fn should_run( + connector: &ConnectorTag, + version: &ConnectorVersion, only: &[(&str, Option<&str>)], exclude: &[(&str, Option<&str>)], capabilities: ConnectorCapabilities, ) -> bool { - let (connector, version) = CONFIG.test_connector().unwrap(); - if !capabilities.is_empty() && !connector.capabilities().contains(capabilities) { println!("Connector excluded. Missing required capability."); return false; } - if !only.is_empty() { - return only - .iter() - .any(|only| ConnectorVersion::try_from(*only).unwrap().matches_pattern(&version)); - } + let exclusions = exclude + .iter() + .filter_map(|c| ConnectorVersion::try_from(*c).ok()) + .collect::>(); - if CONFIG.external_test_executor().is_some() && exclude.iter().any(|excl| excl.0.to_uppercase() == "JS") { - println!("Excluded test execution for JS driver adapters. Skipping test"); - return false; - }; + let inclusions = only + .iter() + .filter_map(|c| ConnectorVersion::try_from(*c).ok()) + .collect::>(); + + for exclusion in exclusions.iter() { + for inclusion in inclusions.iter() { + if exclusion.is_broader(inclusion) { + panic!("Error in connector test execution rules. Version `{exclusion}` in `excluded()` is broader than `{inclusion}` in `only()`"); + } + } + } - if exclude.iter().any(|excl| { - ConnectorVersion::try_from(*excl).map_or(false, |connector_version| connector_version.matches_pattern(&version)) - }) { + if exclusions.iter().any(|excl| excl.matches_pattern(version)) { println!("Connector excluded. Skipping test."); return false; } + if !inclusions.is_empty() { + return inclusions.iter().any(|incl| incl.matches_pattern(version)); + } + // FIXME: This skips vitess unless explicitly opted in. Replace with `true` when fixing // https://github.com/prisma/client-planning/issues/332 - !matches!(version, ConnectorVersion::Vitess(_)) + CONFIG.external_test_executor().is_some() || !matches!(version, ConnectorVersion::Vitess(_)) } impl TryFrom<(&str, Option<&str>)> for ConnectorVersion { @@ -325,7 +357,7 @@ impl TryFrom<(&str, Option<&str>)> for ConnectorVersion { #[track_caller] fn try_from((connector, version): (&str, Option<&str>)) -> Result { Ok(match connector.to_lowercase().as_str() { - "sqlite" => ConnectorVersion::Sqlite, + "sqlite" => ConnectorVersion::Sqlite(version.map(SqliteVersion::try_from).transpose()?), "sqlserver" => ConnectorVersion::SqlServer(version.map(SqlServerVersion::try_from).transpose()?), "cockroachdb" => ConnectorVersion::CockroachDb(version.map(CockroachDbVersion::try_from).transpose()?), "postgres" => ConnectorVersion::Postgres(version.map(PostgresVersion::try_from).transpose()?), @@ -336,3 +368,44 @@ impl TryFrom<(&str, Option<&str>)> for ConnectorVersion { }) } } + +#[cfg(test)] +mod tests { + use crate::connector_tag::{PostgresConnectorTag, PostgresVersion}; + use crate::{ConnectorTag, ConnectorVersion}; + + #[test] + #[rustfmt::skip] + fn test_should_run() { + let only = vec![("postgres", None)]; + let exclude = vec![("postgres", Some("neon.js"))]; + let postgres = &PostgresConnectorTag as ConnectorTag; + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); + let pg = ConnectorVersion::Postgres(Some(PostgresVersion::PgJs)); + + assert!(!super::should_run(&postgres, &neon, &only, &exclude, Default::default())); + assert!(super::should_run(&postgres, &pg, &only, &exclude, Default::default())); + } + + #[test] + #[should_panic] + fn test_should_run_wrong_definition_versionless() { + let only = vec![("postgres", None)]; + let exclude = vec![("postgres", None)]; + let postgres = &PostgresConnectorTag as ConnectorTag; + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); + + super::should_run(&postgres, &neon, &only, &exclude, Default::default()); + } + + #[test] + #[should_panic] + fn test_should_run_wrong_definition_wider_exclusion() { + let only = vec![("postgres", Some("neon.js"))]; + let exclude = vec![("postgres", None)]; + let postgres = &PostgresConnectorTag as ConnectorTag; + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); + + super::should_run(&postgres, &neon, &only, &exclude, Default::default()); + } +} diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs index 039231a3f74e..42d0a8c7afdc 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs @@ -36,6 +36,8 @@ pub enum PostgresVersion { V14, V15, PgBouncer, + NeonJs, + PgJs, } impl TryFrom<&str> for PostgresVersion { @@ -51,6 +53,8 @@ impl TryFrom<&str> for PostgresVersion { "14" => Self::V14, "15" => Self::V15, "pgbouncer" => Self::PgBouncer, + "neon.js" => Self::NeonJs, + "pg.js" => Self::PgJs, _ => return Err(TestError::parse_error(format!("Unknown Postgres version `{s}`"))), }; @@ -69,6 +73,8 @@ impl ToString for PostgresVersion { PostgresVersion::V14 => "14", PostgresVersion::V15 => "15", PostgresVersion::PgBouncer => "pgbouncer", + PostgresVersion::NeonJs => "neon.js", + PostgresVersion::PgJs => "pg.js", } .to_owned() } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs index 16b16cf5ba22..5f4dab56784a 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs @@ -25,3 +25,31 @@ impl ConnectorTagInterface for SqliteConnectorTag { psl::builtin_connectors::SQLITE.capabilities() } } + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum SqliteVersion { + V3, + LibsqlJS, +} + +impl ToString for SqliteVersion { + fn to_string(&self) -> String { + match self { + SqliteVersion::V3 => "3".to_string(), + SqliteVersion::LibsqlJS => "libsql.js".to_string(), + } + } +} + +impl TryFrom<&str> for SqliteVersion { + type Error = TestError; + + fn try_from(s: &str) -> Result { + let version = match s { + "3" => Self::V3, + "libsql.js" => Self::LibsqlJS, + _ => return Err(TestError::parse_error(format!("Unknown SQLite version `{s}`"))), + }; + Ok(version) + } +} diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs index 7afb78bab630..ce827927b403 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs @@ -33,8 +33,8 @@ impl ConnectorTagInterface for VitessConnectorTag { #[derive(Debug, Clone, Copy, PartialEq)] pub enum VitessVersion { - V5_7, V8_0, + PlanetscaleJs, } impl FromStr for VitessVersion { @@ -42,8 +42,8 @@ impl FromStr for VitessVersion { fn from_str(s: &str) -> Result { let version = match s { - "5.7" => Self::V5_7, "8.0" => Self::V8_0, + "planetscale.js" => Self::PlanetscaleJs, _ => return Err(TestError::parse_error(format!("Unknown Vitess version `{s}`"))), }; @@ -54,8 +54,8 @@ impl FromStr for VitessVersion { impl Display for VitessVersion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::V5_7 => write!(f, "5.7"), Self::V8_0 => write!(f, "8.0"), + Self::PlanetscaleJs => write!(f, "planetscale.js"), } } } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs index ea7360c56fa6..7295972f9812 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs @@ -62,13 +62,7 @@ pub fn render_test_datamodel( }} "#}, tag.datamodel_provider(), - connection_string( - &version, - test_database, - CONFIG.is_ci(), - is_multi_schema, - isolation_level - ), + connection_string(&CONFIG, &version, test_database, is_multi_schema, isolation_level), relation_mode_override.unwrap_or_else(|| tag.relation_mode().to_string()), schema_def, preview_features diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs index 2e79581a0c78..af99d9a7a7d3 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs @@ -152,8 +152,9 @@ fn run_relation_link_test_impl( let required_capabilities_for_test = required_capabilities | caps; let test_db_name = format!("{suite_name}_{test_name}_{i}"); let template = dm.datamodel().to_owned(); + let (connector, version) = CONFIG.test_connector().unwrap(); - if !should_run(only, exclude, required_capabilities_for_test) { + if !should_run(&connector, &version, only, exclude, required_capabilities_for_test) { continue; } @@ -250,7 +251,9 @@ fn run_connector_test_impl( referential_override: Option, test_fn: &dyn Fn(Runner) -> BoxFuture<'static, TestResult<()>>, ) { - if !should_run(only, exclude, capabilities) { + let (connector, version) = CONFIG.test_connector().unwrap(); + + if !should_run(&connector, &version, only, exclude, capabilities) { return; } diff --git a/query-engine/connector-test-kit-rs/test-configs/libsql-js b/query-engine/connector-test-kit-rs/test-configs/libsql-js new file mode 100644 index 000000000000..9c7ffe8f8473 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/libsql-js @@ -0,0 +1,6 @@ +{ + "connector": "sqlite", + "version": "libsql.js", + "driver_adapter": "libsql", + "external_test_executor": "Napi" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite b/query-engine/connector-test-kit-rs/test-configs/libsql-wasm similarity index 60% rename from query-engine/connector-test-kit-rs/test-configs/libsql-sqlite rename to query-engine/connector-test-kit-rs/test-configs/libsql-wasm index 9638e3a22840..b93966875dea 100644 --- a/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite +++ b/query-engine/connector-test-kit-rs/test-configs/libsql-wasm @@ -1,5 +1,5 @@ { "connector": "sqlite", "driver_adapter": "libsql", - "external_test_executor": "default" + "external_test_executor": "Wasm" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-js b/query-engine/connector-test-kit-rs/test-configs/neon-js new file mode 100644 index 000000000000..ac76377ac66d --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/neon-js @@ -0,0 +1,7 @@ +{ + "connector": "postgres", + "version": "neon.js", + "driver_adapter": "neon:ws", + "driver_adapter_config": { "proxy_url": "127.0.0.1:5488/v1" }, + "external_test_executor": "Napi" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-wasm b/query-engine/connector-test-kit-rs/test-configs/neon-wasm new file mode 100644 index 000000000000..2697c5227399 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/neon-wasm @@ -0,0 +1,7 @@ +{ + "connector": "postgres", + "version": "13", + "driver_adapter": "neon:ws", + "driver_adapter_config": { "proxy_url": "127.0.0.1:5488/v1" }, + "external_test_executor": "Wasm" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 b/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 deleted file mode 100644 index 0097d8c91f57..000000000000 --- a/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 +++ /dev/null @@ -1,7 +0,0 @@ -{ - "connector": "postgres", - "version": "13", - "driver_adapter": "neon:ws", - "driver_adapter_config": { "proxyUrl": "127.0.0.1:5488/v1" }, - "external_test_executor": "default" -} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/pg-js b/query-engine/connector-test-kit-rs/test-configs/pg-js new file mode 100644 index 000000000000..23fddfd72a06 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/pg-js @@ -0,0 +1,6 @@ +{ + "connector": "postgres", + "version": "pg.js", + "driver_adapter": "pg", + "external_test_executor": "Napi" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 b/query-engine/connector-test-kit-rs/test-configs/pg-wasm similarity index 66% rename from query-engine/connector-test-kit-rs/test-configs/pg-postgres13 rename to query-engine/connector-test-kit-rs/test-configs/pg-wasm index 00f0c75ed736..b5d8ac3c7b15 100644 --- a/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 +++ b/query-engine/connector-test-kit-rs/test-configs/pg-wasm @@ -2,5 +2,5 @@ "connector": "postgres", "version": "13", "driver_adapter": "pg", - "external_test_executor": "default" + "external_test_executor": "Wasm" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-js b/query-engine/connector-test-kit-rs/test-configs/planetscale-js new file mode 100644 index 000000000000..327ad94ba661 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/planetscale-js @@ -0,0 +1,9 @@ +{ + "connector": "vitess", + "version": "planetscale.js", + "driver_adapter": "planetscale", + "driver_adapter_config": { + "proxy_url": "http://root:root@127.0.0.1:8085" + }, + "external_test_executor": "Napi" +} diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 b/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 deleted file mode 100644 index 48c89c79427c..000000000000 --- a/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 +++ /dev/null @@ -1,7 +0,0 @@ -{ - "connector": "vitess", - "version": "8.0", - "driver_adapter": "planetscale", - "driver_adapter_config": { "proxyUrl": "http://root:root@127.0.0.1:8085" }, - "external_test_executor": "default" -} diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm b/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm new file mode 100644 index 000000000000..62dd895e970c --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm @@ -0,0 +1,9 @@ +{ + "connector": "vitess", + "version": "planetscale.js", + "driver_adapter": "planetscale", + "driver_adapter_config": { + "proxy_url": "http://root:root@127.0.0.1:8085" + }, + "external_test_executor": "Wasm" +} diff --git a/query-engine/connector-test-kit-rs/test-configs/sqlite b/query-engine/connector-test-kit-rs/test-configs/sqlite index cfbcc7e8829b..092f9182ec91 100644 --- a/query-engine/connector-test-kit-rs/test-configs/sqlite +++ b/query-engine/connector-test-kit-rs/test-configs/sqlite @@ -1,2 +1,4 @@ { - "connector": "sqlite"} \ No newline at end of file + "connector": "sqlite", + "version": "3" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 b/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 deleted file mode 100644 index 64fb5162ac41..000000000000 --- a/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 +++ /dev/null @@ -1,3 +0,0 @@ -{ - "connector": "vitess", - "version": "5.7"} \ No newline at end of file diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index d16771aa3daf..788b8ca65576 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -14,6 +14,6 @@ prisma-value = {path = "../../../libs/prisma-value"} serde.workspace = true serde_json.workspace = true thiserror = "1.0" -user-facing-errors = {path = "../../../libs/user-facing-errors"} +user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid = "1" indexmap = "1.7" diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 62d0be640761..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -5,6 +5,8 @@ version = "0.1.0" [features] vendored-openssl = ["quaint/vendored-openssl"] + +# Enable Driver Adapters driver-adapters = [] [dependencies] @@ -18,15 +20,20 @@ once_cell = "1.3" rand = "0.7" serde_json = {version = "1.0", features = ["float_roundtrip"]} thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = "0.1" tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint.workspace = true cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector" diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 0247e8c4b601..7895e838399a 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use super::{catch, transaction::SqlConnectorTransaction}; use crate::{database::operations::*, Context, SqlError}; use async_trait::async_trait; diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 695db13b6620..e693769373b0 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -1,12 +1,16 @@ mod connection; #[cfg(feature = "driver-adapters")] mod js; -mod mssql; -mod mysql; -mod postgresql; -mod sqlite; mod transaction; +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod native { + pub(crate) mod mssql; + pub(crate) mod mysql; + pub(crate) mod postgresql; + pub(crate) mod sqlite; +} + pub(crate) mod operations; use async_trait::async_trait; @@ -14,10 +18,9 @@ use connector_interface::{error::ConnectorError, Connector}; #[cfg(feature = "driver-adapters")] pub use js::*; -pub use mssql::*; -pub use mysql::*; -pub use postgresql::*; -pub use sqlite::*; + +#[cfg(not(target_arch = "wasm32"))] +pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; #[async_trait] pub trait FromSource { diff --git a/query-engine/connectors/sql-query-connector/src/database/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs similarity index 94% rename from query-engine/connectors/sql-query-connector/src/database/mssql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index 9655d205e4ca..19d3580bba9f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/mysql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index deb3e6a4f35f..477d687b995b 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/postgresql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 242b2b63090e..0e49a1de8bbd 100644 --- a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs similarity index 96% rename from query-engine/connectors/sql-query-connector/src/database/sqlite.rs rename to query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index 6be9faeac54d..e38bccb861f4 100644 --- a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info().clone(), async move { + catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 425f4ac1d4b3..611557c4f3ba 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -18,9 +18,28 @@ use std::{ ops::Deref, usize, }; -use tracing::log::trace; use user_facing_errors::query_engine::DatabaseConstraint; +#[cfg(target_arch = "wasm32")] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => {{ + // No-op in WebAssembly + }}; + ($($arg:tt)+) => {{ + // No-op in WebAssembly + }}; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => { + tracing::log::trace!(target: $target, $($arg)+); + }; + ($($arg:tt)+) => { + tracing::log::trace!($($arg)+); + }; +} + async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index ed1528ded6b5..74c0a4aab5d3 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -22,9 +22,12 @@ mod value_ext; use self::{column_metadata::*, context::Context, query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; +pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::{activate_driver_adapter, Js}; -pub use database::{FromSource, Mssql, Mysql, PostgreSql, Sqlite}; pub use error::SqlError; +#[cfg(not(target_arch = "wasm32"))] +pub use database::{Mssql, Mysql, PostgreSql, Sqlite}; + type Result = std::result::Result; diff --git a/query-engine/core-tests/Cargo.toml b/query-engine/core-tests/Cargo.toml index 9a2c3f5686eb..bac9219c3522 100644 --- a/query-engine/core-tests/Cargo.toml +++ b/query-engine/core-tests/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" dissimilar = "1.0.4" user-facing-errors = { path = "../../libs/user-facing-errors" } request-handlers = { path = "../request-handlers" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } schema = { path = "../schema" } psl.workspace = true serde_json.workspace = true diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index caadf6cdba00..9e0f03517cb5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -3,6 +3,9 @@ edition = "2021" name = "query-core" version = "0.1.0" +[features] +metrics = ["query-engine-metrics"] + [dependencies] async-trait = "0.1" bigdecimal = "0.3" @@ -18,11 +21,11 @@ once_cell = "1" petgraph = "0.4" prisma-models = { path = "../prisma-models", features = ["default_generators"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = {path = "../metrics"} +query-engine-metrics = { path = "../metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = { version = "0.1", features = ["attributes"] } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -34,3 +37,9 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +pin-project = "1" +wasm-bindgen-futures = "0.4" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +pin-project = "1" +wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 06452fcdd865..6ba21d37f9ff 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(unused_variables))] + use super::pipeline::QueryPipeline; use crate::{ executor::request_context, protocol::EngineProtocol, CoreError, IrSerializer, Operation, QueryGraph, @@ -5,9 +7,12 @@ use crate::{ }; use connector::{Connection, ConnectionLike, Connector}; use futures::future; + +#[cfg(feature = "metrics")] use query_engine_metrics::{ histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, }; + use schema::{QuerySchema, QuerySchemaRef}; use std::time::{Duration, Instant}; use tracing::Instrument; @@ -24,6 +29,7 @@ pub async fn execute_single_operation( let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -45,6 +51,8 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = Instant::now(); let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); match result { @@ -98,6 +106,7 @@ pub async fn execute_many_self_contained( let dispatcher = crate::get_current_dispatcher(); for op in operations { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let conn_span = info_span!( @@ -158,6 +167,7 @@ async fn execute_self_contained( execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await }; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -259,6 +269,7 @@ async fn execute_on<'a>( query_schema: &'a QuerySchema, trace_id: Option, ) -> crate::Result { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let interpreter = QueryInterpreter::new(conn); diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ddbb7dfc8429..ba2784d3c71a 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -10,6 +10,7 @@ mod execute_operation; mod interpreting_executor; mod pipeline; mod request_context; +pub(crate) mod task; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs new file mode 100644 index 000000000000..8d1c39bbcd06 --- /dev/null +++ b/query-engine/core/src/executor/task.rs @@ -0,0 +1,59 @@ +//! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. + +pub use arch::{spawn, JoinHandle}; +use futures::Future; + +// On native targets, `tokio::spawn` spawns a new asynchronous task. +#[cfg(not(target_arch = "wasm32"))] +mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +// On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. +#[cfg(target_arch = "wasm32")] +mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + pub struct JoinHandle(oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + core::pin::Pin::new(&mut self.0).poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } +} diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 98208343d28a..105733be4166 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -1,3 +1,4 @@ +use crate::executor::task::JoinHandle; use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; use connector::Connection; use lru::LruCache; @@ -9,7 +10,6 @@ use tokio::{ mpsc::{channel, Sender}, RwLock, }, - task::JoinHandle, time::Duration, }; diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 88402d86fedd..104ffc26812f 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,7 +1,8 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; +use crate::executor::task::{spawn, JoinHandle}; use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, - telemetry::helpers::set_span_link_from_traceparent, ClosedTx, Operation, ResponseData, TxId, + execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, + TxId, }; use connector::Connection; use schema::QuerySchemaRef; @@ -11,13 +12,15 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - task::JoinHandle, time::{self, Duration, Instant}, }; use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; +#[cfg(feature = "metrics")] +use crate::telemetry::helpers::set_span_link_from_traceparent; + #[derive(PartialEq)] enum RunState { Continue, @@ -81,6 +84,8 @@ impl<'a> ITXServer<'a> { traceparent: Option, ) -> crate::Result { let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); + + #[cfg(feature = "metrics")] set_span_link_from_traceparent(&span, traceparent.clone()); let conn = self.cached_tx.as_open()?; @@ -267,7 +272,7 @@ pub(crate) async fn spawn_itx_actor( }; let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - tokio::task::spawn( + spawn( crate::executor::with_request_context(engine_protocol, async move { // We match on the result in order to send the error to the parent task and abort this // task, on error. This is a separate task (actor), not a function where we can just bubble up the @@ -380,7 +385,7 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { - tokio::task::spawn(async move { + spawn(async move { loop { if let Some((id, closed_tx)) = rx.recv().await { trace!("removing {} from client list", id); diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 7970c96139b7..38f39e9fb5d9 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -9,6 +9,8 @@ pub mod protocol; pub mod query_document; pub mod query_graph_builder; pub mod response_ir; + +#[cfg(feature = "metrics")] pub mod telemetry; pub use self::{ @@ -16,8 +18,11 @@ pub use self::{ executor::{QueryExecutor, TransactionOptions}, interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, query_document::*, - telemetry::*, }; + +#[cfg(feature = "metrics")] +pub use self::telemetry::*; + pub use connector::{ error::{ConnectorError, ErrorKind as ConnectorErrorKind}, Connector, diff --git a/query-engine/driver-adapters/connector-test-kit-executor/package.json b/query-engine/driver-adapters/connector-test-kit-executor/package.json index a5b3f20ebdbc..81640a61b003 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/package.json +++ b/query-engine/driver-adapters/connector-test-kit-executor/package.json @@ -12,6 +12,11 @@ "scripts": { "build": "tsup ./src/index.ts --format esm --dts" }, + "tsup": { + "external": [ + "../../../query-engine-wasm/pkg/query_engine_bg.js" + ] + }, "keywords": [], "author": "", "sideEffects": false, diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts index b89348fb3e77..4e847742e51b 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts @@ -1,5 +1,4 @@ import * as qe from './qe' -import * as engines from './engines/Library' import * as readline from 'node:readline' import * as jsonRpc from './jsonRpc' @@ -18,7 +17,7 @@ import { createClient } from '@libsql/client' import { PrismaLibSQL } from '@prisma/adapter-libsql' // planetscale dependencies -import { connect as planetscaleConnect } from '@planetscale/database' +import { Client as PlanetscaleClient } from '@planetscale/database' import { PrismaPlanetScale } from '@prisma/adapter-planetscale' @@ -76,7 +75,7 @@ async function main(): Promise { } const state: Record = {} @@ -215,10 +214,10 @@ function respondOk(requestId: number, payload: unknown) { console.log(JSON.stringify(msg)) } -async function initQe(url: string, prismaSchema: string, logCallback: qe.QueryLogCallback): Promise<[engines.QueryEngineInstance, ErrorCapturingDriverAdapter]> { +async function initQe(url: string, prismaSchema: string, logCallback: qe.QueryLogCallback): Promise<[qe.QueryEngine, ErrorCapturingDriverAdapter]> { const adapter = await adapterFromEnv(url) as DriverAdapter const errorCapturingAdapter = bindAdapter(adapter) - const engineInstance = qe.initQueryEngine(errorCapturingAdapter, prismaSchema, logCallback, debug) + const engineInstance = await qe.initQueryEngine(errorCapturingAdapter, prismaSchema, logCallback, debug) return [engineInstance, errorCapturingAdapter]; } @@ -251,7 +250,7 @@ async function pgAdapter(url: string): Promise { } async function neonWsAdapter(url: string): Promise { - const proxyURL = JSON.parse(process.env.DRIVER_ADAPTER_CONFIG || '{}').proxyUrl ?? '' + const proxyURL = JSON.parse(process.env.DRIVER_ADAPTER_CONFIG || '{}').proxy_url ?? '' if (proxyURL == '') { throw new Error("DRIVER_ADAPTER_CONFIG is not defined or empty, but its required for neon adapter."); } @@ -271,17 +270,17 @@ async function libsqlAdapter(url: string): Promise { } async function planetscaleAdapter(url: string): Promise { - const proxyURL = JSON.parse(process.env.DRIVER_ADAPTER_CONFIG || '{}').proxyUrl ?? '' + const proxyURL = JSON.parse(process.env.DRIVER_ADAPTER_CONFIG || '{}').proxy_url ?? '' if (proxyURL == '') { throw new Error("DRIVER_ADAPTER_CONFIG is not defined or empty, but its required for planetscale adapter."); } - const connection = planetscaleConnect({ + const client = new PlanetscaleClient({ url: proxyURL, fetch, }) - return new PrismaPlanetScale(connection) + return new PrismaPlanetScale(client) } main().catch(err) diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts index 186d7a9e80d2..20e9a4917fb5 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts @@ -1,22 +1,24 @@ import type { ErrorCapturingDriverAdapter } from '@prisma/driver-adapter-utils' -import * as lib from './engines/Library' +import * as napi from './engines/Library' import * as os from 'node:os' import * as path from 'node:path' +import { fileURLToPath } from 'node:url' -export type QueryLogCallback = (log: string) => void +const dirname = path.dirname(fileURLToPath(import.meta.url)) -export function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: string, queryLogCallback: QueryLogCallback, debug: (...args: any[]) => void): lib.QueryEngineInstance { - // I assume nobody will run this on Windows ¯\_(ツ)_/¯ - const libExt = os.platform() === 'darwin' ? 'dylib' : 'so' - const dirname = path.dirname(new URL(import.meta.url).pathname) +export interface QueryEngine { + connect(trace: string): Promise + disconnect(trace: string): Promise; + query(body: string, trace: string, tx_id?: string): Promise; + startTransaction(input: string, trace: string): Promise; + commitTransaction(tx_id: string, trace: string): Promise; + rollbackTransaction(tx_id: string, trace: string): Promise; +} - const libQueryEnginePath = path.join(dirname, `../../../../target/debug/libquery_engine.${libExt}`) +export type QueryLogCallback = (log: string) => void - const libqueryEngine = { exports: {} as unknown as lib.Library } - // @ts-ignore - process.dlopen(libqueryEngine, libQueryEnginePath) - const QueryEngine = libqueryEngine.exports.QueryEngine +export async function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: string, queryLogCallback: QueryLogCallback, debug: (...args: any[]) => void): QueryEngine { const queryEngineOptions = { datamodel, @@ -37,5 +39,29 @@ export function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: debug(parsed) } - return new QueryEngine(queryEngineOptions, logCallback, adapter) + const engineFromEnv = process.env.EXTERNAL_TEST_EXECUTOR ?? 'Napi' + if (engineFromEnv === 'Wasm') { + const { WasmQueryEngine } = await import('./wasm') + return new WasmQueryEngine(queryEngineOptions, logCallback, adapter) + } else if (engineFromEnv === 'Napi') { + const { QueryEngine } = loadNapiEngine() + return new QueryEngine(queryEngineOptions, logCallback, adapter) + } else { + throw new TypeError(`Invalid EXTERNAL_TEST_EXECUTOR value: ${engineFromEnv}. Expected Napi or Wasm`) + } + + } + +function loadNapiEngine(): napi.Library { + // I assume nobody will run this on Windows ¯\_(ツ)_/¯ + const libExt = os.platform() === 'darwin' ? 'dylib' : 'so' + + const libQueryEnginePath = path.join(dirname, `../../../../target/debug/libquery_engine.${libExt}`) + + const libqueryEngine = { exports: {} as unknown as napi.Library } + // @ts-ignore + process.dlopen(libqueryEngine, libQueryEnginePath) + + return libqueryEngine.exports +} \ No newline at end of file diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts new file mode 100644 index 000000000000..439fd0c3f94f --- /dev/null +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts @@ -0,0 +1,14 @@ +import * as wasm from '../../../query-engine-wasm/pkg/query_engine_bg.js' +import fs from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' + +const dirname = path.dirname(fileURLToPath(import.meta.url)) + +const bytes = await fs.readFile(path.resolve(dirname, '..', '..', '..', 'query-engine-wasm', 'pkg', 'query_engine_bg.wasm')) +const module = new WebAssembly.Module(bytes) +const instance = new WebAssembly.Instance(module, { './query_engine_bg.js': wasm }) +wasm.__wbg_set_wasm(instance.exports); +wasm.init() + +export const WasmQueryEngine = wasm.QueryEngine \ No newline at end of file diff --git a/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json b/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json index 516c114b3e15..20fc4bd62ff7 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json +++ b/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json @@ -2,7 +2,7 @@ "compilerOptions": { "target": "ES2022", "module": "ESNext", - "lib": ["ES2022"], + "lib": ["ES2022", "DOM"], "moduleResolution": "Bundler", "esModuleInterop": false, "isolatedModules": true, @@ -17,7 +17,7 @@ "skipDefaultLibCheck": true, "skipLibCheck": true, "emitDeclarationOnly": true, - "resolveJsonModule": true + "resolveJsonModule": true, }, "exclude": ["**/dist", "**/declaration", "**/node_modules", "**/src/__tests__"] } \ No newline at end of file diff --git a/query-engine/driver-adapters/package.json b/query-engine/driver-adapters/package.json index e137d6a524b2..869da0a98173 100644 --- a/query-engine/driver-adapters/package.json +++ b/query-engine/driver-adapters/package.json @@ -11,7 +11,7 @@ "scripts": { "build": "pnpm -r run build", "lint": "pnpm -r run lint", - "clean": "git clean -nXd -e !query-engine/driver-adapters" + "clean": "git clean -dXf -e !query-engine/driver-adapters" }, "keywords": [], "author": "", diff --git a/query-engine/driver-adapters/src/conversion.rs b/query-engine/driver-adapters/src/conversion.rs index c6ea87f1bfa2..00061d72de44 100644 --- a/query-engine/driver-adapters/src/conversion.rs +++ b/query-engine/driver-adapters/src/conversion.rs @@ -1,3 +1,4 @@ +pub(crate) mod mysql; pub(crate) mod postgres; pub(crate) mod sqlite; @@ -6,10 +7,9 @@ use napi::NapiValue; use serde::Serialize; use serde_json::value::Value as JsonValue; -#[derive(Debug, Serialize)] +#[derive(Debug, PartialEq, Serialize)] #[serde(untagged)] pub enum JSArg { - RawString(String), Value(serde_json::Value), Buffer(Vec), Array(Vec), @@ -34,7 +34,6 @@ impl FromNapiValue for JSArg { impl ToNapiValue for JSArg { unsafe fn to_napi_value(env: napi::sys::napi_env, value: Self) -> napi::Result { match value { - JSArg::RawString(s) => ToNapiValue::to_napi_value(env, s), JSArg::Value(v) => ToNapiValue::to_napi_value(env, v), JSArg::Buffer(bytes) => { ToNapiValue::to_napi_value(env, napi::Env::from_raw(env).create_buffer_with_data(bytes)?.into_raw()) @@ -50,7 +49,7 @@ impl ToNapiValue for JSArg { for (index, item) in items.into_iter().enumerate() { let js_value = ToNapiValue::to_napi_value(env.raw(), item)?; // TODO: NapiRaw could be implemented for sys::napi_value directly, there should - // be no need for re-wrapping; submit a patch to napi-rs and simplify here. + // be no need for re-wrapping; submit a patch to napi-rs and simplify here. array.set(index as u32, napi::JsUnknown::from_raw_unchecked(env.raw(), js_value))?; } @@ -59,32 +58,3 @@ impl ToNapiValue for JSArg { } } } - -pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { - let res = match &value.typed { - quaint::ValueType::Json(s) => match s { - Some(ref s) => { - let json_str = serde_json::to_string(s)?; - JSArg::RawString(json_str) - } - None => JsonValue::Null.into(), - }, - quaint::ValueType::Bytes(bytes) => match bytes { - Some(bytes) => JSArg::Buffer(bytes.to_vec()), - None => JsonValue::Null.into(), - }, - quaint::ValueType::Numeric(bd) => match bd { - // converting decimal to string to preserve the precision - Some(bd) => JSArg::RawString(bd.to_string()), - None => JsonValue::Null.into(), - }, - quaint::ValueType::Array(Some(ref items)) => JSArg::Array(values_to_js_args(items)?), - quaint_value => JSArg::from(JsonValue::from(quaint_value.clone())), - }; - - Ok(res) -} - -pub fn values_to_js_args(values: &[quaint::Value<'_>]) -> serde_json::Result> { - values.iter().map(value_to_js_arg).collect() -} diff --git a/query-engine/driver-adapters/src/conversion/mysql.rs b/query-engine/driver-adapters/src/conversion/mysql.rs new file mode 100644 index 000000000000..aab33213431a --- /dev/null +++ b/query-engine/driver-adapters/src/conversion/mysql.rs @@ -0,0 +1,107 @@ +use crate::conversion::JSArg; +use serde_json::value::Value as JsonValue; + +const DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S%.f"; +const DATE_FORMAT: &str = "%Y-%m-%d"; +const TIME_FORMAT: &str = "%H:%M:%S%.f"; + +#[rustfmt::skip] +pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { + let res = match &value.typed { + quaint::ValueType::Numeric(Some(bd)) => JSArg::Value(JsonValue::String(bd.to_string())), + quaint::ValueType::Json(Some(s)) => JSArg::Value(JsonValue::String(serde_json::to_string(s)?)), + quaint::ValueType::Bytes(Some(bytes)) => JSArg::Buffer(bytes.to_vec()), + quaint::ValueType::Date(Some(d)) => JSArg::Value(JsonValue::String(d.format(DATE_FORMAT).to_string())), + quaint::ValueType::DateTime(Some(dt)) => JSArg::Value(JsonValue::String(dt.format(DATETIME_FORMAT).to_string())), + quaint::ValueType::Time(Some(t)) => JSArg::Value(JsonValue::String(t.format(TIME_FORMAT).to_string())), + quaint::ValueType::Array(Some(ref items)) => JSArg::Array( + items + .iter() + .map(value_to_js_arg) + .collect::>>()?, + ), + quaint_value => JSArg::from(JsonValue::from(quaint_value.clone())), + }; + Ok(res) +} + +#[cfg(test)] +mod test { + use super::*; + use bigdecimal::BigDecimal; + use chrono::*; + use quaint::ValueType; + use std::str::FromStr; + + #[test] + #[rustfmt::skip] + fn test_value_to_js_arg() { + let test_cases = vec![ + ( + ValueType::Numeric(Some(1.into())), + JSArg::Value(JsonValue::String("1".to_string())) + ), + ( + ValueType::Numeric(Some(BigDecimal::from_str("-1.1").unwrap())), + JSArg::Value(JsonValue::String("-1.1".to_string())) + ), + ( + ValueType::Numeric(None), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Json(Some(serde_json::json!({"a": 1}))), + JSArg::Value(JsonValue::String("{\"a\":1}".to_string())) + ), + ( + ValueType::Json(None), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Date(Some(NaiveDate::from_ymd_opt(2020, 2, 29).unwrap())), + JSArg::Value(JsonValue::String("2020-02-29".to_string())) + ), + ( + ValueType::Date(None), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::DateTime(Some(Utc.with_ymd_and_hms(2020, 1, 1, 23, 13, 1).unwrap().with_nanosecond(100).unwrap())), + JSArg::Value(JsonValue::String("2020-01-01 23:13:01.000000100".to_string())) + ), + ( + ValueType::DateTime(None), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap().with_nanosecond(1200).unwrap())), + JSArg::Value(JsonValue::String("23:13:01.000001200".to_string())) + ), + ( + ValueType::Time(None), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Array(Some(vec!( + ValueType::Numeric(Some(1.into())).into_value(), + ValueType::Numeric(None).into_value(), + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap())).into_value(), + ))), + JSArg::Array(vec!( + JSArg::Value(JsonValue::String("1".to_string())), + JSArg::Value(JsonValue::Null), + JSArg::Value(JsonValue::String("23:13:01".to_string())) + )) + ), + ]; + + let mut errors: Vec = vec![]; + for (val, expected) in test_cases { + let actual = value_to_js_arg(&val.clone().into_value()).unwrap(); + if actual != expected { + errors.push(format!("transforming: {:?}, expected: {:?}, actual: {:?}", &val, expected, actual)); + } + } + assert_eq!(errors.len(), 0, "{}", errors.join("\n")); + } +} diff --git a/query-engine/driver-adapters/src/conversion/postgres.rs b/query-engine/driver-adapters/src/conversion/postgres.rs index 8c00d0aae59f..113be5170a84 100644 --- a/query-engine/driver-adapters/src/conversion/postgres.rs +++ b/query-engine/driver-adapters/src/conversion/postgres.rs @@ -5,31 +5,115 @@ use serde_json::value::Value as JsonValue; static TIME_FMT: Lazy = Lazy::new(|| StrftimeItems::new("%H:%M:%S%.f")); +#[rustfmt::skip] pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { let res = match (&value.typed, value.native_column_type_name()) { - (quaint::ValueType::DateTime(value), Some("DATE")) => match value { - Some(value) => JSArg::RawString(value.date_naive().to_string()), - None => JsonValue::Null.into(), - }, - (quaint::ValueType::DateTime(value), Some("TIME")) => match value { - Some(value) => JSArg::RawString(value.time().to_string()), - None => JsonValue::Null.into(), - }, - (quaint::ValueType::DateTime(value), Some("TIMETZ")) => match value { - Some(value) => JSArg::RawString(value.time().format_with_items(TIME_FMT.clone()).to_string()), - None => JsonValue::Null.into(), - }, - (quaint::ValueType::DateTime(value), _) => match value { - Some(value) => JSArg::RawString(value.naive_utc().to_string()), - None => JsonValue::Null.into(), - }, - (quaint::ValueType::Array(Some(items)), _) => JSArg::Array(values_to_js_args(items)?), - _ => super::value_to_js_arg(value)?, + (quaint::ValueType::DateTime(Some(dt)), Some("DATE")) => JSArg::Value(JsonValue::String(dt.date_naive().to_string())), + (quaint::ValueType::DateTime(Some(dt)), Some("TIME")) => JSArg::Value(JsonValue::String(dt.time().to_string())), + (quaint::ValueType::DateTime(Some(dt)), Some("TIMETZ")) => JSArg::Value(JsonValue::String(dt.time().format_with_items(TIME_FMT.clone()).to_string())), + (quaint::ValueType::DateTime(Some(dt)), _) => JSArg::Value(JsonValue::String(dt.naive_utc().to_string())), + (quaint::ValueType::Json(Some(s)), _) => JSArg::Value(JsonValue::String(serde_json::to_string(s)?)), + (quaint::ValueType::Bytes(Some(bytes)), _) => JSArg::Buffer(bytes.to_vec()), + (quaint::ValueType::Numeric(Some(bd)), _) => JSArg::Value(JsonValue::String(bd.to_string())), + (quaint::ValueType::Array(Some(items)), _) => JSArg::Array( + items + .iter() + .map(value_to_js_arg) + .collect::>>()?, + ), + (quaint_value, _) => JSArg::from(JsonValue::from(quaint_value.clone())), }; Ok(res) } -pub fn values_to_js_args(values: &[quaint::Value<'_>]) -> serde_json::Result> { - values.iter().map(value_to_js_arg).collect() +#[cfg(test)] +mod test { + use super::*; + use bigdecimal::BigDecimal; + use chrono::*; + use quaint::ValueType; + use std::str::FromStr; + + #[test] + #[rustfmt::skip] + fn test_value_to_js_arg() { + let test_cases: Vec<(quaint::Value, JSArg)> = vec![ + ( + ValueType::Numeric(Some(1.into())).into_value(), + JSArg::Value(JsonValue::String("1".to_string())) + ), + ( + ValueType::Numeric(Some(BigDecimal::from_str("-1.1").unwrap())).into_value(), + JSArg::Value(JsonValue::String("-1.1".to_string())) + ), + ( + ValueType::Numeric(None).into_value(), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Json(Some(serde_json::json!({"a": 1}))).into_value(), + JSArg::Value(JsonValue::String("{\"a\":1}".to_string())) + ), + ( + ValueType::Json(None).into_value(), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Date(Some(NaiveDate::from_ymd_opt(2020, 2, 29).unwrap())).into_value(), + JSArg::Value(JsonValue::String("2020-02-29".to_string())) + ), + ( + ValueType::Date(None).into_value(), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::DateTime(Some(Utc.with_ymd_and_hms(2020, 1, 1, 23, 13, 1).unwrap())).into_value().with_native_column_type(Some("DATE")), + JSArg::Value(JsonValue::String("2020-01-01".to_string())) + ), + ( + ValueType::DateTime(Some(Utc.with_ymd_and_hms(2020, 1, 1, 23, 13, 1).unwrap())).into_value().with_native_column_type(Some("TIME")), + JSArg::Value(JsonValue::String("23:13:01".to_string())) + ), + ( + ValueType::DateTime(Some(Utc.with_ymd_and_hms(2020, 1, 1, 23, 13, 1).unwrap())).into_value().with_native_column_type(Some("TIMETZ")), + JSArg::Value(JsonValue::String("23:13:01".to_string())) + ), + ( + ValueType::DateTime(None).into_value(), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap())).into_value(), + JSArg::Value(JsonValue::String("23:13:01".to_string())) + ), + ( + ValueType::Time(None).into_value(), + JSArg::Value(JsonValue::Null) + ), + ( + ValueType::Array(Some(vec!( + ValueType::Numeric(Some(1.into())).into_value(), + ValueType::Numeric(None).into_value(), + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap())).into_value(), + ValueType::Time(None).into_value(), + ))).into_value(), + JSArg::Array(vec!( + JSArg::Value(JsonValue::String("1".to_string())), + JSArg::Value(JsonValue::Null), + JSArg::Value(JsonValue::String("23:13:01".to_string())), + JSArg::Value(JsonValue::Null), + )) + ), + ]; + + let mut errors: Vec = vec![]; + for (val, expected) in test_cases { + let actual = value_to_js_arg(&val).unwrap(); + if actual != expected { + errors.push(format!("transforming: {:?}, expected: {:?}, actual: {:?}", &val, expected, actual)); + } + } + assert_eq!(errors.len(), 0, "{}", errors.join("\n")); + } } diff --git a/query-engine/driver-adapters/src/conversion/sqlite.rs b/query-engine/driver-adapters/src/conversion/sqlite.rs index 4e6e56cb274a..032c16923256 100644 --- a/query-engine/driver-adapters/src/conversion/sqlite.rs +++ b/query-engine/driver-adapters/src/conversion/sqlite.rs @@ -3,21 +3,106 @@ use serde_json::value::Value as JsonValue; pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { let res = match &value.typed { - quaint::ValueType::Numeric(bd) => match bd { - // converting decimal to string to preserve the precision - Some(bd) => match bd.to_string().parse::() { - Ok(double) => JSArg::from(JsonValue::from(double)), - Err(_) => JSArg::from(JsonValue::from(value.clone())), - }, - None => JsonValue::Null.into(), + quaint::ValueType::Numeric(Some(bd)) => match bd.to_string().parse::() { + Ok(double) => JSArg::from(JsonValue::from(double)), + Err(_) => JSArg::from(JsonValue::from(value.clone())), }, - quaint::ValueType::Array(Some(ref items)) => JSArg::Array(values_to_js_args(items)?), - _ => super::value_to_js_arg(value)?, + quaint::ValueType::Json(Some(s)) => JSArg::Value(s.to_owned()), + quaint::ValueType::Bytes(Some(bytes)) => JSArg::Buffer(bytes.to_vec()), + quaint::ValueType::Array(Some(ref items)) => JSArg::Array( + items + .iter() + .map(value_to_js_arg) + .collect::>>()?, + ), + quaint_value => JSArg::from(JsonValue::from(quaint_value.clone())), }; Ok(res) } -pub fn values_to_js_args(values: &[quaint::Value<'_>]) -> serde_json::Result> { - values.iter().map(value_to_js_arg).collect() +// unit tests for value_to_js_arg +#[cfg(test)] +mod test { + use super::*; + use bigdecimal::BigDecimal; + use chrono::*; + use quaint::ValueType; + use serde_json::Value; + use std::str::FromStr; + + #[test] + #[rustfmt::skip] + fn test_value_to_js_arg() { + let test_cases = vec![ + ( + // This is different than how mysql or postgres processes integral BigInt values. + ValueType::Numeric(Some(1.into())), + JSArg::Value(Value::Number("1.0".parse().unwrap())) + ), + ( + ValueType::Numeric(Some(BigDecimal::from_str("-1.1").unwrap())), + JSArg::Value(Value::Number("-1.1".parse().unwrap())), + ), + ( + ValueType::Numeric(None), + JSArg::Value(Value::Null) + ), + ( + ValueType::Json(Some(serde_json::json!({"a": 1}))), + JSArg::Value(serde_json::json!({"a": 1})), + ), + ( + ValueType::Json(None), + JSArg::Value(Value::Null) + ), + ( + ValueType::Date(Some(NaiveDate::from_ymd_opt(2020, 2, 29).unwrap())), + JSArg::Value(Value::String("2020-02-29".to_string())), + ), + ( + ValueType::Date(None), + JSArg::Value(Value::Null) + ), + ( + ValueType::DateTime(Some(Utc.with_ymd_and_hms(2020, 1, 1, 23, 13, 1).unwrap())), + JSArg::Value(Value::String("2020-01-01T23:13:01+00:00".to_string())), + ), + ( + ValueType::DateTime(None), + JSArg::Value(Value::Null) + ), + ( + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap())), + JSArg::Value(Value::String("23:13:01".to_string())), + ), + ( + ValueType::Time(None), + JSArg::Value(Value::Null) + ), + ( + ValueType::Array(Some(vec!( + ValueType::Numeric(Some(1.into())).into_value(), + ValueType::Numeric(None).into_value(), + ValueType::Time(Some(NaiveTime::from_hms_opt(23, 13, 1).unwrap())).into_value(), + ValueType::Time(None).into_value(), + ))), + JSArg::Array(vec!( + JSArg::Value(Value::Number("1.0".parse().unwrap())), + JSArg::Value(Value::Null), + JSArg::Value(Value::String("23:13:01".to_string())), + JSArg::Value(Value::Null), + )) + ), + ]; + + let mut errors: Vec = vec![]; + for (val, expected) in test_cases { + let actual = value_to_js_arg(&val.clone().into_value()).unwrap(); + if actual != expected { + errors.push(format!("transforming: {:?}, expected: {:?}, actual: {:?}", &val, expected, actual)); + } + } + assert_eq!(errors.len(), 0, "{}", errors.join("\n")); + } } diff --git a/query-engine/driver-adapters/src/error.rs b/query-engine/driver-adapters/src/error.rs index f2fbb7dd9caf..4f4128088f49 100644 --- a/query-engine/driver-adapters/src/error.rs +++ b/query-engine/driver-adapters/src/error.rs @@ -12,7 +12,7 @@ pub(crate) fn into_quaint_error(napi_err: NapiError) -> QuaintError { QuaintError::raw_connector_error(status, reason) } -/// catches a panic thrown during the executuin of an asynchronous closure and transforms it into +/// catches a panic thrown during the execution of an asynchronous closure and transforms it into /// the Error variant of a napi::Result. pub(crate) async fn async_unwinding_panic(fut: F) -> napi::Result where diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 62086a245199..642c2491757a 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -249,6 +249,12 @@ fn js_value_to_quaint( column_type: ColumnType, column_name: &str, ) -> quaint::Result> { + let parse_number_as_i64 = |n: &serde_json::Number| { + n.as_i64().ok_or(conversion_error!( + "number must be an integer in column '{column_name}', got '{n}'" + )) + }; + // Note for the future: it may be worth revisiting how much bloat so many panics with different static // strings add to the compiled artefact, and in case we should come up with a restricted set of panic // messages, or even find a way of removing them altogether. @@ -256,8 +262,7 @@ fn js_value_to_quaint( ColumnType::Int32 => match json_value { serde_json::Value::Number(n) => { // n.as_i32() is not implemented, so we need to downcast from i64 instead - n.as_i64() - .ok_or(conversion_error!("number must be an integer in column '{column_name}'")) + parse_number_as_i64(&n) .and_then(|n| -> quaint::Result { n.try_into() .map_err(|e| conversion_error!("cannot convert {n} to i32 in column '{column_name}': {e}")) @@ -273,9 +278,7 @@ fn js_value_to_quaint( )), }, ColumnType::Int64 => match json_value { - serde_json::Value::Number(n) => n.as_i64().map(QuaintValue::int64).ok_or(conversion_error!( - "number must be an i64 in column '{column_name}', got {n}" - )), + serde_json::Value::Number(n) => parse_number_as_i64(&n).map(QuaintValue::int64), serde_json::Value::String(s) => s.parse::().map(QuaintValue::int64).map_err(|e| { conversion_error!("string-encoded number must be an i64 in column '{column_name}', got {s}: {e}") }), @@ -850,7 +853,7 @@ mod proxy_test { let s = "13:02:20.321"; let json_value = serde_json::Value::String(s.to_string()); let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); + let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 2, 20, 321).unwrap(); assert_eq!(quaint_value, QuaintValue::time(time)); } diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index b9a8cfe6564d..ab154eccc139 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -49,11 +49,19 @@ impl JsBaseQueryable { async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { let sql: String = sql.to_string(); - let args = match self.flavour { - Flavour::Postgres => conversion::postgres::values_to_js_args(values), - Flavour::Sqlite => conversion::sqlite::values_to_js_args(values), - _ => conversion::values_to_js_args(values), - }?; + + let converter = match self.flavour { + Flavour::Postgres => conversion::postgres::value_to_js_arg, + Flavour::Sqlite => conversion::sqlite::value_to_js_arg, + Flavour::Mysql => conversion::mysql::value_to_js_arg, + _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), + }; + + let args = values + .iter() + .map(converter) + .collect::>>()?; + Ok(Query { sql, args }) } } diff --git a/query-engine/driver-adapters/src/result.rs b/query-engine/driver-adapters/src/result.rs index 53133e037b6f..ad4ce7cbb546 100644 --- a/query-engine/driver-adapters/src/result.rs +++ b/query-engine/driver-adapters/src/result.rs @@ -1,5 +1,5 @@ use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; -use quaint::error::{Error as QuaintError, MysqlError, PostgresError, SqliteError}; +use quaint::error::{Error as QuaintError, ErrorKind, MysqlError, PostgresError, SqliteError}; use serde::Deserialize; #[derive(Deserialize)] @@ -36,7 +36,10 @@ pub(crate) enum DriverAdapterError { GenericJs { id: i32, }, - + UnsupportedNativeDataType { + #[serde(rename = "type")] + native_type: String, + }, Postgres(#[serde(with = "PostgresErrorDef")] PostgresError), Mysql(#[serde(with = "MysqlErrorDef")] MysqlError), Sqlite(#[serde(with = "SqliteErrorDef")] SqliteError), @@ -53,6 +56,12 @@ impl FromNapiValue for DriverAdapterError { impl From for QuaintError { fn from(value: DriverAdapterError) -> Self { match value { + DriverAdapterError::UnsupportedNativeDataType { native_type } => { + QuaintError::builder(ErrorKind::UnsupportedColumnType { + column_type: native_type, + }) + .build() + } DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), DriverAdapterError::Postgres(e) => e.into(), DriverAdapterError::Mysql(e) => e.into(), diff --git a/query-engine/prisma-models/src/field/scalar.rs b/query-engine/prisma-models/src/field/scalar.rs index 92039da53663..b8ef8ab204e2 100644 --- a/query-engine/prisma-models/src/field/scalar.rs +++ b/query-engine/prisma-models/src/field/scalar.rs @@ -91,7 +91,7 @@ impl ScalarField { match scalar_field_type { ScalarFieldType::CompositeType(_) => { - unreachable!("Cannot convert a composite type to a type identifier. This error is typically caused by mistakenly using a composite type within a composite index.",) + unreachable!("This shouldn't be reached; composite types are not supported in compound unique indices.",) } ScalarFieldType::Enum(x) => TypeIdentifier::Enum(x), ScalarFieldType::BuiltInScalar(scalar) => scalar.into(), diff --git a/query-engine/prisma-models/tests/datamodel_converter_tests.rs b/query-engine/prisma-models/tests/datamodel_converter_tests.rs index 0a45c80ed163..a2ee28ca6c0d 100644 --- a/query-engine/prisma-models/tests/datamodel_converter_tests.rs +++ b/query-engine/prisma-models/tests/datamodel_converter_tests.rs @@ -38,31 +38,159 @@ fn converting_enums() { } } +// region: composite #[test] -fn converting_composite_types() { +fn converting_composite_types_compound() { let res = psl::parse_schema( r#" - datasource db { - provider = "mongodb" - url = "mongodb://localhost:27017/hello" - } + datasource db { + provider = "mongodb" + url = "mongodb://localhost:27017/hello" + } - model MyModel { - id String @id @default(auto()) @map("_id") @db.ObjectId - attribute Attribute + model Post { + id String @id @default(auto()) @map("_id") @db.ObjectId + author User @relation(fields: [authorId], references: [id]) + authorId String @db.ObjectId + attributes Attribute[] + + @@index([authorId, attributes]) + } + + type Attribute { + name String + value String + group String + } + + model User { + id String @id @default(auto()) @map("_id") @db.ObjectId + Post Post[] + } + "#, + ); - @@unique([attribute], name: "composite_index") - } + assert!(res.is_ok()); +} - type Attribute { - name String - value String - group String - } +#[test] +fn converting_composite_types_compound_unique() { + let res = psl::parse_schema( + r#" + datasource db { + provider = "mongodb" + url = "mongodb://localhost:27017/hello" + } + + model Post { + id String @id @default(auto()) @map("_id") @db.ObjectId + author User @relation(fields: [authorId], references: [id]) + authorId String @db.ObjectId + attributes Attribute[] + + @@unique([authorId, attributes]) + // ^^^^^^^^^^^^^^^^^^^^^^ + // Prisma does not currently support composite types in compound unique indices... + } + + type Attribute { + name String + value String + group String + } + + model User { + id String @id @default(auto()) @map("_id") @db.ObjectId + Post Post[] + } "#, ); - assert!(res.unwrap_err().contains("Indexes can only contain scalar attributes. Please remove \"attribute\" from the argument list of the indexes.")); + + assert!(res + .unwrap_err() + .contains(r#"Prisma does not currently support composite types in compound unique indices, please remove "attributes" from the index. See https://pris.ly/d/mongodb-composite-compound-indices for more details"#)); +} + +#[test] +fn converting_composite_types_nested() { + let res = psl::parse_schema( + r#" + datasource db { + provider = "mongodb" + url = "mongodb://localhost:27017/hello" + } + + type TheatersLocation { + address TheatersLocationAddress + geo TheatersLocationGeo + } + + type TheatersLocationAddress { + city String + state String + street1 String + street2 String? + zipcode String + } + + type TheatersLocationGeo { + coordinates Float[] + type String + } + + model theaters { + id String @id @default(auto()) @map("_id") @db.ObjectId + location TheatersLocation + theaterId Int + + @@index([location.geo], map: "geo index") + } + "#, + ); + + assert!(res.is_ok()); +} + +#[test] +fn converting_composite_types_nested_scalar() { + let res = psl::parse_schema( + r#" + datasource db { + provider = "mongodb" + url = "mongodb://localhost:27017/hello" + } + + type TheatersLocation { + address TheatersLocationAddress + geo TheatersLocationGeo + } + + type TheatersLocationAddress { + city String + state String + street1 String + street2 String? + zipcode String + } + + type TheatersLocationGeo { + coordinates Float[] + type String + } + + model theaters { + id String @id @default(auto()) @map("_id") @db.ObjectId + location TheatersLocation + theaterId Int + + @@index([location.geo.type], map: "geo index") + } + "#, + ); + + assert!(res.is_ok()); } +// endregion #[test] fn models_with_only_scalar_fields() { diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index 74f9686189fc..0eaed9eff7ce 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -16,7 +16,7 @@ driver-adapters = ["request-handlers/driver-adapters", "sql-connector/driver-ada [dependencies] anyhow = "1" async-trait = "0.1" -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } query-connector = { path = "../connectors/query-connector" } user-facing-errors = { path = "../../libs/user-facing-errors" } diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index a8bc393aee3f..fdccc773eaf3 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -14,15 +14,23 @@ async-trait = "0.1" user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } +quaint = { path = "../../quaint" } +request-handlers = { path = "../request-handlers", default-features = false, features = [ + "sql", + "driver-adapters", +] } +connector = { path = "../connectors/query-connector", package = "query-connector" } +sql-query-connector = { path = "../connectors/sql-query-connector" } +query-core = { path = "../core" } thiserror = "1" -connection-string.workspace = true +connection-string.workspace = true url = "2" serde_json.workspace = true serde.workspace = true tokio = { version = "1.25", features = ["macros", "sync", "io-util", "time"] } futures = "0.3" -wasm-bindgen = "=0.2.87" +wasm-bindgen = "=0.2.88" wasm-bindgen-futures = "0.4" serde-wasm-bindgen = "0.5" js-sys = "0.3" diff --git a/query-engine/query-engine-wasm/package-lock.json b/query-engine/query-engine-wasm/package-lock.json index 1c66eec352d2..c2d5a7a1162e 100644 --- a/query-engine/query-engine-wasm/package-lock.json +++ b/query-engine/query-engine-wasm/package-lock.json @@ -6,8 +6,8 @@ "": { "dependencies": { "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.4.1", - "@prisma/driver-adapter-utils": "5.4.1" + "@prisma/adapter-neon": "5.5.2", + "@prisma/driver-adapter-utils": "5.5.2" } }, "node_modules/@neondatabase/serverless": { @@ -19,24 +19,41 @@ } }, "node_modules/@prisma/adapter-neon": { - "version": "5.4.1", - "resolved": "https://registry.npmjs.org/@prisma/adapter-neon/-/adapter-neon-5.4.1.tgz", - "integrity": "sha512-mIwLmwyAwDV9HXar9lSyM2uVm9H+X8noG4reKLnC3NjFsBxBfSUgW9vS8dPGqGW/rJWX3hg4pIffjEjmX4TDqg==", + "version": "5.5.2", + "resolved": "https://registry.npmjs.org/@prisma/adapter-neon/-/adapter-neon-5.5.2.tgz", + "integrity": "sha512-XcpJ/fgh/sP7mlBFkqjIzEcU/kWnNyiZf19MBP366HF7vXg2UQTbGxmbbeFiohXSJ/rwyu1Qmos7IrKK+QJOgg==", "dependencies": { - "@prisma/driver-adapter-utils": "5.4.1" + "@prisma/driver-adapter-utils": "5.5.2", + "postgres-array": "^3.0.2" }, "peerDependencies": { "@neondatabase/serverless": "^0.6.0" } }, + "node_modules/@prisma/adapter-neon/node_modules/postgres-array": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-3.0.2.tgz", + "integrity": "sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==", + "engines": { + "node": ">=12" + } + }, "node_modules/@prisma/driver-adapter-utils": { - "version": "5.4.1", - "resolved": "https://registry.npmjs.org/@prisma/driver-adapter-utils/-/driver-adapter-utils-5.4.1.tgz", - "integrity": "sha512-muYjkzf6qdxz4uGBi7nKyPaGRGLnSgiRautqAhZiMwbTOr9hMgyNI+aCJTCaKfYfNWjYCx2r5J6R1mJtPhzFhQ==", + "version": "5.5.2", + "resolved": "https://registry.npmjs.org/@prisma/driver-adapter-utils/-/driver-adapter-utils-5.5.2.tgz", + "integrity": "sha512-lRkxjboGcIl2VkJNomZQ9b6vc2qGFnVwjaR/o3cTPGmmSxETx71cYRYcG/NHKrhvKxI6oKNZ/xzyuzPpg1+kJQ==", "dependencies": { "debug": "^4.3.4" } }, + "node_modules/@types/node": { + "version": "20.8.10", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.8.10.tgz", + "integrity": "sha512-TlgT8JntpcbmKUFzjhsyhGfP2fsiz1Mv56im6enJ905xG1DAYesxJaeSbGqQmAw8OWPdhyJGhGSQGKRNJ45u9w==", + "dependencies": { + "undici-types": "~5.26.4" + } + }, "node_modules/@types/pg": { "version": "8.6.6", "resolved": "https://registry.npmjs.org/@types/pg/-/pg-8.6.6.tgz", @@ -131,6 +148,11 @@ "node": ">=0.10.0" } }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==" + }, "node_modules/xtend": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", diff --git a/query-engine/query-engine-wasm/package.json b/query-engine/query-engine-wasm/package.json index b4447ffcfb71..102db2ce14b5 100644 --- a/query-engine/query-engine-wasm/package.json +++ b/query-engine/query-engine-wasm/package.json @@ -3,7 +3,7 @@ "main": "./example.js", "dependencies": { "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.4.1", - "@prisma/driver-adapter-utils": "5.4.1" + "@prisma/adapter-neon": "5.5.2", + "@prisma/driver-adapter-utils": "5.5.2" } } diff --git a/query-engine/query-engine-wasm/pnpm-lock.yaml b/query-engine/query-engine-wasm/pnpm-lock.yaml new file mode 100644 index 000000000000..89591aef9869 --- /dev/null +++ b/query-engine/query-engine-wasm/pnpm-lock.yaml @@ -0,0 +1,130 @@ +lockfileVersion: '6.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +dependencies: + '@neondatabase/serverless': + specifier: 0.6.0 + version: 0.6.0 + '@prisma/adapter-neon': + specifier: 5.6.0 + version: 5.6.0(@neondatabase/serverless@0.6.0) + '@prisma/driver-adapter-utils': + specifier: 5.6.0 + version: 5.6.0 + +packages: + + /@neondatabase/serverless@0.6.0: + resolution: {integrity: sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==} + dependencies: + '@types/pg': 8.6.6 + dev: false + + /@prisma/adapter-neon@5.6.0(@neondatabase/serverless@0.6.0): + resolution: {integrity: sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==} + peerDependencies: + '@neondatabase/serverless': ^0.6.0 + dependencies: + '@neondatabase/serverless': 0.6.0 + '@prisma/driver-adapter-utils': 5.6.0 + postgres-array: 3.0.2 + transitivePeerDependencies: + - supports-color + dev: false + + /@prisma/driver-adapter-utils@5.6.0: + resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} + dependencies: + debug: 4.3.4 + transitivePeerDependencies: + - supports-color + dev: false + + /@types/node@20.9.1: + resolution: {integrity: sha512-HhmzZh5LSJNS5O8jQKpJ/3ZcrrlG6L70hpGqMIAoM9YVD0YBRNWYsfwcXq8VnSjlNpCpgLzMXdiPo+dxcvSmiA==} + dependencies: + undici-types: 5.26.5 + dev: false + + /@types/pg@8.6.6: + resolution: {integrity: sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==} + dependencies: + '@types/node': 20.9.1 + pg-protocol: 1.6.0 + pg-types: 2.2.0 + dev: false + + /debug@4.3.4: + resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + dependencies: + ms: 2.1.2 + dev: false + + /ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + dev: false + + /pg-int8@1.0.1: + resolution: {integrity: sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==} + engines: {node: '>=4.0.0'} + dev: false + + /pg-protocol@1.6.0: + resolution: {integrity: sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==} + dev: false + + /pg-types@2.2.0: + resolution: {integrity: sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==} + engines: {node: '>=4'} + dependencies: + pg-int8: 1.0.1 + postgres-array: 2.0.0 + postgres-bytea: 1.0.0 + postgres-date: 1.0.7 + postgres-interval: 1.2.0 + dev: false + + /postgres-array@2.0.0: + resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} + engines: {node: '>=4'} + dev: false + + /postgres-array@3.0.2: + resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} + engines: {node: '>=12'} + dev: false + + /postgres-bytea@1.0.0: + resolution: {integrity: sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-date@1.0.7: + resolution: {integrity: sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-interval@1.2.0: + resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} + engines: {node: '>=0.10.0'} + dependencies: + xtend: 4.0.2 + dev: false + + /undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + dev: false + + /xtend@4.0.2: + resolution: {integrity: sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==} + engines: {node: '>=0.4'} + dev: false diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index be36e4f842dc..c70d8590d0ff 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -20,7 +20,7 @@ enumflags2 = { version = "0.7"} psl.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser" } mongodb-connector = { path = "../connectors/mongodb-query-connector", optional = true, package = "mongodb-query-connector" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } serde.workspace = true serde_json.workspace = true diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index f5fb433b13ba..f04d742c448e 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" prisma-models = { path = "../prisma-models" } query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +quaint = { path = "../../quaint" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools = "0.10" @@ -20,7 +21,6 @@ thiserror = "1" tracing = "0.1" url = "2" connection-string.workspace = true -quaint.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } @@ -32,10 +32,11 @@ schema = { path = "../schema" } codspeed-criterion-compat = "1.1.0" [features] -default = ["mongodb", "sql"] +default = ["sql", "mongodb", "native"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] -driver-adapters = ["sql-query-connector"] +driver-adapters = ["sql-query-connector/driver-adapters"] +native = ["mongodb", "sql-query-connector", "quaint/native", "query-core/metrics"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/connector_mode.rs b/query-engine/request-handlers/src/connector_mode.rs index 00e0515a596e..be03fbab5820 100644 --- a/query-engine/request-handlers/src/connector_mode.rs +++ b/query-engine/request-handlers/src/connector_mode.rs @@ -1,6 +1,7 @@ #[derive(Copy, Clone, PartialEq, Eq)] pub enum ConnectorMode { /// Indicates that Rust drivers are used in Query Engine. + #[cfg(feature = "native")] Rust, /// Indicates that JS drivers are used in Query Engine. diff --git a/query-engine/request-handlers/src/load_executor.rs b/query-engine/request-handlers/src/load_executor.rs index 652ad3108f0d..26728605f92a 100644 --- a/query-engine/request-handlers/src/load_executor.rs +++ b/query-engine/request-handlers/src/load_executor.rs @@ -1,14 +1,12 @@ +#![allow(unused_imports)] + use psl::{builtin_connectors::*, Datasource, PreviewFeatures}; use query_core::{executor::InterpretingExecutor, Connector, QueryExecutor}; use sql_query_connector::*; use std::collections::HashMap; use std::env; -use tracing::trace; use url::Url; -#[cfg(feature = "mongodb")] -use mongodb_query_connector::MongoDb; - use super::ConnectorMode; /// Loads a query executor based on the parsed Prisma schema (datasource). @@ -27,6 +25,7 @@ pub async fn load( driver_adapter(source, url, features).await } + #[cfg(feature = "native")] ConnectorMode::Rust => { if let Ok(value) = env::var("PRISMA_DISABLE_QUAINT_EXECUTORS") { let disable = value.to_uppercase(); @@ -36,14 +35,14 @@ pub async fn load( } match source.active_provider { - p if SQLITE.is_provider(p) => sqlite(source, url, features).await, - p if MYSQL.is_provider(p) => mysql(source, url, features).await, - p if POSTGRES.is_provider(p) => postgres(source, url, features).await, - p if MSSQL.is_provider(p) => mssql(source, url, features).await, - p if COCKROACH.is_provider(p) => postgres(source, url, features).await, + p if SQLITE.is_provider(p) => native::sqlite(source, url, features).await, + p if MYSQL.is_provider(p) => native::mysql(source, url, features).await, + p if POSTGRES.is_provider(p) => native::postgres(source, url, features).await, + p if MSSQL.is_provider(p) => native::mssql(source, url, features).await, + p if COCKROACH.is_provider(p) => native::postgres(source, url, features).await, #[cfg(feature = "mongodb")] - p if MONGODB.is_provider(p) => mongodb(source, url, features).await, + p if MONGODB.is_provider(p) => native::mongodb(source, url, features).await, x => Err(query_core::CoreError::ConfigurationError(format!( "Unsupported connector type: {x}" @@ -53,57 +52,88 @@ pub async fn load( } } -async fn sqlite( +#[cfg(feature = "driver-adapters")] +async fn driver_adapter( source: &Datasource, url: &str, features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQLite query connector..."); - let sqlite = Sqlite::from_source(source, url, features).await?; - trace!("Loaded SQLite query connector."); - Ok(executor_for(sqlite, false)) +) -> Result, query_core::CoreError> { + let js = Js::from_source(source, url, features).await?; + Ok(executor_for(js, false)) } -async fn postgres( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading Postgres query connector..."); - let database_str = url; - let psql = PostgreSql::from_source(source, url, features).await?; - - let url = Url::parse(database_str) - .map_err(|err| query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")))?; - let params: HashMap = url.query_pairs().into_owned().collect(); - - let force_transactions = params - .get("pgbouncer") - .and_then(|flag| flag.parse().ok()) - .unwrap_or(false); - trace!("Loaded Postgres query connector."); - Ok(executor_for(psql, force_transactions)) -} +#[cfg(feature = "native")] +mod native { + use super::*; + use tracing::trace; + + pub(crate) async fn sqlite( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQLite query connector..."); + let sqlite = Sqlite::from_source(source, url, features).await?; + trace!("Loaded SQLite query connector."); + Ok(executor_for(sqlite, false)) + } -async fn mysql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - let mysql = Mysql::from_source(source, url, features).await?; - trace!("Loaded MySQL query connector."); - Ok(executor_for(mysql, false)) -} + pub(crate) async fn postgres( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading Postgres query connector..."); + let database_str = url; + let psql = PostgreSql::from_source(source, url, features).await?; + + let url = Url::parse(database_str).map_err(|err| { + query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")) + })?; + let params: HashMap = url.query_pairs().into_owned().collect(); + + let force_transactions = params + .get("pgbouncer") + .and_then(|flag| flag.parse().ok()) + .unwrap_or(false); + trace!("Loaded Postgres query connector."); + Ok(executor_for(psql, force_transactions)) + } -async fn mssql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQL Server query connector..."); - let mssql = Mssql::from_source(source, url, features).await?; - trace!("Loaded SQL Server query connector."); - Ok(executor_for(mssql, false)) + pub(crate) async fn mysql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + let mysql = Mysql::from_source(source, url, features).await?; + trace!("Loaded MySQL query connector."); + Ok(executor_for(mysql, false)) + } + + pub(crate) async fn mssql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQL Server query connector..."); + let mssql = Mssql::from_source(source, url, features).await?; + trace!("Loaded SQL Server query connector."); + Ok(executor_for(mssql, false)) + } + + #[cfg(feature = "mongodb")] + pub(crate) async fn mongodb( + source: &Datasource, + url: &str, + _features: PreviewFeatures, + ) -> query_core::Result> { + use mongodb_query_connector::MongoDb; + + trace!("Loading MongoDB query connector..."); + let mongo = MongoDb::new(source, url).await?; + trace!("Loaded MongoDB query connector."); + Ok(executor_for(mongo, false)) + } } fn executor_for(connector: T, force_transactions: bool) -> Box @@ -112,27 +142,3 @@ where { Box::new(InterpretingExecutor::new(connector, force_transactions)) } - -#[cfg(feature = "mongodb")] -async fn mongodb( - source: &Datasource, - url: &str, - _features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading MongoDB query connector..."); - let mongo = MongoDb::new(source, url).await?; - trace!("Loaded MongoDB query connector."); - Ok(executor_for(mongo, false)) -} - -#[cfg(feature = "driver-adapters")] -async fn driver_adapter( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> Result, query_core::CoreError> { - trace!("Loading driver adapter..."); - let js = Js::from_source(source, url, features).await?; - trace!("Loaded driver adapter..."); - Ok(executor_for(js, false)) -} diff --git a/renovate.json b/renovate.json index c25ec3daa6e8..4d8e7d2511d0 100644 --- a/renovate.json +++ b/renovate.json @@ -1,6 +1,7 @@ { + "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ - "config:base" + "config:recommended" ], "cargo": { "enabled": false @@ -15,25 +16,44 @@ ], "rangeStrategy": "pin", "separateMinorPatch": true, + "configMigration": true, "packageRules": [ { - "matchFiles": ["docker-compose.yml"], - "matchUpdateTypes": ["minor", "major"], + "matchFileNames": [ + "docker-compose.yml" + ], + "matchUpdateTypes": [ + "minor", + "major" + ], "enabled": false }, { "groupName": "Weekly vitess docker image version update", - "packageNames": ["vitess/vttestserver"], - "schedule": ["before 7am on Wednesday"] + "matchPackageNames": [ + "vitess/vttestserver" + ], + "schedule": [ + "before 7am on Wednesday" + ] }, { - "groupName": ["Prisma Driver Adapters"], - "matchPackageNames": ["@prisma/driver-adapter-utils"], - "matchPackagePrefixes": ["@prisma/adapter"], - "schedule": ["at any time"] + "groupName": "Prisma Driver Adapters", + "matchPackageNames": [ + "@prisma/driver-adapter-utils" + ], + "matchPackagePrefixes": [ + "@prisma/adapter" + ], + "schedule": [ + "at any time" + ] }, { - "packageNames": ["node", "pnpm"], + "matchPackageNames": [ + "node", + "pnpm" + ], "enabled": false } ] diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs index 18a0b8e94b3c..51a8f5ef54be 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs @@ -23,7 +23,7 @@ impl SqlSchemaCalculatorFlavour for MssqlFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut data = MssqlSchemaExt::default(); for model in context.datamodel.db.walk_models() { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs index 40577d68a35d..656fe432a970 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs @@ -37,7 +37,7 @@ impl SqlSchemaCalculatorFlavour for PostgresFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut postgres_ext = PostgresSchemaExt::default(); let db = &context.datamodel.db; diff --git a/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs b/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs index d712b17f684e..537e2233e9ee 100644 --- a/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs +++ b/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs @@ -4,7 +4,7 @@ use indoc::indoc; use sql_introspection_tests::test_api::*; use test_macros::test_connector; -// Note: MySQL 5.6 ad 5.7 do not support check constraints, so this test is only run on MySQL 8.0. +// Note: MySQL 5.6 and 5.7 do not support check constraints, so this test is only run on MySQL 8.0. #[test_connector(tags(Mysql8), exclude(Vitess))] async fn check_constraints_stopgap(api: &mut TestApi) -> TestResult { let raw_sql = indoc! {r#" diff --git a/schema-engine/sql-migration-tests/tests/native_types/mysql.rs b/schema-engine/sql-migration-tests/tests/native_types/mysql.rs index d8cf62f5767c..b74f3dd6bac4 100644 --- a/schema-engine/sql-migration-tests/tests/native_types/mysql.rs +++ b/schema-engine/sql-migration-tests/tests/native_types/mysql.rs @@ -697,8 +697,8 @@ fn filter_from_types(api: &TestApi, cases: Cases) -> Cow<'static, [Case]> { return Cow::Owned( cases .iter() + .filter(|&(ty, _, _)| !type_is_unsupported_mariadb(ty)) .cloned() - .filter(|(ty, _, _)| !type_is_unsupported_mariadb(ty)) .collect(), ); } @@ -707,8 +707,8 @@ fn filter_from_types(api: &TestApi, cases: Cases) -> Cow<'static, [Case]> { return Cow::Owned( cases .iter() + .filter(|&(ty, _, _)| !type_is_unsupported_mysql_5_6(ty)) .cloned() - .filter(|(ty, _, _)| !type_is_unsupported_mysql_5_6(ty)) .collect(), ); }