Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sampling from uniform and categorical distributions #483

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,7 @@ $probabilities = ['a' => 0.3, 'b' => 0.2, 'c' => 0.5]; // probabilities for cate
$categorical = new Discrete\Categorical($k, $probabilities);
$pmf_a = $categorical->pmf('a');
$mode = $categorical->mode();
$random = $categorical->rand(); // returns 'a' or 'b' or 'c'

// Geometric distribution (failures before the first success)
$p = 0.5; // success probability
Expand Down Expand Up @@ -1549,6 +1550,7 @@ $cdf = $uniform->cdf($k);
$μ = $uniform->mean();
$median = $uniform->median();
$σ² = $uniform->variance();
$random = $uniform->rand();

// Zipf distribution
$k = 2; // rank
Expand Down
49 changes: 49 additions & 0 deletions src/Probability/Distribution/Discrete/Categorical.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
*/
private $probabilities;

/**
* @var array<int|string, int|float>|null
* Cached CDF when pmf sorted from most probable category
* to least probable category.
* This is only useful for repeated sampling using Categorical::rand()
*/
private $sorted_cdf = null;

/**
* Distribution constructor
*
Expand Down Expand Up @@ -123,4 +131,45 @@
throw new Exception\BadDataException("$name is not a valid gettable parameter");
}
}

/**
* Sample a random category and return its key
*
* @return int|string
*/
public function rand()
{
// calculate sorted cdf or use cached array
if (is_null($this->sorted_cdf)) {
// sort probabilities in descending order
$sorted_probabilities = $this->probabilities; // copy as arsort works in place
arsort($sorted_probabilities, SORT_NUMERIC);

// calculate cdf
$cdf = [];
$sum = 0.0;
foreach ($sorted_probabilities as $category => $pᵢ) {
$sum += $pᵢ;
$cdf[$category] = $sum;
}

$this->sorted_cdf = $cdf;
}

$rand = \random_int(0, \PHP_INT_MAX) / \PHP_INT_MAX; // [0, 1]

// find first element in sorted cdf that is larger than $rand
// for large arrays, performance could be improved by using binary search instead
// also possible with array_find_key in PHP >=8.4
foreach ($this->sorted_cdf as $category => $v) {
if ($v >= $rand) {
return $category;
}
}

// should only end up here if due to rounding errors the sum of probabilities
// is less than 1.0 and the generated random value is larger than the sum
// should be very unlikely, but possible
return array_key_last($this->sorted_cdf);

Check failure on line 173 in src/Probability/Distribution/Discrete/Categorical.php

View workflow job for this annotation

GitHub Actions / Static Analysis (7.2)

Method MathPHP\Probability\Distribution\Discrete\Categorical::rand() should return int|string but returns int|string|null.
}
}
10 changes: 10 additions & 0 deletions src/Probability/Distribution/Discrete/Uniform.php
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,14 @@

return (($b - $a + 1) ** 2 - 1) / 12;
}

/**
* Random number sampled from the distribution
*
* @return int
*/
public function rand(): int
{
return \random_int($this->a, $this->b);

Check failure on line 165 in src/Probability/Distribution/Discrete/Uniform.php

View workflow job for this annotation

GitHub Actions / Static Analysis (7.2)

Parameter #2 $max of function random_int expects int, float given.
}
}
35 changes: 35 additions & 0 deletions tests/Probability/Distribution/Discrete/CategoricalTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,39 @@ public function testGetException()
// When
$does_not_exist = $categorical->does_not_exist;
}


/**
* @test rand
*/
public function testRand()
{
// Given
$k = 3;
$probabilities = ['a' => 0.2, 'b' => 0.5, 'c' => 0.3];
$categorical = new Categorical($k, $probabilities);

// When
$rand = $categorical->rand();

// Then
$this->assertContains($rand, ['a', 'b', 'c']);
}

/**
* @test rand with certainty
*/
public function testRandCertain()
{
// Given
$k = 3;
$probabilities = ['a' => 0.0, 'b' => 1.0, 'c' => 0.0];
$categorical = new Categorical($k, $probabilities);

// When
$rand = $categorical->rand();

// Then
$this->assertEquals('b', $rand);
}
}
19 changes: 19 additions & 0 deletions tests/Probability/Distribution/Discrete/UniformTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,23 @@ public function dataProviderForVariance(): array
[2, 4, 0.66666666666667],
];
}

/**
* @test rand
*/
public function testRand()
{
// Given
$a = 10;
$b = 11;
$uniform = new Uniform($a, $b);

// When
$random = $uniform->rand();

// Then
$this->assertTrue(\is_numeric($random));
$this->assertTrue($a <= $random);
$this->assertTrue($random <= $b);
}
}
Loading