Skip to content

Commit

Permalink
more optim
Browse files Browse the repository at this point in the history
  • Loading branch information
hmusta committed Aug 28, 2024
1 parent a70cae5 commit 4c2ceb4
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions metagraph/src/graph/annotated_graph_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,16 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
common::logger->trace(" Scaled Totals: in: {}\tout: {}", in_kmers, out_kmers);

auto compute_min_pval_r = [](double r, double r_in, double r_out, double lscaling_base, int64_t n) {
double half_div0 = -r_in * log(r_in) - (n + r_out) * log(n + r_out);
double half_divn = -r_out * log(r_out) - (n + r_in) * log(n + r_in);
double half_div0 = -r_in * log2(r_in) - (n + r_out) * log2(n + r_out);
double half_divn = -r_out * log2(r_out) - (n + r_in) * log2(n + r_in);

double pval = 0.0;
double lscaling = lscaling_base - lgamma(r + n);
double lscaling = lscaling_base - lgamma(r + n) / log(2);
if (half_div0 >= half_divn)
pval += exp(lscaling + lgamma(r_out + n) - lgamma(r_out));
pval += exp2(lscaling + (lgamma(r_out + n) - lgamma(r_out)) / log(2));

if (half_divn >= half_div0)
pval += exp(lscaling + lgamma(r_in + n) - lgamma(r_in));
pval += exp2(lscaling + (lgamma(r_in + n) - lgamma(r_in)) / log(2));

return std::min(1.0, pval);
};
Expand All @@ -640,7 +640,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
if (in_sum == argmin_d)
return 1.0;

double lscaling = lscaling_base - lgamma(r + n);
double lscaling = lscaling_base - lgamma(r + n) / log(2);

auto get_pval = [&](const auto &get_stat) {
double base_stat = get_stat(in_sum, out_sum,
Expand All @@ -652,7 +652,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
double ls = 0;
double lt = log2(t);
if (get_stat(s, t, ls, lt) >= base_stat) {
double base = (lscaling + lgamma(n + r_out) - lgamma(r_out)) / log(2);
double base = lscaling + (lgamma(n + r_out) - lgamma(r_out)) / log(2);
pval += exp2(base);
for (++s,--t; s <= n; ++s,--t) {
ls = log2(s);
Expand All @@ -670,7 +670,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
ls = log2(sp);
lt = 0;
if (get_stat(sp, t, ls, lt) >= base_stat) {
double base = (lscaling + lgamma(n + r_in) - lgamma(r_in)) / log(2);
double base = lscaling + (lgamma(n + r_in) - lgamma(r_in)) / log(2);
pval += exp2(base);
for (--sp,++t; sp >= s; --sp,++t) {
ls = sp > 0 ? log2(sp) : 0.0;
Expand All @@ -697,7 +697,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,

if (config.test_type == "nbinom_exact") {
double r = r_in + r_out;
double lscaling_base = lgamma(r);
double lscaling_base = lgamma(r) / log(2);
compute_min_pval = [compute_min_pval_r,lscaling_base,r,r_in,r_out](int64_t n, const PairContainer&) {
return n > 0 ? compute_min_pval_r(r, r_in, r_out, lscaling_base, n) : 1.0;
};
Expand Down Expand Up @@ -914,7 +914,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
double r_in = r_base * num_labels_in;
double r_out = r_base * num_labels_out;
double r = r_in + r_out;
double lscaling_base = lgamma(r);
double lscaling_base = lgamma(r) / log(2);

return compute_min_pval_r(r, r_in, r_out, lscaling_base, n);
};
Expand All @@ -935,7 +935,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
double r_in = r_base * num_labels_in;
double r_out = r_base * num_labels_out;
double r = r_in + r_out;
double lscaling_base = lgamma(r);
double lscaling_base = lgamma(r) / log(2);

return compute_pval_r(r, r_in, r_out, lscaling_base, in_sum, out_sum, row);
};
Expand Down

0 comments on commit 4c2ceb4

Please sign in to comment.