diff --git a/acceptance/cases/models/binary_identifiers.rb b/acceptance/cases/models/binary_identifiers.rb new file mode 100644 index 00000000..94e859b2 --- /dev/null +++ b/acceptance/cases/models/binary_identifiers.rb @@ -0,0 +1,87 @@ +# Copyright 2025 Google LLC +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +# frozen_string_literal: true + +require "test_helper" +require "test_helpers/with_separate_database" +require_relative "../../models/user" +require_relative "../../models/binary_project" + +module Models + class DefaultValueTest < SpannerAdapter::TestCase + include TestHelpers::WithSeparateDatabase + + def setup + super + + connection.create_table :users, id: :binary do |t| + t.string :email, null: false + t.string :full_name, null: false + end + connection.create_table :binary_projects, id: :binary do |t| + t.string :name, null: false + t.string :description, null: false + t.binary :owner_id, null: false + t.foreign_key :users, column: :owner_id + end + end + + def test_includes_works + user = User.create!( + email: "test@example.com", + full_name: "Test User" + ) + 3.times do |i| + Project.create!( + name: "Project #{i}", + description: "Description #{i}", + owner: user + ) + end + + # First verify the association works without includes + projects = Project.all + assert_equal 3, projects.count + + # Compare the base64 content instead of the StringIO objects + first_project = projects.first + assert_equal to_base64(user.id), to_base64(first_project.owner_id) + + # Now verify includes is working + query_count = count_queries do + loaded_projects = Project.includes(:owner).to_a + loaded_projects.each do |project| + # Access the owner to ensure it's preloaded + assert_equal user.full_name, project.owner.full_name + end + end + + # Spanner should execute 2 queries: one for projects and one for users + assert_equal 2, query_count + end + + private + + def to_base64 buffer + buffer.rewind + value = buffer.read + Base64.strict_encode64 value.force_encoding("ASCII-8BIT") + end + + def count_queries(&block) + count = 0 + counter_fn = ->(name, started, finished, unique_id, payload) { + unless %w[CACHE SCHEMA].include?(payload[:name]) + count += 1 + end + } + + ActiveSupport::Notifications.subscribed(counter_fn, "sql.active_record", &block) + count + end + end +end diff --git a/acceptance/models/binary_project.rb b/acceptance/models/binary_project.rb new file mode 100644 index 00000000..8e817b37 --- /dev/null +++ b/acceptance/models/binary_project.rb @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +# frozen_string_literal: true + +class BinaryProject < ActiveRecord::Base + belongs_to :owner, class_name: 'User' + + before_create :set_uuid + private + + def set_uuid + self.id ||= StringIO.new(SecureRandom.random_bytes(16)) + end +end diff --git a/acceptance/models/user.rb b/acceptance/models/user.rb new file mode 100644 index 00000000..f5fbd330 --- /dev/null +++ b/acceptance/models/user.rb @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +# frozen_string_literal: true + +class User < ActiveRecord::Base + has_many :binary_projects, foreign_key: :owner_id + + before_create :set_uuid + private + + def set_uuid + self.id ||= StringIO.new(SecureRandom.random_bytes(16)) + end +end diff --git a/lib/active_record/type/spanner/bytes.rb b/lib/active_record/type/spanner/bytes.rb index 1cee6452..2a07446d 100644 --- a/lib/active_record/type/spanner/bytes.rb +++ b/lib/active_record/type/spanner/bytes.rb @@ -10,6 +10,16 @@ module ActiveRecord module Type module Spanner class Bytes < ActiveRecord::Type::Binary + def deserialize value + # Set this environment variable to disable de-serializing BYTES + # to a StringIO instance. + return super if ENV["SPANNER_BYTES_DESERIALIZE_DISABLED"] + + return super value if value.nil? + return StringIO.new Base64.strict_decode64(value) if value.is_a? ::String + value + end + def serialize value return super value if value.nil? diff --git a/test/activerecord_spanner_mock_server/base_spanner_mock_server_test.rb b/test/activerecord_spanner_mock_server/base_spanner_mock_server_test.rb index 35325e07..3235fcf4 100644 --- a/test/activerecord_spanner_mock_server/base_spanner_mock_server_test.rb +++ b/test/activerecord_spanner_mock_server/base_spanner_mock_server_test.rb @@ -15,6 +15,8 @@ require_relative "models/table_with_commit_timestamp" require_relative "models/table_with_sequence" require_relative "models/versioned_singer" +require_relative "models/user" +require_relative "models/binary_project" require "securerandom" @@ -57,6 +59,15 @@ def setup MockServerTests::register_table_with_sequence_columns_result @mock MockServerTests::register_table_with_sequence_primary_key_columns_result @mock MockServerTests::register_table_with_sequence_primary_and_parent_key_columns_result @mock + + MockServerTests::register_users_columns_result @mock + MockServerTests::register_users_primary_key_columns_result @mock + MockServerTests::register_users_primary_and_parent_key_columns_result @mock + + MockServerTests::register_binary_projects_columns_result @mock + MockServerTests::register_binary_projects_primary_key_columns_result @mock + MockServerTests::register_binary_projects_primary_and_parent_key_columns_result @mock + # Connect ActiveRecord to the mock server ActiveRecord::Base.establish_connection( adapter: "spanner", diff --git a/test/activerecord_spanner_mock_server/model_helper.rb b/test/activerecord_spanner_mock_server/model_helper.rb index bcbd28bb..28617f14 100644 --- a/test/activerecord_spanner_mock_server/model_helper.rb +++ b/test/activerecord_spanner_mock_server/model_helper.rb @@ -192,6 +192,24 @@ def self.register_select_tables_result spanner_mock_server Value.new(null_value: "NULL_VALUE"), ) result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: ""), + Value.new(string_value: ""), + Value.new(string_value: "users"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: ""), + Value.new(string_value: ""), + Value.new(string_value: "binary_projects"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + ) + result_set.rows.push row spanner_mock_server.put_statement_result sql, StatementResult.new(result_set) end @@ -727,6 +745,136 @@ def self.register_table_with_sequence_primary_and_parent_key_columns_result span register_key_columns_result spanner_mock_server, sql end + def self.register_users_columns_result spanner_mock_server + register_commit_timestamps_result spanner_mock_server, "users" + + sql = table_columns_sql "users" + + column_name = Field.new name: "COLUMN_NAME", type: Type.new(code: TypeCode::STRING) + spanner_type = Field.new name: "SPANNER_TYPE", type: Type.new(code: TypeCode::STRING) + is_nullable = Field.new name: "IS_NULLABLE", type: Type.new(code: TypeCode::STRING) + generation_expression = Field.new name: "GENERATION_EXPRESSION", type: Type.new(code: TypeCode::STRING) + column_default = Field.new name: "COLUMN_DEFAULT", type: Type.new(code: TypeCode::STRING) + ordinal_position = Field.new name: "ORDINAL_POSITION", type: Type.new(code: TypeCode::INT64) + + metadata = ResultSetMetadata.new row_type: StructType.new + metadata.row_type.fields.push column_name, spanner_type, is_nullable, generation_expression, column_default, ordinal_position + result_set = ResultSet.new metadata: metadata + + row = ListValue.new + row.values.push( + Value.new(string_value: "id"), + Value.new(string_value: "BYTES(16)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "1") + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: "email"), + Value.new(string_value: "STRING(MAX)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "2") + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: "full_name"), + Value.new(string_value: "STRING(MAX)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "3") + ) + result_set.rows.push row + + spanner_mock_server.put_statement_result sql, StatementResult.new(result_set) + end + + def self.register_users_primary_key_columns_result spanner_mock_server + sql = primary_key_columns_sql "users", parent_keys: false + register_key_columns_result spanner_mock_server, sql + end + + def self.register_users_primary_and_parent_key_columns_result spanner_mock_server + sql = primary_key_columns_sql "users", parent_keys: true + register_key_columns_result spanner_mock_server, sql + end + + def self.register_binary_projects_columns_result spanner_mock_server + register_commit_timestamps_result spanner_mock_server, "binary_projects" + + sql = table_columns_sql "binary_projects" + + column_name = Field.new name: "COLUMN_NAME", type: Type.new(code: TypeCode::STRING) + spanner_type = Field.new name: "SPANNER_TYPE", type: Type.new(code: TypeCode::STRING) + is_nullable = Field.new name: "IS_NULLABLE", type: Type.new(code: TypeCode::STRING) + generation_expression = Field.new name: "GENERATION_EXPRESSION", type: Type.new(code: TypeCode::STRING) + column_default = Field.new name: "COLUMN_DEFAULT", type: Type.new(code: TypeCode::STRING) + ordinal_position = Field.new name: "ORDINAL_POSITION", type: Type.new(code: TypeCode::INT64) + + metadata = ResultSetMetadata.new row_type: StructType.new + metadata.row_type.fields.push column_name, spanner_type, is_nullable, generation_expression, column_default, ordinal_position + result_set = ResultSet.new metadata: metadata + + row = ListValue.new + row.values.push( + Value.new(string_value: "id"), + Value.new(string_value: "BYTES(16)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "1") + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: "name"), + Value.new(string_value: "STRING(MAX)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "2") + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: "description"), + Value.new(string_value: "STRING(MAX)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "3") + ) + result_set.rows.push row + row = ListValue.new + row.values.push( + Value.new(string_value: "owner_id"), + Value.new(string_value: "BYTES(16)"), + Value.new(string_value: "NO"), + Value.new(null_value: "NULL_VALUE"), + Value.new(null_value: "NULL_VALUE"), + Value.new(string_value: "4") + ) + result_set.rows.push row + + spanner_mock_server.put_statement_result sql, StatementResult.new(result_set) + end + + def self.register_binary_projects_primary_key_columns_result spanner_mock_server + sql = primary_key_columns_sql "binary_projects", parent_keys: false + register_key_columns_result spanner_mock_server, sql + end + + def self.register_binary_projects_primary_and_parent_key_columns_result spanner_mock_server + sql = primary_key_columns_sql "binary_projects", parent_keys: true + register_key_columns_result spanner_mock_server, sql + end + def self.register_empty_select_indexes_result spanner_mock_server, sql col_index_name = Field.new name: "INDEX_NAME", type: Type.new(code: TypeCode::STRING) col_index_type = Field.new name: "INDEX_TYPE", type: Type.new(code: TypeCode::STRING) diff --git a/test/activerecord_spanner_mock_server/models/binary_project.rb b/test/activerecord_spanner_mock_server/models/binary_project.rb new file mode 100644 index 00000000..8e817b37 --- /dev/null +++ b/test/activerecord_spanner_mock_server/models/binary_project.rb @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +# frozen_string_literal: true + +class BinaryProject < ActiveRecord::Base + belongs_to :owner, class_name: 'User' + + before_create :set_uuid + private + + def set_uuid + self.id ||= StringIO.new(SecureRandom.random_bytes(16)) + end +end diff --git a/test/activerecord_spanner_mock_server/models/user.rb b/test/activerecord_spanner_mock_server/models/user.rb new file mode 100644 index 00000000..f5fbd330 --- /dev/null +++ b/test/activerecord_spanner_mock_server/models/user.rb @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +# frozen_string_literal: true + +class User < ActiveRecord::Base + has_many :binary_projects, foreign_key: :owner_id + + before_create :set_uuid + private + + def set_uuid + self.id ||= StringIO.new(SecureRandom.random_bytes(16)) + end +end diff --git a/test/activerecord_spanner_mock_server/spanner_active_record_with_mock_server_test.rb b/test/activerecord_spanner_mock_server/spanner_active_record_with_mock_server_test.rb index a832e585..2c03f099 100644 --- a/test/activerecord_spanner_mock_server/spanner_active_record_with_mock_server_test.rb +++ b/test/activerecord_spanner_mock_server/spanner_active_record_with_mock_server_test.rb @@ -1259,6 +1259,86 @@ def test_upsert_all_dml assert_equal 1, execute_requests.length end + def test_binary_id + user = User.create!( + email: "test@example.com", + full_name: "Test User" + ) + # Verify that an ID was generated for the User. + assert user.id + assert user.id.is_a?(StringIO) + + commit_requests = @mock.requests.select { |req| req.is_a?(CommitRequest) } + assert_equal 1, commit_requests.length + assert_equal 1, commit_requests[0].mutations.length + mutation = commit_requests[0].mutations[0] + assert_equal :insert, mutation.operation + assert_equal "users", mutation.insert.table + + assert_equal 1, mutation.insert.values.length + assert_equal 3, mutation.insert.values[0].length + assert_equal to_base64(user.id), mutation.insert.values[0][0] + assert_equal "test@example.com", mutation.insert.values[0][1] + assert_equal "Test User", mutation.insert.values[0][2] + end + + def test_binary_id_association + user = User.create!( + email: "test@example.com", + full_name: "Test User" + ) + project1 = BinaryProject.create!( + name: "Test Project 1", + description: "Test Description 1", + owner: user + ) + project2 = BinaryProject.create!( + name: "Test Project 2", + description: "Test Description 2", + owner: user + ) + # Verify that an ID was generated for the records. + assert user.id + assert project1.id + assert project2.id + + commit_requests = @mock.requests.select { |req| req.is_a?(CommitRequest) } + assert_equal 3, commit_requests.length + assert_equal 1, commit_requests[1].mutations.length + mutation = commit_requests[1].mutations[0] + assert_equal :insert, mutation.operation + assert_equal "binary_projects", mutation.insert.table + + assert_equal 1, mutation.insert.values.length + assert_equal 4, mutation.insert.values[0].length + assert_equal to_base64(project1.id), mutation.insert.values[0][0] + assert_equal "Test Project 1", mutation.insert.values[0][1] + assert_equal "Test Description 1", mutation.insert.values[0][2] + assert_equal to_base64(user.id), mutation.insert.values[0][3] + end + + def test_skip_binary_deserialization + ENV["SPANNER_BYTES_DESERIALIZE_DISABLED"] = "true" + begin + user = User.create!( + email: "test@example.com", + full_name: "Test User" + ) + # Verify that the ID is returned as a Base64 string. + assert user.id + assert user.id.is_a?(String) + assert_equal user.id, Base64.strict_encode64(Base64.strict_decode64(user.id)) + ensure + ENV.delete("SPANNER_BYTES_DESERIALIZE_DISABLED") + end + end + + def to_base64 buffer + buffer.rewind + value = buffer.read + Base64.strict_encode64 value.force_encoding("ASCII-8BIT") + end + private def verify_insert_upsert_all operation