Skip to content

Commit

Permalink
Merge pull request #50 from otsob/additional-set-operations
Browse files Browse the repository at this point in the history
Add set diff and containment operations to PointSet2d
  • Loading branch information
otsob authored Jul 13, 2024
2 parents 5ed689f + d7bd3c1 commit 6b32389
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 17 deletions.
83 changes: 70 additions & 13 deletions musii_kit/point_set/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,20 @@ def _read_elem_to_points(elem, measure_offset, points_and_notes, pitch_extractor

@staticmethod
def from_numpy(points_array, piece_name=None, pitch_type=None):
points = []
for i in range(len(points_array)):
row = points_array[i, :]
points.append(Point2d(row[0], row[1]))
points = PointSet2d.__array_to_point_list(points_array)

point_set = PointSet2d(points, piece_name, dtype=points_array.dtype)
point_set._pitch_type = pitch_type
return point_set

@staticmethod
def __array_to_point_list(points_array):
points = []
for i in range(len(points_array)):
row = points_array[i, :]
points.append(Point2d(row[0], row[1]))
return points

@staticmethod
def from_dict(input_dict):
piece_name = input_dict['piece_name']
Expand Down Expand Up @@ -457,6 +462,48 @@ def __or__(self, other):
all_points = [p for p in self] + [p for p in other]
return PointSet2d(all_points, self.piece_name, self._dtype)

def __contains__(self, point: Point2d):
return any((self._points[:, 0:2] == np.array([point.onset_time, point.pitch_number])).all(1))

def __repr__(self):

num_of_points_to_show = 10
points_string = (','.join([str(p) for p in self][:num_of_points_to_show])
+ (' ...' if len(self) > num_of_points_to_show else ''))

return (f'PointSet2d[len={len(self)}, {points_string}], piece={self.piece_name},'
f'dtype={self.dtype}, pitch_type={self.pitch_type}')

def __sub__(self, other):
"""
Returns a new set that is the set difference between this and other, i.e.,
a set that has all points included in this that are not included in other.
:param other: the set of points to remove from this
:return: a new set that is the set difference between this and other
"""
i = 0
j = 0

included_points = []

while i < len(self) and j < len(other):
this = self[i]
that = other[j]

if this == that:
i += 1
j += 1
elif this < that:
i += 1
included_points.append(this)
else:
j += 1

# Append all the rest of the points from self that are left
included_points += [self[p] for p in range(i, len(self))]

return self.__deep_copy_other_fields(included_points)

def get_range(self, start, end) -> List[Point2d]:
"""
Returns the points in the given time-range (inclusive) in ascending lexicographic order.
Expand All @@ -474,6 +521,24 @@ def get_range(self, start, end) -> List[Point2d]:

return points

def __deep_copy_other_fields(self, points):
copied = PointSet2d(points,
piece_name=self.piece_name,
dtype=self.dtype,
quarter_length=self.quarter_length,
measure_line_positions=deepcopy(self.measure_line_positions),
score=deepcopy(self.score),
points_to_notes=deepcopy(self._point_to_notes),
pitch_extractor=self.pitch_extractor,
# Generate new id
point_set_id=None,
has_expanded_repetitions=self.has_expanded_repetitions,
tie_continuations=deepcopy(self.__tie_continuations),
time_signatures=deepcopy(self.time_signatures))
copied._pitch_type = self._pitch_type

return copied

def time_scaled(self, factor):
"""
Returns a time-scaled copy of this point-set. The onset times are multiplied by the given factor.
Expand All @@ -486,15 +551,7 @@ def time_scaled(self, factor):
scaled_point_array[:, 0] = self._points[:, 2] * factor
scaled_point_array[:, 1] = self._points[:, 1]

scaled = PointSet2d.from_numpy(scaled_point_array, self.piece_name)
scaled.quarter_length = self.quarter_length
scaled.measure_line_positions = self.measure_line_positions
scaled._score = self.score
scaled._point_to_notes = self._point_to_notes
scaled._pitch_extractor = self._pitch_extractor
scaled._pitch_type = self._pitch_type

return scaled
return self.__deep_copy_other_fields(self.__array_to_point_list(scaled_point_array))

def get_measure(self, point):
""" Returns the number of the measure in which the point is located.
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 43 additions & 1 deletion tests/test_point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

class TestPointSet2d:
test_path = Path(os.path.dirname(os.path.realpath(__file__)))
test_points = [Point2d(1.0000001, 20.0), Point2d(1.0, 20.0), Point2d(0.0, 21.0), Point2d(2.0, 20.0),
test_points = [Point2d(1.0000001, 20.0),
Point2d(1.0, 20.0),
Point2d(0.0, 21.0),
Point2d(2.0, 20.0),
Point2d(2.0, 21.0)]

def test_given_float_points_then_correct_point_set_is_created(self):
Expand Down Expand Up @@ -186,6 +189,45 @@ def test_given_pattern_score_region_with_tolerance_is_correctly_retrieved(self):
region = point_set.get_score_region(pattern, boundaries='exclude', tolerance=1.0)
assert len(region.flatten().notes) == 6

def test_given_points_in_point_set_contains_is_true(self):
point_set = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)
for point in self.test_points:
assert point in point_set

def test_given_points_not_in_point_set_contains_is_false(self):
point_set = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)
assert Point2d(1.2, 20.0) not in point_set
assert Point2d(1.0, 19.0) not in point_set

def test_given_equal_point_sets_difference_is_empty(self):
ps_a = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)
ps_b = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)

assert len(ps_a - ps_b) == 0
assert len(ps_b - ps_a) == 0

def test_given_point_sets_with_no_common_points_difference_has_no_effect(self):
ps_a = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)
ps_b = PointSet2d([Point2d(1.0, 21.0),
Point2d(0.5, 21.0),
Point2d(2.0, 24.0),
Point2d(2.0, 11.0)], piece_name='Test piece', dtype=float)

assert (ps_a - ps_b) == ps_a
assert (ps_b - ps_a) == ps_b

def test_given_point_sets_with_some_common_points_difference_is_correct(self):
ps_a = PointSet2d(self.test_points, piece_name='Test piece', dtype=float)
ps_b = PointSet2d([Point2d(0.0, 21.0),
Point2d(2.0, 20.0),
Point2d(2.0, 21.0),
Point2d(2.5, 21.0)], piece_name='Test piece', dtype=float)

expected = PointSet2d([Point2d(1.0, 20.0)],
piece_name='Difference', dtype=float)

assert (ps_a - ps_b) == expected


class TestPattern2d:
test_points = [Point2d(1.0000001, 20.0), Point2d(1.0, 20.0), Point2d(0.0, 21.0)]
Expand Down

0 comments on commit 6b32389

Please sign in to comment.