Skip to content

Commit

Permalink
Adding poutine.broadcast to tensor shapes tutorial (#1154)
Browse files Browse the repository at this point in the history
* Adding poutine.broadcast to tensor shapes tutorial

* address review comment

* minor cosmetic fixes

* clarify description

* fix newlines

* fix nit

* Fix bug in tensor_shapes tutorial

* Simplify poutine.broadcast example
  • Loading branch information
neerajprad authored and jpchen committed May 22, 2018
1 parent b92a679 commit c7449bb
Showing 1 changed file with 117 additions and 11 deletions.
128 changes: 117 additions & 11 deletions tutorial/source/tensor_shapes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
"- [Declaring independence with iarange](#Declaring-independent-dims-with-iarange)\n",
"- [Subsampling inside iarange](#Subsampling-tensors-inside-an-iarange)\n",
"- [Broadcasting to allow Parallel Enumeration](#Broadcasting-to-allow-parallel-enumeration)\n",
" - [Writing parallelizable code](#Writing-parallelizable-code)"
" - [Writing parallelizable code](#Writing-parallelizable-code)\n",
"- [Automatic broadcasting via broadcast poutine](#Automatic-broadcasting-via-broadcast-poutine)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,6 +45,7 @@
"from torch.distributions import constraints\n",
"from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal\n",
"from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate\n",
"import pyro.poutine as poutine\n",
"from pyro.optim import Adam\n",
"\n",
"smoke_test = ('CI' in os.environ)\n",
Expand Down Expand Up @@ -90,7 +92,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -111,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -133,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -154,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -177,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -252,7 +254,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -329,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -396,7 +398,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -458,7 +460,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -521,6 +523,110 @@
"enumerated = True\n",
"test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Automatic broadcasting via broadcast poutine<a class=\"anchor\" id=\"Automatic-broadcasting-via-broadcast-poutine\"></a>\n",
"\n",
"Note that in all our model/guide specifications, we had to expand sample shapes by hand to satisfy the constraints on batch shape enforced by `pyro.iarange` statements. This code can be simplified by using [poutine.broadcast](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.broadcast), which automatically broadcasts the batch shape of `pyro.sample` statements when inside a single or nested iarange context. \n",
"\n",
"We will demonstrate this using `model4` from the [previous section](#Writing-parallelizable-code). Note the following changes to the code from earlier:\n",
"\n",
" - For the purpose of this example, we will only consider \"parallel\" enumeration, but broadcasting should work as expected without enumeration or with \"sequential\" enumeration.\n",
" - We have separated out the sampling function which returns the tensors corresponding to the active pixels. Modularizing the model code into components is a common practice, and helps with maintainability of large models. The first sampling function is identical to what we had in `model4`, and the remaining sampling functions use `poutine.broadcast` to implicitly expand sample sites to confirm to the shape requirements imposed by the `iarange` contexts in which they are embedded.\n",
" - We would also like to use the `pyro.iarange` construct to parallelize the ELBO estimator over [num_particles](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.elbo.ELBO). This is done by wrapping the contents of model/guide inside an outermost `pyro.iarange` context."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_particles = 100 # Number of samples for the ELBO estimator\n",
"width = 8\n",
"height = 10\n",
"sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])\n",
"\n",
"def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):\n",
" with x_axis:\n",
" x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand_by([num_particles, width, 1]))\n",
" with y_axis:\n",
" y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand_by([num_particles, 1, height]))\n",
" return x_active, y_active\n",
"\n",
"def sample_pixel_locations_automatic_broadcasting(p_x, p_y, x_axis, y_axis):\n",
" with x_axis:\n",
" x_active = pyro.sample(\"x_active\", Bernoulli(p_x))\n",
" with y_axis:\n",
" y_active = pyro.sample(\"y_active\", Bernoulli(p_y))\n",
" return x_active, y_active \n",
"\n",
"def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):\n",
" with x_axis:\n",
" x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand_by([width, 1]))\n",
" with y_axis:\n",
" y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand_by([height]))\n",
" return x_active, y_active \n",
"\n",
"def fun(observe, sample_fn):\n",
" p_x = pyro.param(\"p_x\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
" p_y = pyro.param(\"p_y\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
" x_axis = pyro.iarange('x_axis', width, dim=-2)\n",
" y_axis = pyro.iarange('y_axis', height, dim=-1)\n",
"\n",
" with pyro.iarange(\"num_particles\", 100, dim=-3):\n",
" x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)\n",
" # Indices corresponding to \"parallel\" enumeration are appended \n",
" # to the left of the \"num_particles\" iarange dim.\n",
" assert x_active.shape == (2, num_particles, width, 1)\n",
" assert y_active.shape == (2, 1, num_particles, 1, height)\n",
" p = 0.1 + 0.5 * x_active * y_active\n",
" assert p.shape == (2, 2, num_particles, width, height)\n",
"\n",
" dense_pixels = torch.zeros_like(p)\n",
" for x, y in sparse_pixels:\n",
" dense_pixels[..., x, y] = 1\n",
" assert dense_pixels.shape == (2, 2, num_particles, width, height)\n",
"\n",
" with x_axis, y_axis: \n",
" if observe:\n",
" pyro.sample(\"pixels\", Bernoulli(p), obs=dense_pixels)\n",
"\n",
"def test_model_with_sample_fn(sample_fn, broadcast=False):\n",
" def model():\n",
" fun(observe=True, sample_fn=sample_fn)\n",
"\n",
" @config_enumerate(default=\"parallel\")\n",
" def guide():\n",
" fun(observe=False, sample_fn=sample_fn)\n",
"\n",
" if broadcast:\n",
" model = poutine.broadcast(model)\n",
" guide = poutine.broadcast(guide)\n",
" test_model(model, guide, TraceEnum_ELBO(max_iarange_nesting=3))\n",
"\n",
"test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)\n",
"test_model_with_sample_fn(sample_pixel_locations_automatic_broadcasting, broadcast=True)\n",
"test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting, broadcast=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the first sampling function, we had to do some manual book-keeping and expand the `Bernoulli` distribution's batch shape to account for the independent dimensions added by the `pyro.iarange` contexts. In particular, note how `sample_pixel_locations` needs knowledge of `num_particles`, `width` and `height` and is accessing these variables from the global scope, which is not ideal. \n",
"\n",
"The next two sampling functions are annotated with [poutine.broadcast](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.broadcast), so that this can be automatically achieved via an effect handler. Note the following in the next two modified sampling functions:\n",
"\n",
" - The second argument to `pyro.iarange`, i.e. the optional `size` argument needs to be provided for implicit broadasting, so that `poutine.broadcast` can infer the batch shape requirement for each of the sample sites. \n",
" - The existing `batch_shape` of the sample site must be broadcastable with the size of the `pyro.iarange` contexts. In our particular example, `Bernoulli(p_x)` has an empty batch shape which is universally broadcastable.\n",
" - `poutine.broadcast` is idempotent, and is also safe to use when the sample sites have been partially broadcasted to the size of some of the `iarange`s but not all. In the third sampling function, the user has partially expanded `x_active` and `y_active`, and the broadcast effect handler expands the other batch dimensions to the size of remaining `iarange`s.\n",
"\n",
"Note how simple it is to achieve parallelization via tensorized operations using `pyro.iarange` and `poutine.broadcast`! `poutine.broadcast` also helps in code modularization because model components can be written agnostic of the `iarange` contexts in which they may subsequently get embedded in."
]
}
],
"metadata": {
Expand Down

0 comments on commit c7449bb

Please sign in to comment.