diff --git a/src/quartz/parser/qasm_parser.cpp b/src/quartz/parser/qasm_parser.cpp index 41ec7d8a..202c031e 100644 --- a/src/quartz/parser/qasm_parser.cpp +++ b/src/quartz/parser/qasm_parser.cpp @@ -144,16 +144,26 @@ bool ParamParser::parse_array_decl(std::stringstream &ss) { getline(ss, name); name = strip(name); - // Ensures that the name is unique. - if (symb_params_.count(name) > 0) { - std::cerr << "Each param must have a unique name: " << name << std::endl; - assert(false); - return false; - } + // Determines whether the parameter is being declared or reused. + if (first_file_) { + // Ensures that the name is unique. + if (symb_params_.count(name) > 0) { + std::cerr << "Each param must have a unique name: " << name << std::endl; + assert(false); + return false; + } - // Allocates a symbolic parameter for each element of the array. - for (int i = 0; i < len; ++i) { - symb_params_[name][i] = ctx_->get_new_param_id(); + // Allocates a symbolic parameter for each element of the array. + for (int i = 0; i < len; ++i) { + symb_params_[name][i] = ctx_->get_new_param_id(); + } + } else { + // Check that the parameter is declared and of the correct size. + if (symb_params_[name].size() != len) { + std::cerr << "Parameter size misalignment: " << name << std::endl; + assert(false); + return false; + } } return true; } diff --git a/src/quartz/parser/qasm_parser.h b/src/quartz/parser/qasm_parser.h index ee413e73..1216646a 100644 --- a/src/quartz/parser/qasm_parser.h +++ b/src/quartz/parser/qasm_parser.h @@ -40,8 +40,8 @@ std::string strip(const std::string &input); */ class ParamParser { public: - ParamParser(Context *ctx, bool symbolic_pi) - : ctx_(ctx), symbolic_pi_(symbolic_pi) {} + ParamParser(Context *ctx) + : ctx_(ctx), symbolic_pi_(false), first_file_(true) {} /** * Adds an angle array declaration to the registry of symbolic parameters. @@ -68,6 +68,21 @@ class ParamParser { */ int parse_expr(std::stringstream &token); + /** + * Calling this function allows for symbolic pi values to be enabled or + * disabled. When symbolic pi values are enabled, each constant pi/n will be + * replaced by the symbolic expression pi(n). + * @param v if true, then symbolic pi values will be enabled. + */ + void use_symbolic_pi(bool v) { symbolic_pi_ = v; } + + /** + * Calling this function indicates that a file has been entirely parsed. In + * particular, after the first file is parsed, only the names and indices of + * existing symbolic variables may be used. + */ + void end_file() { first_file_ = false; } + private: /** * Implementation details for parse_expr when the expression is a constant @@ -119,6 +134,14 @@ class ParamParser { * is evaluated and stored as a floating-point constant. */ bool symbolic_pi_; + + /** + * If true, then this parameter parser has already parsed an OpenQASM 3 file. + * When parsing the first file, it is expected that all parameter variables + * are new. When parsing subsequent files, it is expected that all parameter + * variables were defined in the original file. + */ + bool first_file_; }; /** @@ -200,7 +223,7 @@ class QubitParser { // Parser from OpenQASM files to CircuitSeq objects. class QASMParser { public: - QASMParser(Context *ctx) : ctx_(ctx), symbolic_pi_(false) {} + QASMParser(Context *ctx) : ctx_(ctx), param_parser_(ctx) {} template bool load_qasm_stream(std::basic_istream<_CharT, _Traits> &qasm_stream, @@ -223,11 +246,11 @@ class QASMParser { return res; } - void use_symbolic_pi(bool v) { symbolic_pi_ = v; } + void use_symbolic_pi(bool v) { param_parser_.use_symbolic_pi(v); } private: Context *ctx_; - bool symbolic_pi_; + ParamParser param_parser_; }; // We cannot put this template function implementation in a .cpp file. @@ -236,7 +259,6 @@ bool QASMParser::load_qasm_stream( std::basic_istream<_CharT, _Traits> &qasm_stream, CircuitSeq *&seq) { // Results and sub-parsers. seq = nullptr; - ParamParser param_parser(ctx_, symbolic_pi_); QubitParser qubit_parser; // Generalized control data. @@ -321,7 +343,7 @@ bool QASMParser::load_qasm_stream( } // Parses the parameter array. - if (!param_parser.parse_array_decl(ss)) { + if (!param_parser_.parse_array_decl(ss)) { return false; } } else if (is_gate_string(command, gate_type)) { @@ -344,7 +366,7 @@ bool QASMParser::load_qasm_stream( std::vector param_indices(num_params); for (int i = 0; i < num_params; ++i) { assert(ss.good()); - int index = param_parser.parse_expr(ss); + int index = param_parser_.parse_expr(ss); if (index == -1) { return false; } @@ -405,6 +427,9 @@ bool QASMParser::load_qasm_stream( assert(false); } } + + // Successfully parsed file. + param_parser_.end_file(); return true; } diff --git a/src/test/test_qasm_parser.cpp b/src/test/test_qasm_parser.cpp index 4b33ae56..ac143157 100644 --- a/src/test/test_qasm_parser.cpp +++ b/src/test/test_qasm_parser.cpp @@ -152,36 +152,75 @@ void test_param_parsing() { QASMParser parser(&ctx); - std::string str = "OPENQASM 2.0;\n" - "include \"qelib1.inc\";\n" - "qubit[2] q;\n" - "input array[angle,2] ps;\n" - "input array[float,3] params;\n" - "cx q[0], q[1];\n" - "rx(ps[0]) q[0];\n" - "rx(ps[1]) q[1];\n" - "rx(params[0]) q[0];\n" - "rx(params[1]) q[1];\n"; + // Tests parsing a first file. + std::string str1 = "OPENQASM 3;\n" + "include \"stdgates.inc\";\n" + "qubit[2] q;\n" + "input array[angle,2] ps;\n" + "input array[float,3] params;\n" + "cx q[0], q[1];\n" + "rx(ps[0]) q[0];\n" + "rx(ps[1]) q[1];\n" + "rx(params[0]) q[0];\n" + "rx(params[1]) q[1];\n"; CircuitSeq *seq1 = nullptr; - bool res1 = parser.load_qasm_str(str, seq1); + bool res1 = parser.load_qasm_str(str1, seq1); if (!res1) { std::cout << "Parsing failed with parameter variables." << std::endl; assert(false); return; } - int pnum = ctx.get_num_parameters(); - if (pnum != 5) { - std::cout << "Unexpected parameter total: " << pnum << "." << std::endl; + int pnum1 = ctx.get_num_parameters(); + if (pnum1 != 5) { + std::cout << "Unexpected parameter total: " << pnum1 << "." << std::endl; assert(false); } - int input_num = seq1->get_input_param_indices(&ctx).size(); - if (input_num != 4) { - std::cout << "Unexpected input count: " << input_num << "." << std::endl; + int input_num1 = seq1->get_input_param_indices(&ctx).size(); + if (input_num1 != 4) { + std::cout << "Unexpected input count: " << input_num1 << "." << std::endl; assert(false); } + + // Tests parsing a second file. + std::string str2 = "OPENQASM 3;\n" + "include \"stdgates.inc\";\n" + "qubit[5] q;\n" + "input array[angle,2] ps;\n" + "cx q[0], q[1];\n" + "rx(ps[0]) q[0];\n" + "rx(ps[1]) q[1];\n"; + + CircuitSeq *seq2 = nullptr; + bool res2 = parser.load_qasm_str(str2, seq2); + if (!res2) { + std::cout << "Parsing failed with parameter variables." << std::endl; + assert(false); + return; + } + + int pnum2 = ctx.get_num_parameters(); + if (pnum2 != 5) { + std::cout << "Unexpected parameter total: " << pnum2 << "." << std::endl; + assert(false); + } + + int input_num2 = seq2->get_input_param_indices(&ctx).size(); + if (input_num2 != 2) { + std::cout << "Unexpected input count: " << input_num2 << "." << std::endl; + assert(false); + } + + // Checks that parameters were reused. + auto all_indices = seq1->get_input_param_indices(&ctx); + for (auto j : seq2->get_input_param_indices(&ctx)) { + if (std::count(all_indices.begin(), all_indices.end(), j) == 0) { + std::cout << "Unexpected parameter: " << j << "." << std::endl; + assert(false); + } + } } int main() {