Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config: hold shared_ptr to configs to avoid use-after-free #33826

Merged
merged 1 commit into from Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/kafka/filters/network/source/broker/config.cc
Expand Up @@ -20,7 +20,7 @@ Network::FilterFactoryCb KafkaConfigFactory::createFilterFactoryFromProtoTyped(
std::make_shared<BrokerFilterConfig>(proto_config);
return [&context, filter_config](Network::FilterManager& filter_manager) -> void {
Network::FilterSharedPtr filter = std::make_shared<KafkaBrokerFilter>(
context.scope(), context.serverFactoryContext().timeSource(), *filter_config);
context.scope(), context.serverFactoryContext().timeSource(), filter_config);
filter_manager.addFilter(filter);
};
}
Expand Down
6 changes: 3 additions & 3 deletions contrib/kafka/filters/network/source/broker/filter.cc
Expand Up @@ -70,11 +70,11 @@ absl::flat_hash_map<int32_t, MonotonicTime>& KafkaMetricsFacadeImpl::getRequestA
}

KafkaBrokerFilter::KafkaBrokerFilter(Stats::Scope& scope, TimeSource& time_source,
const BrokerFilterConfig& filter_config)
const BrokerFilterConfigSharedPtr& filter_config)
: KafkaBrokerFilter{filter_config, std::make_shared<KafkaMetricsFacadeImpl>(
scope, time_source, filter_config.statPrefix())} {};
scope, time_source, filter_config->statPrefix())} {};

KafkaBrokerFilter::KafkaBrokerFilter(const BrokerFilterConfig& filter_config,
KafkaBrokerFilter::KafkaBrokerFilter(const BrokerFilterConfigSharedPtr& filter_config,
const KafkaMetricsFacadeSharedPtr& metrics)
: metrics_{metrics}, response_rewriter_{createRewriter(filter_config)},
response_decoder_{new ResponseDecoder({metrics, response_rewriter_})},
Expand Down
4 changes: 2 additions & 2 deletions contrib/kafka/filters/network/source/broker/filter.h
Expand Up @@ -147,7 +147,7 @@ class KafkaBrokerFilter : public Network::Filter, private Logger::Loggable<Logge
* duration calculation.
*/
KafkaBrokerFilter(Stats::Scope& scope, TimeSource& time_source,
const BrokerFilterConfig& filter_config);
const BrokerFilterConfigSharedPtr& filter_config);

/**
* Visible for testing.
Expand All @@ -173,7 +173,7 @@ class KafkaBrokerFilter : public Network::Filter, private Logger::Loggable<Logge
* Helper delegate constructor.
* Passes metrics facade as argument to decoders.
*/
KafkaBrokerFilter(const BrokerFilterConfig& filter_config,
KafkaBrokerFilter(const BrokerFilterConfigSharedPtr& filter_config,
const KafkaMetricsFacadeSharedPtr& metrics);

const KafkaMetricsFacadeSharedPtr metrics_;
Expand Down
7 changes: 4 additions & 3 deletions contrib/kafka/filters/network/source/broker/rewriter.cc
Expand Up @@ -8,7 +8,8 @@ namespace Broker {

// ResponseRewriterImpl.

ResponseRewriterImpl::ResponseRewriterImpl(const BrokerFilterConfig& config) : config_{config} {};
ResponseRewriterImpl::ResponseRewriterImpl(const BrokerFilterConfigSharedPtr& config)
: config_{config} {};

void ResponseRewriterImpl::onMessage(AbstractResponseSharedPtr response) {
responses_to_rewrite_.push_back(response);
Expand Down Expand Up @@ -92,8 +93,8 @@ void DoNothingRewriter::process(Buffer::Instance&) {}

// Factory method.

ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfig& config) {
if (config.needsResponseRewrite()) {
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfigSharedPtr& config) {
if (config->needsResponseRewrite()) {
return std::make_shared<ResponseRewriterImpl>(config);
} else {
return std::make_shared<DoNothingRewriter>();
Expand Down
8 changes: 4 additions & 4 deletions contrib/kafka/filters/network/source/broker/rewriter.h
Expand Up @@ -38,7 +38,7 @@ using ResponseRewriterSharedPtr = std::shared_ptr<ResponseRewriter>;
*/
class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<Logger::Id::kafka> {
public:
ResponseRewriterImpl(const BrokerFilterConfig& config);
ResponseRewriterImpl(const BrokerFilterConfigSharedPtr& config);

// ResponseCallback
void onMessage(AbstractResponseSharedPtr response) override;
Expand Down Expand Up @@ -69,7 +69,7 @@ class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<L
// Pointer-to-member used to handle varying field names across the structs.
template <typename T> void maybeUpdateHostAndPort(T& arg, const int32_t T::*node_id_field) const {
const int32_t node_id = arg.*node_id_field;
const absl::optional<HostAndPort> hostAndPort = config_.findBrokerAddressOverride(node_id);
const absl::optional<HostAndPort> hostAndPort = config_->findBrokerAddressOverride(node_id);
if (hostAndPort) {
ENVOY_LOG(trace, "Changing broker [{}] from {}:{} to {}:{}", node_id, arg.host_, arg.port_,
hostAndPort->first, hostAndPort->second);
Expand All @@ -78,7 +78,7 @@ class ResponseRewriterImpl : public ResponseRewriter, private Logger::Loggable<L
}
}

const BrokerFilterConfig& config_;
const BrokerFilterConfigSharedPtr config_;
std::vector<AbstractResponseSharedPtr> responses_to_rewrite_;
};

Expand All @@ -99,7 +99,7 @@ class DoNothingRewriter : public ResponseRewriter {
/**
* Factory method that creates a rewriter depending on configuration.
*/
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfig& config);
ResponseRewriterSharedPtr createRewriter(const BrokerFilterConfigSharedPtr& config);

} // namespace Broker
} // namespace Kafka
Expand Down
Expand Up @@ -36,7 +36,9 @@ class KafkaBrokerFilterProtocolTest : public testing::Test,
Stats::TestUtil::TestStore store_;
Stats::Scope& scope_{*store_.rootScope()};
Event::TestRealTimeSystem time_source_;
KafkaBrokerFilter testee_{scope_, time_source_, BrokerFilterConfig{"prefix", false, {}}};
KafkaBrokerFilter testee_{scope_, time_source_,
std::make_shared<BrokerFilterConfig>(std::string("prefix"), false,
std::vector<RewriteRule>{})};

Network::FilterStatus consumeRequestFromBuffer() {
return testee_.onData(RequestB::buffer_, false);
Expand Down
36 changes: 18 additions & 18 deletions contrib/kafka/filters/network/test/broker/rewriter_unit_test.cc
Expand Up @@ -43,7 +43,7 @@ class FakeResponse : public AbstractResponse {

TEST(ResponseRewriterImplUnitTest, ShouldRewriteBuffer) {
// given
ResponseRewriterImpl testee{MockBrokerFilterConfig{}};
ResponseRewriterImpl testee{std::make_shared<MockBrokerFilterConfig>()};

auto response1 = std::make_shared<FakeResponse>(7);
auto response2 = std::make_shared<FakeResponse>(13);
Expand Down Expand Up @@ -80,13 +80,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteMetadataResponse) {
std::vector<MetadataResponseBroker> brokers = {b1, b2, b3};
MetadataResponse mr = {brokers, {}};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> r1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(b1.node_id_)).WillOnce(Return(r1));
EXPECT_CALL(*config, findBrokerAddressOverride(b1.node_id_)).WillOnce(Return(r1));
absl::optional<HostAndPort> r2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(b2.node_id_)).WillOnce(Return(r2));
EXPECT_CALL(*config, findBrokerAddressOverride(b2.node_id_)).WillOnce(Return(r2));
absl::optional<HostAndPort> r3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(b3.node_id_)).WillOnce(Return(r3));
EXPECT_CALL(*config, findBrokerAddressOverride(b3.node_id_)).WillOnce(Return(r3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -107,15 +107,15 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteFindCoordinatorResponse) {
Coordinator c3 = {"k3", 3, "ch3", 4444, 0, {}, {}};
fcr.coordinators_ = {c1, c2, c3};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> fcrhp = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(fcr.node_id_)).WillOnce(Return(fcrhp));
EXPECT_CALL(*config, findBrokerAddressOverride(fcr.node_id_)).WillOnce(Return(fcrhp));
absl::optional<HostAndPort> cr1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(c1.node_id_)).WillOnce(Return(cr1));
EXPECT_CALL(*config, findBrokerAddressOverride(c1.node_id_)).WillOnce(Return(cr1));
absl::optional<HostAndPort> cr2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(c2.node_id_)).WillOnce(Return(cr2));
EXPECT_CALL(*config, findBrokerAddressOverride(c2.node_id_)).WillOnce(Return(cr2));
absl::optional<HostAndPort> cr3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(c3.node_id_)).WillOnce(Return(cr3));
EXPECT_CALL(*config, findBrokerAddressOverride(c3.node_id_)).WillOnce(Return(cr3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -136,13 +136,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteDescribeClusterResponse) {
std::vector<DescribeClusterBroker> brokers = {b1, b2, b3};
DescribeClusterResponse dcr = {0, 0, absl::nullopt, "", 0, brokers, 0, {}};

MockBrokerFilterConfig config;
auto config = std::make_shared<MockBrokerFilterConfig>();
absl::optional<HostAndPort> cr1 = {{"nh1", 4444}};
EXPECT_CALL(config, findBrokerAddressOverride(b1.broker_id_)).WillOnce(Return(cr1));
EXPECT_CALL(*config, findBrokerAddressOverride(b1.broker_id_)).WillOnce(Return(cr1));
absl::optional<HostAndPort> cr2 = absl::nullopt;
EXPECT_CALL(config, findBrokerAddressOverride(b2.broker_id_)).WillOnce(Return(cr2));
EXPECT_CALL(*config, findBrokerAddressOverride(b2.broker_id_)).WillOnce(Return(cr2));
absl::optional<HostAndPort> cr3 = {{"nh3", 6666}};
EXPECT_CALL(config, findBrokerAddressOverride(b3.broker_id_)).WillOnce(Return(cr3));
EXPECT_CALL(*config, findBrokerAddressOverride(b3.broker_id_)).WillOnce(Return(cr3));
ResponseRewriterImpl testee{config};

// when
Expand All @@ -155,13 +155,13 @@ TEST(ResponseRewriterImplUnitTest, ShouldRewriteDescribeClusterResponse) {
}

TEST(ResponseRewriterUnitTest, ShouldCreateProperRewriter) {
MockBrokerFilterConfig c1;
EXPECT_CALL(c1, needsResponseRewrite()).WillOnce(Return(true));
auto c1 = std::make_shared<MockBrokerFilterConfig>();
EXPECT_CALL(*c1, needsResponseRewrite()).WillOnce(Return(true));
ResponseRewriterSharedPtr r1 = createRewriter(c1);
ASSERT_NE(std::dynamic_pointer_cast<ResponseRewriterImpl>(r1), nullptr);

MockBrokerFilterConfig c2;
EXPECT_CALL(c2, needsResponseRewrite()).WillOnce(Return(false));
auto c2 = std::make_shared<MockBrokerFilterConfig>();
EXPECT_CALL(*c2, needsResponseRewrite()).WillOnce(Return(false));
ResponseRewriterSharedPtr r2 = createRewriter(c2);
ASSERT_NE(std::dynamic_pointer_cast<DoNothingRewriter>(r2), nullptr);
}
Expand Down
2 changes: 1 addition & 1 deletion contrib/rocketmq_proxy/filters/network/source/config.cc
Expand Up @@ -24,7 +24,7 @@ Network::FilterFactoryCb RocketmqProxyFilterConfigFactory::createFilterFactoryFr
std::shared_ptr<ConfigImpl> filter_config = std::make_shared<ConfigImpl>(proto_config, context);
return [filter_config, &context](Network::FilterManager& filter_manager) -> void {
filter_manager.addReadFilter(std::make_shared<ConnectionManager>(
*filter_config, context.serverFactoryContext().mainThreadDispatcher().timeSource()));
filter_config, context.serverFactoryContext().mainThreadDispatcher().timeSource()));
};
}

Expand Down
Expand Up @@ -24,8 +24,8 @@ bool ConsumerGroupMember::expired() const {
connection_manager_->config().transientObjectLifeSpan().count();
}

ConnectionManager::ConnectionManager(Config& config, TimeSource& time_source)
: config_(config), time_source_(time_source), stats_(config.stats()) {}
ConnectionManager::ConnectionManager(const ConfigSharedPtr& config, TimeSource& time_source)
: config_(config), time_source_(time_source), stats_(config_->stats()) {}

Envoy::Network::FilterStatus ConnectionManager::onData(Envoy::Buffer::Instance& data,
bool end_stream) {
Expand Down Expand Up @@ -137,7 +137,7 @@ void ConnectionManager::purgeDirectiveTable() {
for (auto it = ack_directive_table_.begin(); it != ack_directive_table_.end();) {
auto duration = current - it->second.creation_time_;
if (std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() >
config_.transientObjectLifeSpan().count()) {
config_->transientObjectLifeSpan().count()) {
ack_directive_table_.erase(it++);
} else {
it++;
Expand Down
Expand Up @@ -52,6 +52,8 @@ class Config {
virtual std::chrono::milliseconds transientObjectLifeSpan() const PURE;
};

using ConfigSharedPtr = std::shared_ptr<Config>;

class ConnectionManager;

/**
Expand Down Expand Up @@ -82,7 +84,7 @@ class ConsumerGroupMember {

class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::Id::filter> {
public:
ConnectionManager(Config& config, TimeSource& time_source);
ConnectionManager(const ConfigSharedPtr& config, TimeSource& time_source);

~ConnectionManager() override = default;

Expand Down Expand Up @@ -151,7 +153,7 @@ class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::I

void resetAllActiveMessages(absl::string_view error_msg);

Config& config() { return config_; }
Config& config() { return *config_; }

RocketmqFilterStats& stats() { return stats_; }

Expand Down Expand Up @@ -194,7 +196,7 @@ class ConnectionManager : public Network::ReadFilter, Logger::Loggable<Logger::I
Network::ReadFilterCallbacks* read_callbacks_{};
Buffer::OwnedImpl request_buffer_;

Config& config_;
ConfigSharedPtr config_;
TimeSource& time_source_;
RocketmqFilterStats& stats_;

Expand Down
Expand Up @@ -23,7 +23,7 @@ class ActiveMessageTest : public testing::Test {
public:
ActiveMessageTest()
: stats_(RocketmqFilterStats::generateStats("test.", *store_.rootScope())),
config_(rocketmq_proxy_config_, factory_context_),
config_(std::make_shared<ConfigImpl>(rocketmq_proxy_config_, factory_context_)),
connection_manager_(
config_, factory_context_.serverFactoryContext().mainThreadDispatcher().timeSource()) {
connection_manager_.initializeReadFilterCallbacks(filter_callbacks_);
Expand All @@ -39,7 +39,7 @@ class ActiveMessageTest : public testing::Test {
NiceMock<Server::Configuration::MockFactoryContext> factory_context_;
Stats::IsolatedStoreImpl store_;
RocketmqFilterStats stats_;
ConfigImpl config_;
std::shared_ptr<ConfigImpl> config_;
ConnectionManager connection_manager_;
};

Expand Down
Expand Up @@ -54,9 +54,9 @@ class RocketmqConnectionManagerTest : public Event::TestUsingSimulatedTime, publ
TestUtility::loadFromYaml(yaml, proto_config_);
TestUtility::validate(proto_config_);
}
config_ = std::make_unique<TestConfigImpl>(proto_config_, factory_context_, stats_);
config_ = std::make_shared<TestConfigImpl>(proto_config_, factory_context_, stats_);
conn_manager_ = std::make_unique<ConnectionManager>(
*config_, factory_context_.server_factory_context_.mainThreadDispatcher().timeSource());
config_, factory_context_.server_factory_context_.mainThreadDispatcher().timeSource());
conn_manager_->initializeReadFilterCallbacks(filter_callbacks_);
conn_manager_->onNewConnection();
current_ = factory_context_.server_factory_context_.mainThreadDispatcher()
Expand Down Expand Up @@ -84,7 +84,7 @@ class RocketmqConnectionManagerTest : public Event::TestUsingSimulatedTime, publ
RocketmqFilterStats stats_;
ConfigRocketmqProxy proto_config_;

std::unique_ptr<TestConfigImpl> config_;
std::shared_ptr<TestConfigImpl> config_;

Buffer::OwnedImpl buffer_;
NiceMock<Network::MockReadFilterCallbacks> filter_callbacks_;
Expand Down
4 changes: 2 additions & 2 deletions contrib/rocketmq_proxy/filters/network/test/router_test.cc
Expand Up @@ -21,7 +21,7 @@ namespace Router {
class RocketmqRouterTestBase {
public:
RocketmqRouterTestBase()
: config_(rocketmq_proxy_config_, context_),
: config_(std::make_shared<ConfigImpl>(rocketmq_proxy_config_, context_)),
cluster_info_(std::make_shared<Upstream::MockClusterInfo>()) {
context_.server_factory_context_.cluster_manager_.initializeThreadLocalClusters(
{"fake_cluster"});
Expand Down Expand Up @@ -147,7 +147,7 @@ class RocketmqRouterTestBase {
NiceMock<Network::MockReadFilterCallbacks> filter_callbacks_;
NiceMock<Server::Configuration::MockFactoryContext> context_;
ConfigImpl::RocketmqProxyConfig rocketmq_proxy_config_;
ConfigImpl config_;
std::shared_ptr<ConfigImpl> config_;
std::unique_ptr<ConnectionManager> conn_manager_;

std::unique_ptr<Router> router_;
Expand Down
2 changes: 1 addition & 1 deletion contrib/sip_proxy/filters/network/source/config.cc
Expand Up @@ -64,7 +64,7 @@ Network::FilterFactoryCb SipProxyFilterConfigFactory::createFilterFactoryFromPro
return
[filter_config, &context, transaction_infos](Network::FilterManager& filter_manager) -> void {
filter_manager.addReadFilter(std::make_shared<ConnectionManager>(
*filter_config, context.serverFactoryContext().api().randomGenerator(),
filter_config, context.serverFactoryContext().api().randomGenerator(),
context.serverFactoryContext().mainThreadDispatcher().timeSource(), context,
transaction_infos));
};
Expand Down
11 changes: 6 additions & 5 deletions contrib/sip_proxy/filters/network/source/conn_manager.cc
Expand Up @@ -174,11 +174,12 @@ void TrafficRoutingAssistantHandler::doSubscribe(
}
}

ConnectionManager::ConnectionManager(Config& config, Random::RandomGenerator& random_generator,
ConnectionManager::ConnectionManager(const ConfigSharedPtr& config,
Random::RandomGenerator& random_generator,
TimeSource& time_source,
Server::Configuration::FactoryContext& context,
std::shared_ptr<Router::TransactionInfos> transaction_infos)
: config_(config), stats_(config_.stats()), decoder_(std::make_unique<Decoder>(*this)),
: config_(config), stats_(config_->stats()), decoder_(std::make_unique<Decoder>(*this)),
random_generator_(random_generator), time_source_(time_source), context_(context),
transaction_infos_(transaction_infos) {}

Expand Down Expand Up @@ -340,7 +341,7 @@ void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbac
time_source_, read_callbacks_->connection().connectionInfoProviderSharedPtr(),
StreamInfo::FilterState::LifeSpan::Connection);
tra_handler_ = std::make_shared<TrafficRoutingAssistantHandler>(
*this, read_callbacks_->connection().dispatcher(), config_.settings()->traServiceConfig(),
*this, read_callbacks_->connection().dispatcher(), config_->settings()->traServiceConfig(),
context_, stream_info);
}

Expand Down Expand Up @@ -504,7 +505,7 @@ FilterStatus ConnectionManager::ActiveTrans::messageEnd() {
}

void ConnectionManager::ActiveTrans::createFilterChain() {
parent_.config_.filterFactory().createFilterChain(*this);
parent_.config_->filterFactory().createFilterChain(*this);
}

void ConnectionManager::ActiveTrans::onReset() { parent_.doDeferredTransDestroy(*this); }
Expand All @@ -526,7 +527,7 @@ const Network::Connection* ConnectionManager::ActiveTrans::connection() const {
Router::RouteConstSharedPtr ConnectionManager::ActiveTrans::route() {
if (!cached_route_) {
if (metadata_ != nullptr) {
Router::RouteConstSharedPtr route = parent_.config_.routerConfig().route(*metadata_);
Router::RouteConstSharedPtr route = parent_.config_->routerConfig().route(*metadata_);
cached_route_ = std::move(route);
} else {
cached_route_ = nullptr;
Expand Down