Skip to content

Commit

Permalink
In this CL, we update generated_message_tctable_lite to support both …
Browse files Browse the repository at this point in the history
…length-prefixed and delimited when it comes to parsing submessages.

In the past, the parser would do error checking based on the schema. If a .proto file declared a submessage to be length-prefixed but an SGROUP was encountered on the wire, it would invoke the fallback. This behavior is now updated: the parser will solely look at the tag when dealing with submessages.

This increases the flexibility of the parser and will make our Editions rollout smoother.

PiperOrigin-RevId: 617275654
  • Loading branch information
honglooker authored and copybara-github committed Mar 28, 2024
1 parent 1f6580d commit 51d95e0
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 52 deletions.
45 changes: 16 additions & 29 deletions src/google/protobuf/generated_message_tctable_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2297,44 +2297,35 @@ PROTOBUF_NOINLINE const char* TcParser::MpMessage(PROTOBUF_TC_PARAM_DECL) {
const uint16_t type_card = entry.type_card;
const uint16_t card = type_card & field_layout::kFcMask;

const uint32_t decoded_tag = data.tag();
const uint32_t decoded_wiretype = decoded_tag & 7;

// Check for repeated parsing:
if (card == field_layout::kFcRepeated) {
const uint16_t rep = type_card & field_layout::kRepMask;
switch (rep) {
case field_layout::kRepMessage:
switch (decoded_wiretype) {
case WireFormatLite::WIRETYPE_LENGTH_DELIMITED:
PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup<is_split, false>(
PROTOBUF_TC_PARAM_PASS);
case field_layout::kRepGroup:
case WireFormatLite::WIRETYPE_START_GROUP:
PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup<is_split, true>(
PROTOBUF_TC_PARAM_PASS);
default:
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
}

const uint32_t decoded_tag = data.tag();
const uint32_t decoded_wiretype = decoded_tag & 7;
const uint16_t rep = type_card & field_layout::kRepMask;
const bool is_group = rep == field_layout::kRepGroup;

// Validate wiretype:
switch (rep) {
case field_layout::kRepMessage:
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
goto fallback;
}
break;
case field_layout::kRepGroup:
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
goto fallback;
}
break;
default: {
fallback:
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
// note that we solely rely on wiretype for parsing messages (schema ignored)
const bool is_group =
decoded_wiretype == WireFormatLite::WIRETYPE_START_GROUP;

// If we don't see a wiretype of START_GROUP or DELIM even though we're in the
// entry point for MpMessage, something is wrong. Bail out!
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP &&
decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}


const bool is_oneof = card == field_layout::kFcOneof;
bool need_init = false;
if (card == field_layout::kFcOptional) {
Expand Down Expand Up @@ -2386,14 +2377,10 @@ const char* TcParser::MpRepeatedMessageOrGroup(PROTOBUF_TC_PARAM_DECL) {

// Validate wiretype:
if (!is_group) {
ABSL_DCHECK_EQ(type_card & field_layout::kRepMask,
static_cast<uint16_t>(field_layout::kRepMessage));
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
} else {
ABSL_DCHECK_EQ(type_card & field_layout::kRepMask,
static_cast<uint16_t>(field_layout::kRepGroup));
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
Expand Down
20 changes: 20 additions & 0 deletions src/google/protobuf/lite_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,26 @@ TYPED_TEST(LiteTest, CorrectEnding) {
}
}

TYPED_TEST(LiteTest, MessageEncoding) {
protobuf_unittest::TestAllTypesLite msg;
{
// Make sure that we support length-prefixed encoding for submsgs
static const char kWireFormat[] = "\n\002\010\003"; // 1: {1: 3}
io::CodedInputStream cis(reinterpret_cast<const uint8_t*>(kWireFormat), 4);
EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
EXPECT_TRUE(cis.ConsumedEntireMessage());
EXPECT_TRUE(cis.LastTagWas(0));
}
{
// Make sure that we support delimited encoding for submsgs
static const char kWireFormat[] = "\013\010\003\014"; // 1: !{1: 3}
io::CodedInputStream cis(reinterpret_cast<const uint8_t*>(kWireFormat), 4);
EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
EXPECT_TRUE(cis.ConsumedEntireMessage());
EXPECT_TRUE(cis.LastTagWas(0));
}
}

TYPED_TEST(LiteTest, DebugString) {
protobuf_unittest::TestAllTypesLite message1, message2;
EXPECT_TRUE(absl::StartsWith(message1.DebugString(), "MessageLite at 0x"));
Expand Down
114 changes: 91 additions & 23 deletions src/google/protobuf/wire_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,31 +578,59 @@ bool WireFormat::ParseAndMergeField(

case FieldDescriptor::TYPE_GROUP: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
}
break;
}

case FieldDescriptor::TYPE_MESSAGE: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
}
break;
}
}
Expand Down Expand Up @@ -997,19 +1025,59 @@ const char* WireFormat::_InternalParseAndMergeField(

case FieldDescriptor::TYPE_GROUP: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);

if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
ptr = ctx->ParseMessage(sub_message, ptr);

// For map entries, if the value is an unknown enum we have to push it
// into the unknown field set and remove it from the list.
if (ptr != nullptr && field->is_map()) {
auto* value_field = field->message_type()->map_value();
auto* enum_type = value_field->enum_type();
if (enum_type != nullptr &&
!internal::cpp::HasPreservingUnknownEnumSemantics(value_field) &&
enum_type->FindValueByNumber(
sub_message->GetReflection()->GetEnumValue(
*sub_message, value_field)) == nullptr) {
reflection->MutableUnknownFields(msg)->AddLengthDelimited(
field->number(), sub_message->SerializeAsString());
reflection->RemoveLast(msg, field);
}
}

return ctx->ParseGroup(sub_message, ptr, tag);
return ptr;
}
}

case FieldDescriptor::TYPE_MESSAGE: {
Message* sub_message;
if (field->is_repeated()) {
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
Expand Down

0 comments on commit 51d95e0

Please sign in to comment.