Skip to content

Commit

Permalink
Merge pull request #28985 from vespa-engine/vekterli/improve-content-…
Browse files Browse the repository at this point in the history
…policy-thread-safety

Improve thread safety of MessageBus (DocumentAPI) ContentPolicy
  • Loading branch information
baldersheim authored Oct 17, 2023
2 parents 91ae05e + a6d24b3 commit 771b853
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 43 deletions.
10 changes: 5 additions & 5 deletions documentapi/src/tests/policies/policies_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ Test::requireThatContentPolicyIsRandomWithoutState()
ContentPolicy &policy = setupContentPolicy(
frame, param,
"storage/cluster.mycluster/distributor/*/default", 5);
ASSERT_TRUE(policy.getSystemState() == nullptr);
ASSERT_FALSE(policy.getSystemState());

std::set<string> lst;
for (uint32_t i = 0; i < 666; i++) {
Expand Down Expand Up @@ -858,12 +858,12 @@ Test::requireThatContentPolicyIsTargetedWithState()
"cluster=mycluster;slobroks=tcp/localhost:%d;clusterconfigid=%s;syncinit",
slobrok.port(), getDefaultDistributionConfig(2, 5).c_str());
ContentPolicy &policy = setupContentPolicy(frame, param, "storage/cluster.mycluster/distributor/*/default", 5);
ASSERT_TRUE(policy.getSystemState() == nullptr);
ASSERT_FALSE(policy.getSystemState());
{
std::vector<mbus::RoutingNode*> leaf;
ASSERT_TRUE(frame.select(leaf, 1));
leaf[0]->handleReply(std::make_unique<WrongDistributionReply>("distributor:5 storage:5"));
ASSERT_TRUE(policy.getSystemState() != nullptr);
ASSERT_TRUE(policy.getSystemState());
EXPECT_EQUAL(policy.getSystemState()->toString(), "distributor:5 storage:5");
}
std::set<string> lst;
Expand Down Expand Up @@ -897,12 +897,12 @@ Test::requireThatContentPolicyCombinesSystemAndSlobrokState()
ContentPolicy &policy = setupContentPolicy(
frame, param,
"storage/cluster.mycluster/distributor/*/default", 1);
ASSERT_TRUE(policy.getSystemState() == nullptr);
ASSERT_FALSE(policy.getSystemState());
{
std::vector<mbus::RoutingNode*> leaf;
ASSERT_TRUE(frame.select(leaf, 1));
leaf[0]->handleReply(std::make_unique<WrongDistributionReply>("distributor:99 storage:99"));
ASSERT_TRUE(policy.getSystemState() != nullptr);
ASSERT_TRUE(policy.getSystemState());
EXPECT_EQUAL(policy.getSystemState()->toString(), "distributor:99 storage:99");
}
for (int i = 0; i < 666; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace {
class CallBack : public config::IFetcherCallback<storage::lib::Distribution::DistributionConfig>
{
public:
CallBack(ContentPolicy & policy) : _policy(policy) { }
explicit CallBack(ContentPolicy & policy) : _policy(policy) { }
void configure(std::unique_ptr<storage::lib::Distribution::DistributionConfig> config) override {
_policy.configure(std::move(config));
}
Expand Down Expand Up @@ -78,13 +78,13 @@ string ContentPolicy::init()
ContentPolicy::~ContentPolicy() = default;

string
ContentPolicy::createConfigId(const string & clusterName) const
ContentPolicy::createConfigId(const string & clusterName)
{
return clusterName;
}

string
ContentPolicy::createPattern(const string & clusterName, int distributor) const
ContentPolicy::createPattern(const string & clusterName, int distributor)
{
vespalib::asciistream ost;

Expand All @@ -103,7 +103,8 @@ void
ContentPolicy::configure(std::unique_ptr<vespa::config::content::StorDistributionConfig> config)
{
try {
_nextDistribution = std::make_unique<storage::lib::Distribution>(*config);
std::lock_guard guard(_rw_lock);
_distribution = std::make_unique<storage::lib::Distribution>(*config);
} catch (const std::exception& e) {
LOG(warning, "Got exception when configuring distribution, config id was %s", _clusterConfigId.c_str());
throw e;
Expand All @@ -116,8 +117,9 @@ ContentPolicy::doSelect(mbus::RoutingContext &context)
const mbus::Message &msg = context.getMessage();

int distributor = -1;
auto [cur_state, cur_distribution] = internal_state_snapshot();

if (_state.get()) {
if (cur_state) {
document::BucketId id;
switch(msg.getType()) {
case DocumentProtocol::MESSAGE_PUTDOCUMENT:
Expand Down Expand Up @@ -168,15 +170,10 @@ ContentPolicy::doSelect(mbus::RoutingContext &context)

// Pick a distributor using ideal state algorithm
try {
// Update distribution here, to make it not take lock in average case
if (_nextDistribution) {
_distribution = std::move(_nextDistribution);
_nextDistribution.reset();
}
assert(_distribution.get());
distributor = _distribution->getIdealDistributorNode(*_state, id);
assert(cur_distribution);
distributor = cur_distribution->getIdealDistributorNode(*cur_state, id);
} catch (storage::lib::TooFewBucketBitsInUseException& e) {
auto reply = std::make_unique<WrongDistributionReply>(_state->toString());
auto reply = std::make_unique<WrongDistributionReply>(cur_state->toString());
reply->addError(mbus::Error(
DocumentProtocol::ERROR_WRONG_DISTRIBUTION,
"Too few distribution bits used for given cluster state"));
Expand All @@ -185,7 +182,7 @@ ContentPolicy::doSelect(mbus::RoutingContext &context)
} catch (storage::lib::NoDistributorsAvailableException& e) {
// No distributors available in current cluster state. Remove
// cluster state we cannot use and send to random target
_state.reset();
reset_state();
distributor = -1;
}
}
Expand Down Expand Up @@ -216,7 +213,7 @@ ContentPolicy::getRecipient(mbus::RoutingContext& context, int distributor)
return mbus::Hop::parse(entries[random() % entries.size()].second + "/default");
}

return mbus::Hop();
return {};
}

void
Expand All @@ -226,9 +223,9 @@ ContentPolicy::merge(mbus::RoutingContext &context)
mbus::Reply::UP reply = it.removeReply();

if (reply->getType() == DocumentProtocol::REPLY_WRONGDISTRIBUTION) {
updateStateFromReply(static_cast<WrongDistributionReply&>(*reply));
updateStateFromReply(dynamic_cast<WrongDistributionReply&>(*reply));
} else if (reply->hasErrors()) {
_state.reset();
reset_state();
}

context.setReply(std::move(reply));
Expand All @@ -237,8 +234,8 @@ ContentPolicy::merge(mbus::RoutingContext &context)
void
ContentPolicy::updateStateFromReply(WrongDistributionReply& wdr)
{
std::unique_ptr<storage::lib::ClusterState> newState(
new storage::lib::ClusterState(wdr.getSystemState()));
auto newState = std::make_unique<storage::lib::ClusterState>(wdr.getSystemState());
std::lock_guard guard(_rw_lock);
if (!_state || newState->getVersion() >= _state->getVersion()) {
if (_state) {
wdr.getTrace().trace(1, make_string("System state changed from version %u to %u",
Expand All @@ -256,4 +253,28 @@ ContentPolicy::updateStateFromReply(WrongDistributionReply& wdr)
}
}

ContentPolicy::StateSnapshot
ContentPolicy::internal_state_snapshot()
{
std::shared_lock guard(_rw_lock);
return {_state, _distribution};
}

std::shared_ptr<const storage::lib::ClusterState>
ContentPolicy::getSystemState() const noexcept
{
std::shared_lock guard(_rw_lock);
return _state;
}

void
ContentPolicy::reset_state()
{
// It's possible for the caller to race between checking and resetting the state,
// but this should never lead to a worse outcome than sending to a random distributor
// as if no state had been cached prior.
std::lock_guard guard(_rw_lock);
_state.reset();
}

} // documentapi
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,62 @@
#include <vespa/vdslib/distribution/distribution.h>
#include <vespa/document/bucket/bucketidfactory.h>
#include <vespa/messagebus/routing/hop.h>
#include <shared_mutex>

namespace config {
class ICallback;
class ConfigFetcher;
}

namespace storage {
namespace lib {
namespace storage::lib {
class Distribution;
class ClusterState;
}
}

namespace documentapi {

class ContentPolicy : public ExternSlobrokPolicy
{
private:
document::BucketIdFactory _bucketIdFactory;
std::unique_ptr<storage::lib::ClusterState> _state;
string _clusterName;
string _clusterConfigId;
std::unique_ptr<config::ICallback> _callBack;
std::unique_ptr<config::ConfigFetcher> _configFetcher;
std::unique_ptr<storage::lib::Distribution> _distribution;
std::unique_ptr<storage::lib::Distribution> _nextDistribution;
document::BucketIdFactory _bucketIdFactory;
mutable std::shared_mutex _rw_lock;
std::shared_ptr<const storage::lib::ClusterState> _state;
string _clusterName;
string _clusterConfigId;
std::unique_ptr<config::ICallback> _callBack;
std::unique_ptr<config::ConfigFetcher> _configFetcher;
std::shared_ptr<const storage::lib::Distribution> _distribution;

using StateSnapshot = std::pair<std::shared_ptr<const storage::lib::ClusterState>,
std::shared_ptr<const storage::lib::Distribution>>;

// Acquires _lock
[[nodiscard]] StateSnapshot internal_state_snapshot();

mbus::Hop getRecipient(mbus::RoutingContext& context, int distributor);
// Acquires _lock
void updateStateFromReply(WrongDistributionReply& reply);
// Acquires _lock
void reset_state();

public:
ContentPolicy(const string& param);
~ContentPolicy();
explicit ContentPolicy(const string& param);
~ContentPolicy() override;
void doSelect(mbus::RoutingContext &context) override;
void merge(mbus::RoutingContext &context) override;

void updateStateFromReply(WrongDistributionReply& reply);

/**
* @return a pointer to the system state registered with this policy. If
* we haven't received a system state yet, returns NULL.
* we haven't received a system state yet, returns nullptr.
*/
const storage::lib::ClusterState* getSystemState() const { return _state.get(); }
std::shared_ptr<const storage::lib::ClusterState> getSystemState() const noexcept;

virtual void configure(std::unique_ptr<storage::lib::Distribution::DistributionConfig> config);
string init() override;

private:
string createConfigId(const string & clusterName) const;
string createPattern(const string & clusterName, int distributor) const;
static string createConfigId(const string & clusterName);
static string createPattern(const string & clusterName, int distributor);
};

}

0 comments on commit 771b853

Please sign in to comment.