Skip to content

Commit

Permalink
Support video input with data_id mechanism (#173)
Browse files Browse the repository at this point in the history
Support video input with data_id mechanism

Signed-off-by: Rafal <[email protected]>
  • Loading branch information
banasraf authored Feb 27, 2023
1 parent a3f7695 commit 52a0fa7
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 29 deletions.
2 changes: 1 addition & 1 deletion DALI_EXTRA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3a240e3b37b599975886ee65a87a2727ca98a107
1602d530ebdc8876275b9b81b0681c279683113e
2 changes: 1 addition & 1 deletion DALI_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.22
1.23
117 changes: 117 additions & 0 deletions qa/L0_video_input_decoupled/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# The MIT License (MIT)
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


from functools import partial
from itertools import cycle
import numpy as np
import queue
from os import environ
from glob import glob
import argparse

from tritonclient.utils import *
import tritonclient.grpc as t_client

import nvidia.dali.experimental.eager as eager

class UserData:

def __init__(self):
self._completed_requests = queue.Queue()


def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)


def get_dali_extra_path():
return environ['DALI_EXTRA_PATH']

def input_gen():
filenames = glob(f'{get_dali_extra_path()}/db/video/[cv]fr/*.mp4')
filenames = filter(lambda filename: 'mpeg4' not in filename, filenames)
filenames = filter(lambda filename: 'hevc' not in filename, filenames)
for filename in filenames:
yield np.fromfile(filename, dtype=np.uint8)


FRAMES_PER_SEQUENCE = 5
BATCH_SIZE = 3
FRAMES_PER_BATCH = FRAMES_PER_SEQUENCE * BATCH_SIZE
model_name = "model.dali"

user_data = UserData()

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',
help='Inference server GRPC URL. Default is localhost:8001.')
parser.add_argument('-n', '--n_iters', type=int, required=False, default=1, help='Number of iterations')
return parser.parse_args()

if __name__ == '__main__':
args = parse_args()
with t_client.InferenceServerClient(url=args.url) as triton_client:
triton_client.start_stream(callback=partial(callback, user_data))

for req_id, input_data in zip(range(args.n_iters), cycle(input_gen())):
inp = t_client.InferInput('INPUT', [1, input_data.shape[0]], 'UINT8')
inp.set_data_from_numpy(input_data.reshape((1, -1)))

outp = t_client.InferRequestedOutput('OUTPUT')

request_id = str(req_id)
triton_client.async_stream_infer(model_name=model_name,
inputs=[inp],
request_id=request_id,
outputs=[outp])


expected_result = eager.experimental.decoders.video([input_data])
expected_result = eager.pad(expected_result, axes=0, align=FRAMES_PER_SEQUENCE).at(0)
n_frames = expected_result.shape[0]
recv_count = 0
expected_count = (n_frames + FRAMES_PER_BATCH - 1) // FRAMES_PER_BATCH
result_dict = {}
while recv_count < expected_count:
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
else:
this_id = data_item.get_response().id
if this_id not in result_dict.keys():
result_dict[this_id] = []
result_dict[this_id].append(data_item)
recv_count += 1

result_list = result_dict[request_id]
expected_result = np.split(expected_result, n_frames / FRAMES_PER_SEQUENCE)
for i, result in enumerate(result_list):
expected_batch = expected_result[i * BATCH_SIZE : min((i+1) * BATCH_SIZE, len(expected_result))]
expected_batch = np.asarray(expected_batch)
result_data = result.as_numpy('OUTPUT')
assert np.array_equal(expected_batch, result_data)

print(f'ITER {req_id}: OK')
30 changes: 30 additions & 0 deletions qa/L0_video_input_decoupled/model_repository/model.dali/1/dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# The MIT License (MIT)
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import nvidia.dali as dali
import nvidia.dali.fn as fn
from nvidia.dali.plugin.triton import autoserialize


@autoserialize
@dali.pipeline_def(batch_size=3, num_threads=3, device_id=0, output_ndim=4, output_dtype=dali.types.UINT8)
def pipeline():
return fn.experimental.inputs.video(sequence_length=5, name='INPUT', last_sequence_policy='pad')
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# The MIT License (MIT)
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


max_batch_size: 0

model_transaction_policy {
decoupled: True
}

input [
{
name: "INPUT"
dims: [ 1, -1 ]
}
]

output [
{
name: "OUTPUT"
}
]
24 changes: 24 additions & 0 deletions qa/L0_video_input_decoupled/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash -ex

# The MIT License (MIT)
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

echo "model ready"
26 changes: 26 additions & 0 deletions qa/L0_video_input_decoupled/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash -ex

# The MIT License (MIT)
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

echo "RUN SEQUENTIAL CLIENT"
python client.py -u $GRPC_ADDR -n 16
echo "PASS"
22 changes: 17 additions & 5 deletions src/dali_executor/dali_executor.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2020 NVIDIA CORPORATION
// Copyright (c) 2020-2023 NVIDIA CORPORATION & AFFILIATES
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -46,10 +46,13 @@ void DaliExecutor::SetupInputs(const std::vector<IDescr>& inputs) {
}
}
WaitForCopies();
input_names_.clear();
request_id_ += 1;
std::string request_id_str = std::to_string(request_id_);
for (auto& inp : c_inputs) {
pipeline_.SetInput(inp);
input_names_.push_back(inp.meta.name);
pipeline_.SetInput(inp, {request_id_str});
}
inputs_consumed_ = false;
}


Expand Down Expand Up @@ -134,12 +137,16 @@ bool DaliExecutor::IsNoCopy(device_type_t es_device, const IDescr& input) {
}

std::vector<OutputInfo> DaliExecutor::Run(const std::vector<IDescr>& inputs) {
SetupInputs(inputs);
if (inputs_consumed_) {
SetupInputs(inputs);
inputs_consumed_ = false;
}
try {
pipeline_.Run();
pipeline_.Output();
} catch (std::runtime_error& e) {
pipeline_.Reset();
inputs_consumed_ = true;
throw e;
}
std::vector<OutputInfo> ret(pipeline_.GetNumOutput());
Expand All @@ -148,7 +155,12 @@ std::vector<OutputInfo> DaliExecutor::Run(const std::vector<IDescr>& inputs) {
ret[out_idx] = {outputs_shapes[out_idx], pipeline_.GetOutputType(out_idx),
pipeline_.GetOutputDevice(out_idx)};
}
inputs_consumed_ = true; // this will change with introduction of streamed input
for (auto &name: input_names_) {
auto trace = pipeline_.TryGetOperatorTrace(name, "next_output_data_id");
if (!trace.has_value() || std::stoull(*trace) != request_id_) {
inputs_consumed_ = true;
}
}
return ret;
}

Expand Down
6 changes: 4 additions & 2 deletions src/dali_executor/dali_executor.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2020 NVIDIA CORPORATION
// Copyright (c) 2020-2023 NVIDIA CORPORATION & AFFILIATES
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -110,7 +110,9 @@ class DaliExecutor {
ThreadPool thread_pool_;
std::map<std::string, IOBuffer<CPU>> cpu_buffers_;
std::map<std::string, IOBuffer<GPU>> gpu_buffers_;
bool inputs_consumed_;
bool inputs_consumed_ = true;
std::vector<std::string> input_names_;
uint64_t request_id_ = 0;
};

}}} // namespace triton::backend::dali
Expand Down
26 changes: 20 additions & 6 deletions src/dali_executor/dali_pipeline.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES
// Copyright (c) 2020-2023 NVIDIA CORPORATION & AFFILIATES
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -55,13 +55,16 @@ std::vector<TensorListShape<>> DaliPipeline::GetOutputShapes() {

void DaliPipeline::SetInput(const void* data_ptr, const char* name, device_type_t source_device,
dali_data_type_t data_type, span<const int64_t> inputs_shapes,
int sample_ndims, bool force_no_copy) {
int sample_ndims, const char *data_id, bool force_no_copy) {
ENFORCE(inputs_shapes.size() % sample_ndims == 0, "Incorrect inputs shapes or sample ndims");
int batch_size = inputs_shapes.size() / sample_ndims;
unsigned int flags = DALI_ext_default;
if (force_no_copy) {
flags |= DALI_ext_force_no_copy;
}
if (data_id) {
daliSetExternalInputDataId(&handle_, name, data_id);
}
const char *layout = daliGetExternalInputLayout(&handle_, name);
daliSetExternalInputBatchSize(&handle_, name, batch_size);
daliSetExternalInput(&handle_, name, source_device, data_ptr, data_type, inputs_shapes.data(),
Expand All @@ -71,16 +74,18 @@ void DaliPipeline::SetInput(const void* data_ptr, const char* name, device_type_

void DaliPipeline::SetInput(const void* ptr, const char* name, device_type_t source_device,
dali_data_type_t data_type, TensorListShape<> input_shape,
bool force_no_copy) {
std::optional<std::string_view> data_id, bool force_no_copy) {
SetInput(ptr, name, source_device, data_type, make_span(input_shape.shapes),
input_shape.sample_dim(), force_no_copy);
input_shape.sample_dim(), data_id ? data_id->data() : nullptr, force_no_copy);
}

void DaliPipeline::SetInput(const IDescr& io_descr, bool force_no_copy) {
void DaliPipeline::SetInput(const IDescr& io_descr, std::optional<std::string_view> data_id,
bool force_no_copy) {
ENFORCE(io_descr.buffers.size() == 1, "DALI pipeline input has to be a single chunk of memory");
auto meta = io_descr.meta;
auto buffer = io_descr.buffers[0];
SetInput(buffer.data, meta.name.c_str(), buffer.device, meta.type, meta.shape, force_no_copy);
SetInput(buffer.data, meta.name.c_str(), buffer.device, meta.type, meta.shape,
data_id, force_no_copy);
}

void DaliPipeline::SyncStream() {
Expand Down Expand Up @@ -146,4 +151,13 @@ int DaliPipeline::GetMaxBatchSize() {
return daliGetMaxBatchSize(&handle_);
}

std::optional<std::string> DaliPipeline::TryGetOperatorTrace(std::string_view operator_name,
std::string_view trace_name) {
if (daliHasOperatorTrace(&handle_, operator_name.data(), trace_name.data())) {
return daliGetOperatorTrace(&handle_, operator_name.data(), trace_name.data());
} else {
return std::nullopt;
}
}

}}} // namespace triton::backend::dali
Loading

0 comments on commit 52a0fa7

Please sign in to comment.