Skip to content

Commit

Permalink
Allow for static values as array sizes.
Browse files Browse the repository at this point in the history
This applies to both array types `[T * n]`, as well as RepeatArrayExpr
`[E; n]`. In both cases, n must be defined as `static n = ...`, or the
emitter might fail.
  • Loading branch information
m-kurtenacker committed Nov 5, 2024
1 parent b9a39d4 commit 45a32b3
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 27 deletions.
13 changes: 7 additions & 6 deletions include/artic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,15 @@ struct ArrayType : public Type {

/// Sized array type.
struct SizedArrayType : public ArrayType {
size_t size;
std::variant<size_t, ast::Path> size;
bool is_simd;

SizedArrayType(const Loc& loc, Ptr<Type>&& elem, size_t size, bool is_simd)
: ArrayType(loc, std::move(elem)), size(size), is_simd(is_simd)
SizedArrayType(const Loc& loc, Ptr<Type>&& elem, std::variant<size_t, ast::Path>&& size, bool is_simd)
: ArrayType(loc, std::move(elem)), size(std::move(size)), is_simd(is_simd)
{}

const artic::Type* infer(TypeChecker&) override;
void bind(NameBinder&) override;
void print(Printer&) const override;
};

Expand Down Expand Up @@ -673,11 +674,11 @@ struct ArrayExpr : public Expr {
/// Array expression repeating a given value a given number of times.
struct RepeatArrayExpr : public Expr {
Ptr<Expr> elem;
size_t size;
std::variant<size_t, ast::Path> size;
bool is_simd;

RepeatArrayExpr(const Loc& loc, Ptr<Expr>&& elem, size_t size, bool is_simd)
: Expr(loc), elem(std::move(elem)), size(size), is_simd(is_simd)
RepeatArrayExpr(const Loc& loc, Ptr<Expr>&& elem, std::variant<size_t, ast::Path>&& size, bool is_simd)
: Expr(loc), elem(std::move(elem)), size(std::move(size)), is_simd(is_simd)
{}

bool is_jumping() const override;
Expand Down
3 changes: 2 additions & 1 deletion include/artic/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ class Parser : public Logger {
ast::AsmExpr::Constr parse_constr();
Literal parse_lit();
std::string parse_str();
std::optional<size_t> parse_array_size();
size_t parse_addr_space();

std::optional<std::variant<size_t, ast::Path>> parse_array_size();

std::pair<Ptr<ast::Expr>, Ptr<ast::Expr>> parse_cond_and_block();

struct Tracker {
Expand Down
8 changes: 8 additions & 0 deletions src/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ void ArrayType::bind(NameBinder& binder) {
binder.bind(*elem);
}

void SizedArrayType::bind(NameBinder& binder) {
binder.bind(*elem);
if (std::holds_alternative<ast::Path>(size))
binder.bind(std::get<ast::Path>(size));
}

void FnType::bind(NameBinder& binder) {
binder.bind(*from);
if (to) binder.bind(*to);
Expand Down Expand Up @@ -179,6 +185,8 @@ void ArrayExpr::bind(NameBinder& binder) {

void RepeatArrayExpr::bind(NameBinder& binder) {
binder.bind(*elem);
if (std::holds_alternative<ast::Path>(size))
binder.bind(std::get<ast::Path>(size));
}

void FnExpr::bind(NameBinder& binder, bool in_for_loop) {
Expand Down
95 changes: 92 additions & 3 deletions src/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,37 @@ const artic::Type* SizedArrayType::infer(TypeChecker& checker) {
auto elem_type = checker.infer(*elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
return checker.invalid_simd(loc, elem_type);
return checker.type_table.sized_array_type(elem_type, size, is_simd);

if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.type_table.sized_array_type(elem_type, std::get<size_t>(size), is_simd);
}

const artic::Type* UnsizedArrayType::infer(TypeChecker& checker) {
Expand Down Expand Up @@ -994,12 +1024,71 @@ const artic::Type* RepeatArrayExpr::infer(TypeChecker& checker) {
auto elem_type = checker.deref(elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
return checker.invalid_simd(loc, elem_type);
return checker.type_table.sized_array_type(elem_type, size, is_simd);

if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.type_table.sized_array_type(elem_type, std::get<size_t>(size), is_simd);
}

const artic::Type* RepeatArrayExpr::check(TypeChecker& checker, const artic::Type* expected) {
if (std::holds_alternative<ast::Path>(size)) {
auto &path = std::get<ast::Path>(size);
const auto* decl = path.start_decl;

for (size_t i = 0, n = path.elems.size(); i < n; ++i) {
if (path.elems[i].is_super())
decl = i == 0 ? path.start_decl : decl->as<ModDecl>()->super;
if (auto mod_type = path.elems[i].type->isa<ModType>()) {
decl = &mod_type->member(path.elems[i + 1].index);
} else if (!path.is_ctor) {
assert(path.elems[i].inferred_args.empty());
assert(decl->isa<StaticDecl>() && "The only supported type right now.");
break;
} else if (match_app<StructType>(path.elems[i].type).second) {
assert(false && "This is not supported as a size for repeated arrays.");
} else if (auto [type_app, enum_type] = match_app<artic::EnumType>(path.elems[i].type); enum_type) {
assert(false && "This is not supported as a size for repeated arrays.");
}
}

auto static_decl = decl->as<StaticDecl>();
assert(!static_decl->is_mut);
assert(static_decl->init);
auto& value = static_decl->init;
auto lit_value = value->as<LiteralExpr>()->lit;

size = lit_value.as_integer();
}

return checker.check_array(loc, "array expression",
expected, size, is_simd, [&] (auto elem_type) {
expected, std::get<size_t>(size), is_simd, [&] (auto elem_type) {
checker.coerce(elem, elem_type);
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ const thorin::Def* ArrayExpr::emit(Emitter& emitter) const {
}

const thorin::Def* RepeatArrayExpr::emit(Emitter& emitter) const {
thorin::Array<const thorin::Def*> ops(size, emitter.emit(*elem));
thorin::Array<const thorin::Def*> ops(std::get<size_t>(size), emitter.emit(*elem));
return is_simd
? emitter.world.vector(ops, emitter.debug_info(*this))
: emitter.world.definite_array(ops, emitter.debug_info(*this));
Expand Down
28 changes: 14 additions & 14 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ Ptr<ast::Expr> Parser::parse_array_expr() {
auto size = parse_array_size();
expect(Token::RBracket);
if (size)
return make_ptr<ast::RepeatArrayExpr>(tracker(), std::move(elems.front()), *size, is_simd);
return make_ptr<ast::RepeatArrayExpr>(tracker(), std::move(elems.front()), std::move(*size), is_simd);
return make_ptr<ast::ArrayExpr>(tracker(), std::move(elems), is_simd);
} else if (accept(Token::Comma)) {
parse_list(Token::RBracket, Token::Comma, [&] {
Expand Down Expand Up @@ -1069,15 +1069,17 @@ Ptr<ast::ArrayType> Parser::parse_array_type() {
bool is_simd = accept(Token::Simd);
expect(Token::LBracket);
auto elem = parse_type();
std::optional<size_t> size;
if (is_simd || ahead().tag() == Token::Mul) {
expect(Token::Mul);
size = parse_array_size();
auto size = parse_array_size();
expect(Token::RBracket);
if (size)
return make_ptr<ast::SizedArrayType>(tracker(), std::move(elem), std::move(*size), is_simd);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
} else {
expect(Token::RBracket);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
}
expect(Token::RBracket);
if (size)
return make_ptr<ast::SizedArrayType>(tracker(), std::move(elem), *size, is_simd);
return make_ptr<ast::UnsizedArrayType>(tracker(), std::move(elem));
}

Ptr<ast::FnType> Parser::parse_fn_type() {
Expand Down Expand Up @@ -1246,17 +1248,15 @@ std::string Parser::parse_str() {
return str;
}

std::optional<size_t> Parser::parse_array_size() {
std::optional<size_t> size;
std::optional<std::variant<size_t, ast::Path>> Parser::parse_array_size() {
if (ahead().is_literal() && ahead().literal().is_integer()) {
size = ahead().literal().as_integer();
auto size = ahead().literal().as_integer();
eat(Token::Lit);
return size;
} else {
error(ahead().loc(), "expected integer literal as array size");
if (ahead().tag() != Token::RBracket)
next();
auto path = parse_path();
return path;
}
return size;
}

size_t Parser::parse_addr_space() {
Expand Down
20 changes: 18 additions & 2 deletions src/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,15 @@ void RepeatArrayExpr::print(Printer& p) const {
p << log::keyword_style("simd");
p << '[';
elem->print(p);
p << "; " << size << ']';
p << "; ";
std::visit([&] (auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, size_t>)
p << arg;
else if constexpr (std::is_same_v<T, ast::Path&>)
arg->print(p);
}, size);
p << ']';
}

void FnExpr::print(Printer& p) const {
Expand Down Expand Up @@ -647,7 +655,15 @@ void SizedArrayType::print(Printer& p) const {
p << log::keyword_style("simd");
p << '[';
elem->print(p);
p << " * " << size << ']';
p << " * ";
std::visit([&] (auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, size_t>)
p << arg;
else if constexpr (std::is_same_v<T, ast::Path&>)
arg->print(p);
}, size);
p << ']';
}

void UnsizedArrayType::print(Printer& p) const {
Expand Down

0 comments on commit 45a32b3

Please sign in to comment.