Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sharded execution for bazel test sharding support #2257

Merged
merged 4 commits into from Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/command-line.md
Expand Up @@ -29,6 +29,7 @@
[Specify the section to run](#specify-the-section-to-run)<br>
[Filenames as tags](#filenames-as-tags)<br>
[Override output colouring](#override-output-colouring)<br>
[Test Sharding](#test-sharding)<br>

Catch works quite nicely without any command line options at all - but for those times when you want greater control the following options are available.
Click one of the following links to take you straight to that option - or scroll on to browse the available options.
Expand Down Expand Up @@ -67,6 +68,8 @@ Click one of the following links to take you straight to that option - or scroll
<a href="#benchmark-no-analysis"> ` --benchmark-no-analysis`</a><br />
<a href="#benchmark-warmup-time"> ` --benchmark-warmup-time`</a><br />
<a href="#use-colour"> ` --use-colour`</a><br />
<a href="#test-sharding"> ` --shard-count`</a><br />
<a href="#test-sharding"> ` --shard-index`</a><br />

</br>

Expand Down Expand Up @@ -425,6 +428,20 @@ processing of output.
`--use-colour yes` forces coloured output, `--use-colour no` disables coloured
output. The default behaviour is `--use-colour auto`.

<a id="test-sharding"></a>
## Test Sharding
<pre>--shard-count <#number of shards>, --shard-index <#shard index to run></pre>
horenmar marked this conversation as resolved.
Show resolved Hide resolved

> [Introduced](https://github.com/catchorg/Catch2/pull/2257) in Catch2 X.Y.Z.

When `--shard-count <#number of shards>` is used, the tests to execute will be split evenly in to the given number of sets,
identified by indicies starting at 0. The tests in the set given by `--shard-index <#shard index to run>` will be executed.
The default shard count is `1`, and the default index to run is `0`. It is an error to specify a shard index greater than
the number of shards.

This is useful when you want to split test execution across multiple processes, as is done with [Bazel test sharding](https://docs.bazel.build/versions/main/test-encyclopedia.html#test-sharding).


---

[Home](Readme.md#top)
1 change: 1 addition & 0 deletions src/catch2/catch_all.hpp
Expand Up @@ -85,6 +85,7 @@
#include <catch2/internal/catch_result_type.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_section.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_singletons.hpp>
#include <catch2/internal/catch_source_line_info.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>
Expand Down
2 changes: 2 additions & 0 deletions src/catch2/catch_config.cpp
Expand Up @@ -73,6 +73,8 @@ namespace Catch {
double Config::minDuration() const { return m_data.minDuration; }
TestRunOrder Config::runOrder() const { return m_data.runOrder; }
uint32_t Config::rngSeed() const { return m_data.rngSeed; }
unsigned int Config::shardCount() const { return m_data.shardCount; }
unsigned int Config::shardIndex() const { return m_data.shardIndex; }
UseColour Config::useColour() const { return m_data.useColour; }
bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; }
int Config::abortAfter() const { return m_data.abortAfter; }
Expand Down
5 changes: 5 additions & 0 deletions src/catch2/catch_config.hpp
Expand Up @@ -37,6 +37,9 @@ namespace Catch {
int abortAfter = -1;
uint32_t rngSeed = generateRandomSeed(GenerateFrom::Default);

unsigned int shardCount = 1;
unsigned int shardIndex = 0;

bool benchmarkNoAnalysis = false;
unsigned int benchmarkSamples = 100;
double benchmarkConfidenceInterval = 0.95;
Expand Down Expand Up @@ -99,6 +102,8 @@ namespace Catch {
double minDuration() const override;
TestRunOrder runOrder() const override;
uint32_t rngSeed() const override;
unsigned int shardCount() const override;
unsigned int shardIndex() const override;
UseColour useColour() const override;
bool shouldDebugBreak() const override;
int abortAfter() const override;
Expand Down
10 changes: 10 additions & 0 deletions src/catch2/catch_session.cpp
Expand Up @@ -16,6 +16,7 @@
#include <catch2/catch_version.hpp>
#include <catch2/interfaces/catch_interfaces_reporter.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_textflow.hpp>
#include <catch2/internal/catch_windows_h_proxy.hpp>
#include <catch2/reporters/catch_reporter_listening.hpp>
Expand Down Expand Up @@ -72,6 +73,8 @@ namespace Catch {
for (auto const& match : m_matches)
m_tests.insert(match.tests.begin(), match.tests.end());
}

m_tests = createShard(m_tests, m_config->shardCount(), m_config->shardIndex());
}

Totals execute() {
Expand Down Expand Up @@ -171,6 +174,7 @@ namespace Catch {
return 1;

auto result = m_cli.parse( Clara::Args( argc, argv ) );

if( !result ) {
config();
getCurrentMutableContext().setConfig(m_config.get());
Expand Down Expand Up @@ -253,6 +257,12 @@ namespace Catch {
if( m_startupExceptions )
return 1;


if( m_configData.shardIndex >= m_configData.shardCount ) {
Catch::cerr() << "The shard count (" << m_configData.shardCount << ") must be greater than the shard index (" << m_configData.shardIndex << ")\n" << std::flush;
return 1;
}

if (m_configData.showHelp || m_configData.libIdentify) {
return 0;
}
Expand Down
2 changes: 2 additions & 0 deletions src/catch2/interfaces/catch_interfaces_config.hpp
Expand Up @@ -74,6 +74,8 @@ namespace Catch {
virtual std::vector<std::string> const& getTestsOrTags() const = 0;
virtual TestRunOrder runOrder() const = 0;
virtual uint32_t rngSeed() const = 0;
virtual unsigned int shardCount() const = 0;
virtual unsigned int shardIndex() const = 0;
virtual UseColour useColour() const = 0;
virtual std::vector<std::string> const& getSectionsToRun() const = 0;
virtual Verbosity verbosity() const = 0;
Expand Down
42 changes: 42 additions & 0 deletions src/catch2/internal/catch_commandline.cpp
Expand Up @@ -149,6 +149,42 @@ namespace Catch {
return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" );
return ParserResult::ok( ParseResultType::Matched );
};
auto const setShardCount = [&]( std::string const& shardCount ) {
CATCH_TRY{
std::size_t parsedTo = 0;
int64_t parsedCount = std::stoll(shardCount, &parsedTo, 0);
if (parsedTo != shardCount.size()) {
return ParserResult::runtimeError("Could not parse '" + shardCount + "' as shard count");
}
if (parsedCount <= 0) {
return ParserResult::runtimeError("Shard count must be a positive number");
}

config.shardCount = static_cast<unsigned int>(parsedCount);
return ParserResult::ok(ParseResultType::Matched);
} CATCH_CATCH_ANON(std::exception const&) {
return ParserResult::runtimeError("Could not parse '" + shardCount + "' as shard count");
}
};

auto const setShardIndex = [&](std::string const& shardIndex) {
CATCH_TRY{
std::size_t parsedTo = 0;
int64_t parsedIndex = std::stoll(shardIndex, &parsedTo, 0);
if (parsedTo != shardIndex.size()) {
return ParserResult::runtimeError("Could not parse '" + shardIndex + "' as shard index");
}
if (parsedIndex < 0) {
return ParserResult::runtimeError("Shard index must be a non-negative number");
}

config.shardIndex = static_cast<unsigned int>(parsedIndex);
return ParserResult::ok(ParseResultType::Matched);
} CATCH_CATCH_ANON(std::exception const&) {
return ParserResult::runtimeError("Could not parse '" + shardIndex + "' as shard index");
}
};


auto cli
= ExeName( config.processName )
Expand Down Expand Up @@ -240,6 +276,12 @@ namespace Catch {
| Opt( config.benchmarkWarmupTime, "benchmarkWarmupTime" )
["--benchmark-warmup-time"]
( "amount of time in milliseconds spent on warming up each test (default: 100)" )
| Opt( setShardCount, "shard count" )
["--shard-count"]
( "split the tests to execute into this many groups" )
| Opt( setShardIndex, "shard index" )
["--shard-index"]
( "index of the group of tests to execute (see --shard-count)" )
| Arg( config.testsOrTags, "test name|pattern|tags" )
( "which test or tests to use" );

Expand Down
41 changes: 41 additions & 0 deletions src/catch2/internal/catch_sharding.hpp
@@ -0,0 +1,41 @@

// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)

// SPDX-License-Identifier: BSL-1.0
#ifndef CATCH_SHARDING_HPP_INCLUDED
#define CATCH_SHARDING_HPP_INCLUDED

#include <catch2/catch_session.hpp>

#include <cmath>

namespace Catch {

template<typename Container>
Container createShard(Container const& container, std::size_t const shardCount, std::size_t const shardIndex) {
assert(shardCount > shardIndex);

if (shardCount == 1) {
return container;
}

const std::size_t totalTestCount = container.size();

const std::size_t shardSize = totalTestCount / shardCount;
const std::size_t leftoverTests = totalTestCount % shardCount;

const std::size_t startIndex = shardIndex * shardSize + (std::min)(shardIndex, leftoverTests);
const std::size_t endIndex = (shardIndex + 1) * shardSize + (std::min)(shardIndex + 1, leftoverTests);

auto startIterator = std::next(container.begin(), startIndex);
auto endIterator = std::next(container.begin(), endIndex);

return Container(startIterator, endIterator);
}

}

#endif // CATCH_SHARDING_HPP_INCLUDED
3 changes: 2 additions & 1 deletion src/catch2/internal/catch_test_case_registry_impl.cpp
Expand Up @@ -12,6 +12,7 @@
#include <catch2/interfaces/catch_interfaces_registry_hub.hpp>
#include <catch2/internal/catch_random_number_generator.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/catch_test_case_info.hpp>
#include <catch2/catch_test_spec.hpp>
#include <catch2/internal/catch_move_and_forward.hpp>
Expand Down Expand Up @@ -135,7 +136,7 @@ namespace {
filtered.push_back(testCase);
}
}
return filtered;
return createShard(filtered, config.shardCount(), config.shardIndex());
}
std::vector<TestCaseHandle> const& getAllTestCasesSorted( IConfig const& config ) {
return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );
Expand Down
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Expand Up @@ -87,6 +87,7 @@ set(TEST_SOURCES
${SELF_TEST_DIR}/IntrospectiveTests/RandomNumberGeneration.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Reporters.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Tag.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Sharding.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/String.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/StringManip.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Xml.tests.cpp
Expand Down Expand Up @@ -310,6 +311,7 @@ set_tests_properties(TagAlias PROPERTIES
add_test(NAME RandomTestOrdering COMMAND ${PYTHON_EXECUTABLE}
${CATCH_DIR}/tests/TestScripts/testRandomOrder.py $<TARGET_FILE:SelfTest>)


add_test(NAME CheckConvenienceHeaders
COMMAND
${PYTHON_EXECUTABLE} ${CATCH_DIR}/tools/scripts/checkConvenienceHeaders.py
Expand Down
16 changes: 16 additions & 0 deletions tests/ExtraTests/CMakeLists.txt
Expand Up @@ -8,6 +8,22 @@ project( Catch2ExtraTests LANGUAGES CXX )

message( STATUS "Extra tests included" )


add_test(
NAME TestShardingIntegration
COMMAND ${PYTHON_EXECUTABLE} ${CATCH_DIR}/tests/TestScripts/testSharding.py $<TARGET_FILE:SelfTest>
)

add_test(
NAME TestSharding::OverlyLargeShardIndex
COMMAND $<TARGET_FILE:SelfTest> --shard-index 5 --shard-count 5
)
set_tests_properties(
TestSharding::OverlyLargeShardIndex
PROPERTIES
PASS_REGULAR_EXPRESSION "The shard count \\(5\\) must be greater than the shard index \\(5\\)"
)

# The MinDuration reporting tests do not need separate compilation, but
# they have non-trivial execution time, so they are categorized as
# extra tests, so that they are run less.
Expand Down
1 change: 1 addition & 0 deletions tests/SelfTest/Baselines/automake.sw.approved.txt
Expand Up @@ -177,6 +177,7 @@ Nor would this
:test-result: FAIL Output from all sections is reported
:test-result: PASS Overloaded comma or address-of operators are not used
:test-result: PASS Parse test names and tags
:test-result: PASS Parsing sharding-related cli flags
:test-result: PASS Pointers can be compared to null
:test-result: PASS Precision of floating point stringification can be set
:test-result: PASS Predicate matcher can accept const char*
Expand Down
12 changes: 12 additions & 0 deletions tests/SelfTest/Baselines/compact.sw.approved.txt
Expand Up @@ -1178,6 +1178,18 @@ CmdLine.tests.cpp:<line number>: passed: !(spec.matches(*fakeTestCase("hidden an
CmdLine.tests.cpp:<line number>: passed: !(spec.matches(*fakeTestCase("only foo", "[foo]"))) for: !false
CmdLine.tests.cpp:<line number>: passed: !(spec.matches(*fakeTestCase("only hidden", "[.]"))) for: !false
CmdLine.tests.cpp:<line number>: passed: spec.matches(*fakeTestCase("neither foo nor hidden", "[bar]")) for: true
CmdLine.tests.cpp:<line number>: passed: cli.parse({ "test", "--shard-count=8" }) for: {?}
CmdLine.tests.cpp:<line number>: passed: config.shardCount == 8 for: 8 == 8
CmdLine.tests.cpp:<line number>: passed: !(result) for: !{?}
CmdLine.tests.cpp:<line number>: passed: result.errorMessage(), ContainsSubstring("Shard count must be a positive number") for: "Shard count must be a positive number" contains: "Shard count must be a positive number"
CmdLine.tests.cpp:<line number>: passed: !(result) for: !{?}
CmdLine.tests.cpp:<line number>: passed: result.errorMessage(), ContainsSubstring("Shard count must be a positive number") for: "Shard count must be a positive number" contains: "Shard count must be a positive number"
CmdLine.tests.cpp:<line number>: passed: cli.parse({ "test", "--shard-index=2" }) for: {?}
CmdLine.tests.cpp:<line number>: passed: config.shardIndex == 2 for: 2 == 2
CmdLine.tests.cpp:<line number>: passed: !(result) for: !{?}
CmdLine.tests.cpp:<line number>: passed: result.errorMessage(), ContainsSubstring("Shard index must be a non-negative number") for: "Shard index must be a non-negative number" contains: "Shard index must be a non-negative number"
CmdLine.tests.cpp:<line number>: passed: cli.parse({ "test", "--shard-index=0" }) for: {?}
CmdLine.tests.cpp:<line number>: passed: config.shardIndex == 0 for: 0 == 0
Condition.tests.cpp:<line number>: passed: p == 0 for: 0 == 0
Condition.tests.cpp:<line number>: passed: p == pNULL for: 0 == 0
Condition.tests.cpp:<line number>: passed: p != 0 for: 0x<hex digits> != 0
Expand Down
4 changes: 2 additions & 2 deletions tests/SelfTest/Baselines/console.std.approved.txt
Expand Up @@ -1426,6 +1426,6 @@ due to unexpected exception with message:
Why would you throw a std::string?

===============================================================================
test cases: 373 | 296 passed | 70 failed | 7 failed as expected
assertions: 2115 | 1959 passed | 129 failed | 27 failed as expected
test cases: 374 | 297 passed | 70 failed | 7 failed as expected
assertions: 2127 | 1971 passed | 129 failed | 27 failed as expected