diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index 50e5092cf0..4ac3bd140a 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -29,12 +29,13 @@ "- [Declaring independence with iarange](#Declaring-independent-dims-with-iarange)\n", "- [Subsampling inside iarange](#Subsampling-tensors-inside-an-iarange)\n", "- [Broadcasting to allow Parallel Enumeration](#Broadcasting-to-allow-parallel-enumeration)\n", - " - [Writing parallelizable code](#Writing-parallelizable-code)" + " - [Writing parallelizable code](#Writing-parallelizable-code)\n", + "- [Automatic broadcasting via broadcast poutine](#Automatic-broadcasting-via-broadcast-poutine)" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -44,6 +45,7 @@ "from torch.distributions import constraints\n", "from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal\n", "from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate\n", + "import pyro.poutine as poutine\n", "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", @@ -90,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -154,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -252,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -329,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -396,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -458,7 +460,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -521,6 +523,110 @@ "enumerated = True\n", "test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Automatic broadcasting via broadcast poutine\n", + "\n", + "Note that in all our model/guide specifications, we had to expand sample shapes by hand to satisfy the constraints on batch shape enforced by `pyro.iarange` statements. This code can be simplified by using [poutine.broadcast](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.broadcast), which automatically broadcasts the batch shape of `pyro.sample` statements when inside a single or nested iarange context. \n", + "\n", + "We will demonstrate this using `model4` from the [previous section](#Writing-parallelizable-code). Note the following changes to the code from earlier:\n", + "\n", + " - For the purpose of this example, we will only consider \"parallel\" enumeration, but broadcasting should work as expected without enumeration or with \"sequential\" enumeration.\n", + " - We have separated out the sampling function which returns the tensors corresponding to the active pixels. Modularizing the model code into components is a common practice, and helps with maintainability of large models. The first sampling function is identical to what we had in `model4`, and the remaining sampling functions use `poutine.broadcast` to implicitly expand sample sites to confirm to the shape requirements imposed by the `iarange` contexts in which they are embedded.\n", + " - We would also like to use the `pyro.iarange` construct to parallelize the ELBO estimator over [num_particles](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.elbo.ELBO). This is done by wrapping the contents of model/guide inside an outermost `pyro.iarange` context." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_particles = 100 # Number of samples for the ELBO estimator\n", + "width = 8\n", + "height = 10\n", + "sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])\n", + "\n", + "def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):\n", + " with x_axis:\n", + " x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand_by([num_particles, width, 1]))\n", + " with y_axis:\n", + " y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand_by([num_particles, 1, height]))\n", + " return x_active, y_active\n", + "\n", + "def sample_pixel_locations_automatic_broadcasting(p_x, p_y, x_axis, y_axis):\n", + " with x_axis:\n", + " x_active = pyro.sample(\"x_active\", Bernoulli(p_x))\n", + " with y_axis:\n", + " y_active = pyro.sample(\"y_active\", Bernoulli(p_y))\n", + " return x_active, y_active \n", + "\n", + "def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):\n", + " with x_axis:\n", + " x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand_by([width, 1]))\n", + " with y_axis:\n", + " y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand_by([height]))\n", + " return x_active, y_active \n", + "\n", + "def fun(observe, sample_fn):\n", + " p_x = pyro.param(\"p_x\", torch.tensor(0.1), constraint=constraints.unit_interval)\n", + " p_y = pyro.param(\"p_y\", torch.tensor(0.1), constraint=constraints.unit_interval)\n", + " x_axis = pyro.iarange('x_axis', width, dim=-2)\n", + " y_axis = pyro.iarange('y_axis', height, dim=-1)\n", + "\n", + " with pyro.iarange(\"num_particles\", 100, dim=-3):\n", + " x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)\n", + " # Indices corresponding to \"parallel\" enumeration are appended \n", + " # to the left of the \"num_particles\" iarange dim.\n", + " assert x_active.shape == (2, num_particles, width, 1)\n", + " assert y_active.shape == (2, 1, num_particles, 1, height)\n", + " p = 0.1 + 0.5 * x_active * y_active\n", + " assert p.shape == (2, 2, num_particles, width, height)\n", + "\n", + " dense_pixels = torch.zeros_like(p)\n", + " for x, y in sparse_pixels:\n", + " dense_pixels[..., x, y] = 1\n", + " assert dense_pixels.shape == (2, 2, num_particles, width, height)\n", + "\n", + " with x_axis, y_axis: \n", + " if observe:\n", + " pyro.sample(\"pixels\", Bernoulli(p), obs=dense_pixels)\n", + "\n", + "def test_model_with_sample_fn(sample_fn, broadcast=False):\n", + " def model():\n", + " fun(observe=True, sample_fn=sample_fn)\n", + "\n", + " @config_enumerate(default=\"parallel\")\n", + " def guide():\n", + " fun(observe=False, sample_fn=sample_fn)\n", + "\n", + " if broadcast:\n", + " model = poutine.broadcast(model)\n", + " guide = poutine.broadcast(guide)\n", + " test_model(model, guide, TraceEnum_ELBO(max_iarange_nesting=3))\n", + "\n", + "test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)\n", + "test_model_with_sample_fn(sample_pixel_locations_automatic_broadcasting, broadcast=True)\n", + "test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting, broadcast=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the first sampling function, we had to do some manual book-keeping and expand the `Bernoulli` distribution's batch shape to account for the independent dimensions added by the `pyro.iarange` contexts. In particular, note how `sample_pixel_locations` needs knowledge of `num_particles`, `width` and `height` and is accessing these variables from the global scope, which is not ideal. \n", + "\n", + "The next two sampling functions are annotated with [poutine.broadcast](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.broadcast), so that this can be automatically achieved via an effect handler. Note the following in the next two modified sampling functions:\n", + "\n", + " - The second argument to `pyro.iarange`, i.e. the optional `size` argument needs to be provided for implicit broadasting, so that `poutine.broadcast` can infer the batch shape requirement for each of the sample sites. \n", + " - The existing `batch_shape` of the sample site must be broadcastable with the size of the `pyro.iarange` contexts. In our particular example, `Bernoulli(p_x)` has an empty batch shape which is universally broadcastable.\n", + " - `poutine.broadcast` is idempotent, and is also safe to use when the sample sites have been partially broadcasted to the size of some of the `iarange`s but not all. In the third sampling function, the user has partially expanded `x_active` and `y_active`, and the broadcast effect handler expands the other batch dimensions to the size of remaining `iarange`s.\n", + "\n", + "Note how simple it is to achieve parallelization via tensorized operations using `pyro.iarange` and `poutine.broadcast`! `poutine.broadcast` also helps in code modularization because model components can be written agnostic of the `iarange` contexts in which they may subsequently get embedded in." + ] } ], "metadata": {