Introduction
In this post, we will look into expanding the features of our flatMap implementation and improve its performance.
RxJava's flatMap implementation offers limiting the maximum concurrency, that is, the maximum number of active subscriptions to the generated sources and allows delaying exceptions coming from any of the sources, including the main.
Limiting concurrency
Due to historical reasons, RxJava's flatMap (and our version of it from part 1) is unbounded towards the main source. This may work with infrequent main emissions and/or short lived inner Observable sequences. However, even if the main source, such as range(), can emit at any rate, the mapped inner Observables may consume limited resources such as network connections.So the question is, how can we make sure only an user defined number of active Observables are being merged at once? How can we make sure some source emits only a limited number of values?
The answer is, of course, backpressure.
To limit the concurrency in flatMap, the idea is to request a maxConcurrency amount upfront via request(), and then whenever a source completes, request(1) extra.
Let's change our OpFlatMap and FlatMapSubscriber's implementation to include this maxConcurrency parameter:
final int maxConcurrency; public OpFlatMap(Func1<? super T, ? extends Observable<? extends R>> mapper, int prefetch, int maxConcurrency) { this.mapper = mapper; this.prefetch = prefetch; this.maxConcurrency = maxConcurrency; } @Override public Subscriber<T> call(Subscriber<? super R> t) { FlatMapSubscriber<T, R> parent = new FlatMapSubscriber<>(t, mapper, prefetch, maxConcurrency); parent.init(); return parent; }
As a contract, we will handle Integer.MAX_VALUE as an indicator for the original unbounded mode:
final int maxConcurrency; public FlatMapSubscriber(Subscriber<? super R> actual, Func1<? super T, ? extends Observable<? extends R>> mapper, int prefetch, int maxConcurrency) { this.actual = actual; this.mapper = mapper; this.prefetch = prefetch; this.csub = new CompositeSubscription(); this.wip = new AtomicInteger(); this.requested = new AtomicLong(); this.queue = new ConcurrentLinkedQueue<>(); this.active = new AtomicInteger(1); this.error = new AtomicReference<>(); this.maxConcurrency = maxConcurrency; if (maxConcurrency != Integer.MAX_VALUE) { request(maxConcurrency); } }
Finally, we need to update the innerComplete to request another value from the main source
void innerComplete(Subscriber<?> inner) { csub.remove(inner); request(1); onCompleted(); }
Quite straightforward use of backpressure. Note here that this innerComplete may happen concurrently, triggered by the different inner Observables, therefore, the main source's request handler must be thread safe and reentrant-safe.
Delaying errors
By default, many standard operators terminate eagerly whenever they encounter an onError signal. If said operator does something with multiple sources, one sometimes wishes to process all non-error values first and only then act on any error signal that has popped up.final boolean delayErrors; public OpFlatMap(Func1<? super T, ? extends Observable<? extends R>> mapper, int prefetch, int maxConcurrency, boolean delayErrors) { this.mapper = mapper; this.prefetch = prefetch; this.maxConcurrency = maxConcurrency; this.delayErrors = delayErrors; } @Override public Subscriber<T> call(Subscriber<? super R> t) { FlatMapSubscriber<T, R> parent = new FlatMapSubscriber<>(t, mapper, prefetch, maxConcurrency, delayErrors); parent.init(); return parent; } // ... final boolean delayErrors; public FlatMapSubscriber(Subscriber<? super R> actual, Func1<? super T, ? extends Observable<? extends R>> mapper, int prefetch, int maxConcurrency, boolean delayErrors) { this.actual = actual; this.mapper = mapper; this.prefetch = prefetch; this.csub = new CompositeSubscription(); this.wip = new AtomicInteger(); this.requested = new AtomicLong(); this.queue = new ConcurrentLinkedQueue<>(); this.active = new AtomicInteger(1); this.error = new AtomicReference<>(); this.maxConcurrency = maxConcurrency; if (maxConcurrency != Integer.MAX_VALUE) { request(maxConcurrency); } this.delayErrors = delayErrors; }
Delaying errors within flatMap is simple in terms of the delay part, but requires some extra effort on the errors part: at the very end we need to emit a single onError signal no matter how many sources (main or inner) signalled onError before. Certainly, the keeping the very first error till the end is a possible option, but dropping the rest of the errors may not be desirable either. The solution is to collect all Throwables in some data structure and emit a CompositeException at the end.
Using a concurrent Queue<Throwable> for this purpose is an option, RxJava does this, but we can reuse our existing error AtomicReference and perform some compare-and-swap loop to accumulate all the exceptions:
@Override public void onError(Throwable e) { if (delayErrors) { for (;;) { Throwable current = error.get(); Throwable next; if (current == null) { next = e; } else { List<Throwable> list = new ArrayList<>(); if (current instanceof CompositeException) { list.addAll(((CompositeException)current).getExceptions()); } else { list.add(current); } list.add(e); next = new CompositeException(list); } if (error.compareAndSet(current, next)) { if (active.decrementAndGet() == 0) { drain(); } return; } } } else { if (error.compareAndSet(null, e)) { unsubscribe(); drain(); } else { RxJavaPlugins.getInstance() .getErrorHandler().handleError(e); } } }
In the loop, we take the current error and if it is null, we update it to the given exception. If there is an error already, we create a CompositeException to hold the current and the given new exception. However, if the current error happens to be a CompositeException, we flatten the whole list of the previous errors; this gives a nice and flat array of errors at the end inside a single CompositeException. Since we now accept both onError and onCompleted as being non-global terminal events, we decrement the active count and trigger a drain if it reaches zero.
Given Java 7's Throwable.addSuppressed, you may be tempted to use that to collect exceptions, but it has some drawbacks: it uses synchronized and needs a parent exception upfront that costs time to create, even if there was no exception after all. In addition, modifying an existing exception which already has some suppressed exception itself may be more confusing to figure out.
Since an innerError is no longer an immediate terminal condition, we need to adjust the method to remove the inner subscriber from the tracking structure as well as ask for replenishment in case the flatMap operator also runs with limited concurrency:
void innerError(Throwable ex, Subscriber<?> inner) { if (delayErrors) { csub.remove(inner); request(1); } onError(ex); }
Lastly, the drain() method needs adjustments as well. The default implementation signalled the onError the moment it detected it. This has to be changed so if there is an error, it gets only emitted if all values inside the shared queue have been emitted (just like the completion event):
boolean done = active.get() == 0; if (!delayErrors) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); return; } } Object o = queue.poll(); if (done && o == null) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); } else { actual.onCompleted(); } return; } if (o == null) { break; }The original error emission case is now behind a check for delayErrors being false. Otherwise, we check if all sources terminated and the queue is empty and then check if there is any error. We emit the terminal event accordingly and quit.
In addition, we need to update the e == r case (i.e., the case when we emitted the requested amount and the next signal would be a terminal event):
if (e == r) { if (actual.isUnsubscribed()) { return; } boolean done = active.get() == 0; if (!delayErrors) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); return; } } if (done && queue.isEmpty()) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); } else { actual.onCompleted(); } return; } }
Practically the same as above, except the check for isEmpty() instead of poll() as we don't want to consume the value if there is one.
Now we are done with the extra features of OpFlatMap (don't forget to change the FlatMapInnerSubscriber.onError to parent.innerError(e, this); by the way).
Increasing the performance of the Queue
Our flatMap implementation is decent in performance, but all those thread-safety features put an extra toll on the throughput.The default queue implementation we employed, ConcurrentLinkedQueue, is nice, but all those unused Queue features entail an unnecessary overhead; plus, our usage pattern is just multiple-producer single-consumer.
Fortunately, the JCTools library offers higher performance Queue implementations with often 5-20x lower overhead. We can just replace the queue implementation with MpscLinkedQueue (or the fresh MpscGrowableArrayQueue). In addition, if the flatMap runs with some maxConcurrency value, you can even use an MpscArrayQueue directly (where the capacity is maxConcurrency * prefetch), but note that the capacity of array-based queues are rounded up to the next power-of-two value and may waste space.
This change gives a decent throughput increase, but can we do better? Let me answer with another question: what is the overhead of not using the queue at all? Practically zero! If we bypassed the queue completely, that would save even more overhead!
Now the question is how and when can we bypass the queue? In other terms, when can we emit a value?
To emit a value two conditions must be met: 1) no other source is trying to emit at the same time and 2) downstream has requested some.
The first condition is ensured by the wip counter and the second is checked in the drain() loop. If you remember, the wip counter actually encodes 3 states. If zero, nobody is emitting, if 1, there is a drain going on and 2+ indicates more work has to be performed. Therefore, if we can change the wip from 0 to 1, that would meet condition 1). Next we have to check the requested amount for condition 2) and if met, emit the value directly to the downstream's Subscriber.
To accomplish this, we have to extend the innerNext() method with some bypass logic:
void innerNext(FlatMapInnerSubscriber<T, R> inner, R value) { Object v = NotificationLite.instance().next(value); // (1) if (wip.get() == 0 && wip.compareAndSet(0, 1)) { // (2) if (requested.get() != 0L) { // (3) actual.onNext(value); // (4) BackpressureUtils.produced(requested, 1); inner.requestMore(1); } else { queue.offer(inner); // (5) queue.offer(v); } if (wip.decrementAndGet() != 0) { // (6) drainLoop(); } return; } queue.offer(inner); // (7) queue.offer(v); drain(); } void drain() { if (wip.getAndIncrement() != 0) { return; } drainLoop(); } void drainLoop() { int missed = 1; // ...
This pattern should be somewhat familiar too; it is a fast-path queue-drain approach introduced at the very beginning.
- First, we convert the potential null value as it will be potentially required at two places later on.
- If the wip value is zero and can be successfully changed to 1, we enter into the serialized drain mode and are now free to emit ...
- if the requested amount is non-zero.
- Therefore, if there is no contention and there is request from downstream, we emit and don't touch the queue. Once emitted, we have to reduce the requested amount by 1 and request replenishment from the source.
- If the downstream hasn't requested, we have to revert to the original queuing behavior and store the value for later use.
- In the drain mode, we decrement the wip counter and if more work has arrived in the meantime, we resume with the loop-part of the former drain. This also means the drain() method has to be refactored into drain() and drainLoop() methods. Using the original drain() here wouldn't work because it would skip its loop due to wip being non-zero already.
- In case there was a contention, i.e., wip was non-zero, we have to revert to the original queue-drain behavior.
You may recall we use wip.getAndIncrement() == 0 to enter the serialized mode elsewhere but not here. The reason for it is that although getAndIncrement scales better and is intrinsified into a single CPU instruction on x86, it has more overhead compared to a single compareAndSet call when there is no contention - and we base our fast-path optimization on this no contention property.
Increasing performance in high-contention case
The queue-bypass optimization has its limits; it's almost never triggered when all sources emit quite rapidly, causing contention all the time.This contention affects both the shared queue and the wip counter, therefore, we could gain some performance by getting rid of one or both contention point. Unfortunately, wip is essential and unavoidable, therefore, let's look at queue instead.
The problem is that all concurrent sources use the same queue and contend on the queue's offer() side, thus requiring a multi-producer queue instance that uses the heavyweight getAndSet() or getAndIncrement() atomic operations internally.
However, since each source is sequential by nature, we practically have single-threaded producers, N at once and due to the drain loop, there is only a single consumer to all of those sources.
The solution is to use a single-producer single-consumer queue for each source and in the drain loop, collect values from all of them individually. A great opportunity for JCTools again with its ultra-high performance SpscArrayQueue. We can use the array variant because our prefetch value is expected to be reasonably low; RxJava runs with 128 by default.
This requires some modest changes to both the FlatMapInnerSubscriber and its FlatMapSubscriber parent as well:
static final class FlatMapInnerSubscriber<T, R> extends Subscriber<R> { final FlatMapSubscriber<T, R> parent; final int prefetch; volatile Queue<Object> queue; volatile boolean done; public FlatMapInnerSubscriber( FlatMapSubscriber<T, R> parent, int prefetch) { this.parent = parent; this.prefetch = prefetch; request(prefetch); } @Override public void onNext(R t) { parent.innerNext(this, t); } @Override public void onError(Throwable e) { done = true; parent.innerError(e, this); } @Override public void onCompleted() { done = true; parent.innerComplete(this); } void requestMore(long n) { request(n); } Queue<Object> getOrCreateQueue() { Queue<Object> q = queue; if (q == null) { q = new SpscArrayQueue<>(prefetch); queue = q; } return q; } }
The FlatMapInnerSubscriber gets two fields, one storing the prefetch amount to be used later creating the SpscArrayQueue and the Queue instance itself. In addition, we need to know if the source finished emitting any events via the done flag of its own. Of course, we could pre-create the queue but we would waste the benefit of the fast-path from the previous section, which doesn't require a queue if the fast-path succeeds. Regardless, if a queue is eventually needed, the getOrCreateQueue will make that happen. Note that if the queue is eventually needed, it will be created by a single thread yet may be read from the draining thread and thus has to be volatile.
Next step is to change the innerNext() to work with this per-source queue instead of the shared one:
void innerNext(FlatMapInnerSubscriber<T, R> inner, R value) { Object v = NotificationLite.instance().next(value); if (wip.get() == 0 && wip.compareAndSet(0, 1)) { if (requested.get() != 0L) { actual.onNext(value); BackpressureUtils.produced(requested, 1); inner.requestMore(1); } else { Queue<Object> q = inner.getOrCreateQueue(); q.offer(v); } if (wip.decrementAndGet() != 0) { drainLoop(); } return; } Queue<Object> q = inner.getOrCreateQueue(); q.offer(v); drain(); }
This incurs a small change only in the form of inner.getOrCreateQueue() as the target queue in case of contention or missing downstream requested. (At this point, one could remove the main queue from the parent class, but let's hold onto it a bit more.)
Unfortunately, this per-source queue does some trouble because the drainLoop() can no longer use the shared queue and has to know about the current active sources in some way, but the CompositeSubscription doesn't expose its content. In addition, the CompositeSubscription uses an internal HashSet which has to be iterated in a thread-safe manner, adding so much overhead to the common case all our hard work goes to waste.
Instead, we can use the same copy-on-write subscriber tracking we did with Subjects and ConnectableObservables. This gives us a nice array of FlatMapInnerSubscribers and can get rid of both csub and active fields.
@SuppressWarnings("rawtypes") static final FlatMapInnerSubscriber[] EMPTY = new FlatMapInnerSubscriber[0]; @SuppressWarnings("rawtypes") static final FlatMapInnerSubscriber[] TERMINATED = new FlatMapInnerSubscriber[0]; final AtomicReference<FlatMapInnerSubscriber<T, R>[]> subscribers; volatile boolean done; @SuppressWarnings("unchecked") public FlatMapSubscriber(Subscriber<? super R> actual, Func1<? super T, ? extends Observable<? extends R>> mapper, int prefetch, int maxConcurrency, boolean delayErrors) { this.actual = actual; this.mapper = mapper; this.prefetch = prefetch; this.wip = new AtomicInteger(); this.requested = new AtomicLong(); this.error = new AtomicReference<>(); this.subscribers = new AtomicReference<>(EMPTY); this.maxConcurrency = maxConcurrency; if (maxConcurrency != Integer.MAX_VALUE) { request(maxConcurrency); } this.delayErrors = delayErrors; }
We have the usual empty and terminated indicator arrays and a volatile done field which will be true if the main source completes. The initialization logic has to change as well, plus, we need the usual add(), remove() and terminate() methods:
public void init() { add(Subscriptions.create(this::terminate)); actual.add(this); actual.setProducer(new Producer() { @Override public void request(long n) { childRequested(n); } }); } @SuppressWarnings("unchecked") void terminate() { FlatMapInnerSubscriber<T, R>[] a = subscribers.get(); if (a != TERMINATED) { a = subscribers.getAndSet(TERMINATED); if (a != TERMINATED) { for (FlatMapInnerSubscriber<T, R> inner : a) { inner.unsubscribe(); } } } } boolean add(FlatMapInnerSubscriber<T, R> inner) { for (;;) { FlatMapInnerSubscriber<T, R>[] a = subscribers.get(); if (a == TERMINATED) { return false; } int n = a.length; @SuppressWarnings("unchecked") FlatMapInnerSubscriber<T, R>[] b = new FlatMapInnerSubscriber[n + 1]; System.arraycopy(a, 0, b, 0, n); b[n] = inner; if (subscribers.compareAndSet(a, b)) { return true; } } } @SuppressWarnings("unchecked") void remove(FlatMapInnerSubscriber<T, R> inner) { for (;;) { FlatMapInnerSubscriber<T, R>[] a = subscribers.get(); if (a == TERMINATED || a == EMPTY) { return; } int n = a.length; int j = -1; for (int i = 0; i < n; i++) { if (a[i] == inner) { j = i; break; } } if (j < 0) { return; } FlatMapInnerSubscriber<T, R>[] b; if (n == 1) { b = EMPTY; } else { b = new FlatMapInnerSubscriber[n - 1]; System.arraycopy(a, 0, b, 0, j); System.arraycopy(a, j + 1, b, j, n - j - 1); } if (subscribers.compareAndSet(a, b)) { return; } } }
The onNext method has a small change by adding a conditional subscription clause in case the operator is unsubscribed:
@Override public void onNext(T t) { Observable<? extends R> o; try { o = mapper.call(t); } catch (Throwable ex) { Exceptions.throwOrReport(ex, this, t); return; } FlatMapInnerSubscriber<T, R> inner = new FlatMapInnerSubscriber<>(this, prefetch); if (add(inner)) { o.subscribe(inner); } }
The onError method has a small change too; there is no active field to decrement and thus the drain is always called:
if (error.compareAndSet(current, next)) { drain(); return; }
The onCompleted no longer decrements the active field but instead has to set the done flag:
@Override public void onCompleted() { done = true; drain(); }
The innerError and innerCompleted gets also simpler:
void innerError(Throwable ex, FlatMapInnerSubscriber<T, R> inner) { onError(ex); } void innerComplete(FlatMapInnerSubscriber<T, R> inner) { drain(); }
As usual, all that simplification is eventually offset by complication somewhere else. In our case, the drain loop gets more complicated: now it has to iterate over all active sources and drain their queues, ask for replenishments for the sources as well as the main source.
void drainLoop() { int missed = 1; for (;;) { boolean d = done; FlatMapInnerSubscriber<T, R>[] a = subscribers.get(); long r = requested.get(); long e = 0L; int requestMain = 0; boolean again = false; if (isUnsubscribed()) { return; }
The drain loop now has some additional variables. We get the fresh array of subscribers upfront and introduce a counter for requesting more from the main source and a flag indicating a condition that indicates the front of this outer loop needs to be executed. Note that the done flag has to be checked before getting the current value of the subscribers array due to a possible race with onNext.
if (!delayErrors) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); return; } } if (d && a.length == 0) { Throwable ex = error.get(); if (ex != null) { actual.onError(ex); } else { actual.onCompleted(); } return; }
The next section deals with a delayed or non-delayed error condition. Note also that the done indicator is not used on its own but in conjunction with the length of the inner subscriber array; we know we reached a terminal state if both the main has terminated and we don't have any active inner subscribers (empty array).
for (FlatMapInnerSubscriber<T, R> inner : a) { if (isUnsubscribed()) { return; } d = inner.done; Queue<Object> q = inner.queue; if (q == null) { if (d) { remove(inner); requestMain++; again = true; } } else {
The next part has to loop over all subscribers to see if any of them has values in its respective queue, provided that it has actually queue to begin with; it is possible the fast-path was taken for this source all the way through and there was no queue created via getOrCreateQueue(). In this case, all that remains is to remove the inner subscriber and indicate replenishment from the main source.
long f = 0L; while (e != r) { if (isUnsubscribed()) { return; } d = inner.done; Object v = q.poll(); boolean empty = v == null; if (d && empty) { remove(inner); requestMain++; again = true; } if (empty) { break; } actual.onNext(NotificationLite.<R>instance().getValue(v)); e++; f++; } if (f != 0L) { inner.requestMore(f); } if (e == r) { if (inner.done && q.isEmpty()) { remove(inner); requestMain++; again = true; } break; }
This is a fairly usual drain loop, with the addition of the remove/replenishment logic and a break keyword to stop the loop since the emission count reached the request count and can't emit values anymore. Note the use of the f counter which tracks how many items were consumed from that particular inner FlatMapInnerSubscriber.
} } if (e != 0L) { BackpressureUtils.produced(requested, e); } if (requestMain != 0) { request(requestMain); } if (again) { continue; } missed = wip.addAndGet(-missed); if (missed == 0) { break; } } } }
The final part performs the necessary request tracking, replenishment and missed work handling.
Given the hundreds of lines of complicated code, you should now understand why flatMap is one of the most complicated operators we have.
Inner request rebatching
Before finishing up with this post, lets do a small and final optimization to the latest flatMap structure.If you look at the innerNext() logic, you see that whenever the fast-path is taken, 1 item is requested as replacement for the emitted item all the time. If you imagine the source is a range() operator, such 1-by-1 request will yield an atomic increment after each and every value emitted by range(), adding more overhead.
Fortunately, since the inner sources use a fixed prefetch amount, we can define a re-request point and batch up those 1-by-1 requests into a much larger amount, amortizing the request-tracking overhead in the source sequence.
This point can be anywhere between 1 and the prefetch amount and generally, the value really depends on how the source emits values. Sources may work better at any point in this range. Unfortunately, libraries can't help too much setting these individually and any adaptive logic adds so much overhead that likely negates any benefit from it. Therefore, RxJava chose to re-request after half of the prefetch amount has been emitted (and lately, I tend to use 75% of the prefetch for it).
The solution requires two additional fields and a change to the requestMore() method in FlatMapInnerSubscriber:
final int limit; long produced; public FlatMapInnerSubscriber( FlatMapSubscriber<T, R> parent, int prefetch) { this.parent = parent; this.prefetch = prefetch; this.limit = prefetch - (prefetch >> 2); request(prefetch); } void requestMore(long n) { long p = produced + n; if (p >= limit) { produced = 0; request(p); } else { produced = p; } }
Conclusion
In this post, I showed ways to improve the functionality and performance of our flatMap operator. The diligent reader may check if we reached the structure of RxJava's flatMap implementation, but the answer is: not yet. Apart from cutting an already long post sort, there are a couple more optimizations remaining we could apply.
The first one exploits the fact likelihood that the last source cut short due to lack of requests is the likely candidate to resume emitting values once more requests come in. Saving and restoring an index in the for-loop over the FlatMapInnerSubscriber array helps with it.
The second is called the scalar-optimization and handles a case when one flatMaps Observables of Observable.just(), avoiding the subscription overhead with such sources. This optimization adds significantly more logic to our drainLoop() method, plus has its own queue-bypass optimization as well.
In the next part of this series, I'll add these remaining two optimizations as well as something even better. However, to understand this mysterious optimization, of which the scalar-optimization is actually a member of, we have to learn about something, that requires almost perfect knowledge about the internals of not just flatMap, but every other operator as well.
We call it operator fusion.
Nincsenek megjegyzések:
Megjegyzés küldése