Skip to content

Commit

Permalink
[Parser] Parsing multiple files with symbolic parameters (#188)
Browse files Browse the repository at this point in the history
* [Parser] support for parameters across files

* [Style] precommit changes
  • Loading branch information
ScottWe authored Dec 6, 2024
1 parent 6166a6f commit cdfcc27
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 34 deletions.
28 changes: 19 additions & 9 deletions src/quartz/parser/qasm_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,26 @@ bool ParamParser::parse_array_decl(std::stringstream &ss) {
getline(ss, name);
name = strip(name);

// Ensures that the name is unique.
if (symb_params_.count(name) > 0) {
std::cerr << "Each param must have a unique name: " << name << std::endl;
assert(false);
return false;
}
// Determines whether the parameter is being declared or reused.
if (first_file_) {
// Ensures that the name is unique.
if (symb_params_.count(name) > 0) {
std::cerr << "Each param must have a unique name: " << name << std::endl;
assert(false);
return false;
}

// Allocates a symbolic parameter for each element of the array.
for (int i = 0; i < len; ++i) {
symb_params_[name][i] = ctx_->get_new_param_id();
// Allocates a symbolic parameter for each element of the array.
for (int i = 0; i < len; ++i) {
symb_params_[name][i] = ctx_->get_new_param_id();
}
} else {
// Check that the parameter is declared and of the correct size.
if (symb_params_[name].size() != len) {
std::cerr << "Parameter size misalignment: " << name << std::endl;
assert(false);
return false;
}
}
return true;
}
Expand Down
41 changes: 33 additions & 8 deletions src/quartz/parser/qasm_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ std::string strip(const std::string &input);
*/
class ParamParser {
public:
ParamParser(Context *ctx, bool symbolic_pi)
: ctx_(ctx), symbolic_pi_(symbolic_pi) {}
ParamParser(Context *ctx)
: ctx_(ctx), symbolic_pi_(false), first_file_(true) {}

/**
* Adds an angle array declaration to the registry of symbolic parameters.
Expand All @@ -68,6 +68,21 @@ class ParamParser {
*/
int parse_expr(std::stringstream &token);

/**
* Calling this function allows for symbolic pi values to be enabled or
* disabled. When symbolic pi values are enabled, each constant pi/n will be
* replaced by the symbolic expression pi(n).
* @param v if true, then symbolic pi values will be enabled.
*/
void use_symbolic_pi(bool v) { symbolic_pi_ = v; }

/**
* Calling this function indicates that a file has been entirely parsed. In
* particular, after the first file is parsed, only the names and indices of
* existing symbolic variables may be used.
*/
void end_file() { first_file_ = false; }

private:
/**
* Implementation details for parse_expr when the expression is a constant
Expand Down Expand Up @@ -119,6 +134,14 @@ class ParamParser {
* is evaluated and stored as a floating-point constant.
*/
bool symbolic_pi_;

/**
* If true, then this parameter parser has already parsed an OpenQASM 3 file.
* When parsing the first file, it is expected that all parameter variables
* are new. When parsing subsequent files, it is expected that all parameter
* variables were defined in the original file.
*/
bool first_file_;
};

/**
Expand Down Expand Up @@ -200,7 +223,7 @@ class QubitParser {
// Parser from OpenQASM files to CircuitSeq objects.
class QASMParser {
public:
QASMParser(Context *ctx) : ctx_(ctx), symbolic_pi_(false) {}
QASMParser(Context *ctx) : ctx_(ctx), param_parser_(ctx) {}

template <class _CharT, class _Traits>
bool load_qasm_stream(std::basic_istream<_CharT, _Traits> &qasm_stream,
Expand All @@ -223,11 +246,11 @@ class QASMParser {
return res;
}

void use_symbolic_pi(bool v) { symbolic_pi_ = v; }
void use_symbolic_pi(bool v) { param_parser_.use_symbolic_pi(v); }

private:
Context *ctx_;
bool symbolic_pi_;
ParamParser param_parser_;
};

// We cannot put this template function implementation in a .cpp file.
Expand All @@ -236,7 +259,6 @@ bool QASMParser::load_qasm_stream(
std::basic_istream<_CharT, _Traits> &qasm_stream, CircuitSeq *&seq) {
// Results and sub-parsers.
seq = nullptr;
ParamParser param_parser(ctx_, symbolic_pi_);
QubitParser qubit_parser;

// Generalized control data.
Expand Down Expand Up @@ -321,7 +343,7 @@ bool QASMParser::load_qasm_stream(
}

// Parses the parameter array.
if (!param_parser.parse_array_decl(ss)) {
if (!param_parser_.parse_array_decl(ss)) {
return false;
}
} else if (is_gate_string(command, gate_type)) {
Expand All @@ -344,7 +366,7 @@ bool QASMParser::load_qasm_stream(
std::vector<int> param_indices(num_params);
for (int i = 0; i < num_params; ++i) {
assert(ss.good());
int index = param_parser.parse_expr(ss);
int index = param_parser_.parse_expr(ss);
if (index == -1) {
return false;
}
Expand Down Expand Up @@ -405,6 +427,9 @@ bool QASMParser::load_qasm_stream(
assert(false);
}
}

// Successfully parsed file.
param_parser_.end_file();
return true;
}

Expand Down
73 changes: 56 additions & 17 deletions src/test/test_qasm_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,36 +152,75 @@ void test_param_parsing() {

QASMParser parser(&ctx);

std::string str = "OPENQASM 2.0;\n"
"include \"qelib1.inc\";\n"
"qubit[2] q;\n"
"input array[angle,2] ps;\n"
"input array[float,3] params;\n"
"cx q[0], q[1];\n"
"rx(ps[0]) q[0];\n"
"rx(ps[1]) q[1];\n"
"rx(params[0]) q[0];\n"
"rx(params[1]) q[1];\n";
// Tests parsing a first file.
std::string str1 = "OPENQASM 3;\n"
"include \"stdgates.inc\";\n"
"qubit[2] q;\n"
"input array[angle,2] ps;\n"
"input array[float,3] params;\n"
"cx q[0], q[1];\n"
"rx(ps[0]) q[0];\n"
"rx(ps[1]) q[1];\n"
"rx(params[0]) q[0];\n"
"rx(params[1]) q[1];\n";

CircuitSeq *seq1 = nullptr;
bool res1 = parser.load_qasm_str(str, seq1);
bool res1 = parser.load_qasm_str(str1, seq1);
if (!res1) {
std::cout << "Parsing failed with parameter variables." << std::endl;
assert(false);
return;
}

int pnum = ctx.get_num_parameters();
if (pnum != 5) {
std::cout << "Unexpected parameter total: " << pnum << "." << std::endl;
int pnum1 = ctx.get_num_parameters();
if (pnum1 != 5) {
std::cout << "Unexpected parameter total: " << pnum1 << "." << std::endl;
assert(false);
}

int input_num = seq1->get_input_param_indices(&ctx).size();
if (input_num != 4) {
std::cout << "Unexpected input count: " << input_num << "." << std::endl;
int input_num1 = seq1->get_input_param_indices(&ctx).size();
if (input_num1 != 4) {
std::cout << "Unexpected input count: " << input_num1 << "." << std::endl;
assert(false);
}

// Tests parsing a second file.
std::string str2 = "OPENQASM 3;\n"
"include \"stdgates.inc\";\n"
"qubit[5] q;\n"
"input array[angle,2] ps;\n"
"cx q[0], q[1];\n"
"rx(ps[0]) q[0];\n"
"rx(ps[1]) q[1];\n";

CircuitSeq *seq2 = nullptr;
bool res2 = parser.load_qasm_str(str2, seq2);
if (!res2) {
std::cout << "Parsing failed with parameter variables." << std::endl;
assert(false);
return;
}

int pnum2 = ctx.get_num_parameters();
if (pnum2 != 5) {
std::cout << "Unexpected parameter total: " << pnum2 << "." << std::endl;
assert(false);
}

int input_num2 = seq2->get_input_param_indices(&ctx).size();
if (input_num2 != 2) {
std::cout << "Unexpected input count: " << input_num2 << "." << std::endl;
assert(false);
}

// Checks that parameters were reused.
auto all_indices = seq1->get_input_param_indices(&ctx);
for (auto j : seq2->get_input_param_indices(&ctx)) {
if (std::count(all_indices.begin(), all_indices.end(), j) == 0) {
std::cout << "Unexpected parameter: " << j << "." << std::endl;
assert(false);
}
}
}

int main() {
Expand Down

0 comments on commit cdfcc27

Please sign in to comment.