Skip to content

Commit

Permalink
Add support for matching struct to compatible interface type on fct c…
Browse files Browse the repository at this point in the history
…alls (#488)
  • Loading branch information
marcauberer committed Mar 3, 2024
1 parent c3d782d commit 9c6888b
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 54 deletions.
73 changes: 57 additions & 16 deletions media/test-project/test.spice
@@ -1,21 +1,62 @@
import "std/io/filepath";
import "std/os/os";
import "bootstrap/bindings/llvm/llvm" as llvm;
import "std/data/vector";

f<int> main() {
FilePath path = FilePath("C:\Users\Public\Documents");
path /= "test.txt";
assert len(path.toString()) == 34;
string expectedString = isWindows() ? "C:\\Users\\Public\\Documents\\test.txt" : "C:\\Users\\Public\\Documents/test.txt";
assert path.toString() == expectedString;
expectedString = isWindows() ? "C:\\Users\\Public\\Documents\\test.txt" : "C:/Users/Public/Documents/test.txt";
assert path.toNativeString() == expectedString;
assert path.toGenericString() == "C:/Users/Public/Documents/test.txt";

assert path.getFileName() == "test.txt";
assert path.getExtension() == "txt";
assert path.getBaseName() == "test";

printf("All assertions passed!");
llvm::initializeNativeTarget();
llvm::initializeNativeAsmPrinter();

heap string targetTriple = llvm::getDefaultTargetTriple();
string error;
llvm::Target target = llvm::getTargetFromTriple(targetTriple, &error);
llvm::TargetMachine targetMachine = target.createTargetMachine(targetTriple, "generic", "", llvm::LLVMCodeGenOptLevel::Default, llvm::LLVMRelocMode::Default, llvm::LLVMCodeModel::Default);

llvm::LLVMContext context;
llvm::Module module = llvm::Module("test", context);
module.setDataLayout(targetMachine.createDataLayout());
//module.setTargetTriple(targetTriple); // This emits target dependent information in the IR, which is not what we want here.
llvm::Builder builder = llvm::Builder(context);

llvm::Type returnType = builder.getInt32Ty();
Vector<llvm::Type> argTypes;
llvm::Type funcType = llvm::getFunctionType(returnType, argTypes);
llvm::Function func = llvm::Function(module, "main", funcType);
func.setLinkage(llvm::LLVMLinkage::ExternalLinkage);

llvm::BasicBlock entry = llvm::BasicBlock(context, "");
func.pushBack(entry);
builder.setInsertPoint(entry);

llvm::Value calcResult = builder.createAdd(builder.getInt32(1), builder.getInt32(2), "calcResult");

llvm::Value helloWorldStr = builder.createGlobalStringPtr("Hello, world!\n", "helloWorldStr");
Vector<llvm::Type> printfArgTypes;
printfArgTypes.pushBack(builder.getPtrTy());
printfArgTypes.pushBack(builder.getInt32Ty());
llvm::Type printfFuncType = llvm::getFunctionType(builder.getInt32Ty(), printfArgTypes, true);
llvm::Function printfFunc = module.getOrInsertFunction("printf", printfFuncType);

Vector<llvm::Value> printfArgs;
printfArgs.pushBack(helloWorldStr);
printfArgs.pushBack(calcResult);
builder.createCall(printfFunc, printfArgs);

builder.createRet(builder.getInt32(0));

assert !llvm::verifyFunction(func);
string output;
assert !llvm::verifyModule(module, &output);

printf("Unoptimized IR:\n%s", module.print());

llvm::PassBuilderOptions pto;
llvm::PassBuilder passBuilder = llvm::PassBuilder(pto);
passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O2);
passBuilder.addPass(llvm::AlwaysInlinerPass());
passBuilder.run(module, targetMachine);

printf("Optimized IR:\n%s", module.print());

targetMachine.emitToFile(module, "this-is-a-test.o", llvm::LLVMCodeGenFileType::ObjectFile);
}

/*import "bootstrap/util/block-allocator";
Expand Down
8 changes: 1 addition & 7 deletions src-bootstrap/bindings/llvm/llvm.spice
Expand Up @@ -1085,13 +1085,7 @@ public p PassBuilder.addPass(string pass) {
this.passes.pushBack(String(pass));
}

// ToDo: Uncomment if this is working
/*public p PassBuilder.addPass(const PassInfo& pass) {
this.passes.pushBack(String(pass.getOption()));
}*/

// ToDo: Delete if the code above is working
public p PassBuilder.addPass(const AlwaysInlinerPass& pass) {
public p PassBuilder.addPass(const PassInfo& pass) {
this.passes.pushBack(String(pass.getOption()));
}

Expand Down
10 changes: 7 additions & 3 deletions src/irgenerator/GenValues.cpp
Expand Up @@ -145,15 +145,19 @@ std::any IRGenerator::visitFctCall(const FctCallNode *node) {
const SymbolType &expectedSTy = paramSTypes.at(i);
const SymbolType &actualSTy = argNode->getEvaluatedSymbolType(manIdx);

const auto matchFct = [](const SymbolType &lhsTy, const SymbolType &rhsTy) {
return lhsTy.matches(rhsTy, false, true, true) || lhsTy.matchesInterfaceImplementedByStruct(rhsTy);
};

// If the arrays are both of size -1 or 0, they are both pointers and do not need to be cast implicitly
if (expectedSTy.matches(actualSTy, false, true, true)) { // Matches the param type
if (matchFct(expectedSTy, actualSTy)) {
// Resolve address if actual type is reference, otherwise value
llvm::Value *argValue = actualSTy.isRef() ? resolveAddress(argNode) : resolveValue(argNode);
argValues.push_back(argValue);
} else if (expectedSTy.isRef() && expectedSTy.getContainedTy().matches(actualSTy, false, true, true)) { // Matches with ref
} else if (expectedSTy.isRef() && matchFct(expectedSTy.getContainedTy(), actualSTy)) { // Matches with ref
llvm::Value *argAddress = resolveAddress(argNode);
argValues.push_back(argAddress);
} else if (actualSTy.isRef() && expectedSTy.matches(actualSTy.getContainedTy(), false, true, true)) { // Matches with ref
} else if (actualSTy.isRef() && matchFct(expectedSTy, actualSTy.getContainedTy())) { // Matches with ref
llvm::Value *argAddress = resolveValue(argNode);
argValues.push_back(argAddress);
} else { // Need implicit cast
Expand Down
22 changes: 22 additions & 0 deletions src/symboltablebuilder/SymbolType.cpp
Expand Up @@ -606,6 +606,28 @@ bool SymbolType::matches(const SymbolType &otherType, bool ignoreArraySize, bool
return ignoreSpecifiers || specifiers.match(otherType.specifiers, allowConstify);
}

/**
* Check for the matching compatibility of two types in terms of interface implementation.
* Useful for function matching as well as assignment type validation and function arg matching.
*
* @param otherType Type to compare against
* @return Matching or not
*/
bool SymbolType::matchesInterfaceImplementedByStruct(const SymbolType &otherType) const {
if (!is(TY_INTERFACE) || !otherType.is(TY_STRUCT))
return false;

// Check if the rhs is a struct type that implements the lhs interface type
const Struct *spiceStruct = otherType.getStruct(nullptr);
assert(spiceStruct != nullptr);
for (const SymbolType &interfaceType : spiceStruct->interfaceTypes) {
assert(interfaceType.is(TY_INTERFACE));
if (matches(interfaceType, false, false, true))
return true;
}
return false;
}

/**
* Check if a certain input type can be bound (assigned) to the current type.
*
Expand Down
1 change: 1 addition & 0 deletions src/symboltablebuilder/SymbolType.h
Expand Up @@ -214,6 +214,7 @@ class SymbolType {
friend bool operator==(const SymbolType &lhs, const SymbolType &rhs);
friend bool operator!=(const SymbolType &lhs, const SymbolType &rhs);
[[nodiscard]] bool matches(const SymbolType &otherType, bool ignoreArraySize, bool ignoreSpecifiers, bool allowConstify) const;
[[nodiscard]] bool matchesInterfaceImplementedByStruct(const SymbolType &otherType) const;
[[nodiscard]] bool canBind(const SymbolType &otherType, bool isTemporary) const;

// Static util methods
Expand Down
17 changes: 6 additions & 11 deletions src/typechecker/OpRuleManager.cpp
Expand Up @@ -97,7 +97,7 @@ SymbolType OpRuleManager::getAssignResultTypeCommon(const ASTNode *node, const E
// Allow type to ref type of the same contained type straight away
if (lhsType.isRef() && lhsType.getContainedTy().matches(rhsType, false, false, true)) {
if (isDecl && !lhsType.canBind(rhsType, rhs.isTemporary()))
throw SemanticError(node, TEMP_TO_NON_CONST_REF, "Temporary values can only be bound to const reference parameters");
throw SemanticError(node, TEMP_TO_NON_CONST_REF, "Temporary values can only be bound to const reference variables/fields");
return lhsType;
}
// Allow dyn[] (empty array literal) to any array
Expand All @@ -109,20 +109,15 @@ SymbolType OpRuleManager::getAssignResultTypeCommon(const ASTNode *node, const E
// Allow array to pointer
if (lhsType.isPtr() && rhsType.isArray() && lhsType.getContainedTy().matches(rhsType.getContainedTy(), false, false, true))
return lhsType;
// Allow interface* = struct* that implements this interface
// Allow interface* = struct* or interface& = struct that implements this interface
const bool sameChainDepth = lhsType.typeChain.size() == rhsType.typeChain.size();
if (lhsType.isPtr() && rhsType.isPtr() && sameChainDepth && lhsType.isBaseType(TY_INTERFACE) && rhsType.isBaseType(TY_STRUCT)) {
const bool typesCompatible = (lhsType.isPtr() && rhsType.isPtr() && sameChainDepth) || lhsType.isRef();
if (typesCompatible && lhsType.isBaseType(TY_INTERFACE) && rhsType.isBaseType(TY_STRUCT)) {
SymbolType lhsTypeCopy = lhsType;
SymbolType rhsTypeCopy = rhsType;
SymbolType::unwrapBoth(lhsTypeCopy, rhsTypeCopy);

Struct *spiceStruct = rhsTypeCopy.getStruct(node);
assert(spiceStruct != nullptr);
for (const SymbolType &interfaceType : spiceStruct->interfaceTypes) {
assert(interfaceType.is(TY_INTERFACE));
if (lhsTypeCopy.matches(interfaceType, false, false, true))
return lhsType;
}
if (lhsTypeCopy.matchesInterfaceImplementedByStruct(rhsTypeCopy))
return lhsType;
}

// Nothing matched
Expand Down
5 changes: 3 additions & 2 deletions src/typechecker/TypeChecker.cpp
Expand Up @@ -1485,8 +1485,9 @@ std::any TypeChecker::visitAtomicExpr(AtomicExprNode *node) {
// Check if overloaded function was referenced
const std::vector<Function *> *manifestations = varEntry->declNode->getFctManifestations(varEntry->name);
if (manifestations->size() > 1)
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT,
"Overloaded functions or functions with optional parameters cannot be referenced")
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT, "Overloaded functions / functions with optional params cannot be referenced")
if (!manifestations->front()->templateTypes.empty())
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT, "Generic functions cannot be referenced")
// Set referenced function to used
Function *referencedFunction = manifestations->front();
referencedFunction->used = true;
Expand Down
7 changes: 6 additions & 1 deletion src/typechecker/TypeMatcher.cpp
Expand Up @@ -32,8 +32,13 @@ bool TypeMatcher::matchRequestedToCandidateType(SymbolType candidateType, Symbol
SymbolType::unwrapBoth(candidateType, requestedType);

// If the candidate does not contain any generic parts, we can simply check for type equality
if (!candidateType.hasAnyGenericParts())
if (!candidateType.hasAnyGenericParts()) {
// Check if the right one is a struct that implements the interface on the left
if (candidateType.matchesInterfaceImplementedByStruct(requestedType))
return true;
// Normal equality check
return candidateType.matches(requestedType, true, !strictSpecifierMatching, true);
}

// Check if the candidate type itself is generic
if (candidateType.isBaseType(TY_GENERIC)) { // The candidate type itself is generic
Expand Down
@@ -0,0 +1 @@
Test
@@ -0,0 +1,71 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

%struct.Test = type { %interface.ITest }
%interface.ITest = type { ptr }

$_ZTS5ITest = comdat any

$_ZTI5ITest = comdat any

$_ZTV5ITest = comdat any

$_ZTS4Test = comdat any

$_ZTI4Test = comdat any

$_ZTV4Test = comdat any

@_ZTS5ITest = dso_local constant [7 x i8] c"5ITest\00", comdat, align 1
@_ZTV8TypeInfo = external global ptr
@_ZTI5ITest = dso_local constant { ptr, ptr } { ptr getelementptr inbounds (ptr, ptr @_ZTV8TypeInfo, i64 2), ptr @_ZTS5ITest }, comdat, align 8
@_ZTV5ITest = dso_local unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr @_ZTI5ITest, ptr null] }, comdat, align 8
@_ZTS4Test = dso_local constant [6 x i8] c"4Test\00", comdat, align 1
@_ZTI4Test = dso_local constant { ptr, ptr, ptr } { ptr getelementptr inbounds (ptr, ptr @_ZTV8TypeInfo, i64 2), ptr @_ZTS4Test, ptr @_ZTI5ITest }, comdat, align 8
@_ZTV4Test = dso_local unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr @_ZTI4Test, ptr @_ZN4Test4testEv] }, comdat, align 8
@printf.str.0 = private unnamed_addr constant [5 x i8] c"Test\00", align 1

define private void @_ZN4Test4ctorEv(ptr noundef nonnull align 8 dereferenceable(8) %0) {
%this = alloca ptr, align 8
store ptr %0, ptr %this, align 8
%2 = load ptr, ptr %this, align 8
store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, i32 0, i32 2), ptr %2, align 8
ret void
}

define private void @_ZN4Test4testEv(ptr noundef nonnull align 8 dereferenceable(8) %0) {
%this = alloca ptr, align 8
store ptr %0, ptr %this, align 8
%2 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0)
ret void
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #0

define private void @_Z7testFctR5ITest(ptr %0) {
%test = alloca ptr, align 8
store ptr %0, ptr %test, align 8
%2 = load ptr, ptr %test, align 8
%vtable.addr = load ptr, ptr %2, align 8
%vfct.addr = getelementptr inbounds ptr, ptr %vtable.addr, i64 0
%fct = load ptr, ptr %vfct.addr, align 8
call void %fct(ptr noundef nonnull align 8 dereferenceable(8) %2)
ret void
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #1 {
%result = alloca i32, align 4
%test = alloca %struct.Test, align 8
%itest = alloca ptr, align 8
store i32 0, ptr %result, align 4
call void @_ZN4Test4ctorEv(ptr noundef nonnull align 8 dereferenceable(8) %test)
store ptr %test, ptr %itest, align 8
%1 = load ptr, ptr %itest, align 8
call void @_Z7testFctR5ITest(ptr %1)
ret i32 0
}

attributes #0 = { nofree nounwind }
attributes #1 = { noinline nounwind optnone uwtable }
@@ -0,0 +1,22 @@
type ITest interface {
p test();
}

type Test struct : ITest {}

p Test.ctor() {}

p Test.test() {
printf("Test");
}

p testFct(ITest& test) {
test.test();
}

f<int> main() {
Test test = Test();
ITest& itest = test;
testFct(itest);
return 0;
}
@@ -0,0 +1,8 @@
[Error|Compiler]:
Unresolved soft errors: There are unresolved errors. Please fix them and recompile.

[Error|Semantic] ./source.spice:8:21:
Referenced overloaded function: Generic functions cannot be referenced

8 p(const T&) s = test;
^^^^
@@ -0,0 +1,10 @@
type T dyn;

p test<T>(const T& t) {
printf("%d\n", t);
}

f<int> main() {
p(const T&) s = test;
s(123);
}
Expand Up @@ -2,7 +2,7 @@
Unresolved soft errors: There are unresolved errors. Please fix them and recompile.

[Error|Semantic] ./source.spice:4:21:
Referenced overloaded function: Overloaded functions or functions with optional parameters cannot be referenced
Referenced overloaded function: Overloaded functions / functions with optional params cannot be referenced

4 p(double) fct = overloadedFct;
^^^^^^^^^^^^^

This file was deleted.

This file was deleted.

@@ -1,5 +1,5 @@
[Error|Semantic] ./source.spice:2:5:
Temporary bound to non-const reference: Temporary values can only be bound to const reference parameters
Temporary bound to non-const reference: Temporary values can only be bound to const reference variables/fields

2 int& i = 123;
^^^^^^^^^^^^

0 comments on commit 9c6888b

Please sign in to comment.