From 3a2ff8307436036d6b0263af8184a127377a8377 Mon Sep 17 00:00:00 2001 From: "Fernando J. Iglesias Garcia" Date: Tue, 27 Aug 2024 14:18:10 +0200 Subject: [PATCH] Shortest path maximizing probability --- leetcode/1514.cpp | 63 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 leetcode/1514.cpp diff --git a/leetcode/1514.cpp b/leetcode/1514.cpp new file mode 100644 index 0000000..f4a47d3 --- /dev/null +++ b/leetcode/1514.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include +#include + +#include + +using std::vector; +using std::unordered_map; + +double max_probability([[maybe_unused]] int n, const vector>& edges, const vector& probs, int start_node, int end_node) { + std::priority_queue, vector>, std::greater>> pq; + + unordered_map dist; + dist[start_node] = .0; + pq.emplace(0., start_node); + + double ans = .0; + + unordered_map> edges_map; + for (size_t i = 0; i < edges.size(); i++) { + edges_map[edges[i][0]].emplace(edges[i][1], -std::log(probs[i])); + edges_map[edges[i][1]].emplace(edges[i][0], -std::log(probs[i])); + } + + while(!pq.empty()) { + auto const top{pq.top()}; + pq.pop(); + for (auto const& [dst, cost] : edges_map[top.second]) { + if (dst == end_node) ans = std::max(ans, std::exp(-top.first-cost)); + if (!dist.contains(dst) or dist[dst] > top.first + cost) { + dist[dst] = top.first + cost; + pq.emplace(top.first + cost, dst); + } + } + } + + return ans; +} + +TEST(path_with_maximum_probability, a) { + const vector> edges{{0, 1}, {1, 2}, {0, 2}}; + const vector probs{.5, .5, .2}; + EXPECT_EQ(max_probability(3, edges, probs, 0, 2), .25); +} + +TEST(path_with_maximum_probability, b) { + const vector> edges{{0, 1}, {1, 2}, {0, 2}}; + const vector probs{.5, .5, .3}; + EXPECT_EQ(max_probability(3, edges, probs, 0, 2), .3); +} + +TEST(path_with_maximum_probability, c) { + const vector> edges{{0, 1}}; + const vector probs{.5}; + EXPECT_EQ(max_probability(3, edges, probs, 0, 2), .0); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}