2016. február 24., szerda

FlatMap (part 1)

Introduction

In this blog post, I begin to explain the outer and inner properties of the most used, misunderstood and at the same time, one of the most complex operator there is: flatMap.

FlatMap is most useful because it let's you replace simple values with something that can change the output in terms of time, location and value count. FlatMap is misunderstood because it is introduced late, not enough time is spent demonstrating it and often surrounded with functional programming technoblabble. Finally, it's complex because it has to coordinate backpressure of a single consumer and request from multiple sources, and we usually don't know which of them will respond with actual items. Maybe all of them.

FlatMap has a companion operator: merge. Merge lets you flatten a sequence of Observables into a single stream of values while ensuring the contract of the Observer, namely, the requirement of non-concurrent invocation of the onXXX methods and the conformance to the onNext* (onError|onCompleted)? protocol. This is necessary because although the individual Observables you merge do conform the same protocol individually, they get mixed in time, location and numbers when you listen to them all at once. Of course, flatMap has to do the same so why are there two operators?

The answer is convenience and usage pattern. FlatMap is an in-sequence operator that reacts to values from the upstream by generating an Observable, through a callback function, that is internally subscribed to, coordinated and serialized in respect to any previous or subsequent Observables generated through the same callback function. Merge, on the other hand works on a two-dimensional sequence: an Observable of Observables. There is no function involved here but the operator has to subscribe all of those inner Observables emitted by the outer Observable.

The fun thing is, you can express them with the other:

Func1<T, Observable<R>> f = ...
source.flatMap(f) == Observable.merge(source.map(f))


Observable<Observable<R>> o = ...
Observable.merge(o) == o.flatMap(v -> v)

In the first case, flatMap can be expressed by mapping the values of the source onto Observable<R>s and mergin them together.  If you look at RxJava 1.x source code, you'll see that flatMap is implemented in terms of merge in this way. In the second case, given the two-dimensional sequence, when we flatMap over the elements of the inner Observable<R>s as the value v, they are already of type observable and we can return them as they are.

You can think of flatMap as the join part of a fork-join operation, when all threads come together again to form a single sequence. However, there are no guarantees on when this coming together happens.

For example, given a sequence of product IDs, you'd want to fire off network calls, available conveniently with Retrofit-reachable services, that return some additional information to each of them. We know that networks and databases can respond in an unpredictable way and some network calls for later IDs may come back earlier than other responses. Now you have the responses in arbitrary order. Sometimes, the order doesn't matter, sometimes it does.

But no matter how many times this lack of ordering guarantee is mentioned with flatMap, it still tends to pop up in questions everywhere. But why?

The reason is how flatMap is introduced to the reader: by showing a completely sequential example of maybe mapping a range of values onto a subranges:


Observable.range(1, 10).flatMap(v -> Observable.range(v, 2))
.subscribe(System.out::println);


This is completely synchronous use and the the output is nicely ordered. Then you see this same example written with concatMap, which does keep the order, and receive the same output. So instead, let's introduce some asynchrony into the example, a fairly obvious and simple one to show that the order of the input may not hold all the way through:


Observable.range(1, 10)
.flatMap(v -> Observable.just(v).delay(11 - v, TimeUnit.SECONDS))
.toBlocking()
.subscribe(System.out::println);


What happens here is we map the individual integers onto a delayed scalar Observable where the delay is gets smaller as the value gets bigger. At the end, the output is a sequence of decreasing values, the complete reverse order of the original input range.

FlatMap also plays a big role in asynchronous continuations. That is, when some asynchronous computation or network retrieval completes, one wishes to resume with some other asynchronous computation based on the single value returned by the first. The emphasis is on single here. My current knowledge is that no reactive network libraries do stream you multiple values at this time, but give you a single result at once (which can be a List of all values but still just a single object). Thus, when one encounters a logic with flatMap, one rarely experiences its property of running Observables in parallel.

The final property of using flatMap is its ability to change the number of items that gets emitted to the downstream. Given the proper callback function, you can make the operator emit nothing for a single input, exactly one value, multiple values or even an error. All you have to do is return empty(), just(), some chain of Observables or error() from that callback.

It is often asked what one should do when given a regular map() operation, one would like to throw an error instead of returning a value. If that exception is a RuntimeException, you can throw it directly and RxJava will turn it into onError for you inside the map operator. However, if you have checked Exception, like IOException, you are out of luck with map(). Either you wrap it into a RuntimeException but then you have to unwrap it somewhere else or write a custom map() operator with a callback function that let's you throw checked exception.

The alternative is to use flatMap because you can return an error() Observable which then directly emits your error as onError without the need of wrapping:

Observable.range(1, 10)
.flatMap(v -> {
    if (v < 5) {
        return Observable.just(v * v);
    }
    return Observable.<Integer>error(new IOException("Why not?!"));
})
.subscribe(System.out::println, Throwable::printStackTrace);

Sometimes, you flatMap Observables that themselves may signal an error for some reason. The default RxJava behavior is that whenever an onError situation is encountered, tear down everything immediately and terminate the stream with that specific error. The problem with this is that a failing source can waste your ongoing effort of the other sources, but at the same time, you don't want to suppress that failing source with one of the onErrorXXX operators - likely because you may not know which of them is going to fail upfront.

The solution is to delay the error till all sources terminated and report the error(s) after at the very end. This allows us to apply a "global" error handler on the output of the flatMap and still work with all "successfully received" values.

Therefore, flatMap has an overload which takes a boolean delayErrors parameter just after the function callback. The operator merge has a different method name for the same behavior: mergeDelayError.

Backpressure, the way of preventing buffer-bloat with reactive flows, is a cornerstone of RxJava. Most operators that don't have timing aspects in them apply and honor backpressure. Unfortunately, flatMap by default can only say it honors backpressure but doesn't apply it towards its main input.

This means that when you use the common overload of flatMap or merge, the operator will request Long.MAX_VALUE from its upstream and realize it all at once. This unbounded behavior leads to an unbounded number of active subscriptions to the generated inner Observables.

This property doesn't cause much trouble if the inner Observables are short or infrequently emitting, but if there is an asynchronous boundary after flatMap, let's say, observeOn, items can quite easily pile up in flatMap and degrade performance considerably.

Technically, as we will see, there is nothing preventing flatMap from applying the same backpressure to its input. However, historically, some coming from Rx.NET were relying on its lack of backpressure and happily merge 1000s of Observables at once and at the same time, somehow relying on the fact that they are merged live. Thus, the default unbounded behavior stuck with RxJava.

However, many recognized that merging 1000s hot Observables doesn't compare to 1000s of cold, networking Observables and there needs to be a way of limit the active number of Observables. Thus, flatMap and merge have overloads that take a maxConcurrency parameter to make this limit happen. The fun fact of this property is that it's much easier to implement it in a backpressured environment than in the non-backpressured Rx.NET world.


Implementing FlatMap

Given the conversion between flatMap and merge above, one can ask the question which operator should we implement. Clearly, merge() doesn't have to deal with a mapping function so why not do that, like RxJava itself?

The answer is: allocation. If you implement flatMap in terms of merge, you have to use two operators: merge and map. When the sequences are assembled, the application of an operator incurs allocation cost. This has to happen because operators usually hold some state: parameters, function callbacks, etc. Having more operators means having more assembly allocation, more garbage and more GC churn, especially if the sequence is short lived.

(Things got worse a bit due to a convenience decision made some time ago in RxJava: the introduction of the lift() operator and its ubiquitous usage inside the standard operators. So in total, applying operators may incur allocating 6-10 objects whereas the theoretical minimum should be 1-2.)

FlatMap has to serialize events coming from all of the active sources so the first building block that comes into mind is SerializedSubscriber. How easy it would be to just subscribe (or route to) an instance of it and we have the all things nicely serialized out.

Unfortunately, that doesn't work.

First of all, subscribing the same, stateful instance of SerializedSubscriber to multiple sources is a bad idea. We can't really control the request this way, plus, different sources may and often will set a Producer on their Subscribers, thus given a single instance of SerializedSubscriber, they will overwrite each other's Producer.

Second, we need a way to get rid of completed sources and not retain them indefinitely. Since there is no Subscriber.remove() to complement Subscriber.add(), even if we instantiate multiple Subscribers that forward events to the same underlying SerializedSubscriber, we have to do some CompositeSubscription juggling to get the cleanup or downstream's unsubscription working.

Lastly, SerializedSubscriber can block due to the use of synchronized block. Blocking gives some "natural" backpressure, but at the same time, hinders progress. If one runs within an asynchronous requirement/environment, there is a great incentive to avoid blocking as much as possible.

Therefore, the tool to avoid blocking and get serialized output is to use the familiar queue-drain approach. So let's start by sketching out the skeleton of our flatMap operator:


public final class OpFlatMap<T, R> implements Operator<R, T> {

    final Func1<? super T, ? extends Observable<? extends R>> mapper;
    
    final int prefetch;

    public OpFlatMap(Func1<? super T, ? extends Observable<? extends R>> mapper,
            int prefetch) {
        this.mapper = mapper;
        this.prefetch = prefetch;
    }
    
    @Override
    public Subscriber call(Subscriber<? super R> t) {
        FlatMapSubscriber<T, R> parent = new FlatMapSubscriber<>(t, mapper, prefetch);
        parent.init();
        return parent;
    }
}

The operator takes a mapper callback function that will generate the inner Observables and a prefetch amount that tells how many items to request from each of these inner Observables. We will hand these and the incoming child subscriber over to the parent Subscriber we create. For convenience, the setting up of the unsubscription chain and backpressure is hidden inside the init() method.

Next comes the implementation of the FlatMapSubscriber that does the coordination and value collection:


static final class FlatMapSubscriber<T, R> extends Subscriber<T> {
    final Subscriber<? super R> actual;
    
    final Func1<? super T, ? extends Observable<? extends R>> mapper;
    
    final int prefetch;                                             // (1)

    final CompositeSubscription csub;                               // (2)
    
    final AtomicInteger wip;                                        // (3)
    
    final Queue<Object> queue;                                      // (4)
    
    final AtomicLong requested;                                     // (5)

    final AtomicInteger active;                                     // (6)
    
    final AtomicReference<Throwable> error;                         // (7)
    
    public FlatMapSubscriber(Subscriber<? super R> actual,
            Func1<? super T, ? extends Observable<? extends R>> mapper,
            int prefetch) {
        this.actual = actual;
        this.mapper = mapper;
        this.prefetch = prefetch;
        this.csub = new CompositeSubscription();
        this.wip = new AtomicInteger();
        this.requested = new AtomicLong();
        this.queue = new ConcurrentLinkedQueue<>();
        this.active = new AtomicInteger(1);
        this.error = new AtomicReference<>();
    }
    
    public void init() {
        // TODO implement
    }
    
    @Override
    public void onNext(T t) {
        // TODO implement
    }
    
    @Override
    public void onError(Throwable e) {
        // TODO implement
    }
    
    @Override
    public void onCompleted() {
        // TODO implement
    }
    
    void childRequested(long n) {
        // TODO implement
    }

    void innerNext(Subscriber<R> inner, R value) {
        // TODO implement
    }
    
    void innerError(Throwable ex) {
        // TODO implement
    }
    
    void innerComplete(Subscriber<?> inner) {
        // TODO implement
    }
    
    void drain() {
        // TODO implement
    }
}

So far nothing special, just the usual fields and parameters:

  1. We have the child subscriber, the mapper function and the prefetch value.
  2. We will track the inner subscribers with a CompositeSubscription so when the child unsubscribes, we can unsubscribe them all at once, plus when an inner Observable terminates, we can remove just its subscriber.
  3. We have the usual work-in-progress atomic integer indicating if there is a drain going on, thus establishing the non-blocking queue-drain approach
  4. We have a shared queue where all sources will submit their value before attempting to drain it towards the child subscriber. The queue takes Object instead of R because we will also use this queue to post the Subscriber who generated that particular value so we can request more from that particular source.
  5. We need to track the child requested amount because when a child request 1, we can't really tell which inner Observable to request that 1 from, thus we have to request all of them. However, this may yield any number of response items and we can't just simply emit all of them to the child subscriber (possibly causing MissingBackpressureException down the line).
  6. We need to track how many active sources there are, including the main source of Ts. When this counter reaches zero, the child subscriber can be completed.
  7. Any of the sources may signal an error or even the main source itself as well. For simplicity, we will only store the very first exception and route the rest to the RxJavaPlugins' error handler.
You may think, if this so-called request amplification happens, why not request 1-by-1? The reason is twofold: a) it is very inefficient to get values 1-by-1 from most sources and b) you still get N values for a single downstream request so you have to do accounting of delivery of some sorts to know when to request 1 again from the inner sources.

Before we jump into the unimplemented methods, we still need another class: FlatMapInnerSubscriber that we will use to subscribe to each individual Observable<R> generated by the mapper function. Since you can't extend two classes or implement the same generic interface with different type parameters, a separate class is required.

static final class FlatMapInnerSubscriber<T, R> extends Subscriber<R> {
    final FlatMapSubscriber<T, R> parent;

    public FlatMapInnerSubscriber(FlatMapSubscriber<T, R> parent, int prefetch) {
        this.parent = parent;
        request(prefetch);                                         // (1)
    }
    
    @Override
    public void onNext(R t) {
        parent.innerNext(this, t);                                 // (2)
    }
    
    @Override
    public void onError(Throwable e) {
        parent.innerError(e);
    }
    
    @Override
    public void onCompleted() {
        parent.innerComplete(this);
    }

    void requestMore(long n) {
        request(n);                                                // (3)
    }
}

Here, we start with an initial request of the prefetch (1) value and delegate all onXXX methods back to the parent FlatMapSubscriber. In the parent FlatMapSubscriber, I mentioned that we will enqueue the sender along with the value it sends in the shared Queue<Object>. This may come non-intuitive and one would simply call request(1) just before or after (2). The problem with this is that the source will keep receiving requests and generate values, flooding the queue and not achieving backpressure at all. The solution is to make sure one requests only when that particular source's value has been taken and thus it is allowed to produce a replacement. (We will see in part 2 how this can be achieved by different means). In addition, we need to expose the protected request() method to allow the drain loop to request replenishments.

Now back to FlatMapSubscriber:


    public void init() {
        add(csub);
        actual.add(this);
        actual.setProducer(new Producer() {
            @Override
            public void request(long n) {
                childRequested(n);
            }
        });
    }

In the method init() we setup the unsubscription link with the child subscriber and delegate its request() call back to the childRequested() method of ours. You could ask, why not do this in the constructor? The reason is that by having this separate, there this of the constructor won't leak before all the final fields have been sealed, avoiding memory visibility and other problems.


    @Override
    public void onNext(T t) {
        Observable<? extends R> o;
        
        try {
            o = mapper.call(t);
        } catch (Throwable ex) {
            Exceptions.throwOrReport(ex, this, t);
            return;
        }
        
        active.getAndIncrement();
        FlatMapInnerSubscriber<T, R> inner = 
                new FlatMapInnerSubscriber<>(this, prefetch);
        csub.add(inner);
        
        o.subscribe(inner);
    }

In the onNext() method, we call the function to generate an Observable, increment the active counter, create the inner Subscriber and add it to the composite before (!) we subscribe it to the generated Observable. This way, when the inner terminates, it won't accidentally decrement the active count to zero and will be able to remove itself from the composite. Since the mapper can throw, we wrap the call into a try-catch and use a helper method from Exceptions to either rethrow a fatal exception (such as OutOfMemoryError or StackOverflowError) or report it through ourselves, which essentially calls onError:

    @Override
    public void onError(Throwable e) {
        if (error.compareAndSet(null, e)) {
            unsubscribe();
            drain();
        } else {
            RxJavaPlugins.getInstance()
            .getErrorHandler().handleError(e);
        }
    }

We atomically try to set the Throwable inside error if it is still null, if successful, we unsubscribe ourselves (and thus all active inner subscribers) and call drain() which will take care of emitting the error in a serialized fashion to the child subscriber. If there was somebody else that already signaled an error, we instead send the Throwable to the plugin handler. We don't have to worry about the active count here because onError is an immediate terminal state for us, unlike what happens in onCompleted:


    @Override
    public void onCompleted() {
        if (active.decrementAndGet() == 0) {
            drain();
        }
    }

The active count is decremented atomically and if it reaches zero, we call drain(). The drain will make sure all queued up values get emitted before the completion signal. Since we consider the main source of Ts also an input, the counter starts at 1 and may go up and down as inner sources get created and subscribed to. We strongly expect the Observables participating in the flatmapping to honor the protocol, we expect at most 1 onCompleted calls from anyone. Thus if the main source completes and the active counter was 1, we can be sure no further inner sources will arrive (so there is no 0 - 1 - 0 change) to mess up the accounting. If for, some reason, you don't trust all the sources or your main Observable, feel free to surround this with a CAS to ensure the decrement is only executed once:


    AtomicBoolean once = new AtomicBoolean();
    // ...
    @Override
    public void onCompleted() {
        if (once.compareAndSet(false, true)) {
            if (active.decrementAndGet() == 0) {
                drain();
            }
        }
    }

Due to design decisions in RxJava, the FlatMapSubscriber can't itself implement the Producer interface and needs to swing around any child request with the help of another Producer instance, as seen in the init() method. The target of that call looks as follows:


    void childRequested(long n) {
        if (n > 0) {
            BackpressureUtils.getAndAddRequest(requested, n);
            drain();
        }
    }

The utility method will make sure the requested amount is added and capped to Long.MAX_VALUE and then invokes the drain() method to make sure any queued value is emitted up to that total requested amount.

The delegate methods called from the FlatMapInnerSubscriber are themselves short and I'll show them together:


    void innerNext(Subscriber inner, R value) {
        queue.offer(inner);
        queue.offer(NotificationLite.instance().next(value));
        drain();
    }
    
    void innerError(Throwable ex) {
        onError(ex);
    }
    
    void innerComplete(Subscriber inner) {
        csub.remove(inner);
        onCompleted();
    }

The innerNext() puts the subscriber and the value (plus making sure nulls are wrapped with the help of NotificationLite) into the queue before calling drain; the innerError just delegates to onError and finally, the innerComplete removes the inner Subscriber from the composite and delegates to the regular onCompleted. Note that if you did the once trick I mentioned above, this delegation won't work. You have to introduce the same once field on the FlatMapInnerSubscriber before it calls the innerCompleted and you have to include that decrementAndGet() == 0 in innerCompleted directly.

Finally, let's see the drain() method, piece by piece:


if (wip.getAndIncrement() != 0) {
    return;
}

int missed = 1;

for (;;) {
    
    long r = requested.get();
    long e = 0L;
    
    while (e != r) {

The first section is typical; we increment the wip counter and if it happened to be 0, we enter the drain loop. We'll use missed to detect if others have also called drain() and thus more work has to be performed. We read the current requested amount and prepare the emission amount. The loop will go as long as the emission account doesn't reach the requested amount.

        if (actual.isUnsubscribed()) {
            return;
        }
        
        boolean done = active.get() == 0;              // (1)
        Throwable ex = error.get();                    // (2)
        if (ex != null) {
            actual.onError(ex);
            return;
        }
        
        Object o = queue.poll();
        
        if (done && o == null) {                       // (3)
            actual.onCompleted();
            return;
        }
        
        if (o == null) {
            break;
        }

        Object v;
        
        for (;;) {                                     // (4)
            if (actual.isUnsubscribed()) {
                return;
            }
            v = queue.poll();
            if (v != null) {
                break;
            }
        }
        
        actual.onNext(NotificationLite
                .<R>instance().getValue(v));           // (5)
        
        ((FlatMapInnerSubscriber<?, ?>)o)
            .requestMore(1);
        
        e++;
    }


  1. The inside of the drain loop should look familiar with the exception of a secondary loop perhaps. In the while loop, we detect the completion by checking the active count against zero
  2. as well as see if the error reference holds something non-null. 
  3. Since we put 2 objects inside the queue, we have to take 2 objects out. If the first poll returns null, the queue is empty and if at the same time, done is true, we reached the end of all sources and can complete the child. 
  4. However, we can't just poll again because it is possible that the thread got interrupted between the two offer() call in the innerNext() above and thus the second value, the value to be emitted, isn't there yet. Therefore, we need to have an inner loop that keeps polling until the second value arrives (or the child unsubscribes). (Note that this can be avoided with specialized queues or with tuple types.)
  5. Once we have the second and real value, we unwrap it - with the help of NotificationLite - and emit it to the child subscriber. At the same time, we cast the sender back to FlatMapInnerSubscriber (of any) and ask for replenishment. The loop body ends by incrementing the emission amount so we can detect if we fulfilled all current requests.



    if (e == r) {
        if (actual.isUnsubscribed()) {
            return;
        }
        boolean done = active.get() == 0;
        Throwable ex = error.get();
        if (ex != null) {
            actual.onError(ex);
            return;
        }
        
        if (done && queue.isEmpty()) {
            actual.onCompleted();
            return;
        }
    }
    
    if (e != 0L) {
        BackpressureUtils.produced(requested, e);
    }
    
    missed = wip.addAndGet(-missed);
    if (missed == 0) {
        break;
    }
}

The last section deals with the case if the emission count reached the requested count but all that's left is to complete the sequence. Incidentally, this case of e == r may happen if all sources were empty but the child didn't request either; the logic ensures the eager completion of the child Subscriber. If there were emissions, we deduce that amount from the requested field via the help of another utility method, then subtract the missed amount from the wip counter. If that reaches zero, that indicates no further work needs to happen. Otherwise, the most outer loop starts over (with a new missed amount).

Conclusion

If you look into how merge is implemented in RxJava 1.x or how flatMap is implemented in RxJava 2.x and in Reactor 2.5, you will find that they differ notably from the implementation I showed you. The difference is due to performance and functionality reasons that we will dive into in the next parts of this mini-series.

You may think, why not explain the implementation in RxJava 1.x immediately? The reason is twofold: complexity and building blocks. The implementation in this blog post is, in my opinion, more accessible for those who follow my blog posts and as usual, builds upon previously explained concepts as well as establishes a base for further concepts and tie-ins that will come later on.


1 megjegyzés:

  1. This is an excellent treatise on flatmap()'s internals. Thanks very much Dávid for sharing.

    VálaszTörlés