Skip to content

Commit

Permalink
Add TypedSyntaxNode::cast and Terminal::cast_token methods (#7044)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkaput authored Jan 10, 2025
1 parent ec7c129 commit baee641
Show file tree
Hide file tree
Showing 7 changed files with 3,120 additions and 42 deletions.
7 changes: 3 additions & 4 deletions crates/cairo-lang-doc/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use cairo_lang_semantic::items::functions::GenericFunctionId;
use cairo_lang_semantic::resolve::{AsSegments, ResolvedGenericItem, Resolver};
use cairo_lang_syntax::node::ast::{Expr, ExprPath, ItemModule};
use cairo_lang_syntax::node::helpers::GetIdentifier;
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode};
use cairo_lang_utils::Intern;
use pulldown_cmark::{
Expand Down Expand Up @@ -230,9 +229,9 @@ impl<'a> DocumentationCommentParser<'a> {
// Get the stack (bottom-up) of submodule names in the file containing the node, in the main
// module, that lead to the node.
iter::successors(node.parent(), SyntaxNode::parent)
.filter(|node| node.kind(syntax_db) == SyntaxKind::ItemModule)
.map(|node| {
ItemModule::from_syntax_node(syntax_db, node)
.filter_map(|node| ItemModule::cast(syntax_db, node))
.map(|item_module| {
item_module
.stable_ptr()
.name_green(syntax_db)
.identifier(syntax_db)
Expand Down
11 changes: 4 additions & 7 deletions crates/cairo-lang-formatter/src/formatter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,13 +1081,10 @@ impl<'a> FormatterImpl<'a> {
/// Returns whether the node has only whitespace trivia.
fn has_only_whitespace_trivia(&self, node: &SyntaxNode) -> bool {
node.descendants(self.db).all(|descendant| {
if descendant.kind(self.db) == SyntaxKind::Trivia {
ast::Trivia::from_syntax_node(self.db, descendant)
.elements(self.db)
.into_iter()
.all(|element| {
matches!(element, ast::Trivium::Whitespace(_) | ast::Trivium::Newline(_))
})
if let Some(trivia) = ast::Trivia::cast(self.db, descendant) {
trivia.elements(self.db).into_iter().all(|element| {
matches!(element, ast::Trivium::Whitespace(_) | ast::Trivium::Newline(_))
})
} else {
true
}
Expand Down
8 changes: 3 additions & 5 deletions crates/cairo-lang-plugins/src/plugins/generate_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use cairo_lang_defs::plugin::{
use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -160,10 +159,7 @@ fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> Plu
for node in
db.get_children(signature.parameters(db).node.clone()).iter().cloned()
{
if node.kind(db) != SyntaxKind::Param {
builder.add_node(node);
} else {
let param = ast::Param::from_syntax_node(db, node);
if let Some(param) = ast::Param::cast(db, node.clone()) {
for modifier in param.modifiers(db).elements(db) {
// `mut` modifiers are only relevant for impls, not traits.
if !matches!(modifier, ast::Modifier::Mut(_)) {
Expand All @@ -172,6 +168,8 @@ fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> Plu
}
builder.add_node(param.name(db).as_syntax_node());
builder.add_node(param.type_clause(db).as_syntax_node());
} else {
builder.add_node(node);
}
}
let rparen = signature.rparen(db);
Expand Down
6 changes: 2 additions & 4 deletions crates/cairo-lang-starknet/src/plugin/starknet_module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ use cairo_lang_plugins::plugins::HasItemsInCfgEx;
use cairo_lang_syntax::node::ast::MaybeModuleBody;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::{BodyItems, QueryAttrs};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{SyntaxNode, Terminal, TypedSyntaxNode, ast};
use cairo_lang_utils::{extract_matches, require};
use cairo_lang_utils::extract_matches;

use self::component::generate_component_specific_code;
use self::contract::generate_contract_specific_code;
Expand Down Expand Up @@ -244,8 +243,7 @@ fn grand_grand_parent_starknet_module(
// Get the containing module node. The parent is the item list, the grand parent is the module
// body, and the grand grand parent is the module.
let module_node = item_node.parent()?.parent()?.parent()?;
require(module_node.kind(db) == SyntaxKind::ItemModule)?;
let module_ast = ast::ItemModule::from_syntax_node(db, module_node);
let module_ast = ast::ItemModule::cast(db, module_node)?;
let (module_kind, attr) = StarknetModuleKind::from_module(db, &module_ast)?;
Some((module_ast, module_kind, attr))
}
Expand Down
32 changes: 32 additions & 0 deletions crates/cairo-lang-syntax-codegen/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,13 @@ fn gen_common_list_code(name: &str, green_name: &str, ptr_name: &str) -> rust::T
fn from_syntax_node(db: &dyn SyntaxGroup, node: SyntaxNode) -> Self {
Self(ElementList::new(node))
}
fn cast(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<Self> {
if node.kind(db) == SyntaxKind::$name {
Some(Self(ElementList::new(node)))
} else {
None
}
}
fn as_syntax_node(&self) -> SyntaxNode {
self.node.clone()
}
Expand All @@ -390,6 +397,7 @@ fn gen_enum_code(
let green_name = format!("{name}Green");
let mut enum_body = quote! {};
let mut from_node_body = quote! {};
let mut cast_body = quote! {};
let mut ptr_conversions = quote! {};
let mut green_conversions = quote! {};
for variant in &variants {
Expand All @@ -402,6 +410,9 @@ fn gen_enum_code(
from_node_body.extend(quote! {
SyntaxKind::$k => $(&name)::$n($k::from_syntax_node(db, node)),
});
cast_body.extend(quote! {
SyntaxKind::$k => Some($(&name)::$n($k::from_syntax_node(db, node))),
});
let variant_ptr = format!("{k}Ptr");
ptr_conversions.extend(quote! {
impl From<$(&variant_ptr)> for $(&ptr_name) {
Expand Down Expand Up @@ -469,6 +480,13 @@ fn gen_enum_code(
$[str]($[const](&name))),
}
}
fn cast(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<Self> {
let kind = node.kind(db);
match kind {
$cast_body
_ => None,
}
}
fn as_syntax_node(&self) -> SyntaxNode {
match self {
$(for v in &variants => $(&name)::$(&v.name)(x) => x.as_syntax_node(),)
Expand Down Expand Up @@ -556,6 +574,12 @@ fn gen_token_code(name: String) -> rust::Tokens {
),
}
}
fn cast(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<Self> {
match node.0.green.lookup_intern(db).details {
GreenNodeDetails::Token(_) => Some(Self { node }),
GreenNodeDetails::Node { .. } => None,
}
}
fn as_syntax_node(&self) -> SyntaxNode {
self.node.clone()
}
Expand Down Expand Up @@ -706,6 +730,14 @@ fn gen_struct_code(name: String, members: Vec<Member>, is_terminal: bool) -> rus
let children = db.get_children(node.clone());
Self { node, children }
}
fn cast(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<Self> {
let kind = node.kind(db);
if kind == SyntaxKind::$(&name) {
Some(Self::from_syntax_node(db, node))
} else {
None
}
}
fn as_syntax_node(&self) -> SyntaxNode {
self.node.clone()
}
Expand Down
Loading

0 comments on commit baee641

Please sign in to comment.