diff --git a/src/mcts/params.cc b/src/mcts/params.cc index 22310477e1..5d9b58eee4 100644 --- a/src/mcts/params.cc +++ b/src/mcts/params.cc @@ -299,6 +299,14 @@ const OptionId SearchParams::kPolicySoftmaxTempId{ "policy-softmax-temp", "PolicyTemperature", "Policy softmax temperature. Higher values make priors of move candidates " "closer to each other, widening the search."}; +const OptionId SearchParams::kPolicyDecayExponentId{ + "policy-decay-exponent", "PolicyDecayExponent", + "Policy decay exponent. Sets the exponent of the visit based policy decay " + "term."}; +const OptionId SearchParams::kPolicyDecayFactorId{ + "policy-decay-factor", "PolicyDecayFactor", + "Policy decay factor. Scales the visit count for the visit based policy " + "decay term."}; const OptionId SearchParams::kMaxCollisionVisitsId{ "max-collision-visits", "MaxCollisionVisits", "Total allowed node collision visits, per batch."}; @@ -522,6 +530,8 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kFpuValueAtRootId, -100.0f, 100.0f) = 1.0f; options->Add(kCacheHistoryLengthId, 0, 7) = 0; options->Add(kPolicySoftmaxTempId, 0.1f, 10.0f) = 1.359f; + options->Add(kPolicyDecayExponentId, 0.0f, 10.0f) = 0.5f; + options->Add(kPolicyDecayFactorId, 0.0f, 1.0f) = 0.0001f; options->Add(kMaxCollisionEventsId, 1, 65536) = 917; options->Add(kMaxCollisionVisitsId, 1, 100000000) = 80000; options->Add(kMaxCollisionVisitsScalingStartId, 1, 100000) = 28; @@ -637,6 +647,8 @@ SearchParams::SearchParams(const OptionsDict& options) : options.Get(kFpuValueAtRootId)), kCacheHistoryLength(options.Get(kCacheHistoryLengthId)), kPolicySoftmaxTemp(options.Get(kPolicySoftmaxTempId)), + kPolicyDecayExponent(options.Get(kPolicyDecayExponentId)), + kPolicyDecayFactor(options.Get(kPolicyDecayFactorId)), kMaxCollisionEvents(options.Get(kMaxCollisionEventsId)), kMaxCollisionVisits(options.Get(kMaxCollisionVisitsId)), kOutOfOrderEval(options.Get(kOutOfOrderEvalId)), diff --git a/src/mcts/params.h b/src/mcts/params.h index df02f124fb..0a59f66a37 100644 --- a/src/mcts/params.h +++ b/src/mcts/params.h @@ -95,6 +95,8 @@ class SearchParams { } int GetCacheHistoryLength() const { return kCacheHistoryLength; } float GetPolicySoftmaxTemp() const { return kPolicySoftmaxTemp; } + float GetPolicyDecayExponent() const { return kPolicyDecayExponent; } + float GetPolicyDecayFactor() const { return kPolicyDecayFactor; } int GetMaxCollisionEvents() const { return kMaxCollisionEvents; } int GetMaxCollisionVisits() const { return kMaxCollisionVisits; } bool GetOutOfOrderEval() const { return kOutOfOrderEval; } @@ -191,6 +193,8 @@ class SearchParams { static const OptionId kFpuValueAtRootId; static const OptionId kCacheHistoryLengthId; static const OptionId kPolicySoftmaxTempId; + static const OptionId kPolicyDecayExponentId; + static const OptionId kPolicyDecayFactorId; static const OptionId kMaxCollisionEventsId; static const OptionId kMaxCollisionVisitsId; static const OptionId kOutOfOrderEvalId; @@ -259,6 +263,8 @@ class SearchParams { const float kFpuValueAtRoot; const int kCacheHistoryLength; const float kPolicySoftmaxTemp; + const float kPolicyDecayExponent; + const float kPolicyDecayFactor; const int kMaxCollisionEvents; const int kMaxCollisionVisits; const bool kOutOfOrderEval; diff --git a/src/mcts/search.cc b/src/mcts/search.cc index b3326a7661..884175e100 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -456,6 +456,18 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N, const float base = params.GetCpuctBase(is_root_node); return init + (k ? k * FastLog((N + base) / base) : 0.0f); } + +inline float ComputePolicyDecayFactor(const SearchParams& params, uint32_t N) { + const float exponent = params.GetPolicyDecayExponent(); + const float proportionality_factor = params.GetPolicyDecayFactor(); + return (exponent == 0.0f || proportionality_factor == 0.0f) + ? 1.0f + : FastExp(-FastLog(1.0f + proportionality_factor * N) * exponent); +} + +inline float ComputePolicyDecay(const float factor, const float pol) { + return factor == 1.0f ? pol : pol / (pol + (1.0f - pol) * factor); +} } // namespace std::vector Search::GetVerboseStats(Node* node) const { @@ -1720,10 +1732,17 @@ void SearchWorker::PickNodesToExtendTask( ? odd_draw_score : even_draw_score; m_evaluator.SetParent(node); + // Store the policy decay factor here since it is only dependent on the + // visit count of node and not of children. + const float policy_decay_factor = + ComputePolicyDecayFactor(params_, node->GetN()); float visited_pol = 0.0f; for (Node* child : node->VisitedNodes()) { int index = child->Index(); visited_pol += current_pol[index]; + // Apply policy decay and store the value. + current_pol[index] = ComputePolicyDecay(policy_decay_factor, + current_pol[index]); float q = child->GetQ(draw_score); current_util[index] = q + m_evaluator.GetMUtility(child, q); }