diff --git a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java index fa4ed9c9466..045457df9c2 100644 --- a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java +++ b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java @@ -58,12 +58,6 @@ public void addCase(List keysList, IContainer c) { cases.add(new CaseInfo(keysList, c)); } - public void addDefaultCase(IContainer c) { - if (c != null) { - cases.add(new CaseInfo(Collections.singletonList(DEFAULT_CASE_KEY), c)); - } - } - public List getCases() { return cases; } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchOverStringVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchOverStringVisitor.java index bb5dbbcd3df..4467656e0b7 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchOverStringVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchOverStringVisitor.java @@ -106,10 +106,9 @@ private boolean restoreSwitchOverString(MethodNode mth, SwitchRegion switchRegio // all checks passed, replace with new switch IRegion parentRegion = switchRegion.getParent(); SwitchRegion replaceRegion = new SwitchRegion(parentRegion, switchRegion.getHeader()); - for (CaseData caseData : switchData.getCases()) { - replaceRegion.addCase(Collections.unmodifiableList(caseData.getStrValues()), caseData.getCode()); + for (SwitchRegion.CaseInfo caseInfo : switchData.getNewCases()) { + replaceRegion.addCase(Collections.unmodifiableList(caseInfo.getKeys()), caseInfo.getContainer()); } - replaceRegion.addDefaultCase(switchData.getDefaultCode()); if (!parentRegion.replaceSubBlock(switchRegion, replaceRegion)) { mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue"); return false; @@ -216,36 +215,53 @@ private boolean mergeWithCode(SwitchData switchData) { block -> switchData.getToRemove().add(block)); } - IContainer defaultContainer = null; + final var newCases = new ArrayList(); for (SwitchRegion.CaseInfo caseInfo : codeSwitch.getCases()) { - CaseData prevCase = null; + SwitchRegion.CaseInfo newCase = null; for (Object key : caseInfo.getKeys()) { final Integer intKey = unwrapIntKey(key); if (intKey != null) { - CaseData caseData = casesMap.get(intKey); + final var caseData = casesMap.remove(intKey); if (caseData == null) { return false; } - if (prevCase == null) { - caseData.setCode(caseInfo.getContainer()); - prevCase = caseData; + if (newCase == null) { + final List keys = new ArrayList<>(caseData.getStrValues()); + newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer()); } else { // merge cases - prevCase.getStrValues().addAll(caseData.getStrValues()); - caseData.setCodeNum(-1); + newCase.getKeys().addAll(caseData.getStrValues()); } } else if (key == SwitchRegion.DEFAULT_CASE_KEY) { - defaultContainer = caseInfo.getContainer(); + final var iterator = casesMap.entrySet().iterator(); + while (iterator.hasNext()) { + final var caseData = iterator.next().getValue(); + if (newCase == null) { + final List keys = new ArrayList<>(caseData.getStrValues()); + newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer()); + } else { + // merge cases + newCase.getKeys().addAll(caseData.getStrValues()); + } + + iterator.remove(); + } + + if (newCase == null) { + newCase = new SwitchRegion.CaseInfo(new ArrayList<>(), caseInfo.getContainer()); + } + + newCase.getKeys().add(SwitchRegion.DEFAULT_CASE_KEY); } else { return false; } } + newCases.add(newCase); } - cases.removeIf(c -> c.getCodeNum() == -1); - switchData.setDefaultCode(defaultContainer); switchData.setCodeSwitch(codeSwitch); switchData.setNumArg(numArg); + switchData.setNewCases(newCases); return true; } @@ -367,7 +383,7 @@ private static final class SwitchData { private final List toRemove = new ArrayList<>(); private Map strEqInsns; private List cases; - private IContainer defaultCode; + private List newCases; private SwitchRegion codeSwitch; private RegisterArg numArg; @@ -384,12 +400,12 @@ public void setCases(List cases) { this.cases = cases; } - public IContainer getDefaultCode() { - return defaultCode; + public List getNewCases() { + return newCases; } - public void setDefaultCode(IContainer defaultCode) { - this.defaultCode = defaultCode; + public void setNewCases(List cases) { + this.newCases = cases; } public MethodNode getMth() { diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchOverStrings2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchOverStrings2.java new file mode 100644 index 00000000000..abab633f000 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchOverStrings2.java @@ -0,0 +1,43 @@ +package jadx.tests.integration.switches; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestSwitchOverStrings2 extends IntegrationTest { + + public static class TestCls { + + public int test(String str) { + switch (str) { + case "branch1": + case "branch2": + return 1; + case "branch3": + case "branch4": + default: + return 0; + } + } + + public void check() { + assertThat(test("branch1")).isEqualTo(1); + assertThat(test("branch2")).isEqualTo(1); + assertThat(test("branch3")).isEqualTo(0); + assertThat(test("branch4")).isEqualTo(0); + assertThat(test("other")).isEqualTo(0); + assertThat(test("other2")).isEqualTo(0); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .countString(4, "case ") + .countString(1, "default:") + .countString(2, "return "); + } +}