Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for matching struct to compatible interface type on fct calls #488

Merged
merged 2 commits into from Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 57 additions & 16 deletions media/test-project/test.spice
@@ -1,21 +1,62 @@
import "std/io/filepath";
import "std/os/os";
import "bootstrap/bindings/llvm/llvm" as llvm;
import "std/data/vector";

f<int> main() {
FilePath path = FilePath("C:\Users\Public\Documents");
path /= "test.txt";
assert len(path.toString()) == 34;
string expectedString = isWindows() ? "C:\\Users\\Public\\Documents\\test.txt" : "C:\\Users\\Public\\Documents/test.txt";
assert path.toString() == expectedString;
expectedString = isWindows() ? "C:\\Users\\Public\\Documents\\test.txt" : "C:/Users/Public/Documents/test.txt";
assert path.toNativeString() == expectedString;
assert path.toGenericString() == "C:/Users/Public/Documents/test.txt";

assert path.getFileName() == "test.txt";
assert path.getExtension() == "txt";
assert path.getBaseName() == "test";

printf("All assertions passed!");
llvm::initializeNativeTarget();
llvm::initializeNativeAsmPrinter();

heap string targetTriple = llvm::getDefaultTargetTriple();
string error;
llvm::Target target = llvm::getTargetFromTriple(targetTriple, &error);
llvm::TargetMachine targetMachine = target.createTargetMachine(targetTriple, "generic", "", llvm::LLVMCodeGenOptLevel::Default, llvm::LLVMRelocMode::Default, llvm::LLVMCodeModel::Default);

llvm::LLVMContext context;
llvm::Module module = llvm::Module("test", context);
module.setDataLayout(targetMachine.createDataLayout());
//module.setTargetTriple(targetTriple); // This emits target dependent information in the IR, which is not what we want here.
llvm::Builder builder = llvm::Builder(context);

llvm::Type returnType = builder.getInt32Ty();
Vector<llvm::Type> argTypes;
llvm::Type funcType = llvm::getFunctionType(returnType, argTypes);
llvm::Function func = llvm::Function(module, "main", funcType);
func.setLinkage(llvm::LLVMLinkage::ExternalLinkage);

llvm::BasicBlock entry = llvm::BasicBlock(context, "");
func.pushBack(entry);
builder.setInsertPoint(entry);

llvm::Value calcResult = builder.createAdd(builder.getInt32(1), builder.getInt32(2), "calcResult");

llvm::Value helloWorldStr = builder.createGlobalStringPtr("Hello, world!\n", "helloWorldStr");
Vector<llvm::Type> printfArgTypes;
printfArgTypes.pushBack(builder.getPtrTy());
printfArgTypes.pushBack(builder.getInt32Ty());
llvm::Type printfFuncType = llvm::getFunctionType(builder.getInt32Ty(), printfArgTypes, true);
llvm::Function printfFunc = module.getOrInsertFunction("printf", printfFuncType);

Vector<llvm::Value> printfArgs;
printfArgs.pushBack(helloWorldStr);
printfArgs.pushBack(calcResult);
builder.createCall(printfFunc, printfArgs);

builder.createRet(builder.getInt32(0));

assert !llvm::verifyFunction(func);
string output;
assert !llvm::verifyModule(module, &output);

printf("Unoptimized IR:\n%s", module.print());

llvm::PassBuilderOptions pto;
llvm::PassBuilder passBuilder = llvm::PassBuilder(pto);
passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O2);
passBuilder.addPass(llvm::AlwaysInlinerPass());
passBuilder.run(module, targetMachine);

printf("Optimized IR:\n%s", module.print());

targetMachine.emitToFile(module, "this-is-a-test.o", llvm::LLVMCodeGenFileType::ObjectFile);
}

/*import "bootstrap/util/block-allocator";
Expand Down
8 changes: 1 addition & 7 deletions src-bootstrap/bindings/llvm/llvm.spice
Expand Up @@ -1085,13 +1085,7 @@ public p PassBuilder.addPass(string pass) {
this.passes.pushBack(String(pass));
}

// ToDo: Uncomment if this is working
/*public p PassBuilder.addPass(const PassInfo& pass) {
this.passes.pushBack(String(pass.getOption()));
}*/

// ToDo: Delete if the code above is working
public p PassBuilder.addPass(const AlwaysInlinerPass& pass) {
public p PassBuilder.addPass(const PassInfo& pass) {
this.passes.pushBack(String(pass.getOption()));
}

Expand Down
10 changes: 7 additions & 3 deletions src/irgenerator/GenValues.cpp
Expand Up @@ -145,15 +145,19 @@ std::any IRGenerator::visitFctCall(const FctCallNode *node) {
const SymbolType &expectedSTy = paramSTypes.at(i);
const SymbolType &actualSTy = argNode->getEvaluatedSymbolType(manIdx);

const auto matchFct = [](const SymbolType &lhsTy, const SymbolType &rhsTy) {
return lhsTy.matches(rhsTy, false, true, true) || lhsTy.matchesInterfaceImplementedByStruct(rhsTy);
};

// If the arrays are both of size -1 or 0, they are both pointers and do not need to be cast implicitly
if (expectedSTy.matches(actualSTy, false, true, true)) { // Matches the param type
if (matchFct(expectedSTy, actualSTy)) {
// Resolve address if actual type is reference, otherwise value
llvm::Value *argValue = actualSTy.isRef() ? resolveAddress(argNode) : resolveValue(argNode);
argValues.push_back(argValue);
} else if (expectedSTy.isRef() && expectedSTy.getContainedTy().matches(actualSTy, false, true, true)) { // Matches with ref
} else if (expectedSTy.isRef() && matchFct(expectedSTy.getContainedTy(), actualSTy)) { // Matches with ref
llvm::Value *argAddress = resolveAddress(argNode);
argValues.push_back(argAddress);
} else if (actualSTy.isRef() && expectedSTy.matches(actualSTy.getContainedTy(), false, true, true)) { // Matches with ref
} else if (actualSTy.isRef() && matchFct(expectedSTy, actualSTy.getContainedTy())) { // Matches with ref
llvm::Value *argAddress = resolveValue(argNode);
argValues.push_back(argAddress);
} else { // Need implicit cast
Expand Down
22 changes: 22 additions & 0 deletions src/symboltablebuilder/SymbolType.cpp
Expand Up @@ -606,6 +606,28 @@ bool SymbolType::matches(const SymbolType &otherType, bool ignoreArraySize, bool
return ignoreSpecifiers || specifiers.match(otherType.specifiers, allowConstify);
}

/**
* Check for the matching compatibility of two types in terms of interface implementation.
* Useful for function matching as well as assignment type validation and function arg matching.
*
* @param otherType Type to compare against
* @return Matching or not
*/
bool SymbolType::matchesInterfaceImplementedByStruct(const SymbolType &otherType) const {
if (!is(TY_INTERFACE) || !otherType.is(TY_STRUCT))
return false;

// Check if the rhs is a struct type that implements the lhs interface type
const Struct *spiceStruct = otherType.getStruct(nullptr);
assert(spiceStruct != nullptr);
for (const SymbolType &interfaceType : spiceStruct->interfaceTypes) {
assert(interfaceType.is(TY_INTERFACE));
if (matches(interfaceType, false, false, true))
return true;
}
return false;
}

/**
* Check if a certain input type can be bound (assigned) to the current type.
*
Expand Down
1 change: 1 addition & 0 deletions src/symboltablebuilder/SymbolType.h
Expand Up @@ -214,6 +214,7 @@ class SymbolType {
friend bool operator==(const SymbolType &lhs, const SymbolType &rhs);
friend bool operator!=(const SymbolType &lhs, const SymbolType &rhs);
[[nodiscard]] bool matches(const SymbolType &otherType, bool ignoreArraySize, bool ignoreSpecifiers, bool allowConstify) const;
[[nodiscard]] bool matchesInterfaceImplementedByStruct(const SymbolType &otherType) const;
[[nodiscard]] bool canBind(const SymbolType &otherType, bool isTemporary) const;

// Static util methods
Expand Down
17 changes: 6 additions & 11 deletions src/typechecker/OpRuleManager.cpp
Expand Up @@ -97,7 +97,7 @@ SymbolType OpRuleManager::getAssignResultTypeCommon(const ASTNode *node, const E
// Allow type to ref type of the same contained type straight away
if (lhsType.isRef() && lhsType.getContainedTy().matches(rhsType, false, false, true)) {
if (isDecl && !lhsType.canBind(rhsType, rhs.isTemporary()))
throw SemanticError(node, TEMP_TO_NON_CONST_REF, "Temporary values can only be bound to const reference parameters");
throw SemanticError(node, TEMP_TO_NON_CONST_REF, "Temporary values can only be bound to const reference variables/fields");
return lhsType;
}
// Allow dyn[] (empty array literal) to any array
Expand All @@ -109,20 +109,15 @@ SymbolType OpRuleManager::getAssignResultTypeCommon(const ASTNode *node, const E
// Allow array to pointer
if (lhsType.isPtr() && rhsType.isArray() && lhsType.getContainedTy().matches(rhsType.getContainedTy(), false, false, true))
return lhsType;
// Allow interface* = struct* that implements this interface
// Allow interface* = struct* or interface& = struct that implements this interface
const bool sameChainDepth = lhsType.typeChain.size() == rhsType.typeChain.size();
if (lhsType.isPtr() && rhsType.isPtr() && sameChainDepth && lhsType.isBaseType(TY_INTERFACE) && rhsType.isBaseType(TY_STRUCT)) {
const bool typesCompatible = (lhsType.isPtr() && rhsType.isPtr() && sameChainDepth) || lhsType.isRef();
if (typesCompatible && lhsType.isBaseType(TY_INTERFACE) && rhsType.isBaseType(TY_STRUCT)) {
SymbolType lhsTypeCopy = lhsType;
SymbolType rhsTypeCopy = rhsType;
SymbolType::unwrapBoth(lhsTypeCopy, rhsTypeCopy);

Struct *spiceStruct = rhsTypeCopy.getStruct(node);
assert(spiceStruct != nullptr);
for (const SymbolType &interfaceType : spiceStruct->interfaceTypes) {
assert(interfaceType.is(TY_INTERFACE));
if (lhsTypeCopy.matches(interfaceType, false, false, true))
return lhsType;
}
if (lhsTypeCopy.matchesInterfaceImplementedByStruct(rhsTypeCopy))
return lhsType;
}

// Nothing matched
Expand Down
5 changes: 3 additions & 2 deletions src/typechecker/TypeChecker.cpp
Expand Up @@ -1485,8 +1485,9 @@ std::any TypeChecker::visitAtomicExpr(AtomicExprNode *node) {
// Check if overloaded function was referenced
const std::vector<Function *> *manifestations = varEntry->declNode->getFctManifestations(varEntry->name);
if (manifestations->size() > 1)
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT,
"Overloaded functions or functions with optional parameters cannot be referenced")
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT, "Overloaded functions / functions with optional params cannot be referenced")
if (!manifestations->front()->templateTypes.empty())
SOFT_ERROR_ER(node, REFERENCED_OVERLOADED_FCT, "Generic functions cannot be referenced")
// Set referenced function to used
Function *referencedFunction = manifestations->front();
referencedFunction->used = true;
Expand Down
7 changes: 6 additions & 1 deletion src/typechecker/TypeMatcher.cpp
Expand Up @@ -32,8 +32,13 @@ bool TypeMatcher::matchRequestedToCandidateType(SymbolType candidateType, Symbol
SymbolType::unwrapBoth(candidateType, requestedType);

// If the candidate does not contain any generic parts, we can simply check for type equality
if (!candidateType.hasAnyGenericParts())
if (!candidateType.hasAnyGenericParts()) {
// Check if the right one is a struct that implements the interface on the left
if (candidateType.matchesInterfaceImplementedByStruct(requestedType))
return true;
// Normal equality check
return candidateType.matches(requestedType, true, !strictSpecifierMatching, true);
}

// Check if the candidate type itself is generic
if (candidateType.isBaseType(TY_GENERIC)) { // The candidate type itself is generic
Expand Down
@@ -0,0 +1 @@
Test
@@ -0,0 +1,71 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

%struct.Test = type { %interface.ITest }
%interface.ITest = type { ptr }

$_ZTS5ITest = comdat any

$_ZTI5ITest = comdat any

$_ZTV5ITest = comdat any

$_ZTS4Test = comdat any

$_ZTI4Test = comdat any

$_ZTV4Test = comdat any

@_ZTS5ITest = dso_local constant [7 x i8] c"5ITest\00", comdat, align 1
@_ZTV8TypeInfo = external global ptr
@_ZTI5ITest = dso_local constant { ptr, ptr } { ptr getelementptr inbounds (ptr, ptr @_ZTV8TypeInfo, i64 2), ptr @_ZTS5ITest }, comdat, align 8
@_ZTV5ITest = dso_local unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr @_ZTI5ITest, ptr null] }, comdat, align 8
@_ZTS4Test = dso_local constant [6 x i8] c"4Test\00", comdat, align 1
@_ZTI4Test = dso_local constant { ptr, ptr, ptr } { ptr getelementptr inbounds (ptr, ptr @_ZTV8TypeInfo, i64 2), ptr @_ZTS4Test, ptr @_ZTI5ITest }, comdat, align 8
@_ZTV4Test = dso_local unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr @_ZTI4Test, ptr @_ZN4Test4testEv] }, comdat, align 8
@printf.str.0 = private unnamed_addr constant [5 x i8] c"Test\00", align 1

define private void @_ZN4Test4ctorEv(ptr noundef nonnull align 8 dereferenceable(8) %0) {
%this = alloca ptr, align 8
store ptr %0, ptr %this, align 8
%2 = load ptr, ptr %this, align 8
store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, i32 0, i32 2), ptr %2, align 8
ret void
}

define private void @_ZN4Test4testEv(ptr noundef nonnull align 8 dereferenceable(8) %0) {
%this = alloca ptr, align 8
store ptr %0, ptr %this, align 8
%2 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0)
ret void
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #0

define private void @_Z7testFctR5ITest(ptr %0) {
%test = alloca ptr, align 8
store ptr %0, ptr %test, align 8
%2 = load ptr, ptr %test, align 8
%vtable.addr = load ptr, ptr %2, align 8
%vfct.addr = getelementptr inbounds ptr, ptr %vtable.addr, i64 0
%fct = load ptr, ptr %vfct.addr, align 8
call void %fct(ptr noundef nonnull align 8 dereferenceable(8) %2)
ret void
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #1 {
%result = alloca i32, align 4
%test = alloca %struct.Test, align 8
%itest = alloca ptr, align 8
store i32 0, ptr %result, align 4
call void @_ZN4Test4ctorEv(ptr noundef nonnull align 8 dereferenceable(8) %test)
store ptr %test, ptr %itest, align 8
%1 = load ptr, ptr %itest, align 8
call void @_Z7testFctR5ITest(ptr %1)
ret i32 0
}

attributes #0 = { nofree nounwind }
attributes #1 = { noinline nounwind optnone uwtable }
@@ -0,0 +1,22 @@
type ITest interface {
p test();
}

type Test struct : ITest {}

p Test.ctor() {}

p Test.test() {
printf("Test");
}

p testFct(ITest& test) {
test.test();
}

f<int> main() {
Test test = Test();
ITest& itest = test;
testFct(itest);
return 0;
}
@@ -0,0 +1,8 @@
[Error|Compiler]:
Unresolved soft errors: There are unresolved errors. Please fix them and recompile.

[Error|Semantic] ./source.spice:8:21:
Referenced overloaded function: Generic functions cannot be referenced

8 p(const T&) s = test;
^^^^
@@ -0,0 +1,10 @@
type T dyn;

p test<T>(const T& t) {
printf("%d\n", t);
}

f<int> main() {
p(const T&) s = test;
s(123);
}
Expand Up @@ -2,7 +2,7 @@
Unresolved soft errors: There are unresolved errors. Please fix them and recompile.

[Error|Semantic] ./source.spice:4:21:
Referenced overloaded function: Overloaded functions or functions with optional parameters cannot be referenced
Referenced overloaded function: Overloaded functions / functions with optional params cannot be referenced

4 p(double) fct = overloadedFct;
^^^^^^^^^^^^^

This file was deleted.

This file was deleted.

@@ -1,5 +1,5 @@
[Error|Semantic] ./source.spice:2:5:
Temporary bound to non-const reference: Temporary values can only be bound to const reference parameters
Temporary bound to non-const reference: Temporary values can only be bound to const reference variables/fields

2 int& i = 123;
^^^^^^^^^^^^