Skip to content

Chain ‐ an immutable List with O(1) concatenation

Ben Yu edited this page Jun 12, 2024 · 3 revisions

Once upon a time

Agent Aragorn (Son of Arathorn), Jonny English and Smith collaborated on a bunch of missions.

The Mission class has signature like:

class Mission {
  abstract MissionId id();
  abstract Range<LocalDate> timeWindow();
  abstract ImmutableSet<Agent> agents();
}

The goal is to create an ImmutableRangeMap<LocalDate, Set<Agent>> to account for all the agents during each time window. Note that missions can have overlapping time windows, and agents could work on multiple missions at the same time.

So for missions like:

missions = [{
  timeWindow: [10/01..10/30]
  heroes: [Aragorn, English]
},
{
  timeWindow: [10/15..11/15]
  heroes: [Aragorn, Smith]
}]

I want the result to be:

[10/01..10/15): [Aragorn, English]
[10/15..10/30]: [Aragorn, English, Smith]
(10/30..11/15]: [Aragorn, Smith] 

At first I thought to use the toImmutableRangeMap() collector, as in:

missions.stream()
    .collect(toImmutableRangeMap(Mission::timeWindow, Mission::agents));

Voila, done, right?

No. My colleague pointed out that toImmutableRangeMap() does not allow overlapping ranges. It wants all input time windows to be disjoint.

Guava RangeMap is your friend

Luckily, TreeRangeMap has a merge() method that already does the heavylifting: finds overlapping and splits the ranges, and then merge the values mapped to the overlapping subrange.

With some effort, I created a toImmutableRangeMap(merger) BiCollector on top of the merge() function.

So if what I needed is just to count the number of agents, I could have done:

import static ...BiStream.biStream;

ImmutableRangeMap<LocalDate, Integer> agentCounts =
    biStream(missions)
        .mapKeys(Mission::timeWindow)
        .mapValues(mission -> mission.agents().size())
        .collect(toImmutableRangeMap(Integer::sum));

(It'll double count the duplicate agents though)

Anyhoo, here goes the interesting part: how do I merge the Set<Agent>?

Quadratic runtime

I could use Guava's Sets.union():

import com.google.common.collect.Sets;

ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
    biStream(missions)
        .mapKeys(Mission::timeWindow)
        .mapValues(mission -> mission.agents())
        .collect(toImmutableRangeMap((set1, set2) -> Sets.union(set1, set2).immutableCopy()));

The gotcha is that each time merging happens, merging two original sets into one is O(n) where n is the number of agents from the two overlapping ranges. If we are unlucky, we can get into the situation where a time window is repetitively discovered to overlap with another time window, and we keep copying and copying over again. The time complexity is quadratic.

Stack overflow

Could I remove the .immutableCopy()? Sets.union() returns a view that takes constant time so we should be good?

Well, not really. We don't know how many times merging will happen, a Set can be unioned, then unioned again for unknown times. In the worst case, we'd create a union-of-union-of-union N levels deep. If N is a large number, we'll stack overflow when we try to access the final SetView!

The same will happen if for example I use Iterables.concat(). The Stream.concat() javadoc discusses this problem.

Put it in a tree

I slept on this problem for about two days for an idea to come to me: can we use something like Haskell's List?

Tl;dr, Haskell's List is like LinkedList except it's immutable. So given a list of [2, 3], you can cons the number 1 onto the list to get a new instance of [1, 2, 3]. Under the hood it's as simple as creating a new object with the internal tail pointer pointing to the old [2, 3] list.

If I can do this, each time merging happens, I only need to pay O(1) cost. The resulting object is probably less efficient for random access than ArrayList or Guava's ImmutableList because of all the pointers and indirections. But that's okay. When the whole split-merge process is done, I can perform a final copy into ImmutableList, which is O(n).

The only problem? Haskell's cons only allows to add one element, while I have two List<Agent>s to concatenate (I can't cons every element from one list, because then I'm getting back to quadratic).

To support concat(list1, list2), I decided to use a binary tree to represent the List's state:

private static final class Tree<T> {
  final T mid;
  @Nullable final Tree<T> left;  // null means empty
  @Nullable final Tree<T> right;  // null means empty

  Tree(T value, Tree<T> left, Tree<T> right) {...}
}

In the list, the elements in left show up first, followed by mid, then followed by the elements in right. In other words, an in-order traversal will give us back the list.

The key trick is to figure out how to concatenate two binary trees into one. Intuitively, I need to find the new "mid point" value, which can be either the left tree's last element, or the right tree's first element. Say, if I take the right tree's first element, then the new tree's left remains the old left, while the new tree's right would need to be the old right after removing the first element.

Since the Tree is immutable, how do I remove? And in a binary tree, finding the first element takes up to O(n) time (it's not balanced tree).

It turns out there's a law in computer science:

All problems in computer science can be solved by another level of indirection

In human language: if a problem can't be solved with one layer of indirection, add a second layer of indirection. :)

Here goes my second layer of indirection that handles the remove first element from an immutable list task:

public final class Chain<T> {
  private final T head;
  @Nullable private final Tree<T> tail;

  public static <T> Chain<T> of(T value) {
    return new Chain<>(value, null);
  }

  public static <T> Chain<T> concat(Chain<T> left, Chain<T> right) {
    T newHead = left.head;
    Tree<T> newTail = new Tree<>(right.head, left.tail, right.tail);
    return new Chain<T>(newHead, newTail);
  }
}

It takes a bit of brain gymnastics. But if you sit down and think for a minute, it's actually pretty straight forward.

This solves the O(1) concatenation. And the good thing is that, no matter how deep concat() is nested, the result is always one layer of Chain with a heap-allocated Tree object.

Now we just need to make sure when we access the Chain, we take no more than O(n) time, and constant stack space.

Converting tree to List

My secret weapon is Walker.inBinaryTree() from Mug. It already does everything I needed:

  1. O(n) time in-order traversal.
  2. Constant stack space.

Using it is pretty simple. First we add a stream() method to the Tree class:

private static final class Tree<T> {
  ...

  Stream<T> stream() {
    return Walker.<Tree<T>>inBinaryTree(t -> t.left, t -> t.right)
        .inOrderFrom(this)
        .map(t -> t.mid);
  }
}

The inOrderFrom() method returns a lazy stream, which will take at the worst case O(n) heap space and constant stack space.

Then we wrap and polish it up in our wrapper Chain class:

public final class Chain<T> {
  ...

  /**
   * Returns a <em>lazy</em> stream of the elements in this list.
   * The returned stream is lazy in that concatenated chains aren't consumed until the stream
   * reaches their elements.
   */
  public Stream<T> stream() {
    return tail == null
        ? Stream.of(head)
        : Stream.concat(Stream.of(head), tail.stream());
  }
}

With that, it gives me O(n) time read access to the tree and I can easily collect() it into an ImmutableList.

In the actual implementation, I also made Chain implements List to make it nicer to use, and used lazy initialization to pay the cost only once. But that's just some extra API makeup. The meat is all here.

A bit of googling shows that people have run into similar needs but I didn't find a similar implementation that handles both the O(1) concatenation time and stack overflow concern.

Putting it together

So to build the RangeMap, we can first wrap each Mission in a Chain, let the merge process run, and finally flatten the merged mission chain:

import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;

ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
    biStream(missions)
        .mapKeys(Mission::timeWindow)
        .mapValues(Chain::of)
        .collect(toDisjointRanges(Chain::concat))  // BiStream<Range, Chain<Mission>>
        .mapValues(
            missions -> missions.stream()
                .flatMap(mission -> mission.agents().stream())
                .collect(toImmutableSet())
        .collect(ImmutableRangeMap::toImmutableRangeMap);

It still feels a bit verbose having to first wrap each Mission in a Chain and finally unwrap them. So I created another toDisjointRanges() overload to hide some of the implementation details. The above code is then simplified to:

import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;

ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
    biStream(missions)
        .mapKeys(Mission::timeWindow)
        .collect(toDisjointRanges())  // BiStream<Range, Chain<Mission>>
        .mapValues(
            missions -> missions.stream()
               .flatMap(mission -> mission.agents().stream())
               .collect(toImmutableSet())))
        .collect(ImmutableRangeMap::toImmutableRangeMap);

How generic is this API?

To be honest, I simplified the use case a bit to make it easier to explain. The original motivating use case also involves nested Maps.

Assume instead of the ImmutableRangeMap<LocalDate, Set<Agent>> which gives no information about which mission(s) each agent worked on at that time, we want ImmutableRangeMap<LocalDate, SetMultimap<Agent, MissionId>> where for each time window, there're both the agents and the missions they worked on.

We can change the above code to turn a Mission into a stream of mappings from Agent to MissionId (if you know the BiStream API):

import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;
import static com.google.mu.util.stream.GuavaCollectors.toImmutableSetMultimap;

ImmutableRangeMap<LocalDate, ImmutableSetMultimap<Agent, MissionId>> agentsTimeline =
    biStream(missions)
        .mapKeys(Mission::timeWindow)
        .collect(toDisjointRanges()) // BiStream<Range, Chain<Mission>>
        .mapValues(
            missions -> biStream(missions)
                .flatMapKeys(mission -> mission.agents().stream())
                .mapValues(Mission::id)
                .collect(toImmutableSetMultimap())))
        .collect(ImmutableRangeMap::toImmutableRangeMap);

About that quadratic thing

Before we conclude, there is another gotcha. That is, even with all that effort, you can still get into quadratic range merging.

Consider this:

{
  [0..1]: a,
  [0..2]: b,
  [0..3]: c,
  [0..4]: d,
  [0..5]: e,
  ...
}

When TreeRangeMap merges them in the given order, the number of range merges will be 1 + 2 + 3 + 4 + ... times. There is not much we can do in the Chain class to help because it's managed and triggered by the TreeRangeMap class itself.

To avoid the quadratic merging problem, perhaps it's prudent to first sort the ranges by their length in descending order. For the above example, merging in the descending order will cut the number of merges to 4, in other words, O(n).