Skip to content

Commit

Permalink
Move op fct ptrs away from generic ASTNode (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer committed Jan 5, 2024
1 parent 55ac3a7 commit f96f19a
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 17 deletions.
1 change: 0 additions & 1 deletion src/ast/ASTBuilder.cpp
Expand Up @@ -1256,7 +1256,6 @@ std::any ASTBuilder::visitAssignOp(SpiceParser::AssignOpContext *ctx) {
assignExprNode->op = AssignExprNode::OP_XOR_EQUAL;
else
assert_fail("Unknown assign operator");
assignExprNode->hasOperator = true;

return nullptr;
}
Expand Down
41 changes: 35 additions & 6 deletions src/ast/ASTNodes.h
Expand Up @@ -81,12 +81,19 @@ class ASTNode {
}
// Reserve this node
symbolTypes.resize(manifestationCount, SymbolType(TY_INVALID));
// Reserve operator functions
opFct.resize(manifestationCount, {nullptr});
// Do custom work
customItemsInitialization(manifestationCount);
}

virtual std::vector<std::vector<const Function *>> *getOpFctPointers() { // LCOV_EXCL_LINE
assert_fail("The given node does not overload the getOpFctPointers function"); // LCOV_EXCL_LINE
return nullptr; // LCOV_EXCL_LINE
} // LCOV_EXCL_LINE
virtual const std::vector<std::vector<const Function *>> *getOpFctPointers() const { // LCOV_EXCL_LINE
assert_fail("The given node does not overload the getOpFctPointers function"); // LCOV_EXCL_LINE
return nullptr; // LCOV_EXCL_LINE
} // LCOV_EXCL_LINE

virtual void customItemsInitialization(size_t) {} // Noop

SymbolType setEvaluatedSymbolType(const SymbolType &symbolType, const size_t idx) {
Expand Down Expand Up @@ -147,12 +154,11 @@ class ASTNode {
std::vector<ASTNode *> children;
const CodeLoc codeLoc;
std::vector<SymbolType> symbolTypes;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
bool unreachable = false;
};

// Make sure we have no unexpected increases in memory consumption
static_assert(sizeof(ASTNode) == 136);
static_assert(sizeof(ASTNode) == 112);

// ========================================================== EntryNode ==========================================================

Expand Down Expand Up @@ -1341,10 +1347,13 @@ class AssignExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool returnsOnAllControlPaths(bool *doSetPredecessorsUnreachable) const override;
[[nodiscard]] bool isAssignExpr() const override { return true; }
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
AssignOp op = OP_NONE;
bool hasOperator = false;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

// ======================================================= TernaryExprNode =======================================================
Expand Down Expand Up @@ -1487,9 +1496,13 @@ class EqualityExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool hasCompileTimeValue() const override;
[[nodiscard]] CompileTimeValue getCompileTimeValue() const override;
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
EqualityOp op = OP_NONE;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

// ==================================================== RelationalExprNode =======================================================
Expand Down Expand Up @@ -1547,9 +1560,13 @@ class ShiftExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool hasCompileTimeValue() const override;
[[nodiscard]] CompileTimeValue getCompileTimeValue() const override;
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
ShiftOp op = OP_NONE;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

// ==================================================== AdditiveExprNode =========================================================
Expand Down Expand Up @@ -1578,9 +1595,13 @@ class AdditiveExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool hasCompileTimeValue() const override;
[[nodiscard]] CompileTimeValue getCompileTimeValue() const override;
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
OpQueue opQueue;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

// ================================================== MultiplicativeExprNode =====================================================
Expand Down Expand Up @@ -1610,9 +1631,13 @@ class MultiplicativeExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool hasCompileTimeValue() const override;
[[nodiscard]] CompileTimeValue getCompileTimeValue() const override;
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
OpQueue opQueue;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

// ======================================================= CastExprNode ==========================================================
Expand Down Expand Up @@ -1701,10 +1726,14 @@ class PostfixUnaryExprNode : public ASTNode {
// Other methods
[[nodiscard]] bool hasCompileTimeValue() const override;
[[nodiscard]] CompileTimeValue getCompileTimeValue() const override;
[[nodiscard]] std::vector<std::vector<const Function *>> *getOpFctPointers() override { return &opFct; }
[[nodiscard]] const std::vector<std::vector<const Function *>> *getOpFctPointers() const override { return &opFct; }
void customItemsInitialization(size_t manifestationCount) override { opFct.resize(manifestationCount, {nullptr}); }

// Public members
PostfixUnaryOp op = OP_NONE;
std::string identifier; // Only set when operator is member access
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
std::string identifier; // Only set when operator is member access
};

// ====================================================== AtomicExprNode =========================================================
Expand Down
2 changes: 1 addition & 1 deletion src/irgenerator/GenExpressions.cpp
Expand Up @@ -14,7 +14,7 @@ std::any IRGenerator::visitAssignExpr(const AssignExprNode *node) {
return visit(node->ternaryExpr());

// Assign or compound assign operation
if (node->hasOperator) {
if (node->op != AssignExprNode::OP_NONE) {
const PrefixUnaryExprNode *lhsNode = node->lhs();
const AssignExprNode *rhsNode = node->rhs();

Expand Down
4 changes: 2 additions & 2 deletions src/irgenerator/IRGenerator.cpp
Expand Up @@ -541,8 +541,8 @@ std::string IRGenerator::getIRString() const {
* @return Op fct pointer list
*/
const std::vector<const Function *> &IRGenerator::getOpFctPointers(const ASTNode *node) const {
assert(node->opFct.size() > manIdx);
return node->opFct.at(manIdx);
assert(node->getOpFctPointers()->size() > manIdx);
return node->getOpFctPointers()->at(manIdx);
}

} // namespace spice::compiler
7 changes: 4 additions & 3 deletions src/irgenerator/OpRuleConversionManager.cpp
Expand Up @@ -1665,9 +1665,10 @@ template <size_t N>
LLVMExprResult OpRuleConversionManager::callOperatorOverloadFct(const ASTNode *node, const std::array<ResolverFct, N * 2> &opV,
size_t opIdx) {
const size_t manIdx = irGenerator->manIdx;
assert(!node->opFct.empty() && node->opFct.size() > manIdx);
assert(!node->opFct.at(manIdx).empty() && node->opFct.at(manIdx).size() > opIdx);
const Function *opFct = node->opFct.at(manIdx).at(opIdx);
const std::vector<std::vector<const Function *>> *opFctPointers = node->getOpFctPointers();
assert(!opFctPointers->empty() && opFctPointers->size() > manIdx);
assert(!opFctPointers->at(manIdx).empty() && opFctPointers->at(manIdx).size() > opIdx);
const Function *opFct = opFctPointers->at(manIdx).at(opIdx);
assert(opFct != nullptr);

const std::string mangledName = opFct->getMangledName();
Expand Down
8 changes: 4 additions & 4 deletions src/typechecker/TypeChecker.cpp
Expand Up @@ -253,7 +253,7 @@ std::any TypeChecker::visitIfStmt(IfStmtNode *node) {
SOFT_ERROR_ER(node->condition(), CONDITION_MUST_BE_BOOL, "If condition must be of type bool")

// Warning for bool assignment
if (condition->hasOperator && condition->op == AssignExprNode::OP_ASSIGN)
if (condition->op == AssignExprNode::OP_ASSIGN)
sourceFile->compilerOutput.warnings.emplace_back(condition->codeLoc, BOOL_ASSIGN_AS_CONDITION,
"If you want to compare the values, use '=='");

Expand Down Expand Up @@ -800,7 +800,7 @@ std::any TypeChecker::visitAssignExpr(AssignExprNode *node) {
}

// Check if assignment
if (node->hasOperator) {
if (node->op != AssignExprNode::OP_NONE) {
// Visit the right side first
auto [rhsType, rhsEntry] = std::any_cast<ExprResult>(visit(node->rhs()));
HANDLE_UNRESOLVED_TYPE_ER(rhsType)
Expand Down Expand Up @@ -2486,8 +2486,8 @@ void TypeChecker::autoDeReference(SymbolType &symbolType) {
* @return Op fct pointer list
*/
std::vector<const Function *> &TypeChecker::getOpFctPointers(ASTNode *node) const {
assert(node->opFct.size() > manIdx);
return node->opFct.at(manIdx);
assert(node->getOpFctPointers()->size() > manIdx);
return node->getOpFctPointers()->at(manIdx);
}

/**
Expand Down

0 comments on commit f96f19a

Please sign in to comment.