-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b5d3e47
commit 225cf6d
Showing
16 changed files
with
816 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024 Roboto Technologies, Inc. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public | ||
# License, v. 2.0. If a copy of the MPL was not distributed with this | ||
# file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
from .signal_similarity import ( | ||
Match, | ||
MatchContext, | ||
find_similar_signals, | ||
) | ||
|
||
__all__ = ("Match", "MatchContext", "find_similar_signals") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) 2024 Roboto Technologies, Inc. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public | ||
# License, v. 2.0. If a copy of the MPL was not distributed with this | ||
# file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
from .match import Match, MatchContext | ||
from .signal_similarity import ( | ||
find_similar_signals, | ||
) | ||
|
||
__all__ = ("Match", "MatchContext", "find_similar_signals") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) 2024 Roboto Technologies, Inc. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public | ||
# License, v. 2.0. If a copy of the MPL was not distributed with this | ||
# file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
from __future__ import annotations | ||
|
||
import collections.abc | ||
import dataclasses | ||
import typing | ||
|
||
from ...domain import events | ||
from ...http import RobotoClient | ||
|
||
if typing.TYPE_CHECKING: | ||
import pandas # pants: no-infer-dep | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class MatchContext: | ||
""" | ||
Correlate a matched subsequence back to its source. | ||
""" | ||
|
||
message_paths: collections.abc.Sequence[str] | ||
topic_id: str | ||
topic_name: str | ||
|
||
dataset_id: typing.Optional[str] = None | ||
file_id: typing.Optional[str] = None | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Match: | ||
""" | ||
A subsequence of a target signal that is similar to a query signal. | ||
""" | ||
|
||
context: MatchContext | ||
""" | ||
Correlate a matched subsequence back to its source. | ||
""" | ||
|
||
end_idx: int | ||
""" | ||
The end index in the target signal of this match. | ||
""" | ||
|
||
end_time: int | ||
""" | ||
The end time in the target signal of this match. | ||
""" | ||
|
||
distance: float | ||
""" | ||
Unitless measure of similarity between a query signal | ||
and the subsequence of the target signal this Match represents. | ||
A smaller distance relative to a larger distance indicates a "closer" match. | ||
""" | ||
|
||
start_idx: int | ||
""" | ||
The start index in the target signal of this match. | ||
""" | ||
|
||
start_time: int | ||
""" | ||
The start time in the target signal of this match. | ||
""" | ||
|
||
subsequence: pandas.DataFrame | ||
""" | ||
The subsequence of the target signal this Match represents. | ||
It is equivalent to ``target[start_idx:end_idx]``. | ||
""" | ||
|
||
def to_event( | ||
self, | ||
name: str = "Signal Similarity Match Result", | ||
caller_org_id: typing.Optional[str] = None, | ||
roboto_client: typing.Optional[RobotoClient] = None, | ||
) -> events.Event: | ||
""" | ||
Create a Roboto Platform event out of this similarity match result. | ||
""" | ||
return events.Event.create( | ||
description=f"Match score: {self.distance}", | ||
end_time=self.end_time, | ||
name=name, | ||
metadata={ | ||
"distance": self.distance, | ||
"message_paths": self.context.message_paths, | ||
"start_index": self.start_idx, | ||
"end_index": self.end_idx, | ||
}, | ||
start_time=self.start_time, | ||
topic_ids=[self.context.topic_id], | ||
caller_org_id=caller_org_id, | ||
roboto_client=roboto_client, | ||
) |
261 changes: 261 additions & 0 deletions
261
src/roboto/analytics/signal_similarity/signal_similarity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright (c) 2024 Roboto Technologies, Inc. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public | ||
# License, v. 2.0. If a copy of the MPL was not distributed with this | ||
# file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
from __future__ import annotations | ||
|
||
import collections.abc | ||
import logging | ||
import typing | ||
|
||
import tqdm.auto | ||
|
||
from ...compat import import_optional_dependency | ||
from ...domain.topics import Topic | ||
from ...logging import default_logger | ||
from .match import Match, MatchContext | ||
|
||
if typing.TYPE_CHECKING: | ||
import numpy # pants: no-infer-dep | ||
import numpy.typing # pants: no-infer-dep | ||
import pandas # pants: no-infer-dep | ||
|
||
|
||
logger = default_logger() | ||
|
||
# Query signals must be at minimum 3 values long for results to be meaningful. | ||
MIN_QUERY_LENGTH = 3 | ||
|
||
|
||
class MatchResult(typing.NamedTuple): | ||
start_idx: int | ||
end_idx: int | ||
distance: float | ||
|
||
|
||
def _find_matches( | ||
query: numpy.typing.NDArray, | ||
target: numpy.typing.NDArray, | ||
*, | ||
max_distance: typing.Optional[float] = None, | ||
max_matches: typing.Optional[int] = None, | ||
normalize: bool = False, | ||
) -> collections.abc.Sequence[MatchResult]: | ||
""" | ||
For performing signal similarity, see :py:func:`~roboto.analytics.signal_similarity.find_similar_signals`. | ||
""" | ||
stumpy = import_optional_dependency("stumpy", "analytics") | ||
|
||
if len(query) < MIN_QUERY_LENGTH: | ||
raise ValueError( | ||
f"Query sequence must be greater than {MIN_QUERY_LENGTH} for results to be meaningful. " | ||
f"Received sequence of length {len(query)}." | ||
) | ||
|
||
if len(query) > len(target): | ||
raise ValueError("Query sequence must be shorter than target") | ||
|
||
matches: list[MatchResult] = [] | ||
for distance, start_idx in stumpy.match( | ||
query, | ||
target, | ||
max_distance=max_distance, | ||
max_matches=max_matches, | ||
normalize=normalize, | ||
): | ||
end_idx = start_idx + len(query) - 1 | ||
matches.append( | ||
MatchResult( | ||
start_idx=int(start_idx), | ||
end_idx=int(end_idx), | ||
distance=float(distance), | ||
) | ||
) | ||
|
||
return matches | ||
|
||
|
||
def _find_matches_multidimensional( | ||
query: pandas.DataFrame, | ||
target: pandas.DataFrame, | ||
*, | ||
max_distance: typing.Optional[float] = None, | ||
max_matches: typing.Optional[int] = None, | ||
normalize: bool = False, | ||
) -> collections.abc.Sequence[MatchResult]: | ||
""" | ||
For performing signal similarity, see :py:func:`~roboto.analytics.signal_similarity.find_similar_signals`. | ||
""" | ||
np = import_optional_dependency("numpy", "analytics") | ||
stumpy = import_optional_dependency("stumpy", "analytics") | ||
|
||
if len(query) < MIN_QUERY_LENGTH: | ||
raise ValueError( | ||
f"Query signal must be greater than {MIN_QUERY_LENGTH} for results to be meaningful. " | ||
f"Received DataFrame of size {query.shape}." | ||
) | ||
|
||
query_dims = set(query.columns.tolist()) | ||
target_dims = set(target.columns.tolist()) | ||
non_overlap = query_dims.difference(target_dims) | ||
if len(non_overlap): | ||
raise ValueError( | ||
"Cannot match query against target: they have non-overlapping dimensions. " | ||
f"Target signal is missing the following attributes: {non_overlap}" | ||
) | ||
|
||
# Accumulate summed distances for each subsequence (of length `query_signal`) within the target. | ||
# The distance for each subsequence starts at 0 and is incrementally updated. | ||
# Each dimension of the query signal (i.e., column in dataframe) is considered in turn: | ||
# for each dimension, compute the distance profile against the corresponding dimension in the target signal, | ||
# and then add that distance profile to the running total. | ||
# N.b.: for a target of len N and query of len M, there are a total of N - M + 1 subsequences | ||
# (the first starts at index 0, the second at index 1, ..., the last starts at index N - M) | ||
summed_distance_profile: numpy.typing.NDArray[numpy.floating] = np.zeros( | ||
len(target) - len(query) + 1 | ||
) | ||
for column in query_dims: | ||
query_sequence = query[column].to_numpy() | ||
target_sequence = target[column].to_numpy() | ||
distance_profile: numpy.typing.NDArray[numpy.floating] = stumpy.mass( | ||
query_sequence, target_sequence, normalize=normalize | ||
) | ||
summed_distance_profile += distance_profile | ||
|
||
matches: list[MatchResult] = [] | ||
for distance, start_idx in stumpy.core._find_matches( | ||
summed_distance_profile, | ||
# https://github.com/TDAmeritrade/stumpy/blob/b7b355ce4a9450357ad207dd4f04fc8e8b4db100/stumpy/motifs.py#L533C17-L533C64 | ||
excl_zone=int(np.ceil(len(query) / stumpy.core.config.STUMPY_EXCL_ZONE_DENOM)), | ||
max_distance=max_distance, | ||
max_matches=max_matches, | ||
): | ||
end_idx = start_idx + len(query) - 1 | ||
matches.append( | ||
MatchResult( | ||
start_idx=int(start_idx), | ||
end_idx=int(end_idx), | ||
distance=float(distance), | ||
) | ||
) | ||
|
||
return matches | ||
|
||
|
||
def find_similar_signals( | ||
needle: pandas.DataFrame, | ||
haystack: collections.abc.Iterable[Topic], | ||
*, | ||
max_distance: typing.Optional[float] = None, | ||
max_matches_per_target: typing.Optional[int] = None, | ||
normalize: bool = False, | ||
) -> collections.abc.Sequence[Match]: | ||
""" | ||
Find subsequences of topic data (from ``haystack``) that are similar to ``needle``. | ||
If ``needle`` is a dataframe with a single, non-index column, | ||
single-dimensional similarity search will be performed. | ||
If it instead has multiple non-index columns, multi-dimensional search will be performed. | ||
Even if there is no true similarity between the query signal and a topic's data, | ||
this will always return at least one :py:class:`~roboto.analytics.signal_similarity.Match`. | ||
Matches are expected to improve in quality as the target is more relevant to the query. | ||
Matches are returned sorted in ascending order by their distance, with the best matches (lowest distance) first. | ||
If ``max_distance`` is provided, only matches with a distance less than ``max_distance`` will be returned. | ||
Given distances computed against all comparison windows in the target, this defaults to the maximum of: | ||
1. the minimum distance | ||
2. the mean distance minus two standard deviations | ||
Use ``max_matches_per_target`` to limit the number of match results contributed by a single target. | ||
If ``normalize`` is True, values will be projected to the unit scale before matching. | ||
This is useful if you want to match windows of the target signal regardless of scale. | ||
For example, a query sequence of ``[1., 2., 3.]`` will perfectly match (distance == 0) | ||
the target ``[1000., 2000., 3000.]`` if ``normalize`` is True, | ||
but would have a distance of nearly 3800 if ``normalize`` is False. | ||
""" | ||
matches: list[Match] = [] | ||
_, cols = needle.shape | ||
|
||
if cols == 1: | ||
# Single dimensional similarity search | ||
msg_path = needle.columns[0] | ||
query_sequence = needle[msg_path].to_numpy() | ||
for topic in tqdm.auto.tqdm(iterable=haystack): | ||
match_context = MatchContext( | ||
dataset_id=topic.dataset_id, | ||
file_id=topic.file_id, | ||
message_paths=[msg_path], | ||
topic_name=topic.name, | ||
topic_id=topic.topic_id, | ||
) | ||
|
||
if logger.isEnabledFor(logging.DEBUG): | ||
tqdm.auto.tqdm.write(f"Searching for matches in {match_context!r}") | ||
|
||
topic_data = topic.get_data_as_df(message_paths_include=[msg_path]) | ||
target_signal = topic_data[msg_path].to_numpy() | ||
for match_result in _find_matches( | ||
query_sequence, | ||
target_signal, | ||
max_distance=max_distance, | ||
max_matches=max_matches_per_target, | ||
normalize=normalize, | ||
): | ||
matches.append( | ||
Match( | ||
context=match_context, | ||
end_idx=match_result.end_idx, | ||
end_time=topic_data.index[match_result.end_idx].item(), | ||
distance=match_result.distance, | ||
start_idx=match_result.start_idx, | ||
start_time=topic_data.index[match_result.start_idx].item(), | ||
subsequence=topic_data[ | ||
match_result.start_idx : match_result.end_idx + 1 | ||
], | ||
) | ||
) | ||
else: | ||
# Multi-dimensional match | ||
message_paths = needle.columns.tolist() | ||
|
||
for topic in tqdm.auto.tqdm(iterable=haystack): | ||
match_context = MatchContext( | ||
dataset_id=topic.dataset_id, | ||
file_id=topic.file_id, | ||
message_paths=message_paths, | ||
topic_name=topic.name, | ||
topic_id=topic.topic_id, | ||
) | ||
|
||
if logger.isEnabledFor(logging.DEBUG): | ||
tqdm.auto.tqdm.write(f"Searching for matches in {match_context!r}") | ||
|
||
target_signal = topic.get_data_as_df(message_paths_include=message_paths) | ||
for match_result in _find_matches_multidimensional( | ||
needle, | ||
target_signal, | ||
max_distance=max_distance, | ||
max_matches=max_matches_per_target, | ||
normalize=normalize, | ||
): | ||
matches.append( | ||
Match( | ||
context=match_context, | ||
end_idx=match_result.end_idx, | ||
end_time=target_signal.index[match_result.end_idx].item(), | ||
distance=match_result.distance, | ||
start_idx=match_result.start_idx, | ||
start_time=target_signal.index[match_result.start_idx].item(), | ||
subsequence=target_signal[ | ||
match_result.start_idx : match_result.end_idx + 1 | ||
], | ||
) | ||
) | ||
|
||
matches.sort(key=lambda match: match.distance) | ||
|
||
return matches |
Oops, something went wrong.