Skip to content

Commit

Permalink
Avoids starting mocks "half-way" if a superclass constructor is mocke…
Browse files Browse the repository at this point in the history
…d but an unmocked subclass is initiated.
  • Loading branch information
raphw committed Jun 11, 2022
1 parent ce4e64d commit f28941a
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 0 deletions.
Expand Up @@ -251,6 +251,7 @@ class InlineDelegateByteBuddyMockMaker

ThreadLocal<Class<?>> currentConstruction = new ThreadLocal<>();
ThreadLocal<Boolean> isSuspended = ThreadLocal.withInitial(() -> false);
Predicate<Class<?>> isCallFromSubclassConstructor = StackWalkerChecker.orFallback();
Predicate<Class<?>> isMockConstruction =
type -> {
if (isSuspended.get()) {
Expand All @@ -260,6 +261,11 @@ class InlineDelegateByteBuddyMockMaker
}
Map<Class<?>, ?> interceptors = mockedConstruction.get();
if (interceptors != null && interceptors.containsKey(type)) {
// We only initiate a construction mock, if the call originates from an
// un-mocked (as suppression is not enabled) subclass constructor.
if (isCallFromSubclassConstructor.test(type)) {
return false;
}
currentConstruction.set(type);
return true;
} else {
Expand Down
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2022 Mockito contributors
* This program is made available under the terms of the MIT License.
*/
package org.mockito.internal.creation.bytebuddy;

import java.util.function.Predicate;

class StackTraceChecker implements Predicate<Class<?>> {

@Override
public boolean test(Class<?> type) {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
for (int index = 1; index < stackTrace.length - 1; index++) {
if (!stackTrace[index].getClassName().startsWith("org.mockito.internal.")) {
if (stackTrace[index + 1].getMethodName().startsWith("<init>")) {
try {
if (!stackTrace[index + 1].getClassName().equals(type.getName())
&& type.isAssignableFrom(
Class.forName(
stackTrace[index + 1].getClassName(),
false,
type.getClassLoader()))) {
return true;
} else {
break;
}
} catch (ClassNotFoundException ignored) {
break;
}
}
}
}
return false;
}
}
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2022 Mockito contributors
* This program is made available under the terms of the MIT License.
*/
package org.mockito.internal.creation.bytebuddy;

import java.lang.reflect.Method;
import java.util.Iterator;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

class StackWalkerChecker implements Predicate<Class<?>> {

private final Method stackWalkerGetInstance;
private final Method stackWalkerWalk;
private final Method stackWalkerStackFrameGetDeclaringClass;
private final Enum<?> stackWalkerOptionRetainClassReference;

StackWalkerChecker() throws Exception {
Class<?> stackWalker = Class.forName("java.lang.StackWalker");
@SuppressWarnings({"unchecked", "rawtypes"})
Class<? extends Enum> stackWalkerOption =
(Class<? extends Enum<?>>) Class.forName("java.lang.StackWalker$Option");
stackWalkerGetInstance = stackWalker.getMethod("getInstance", stackWalkerOption);
stackWalkerWalk = stackWalker.getMethod("walk", Function.class);
Class<?> stackWalkerStackFrame = Class.forName("java.lang.StackWalker$StackFrame");
stackWalkerStackFrameGetDeclaringClass =
stackWalkerStackFrame.getMethod("getDeclaringClass");
@SuppressWarnings("unchecked")
Enum<?> stackWalkerOptionRetainClassReference =
Enum.valueOf(stackWalkerOption, "RETAIN_CLASS_REFERENCE");
this.stackWalkerOptionRetainClassReference = stackWalkerOptionRetainClassReference;
}

static Predicate<Class<?>> orFallback() {
try {
return new StackWalkerChecker();
} catch (Exception e) {
return new StackTraceChecker();
}
}

@Override
public boolean test(Class<?> type) {
try {
Object walker =
stackWalkerGetInstance.invoke(null, stackWalkerOptionRetainClassReference);
return (Boolean)
stackWalkerWalk.invoke(
walker,
(Function<?, ?>)
stream -> {
Iterator<?> iterator = ((Stream<?>) stream).iterator();
while (iterator.hasNext()) {
try {
Object frame = iterator.next();
if (((Class<?>)
stackWalkerStackFrameGetDeclaringClass
.invoke(frame))
.getName()
.startsWith("org.mockito.internal.")) {
continue;
}
if (iterator.hasNext()) {
Object next = iterator.next();
Class<?> declaringClass =
(Class<?>)
stackWalkerStackFrameGetDeclaringClass
.invoke(next);
if (type != declaringClass
&& type.isAssignableFrom(
declaringClass)) {
return true;
} else {
break;
}
} else {
break;
}
} catch (Exception ignored) {
return false;
}
}
return false;
});
} catch (Exception ignored) {
return false;
}
}
}
@@ -0,0 +1,43 @@
package org.mockitoinline;

import org.junit.Test;
import org.mockito.MockedConstruction;
import org.mockito.Mockito;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

public class SubconstructorMockTest {

@Test
public void does_not_mock_subclass_constructor_for_superclass_mock() {
try (MockedConstruction<SubClass> mocked = Mockito.mockConstruction(SubClass.class)) { }
try (MockedConstruction<SuperClass> mocked = Mockito.mockConstruction(SuperClass.class)) {
SubClass value = new SubClass();
assertTrue(value.sup());
assertTrue(value.sub());
}
}

@Test
public void does_mock_superclass_constructor_for_subclass_mock() {
try (MockedConstruction<SuperClass> mocked = Mockito.mockConstruction(SuperClass.class)) { }
try (MockedConstruction<SubClass> mocked = Mockito.mockConstruction(SubClass.class)) {
SubClass value = new SubClass();
assertFalse(value.sup());
assertFalse(value.sub());
}
}

public static class SuperClass {
public boolean sup() {
return true;
}
}

public static class SubClass extends SuperClass {
public boolean sub() {
return true;
}
}
}

0 comments on commit f28941a

Please sign in to comment.