Skip to content

Commit

Permalink
update runtime extension 1.12 document (#1012)
Browse files Browse the repository at this point in the history
* fix the package name of Task Example

* change the API name space to ipex
  • Loading branch information
leslie-fang-intel authored Jul 27, 2022
1 parent dcabe00 commit f092200
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions docs/tutorials/features/runtime_extension.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ Runtime Extension

Intel® Extension for PyTorch\* Runtime Extension provides a couple of PyTorch frontend APIs for users to get finer-grained control of the thread runtime. It provides:

1. Multi-stream inference via the Python frontend module `intel_extension_for_pytorch.cpu.runtime.MultiStreamModule`.
2. Spawn asynchronous tasks via the Python frontend module `intel_extension_for_pytorch.cpu.runtime.Task`.
3. Program core bindings for OpenMP threads via the Python frontend `intel_extension_for_pytorch.cpu.runtime.pin`.
1. Multi-stream inference via the Python frontend module `ipex.cpu.runtime.MultiStreamModule`.
2. Spawn asynchronous tasks via the Python frontend module `ipex.cpu.runtime.Task`.
3. Program core bindings for OpenMP threads via the Python frontend `ipex.cpu.runtime.pin`.

**note**: Intel® Extension for PyTorch\* Runtime extension is in the **experimental** stage. The API is subject to change. More detailed descriptions are available at [API Documentation page](../api_doc.rst).

Expand All @@ -27,6 +27,9 @@ If the inputs' batchsize is larger than and divisible by ``num_streams``, the ba

Let's create some ExampleNets that will be used by further examples:
```
import torch
import intel_extension_for_pytorch as ipex
class ExampleNet1(torch.nn.Module):
def __init__(self):
super(ExampleNet1, self).__init__()
Expand Down Expand Up @@ -70,8 +73,8 @@ with torch.no_grad():
Here is the example of a model with single tensor input/output. We create a CPUPool with all the cores available on numa node 0. And creating a `MultiStreamModule` with stream number of 2 to do inference.
```
# Convert the model into multi_Stream_model
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model1, num_streams=2, cpu_pool=cpu_pool)
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model1, num_streams=2, cpu_pool=cpu_pool)
with torch.no_grad():
y = multi_Stream_model(x)
Expand All @@ -81,7 +84,7 @@ with torch.no_grad():
When creating a `MultiStreamModule`, we have default settings for `num_streams` ("AUTO") and `cpu_pool` (with all the cores available on numa node 0). For the `num_streams` of "AUTO", there are limitations to use with int8 datatype as we mentioned in below performance receipts section.
```
# Convert the model into multi_Stream_model
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model1)
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model1)
with torch.no_grad():
y = multi_Stream_model(x)
Expand All @@ -91,17 +94,17 @@ with torch.no_grad():
For module such as ExampleNet2 with structure input/output tensors, user needs to create `MultiStreamModuleHint` as input hint and output hint. `MultiStreamModuleHint` tells `MultiStreamModule` how to auto split the input into streams and concat the output from each steam.
```
# Convert the model into multi_Stream_model
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
# Create the input hint object
input_hint = intel_extension_for_pytorch.cpu.runtime.MultiStreamModuleHint(0, 0)
input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0)
# Create the output hint object
# When Python module has multi output tensors, it will be auto pack into a tuple, So we pass a tuple(0, 0) to create the output_hint
output_hint = intel_extension_for_pytorch.cpu.runtime.MultiStreamModuleHint((0, 0))
multi_Stream_model = intel_extension_for_pytorch.cpu.runtime.MultiStreamModule(traced_model2,
num_streams=2,
cpu_pool=cpu_pool,
input_split_hint=input_hint,
output_concat_hint=output_hint)
output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0))
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(traced_model2,
num_streams=2,
cpu_pool=cpu_pool,
input_split_hint=input_hint,
output_concat_hint=output_hint)
with torch.no_grad():
y = multi_Stream_model(x, x2)
Expand Down Expand Up @@ -133,12 +136,11 @@ Here are some performance receipes that we recommend for better multi-stream per
Here is an example for using asynchronous tasks. With the support of a runtime API, you can run 2 modules simultaneously. Each module runs on the corresponding cpu pool.

```
# Create the cpu pool and numa aware memory allocator
cpu_pool1 = ipex.runtime.CPUPool([0, 1, 2, 3])
cpu_pool2 = ipex.runtime.CPUPool([4, 5, 6, 7])
cpu_pool1 = ipex.cpu.runtime.CPUPool([0, 1, 2, 3])
cpu_pool2 = ipex.cpu.runtime.CPUPool([4, 5, 6, 7])
task1 = ipex.runtime.Task(traced_model1, cpu_pool1)
task2 = ipex.runtime.Task(traced_model1, cpu_pool2)
task1 = ipex.cpu.runtime.Task(traced_model1, cpu_pool1)
task2 = ipex.cpu.runtime.Task(traced_model1, cpu_pool2)
y1_future = task1(x)
y2_future = task2(x)
Expand All @@ -149,11 +151,11 @@ y2 = y2_future.get()

### Example of configuring core binding

Runtime Extension provides API of `intel_extension_for_pytorch.cpu.runtime.pin` to a CPU Pool for binding physical cores. We can use it without the async task feature. Here is the example to use `intel_extension_for_pytorch.cpu.runtime.pin` in the `with` context.
Runtime Extension provides API of `ipex.cpu.runtime.pin` to a CPU Pool for binding physical cores. We can use it without the async task feature. Here is the example to use `ipex.cpu.runtime.pin` in the `with` context.

```
cpu_pool = intel_extension_for_pytorch.cpu.runtime.CPUPool(node_id=0)
with intel_extension_for_pytorch.cpu.runtime.pin(cpu_pool):
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
with ipex.cpu.runtime.pin(cpu_pool):
y_runtime = traced_model1(x)
```

Expand Down

0 comments on commit f092200

Please sign in to comment.