Skip to content

Commit

Permalink
Merge pull request #31 from gcanat/seek_to_start
Browse files Browse the repository at this point in the history
Set scaler and seek to start before decoding and get_batch
  • Loading branch information
gcanat authored Nov 7, 2024
2 parents 03812ca + 0668a89 commit 28d7fb4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "video_reader-rs"
version = "0.2.0"
version = "0.2.1"
edition = "2021"

[lib]
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ If you have some memory limitations that wont let you decode the entire video at
from video_reader import PyVideoReader

videoname = "/path/to/your/video.mp4"
# must set pixel_format to "yuv420" to be able to use `decode_fast()`
vr = PyVideoReader(videoname, pixel_format="yuv420")
vr = PyVideoReader(videoname)

chunk_size = 800 # adjust to fit within your memory limit
video_length = vr.get_shape()[0]
Expand Down
16 changes: 4 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use ffmpeg::format::Pixel as AvPixel;
use ffmpeg_next as ffmpeg;
use numpy::ndarray::Dim;
use numpy::{IntoPyArray, PyArray, PyReadonlyArray4};
mod video_io;
use log::debug;
use pyo3::{
exceptions::PyRuntimeError,
pyclass, pymethods, pymodule,
Expand Down Expand Up @@ -32,27 +31,19 @@ struct PyVideoReader {
#[pymethods]
impl PyVideoReader {
#[new]
#[pyo3(signature = (filename, threads=None, resize_shorter_side=None, pixel_format="rgb24"))]
#[pyo3(signature = (filename, threads=None, resize_shorter_side=None))]
/// create an instance of VideoReader
/// * `filename` - path to the video file
/// * `threads` - number of threads to use. If None, let ffmpeg choose the optimal number.
/// * `resize_shorter_side - Optional, resize shorted side of the video to this value while
/// preserving the aspect ratio.
/// * `pixel_format` - pixel format to use for the ffmpeg scaler. If you are going to use the
/// `decode_fast` method, you must set this to "yuv420".
/// * returns a PyVideoReader instance.
fn new(
filename: &str,
threads: Option<usize>,
resize_shorter_side: Option<f64>,
pixel_format: &str,
) -> PyResult<Self> {
let pixel_format = match pixel_format {
"yuv420" => Some(AvPixel::YUV420P),
_ => None,
};
let decoder_config =
DecoderConfig::new(threads.unwrap_or(0), resize_shorter_side, pixel_format);
let decoder_config = DecoderConfig::new(threads.unwrap_or(0), resize_shorter_side);
match VideoReader::new(filename.to_string(), decoder_config) {
Ok(reader) => Ok(PyVideoReader {
inner: Mutex::new(reader),
Expand Down Expand Up @@ -228,6 +219,7 @@ impl PyVideoReader {
|| (first_key.dts() < &0)
|| start_time > 0)
{
debug!("Switching to get_batch_safe!");
match vr.get_batch_safe(indices) {
Ok(batch) => Ok(batch.into_pyarray_bound(py)),
Err(e) => Err(PyRuntimeError::new_err(format!("Error: {}", e))),
Expand Down
100 changes: 69 additions & 31 deletions src/video_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,13 @@ struct VideoParams {
pub struct DecoderConfig {
threads: usize,
resize_shorter_side: Option<f64>,
pixel_format: Option<AvPixel>,
}

impl DecoderConfig {
pub fn new(
threads: usize,
resize_shorter_side: Option<f64>,
pixel_format: Option<AvPixel>,
) -> Self {
pub fn new(threads: usize, resize_shorter_side: Option<f64>) -> Self {
Self {
threads,
resize_shorter_side,
pixel_format,
}
}
}
Expand Down Expand Up @@ -161,7 +155,6 @@ impl VideoReader {
/// Struct responsible for doing the actual decoding
pub struct VideoDecoder {
video: ffmpeg::decoder::Video,
scaler: Context,
height: u32,
width: u32,
fps: f64,
Expand Down Expand Up @@ -236,6 +229,7 @@ impl VideoDecoder {
/// Decode all frames that match the frame indices
pub fn receive_and_process_decoded_frames(
&mut self,
scaler: &mut Context,
reducer: &mut VideoReducer,
) -> Result<(), ffmpeg::Error> {
let mut decoded = Video::empty();
Expand All @@ -247,7 +241,7 @@ impl VideoDecoder {
if match_index.is_some() {
reducer.indices.remove(match_index.unwrap());
let mut rgb_frame = Video::empty();
self.scaler.run(&decoded, &mut rgb_frame)?;
scaler.run(&decoded, &mut rgb_frame)?;
let res = convert_frame_to_ndarray_rgb24(
&mut rgb_frame,
&mut reducer
Expand All @@ -267,6 +261,7 @@ impl VideoDecoder {
/// Decode frames
pub fn skip_and_decode_frames(
&mut self,
scaler: &mut Context,
reducer: &mut VideoReducer,
indices: &[usize],
frame_map: &mut HashMap<usize, FrameArray>,
Expand All @@ -277,7 +272,7 @@ impl VideoDecoder {
let mut rgb_frame = Video::empty();
let mut nd_frame =
FrameArray::zeros((self.height as usize, self.width as usize, 3_usize));
self.scaler.run(&decoded, &mut rgb_frame)?;
scaler.run(&decoded, &mut rgb_frame)?;
convert_frame_to_ndarray_rgb24(&mut rgb_frame, &mut nd_frame.view_mut())?;
frame_map.insert(reducer.frame_index, nd_frame);
}
Expand Down Expand Up @@ -335,26 +330,28 @@ impl VideoReader {
None => (video.height(), video.width()),
};

let scaler = Context::get(
video.format(),
video.width(),
video.height(),
config.pixel_format.unwrap_or(AvPixel::RGB24),
width,
height,
Flags::BILINEAR,
)?;

Ok(VideoDecoder {
video,
scaler,
height,
width,
fps,
video_info,
})
}

pub fn get_scaler(&self, pix_fmt: AvPixel) -> Result<Context, ffmpeg::Error> {
let scaler = Context::get(
self.decoder.video.format(),
self.decoder.video.width(),
self.decoder.video.height(),
pix_fmt,
self.decoder.width,
self.decoder.height,
Flags::BILINEAR,
)?;
Ok(scaler)
}

pub fn decode_video(
&mut self,
start_frame: Option<usize>,
Expand All @@ -370,6 +367,12 @@ impl VideoReader {
self.decoder.width,
);
let first_index = start_frame.unwrap_or(0);

let mut scaler = self.get_scaler(AvPixel::RGB24)?;

// make sure we are at the begining of the stream
self.seek_to_start()?;

// check if first_index is after the first keyframe, if so we can seek
if self
.stream_info
Expand All @@ -393,7 +396,7 @@ impl VideoReader {
if stream.index() == self.stream_index {
self.decoder.video.send_packet(&packet)?;
self.decoder
.receive_and_process_decoded_frames(&mut reducer)?;
.receive_and_process_decoded_frames(&mut scaler, &mut reducer)?;
} else {
debug!("Packet for another stream");
}
Expand All @@ -404,7 +407,7 @@ impl VideoReader {
&& (&reducer.frame_index <= reducer.indices.iter().max().unwrap_or(&0))
{
self.decoder
.receive_and_process_decoded_frames(&mut reducer)?;
.receive_and_process_decoded_frames(&mut scaler, &mut reducer)?;
}
Ok(reducer.full_video)
}
Expand All @@ -423,6 +426,12 @@ impl VideoReader {
self.decoder.width,
);
let first_index = start_frame.unwrap_or(0);

let mut scaler = self.get_scaler(AvPixel::YUV420P)?;

// make sure we are at the begining of the stream
self.seek_to_start()?;

if self
.stream_info
.key_frames
Expand All @@ -449,7 +458,7 @@ impl VideoReader {
while decoder.receive_frame(&mut decoded).is_ok() {
if reducer.indices.iter().any(|x| x == &curr_frame) {
let mut rgb_frame = Video::empty();
self.decoder.scaler.run(&decoded, &mut rgb_frame).unwrap();
scaler.run(&decoded, &mut rgb_frame).unwrap();
tasks.push(task::spawn(async move {
convert_yuv_to_ndarray_rgb24(rgb_frame)
}));
Expand Down Expand Up @@ -763,6 +772,12 @@ impl VideoReader {
self.decoder.height,
self.decoder.width,
);

let mut scaler = self.get_scaler(AvPixel::RGB24)?;

// make sure we are at the begining of the stream
self.seek_to_start()?;

// check if closest key frames to first_index is non zero, if so we can seek
let key_pos = self.locate_keyframes(first_index, &self.stream_info.key_frames);
if key_pos > 0 {
Expand All @@ -779,8 +794,12 @@ impl VideoReader {
for (stream, packet) in self.ictx.packets() {
if stream.index() == self.stream_index {
self.decoder.video.send_packet(&packet)?;
self.decoder
.skip_and_decode_frames(&mut reducer, &indices, &mut frame_map)?;
self.decoder.skip_and_decode_frames(
&mut scaler,
&mut reducer,
&indices,
&mut frame_map,
)?;
} else {
debug!("Packet for another stream");
}
Expand All @@ -790,8 +809,12 @@ impl VideoReader {
}
self.decoder.video.send_eof()?;
if &reducer.frame_index <= last_index {
self.decoder
.skip_and_decode_frames(&mut reducer, &indices, &mut frame_map)?;
self.decoder.skip_and_decode_frames(
&mut scaler,
&mut reducer,
&indices,
&mut frame_map,
)?;
}

let mut frame_batch: VideoArray = Array4::zeros((
Expand Down Expand Up @@ -827,13 +850,19 @@ impl VideoReader {
// duration of a frame in micro seconds
let frame_duration = (1. / fps * 1_000.0).round() as usize;

let mut scaler = self.get_scaler(AvPixel::RGB24)?;

// make sure we are at the begining of the stream
self.seek_to_start()?;

for (idx_counter, frame_index) in indices.into_iter().enumerate() {
self.n_fails = 0;
debug!("[NEXT INDICE] frame_index: {frame_index}");
self.seek_accurate(
frame_index,
&frame_duration,
&mut video_frames.slice_mut(s![idx_counter, .., .., ..]),
&mut scaler,
)?;
}
Ok(video_frames)
Expand All @@ -844,22 +873,26 @@ impl VideoReader {
frame_index: usize,
frame_duration: &usize,
frame_array: &mut ArrayViewMut3<u8>,
scaler: &mut Context,
) -> Result<(), ffmpeg::Error> {
let key_pos = self.locate_keyframes(&frame_index, &self.stream_info.key_frames);
debug!(" - Key pos: {}", key_pos);
let curr_key_pos = self.locate_keyframes(&self.curr_dec_idx, &self.stream_info.key_frames);
debug!(" - Curr key pos: {}", curr_key_pos);
if (key_pos == curr_key_pos) & (frame_index > self.curr_frame) {
// we can directly skip until frame_index
debug!("No need to seek, we can directly skip frames");
let num_skip = self.get_num_skip(&frame_index);
self.skip_frames(num_skip, &frame_index, frame_array)?;
self.skip_frames(num_skip, &frame_index, frame_array, scaler)?;
} else {
if key_pos < curr_key_pos {
debug!("Seeking back to start");
self.seek_to_start()?;
}
debug!("Seeking to key_pos: {}", key_pos);
self.seek_frame(&key_pos, frame_duration)?;
let num_skip = self.get_num_skip(&frame_index);
self.skip_frames(num_skip, &frame_index, frame_array)?;
self.skip_frames(num_skip, &frame_index, frame_array, scaler)?;
}
Ok(())
}
Expand Down Expand Up @@ -892,6 +925,7 @@ impl VideoReader {
num: usize,
frame_index: &usize,
frame_array: &mut ArrayViewMut3<u8>,
scaler: &mut Context,
) -> Result<(), ffmpeg::Error> {
let num_skip = num.min(self.stream_info.frame_count - 1);
debug!(
Expand All @@ -910,7 +944,7 @@ impl VideoReader {
while self.decoder.video.receive_frame(&mut decoded).is_ok() {
if &self.curr_frame == frame_index {
let mut rgb_frame = Video::empty();
self.decoder.scaler.run(&decoded, &mut rgb_frame)?;
scaler.run(&decoded, &mut rgb_frame)?;
convert_frame_to_ndarray_rgb24(&mut rgb_frame, frame_array)?;
self.update_indices();
return Ok(());
Expand All @@ -937,6 +971,10 @@ impl VideoReader {
.decode_order
.get(&self.curr_dec_idx)
.unwrap_or(&self.stream_info.frame_count);
debug!(
"dec_idx: {}, curr_frame: {}",
self.curr_dec_idx, self.curr_frame
);
}

// AVSEEK_FLAG_BACKWARD 1 <- seek backward
Expand Down

0 comments on commit 28d7fb4

Please sign in to comment.