Skip to content

Commit

Permalink
config: hold shared_ptr to configs to avoid use-after-free (#33826)
Browse files Browse the repository at this point in the history
Signed-off-by: Greg Greenway <ggreenway@apple.com>
  • Loading branch information
ggreenway committed Apr 29, 2024
1 parent 74849c9 commit bfb5443
Show file tree
Hide file tree
Showing 61 changed files with 301 additions and 276 deletions.
2 changes: 1 addition & 1 deletion contrib/kafka/filters/network/source/broker/config.cc
Expand Up @@ -20,7 +20,7 @@ Network::FilterFactoryCb KafkaConfigFactory::createFilterFactoryFromProtoTyped(
std::make_shared<BrokerFilterConfig>(proto_config);
return [&context, filter_config](Network::FilterManager& filter_manager) -> void {
Network::FilterSharedPtr filter = std::make_shared<KafkaBrokerFilter>(
context.scope(), context.serverFactoryContext().timeSource(), *filter_config);
context.scope(), context.serverFactoryContext().timeSource(), filter_config);
filter_manager.addFilter(filter);
};
}
Expand Down
6 changes: 3 additions & 3 deletions contrib/kafka/filters/network/source/broker/filter.cc
Expand Up @@ -70,11 +70,11 @@ absl::flat_hash_map<int32_t, MonotonicTime>& KafkaMetricsFacadeImpl::getRequestA
}

KafkaBrokerFilter::KafkaBrokerFilter(Stats::Scope& scope, TimeSource& time_source,
const BrokerFilterConfig& filter_config)
const BrokerFilterConfigSharedPtr& filter_config)
: KafkaBrokerFilter{filter_config, std::make_shared<KafkaMetricsFacadeImpl>(
scope, time_source, filter_config.statPrefix())} {};
scope, time_source, filter_config->statPrefix())} {};

KafkaBrokerFilter::KafkaBrokerFilter(const BrokerFilterConfig& filter_config,
KafkaBrokerFilter::KafkaBrokerFilter(const BrokerFilterConfigSharedPtr& filter_config,
const KafkaMetricsFacadeSharedPtr& metrics)
: metrics_{metrics}, response_rewriter_{createRewriter(filter_config)},
response_decoder_{new ResponseDecoder({metrics, response_rewriter_})},
Expand Down
4 changes: 2 additions & 2 deletions contrib/kafka/filters/network/source/broker/filter.h
Expand Up @@ -147,7 +147,7 @@ class KafkaBrokerFilter : public Network::Filter, private Logger::Loggable<Logge
* duration calculation.
*/
KafkaBrokerFilter(Stats::Scope& scope, TimeSource& time_source,
const BrokerFilterConfig& filter_config);
const BrokerFilterConfigSharedPtr& filter_config);

/**
* Visible for testing.
Expand All @@ -173,7 +173,7 @@ class KafkaBrokerFilter : public Network::Filter, private Logger::Loggable<Logge
* Helper delegate constructor.
* Passes metrics facade as argument to decoders.
*/
KafkaBrokerFilter(const BrokerFilterConfig& filter_config,
KafkaBrokerFilter(const BrokerFilterConfigSharedPtr& filter_config,
const KafkaMetricsFacadeSharedPtr& metrics);

const KafkaMetricsFacadeSharedPtr metrics_;
Expand Down
7 changes: 4 additions & 3 deletions contrib/kafka/filters/network/source/broker/rewriter.cc
Expand Up @@ -8,7 +8,8 @@ namespace Broker {

// ResponseRewriterImpl.

ResponseRewriterImpl::ResponseRewriterImpl(const BrokerFilterConfig& config) : config_{config} {};
ResponseRewriterImpl::ResponseRewriterImpl(const BrokerFilterConfigSharedPtr& config)
: config_{config} {};

void ResponseRewriterImpl::onMessage(AbstractResponseSharedPtr response) {
responses_to_rewrite_.push_back(response);
Expand Down Expand Up @@ -92,8 +93,8 @@ void DoNothingRewriter::process(Buffer::Instance&) {}

// Factory method.

ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfig& config) {
if (config.needsResponseRewrite()) {
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfigSharedPtr& config) {
if (config->needsResponseRewrite()) {
return std::make_shared<ResponseRewriterImpl>(config);
} else {
return std::make_shared<DoNothingRewriter>();
Expand Down
8 changes: 4 additions & 4 deletions contrib/kafka/filters/network/source/broker/rewriter.h
Expand Up @@ -38,7 +38,7 @@ using ResponseRewriterSharedPtr = std::shared_ptr<ResponseRewriter>;
*/
class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<Logger::Id::kafka> {
public:
ResponseRewriterImpl(const BrokerFilterConfig& config);
ResponseRewriterImpl(const BrokerFilterConfigSharedPtr& config);

// ResponseCallback
void onMessage(AbstractResponseSharedPtr response) override;
Expand Down Expand Up @@ -69,7 +69,7 @@ class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<L
// Pointer-to-member used to handle varying field names across the structs.
template <typename T> void maybeUpdateHostAndPort(T& arg, const int32_t T::*node_id_field) const {
const int32_t node_id = arg.*node_id_field;
const absl::optional<HostAndPort> hostAndPort = config_.findBrokerAddressOverride(node_id);
const absl::optional<HostAndPort> hostAndPort = config_->findBrokerAddressOverride(node_id);
if (hostAndPort) {
ENVOY_LOG(trace, "Changing broker [{}] from {}:{} to {}:{}", node_id, arg.host_, arg.port_,
hostAndPort->first, hostAndPort->second);
Expand All @@ -78,7 +78,7 @@ class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<L
}
}

const BrokerFilterConfig& config_;
const BrokerFilterConfigSharedPtr config_;
std::vector<AbstractResponseSharedPtr> responses_to_rewrite_;
};

Expand All @@ -99,7 +99,7 @@ class DoNothingRewriter : public ResponseRewriter {
/**
* Factory method that creates a rewriter depending on configuration.
*/
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfig& config);
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfigSharedPtr& config);

} // namespace Broker
} // namespace Kafka
Expand Down
Expand Up @@ -36,7 +36,9 @@ class KafkaBrokerFilterProtocolTest : public testing::Test,
Stats::TestUtil::TestStore store_;
Stats::Scope& scope_{*store_.rootScope()};
Event::TestRealTimeSystem time_source_;
KafkaBrokerFilter testee_{scope_, time_source_, BrokerFilterConfig{"prefix", false, {}}};
KafkaBrokerFilter testee_{scope_, time_source_,
std::make_shared<BrokerFilterConfig>(std::string("prefix"), false,
std::vector<RewriteRule>{})};

Network::FilterStatus consumeRequestFromBuffer() {
return testee_.onData(RequestB::buffer_, false);
Expand Down
36 changes: 18 additions & 18 deletions contrib/kafka/filters/network/test/broker/rewriter_unit_test.cc
Expand Up @@ -43,7 +43,7 @@ class FakeResponse : public AbstractResponse {

TEST(ResponseRewriterImplUnitTest, ShouldRewriteBuffer) {
// given
ResponseRewriterImpl testee{MockBrokerFilterConfig{}};
ResponseRewriterImpl testee{std::make_shared<MockBrokerFilterConfig>()};

auto response1 = std::make_shared<FakeResponse>(7);
auto response2 = std::make_shared<FakeResponse>(13);
Expand Down Expand Up @@ -80,13 +80,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteMetadataResponse) {
std::vector<MetadataResponseBroker> brokers = {b1, b2, b3};
MetadataResponse mr = {brokers, {}};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> r1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(b1.node_id_)).WillOnce(Return(r1));
EXPECT_CALL(*config, findBrokerAddressOverride(b1.node_id_)).WillOnce(Return(r1));
absl::optional<HostAndPort> r2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(b2.node_id_)).WillOnce(Return(r2));
EXPECT_CALL(*config, findBrokerAddressOverride(b2.node_id_)).WillOnce(Return(r2));
absl::optional<HostAndPort> r3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(b3.node_id_)).WillOnce(Return(r3));
EXPECT_CALL(*config, findBrokerAddressOverride(b3.node_id_)).WillOnce(Return(r3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -107,15 +107,15 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteFindCoordinatorResponse) {
Coordinator c3 = {"k3", 3, "ch3", 4444, 0, {}, {}};
fcr.coordinators_ = {c1, c2, c3};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> fcrhp = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(fcr.node_id_)).WillOnce(Return(fcrhp));
EXPECT_CALL(*config, findBrokerAddressOverride(fcr.node_id_)).WillOnce(Return(fcrhp));
absl::optional<HostAndPort> cr1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(c1.node_id_)).WillOnce(Return(cr1));
EXPECT_CALL(*config, findBrokerAddressOverride(c1.node_id_)).WillOnce(Return(cr1));
absl::optional<HostAndPort> cr2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(c2.node_id_)).WillOnce(Return(cr2));
EXPECT_CALL(*config, findBrokerAddressOverride(c2.node_id_)).WillOnce(Return(cr2));
absl::optional<HostAndPort> cr3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(c3.node_id_)).WillOnce(Return(cr3));
EXPECT_CALL(*config, findBrokerAddressOverride(c3.node_id_)).WillOnce(Return(cr3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -136,13 +136,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteDescribeClusterResponse) {
std::vector<DescribeClusterBroker> brokers = {b1, b2, b3};
DescribeClusterResponse dcr = {0, 0, absl::nullopt, "", 0, brokers, 0, {}};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> cr1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(b1.broker_id_)).WillOnce(Return(cr1));
EXPECT_CALL(*config, findBrokerAddressOverride(b1.broker_id_)).WillOnce(Return(cr1));
absl::optional<HostAndPort> cr2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(b2.broker_id_)).WillOnce(Return(cr2));
EXPECT_CALL(*config, findBrokerAddressOverride(b2.broker_id_)).WillOnce(Return(cr2));
absl::optional<HostAndPort> cr3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(b3.broker_id_)).WillOnce(Return(cr3));
EXPECT_CALL(*config, findBrokerAddressOverride(b3.broker_id_)).WillOnce(Return(cr3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -155,13 +155,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteDescribeClusterResponse) {
}

TEST(ResponseRewriterUnitTest, ShouldCreateProperRewriter) {
MockBrokerFilterConfig c1;
EXPECT_CALL(c1, needsResponseRewrite()).WillOnce(Return(true));
auto c1 = std::make_shared<MockBrokerFilterConfig>();
EXPECT_CALL(*c1, needsResponseRewrite()).WillOnce(Return(true));
ResponseRewriterSharedPtr r1 = createRewriter(c1);
ASSERT_NE(std::dynamic_pointer_cast<ResponseRewriterImpl>(r1), nullptr);

MockBrokerFilterConfig c2;
EXPECT_CALL(c2, needsResponseRewrite()).WillOnce(Return(false));
auto c2 = std::make_shared<MockBrokerFilterConfig>();
EXPECT_CALL(*c2, needsResponseRewrite()).WillOnce(Return(false));
ResponseRewriterSharedPtr r2 = createRewriter(c2);
ASSERT_NE(std::dynamic_pointer_cast<DoNothingRewriter>(r2), nullptr);
}
Expand Down
2 changes: 1 addition & 1 deletion contrib/rocketmq_proxy/filters/network/source/config.cc
Expand Up @@ -24,7 +24,7 @@ Network::FilterFactoryCb RocketmqProxyFilterConfigFactory::createFilterFactoryFr
std::shared_ptr<ConfigImpl> filter_config = std::make_shared<ConfigImpl>(proto_config, context);
return [filter_config, &context](Network::FilterManager& filter_manager) -> void {
filter_manager.addReadFilter(std::make_shared<ConnectionManager>(
*filter_config, context.serverFactoryContext().mainThreadDispatcher().timeSource()));
filter_config, context.serverFactoryContext().mainThreadDispatcher().timeSource()));
};
}

Expand Down
6 changes: 3 additions & 3 deletions contrib/rocketmq_proxy/filters/network/source/conn_manager.cc
Expand Up @@ -24,8 +24,8 @@ bool ConsumerGroupMember::expired() const {
connection_manager_->config().transientObjectLifeSpan().count();
}

ConnectionManager::ConnectionManager(Config& config, TimeSource& time_source)
: config_(config), time_source_(time_source), stats_(config.stats()) {}
ConnectionManager::ConnectionManager(const ConfigSharedPtr& config, TimeSource& time_source)
: config_(config), time_source_(time_source), stats_(config_->stats()) {}

Envoy::Network::FilterStatus ConnectionManager::onData(Envoy::Buffer::Instance& data,
bool end_stream) {
Expand Down Expand Up @@ -137,7 +137,7 @@ void ConnectionManager::purgeDirectiveTable() {
for (auto it = ack_directive_table_.begin(); it != ack_directive_table_.end();) {
auto duration = current - it->second.creation_time_;
if (std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() >
config_.transientObjectLifeSpan().count()) {
config_->transientObjectLifeSpan().count()) {
ack_directive_table_.erase(it++);
} else {
it++;
Expand Down
8 changes: 5 additions & 3 deletions contrib/rocketmq_proxy/filters/network/source/conn_manager.h
Expand Up @@ -52,6 +52,8 @@ class Config {
virtual std::chrono::milliseconds transientObjectLifeSpan() const PURE;
};

using ConfigSharedPtr = std::shared_ptr<Config>;

class ConnectionManager;

/**
Expand Down Expand Up @@ -82,7 +84,7 @@ class ConsumerGroupMember {

class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::Id::filter> {
public:
ConnectionManager(Config& config, TimeSource& time_source);
ConnectionManager(const ConfigSharedPtr& config, TimeSource& time_source);

~ConnectionManager() override = default;

Expand Down Expand Up @@ -151,7 +153,7 @@ class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::I

void resetAllActiveMessages(absl::string_view error_msg);

Config& config() { return config_; }
Config& config() { return *config_; }

RocketmqFilterStats& stats() { return stats_; }

Expand Down Expand Up @@ -194,7 +196,7 @@ class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::I
Network::ReadFilterCallbacks* read_callbacks_{};
Buffer::OwnedImpl request_buffer_;

Config& config_;
ConfigSharedPtr config_;
TimeSource& time_source_;
RocketmqFilterStats& stats_;

Expand Down
Expand Up @@ -23,7 +23,7 @@ class ActiveMessageTest : public testing::Test {
public:
ActiveMessageTest()
: stats_(RocketmqFilterStats::generateStats("test.", *store_.rootScope())),
config_(rocketmq_proxy_config_, factory_context_),
config_(std::make_shared<ConfigImpl>(rocketmq_proxy_config_, factory_context_)),
connection_manager_(
config_, factory_context_.serverFactoryContext().mainThreadDispatcher().timeSource()) {
connection_manager_.initializeReadFilterCallbacks(filter_callbacks_);
Expand All @@ -39,7 +39,7 @@ class ActiveMessageTest : public testing::Test {
NiceMock<Server::Configuration::MockFactoryContext> factory_context_;
Stats::IsolatedStoreImpl store_;
RocketmqFilterStats stats_;
ConfigImpl config_;
std::shared_ptr<ConfigImpl> config_;
ConnectionManager connection_manager_;
};

Expand Down
Expand Up @@ -54,9 +54,9 @@ class RocketmqConnectionManagerTest : public Event::TestUsingSimulatedTime, publ
TestUtility::loadFromYaml(yaml, proto_config_);
TestUtility::validate(proto_config_);
}
config_ = std::make_unique<TestConfigImpl>(proto_config_, factory_context_, stats_);
config_ = std::make_shared<TestConfigImpl>(proto_config_, factory_context_, stats_);
conn_manager_ = std::make_unique<ConnectionManager>(
*config_, factory_context_.server_factory_context_.mainThreadDispatcher().timeSource());
config_, factory_context_.server_factory_context_.mainThreadDispatcher().timeSource());
conn_manager_->initializeReadFilterCallbacks(filter_callbacks_);
conn_manager_->onNewConnection();
current_ = factory_context_.server_factory_context_.mainThreadDispatcher()
Expand Down Expand Up @@ -84,7 +84,7 @@ class RocketmqConnectionManagerTest : public Event::TestUsingSimulatedTime, publ
RocketmqFilterStats stats_;
ConfigRocketmqProxy proto_config_;

std::unique_ptr<TestConfigImpl> config_;
std::shared_ptr<TestConfigImpl> config_;

Buffer::OwnedImpl buffer_;
NiceMock<Network::MockReadFilterCallbacks> filter_callbacks_;
Expand Down
4 changes: 2 additions & 2 deletions contrib/rocketmq_proxy/filters/network/test/router_test.cc
Expand Up @@ -21,7 +21,7 @@ namespace Router {
class RocketmqRouterTestBase {
public:
RocketmqRouterTestBase()
: config_(rocketmq_proxy_config_, context_),
: config_(std::make_shared<ConfigImpl>(rocketmq_proxy_config_, context_)),
cluster_info_(std::make_shared<Upstream::MockClusterInfo>()) {
context_.server_factory_context_.cluster_manager_.initializeThreadLocalClusters(
{"fake_cluster"});
Expand Down Expand Up @@ -147,7 +147,7 @@ class RocketmqRouterTestBase {
NiceMock<Network::MockReadFilterCallbacks> filter_callbacks_;
NiceMock<Server::Configuration::MockFactoryContext> context_;
ConfigImpl::RocketmqProxyConfig rocketmq_proxy_config_;
ConfigImpl config_;
std::shared_ptr<ConfigImpl> config_;
std::unique_ptr<ConnectionManager> conn_manager_;

std::unique_ptr<Router> router_;
Expand Down
2 changes: 1 addition & 1 deletion contrib/sip_proxy/filters/network/source/config.cc
Expand Up @@ -64,7 +64,7 @@ Network::FilterFactoryCb SipProxyFilterConfigFactory::createFilterFactoryFromPro
return
[filter_config, &context, transaction_infos](Network::FilterManager& filter_manager) -> void {
filter_manager.addReadFilter(std::make_shared<ConnectionManager>(
*filter_config, context.serverFactoryContext().api().randomGenerator(),
filter_config, context.serverFactoryContext().api().randomGenerator(),
context.serverFactoryContext().mainThreadDispatcher().timeSource(), context,
transaction_infos));
};
Expand Down
11 changes: 6 additions & 5 deletions contrib/sip_proxy/filters/network/source/conn_manager.cc
Expand Up @@ -174,11 +174,12 @@ void TrafficRoutingAssistantHandler::doSubscribe(
}
}

ConnectionManager::ConnectionManager(Config& config, Random::RandomGenerator& random_generator,
ConnectionManager::ConnectionManager(const ConfigSharedPtr& config,
Random::RandomGenerator& random_generator,
TimeSource& time_source,
Server::Configuration::FactoryContext& context,
std::shared_ptr<Router::TransactionInfos> transaction_infos)
: config_(config), stats_(config_.stats()), decoder_(std::make_unique<Decoder>(*this)),
: config_(config), stats_(config_->stats()), decoder_(std::make_unique<Decoder>(*this)),
random_generator_(random_generator), time_source_(time_source), context_(context),
transaction_infos_(transaction_infos) {}

Expand Down Expand Up @@ -340,7 +341,7 @@ void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbac
time_source_, read_callbacks_->connection().connectionInfoProviderSharedPtr(),
StreamInfo::FilterState::LifeSpan::Connection);
tra_handler_ = std::make_shared<TrafficRoutingAssistantHandler>(
*this, read_callbacks_->connection().dispatcher(), config_.settings()->traServiceConfig(),
*this, read_callbacks_->connection().dispatcher(), config_->settings()->traServiceConfig(),
context_, stream_info);
}

Expand Down Expand Up @@ -504,7 +505,7 @@ FilterStatus ConnectionManager::ActiveTrans::messageEnd() {
}

void ConnectionManager::ActiveTrans::createFilterChain() {
parent_.config_.filterFactory().createFilterChain(*this);
parent_.config_->filterFactory().createFilterChain(*this);
}

void ConnectionManager::ActiveTrans::onReset() { parent_.doDeferredTransDestroy(*this); }
Expand All @@ -526,7 +527,7 @@ const Network::Connection* ConnectionManager::ActiveTrans::connection() const {
Router::RouteConstSharedPtr ConnectionManager::ActiveTrans::route() {
if (!cached_route_) {
if (metadata_ != nullptr) {
Router::RouteConstSharedPtr route = parent_.config_.routerConfig().route(*metadata_);
Router::RouteConstSharedPtr route = parent_.config_->routerConfig().route(*metadata_);
cached_route_ = std::move(route);
} else {
cached_route_ = nullptr;
Expand Down

0 comments on commit bfb5443

Please sign in to comment.