From 08c47996fb7adc64246b2db59fdfd1588549d1db Mon Sep 17 00:00:00 2001 From: Hong Shin Date: Tue, 19 Mar 2024 13:26:42 -0700 Subject: [PATCH] In this CL, we update generated_message_tctable_lite to support both length-prefixed and delimited when it comes to parsing submessages. In the past, the parser would do error checking based on the schema. If a .proto file declared a submessage to be length-prefixed but an SGROUP was encountered on the wire, it would invoke the fallback. This behavior is now updated: the parser will solely look at the tag when dealing with submessages. This increases the flexibility of the parser and will make our Editions rollout smoother. PiperOrigin-RevId: 617275654 --- .../generated_message_tctable_lite.cc | 45 ++-- src/google/protobuf/lite_unittest.cc | 20 ++ src/google/protobuf/wire_format.cc | 213 ++++++++++++++++-- 3 files changed, 226 insertions(+), 52 deletions(-) diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index f8b0d87942834..73e70f985f1ba 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -2297,44 +2297,35 @@ PROTOBUF_NOINLINE const char* TcParser::MpMessage(PROTOBUF_TC_PARAM_DECL) { const uint16_t type_card = entry.type_card; const uint16_t card = type_card & field_layout::kFcMask; + const uint32_t decoded_tag = data.tag(); + const uint32_t decoded_wiretype = decoded_tag & 7; + // Check for repeated parsing: if (card == field_layout::kFcRepeated) { - const uint16_t rep = type_card & field_layout::kRepMask; - switch (rep) { - case field_layout::kRepMessage: + switch (decoded_wiretype) { + case WireFormatLite::WIRETYPE_LENGTH_DELIMITED: PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup( PROTOBUF_TC_PARAM_PASS); - case field_layout::kRepGroup: + case WireFormatLite::WIRETYPE_START_GROUP: PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup( PROTOBUF_TC_PARAM_PASS); default: PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); } } - - const uint32_t decoded_tag = data.tag(); - const uint32_t decoded_wiretype = decoded_tag & 7; const uint16_t rep = type_card & field_layout::kRepMask; - const bool is_group = rep == field_layout::kRepGroup; - - // Validate wiretype: - switch (rep) { - case field_layout::kRepMessage: - if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { - goto fallback; - } - break; - case field_layout::kRepGroup: - if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) { - goto fallback; - } - break; - default: { - fallback: - PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); - } + // note that we solely rely on wiretype for parsing messages (schema ignored) + const bool is_group = + decoded_wiretype == WireFormatLite::WIRETYPE_START_GROUP; + + // If we don't see a wiretype of START_GROUP or DELIM even though we're in the + // entry point for MpMessage, something is wrong. Bail out! + if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP && + decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); } + const bool is_oneof = card == field_layout::kFcOneof; bool need_init = false; if (card == field_layout::kFcOptional) { @@ -2386,14 +2377,10 @@ const char* TcParser::MpRepeatedMessageOrGroup(PROTOBUF_TC_PARAM_DECL) { // Validate wiretype: if (!is_group) { - ABSL_DCHECK_EQ(type_card & field_layout::kRepMask, - static_cast(field_layout::kRepMessage)); if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); } } else { - ABSL_DCHECK_EQ(type_card & field_layout::kRepMask, - static_cast(field_layout::kRepGroup)); if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) { PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); } diff --git a/src/google/protobuf/lite_unittest.cc b/src/google/protobuf/lite_unittest.cc index 978307528a795..c405a1042c610 100644 --- a/src/google/protobuf/lite_unittest.cc +++ b/src/google/protobuf/lite_unittest.cc @@ -1136,6 +1136,26 @@ TYPED_TEST(LiteTest, CorrectEnding) { } } +TYPED_TEST(LiteTest, MessageEncoding) { + protobuf_unittest::TestAllTypesLite msg; + { + // Make sure that we support length-prefixed encoding for submsgs + static const char kWireFormat[] = "\n\002\010\003"; // 1: {1: 3} + io::CodedInputStream cis(reinterpret_cast(kWireFormat), 4); + EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis)); + EXPECT_TRUE(cis.ConsumedEntireMessage()); + EXPECT_TRUE(cis.LastTagWas(0)); + } + { + // Make sure that we support delimited encoding for submsgs + static const char kWireFormat[] = "\013\010\003\014"; // 1: !{1: 3} + io::CodedInputStream cis(reinterpret_cast(kWireFormat), 4); + EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis)); + EXPECT_TRUE(cis.ConsumedEntireMessage()); + EXPECT_TRUE(cis.LastTagWas(0)); + } +} + TYPED_TEST(LiteTest, DebugString) { protobuf_unittest::TestAllTypesLite message1, message2; EXPECT_TRUE(absl::StartsWith(message1.DebugString(), "MessageLite at 0x")); diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 72e13c4ba7f75..5c20d3eaedfc4 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -30,6 +30,7 @@ #include "google/protobuf/message_lite.h" #include "google/protobuf/parse_context.h" #include "google/protobuf/unknown_field_set.h" +#include "google/protobuf/wire_format_lite.h" // Must be included last. @@ -578,31 +579,59 @@ bool WireFormat::ParseAndMergeField( case FieldDescriptor::TYPE_GROUP: { Message* sub_message; - if (field->is_repeated()) { - sub_message = message_reflection->AddMessage( - message, field, input->GetExtensionFactory()); + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = message_reflection->AddMessage( + message, field, input->GetExtensionFactory()); + } else { + sub_message = message_reflection->MutableMessage( + message, field, input->GetExtensionFactory()); + } + + if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag), + input, sub_message)) + return false; } else { - sub_message = message_reflection->MutableMessage( - message, field, input->GetExtensionFactory()); - } + if (field->is_repeated()) { + sub_message = message_reflection->AddMessage( + message, field, input->GetExtensionFactory()); + } else { + sub_message = message_reflection->MutableMessage( + message, field, input->GetExtensionFactory()); + } - if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag), - input, sub_message)) - return false; + if (!WireFormatLite::ReadMessage(input, sub_message)) return false; + } break; } case FieldDescriptor::TYPE_MESSAGE: { Message* sub_message; - if (field->is_repeated()) { - sub_message = message_reflection->AddMessage( - message, field, input->GetExtensionFactory()); + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = message_reflection->AddMessage( + message, field, input->GetExtensionFactory()); + } else { + sub_message = message_reflection->MutableMessage( + message, field, input->GetExtensionFactory()); + } + + if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag), + input, sub_message)) + return false; } else { - sub_message = message_reflection->MutableMessage( - message, field, input->GetExtensionFactory()); - } + if (field->is_repeated()) { + sub_message = message_reflection->AddMessage( + message, field, input->GetExtensionFactory()); + } else { + sub_message = message_reflection->MutableMessage( + message, field, input->GetExtensionFactory()); + } - if (!WireFormatLite::ReadMessage(input, sub_message)) return false; + if (!WireFormatLite::ReadMessage(input, sub_message)) return false; + } break; } } @@ -871,6 +900,104 @@ const char* WireFormat::_InternalParseAndMergeField( ABSL_LOG(FATAL) << "Can't reach"; return nullptr; } + } else if (WireFormatLite::GetTagWireType(tag) != + WireTypeForFieldType(field->type()) && + (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP || + WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_LENGTH_DELIMITED)) { + switch (field->type()) { + case FieldDescriptor::TYPE_GROUP: { + Message* sub_message; + + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = + reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + + return ctx->ParseGroup(sub_message, ptr, tag); + } else { + if (field->is_repeated()) { + sub_message = + reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + ptr = ctx->ParseMessage(sub_message, ptr); + + // For map entries, if the value is an unknown enum we have to push + // it into the unknown field set and remove it from the list. + if (ptr != nullptr && field->is_map()) { + auto* value_field = field->message_type()->map_value(); + auto* enum_type = value_field->enum_type(); + if (enum_type != nullptr && + !internal::cpp::HasPreservingUnknownEnumSemantics( + value_field) && + enum_type->FindValueByNumber( + sub_message->GetReflection()->GetEnumValue( + *sub_message, value_field)) == nullptr) { + reflection->MutableUnknownFields(msg)->AddLengthDelimited( + field->number(), sub_message->SerializeAsString()); + reflection->RemoveLast(msg, field); + } + } + + return ptr; + } + } + + case FieldDescriptor::TYPE_MESSAGE: { + Message* sub_message; + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = + reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + + return ctx->ParseGroup(sub_message, ptr, tag); + } else if (field->is_repeated()) { + sub_message = + reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + ptr = ctx->ParseMessage(sub_message, ptr); + + // For map entries, if the value is an unknown enum we have to push it + // into the unknown field set and remove it from the list. + if (ptr != nullptr && field->is_map()) { + auto* value_field = field->message_type()->map_value(); + auto* enum_type = value_field->enum_type(); + if (enum_type != nullptr && + !internal::cpp::HasPreservingUnknownEnumSemantics( + value_field) && + enum_type->FindValueByNumber( + sub_message->GetReflection()->GetEnumValue( + *sub_message, value_field)) == nullptr) { + reflection->MutableUnknownFields(msg)->AddLengthDelimited( + field->number(), sub_message->SerializeAsString()); + reflection->RemoveLast(msg, field); + } + } + + return ptr; + } + default: { + return internal::UnknownFieldParse( + tag, reflection->MutableUnknownFields(msg), ptr, ctx); + } + } } else { // mismatched wiretype; return internal::UnknownFieldParse( @@ -997,19 +1124,59 @@ const char* WireFormat::_InternalParseAndMergeField( case FieldDescriptor::TYPE_GROUP: { Message* sub_message; - if (field->is_repeated()) { - sub_message = reflection->AddMessage(msg, field, ctx->data().factory); + + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + + return ctx->ParseGroup(sub_message, ptr, tag); } else { - sub_message = - reflection->MutableMessage(msg, field, ctx->data().factory); - } + if (field->is_repeated()) { + sub_message = reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + ptr = ctx->ParseMessage(sub_message, ptr); + + // For map entries, if the value is an unknown enum we have to push it + // into the unknown field set and remove it from the list. + if (ptr != nullptr && field->is_map()) { + auto* value_field = field->message_type()->map_value(); + auto* enum_type = value_field->enum_type(); + if (enum_type != nullptr && + !internal::cpp::HasPreservingUnknownEnumSemantics(value_field) && + enum_type->FindValueByNumber( + sub_message->GetReflection()->GetEnumValue( + *sub_message, value_field)) == nullptr) { + reflection->MutableUnknownFields(msg)->AddLengthDelimited( + field->number(), sub_message->SerializeAsString()); + reflection->RemoveLast(msg, field); + } + } - return ctx->ParseGroup(sub_message, ptr, tag); + return ptr; + } } case FieldDescriptor::TYPE_MESSAGE: { Message* sub_message; - if (field->is_repeated()) { + if (WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_START_GROUP) { + if (field->is_repeated()) { + sub_message = reflection->AddMessage(msg, field, ctx->data().factory); + } else { + sub_message = + reflection->MutableMessage(msg, field, ctx->data().factory); + } + + return ctx->ParseGroup(sub_message, ptr, tag); + } else if (field->is_repeated()) { sub_message = reflection->AddMessage(msg, field, ctx->data().factory); } else { sub_message =