-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
flatten
extension method for Source[Source[T]]
- Loading branch information
1 parent
8971901
commit 2a3831d
Showing
2 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package ox.channels | ||
|
||
import ox.* | ||
import ox.channels.* | ||
import ox.channels.ChannelClosedUnion.isValue | ||
|
||
extension [U](parentSource: Source[Source[U]]) { | ||
|
||
/** Pipes the elements of child sources into the output source. If the parent source or any of the child sources emit an error, the | ||
* pulling stops and the output source emits the error. | ||
*/ | ||
def flatten(using Ox, StageCapacity): Source[U] = { | ||
val c2 = StageCapacity.newChannel[U] | ||
|
||
forkPropagate(c2) { | ||
var pool = List[Source[Source[U]] | Source[U]](parentSource) | ||
repeatWhile { | ||
selectOrClosed(pool) match { | ||
case ChannelClosed.Done => | ||
// TODO: best to remove the specific channel that signalled to be Done | ||
pool = pool.filterNot(_.isClosedForReceiveDetail.contains(ChannelClosed.Done)) | ||
if pool.isEmpty then | ||
c2.doneOrClosed() | ||
false | ||
else true | ||
case ChannelClosed.Error(r) => | ||
c2.errorOrClosed(r) | ||
false | ||
case t: Source[U] @unchecked => | ||
pool = t :: pool | ||
true | ||
case r: U @unchecked => | ||
c2.sendOrClosed(r).isValue | ||
} | ||
} | ||
} | ||
|
||
c2 | ||
} | ||
} |
145 changes: 145 additions & 0 deletions
145
core/src/test/scala/ox/channels/SourceOfSourceOpsTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package ox.channels | ||
|
||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
import ox.* | ||
|
||
import java.util.concurrent.CountDownLatch | ||
import scala.collection.mutable.ListBuffer | ||
|
||
class SourceOfSourceOpsTest extends AnyFlatSpec with Matchers { | ||
|
||
"flatten" should "pipe all elements of the child sources into the output source" in { | ||
supervised { | ||
val source = Source.fromValues( | ||
Source.fromValues(10), | ||
Source.fromValues(20, 30), | ||
Source.fromValues(40, 50, 60) | ||
) | ||
source.flatten.toList should contain theSameElementsAs List(10, 20, 30, 40, 50, 60) | ||
} | ||
} | ||
|
||
it should "handle empty source" in { | ||
supervised { | ||
val source = Source.empty[Source[Int]] | ||
source.flatten.toList should contain theSameElementsAs Nil | ||
} | ||
} | ||
|
||
it should "handle singleton source" in { | ||
supervised { | ||
val source = Source.fromValues(Source.fromValues(10)) | ||
source.flatten.toList should contain theSameElementsAs List(10) | ||
} | ||
} | ||
|
||
it should "pipe elements realtime" in { | ||
supervised { | ||
val source = Channel.bufferedDefault[Source[Int]] | ||
val lockA = CountDownLatch(1) | ||
val lockB = CountDownLatch(1) | ||
source.send(Source.fromValues(10)) | ||
source.send { | ||
val subSource = Channel.bufferedDefault[Int] | ||
subSource.send(20) | ||
forkUnsupervised { | ||
lockA.await() // 30 won't be added until, lockA is released after 20 consumption | ||
subSource.send(30) | ||
subSource.done() | ||
} | ||
subSource | ||
} | ||
forkUnsupervised { | ||
lockB.await() // 40 won't be added until, lockB is released after 30 consumption | ||
source.send(Source.fromValues(40)) | ||
source.done() | ||
} | ||
|
||
val collected = ListBuffer[Int]() | ||
source.flatten.foreachOrError { e => | ||
collected += e | ||
if e == 20 then lockA.countDown() | ||
else if e == 30 then lockB.countDown() | ||
} | ||
collected should contain theSameElementsAs List(10, 20, 30, 40) | ||
} | ||
} | ||
|
||
it should "propagate error of any of the child sources and stop piping" in { | ||
val error = new Exception("intentional failure") | ||
supervised { | ||
val child1 = Channel.rendezvous[Int] | ||
val lock = CountDownLatch(1) | ||
val child1Producer = fork { | ||
child1.send(10) | ||
// wait for child2 to emit an error | ||
lock.await() | ||
// `flatten` will not receive this, as it will be short-circuited by the error | ||
child1.sendOrClosed(30) | ||
|
||
} | ||
val child2 = Channel.rendezvous[Int] | ||
fork { | ||
child2.send(20) | ||
child2.error(error) | ||
lock.countDown() | ||
} | ||
val source = Source.fromValues(child1, child2) | ||
|
||
val (collectedElems, collectedError) = source.flatten.toPartialList() | ||
collectedError shouldBe Some(error) | ||
collectedElems should contain theSameElementsAs List(10, 20) | ||
child1.receive() shouldBe 30 | ||
} | ||
} | ||
|
||
it should "propagate error of the parent source and stop piping" in { | ||
val error = new Exception("intentional failure") | ||
supervised { | ||
val child1 = Channel.rendezvous[Int] | ||
val lock = CountDownLatch(1) | ||
fork { | ||
child1.send(10) | ||
lock.countDown() | ||
// depending on how quick it picks up the error from the parent | ||
// `flatten` may or may not receive this | ||
child1.send(20) | ||
child1.done() | ||
} | ||
val source = Channel.rendezvous[Source[Int]] | ||
fork { | ||
source.send(child1) | ||
// make sure the first element of child1 is consumed before emitting error | ||
lock.await() | ||
source.error(error) | ||
} | ||
|
||
val (collectedElems, collectedError) = source.flatten.toPartialList() | ||
collectedError shouldBe Some(error) | ||
collectedElems should contain atLeastOneElementOf List(10, 20) | ||
} | ||
} | ||
|
||
it should "stop pulling from the sources when the receiver is closed" in { | ||
// TODO: implement this test | ||
} | ||
|
||
extension [T](source: Source[T]) { | ||
def toPartialList(cb: T | Throwable => Unit = (_: Any) => ()): (List[T], Option[Throwable]) = { | ||
val elementCapture = ListBuffer[T]() | ||
var errorCapture = Option.empty[Throwable] | ||
try { | ||
for (t <- source) { | ||
cb(t) | ||
elementCapture += t | ||
} | ||
} catch { | ||
case ChannelClosedException.Error(e) => | ||
cb(e) | ||
errorCapture = Some(e) | ||
} | ||
(elementCapture.toList, errorCapture) | ||
} | ||
} | ||
} |