Skip to content

Commit

Permalink
Update handling of pattern matching in switch
Browse files Browse the repository at this point in the history
for AST changes in Java 17, in particular the addition of CaseTree#getLabels,
which returns a single `DefaultCaseLabelTree` to represent the default case.

#683
#684

PiperOrigin-RevId: 411074295
  • Loading branch information
cushon authored and google-java-format Team committed Nov 19, 2021
1 parent 58978e6 commit 8eb478a
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.google.googlejavaformat.java.java14;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
Expand All @@ -27,7 +28,6 @@
import com.sun.source.tree.CaseTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.InstanceOfTree;
import com.sun.source.tree.ModifiersTree;
import com.sun.source.tree.ModuleTree;
Expand Down Expand Up @@ -59,6 +59,7 @@ public class Java14InputAstVisitor extends JavaInputAstVisitor {
maybeGetMethod(BindingPatternTree.class, "getType");
private static final Method BINDING_PATTERN_TREE_GET_BINDING =
maybeGetMethod(BindingPatternTree.class, "getBinding");
private static final Method CASE_TREE_GET_LABELS = maybeGetMethod(CaseTree.class, "getLabels");

public Java14InputAstVisitor(OpsBuilder builder, int indentMultiplier) {
super(builder, indentMultiplier);
Expand Down Expand Up @@ -247,14 +248,25 @@ public Void visitCase(CaseTree node, Void unused) {
sync(node);
markForPartialFormat();
builder.forcedBreak();
if (node.getExpressions().isEmpty()) {
List<? extends Tree> labels;
boolean isDefault;
if (CASE_TREE_GET_LABELS != null) {
labels = (List<? extends Tree>) invoke(CASE_TREE_GET_LABELS, node);
isDefault =
labels.size() == 1
&& getOnlyElement(labels).getKind().name().equals("DEFAULT_CASE_LABEL");
} else {
labels = node.getExpressions();
isDefault = labels.isEmpty();
}
if (isDefault) {
token("default", plusTwo);
} else {
token("case", plusTwo);
builder.open(node.getExpressions().size() > 1 ? plusFour : ZERO);
builder.open(labels.size() > 1 ? plusFour : ZERO);
builder.space();
boolean first = true;
for (ExpressionTree expression : node.getExpressions()) {
for (Tree expression : labels) {
if (!first) {
token(",");
builder.breakOp(" ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,28 @@

package com.google.googlejavaformat.java;

import static com.google.common.base.StandardSystemProperty.JAVA_CLASS_VERSION;
import static com.google.common.base.StandardSystemProperty.JAVA_SPECIFICATION_VERSION;
import static com.google.common.collect.MoreCollectors.toOptional;
import static com.google.common.io.Files.getFileExtension;
import static com.google.common.io.Files.getNameWithoutExtension;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.io.CharStreams;
import com.google.common.reflect.ClassPath;
import com.google.common.reflect.ClassPath.ResourceInfo;
import com.google.googlejavaformat.Newlines;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Method;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -47,12 +46,13 @@
@RunWith(Parameterized.class)
public class FormatterIntegrationTest {

private static final ImmutableSet<String> JAVA14_TESTS =
ImmutableSet.of("I477", "Records", "RSLs", "Var", "ExpressionSwitch", "I574", "I594");

private static final ImmutableSet<String> JAVA15_TESTS = ImmutableSet.of("I603");

private static final ImmutableSet<String> JAVA16_TESTS = ImmutableSet.of("I588");
private static final ImmutableMultimap<Integer, String> VERSIONED_TESTS =
ImmutableMultimap.<Integer, String>builder()
.putAll(14, "I477", "Records", "RSLs", "Var", "ExpressionSwitch", "I574", "I594")
.putAll(15, "I603")
.putAll(16, "I588")
.putAll(17, "I683", "I684")
.build();

@Parameters(name = "{index}: {0}")
public static Iterable<Object[]> data() throws IOException {
Expand Down Expand Up @@ -91,35 +91,16 @@ public static Iterable<Object[]> data() throws IOException {
String input = inputs.get(fileName);
assertTrue("unmatched input", outputs.containsKey(fileName));
String expectedOutput = outputs.get(fileName);
if (JAVA14_TESTS.contains(fileName) && getMajor() < 14) {
continue;
}
if (JAVA15_TESTS.contains(fileName) && getMajor() < 15) {
continue;
}
if (JAVA16_TESTS.contains(fileName) && getMajor() < 16) {
Optional<Integer> version =
VERSIONED_TESTS.inverse().get(fileName).stream().collect(toOptional());
if (version.isPresent() && Runtime.version().feature() < version.getAsInt()) {
continue;
}
testInputs.add(new Object[] {fileName, input, expectedOutput});
}
return testInputs;
}

private static int getMajor() {
try {
Method versionMethod = Runtime.class.getMethod("version");
Object version = versionMethod.invoke(null);
return (int) version.getClass().getMethod("major").invoke(version);
} catch (Exception e) {
// continue below
}
int version = (int) Double.parseDouble(JAVA_CLASS_VERSION.value());
if (49 <= version && version <= 52) {
return version - (49 - 5);
}
throw new IllegalStateException("Unknown Java version: " + JAVA_SPECIFICATION_VERSION.value());
}

private final String name;
private final String input;
private final String expected;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
interface Test {

static class Test1 implements Test{}
static class Test2 implements Test{}

public static void main(String[] args) {
Test test = new Test1();
switch (test) {
case Test1 test1 -> {}
case Test2 test2 -> {}
default -> throw new IllegalStateException("Unexpected value: " + test);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
interface Test {

static class Test1 implements Test {}

static class Test2 implements Test {}

public static void main(String[] args) {
Test test = new Test1();
switch (test) {
case Test1 test1 -> {}
case Test2 test2 -> {}
default -> throw new IllegalStateException("Unexpected value: " + test);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package example;

import example.model.SealedInterface;
import example.model.TypeA;
import example.model.TypeB;

public class Main {
public void apply(SealedInterface sealedInterface) {
switch(sealedInterface) {
case TypeA a -> System.out.println("A!");
case TypeB b -> System.out.println("B!");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package example;

import example.model.SealedInterface;
import example.model.TypeA;
import example.model.TypeB;

public class Main {
public void apply(SealedInterface sealedInterface) {
switch (sealedInterface) {
case TypeA a -> System.out.println("A!");
case TypeB b -> System.out.println("B!");
}
}
}

0 comments on commit 8eb478a

Please sign in to comment.