Skip to content

Commit

Permalink
Set up editions codegen tests for python
Browse files Browse the repository at this point in the history
These tests aren't super useful for python because of how little codegen we actually do, but the pyi ones specifically will guard against major editions regressions.

PiperOrigin-RevId: 577495652
  • Loading branch information
mkruskal-google authored and copybara-github committed Oct 28, 2023
1 parent b2efcdc commit 57bb1e5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 92 deletions.
157 changes: 65 additions & 92 deletions src/google/protobuf/compiler/python/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "google/protobuf/compiler/python/generator.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <string>
Expand All @@ -35,10 +37,12 @@
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/strings/substitute.h"
#include "google/protobuf/compiler/code_generator.h"
#include "google/protobuf/compiler/python/helpers.h"
#include "google/protobuf/compiler/python/pyi_generator.h"
#include "google/protobuf/compiler/retention.h"
Expand All @@ -49,6 +53,7 @@
#include "google/protobuf/io/printer.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/message.h"

namespace google {
namespace protobuf {
Expand Down Expand Up @@ -152,21 +157,6 @@ std::string StringifyDefaultValue(const FieldDescriptor& field) {
return "";
}

std::string StringifySyntax(FileDescriptorLegacy::Syntax syntax) {
switch (syntax) {
case FileDescriptorLegacy::Syntax::SYNTAX_PROTO2:
return "proto2";
case FileDescriptorLegacy::Syntax::SYNTAX_PROTO3:
return "proto3";
case FileDescriptorLegacy::Syntax::SYNTAX_UNKNOWN:
default:
ABSL_LOG(FATAL)
<< "Unsupported syntax; this generator only supports proto2 "
"and proto3 syntax.";
return "";
}
}

} // namespace

Generator::Generator() : file_(nullptr) {}
Expand Down Expand Up @@ -194,6 +184,8 @@ GeneratorOptions Generator::ParseParameter(absl::string_view parameter,
options.generate_pyi = true;
} else if (option.first == "annotate_code") {
options.annotate_pyi = true;
} else if (option.first == "experimental_strip_nonfunctional_codegen") {
options.strip_nonfunctional_codegen = true;
} else {
*error = absl::StrCat("Unknown generator option: ", option.first);
}
Expand All @@ -211,8 +203,15 @@ bool Generator::Generate(const FileDescriptor* file,
// Generate pyi typing information
if (options.generate_pyi) {
python::PyiGenerator pyi_generator;
std::string pyi_options = options.annotate_pyi ? "annotate_code" : "";
if (!pyi_generator.Generate(file, pyi_options, context, error)) {
std::vector<std::string> pyi_options;
if (options.annotate_pyi) {
pyi_options.push_back("annotate_code");
}
if (options.strip_nonfunctional_codegen) {
pyi_options.push_back("experimental_strip_nonfunctional_codegen");
}
if (!pyi_generator.Generate(file, absl::StrJoin(pyi_options, ","), context,
error)) {
return false;
}
}
Expand Down Expand Up @@ -423,7 +422,8 @@ void Generator::PrintFileDescriptor() const {
m["descriptor_name"] = kDescriptorKey;
m["name"] = file_->name();
m["package"] = file_->package();
m["syntax"] = StringifySyntax(FileDescriptorLegacy(file_).syntax());
m["syntax"] = std::string(
FileDescriptorLegacy::SyntaxName(FileDescriptorLegacy(file_).syntax()));
m["options"] = OptionsValue(
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
m["serialized_descriptor"] = absl::CHexEscape(file_descriptor_serialized_);
Expand Down Expand Up @@ -677,8 +677,7 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
"options_value", OptionsValue(options_string), "extendable",
message_descriptor.extension_range_count() > 0 ? "True" : "False",
"syntax",
StringifySyntax(
FileDescriptorLegacy(message_descriptor.file()).syntax()));
FileDescriptorLegacy::SyntaxName(FileDescriptorLegacy(file_).syntax()));
printer_->Print(",\n");

// Extension ranges
Expand Down Expand Up @@ -1167,7 +1166,7 @@ void Generator::PrintSerializedPbInterval(
const DescriptorProtoT& descriptor_proto, absl::string_view name) const {
std::string sp;
descriptor_proto.SerializeToString(&sp);
int offset = file_descriptor_serialized_.find(sp);
size_t offset = file_descriptor_serialized_.find(sp);
ABSL_CHECK_GE(offset, 0);

printer_->Print(
Expand All @@ -1177,26 +1176,34 @@ void Generator::PrintSerializedPbInterval(
absl::StrCat(offset + sp.size()));
}

namespace {
void PrintDescriptorOptionsFixingCode(absl::string_view descriptor,
absl::string_view options,
io::Printer* printer) {
template <typename DescriptorT>
bool Generator::PrintDescriptorOptionsFixingCode(
const DescriptorT& descriptor, absl::string_view descriptor_str) const {
std::string options = OptionsValue(
StripLocalSourceRetentionOptions(descriptor).SerializeAsString());

// Reset the _options to None thus DescriptorBase.GetOptions() can
// parse _options again after extensions are registered.
size_t dot_pos = descriptor.find('.');
size_t dot_pos = descriptor_str.find('.');
std::string descriptor_name;
if (dot_pos == std::string::npos) {
descriptor_name = absl::StrCat("_globals['", descriptor, "']");
descriptor_name = absl::StrCat("_globals['", descriptor_str, "']");
} else {
descriptor_name = absl::StrCat("_globals['", descriptor.substr(0, dot_pos),
"']", descriptor.substr(dot_pos));
descriptor_name =
absl::StrCat("_globals['", descriptor_str.substr(0, dot_pos), "']",
descriptor_str.substr(dot_pos));
}

if (options == "None") {
return false;
}
printer->Print(

printer_->Print(
"$descriptor_name$._options = None\n"
"$descriptor_name$._serialized_options = $serialized_value$\n",
"descriptor_name", descriptor_name, "serialized_value", options);
return true;
}
} // namespace

// Generates the start and end offsets for each entity in the serialized file
// descriptor. The file argument must exactly match what was serialized into
Expand Down Expand Up @@ -1246,11 +1253,7 @@ void Generator::SetMessagePbInterval(const DescriptorProto& message_proto,
// Prints expressions that set the options field of all descriptors.
void Generator::FixAllDescriptorOptions() const {
// Prints an expression that sets the file descriptor's options.
std::string file_options = OptionsValue(
StripLocalSourceRetentionOptions(*file_).SerializeAsString());
if (file_options != "None") {
PrintDescriptorOptionsFixingCode(kDescriptorKey, file_options, printer_);
} else {
if (!PrintDescriptorOptionsFixingCode(*file_, kDescriptorKey)) {
printer_->Print("DESCRIPTOR._options = None\n");
}
// Prints expressions that set the options for all top level enums.
Expand All @@ -1275,35 +1278,23 @@ void Generator::FixAllDescriptorOptions() const {
}

void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const {
std::string oneof_options =
OptionsValue(StripLocalSourceRetentionOptions(oneof).SerializeAsString());
if (oneof_options != "None") {
std::string oneof_name = absl::Substitute(
"$0.$1['$2']", ModuleLevelDescriptorName(*oneof.containing_type()),
"oneofs_by_name", oneof.name());
PrintDescriptorOptionsFixingCode(oneof_name, oneof_options, printer_);
}
std::string oneof_name = absl::Substitute(
"$0.$1['$2']", ModuleLevelDescriptorName(*oneof.containing_type()),
"oneofs_by_name", oneof.name());
PrintDescriptorOptionsFixingCode(oneof, oneof_name);
}

// Prints expressions that set the options for an enum descriptor and its
// value descriptors.
void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor) const {
std::string descriptor_name = ModuleLevelDescriptorName(enum_descriptor);
std::string enum_options = OptionsValue(
StripLocalSourceRetentionOptions(enum_descriptor).SerializeAsString());
if (enum_options != "None") {
PrintDescriptorOptionsFixingCode(descriptor_name, enum_options, printer_);
}
PrintDescriptorOptionsFixingCode(enum_descriptor, descriptor_name);
for (int i = 0; i < enum_descriptor.value_count(); ++i) {
const EnumValueDescriptor& value_descriptor = *enum_descriptor.value(i);
std::string value_options = OptionsValue(
StripLocalSourceRetentionOptions(value_descriptor).SerializeAsString());
if (value_options != "None") {
PrintDescriptorOptionsFixingCode(
absl::StrFormat("%s.values_by_name[\"%s\"]", descriptor_name.c_str(),
value_descriptor.name().c_str()),
value_options, printer_);
}
PrintDescriptorOptionsFixingCode(
value_descriptor,
absl::StrFormat("%s.values_by_name[\"%s\"]", descriptor_name.c_str(),
value_descriptor.name().c_str()));
}
}

Expand All @@ -1313,46 +1304,33 @@ void Generator::FixOptionsForService(
const ServiceDescriptor& service_descriptor) const {
std::string descriptor_name =
ModuleLevelServiceDescriptorName(service_descriptor);
std::string service_options = OptionsValue(
StripLocalSourceRetentionOptions(service_descriptor).SerializeAsString());
if (service_options != "None") {
PrintDescriptorOptionsFixingCode(descriptor_name, service_options,
printer_);
}
PrintDescriptorOptionsFixingCode(service_descriptor, descriptor_name);

for (int i = 0; i < service_descriptor.method_count(); ++i) {
const MethodDescriptor* method = service_descriptor.method(i);
std::string method_options = OptionsValue(
StripLocalSourceRetentionOptions(*method).SerializeAsString());
if (method_options != "None") {
std::string method_name = absl::StrCat(
descriptor_name, ".methods_by_name['", method->name(), "']");
PrintDescriptorOptionsFixingCode(method_name, method_options, printer_);
}
PrintDescriptorOptionsFixingCode(
*method, absl::StrCat(descriptor_name, ".methods_by_name['",
method->name(), "']"));
}
}

// Prints expressions that set the options for field descriptors (including
// extensions).
void Generator::FixOptionsForField(const FieldDescriptor& field) const {
std::string field_options =
OptionsValue(StripLocalSourceRetentionOptions(field).SerializeAsString());
if (field_options != "None") {
std::string field_name;
if (field.is_extension()) {
if (field.extension_scope() == nullptr) {
// Top level extensions.
field_name = field.name();
} else {
field_name = FieldReferencingExpression(field.extension_scope(), field,
"extensions_by_name");
}
std::string field_name;
if (field.is_extension()) {
if (field.extension_scope() == nullptr) {
// Top level extensions.
field_name = field.name();
} else {
field_name = FieldReferencingExpression(field.containing_type(), field,
"fields_by_name");
field_name = FieldReferencingExpression(field.extension_scope(), field,
"extensions_by_name");
}
PrintDescriptorOptionsFixingCode(field_name, field_options, printer_);
} else {
field_name = FieldReferencingExpression(field.containing_type(), field,
"fields_by_name");
}
PrintDescriptorOptionsFixingCode(field, field_name);
}

// Prints expressions that set the options for a message and all its inner
Expand Down Expand Up @@ -1381,13 +1359,8 @@ void Generator::FixOptionsForMessage(const Descriptor& descriptor) const {
FixOptionsForField(field);
}
// Message option for this message.
std::string message_options = OptionsValue(
StripLocalSourceRetentionOptions(descriptor).SerializeAsString());
if (message_options != "None") {
std::string descriptor_name = ModuleLevelDescriptorName(descriptor);
PrintDescriptorOptionsFixingCode(descriptor_name, message_options,
printer_);
}
PrintDescriptorOptionsFixingCode(descriptor,
ModuleLevelDescriptorName(descriptor));
}

// If a dependency forwards other files through public dependencies, let's
Expand Down
6 changes: 6 additions & 0 deletions src/google/protobuf/compiler/python/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__
#define GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__

#include <cstdint>
#include <string>
#include <vector>

Expand Down Expand Up @@ -49,6 +50,7 @@ struct GeneratorOptions {
bool generate_pyi = false;
bool annotate_pyi = false;
bool bootstrap = false;
bool strip_nonfunctional_codegen = false;
};

class PROTOC_EXPORT Generator : public CodeGenerator {
Expand Down Expand Up @@ -141,6 +143,10 @@ class PROTOC_EXPORT Generator : public CodeGenerator {
void PrintSerializedPbInterval(const DescriptorProtoT& descriptor_proto,
absl::string_view name) const;

template <typename DescriptorT>
bool PrintDescriptorOptionsFixingCode(const DescriptorT& descriptor,
absl::string_view descriptor_str) const;

void FixAllDescriptorOptions() const;
void FixOptionsForField(const FieldDescriptor& field) const;
void FixOptionsForOneof(const OneofDescriptor& oneof) const;
Expand Down
7 changes: 7 additions & 0 deletions src/google/protobuf/compiler/python/pyi_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/compiler/code_generator.h"
#include "google/protobuf/compiler/python/helpers.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor.pb.h"
Expand Down Expand Up @@ -169,6 +170,9 @@ void PyiGenerator::PrintImports() const {
bool has_importlib = false;
for (int i = 0; i < file_->dependency_count(); ++i) {
const FileDescriptor* dep = file_->dependency(i);
if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) {
continue;
}
PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib);
for (int j = 0; j < dep->public_dependency_count(); ++j) {
PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases,
Expand Down Expand Up @@ -570,11 +574,14 @@ bool PyiGenerator::Generate(const FileDescriptor* file,

std::string filename;
bool annotate_code = false;
strip_nonfunctional_codegen_ = false;
for (const std::pair<std::string, std::string>& option : options) {
if (option.first == "annotate_code") {
annotate_code = true;
} else if (absl::EndsWith(option.first, ".pyi")) {
filename = option.first;
} else if (option.first == "experimental_strip_nonfunctional_codegen") {
strip_nonfunctional_codegen_ = true;
} else {
*error = absl::StrCat("Unknown generator option: ", option.first);
return false;
Expand Down
8 changes: 8 additions & 0 deletions src/google/protobuf/compiler/python/pyi_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define GOOGLE_PROTOBUF_COMPILER_PYTHON_PYI_GENERATOR_H__

#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -53,6 +54,12 @@ class PROTOC_EXPORT PyiGenerator : public google::protobuf::compiler::CodeGenera
GeneratorContext* generator_context,
std::string* error) const override;

Edition GetMinimumEdition() const override { return Edition::EDITION_PROTO2; }
Edition GetMaximumEdition() const override { return Edition::EDITION_2023; }
std::vector<const FieldDescriptor*> GetFeatureExtensions() const override {
return {};
}

private:
void PrintImportForDescriptor(const FileDescriptor& desc,
absl::flat_hash_set<std::string>* seen_aliases,
Expand Down Expand Up @@ -83,6 +90,7 @@ class PROTOC_EXPORT PyiGenerator : public google::protobuf::compiler::CodeGenera
mutable absl::Mutex mutex_;
mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_.
mutable io::Printer* printer_; // Set in Generate(). Under mutex_.
mutable bool strip_nonfunctional_codegen_ = false; // Set in Generate().
// import_map will be a mapping from filename to module alias, e.g.
// "google3/foo/bar.py" -> "_bar"
mutable absl::flat_hash_map<std::string, std::string> import_map_;
Expand Down

0 comments on commit 57bb1e5

Please sign in to comment.