Skip to content

Commit

Permalink
Add per execution process_func_args argument (#157)
Browse files Browse the repository at this point in the history
* First ideas ...

* Refine idea

* Add first tests

* Implement same changes for Feature

* Fix linter

* Update doc strings

* Add process_func_args to Segment

* Add more tests

* Add test for ProcessWithContext

* Fix docstring links
  • Loading branch information
hagenw authored Mar 21, 2024
1 parent ee0b2e8 commit d10bb27
Show file tree
Hide file tree
Showing 5 changed files with 543 additions and 43 deletions.
49 changes: 47 additions & 2 deletions audinterface/core/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def process_file(
start: Timestamp = None,
end: Timestamp = None,
root: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Extract features from an audio file.
Expand All @@ -420,6 +421,11 @@ def process_file(
If value is a float or integer it is treated as seconds.
See :func:`audinterface.utils.to_timedelta` for further options
root: root folder to expand relative file path
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
RuntimeError: if sampling rates do not match
Expand All @@ -433,6 +439,7 @@ def process_file(
start=start,
end=end,
root=root,
process_func_args=process_func_args,
)
return self._series_to_frame(series)

Expand All @@ -443,6 +450,7 @@ def process_files(
starts: Timestamps = None,
ends: Timestamps = None,
root: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Extract features for a list of files.
Expand All @@ -459,6 +467,11 @@ def process_files(
for further options.
If a scalar is given, it is applied to all files
root: root folder to expand relative file paths
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
RuntimeError: if sampling rates do not match
Expand All @@ -472,6 +485,7 @@ def process_files(
starts=starts,
ends=ends,
root=root,
process_func_args=process_func_args,
)
return self._series_to_frame(series)

Expand All @@ -481,6 +495,7 @@ def process_folder(
*,
filetype: str = 'wav',
include_root: bool = True,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Extract features from files in a folder.
Expand All @@ -493,6 +508,11 @@ def process_folder(
the file paths are absolute
in the index
of the returned result
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
FileNotFoundError: if folder does not exist
Expand All @@ -515,7 +535,11 @@ def process_folder(
filetype=filetype,
basenames=not include_root,
)
return self.process_files(files, root=root)
return self.process_files(
files,
root=root,
process_func_args=process_func_args,
)

def process_index(
self,
Expand All @@ -524,6 +548,7 @@ def process_index(
preserve_index: bool = False,
root: str = None,
cache_root: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Extract features from an index conform to audformat_.
Expand All @@ -547,6 +572,11 @@ def process_index(
otherwise always a segmented index is returned
root: root folder to expand relative file paths
cache_root: cache folder (see description)
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
RuntimeError: if sampling rates do not match
Expand All @@ -569,6 +599,7 @@ def process_index(
y = self.process.process_index(
index,
root=root,
process_func_args=process_func_args,
)
df = self._series_to_frame(y)

Expand All @@ -590,7 +621,8 @@ def process_signal(
file: str = None,
start: Timestamp = None,
end: Timestamp = None,
) -> pd.DataFrame:
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Extract features for an audio signal.
.. note:: If a ``file`` is given, the index of the returned frame
Expand All @@ -607,6 +639,11 @@ def process_signal(
end: end processing at this position.
If value is a float or integer it is treated as seconds.
See :func:`audinterface.utils.to_timedelta` for further options
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
RuntimeError: if sampling rates do not match
Expand All @@ -627,6 +664,7 @@ def process_signal(
file=file,
start=start,
end=end,
process_func_args=process_func_args,
)
return self._series_to_frame(series)

Expand All @@ -635,6 +673,7 @@ def process_signal_from_index(
signal: np.ndarray,
sampling_rate: int,
index: pd.MultiIndex,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.DataFrame:
r"""Split a signal into segments and extract features for each segment.
Expand All @@ -645,6 +684,11 @@ def process_signal_from_index(
named `start` and `end` that hold start and end
positions as :class:`pandas.Timedelta` objects.
See also :func:`audinterface.utils.signal_index`
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Feature.process.process_func_args`
Raises:
RuntimeError: if sampling rates do not match
Expand All @@ -658,6 +702,7 @@ def process_signal_from_index(
signal,
sampling_rate,
index,
process_func_args=process_func_args,
)
return self._series_to_frame(series)

Expand Down
Loading

0 comments on commit d10bb27

Please sign in to comment.