Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add audinterface.Segment.process_table() #172

Merged
merged 34 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fd35a83
Initial implementation of Segment.process_table()
Apr 29, 2024
75e33bd
Merge branch 'main' into process_table
Apr 29, 2024
717fe92
fix
Apr 29, 2024
31a3734
Adding notes on usage and tests for process_table().
Apr 29, 2024
450f91e
fix
Apr 29, 2024
3b47dfe
code formatting fixed by pre-commit
May 6, 2024
f7bea4e
Fixing tests for process_table() with relative path
May 7, 2024
f4b391b
renaming test
May 7, 2024
d3c5693
Adding test for calling process_table with an index (ValueError, code…
May 7, 2024
6219510
Test assignment of labels for dataframe with a segmented index and se…
May 7, 2024
589ed6e
fix
May 7, 2024
df02e33
Fixing transfer of dtype and corresponding test
May 7, 2024
3c86379
Fixing other tests not to expect a different dtype
May 7, 2024
2980bc3
fix
May 7, 2024
79b4833
fix if processing function returns empty table
May 7, 2024
da4d433
better fix and 1D-dataframe test
May 7, 2024
187fe10
fixing dtype for empty segments test
May 7, 2024
faa9968
trying to resolve file loading issue in documentation
May 7, 2024
8bd24aa
Fixing issue for category type columns
May 7, 2024
ee4b742
Update audinterface/core/segment.py
maxschmitt May 15, 2024
351de8c
Update audinterface/core/segment.py
maxschmitt May 15, 2024
90e1da3
Update audinterface/core/segment.py
maxschmitt May 15, 2024
5a8d9e8
Update audinterface/core/segment.py
maxschmitt May 15, 2024
0d47e90
moving dtypes and adding description
May 15, 2024
fe6f86e
check error
May 15, 2024
5ef1f7e
revert
May 15, 2024
991d8a2
Update audinterface/core/segment.py
maxschmitt May 15, 2024
607888d
chaning indexes to usual convention
May 15, 2024
b6aa70c
Update docs/usage.rst
maxschmitt May 15, 2024
dfdd96e
Update tests/test_segment.py
maxschmitt May 15, 2024
f4b71d7
Update docs/usage.rst
maxschmitt May 15, 2024
4c197b9
adapt header of usage.rst to make new example work
May 15, 2024
e393dc8
Adding test for overlapping segments
May 15, 2024
b9b6ee2
Update docs/usage.rst
maxschmitt May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions audinterface/core/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,94 @@ def process_index(

return audformat.segmented_index(files, starts, ends)

def process_table(
self,
table: typing.Union[pd.Series, pd.DataFrame],
*,
root: str = None,
cache_root: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> pd.Index:
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
r"""Segment files or segments from a table.

If ``cache_root`` is not ``None``,
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
a hash value is created from the index
using :func:`audformat.utils.hash` and
the result is stored as
``<cache_root>/<hash>.pkl``.
When called again with the same index,
results will be read from the cached file.

Args:
table: ``pd.Series`` or ``pd.DataFrame``
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
with an index conform to audformat_
root: root folder to expand relative file paths
cache_root: cache folder (see description)
process_func_args: (keyword) arguments passed on
to the processing function.
They will temporarily overwrite
the ones stored in
:attr:`audinterface.Segment.process.process_func_args`

Returns:
Segmented table with an index conform to audformat_

Raises:
RuntimeError: if table has a wrong type
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
RuntimeError: if sampling rates do not match
RuntimeError: if channel selection is invalid

.. _audformat: https://audeering.github.io/audformat/data-format.html

"""
if (not isinstance(table, pd.Series) and
not isinstance(table, pd.DataFrame)):
raise ValueError("table has to be pd.Series or pd.DataFrame")

index = audformat.utils.to_segmented_index(table.index)
utils.assert_index(index)

if index.empty:
return table

y = self.process.process_index(
index,
preserve_index=False,
root=root,
cache_root=cache_root,
process_func_args=process_func_args,
)

files = []
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
starts = []
ends = []
labels = []
if isinstance(table, pd.Series):
for j, ((file, start, _), index) in enumerate(y.items()):
files.extend([file] * len(index))
starts.extend(index.get_level_values("start") + start)
ends.extend(index.get_level_values("end") + start)
labels.extend([[table.iloc[j]] * len(index)])
labels = np.hstack(labels)
else:
for j, ((file, start, _), index) in enumerate(y.items()):
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
files.extend([file] * len(index))
starts.extend(index.get_level_values("start") + start)
ends.extend(index.get_level_values("end") + start)
labels.extend([[table.iloc[j].values] * len(index)])
labels = np.vstack(labels)
if labels.shape == (1, 0):
labels = labels.squeeze()

index = audformat.segmented_index(files, starts, ends)

if isinstance(table, pd.Series):
table = pd.Series(labels, index, name=table.name)
else:
table = pd.DataFrame(labels, index, columns=table.columns)

return table

def process_signal(
self,
signal: np.ndarray,
Expand Down
15 changes: 15 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,21 @@ would be a voice activity detection algorithm.
idx = interface.process_file(files[0])
idx

Sometimes, it is required that a table (i.e., `pd.Series`
or `pd.DataFrame`) is segmented and the ``labels`` from
the original segments should be kept. For this,
:class:`audinterface.Segment` has a dedicated method
``process_table()``. This method is useful, if a
segmentation (e.g., voice activity detection) is
performed on an already labelled dataset in order
to do data augmentation or teacher-student training,
improving model performance for shorter chunks.
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved

.. jupyter-execute::

df_segmented = interface.process_table(df)
df_segmented


Special processing function arguments
-------------------------------------
Expand Down
93 changes: 91 additions & 2 deletions tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,15 @@ def process_func(signal, sampling_rate):
path = os.path.join(root, file)
af.write(path, signal, sampling_rate)

# empty index
# empty index and table
index = audformat.segmented_index()
result = segment.process_index(index)
assert result.empty
result = segment.process_signal_from_index(signal, sampling_rate, index)
assert result.empty
table = audformat.Table(index)
result = segment.process_table(table.get())
assert result.index.empty

# segmented index without file level
index = audinterface.utils.signal_index(
Expand Down Expand Up @@ -191,6 +194,25 @@ def process_func(signal, sampling_rate):
result = segment.process_signal_from_index(signal, sampling_rate, index)
pd.testing.assert_index_equal(result, expected)

# segmented index with absolute paths: series and dataframe
table = audformat.Table(index)
table["values"] = audformat.Column()
table.set({"values": [0, 1, 2]})
expected_series = pd.Series(table.get()["values"].values,
index=expected,
name="values",
dtype=np.int64)
result = segment.process_table(table.get()["values"])
pd.testing.assert_series_equal(result, expected_series)
table_df = table.copy()
table_df["string"] = audformat.Column()
table_df.set({"string": ["a", "b", "c"]})
expected_dataframe = pd.DataFrame(table_df.get().values,
index=expected,
columns=["values", "string"])
result = segment.process_table(table_df.get())
pd.testing.assert_frame_equal(result, expected_dataframe)

# filewise index with absolute paths
index = pd.Index([path], name="file")
expected = audformat.segmented_index(path, "0.1s", "2.9s")
Expand All @@ -199,6 +221,25 @@ def process_func(signal, sampling_rate):
result = segment.process_signal_from_index(signal, sampling_rate, index)
pd.testing.assert_index_equal(result, expected)

# filewise index with absolute paths: series and dataframe
table = audformat.Table(index)
table["values"] = audformat.Column()
table.set({"values": [5]})
expected_series = pd.Series(table.get()["values"].values,
index=expected,
name="values",
dtype=np.int64)
result = segment.process_table(table.get()["values"])
pd.testing.assert_series_equal(result, expected_series)
table_df = table.copy()
table_df["string"] = audformat.Column()
table_df.set({"string": ["d"]})
expected_dataframe = pd.DataFrame(table_df.get().values,
index=expected,
columns=["values", "string"])
result = segment.process_table(table_df.get())
pd.testing.assert_frame_equal(result, expected_dataframe)

# segmented index with relative paths
index = audformat.segmented_index(
[file] * 3,
Expand All @@ -215,6 +256,25 @@ def process_func(signal, sampling_rate):
result = segment.process_signal_from_index(signal, sampling_rate, index)
pd.testing.assert_index_equal(result, expected)

# segmented index with relative paths: series and dataframe
table = audformat.Table(index)
table["values"] = audformat.Column()
table.set({"values": [0, 1, 2]})
expected_series = pd.Series(table.get()["values"].values,
index=expected,
name="values",
dtype=np.int64)
maxschmitt marked this conversation as resolved.
Show resolved Hide resolved
result = segment.process_table(table.get()["values"])
pd.testing.assert_series_equal(result, expected_series)
table_df = table.copy()
table_df["string"] = audformat.Column()
table_df.set({"string": ["a", "b", "c"]})
expected_dataframe = pd.DataFrame(table_df.get().values,
index=expected,
columns=["values", "string"])
result = segment.process_table(table_df.get())
pd.testing.assert_frame_equal(result, expected_dataframe)

# filewise index with relative paths
index = pd.Index([file], name="file")
expected = audformat.segmented_index(file, "0.1s", "2.9s")
Expand All @@ -223,7 +283,26 @@ def process_func(signal, sampling_rate):
result = segment.process_signal_from_index(signal, sampling_rate, index)
pd.testing.assert_index_equal(result, expected)

# empty index returned by process func
# filewise index with relative paths: series and dataframe
table = audformat.Table(index)
table["values"] = audformat.Column()
table.set({"values": [5]})
expected_series = pd.Series(table.get()["values"].values,
index=expected,
name="values",
dtype=np.int64)
result = segment.process_table(table.get()["values"])
pd.testing.assert_series_equal(result, expected_series)
table_df = table.copy()
table_df["string"] = audformat.Column()
table_df.set({"string": ["d"]})
expected_dataframe = pd.DataFrame(table_df.get().values,
index=expected,
columns=["values", "string"])
result = segment.process_table(table_df.get())
pd.testing.assert_frame_equal(result, expected_dataframe)

# empty index / series / dataframe returned by process func

def process_func(x, sr):
return audinterface.utils.signal_index()
Expand All @@ -241,6 +320,16 @@ def process_func(x, sr):
result = segment.process_index(index)
pd.testing.assert_index_equal(result, expected)

table = pd.Series([0], index)
expected_series = pd.Series([], expected, dtype=np.float64)
result = segment.process_table(table)
pd.testing.assert_series_equal(result, expected_series)

table_df = pd.DataFrame([0], index, columns=["col"])
expected_df = pd.DataFrame([], expected, columns=["col"], dtype=np.float64)
result = segment.process_table(table_df)
pd.testing.assert_frame_equal(result, expected_df)


@pytest.mark.parametrize(
"signal, sampling_rate, segment_func, result",
Expand Down
Loading