Skip to content

Commit

Permalink
[Parser] support from multi-term expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottWe committed Dec 8, 2024
1 parent 7c3cabf commit 6b632ae
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 8 deletions.
66 changes: 62 additions & 4 deletions src/quartz/parser/qasm_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,70 @@ int ParamParser::parse_expr(std::stringstream &ss) {
std::string token;
ss >> token;

// Determines if angle is negative or positive.
bool negative = (token[0] == '-');
if (negative) {
// Determines if the first term is negative.
bool neg_prefix = (token != "" && token[0] == '-');
if (neg_prefix) {
token = token.substr(1);
}
return parse_term(negative, token);

// Ensures that the string is token is non-empty.
if (token == "") {
std::cerr << "Unexpected end-of-line while parsing expr." << std::endl;
assert(false);
return -1;
}

// Parses all (+) and (-) deliminators, starting from right-to-left.
// Along the way, all terms will be parsed, and converted to parameters.
// The param_expr_id for a running sum of all terms is given by id.
int id = -1;
while (token != "") {
// Determines where the expression splits into terms, when applicable.
// The right-most (last) deliminator will identify the next term to parse.
size_t pos = token.find_last_of("+-");

// Determines which case this corresponds to.
int tid;
if (pos == std::string::npos) {
// Case: t, -t
tid = parse_term(neg_prefix, token);
token = "";
} else if (pos > 0) {
// Case: t+e, t-e
bool is_minus = (token[pos] == '-');

// Splits the token at the deliminator.
auto term = token.substr(pos + 1);
token = token.substr(0, pos);

// Parses the right-hand side as a token.
// The substraction is absorbed by this term as a negative sign.
tid = parse_term(is_minus, term);
} else {
std::cerr << "Unexpected (+) or (-) at index 0: " << token << std::endl;
assert(false);
return -1;
}

// Adds the new term to the expression, if this is not the right-most term.
if (id != -1) {
if (sum_params_[tid].count(id) == 0) {
auto g = ctx_->get_gate(GateType::add);
sum_params_[tid][id] = ctx_->get_new_param_expression_id({tid, id}, g);
}
id = sum_params_[tid][id];
} else {
id = tid;
}

// Ensures that the new term was created successfully.
if (id == -1) {
std::cerr << "Unexpected error: failed to construct sum." << std::endl;
assert(false);
return -1;
}
}
return id;
}

int ParamParser::parse_term(bool negative, std::string token) {
Expand Down
19 changes: 16 additions & 3 deletions src/quartz/parser/qasm_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@ class ParamParser {
bool parse_array_decl(std::stringstream &ss);

/**
* Parses a stream which is known to contain a parameter expression. The
* following formats are supported, where n and m are decimal literals, i is
* an integer literal, and name is a string:
* Parses a stream which is known to contain a parameter expression. Each
* parameter expression consists of one or more terms composed together as
* follows, where t is a term and e is an expression:
* | t
* | -t
* | e+t
* | e-t
* For each term t, the following formats are supported, where n and m are
* decimal literals, i is an integer literal, and name is a string:
* | pi*n
* | n*pi
* | n*pi/m
Expand Down Expand Up @@ -144,6 +150,13 @@ class ParamParser {
*/
std::unordered_map<std::string, std::unordered_map<int, int>> symb_params_;

/**
* Maps a pair of parameter identifiers, to the the identifier of a symbolic
* parameter, in the context of ctx_, which corresponds to sum of their
* corresponding parameters.
*/
std::unordered_map<int, std::unordered_map<int, int>> sum_params_;

/**
* If true, then rational multiples of pi are evaluated exactly as symbolic
* values. For example, 3*pi/2 becomes mult(3, pi(2)). If false, then 3*pi/2
Expand Down
65 changes: 64 additions & 1 deletion src/test/test_qasm_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void test_symbolic_exprs() {
auto mat2 = seq2->get_matrix(&ctx);

assert(mat1.size() == mat2.size());
for (int i = 0; i < mat1.size(); ++i) {
for (size_t i = 0; i < mat1.size(); ++i) {
assert(mat1[i].size() == mat2[i].size());
for (int j = 0; j < mat1[i].size(); ++j) {
if (mat1[i][j] != mat2[i][j]) {
Expand Down Expand Up @@ -223,6 +223,67 @@ void test_param_parsing() {
}
}

void test_sum_parsing() {
ParamInfo param_info(0);
Context ctx({GateType::rx, GateType::mult, GateType::add, GateType::pi}, 2,
&param_info);

QASMParser parser(&ctx);
parser.use_symbolic_pi(true);

std::string str1 = "OPENQASM 2.0;\n"
"include \"qelib1.inc\";\n"
"qubit[1] q;\n"
"rx(pi/5) q[1];\n"
"rx(3*pi/2) q[1];\n"
"rx(-0.32) q[1];\n";

CircuitSeq *seq1 = nullptr;
bool res1 = parser.load_qasm_str(str1, seq1);
if (!res1) {
std::cout << "Unexpected parsing failure." << std::endl;
assert(false);
return;
}

std::string str2 = "OPENQASM 3;\n"
"include \"stdgates.inc\";\n"
"qubit[1] q;\n"
"rx(pi/5+3*pi/2-0.32) q[1];\n";

CircuitSeq *seq2 = nullptr;
bool res2 = parser.load_qasm_str(str2, seq2);
if (!res2) {
std::cout << "Parsing failed with sums of terms." << std::endl;
assert(false);
return;
}

int pnum = ctx.get_num_parameters();
if (pnum != 9) {
// Expected caching.
// - Terms: 2, 3, 5, pi/2, pi/5, 3*pi/2, -0.32
// - Exprs: 3*pi.2-0.32, pi/5+3*pi/2-0.32
std::cout << "Failed to cache all intermediate values." << std::endl;
std::cout << "Number of parameters: " << pnum << std::endl;
assert(false);
}

auto mat1 = seq1->get_matrix(&ctx);
auto mat2 = seq2->get_matrix(&ctx);

assert(mat1.size() == mat2.size());
for (size_t i = 0; i < mat1.size(); ++i) {
assert(mat1[i].size() == mat2[i].size());
for (int j = 0; j < mat1[i].size(); ++j) {
if (mat1[i][j] != mat2[i][j]) {
std::cout << "Disagree at " << i << ", " << j << "." << std::endl;
assert(false);
}
}
}
}

int main() {
std::cout << "[Symbolic Expression Tests]" << std::endl;
test_symbolic_exprs();
Expand All @@ -232,4 +293,6 @@ int main() {
test_qasm3_qubits();
std::cout << "[Sybmolic Parameter Parsing Tests]" << std::endl;
test_param_parsing();
std::cout << "[Sybmolic Summation Parsing Tests]" << std::endl;
test_sum_parsing();
}

0 comments on commit 6b632ae

Please sign in to comment.